diff --git a/api/commands.py b/api/commands.py index 39a97e497f..a50b9b86e4 100644 --- a/api/commands.py +++ b/api/commands.py @@ -52,19 +52,13 @@ def reset_password(email, new_password, password_confirm): account = db.session.query(Account).filter(Account.email == email).one_or_none() if not account: - click.echo( - click.style("Account not found for email: {}".format(email), fg="red") - ) + click.echo(click.style("Account not found for email: {}".format(email), fg="red")) return try: valid_password(new_password) except: - click.echo( - click.style( - "Invalid password. Must match {}".format(password_pattern), fg="red" - ) - ) + click.echo(click.style("Invalid password. Must match {}".format(password_pattern), fg="red")) return # generate password salt @@ -96,9 +90,7 @@ def reset_email(email, new_email, email_confirm): account = db.session.query(Account).filter(Account.email == email).one_or_none() if not account: - click.echo( - click.style("Account not found for email: {}".format(email), fg="red") - ) + click.echo(click.style("Account not found for email: {}".format(email), fg="red")) return try: @@ -132,34 +124,24 @@ def reset_encrypt_key_pair(): Only support SELF_HOSTED mode. """ if dify_config.EDITION != "SELF_HOSTED": - click.echo( - click.style("This command is only for SELF_HOSTED installations.", fg="red") - ) + click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red")) return tenants = db.session.query(Tenant).all() for tenant in tenants: if not tenant: - click.echo( - click.style("No workspaces found. Run /install first.", fg="red") - ) + click.echo(click.style("No workspaces found. Run /install first.", fg="red")) return tenant.encrypt_public_key = generate_key_pair(tenant.id) - db.session.query(Provider).filter( - Provider.provider_type == "custom", Provider.tenant_id == tenant.id - ).delete() - db.session.query(ProviderModel).filter( - ProviderModel.tenant_id == tenant.id - ).delete() + db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete() + db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete() db.session.commit() click.echo( click.style( - "Congratulations! The asymmetric key pair of workspace {} has been reset.".format( - tenant.id - ), + "Congratulations! The asymmetric key pair of workspace {} has been reset.".format(tenant.id), fg="green", ) ) @@ -209,15 +191,12 @@ def migrate_annotation_vector_database(): for app in apps: total_count = total_count + 1 click.echo( - f"Processing the {total_count} app {app.id}. " - + f"{create_count} created, {skipped_count} skipped." + f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped." ) try: click.echo("Creating app annotation index: {}".format(app.id)) app_annotation_setting = ( - db.session.query(AppAnnotationSetting) - .filter(AppAnnotationSetting.app_id == app.id) - .first() + db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first() ) if not app_annotation_setting: @@ -227,22 +206,13 @@ def migrate_annotation_vector_database(): # get dataset_collection_binding info dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) - .filter( - DatasetCollectionBinding.id - == app_annotation_setting.collection_binding_id - ) + .filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id) .first() ) if not dataset_collection_binding: - click.echo( - "App annotation collection binding not found: {}".format(app.id) - ) + click.echo("App annotation collection binding not found: {}".format(app.id)) continue - annotations = ( - db.session.query(MessageAnnotation) - .filter(MessageAnnotation.app_id == app.id) - .all() - ) + annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all() dataset = Dataset( id=app.id, tenant_id=app.tenant_id, @@ -264,24 +234,14 @@ def migrate_annotation_vector_database(): ) documents.append(document) - vector = Vector( - dataset, attributes=["doc_id", "annotation_id", "app_id"] - ) + vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"]) click.echo(f"Migrating annotations for app: {app.id}.") try: vector.delete() - click.echo( - click.style( - f"Deleted vector index for app {app.id}.", fg="green" - ) - ) + click.echo(click.style(f"Deleted vector index for app {app.id}.", fg="green")) except Exception as e: - click.echo( - click.style( - f"Failed to delete vector index for app {app.id}.", fg="red" - ) - ) + click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red")) raise e if documents: try: @@ -292,11 +252,7 @@ def migrate_annotation_vector_database(): ) ) vector.create(documents) - click.echo( - click.style( - f"Created vector index for app {app.id}.", fg="green" - ) - ) + click.echo(click.style(f"Created vector index for app {app.id}.", fg="green")) except Exception as e: click.echo( click.style( @@ -310,9 +266,7 @@ def migrate_annotation_vector_database(): except Exception as e: click.echo( click.style( - "Error creating app annotation index: {} {}".format( - e.__class__.__name__, str(e) - ), + "Error creating app annotation index: {} {}".format(e.__class__.__name__, str(e)), fg="red", ) ) @@ -378,9 +332,7 @@ def migrate_knowledge_vector_database(): f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped." ) try: - click.echo( - "Creating dataset vector database index: {}".format(dataset.id) - ) + click.echo("Creating dataset vector database index: {}".format(dataset.id)) if dataset.index_struct_dict: if dataset.index_struct_dict["type"] == vector_type: skipped_count = skipped_count + 1 @@ -393,10 +345,7 @@ def migrate_knowledge_vector_database(): if dataset.collection_binding_id: dataset_collection_binding = ( db.session.query(DatasetCollectionBinding) - .filter( - DatasetCollectionBinding.id - == dataset.collection_binding_id - ) + .filter(DatasetCollectionBinding.id == dataset.collection_binding_id) .one_or_none() ) if dataset_collection_binding: @@ -407,9 +356,7 @@ def migrate_knowledge_vector_database(): collection_name = Dataset.gen_collection_name_by_id(dataset_id) elif vector_type in lower_collection_vector_types: - collection_name = Dataset.gen_collection_name_by_id( - dataset_id - ).lower() + collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower() else: raise ValueError(f"Vector store {vector_type} is not supported.") @@ -508,9 +455,7 @@ def migrate_knowledge_vector_database(): db.session.rollback() click.echo( click.style( - "Error creating dataset index: {} {}".format( - e.__class__.__name__, str(e) - ), + "Error creating dataset index: {} {}".format(e.__class__.__name__, str(e)), fg="red", ) ) @@ -572,9 +517,9 @@ def convert_to_agent_apps(): db.session.commit() # update conversation mode to agent - db.session.query(Conversation).filter( - Conversation.app_id == app.id - ).update({Conversation.mode: AppMode.AGENT_CHAT.value}) + db.session.query(Conversation).filter(Conversation.app_id == app.id).update( + {Conversation.mode: AppMode.AGENT_CHAT.value} + ) db.session.commit() click.echo(click.style("Converted app: {}".format(app.id), fg="green")) @@ -588,9 +533,7 @@ def convert_to_agent_apps(): click.echo( click.style( - "Conversion complete. Converted {} agent apps.".format( - len(proceeded_app_ids) - ), + "Conversion complete. Converted {} agent apps.".format(len(proceeded_app_ids)), fg="green", ) ) @@ -723,15 +666,11 @@ def old_metadata_migration(): ) db.session.add(dataset_metadata_binding) else: - dataset_metadata_binding = ( - DatasetMetadataBinding.query.filter( - DatasetMetadataBinding.dataset_id - == document.dataset_id, - DatasetMetadataBinding.document_id == document.id, - DatasetMetadataBinding.metadata_id - == dataset_metadata.id, - ).first() - ) + dataset_metadata_binding = DatasetMetadataBinding.query.filter( + DatasetMetadataBinding.dataset_id == document.dataset_id, + DatasetMetadataBinding.document_id == document.id, + DatasetMetadataBinding.metadata_id == dataset_metadata.id, + ).first() if not dataset_metadata_binding: dataset_metadata_binding = DatasetMetadataBinding( tenant_id=document.tenant_id, @@ -750,9 +689,7 @@ def old_metadata_migration(): @click.option("--email", prompt=True, help="Tenant account email.") @click.option("--name", prompt=True, help="Workspace name.") @click.option("--language", prompt=True, help="Account language, default: en-US.") -def create_tenant( - email: str, language: Optional[str] = None, name: Optional[str] = None -): +def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None): """ Create tenant account """ @@ -790,9 +727,7 @@ def create_tenant( click.echo( click.style( - "Account and tenant created.\nAccount: {}\nPassword: {}".format( - email, new_password - ), + "Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password), fg="green", ) ) @@ -867,19 +802,13 @@ where sites.id is null limit 1000""" fg="red", ) ) - logging.exception( - f"Failed to fix app related site missing issue, app_id: {app_id}" - ) + logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}") continue if not processed_count: break - click.echo( - click.style( - "Fix for missing app-related sites completed successfully!", fg="green" - ) - ) + click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green")) @click.command( @@ -895,9 +824,7 @@ where sites.id is null limit 1000""" help="Type of login ID (phone or email)", ) @click.option("--organization-id", required=True, help="Organization ID") -def create_admin_account( - name: str, login_id: str, login_id_type: str, organization_id: str -): +def create_admin_account(name: str, login_id: str, login_id_type: str, organization_id: str): """ Create or update an admin account with a phone number or email for a specific organization. This command will create a new account if the login ID doesn't exist, @@ -907,46 +834,26 @@ def create_admin_account( # Check if organization exists from models.organization import Organization, OrganizationMember, OrganizationRole - organization = ( - db.session.query(Organization) - .filter(Organization.id == organization_id) - .first() - ) + organization = db.session.query(Organization).filter(Organization.id == organization_id).first() if not organization: - click.echo( - click.style( - f"Organization with ID {organization_id} not found.", fg="red" - ) - ) + click.echo(click.style(f"Organization with ID {organization_id} not found.", fg="red")) return # Get tenant from organization - tenant = ( - db.session.query(Tenant).filter(Tenant.id == organization.tenant_id).first() - ) + tenant = db.session.query(Tenant).filter(Tenant.id == organization.tenant_id).first() if not tenant: - click.echo( - click.style( - f"Tenant for organization {organization_id} not found.", fg="red" - ) - ) + click.echo(click.style(f"Tenant for organization {organization_id} not found.", fg="red")) return # Check if account exists with this login ID account = None if login_id_type == "phone": - account = ( - db.session.query(Account).filter(Account.phone == login_id).first() - ) + account = db.session.query(Account).filter(Account.phone == login_id).first() else: # email - account = ( - db.session.query(Account).filter(Account.email == login_id).first() - ) + account = db.session.query(Account).filter(Account.email == login_id).first() if account: - click.echo( - f"Account with {login_id_type} {login_id} already exists. Updating account..." - ) + click.echo(f"Account with {login_id_type} {login_id} already exists. Updating account...") # Update account account.name = name @@ -1010,9 +917,7 @@ def create_admin_account( if org_member: # Update role to admin org_member.role = OrganizationRole.ADMIN - click.echo( - f"Updated account role to {OrganizationRole.ADMIN} in organization {organization.name}" - ) + click.echo(f"Updated account role to {OrganizationRole.ADMIN} in organization {organization.name}") else: # Add account to organization with admin role org_member = OrganizationMember( @@ -1023,9 +928,7 @@ def create_admin_account( created_by=account.id, ) db.session.add(org_member) - click.echo( - f"Added account to organization {organization.name} with role {OrganizationRole.ADMIN}" - ) + click.echo(f"Added account to organization {organization.name} with role {OrganizationRole.ADMIN}") db.session.commit() @@ -1044,12 +947,8 @@ def create_admin_account( click.echo(click.style(f"Error: {str(e)}", fg="red")) -@click.command( - "create-organization", help="Create a new organization for multi-school support." -) -@click.option( - "--tenant-id", required=True, help="ID of the tenant that owns this organization" -) +@click.command("create-organization", help="Create a new organization for multi-school support.") +@click.option("--tenant-id", required=True, help="ID of the tenant that owns this organization") @click.option("--name", required=True, help="Name of the organization") @click.option("--code", required=True, help="Unique code for the organization") @click.option( @@ -1060,21 +959,15 @@ def create_admin_account( help="Type of organization", ) @click.option("--description", default="", help="Description of the organization") -@click.option( - "--email-domains", default="", help="Comma-separated list of allowed email domains" -) +@click.option("--email-domains", default="", help="Comma-separated list of allowed email domains") @click.option("--created-by", required=True, help="Account ID of the creator") -def create_organization_cmd( - tenant_id, name, code, org_type, description, email_domains, created_by -): +def create_organization_cmd(tenant_id, name, code, org_type, description, email_domains, created_by): """Create a new organization under a tenant for multi-school support""" try: # Check if code already exists from models.organization import Organization - existing = ( - db.session.query(Organization).filter(Organization.code == code).first() - ) + existing = db.session.query(Organization).filter(Organization.code == code).first() if existing: click.echo(f"Error: Organization with code '{code}' already exists") return @@ -1106,9 +999,7 @@ def create_organization_cmd( db.session.add(organization) db.session.commit() - click.echo( - f"Organization '{name}' (ID: {organization.id}) created successfully" - ) + click.echo(f"Organization '{name}' (ID: {organization.id}) created successfully") except Exception as e: db.session.rollback() @@ -1120,17 +1011,13 @@ def create_organization_cmd( @click.option("--name", help="New name for the organization") @click.option("--description", help="New description") @click.option("--email-domains", help="Comma-separated list of allowed email domains") -@click.option( - "--status", type=click.Choice(["active", "inactive"]), help="Organization status" -) +@click.option("--status", type=click.Choice(["active", "inactive"]), help="Organization status") def update_organization_cmd(org_id, name, description, email_domains, status): """Update an existing organization's configuration""" try: from models.organization import Organization - organization = ( - db.session.query(Organization).filter(Organization.id == org_id).first() - ) + organization = db.session.query(Organization).filter(Organization.id == org_id).first() if not organization: click.echo(f"Error: Organization with ID '{org_id}' not found") return @@ -1225,9 +1112,7 @@ def show_organization_cmd(org_id): try: from models.organization import Organization - organization = ( - db.session.query(Organization).filter(Organization.id == org_id).first() - ) + organization = db.session.query(Organization).filter(Organization.id == org_id).first() if not organization: click.echo(f"Error: Organization with ID '{org_id}' not found") @@ -1257,27 +1142,19 @@ def show_organization_cmd(org_id): @click.option( "--role", required=True, - type=click.Choice( - ["admin", "teacher", "student", "staff", "manager", "employee", "guest"] - ), + type=click.Choice(["admin", "teacher", "student", "staff", "manager", "employee", "guest"]), help="Role in the organization", ) @click.option("--department", help="Department within the organization") @click.option("--title", help="Job title or position") -@click.option( - "--is-default", is_flag=True, help="Set as the account's default organization" -) -def add_account_to_organization_cmd( - org_id, account_id, role, department, title, is_default -): +@click.option("--is-default", is_flag=True, help="Set as the account's default organization") +def add_account_to_organization_cmd(org_id, account_id, role, department, title, is_default): """Add an account to an organization with appropriate role and metadata""" try: from models.organization import Organization, OrganizationMember # Check if organization exists - organization = ( - db.session.query(Organization).filter(Organization.id == org_id).first() - ) + organization = db.session.query(Organization).filter(Organization.id == org_id).first() if not organization: click.echo(f"Error: Organization with ID '{org_id}' not found") return @@ -1299,9 +1176,7 @@ def add_account_to_organization_cmd( ) if existing: - click.echo( - "Account is already a member of this organization. Updating role and metadata." - ) + click.echo("Account is already a member of this organization. Updating role and metadata.") existing.role = role existing.department = department existing.title = title @@ -1374,9 +1249,7 @@ def upload_private_key_file_cloud_storage(tenant_id: Optional[str] = None): ) file_key = f"privkeys/{tenant_id}/private.pem" - file_content = Path( - f"{os.environ.get('STORAGE_LOCAL_PATH', 'storage')}/{file_key}" - ).read_bytes() + file_content = Path(f"{os.environ.get('STORAGE_LOCAL_PATH', 'storage')}/{file_key}").read_bytes() storage.save(filename=file_key, data=file_content) click.echo( click.style( @@ -1386,9 +1259,7 @@ def upload_private_key_file_cloud_storage(tenant_id: Optional[str] = None): ) -@click.command( - "upload-local-files-to-cloud-storage", help="upload local files to cloud storage" -) +@click.command("upload-local-files-to-cloud-storage", help="upload local files to cloud storage") def upload_local_files_to_cloud_storage(): """ upload local files to cloud storage @@ -1406,14 +1277,10 @@ def upload_local_files_to_cloud_storage(): batch_size = 100 processed_count = 0 while processed_count < total_count: - files: list[UploadFile] = ( - UploadFile.query.filter_by(storage_type="local").limit(batch_size).all() - ) + files: list[UploadFile] = UploadFile.query.filter_by(storage_type="local").limit(batch_size).all() for file in files: - target_filepath = ( - f"{os.environ.get('STORAGE_LOCAL_PATH', 'storage')}/{file.key}" - ) + target_filepath = f"{os.environ.get('STORAGE_LOCAL_PATH', 'storage')}/{file.key}" # if the file exists if not os.path.exists(target_filepath): @@ -1459,11 +1326,7 @@ def upload_local_files_to_cloud_storage(): processed_count += 1 if processed_count % 10 == 0 or processed_count == total_count: - click.echo( - click.style( - f"Processed {processed_count}/{total_count} files\n", fg="blue" - ) - ) + click.echo(click.style(f"Processed {processed_count}/{total_count} files\n", fg="blue")) time.sleep(3) click.echo( @@ -1564,9 +1427,7 @@ def install_plugins(input_file: str, output_file: str, workers: int): click.echo(click.style("Install plugins completed.", fg="green")) -@click.command( - "clear-free-plan-tenant-expired-logs", help="Clear free plan tenant expired logs." -) +@click.command("clear-free-plan-tenant-expired-logs", help="Clear free plan tenant expired logs.") @click.option( "--days", prompt=True, @@ -1593,9 +1454,7 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[ ClearFreePlanTenantExpiredLogs.process(days, batch, tenant_ids) - click.echo( - click.style("Clear free plan tenant expired logs completed.", fg="green") - ) + click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green")) @click.option( @@ -1651,9 +1510,7 @@ def clear_orphaned_file_records(force: bool): ) ) for ids_table in ids_tables: - click.echo( - click.style(f"- {ids_table['table']} ({ids_table['column']})", fg="yellow") - ) + click.echo(click.style(f"- {ids_table['table']} ({ids_table['column']})", fg="yellow")) click.echo("") click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red")) @@ -1704,9 +1561,7 @@ def clear_orphaned_file_records(force: bool): with db.engine.begin() as conn: rs = conn.execute(db.text(query)) for i in rs: - orphaned_message_files.append( - {"id": str(i[0]), "message_id": str(i[1])} - ) + orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])}) if orphaned_message_files: click.echo( @@ -1732,9 +1587,7 @@ def clear_orphaned_file_records(force: bool): abort=True, ) - click.echo( - click.style("- Deleting orphaned message_files records", fg="white") - ) + click.echo(click.style("- Deleting orphaned message_files records", fg="white")) query = "DELETE FROM message_files WHERE id IN :ids" with db.engine.begin() as conn: conn.execute( @@ -1755,11 +1608,7 @@ def clear_orphaned_file_records(force: bool): ) ) except Exception as e: - click.echo( - click.style( - f"Error deleting orphaned message_files records: {str(e)}", fg="red" - ) - ) + click.echo(click.style(f"Error deleting orphaned message_files records: {str(e)}", fg="red")) # clean up the orphaned records in the rest of the *_files tables try: @@ -1776,14 +1625,8 @@ def clear_orphaned_file_records(force: bool): with db.engine.begin() as conn: rs = conn.execute(db.text(query)) for i in rs: - all_files_in_tables.append( - {"table": files_table["table"], "id": str(i[0]), "key": i[1]} - ) - click.echo( - click.style( - f"Found {len(all_files_in_tables)} files in tables.", fg="white" - ) - ) + all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]}) + click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) # fetch referred table and columns guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" @@ -1798,15 +1641,12 @@ def clear_orphaned_file_records(force: bool): ) ) query = ( - f"SELECT {ids_table['column']} FROM {ids_table['table']} " - f"WHERE {ids_table['column']} IS NOT NULL" + f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" ) with db.engine.begin() as conn: rs = conn.execute(db.text(query)) for i in rs: - all_ids_in_tables.append( - {"table": ids_table["table"], "id": str(i[0])} - ) + all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) elif ids_table["type"] == "text": click.echo( click.style( @@ -1842,11 +1682,7 @@ def clear_orphaned_file_records(force: bool): for i in rs: for j in i[0]: all_ids_in_tables.append({"table": ids_table["table"], "id": j}) - click.echo( - click.style( - f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white" - ) - ) + click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white")) except Exception as e: click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) @@ -1864,9 +1700,7 @@ def clear_orphaned_file_records(force: bool): ) ) return - click.echo( - click.style(f"Found {len(orphaned_files)} orphaned file records.", fg="white") - ) + click.echo(click.style(f"Found {len(orphaned_files)} orphaned file records.", fg="white")) for file in orphaned_files: click.echo(click.style(f"- orphaned file id: {file}", fg="black")) if not force: @@ -1888,13 +1722,9 @@ def clear_orphaned_file_records(force: bool): with db.engine.begin() as conn: conn.execute(db.text(query), {"ids": tuple(orphaned_files)}) except Exception as e: - click.echo( - click.style(f"Error deleting orphaned file records: {str(e)}", fg="red") - ) + click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red")) return - click.echo( - click.style(f"Removed {len(orphaned_files)} orphaned file records.", fg="green") - ) + click.echo(click.style(f"Removed {len(orphaned_files)} orphaned file records.", fg="green")) @click.option( @@ -1903,9 +1733,7 @@ def clear_orphaned_file_records(force: bool): is_flag=True, help="Skip user confirmation and force the command to execute.", ) -@click.command( - "remove-orphaned-files-on-storage", help="Remove orphaned files on the storage." -) +@click.command("remove-orphaned-files-on-storage", help="Remove orphaned files on the storage.") def remove_orphaned_files_on_storage(force: bool): """ Remove orphaned files on the storage. @@ -1981,32 +1809,20 @@ def remove_orphaned_files_on_storage(force: bool): all_files_in_tables = [] try: for files_table in files_tables: - click.echo( - click.style( - f"- Listing files from table {files_table['table']}", fg="white" - ) - ) + click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white")) query = f"SELECT {files_table['key_column']} FROM {files_table['table']}" with db.engine.begin() as conn: rs = conn.execute(db.text(query)) for i in rs: all_files_in_tables.append(str(i[0])) - click.echo( - click.style( - f"Found {len(all_files_in_tables)} files in tables.", fg="white" - ) - ) + click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) except Exception as e: click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) all_files_on_storage = [] for storage_path in storage_paths: try: - click.echo( - click.style( - f"- Scanning files on storage path {storage_path}", fg="white" - ) - ) + click.echo(click.style(f"- Scanning files on storage path {storage_path}", fg="white")) files = storage.scan(path=storage_path, files=True, directories=False) all_files_on_storage.extend(files) except FileNotFoundError as e: @@ -2025,18 +1841,12 @@ def remove_orphaned_files_on_storage(force: bool): ) ) continue - click.echo( - click.style(f"Found {len(all_files_on_storage)} files on storage.", fg="white") - ) + click.echo(click.style(f"Found {len(all_files_on_storage)} files on storage.", fg="white")) # find orphaned files orphaned_files = list(set(all_files_on_storage) - set(all_files_in_tables)) if not orphaned_files: - click.echo( - click.style( - "No orphaned files found. There is nothing to remove.", fg="green" - ) - ) + click.echo(click.style("No orphaned files found. There is nothing to remove.", fg="green")) return click.echo(click.style(f"Found {len(orphaned_files)} orphaned files.", fg="white")) for file in orphaned_files: @@ -2057,18 +1867,10 @@ def remove_orphaned_files_on_storage(force: bool): click.echo(click.style(f"- Removing orphaned file: {file}", fg="white")) except Exception as e: error_files += 1 - click.echo( - click.style( - f"- Error deleting orphaned file {file}: {str(e)}", fg="red" - ) - ) + click.echo(click.style(f"- Error deleting orphaned file {file}: {str(e)}", fg="red")) continue if error_files == 0: - click.echo( - click.style( - f"Removed {removed_files} orphaned files without errors.", fg="green" - ) - ) + click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green")) else: click.echo( click.style( diff --git a/api/controllers/admin/auth/login.py b/api/controllers/admin/auth/login.py index 54b65fae6b..25db885e2a 100644 --- a/api/controllers/admin/auth/login.py +++ b/api/controllers/admin/auth/login.py @@ -238,9 +238,7 @@ class LoginApi(Resource): AccountService.reset_login_error_rate_limit(login_id) # Generate token for the authenticated admin - token_pair = AccountService.login( - account, ip_address=extract_remote_ip(request) - ) + token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) response_data = token_pair.model_dump() diff --git a/api/controllers/admin/settings/settings.py b/api/controllers/admin/settings/settings.py index 2ffc13ee27..8abcf2684e 100644 --- a/api/controllers/admin/settings/settings.py +++ b/api/controllers/admin/settings/settings.py @@ -181,6 +181,6 @@ class OperationLogs(Resource): pass -api.add_resource(WatermarkSettings, '/settings/watermark') -api.add_resource(SystemInfo, '/settings/info') -api.add_resource(OperationLogs, '/settings/logs') +api.add_resource(WatermarkSettings, "/settings/watermark") +api.add_resource(SystemInfo, "/settings/info") +api.add_resource(OperationLogs, "/settings/logs") diff --git a/api/controllers/admin/stats/stats.py b/api/controllers/admin/stats/stats.py index 42e679670d..968b8ef000 100644 --- a/api/controllers/admin/stats/stats.py +++ b/api/controllers/admin/stats/stats.py @@ -109,15 +109,15 @@ class UserStats(Resource): """ try: # Parse date parameters - start_date_str = request.args.get('start_date') - end_date_str = request.args.get('end_date') + start_date_str = request.args.get("start_date") + end_date_str = request.args.get("end_date") if not start_date_str or not end_date_str: raise BadRequest("start_date and end_date are required") try: - start_date = datetime.strptime(start_date_str, '%Y-%m-%d') - end_date = datetime.strptime(end_date_str, '%Y-%m-%d') + start_date = datetime.strptime(start_date_str, "%Y-%m-%d") + end_date = datetime.strptime(end_date_str, "%Y-%m-%d") end_date = end_date.replace(hour=23, minute=59, second=59) except ValueError: raise BadRequest("Invalid date format. Use YYYY-MM-DD") @@ -187,15 +187,15 @@ class ConversationStats(Resource): """ try: # Parse date parameters - start_date_str = request.args.get('start_date') - end_date_str = request.args.get('end_date') + start_date_str = request.args.get("start_date") + end_date_str = request.args.get("end_date") if not start_date_str or not end_date_str: raise BadRequest("start_date and end_date are required") try: - start_date = datetime.strptime(start_date_str, '%Y-%m-%d') - end_date = datetime.strptime(end_date_str, '%Y-%m-%d') + start_date = datetime.strptime(start_date_str, "%Y-%m-%d") + end_date = datetime.strptime(end_date_str, "%Y-%m-%d") end_date = end_date.replace(hour=23, minute=59, second=59) except ValueError: raise BadRequest("Invalid date format. Use YYYY-MM-DD") @@ -215,6 +215,6 @@ class ConversationStats(Resource): return {"error": "An error occurred while processing the request"}, 500 -api.add_resource(RiskStats, '/stats/risk') -api.add_resource(UserStats, '/stats/user') -api.add_resource(ConversationStats, '/stats/conversation') +api.add_resource(RiskStats, "/stats/risk") +api.add_resource(UserStats, "/stats/user") +api.add_resource(ConversationStats, "/stats/conversation") diff --git a/api/controllers/admin/students/conversation.py b/api/controllers/admin/students/conversation.py index e642af01a3..a8329eb307 100644 --- a/api/controllers/admin/students/conversation.py +++ b/api/controllers/admin/students/conversation.py @@ -120,4 +120,4 @@ class StudentConversation(Resource): raise NotFound("Last Conversation Not Exists.") -api.add_resource(StudentConversation, '/students//conversation') +api.add_resource(StudentConversation, "/students//conversation") diff --git a/api/controllers/admin/students/students.py b/api/controllers/admin/students/students.py index 672fc7438d..d8bf013804 100644 --- a/api/controllers/admin/students/students.py +++ b/api/controllers/admin/students/students.py @@ -100,11 +100,11 @@ class StudentList(Resource): from flask import request # Get query parameters with defaults - health_status = request.args.get('health_status') - begin_date = request.args.get('begin_date') - end_date = request.args.get('end_date') - page = int(request.args.get('page', 1)) - limit = int(request.args.get('limit', 20)) + health_status = request.args.get("health_status") + begin_date = request.args.get("begin_date") + end_date = request.args.get("end_date") + page = int(request.args.get("page", 1)) + limit = int(request.args.get("limit", 20)) # Validate parameters if begin_date: @@ -122,13 +122,13 @@ class StudentList(Resource): # Build query filters filters = {} if health_status: - filters['health_status'] = health_status + filters["health_status"] = health_status if begin_date: - filters['last_chat_at__gte'] = begin_date + filters["last_chat_at__gte"] = begin_date if end_date: - filters['last_chat_at__lte'] = end_date + filters["last_chat_at__lte"] = end_date # Get students with pagination offset = (page - 1) * limit @@ -142,4 +142,4 @@ class StudentList(Resource): ) -api.add_resource(StudentList, '/students') +api.add_resource(StudentList, "/students") diff --git a/api/controllers/admin/wraps.py b/api/controllers/admin/wraps.py index 5af0dc3d38..017dd4120d 100644 --- a/api/controllers/admin/wraps.py +++ b/api/controllers/admin/wraps.py @@ -44,16 +44,20 @@ def validate_admin_token_and_extract_info(view: Optional[Callable] = None): raise Unauthorized("Invalid token: user not found") if account.status != AccountStatus.ACTIVE: raise Unauthorized("Invalid token: account is not active") - + # Check if user has admin role in their current organization - org_member = db.session.query(OrganizationMember).filter( - OrganizationMember.account_id == user_id, - OrganizationMember.organization_id == account.current_organization_id - ).first() - + org_member = ( + db.session.query(OrganizationMember) + .filter( + OrganizationMember.account_id == user_id, + OrganizationMember.organization_id == account.current_organization_id, + ) + .first() + ) + if not org_member: raise Unauthorized("Invalid token: user is not a member of any organization") - + # Check if the user has admin role if org_member.role != OrganizationRole.ADMIN: raise Unauthorized("Invalid token: account does not have admin privileges") diff --git a/api/controllers/inner_tools/answers_summary_analysis.py b/api/controllers/inner_tools/answers_summary_analysis.py index f08b0d370f..41810dcb0a 100644 --- a/api/controllers/inner_tools/answers_summary_analysis.py +++ b/api/controllers/inner_tools/answers_summary_analysis.py @@ -54,25 +54,17 @@ class AnswersSummaryAnalysisApi(Resource): return {"error": "exam_answers file_id is required"}, 400 # Read the exam answers file to get categories and correct answers - exam_answers_file_content, _ = self._read_file_with_encoding_detection( - exam_answers_file_id - ) + exam_answers_file_content, _ = self._read_file_with_encoding_detection(exam_answers_file_id) if not exam_answers_file_content: return {"error": "Failed to read exam answers file or file not found"}, 404 # Parse the exam answers file - exam_answers, categories, correct_answer = self._parse_exam_answers( - exam_answers_file_content - ) + exam_answers, categories, correct_answer = self._parse_exam_answers(exam_answers_file_content) if not categories or not correct_answer: - return { - "error": "Failed to parse categories and correct answers from exam file" - }, 400 + return {"error": "Failed to parse categories and correct answers from exam file"}, 400 # Read the user answers file content with encoding detection - user_answers_file_content, _ = self._read_file_with_encoding_detection( - user_answers_file_id - ) + user_answers_file_content, _ = self._read_file_with_encoding_detection(user_answers_file_id) if not user_answers_file_content: return {"error": "Failed to read user answers file or file not found"}, 404 @@ -82,9 +74,7 @@ class AnswersSummaryAnalysisApi(Resource): return {"error": "Failed to parse user answers from file"}, 400 # Calculate category statistics - summary_analysis = self._calculate_category_statistics( - user_answers, correct_answer, categories - ) + summary_analysis = self._calculate_category_statistics(user_answers, correct_answer, categories) # Return the response return jsonify( @@ -95,16 +85,12 @@ class AnswersSummaryAnalysisApi(Resource): } ) - def _read_file_with_encoding_detection( - self, file_id: str - ) -> tuple[Optional[str], Optional[str]]: + def _read_file_with_encoding_detection(self, file_id: str) -> tuple[Optional[str], Optional[str]]: """Read file content with automatic encoding detection. Supports both CSV and XLSX files, converting XLSX to CSV text format. """ try: - upload_file = ( - db.session.query(UploadFile).filter(UploadFile.id == file_id).first() - ) + upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first() if not upload_file: return None, None @@ -112,15 +98,12 @@ class AnswersSummaryAnalysisApi(Resource): file_content = storage.load_once(upload_file.key) # Check if the file is Excel (.xlsx) based on filename or mime type - file_extension = ( - upload_file.name.split(".")[-1].lower() if upload_file.name else "" - ) + file_extension = upload_file.name.split(".")[-1].lower() if upload_file.name else "" mime_type = upload_file.mime_type if upload_file.mime_type else "" is_excel = ( file_extension == "xlsx" - or mime_type - == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" + or mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" ) if is_excel: @@ -174,9 +157,7 @@ class AnswersSummaryAnalysisApi(Resource): print(f"Error reading file: {str(e)}") return None, None - def _parse_exam_answers( - self, file_content: str - ) -> tuple[list[dict[str, Any]], list[dict[str, Any]], list[str]]: + def _parse_exam_answers(self, file_content: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]], list[str]]: """Parse exam answers from the file content. Expected format is CSV with columns: @@ -207,9 +188,7 @@ class AnswersSummaryAnalysisApi(Resource): exam_answers = [] category_map = defaultdict(list) - correct_answers = [ - "" - ] * 1000 # Initialize with empty strings, we'll trim later + correct_answers = [""] * 1000 # Initialize with empty strings, we'll trim later max_question_num = 0 for row in csv_reader: @@ -247,9 +226,7 @@ class AnswersSummaryAnalysisApi(Resource): correct_answers = correct_answers[:max_question_num] # Convert category_map to the expected categories format - categories = [ - {"name": cat, "items": items} for cat, items in category_map.items() - ] + categories = [{"name": cat, "items": items} for cat, items in category_map.items()] return exam_answers, categories, correct_answers except Exception as e: @@ -281,9 +258,7 @@ class AnswersSummaryAnalysisApi(Resource): result = [] for row in csv_reader: - if ( - not row or len(row) < 4 - ): # Skip empty rows or rows with insufficient data + if not row or len(row) < 4: # Skip empty rows or rows with insufficient data continue # Extract student ID and name @@ -293,9 +268,7 @@ class AnswersSummaryAnalysisApi(Resource): # Extract answers (skip ID, name, and score columns) answers = [ans.strip() for ans in row[3:]] - result.append( - {"user_name": name, "code": student_id, "answers": answers} - ) + result.append({"user_name": name, "code": student_id, "answers": answers}) return result except Exception as e: @@ -399,9 +372,7 @@ class GenerateAnalysisReportApi(Resource): data = request.get_json() summary_analysis = data.get("summary_analysis") - school_name = data.get( - "school_name", "山东单县一中" - ) # Default value if not provided + school_name = data.get("school_name", "山东单县一中") # Default value if not provided html_template = data.get("html_template") if not summary_analysis: @@ -507,9 +478,7 @@ class GenerateAnalysisReportApi(Resource): # Create the HTML with the template template = Template(html_template) - html_content = template.render( - school_name=school_name, summary_analysis=summary_analysis - ) + html_content = template.render(school_name=school_name, summary_analysis=summary_analysis) # Generate PDF html = HTML(string=html_content) diff --git a/api/controllers/service_api_with_auth/auth/login.py b/api/controllers/service_api_with_auth/auth/login.py index 87ce9a3fc4..0fd0acf04f 100644 --- a/api/controllers/service_api_with_auth/auth/login.py +++ b/api/controllers/service_api_with_auth/auth/login.py @@ -198,7 +198,6 @@ class EmailCodeLoginApi(Resource): is_new_user = account is None if account is None: - # Create new account account = AccountService.create_account_in_tenant( tenant=tenant, @@ -212,9 +211,11 @@ class EmailCodeLoginApi(Resource): OrganizationService.assign_account_to_organization(account, organization.id) else: - - if (organization is not None and account.current_organization_id is not None - and account.current_organization_id != organization.id): + if ( + organization is not None + and account.current_organization_id is not None + and account.current_organization_id != organization.id + ): raise OrganizationMismatchError() connected_tenant = TenantService.get_join_tenants(account) diff --git a/api/controllers/service_api_with_auth/user/profile.py b/api/controllers/service_api_with_auth/user/profile.py index 249d21a475..abfe0c0f9b 100644 --- a/api/controllers/service_api_with_auth/user/profile.py +++ b/api/controllers/service_api_with_auth/user/profile.py @@ -89,33 +89,33 @@ class UserProfile(Resource): validated_data = {} # Validate username if provided - if 'username' in data: - username = data['username'] + if "username" in data: + username = data["username"] # Validate username (Chinese or English only, max 10 chars) - if not re.match(r'^[a-zA-Z\u4e00-\u9fa5]+$', username) or len(username) > 10: + if not re.match(r"^[a-zA-Z\u4e00-\u9fa5]+$", username) or len(username) > 10: return {"success": False, "message": "Invalid username format"}, 400 - validated_data['username'] = username + validated_data["username"] = username # Validate gender if provided - if 'gender' in data: - gender_str = data['gender'] + if "gender" in data: + gender_str = data["gender"] if gender_str not in ["unknown", "male", "female"]: return {"success": False, "message": "Invalid gender value"}, 400 - validated_data['gender'] = gender_str + validated_data["gender"] = gender_str # Validate major if provided - if 'major' in data: - major = data['major'] + if "major" in data: + major = data["major"] # Allow None as a valid value (to clear the field) if major is None: - validated_data['major'] = None + validated_data["major"] = None elif not isinstance(major, str): return {"success": False, "message": "Major must be a string value or null"}, 400 elif len(major) > 50: return {"success": False, "message": "Major exceeds maximum length of 50"}, 400 else: - validated_data['major'] = major + validated_data["major"] = major # Use the service to update user profile success, error = EndUserService.update_user_profile(end_user, validated_data) @@ -126,4 +126,4 @@ class UserProfile(Resource): return {"success": True} -api.add_resource(UserProfile, '/user/profile') +api.add_resource(UserProfile, "/user/profile") diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 3bad09edf4..5847a9a4af 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -27,15 +27,11 @@ def load_user_from_request(request_from_flask_login): raise Unauthorized("Invalid Authorization token.") else: if " " not in auth_header: - raise Unauthorized( - "Invalid Authorization header format. Expected 'Bearer ' format." - ) + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme = auth_scheme.lower() if auth_scheme != "bearer": - raise Unauthorized( - "Invalid Authorization header format. Expected 'Bearer ' format." - ) + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") decoded = PassportService().verify(auth_token) user_id = decoded.get("user_id") diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index 697b3604d5..9bcc911c32 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -111,9 +111,7 @@ def init_app(app: DifyApp): ) as span: span.set_status(StatusCode.ERROR) span.record_exception(record.exc_info[1]) - span.set_attribute( - "exception.type", record.exc_info[0].__name__ - ) + span.set_attribute("exception.type", record.exc_info[0].__name__) span.set_attribute("exception.message", str(record.exc_info[1])) except Exception: pass @@ -198,9 +196,7 @@ def init_app(app: DifyApp): set_meter_provider(MeterProvider(resource=resource, metric_readers=[reader])) if not is_celery_worker(): init_flask_instrumentor(app) - CeleryInstrumentor( - tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider() - ).instrument() + CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument() instrument_exception_logging() init_sqlalchemy_instrumentor(app) atexit.register(shutdown_tracer) @@ -221,6 +217,4 @@ def init_celery_worker(*args, **kwargs): metric_provider = get_meter_provider() if dify_config.DEBUG: logging.info("Initializing OpenTelemetry for Celery worker") - CeleryInstrumentor( - tracer_provider=tracer_provider, meter_provider=metric_provider - ).instrument() + CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument() diff --git a/api/extensions/ext_phone_sms.py b/api/extensions/ext_phone_sms.py index bd146dbdc4..bf32554627 100644 --- a/api/extensions/ext_phone_sms.py +++ b/api/extensions/ext_phone_sms.py @@ -47,11 +47,10 @@ class PhoneSms: access_key_secret=secret, ) # Endpoint 请参考 https://api.aliyun.com/product/Dysmsapi - config.endpoint = 'dysmsapi.aliyuncs.com' + config.endpoint = "dysmsapi.aliyuncs.com" return Dysmsapi20170525Client(config) def send_sms(self, phone_numbers: str, code: str) -> None: - if not self._client: raise ValueError("PhoneSms client is not initialized") @@ -63,7 +62,7 @@ class PhoneSms: ) response = self._client.send_sms_with_options(send_sms_request, util_models.RuntimeOptions()) - if response.body.code != 'OK': + if response.body.code != "OK": raise Exception(response.body.message) diff --git a/api/extensions/ext_swagger.py b/api/extensions/ext_swagger.py index 17c4452a33..d9d81e383b 100644 --- a/api/extensions/ext_swagger.py +++ b/api/extensions/ext_swagger.py @@ -2,23 +2,22 @@ from dify_app import DifyApp def init_app(app: DifyApp): - from flasgger import Swagger - app.config['SWAGGER'] = { - 'title': 'API Docs', - 'uiversion': 3, - 'url_prefix': '/openapi', - 'specs_route': '/', - 'static_url_path': '/flasgger_static', - 'securityDefinitions': { - 'ApiKeyAuth': { - 'type': 'apiKey', - 'name': 'Authorization', - 'in': 'header', - 'description': 'API Key Authorization header using Bearer scheme. Example: "Bearer {token}"' + app.config["SWAGGER"] = { + "title": "API Docs", + "uiversion": 3, + "url_prefix": "/openapi", + "specs_route": "/", + "static_url_path": "/flasgger_static", + "securityDefinitions": { + "ApiKeyAuth": { + "type": "apiKey", + "name": "Authorization", + "in": "header", + "description": 'API Key Authorization header using Bearer scheme. Example: "Bearer {token}"', } - } + }, } Swagger(app) diff --git a/api/models/organization.py b/api/models/organization.py index 8edead73e6..7b42bb2dc9 100644 --- a/api/models/organization.py +++ b/api/models/organization.py @@ -59,13 +59,13 @@ class Organization(db.Model): # type: ignore[name-defined] def allowed_email_domains(self) -> list[str]: """Get list of allowed email domains for this organization""" settings = self.settings_dict - return settings.get('allowed_email_domains', []) + return settings.get("allowed_email_domains", []) @allowed_email_domains.setter def allowed_email_domains(self, domains: list[str]): """Set allowed email domains for this organization""" settings = self.settings_dict - settings['allowed_email_domains'] = domains + settings["allowed_email_domains"] = domains self.settings_dict = settings @property @@ -78,7 +78,7 @@ class Organization(db.Model): # type: ignore[name-defined] if not self.is_email_restricted: return True - email_domain = email.split('@')[-1].lower() + email_domain = email.split("@")[-1].lower() return email_domain in self.allowed_email_domains @property diff --git a/api/schedule/user_profile_generate_task.py b/api/schedule/user_profile_generate_task.py index fc06c7f26d..165c414281 100644 --- a/api/schedule/user_profile_generate_task.py +++ b/api/schedule/user_profile_generate_task.py @@ -43,9 +43,7 @@ def user_profile_generate_task(): logger.info(f"No users to update. for app_id {app_id}") continue - logger.info( - f"Found {len(users_to_update)} users profile and memory updates. in app_id {app_id}" - ) + logger.info(f"Found {len(users_to_update)} users profile and memory updates. in app_id {app_id}") update_user_profile_for_appid(users_to_update) end_at = time.perf_counter() @@ -64,9 +62,7 @@ def update_user_profile_for_appid(users_to_update: list[EndUser]): batch = users_to_update[i : i + batch_size] try: for user in batch: - new_messages, latest_messages_created_at = fetch_new_messages_for_user( - user - ) + new_messages, latest_messages_created_at = fetch_new_messages_for_user(user) if len(new_messages) > 0: process_user_memory(user, new_messages) @@ -91,9 +87,7 @@ def fetch_users_to_update(app_id: str) -> list[EndUser]: ) latest_message_query = latest_message_query.filter(Message.app_id == app_id) - latest_message_subquery = latest_message_query.group_by( - Message.from_end_user_id - ).subquery() + latest_message_subquery = latest_message_query.group_by(Message.from_end_user_id).subquery() # Then join with EndUser to find users who need memory updates users_query = ( @@ -106,8 +100,7 @@ def fetch_users_to_update(app_id: str) -> list[EndUser]: EndUser.app_id == app_id, or_( EndUser.profile_updated_at.is_(None), - EndUser.profile_updated_at - < latest_message_subquery.c.latest_message_time, + EndUser.profile_updated_at < latest_message_subquery.c.latest_message_time, ), ) ) @@ -122,14 +115,10 @@ def fetch_users_to_update(app_id: str) -> list[EndUser]: def fetch_new_messages_for_user(user: EndUser) -> tuple[str, datetime]: """Fetch new messages for a user.""" - message_query = db.session.query(Message).filter( - Message.from_end_user_id == user.id - ) + message_query = db.session.query(Message).filter(Message.from_end_user_id == user.id) message_query = message_query.filter(Message.app_id == user.app_id) if user.profile_updated_at: - message_query = message_query.filter( - Message.created_at > user.profile_updated_at - ) + message_query = message_query.filter(Message.created_at > user.profile_updated_at) new_messages = message_query.order_by(asc(Message.created_at)).all() if len(new_messages) == 0: @@ -150,9 +139,7 @@ def process_user_memory(user: EndUser, new_messages: str): memory_app_id = dify_config.USER_MEMORY_GENERATION_APP_ID if memory_app_id == "": - logger.warning( - "No memory generation app_id provided, skipping memory generation." - ) + logger.warning("No memory generation app_id provided, skipping memory generation.") return memory_app_model = db.session.query(App).filter(App.id == memory_app_id).first() @@ -195,18 +182,12 @@ def process_user_health_summary(user: EndUser, new_messages: str): health_summary_app_id = dify_config.USER_HEALTH_SUMMARY_GENERATION_APP_ID if health_summary_app_id == "": - logger.warning( - "No health summary app_id provided, skipping health summary generation." - ) + logger.warning("No health summary app_id provided, skipping health summary generation.") return - health_summary_app_model = ( - db.session.query(App).filter(App.id == health_summary_app_id).first() - ) + health_summary_app_model = db.session.query(App).filter(App.id == health_summary_app_id).first() if health_summary_app_model is None: - logger.error( - f"App not found for health summary generation app_id {health_summary_app_id}" - ) + logger.error(f"App not found for health summary generation app_id {health_summary_app_id}") return args = { @@ -237,9 +218,7 @@ def process_user_health_summary(user: EndUser, new_messages: str): result = response["data"]["outputs"]["result"] if result is None: - logger.warning( - f"Health summary generation failed with None result for user {user.id}" - ) + logger.warning(f"Health summary generation failed with None result for user {user.id}") return # preprocess result in case of ```json xxxx``` diff --git a/api/services/account_service.py b/api/services/account_service.py index e6a7f4c8dd..fb42e8c326 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -70,9 +70,7 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS) class AccountService: - reset_password_rate_limiter = RateLimiter( - prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1 - ) + reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1) email_code_login_rate_limiter = RateLimiter( prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1 ) @@ -122,16 +120,12 @@ class AccountService: if account.status == AccountStatus.BANNED.value: raise Unauthorized("Account is banned.") - current_tenant = TenantAccountJoin.query.filter_by( - account_id=account.id, current=True - ).first() + current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() if current_tenant: account.current_tenant_id = current_tenant.tenant_id else: available_ta = ( - TenantAccountJoin.query.filter_by(account_id=account.id) - .order_by(TenantAccountJoin.id.asc()) - .first() + TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() ) if not available_ta: return None @@ -140,9 +134,7 @@ class AccountService: available_ta.current = True db.session.commit() - if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta( - minutes=10 - ): + if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10): account.last_active_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() @@ -150,9 +142,7 @@ class AccountService: @staticmethod def get_account_jwt_token(account: Account) -> str: - exp_dt = datetime.now(UTC) + timedelta( - minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES - ) + exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) exp = int(exp_dt.timestamp()) payload = { "user_id": account.id, @@ -165,9 +155,7 @@ class AccountService: return token @staticmethod - def authenticate( - email: str, password: str, invite_token: Optional[str] = None - ) -> Account: + def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account: """authenticate account with email and password""" account = db.session.query(Account).filter_by(email=email).first() @@ -186,9 +174,7 @@ class AccountService: account.password = base64_password_hashed account.password_salt = base64_salt - if account.password is None or not compare_password( - password, account.password, account.password_salt - ): + if account.password is None or not compare_password(password, account.password, account.password_salt): raise AccountPasswordError("Invalid email or password.") if account.status == AccountStatus.PENDING.value: @@ -202,9 +188,7 @@ class AccountService: @staticmethod def update_account_password(account, password, new_password): """update account password""" - if account.password and not compare_password( - password, account.password, account.password_salt - ): + if account.password and not compare_password(password, account.password, account.password_salt): raise CurrentPasswordIncorrectError("Current password is incorrect.") # may be raised @@ -352,11 +336,9 @@ class AccountService: """Link account integrate""" try: # Query whether there is an existing binding record for the same provider - account_integrate: Optional[AccountIntegrate] = ( - AccountIntegrate.query.filter_by( - account_id=account.id, provider=provider - ).first() - ) + account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by( + account_id=account.id, provider=provider + ).first() if account_integrate: # If it exists, update the record @@ -376,9 +358,7 @@ class AccountService: db.session.commit() logging.info(f"Account {account.id} linked {provider} account {open_id}.") except Exception as e: - logging.exception( - f"Failed to link {provider} account {open_id} to Account {account.id}" - ) + logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}") raise LinkAccountIntegrateError("Failed to link account.") from e @staticmethod @@ -425,20 +405,14 @@ class AccountService: @staticmethod def logout(*, account: Account) -> None: - refresh_token = redis_client.get( - AccountService._get_account_refresh_token_key(account.id) - ) + refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id)) if refresh_token: - AccountService._delete_refresh_token( - refresh_token.decode("utf-8"), account.id - ) + AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id) @staticmethod def refresh_token(refresh_token: str) -> TokenPair: # Verify the refresh token - account_id = redis_client.get( - AccountService._get_refresh_token_key(refresh_token) - ) + account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token)) if not account_id: raise ValueError("Invalid refresh token") @@ -525,9 +499,7 @@ class AccountService: if email is None: raise ValueError("Email must be provided.") - if dify_config.DEBUG_ORG_EMAIL_DOMAIN and email.endswith( - dify_config.DEBUG_ORG_EMAIL_DOMAIN - ): + if dify_config.DEBUG_ORG_EMAIL_DOMAIN and email.endswith(dify_config.DEBUG_ORG_EMAIL_DOMAIN): code = dify_config.DEBUG_CODE_FOR_LOGIN elif cls.email_code_login_rate_limiter.is_rate_limited(email): from controllers.console.auth.error import ( @@ -659,9 +631,7 @@ class AccountService: redis_client.setex(freeze_key, 60 * 60, 1) return True else: - redis_client.setex( - hour_limit_key, 60 * 10, hour_limit_count + 1 - ) # first time limit 10 minutes + redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes # add hour limit count redis_client.incr(hour_limit_key) @@ -697,9 +667,7 @@ class AccountService: organization_id = admin_account.current_organization_id if not organization_id: - logging.warning( - f"Account {admin_account.id} is not a member of any organization." - ) + logging.warning(f"Account {admin_account.id} is not a member of any organization.") return None # If organization_id is provided, check if account is an admin member of that organization @@ -716,9 +684,7 @@ class AccountService: ) if not org_member: - logging.warning( - f"Account {admin_account.id} is not a member of any organization." - ) + logging.warning(f"Account {admin_account.id} is not a member of any organization.") return None return admin_account @@ -744,9 +710,7 @@ class AccountService: current_minute_count = int(current_minute_count) # check current hour count - if ( - current_minute_count > dify_config.EMAIL_SEND_IP_LIMIT_PER_MINUTE - ): # Use same limit as email + if current_minute_count > dify_config.EMAIL_SEND_IP_LIMIT_PER_MINUTE: # Use same limit as email hour_limit_count = redis_client.get(hour_limit_key) if hour_limit_count is None: hour_limit_count = 0 @@ -756,9 +720,7 @@ class AccountService: redis_client.setex(freeze_key, 60 * 60, 1) return True else: - redis_client.setex( - hour_limit_key, 60 * 10, hour_limit_count + 1 - ) # first time limit 10 minutes + redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes # add hour limit count redis_client.incr(hour_limit_key) @@ -823,11 +785,7 @@ class AccountService: Returns None if no admin account with this ID exists. Raises Unauthorized if account is banned. """ - account = ( - db.session.query(Account) - .filter((Account.email == login_id) | (Account.phone == login_id)) - .first() - ) + account = db.session.query(Account).filter((Account.email == login_id) | (Account.phone == login_id)).first() if not account: return None @@ -842,7 +800,6 @@ class AccountService: class TenantService: - @staticmethod def get_tenant_by_id(tenant_id: str) -> Tenant: return Tenant.query.filter_by(id=tenant_id).first() @@ -877,53 +834,38 @@ class TenantService: ): """Check if user have a workspace or not""" available_ta = ( - TenantAccountJoin.query.filter_by(account_id=account.id) - .order_by(TenantAccountJoin.id.asc()) - .first() + TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() ) if available_ta: return """Create owner tenant if not exist""" - if ( - not FeatureService.get_system_features().is_allow_create_workspace - and not is_setup - ): + if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup: raise WorkSpaceNotAllowedCreateError() if name: tenant = TenantService.create_tenant(name=name, is_setup=is_setup) else: - tenant = TenantService.create_tenant( - name=f"{account.name}'s Workspace", is_setup=is_setup - ) + tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup) TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant db.session.commit() tenant_was_created.send(tenant) @staticmethod - def create_tenant_member( - tenant: Tenant, account: Account, role: str = "normal" - ) -> TenantAccountJoin: + def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin: """Create tenant member""" if role == TenantAccountRole.OWNER.value: if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]): logging.error(f"Tenant {tenant.id} has already an owner.") raise Exception("Tenant already has an owner.") - ta = ( - db.session.query(TenantAccountJoin) - .filter_by(tenant_id=tenant.id, account_id=account.id) - .first() - ) + ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() if ta: ta.role = role else: - ta = TenantAccountJoin( - tenant_id=tenant.id, account_id=account.id, role=role - ) + ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role) db.session.add(ta) db.session.commit() @@ -949,9 +891,7 @@ class TenantService: if not tenant: raise TenantNotFoundError("Tenant not found.") - ta = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, account_id=account.id - ).first() + ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() if ta: tenant.role = ta.role else: @@ -978,9 +918,7 @@ class TenantService: ) if not tenant_account_join: - raise AccountNotLinkTenantError( - "Tenant not found or account is not a member of the tenant." - ) + raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: TenantAccountJoin.query.filter( TenantAccountJoin.account_id == account.id, @@ -1065,9 +1003,7 @@ class TenantService: return cast(int, db.session.query(func.count(Tenant.id)).scalar()) @staticmethod - def check_member_permission( - tenant: Tenant, operator: Account, member: Account | None, action: str - ) -> None: + def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None: """Check member permission""" perms = { "add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], @@ -1081,26 +1017,20 @@ class TenantService: if operator.id == member.id: raise CannotOperateSelfError("Cannot operate self.") - ta_operator = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, account_id=operator.id - ).first() + ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first() if not ta_operator or ta_operator.role not in perms[action]: raise NoPermissionError(f"No permission to {action} member.") @staticmethod - def remove_member_from_tenant( - tenant: Tenant, account: Account, operator: Account - ) -> None: + def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None: """Remove member from tenant""" if operator.id == account.id: raise CannotOperateSelfError("Cannot operate self.") TenantService.check_member_permission(tenant, operator, account, "remove") - ta = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, account_id=account.id - ).first() + ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() if not ta: raise MemberNotInTenantError("Member not in tenant.") @@ -1108,26 +1038,18 @@ class TenantService: db.session.commit() @staticmethod - def update_member_role( - tenant: Tenant, member: Account, new_role: str, operator: Account - ) -> None: + def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None: """Update member role""" TenantService.check_member_permission(tenant, operator, member, "update") - target_member_join = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, account_id=member.id - ).first() + target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first() if target_member_join.role == new_role: - raise RoleAlreadyAssignedError( - "The provided role is already assigned to the member." - ) + raise RoleAlreadyAssignedError("The provided role is already assigned to the member.") if new_role == "owner": # Find the current owner and change their role to 'admin' - current_owner_join = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, role="owner" - ).first() + current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() current_owner_join.role = "admin" # Update the role of the target member @@ -1137,9 +1059,7 @@ class TenantService: @staticmethod def dissolve_tenant(tenant: Tenant, operator: Account) -> None: """Dissolve tenant""" - if not TenantService.check_member_permission( - tenant, operator, operator, "remove" - ): + if not TenantService.check_member_permission(tenant, operator, operator, "remove"): raise NoPermissionError("No permission to dissolve tenant.") db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() db.session.delete(tenant) @@ -1224,10 +1144,7 @@ class RegisterService: if open_id is not None and provider is not None: AccountService.link_account_integrate(provider, open_id, account) - if ( - FeatureService.get_system_features().is_allow_create_workspace - and create_workspace_required - ): + if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required: tenant = TenantService.create_tenant(f"{account.name}'s Workspace") TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant @@ -1281,9 +1198,7 @@ class RegisterService: TenantService.switch_tenant(account, tenant.id) else: TenantService.check_member_permission(tenant, inviter, account, "add") - ta = TenantAccountJoin.query.filter_by( - tenant_id=tenant.id, account_id=account.id - ).first() + ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() if not ta: TenantService.create_tenant_member(tenant, account, role) @@ -1330,9 +1245,7 @@ class RegisterService: def revoke_token(cls, workspace_id: str, email: str, token: str): if workspace_id and email: email_hash = sha256(email.encode()).hexdigest() - cache_key = "member_invite_token:{}, {}:{}".format( - workspace_id, email_hash, token - ) + cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token) redis_client.delete(cache_key) else: redis_client.delete(cls._get_invitation_token_key(token)) @@ -1347,9 +1260,7 @@ class RegisterService: tenant = ( db.session.query(Tenant) - .filter( - Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal" - ) + .filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal") .first() ) diff --git a/api/services/end_user_service.py b/api/services/end_user_service.py index 92fbdae485..115fc005ea 100644 --- a/api/services/end_user_service.py +++ b/api/services/end_user_service.py @@ -32,8 +32,8 @@ class EndUserService: db.session.query( Message.from_end_user_id, func.count( - func.distinct(func.date(func.timezone('UTC+8', func.timezone('UTC', Message.created_at)))) - ).label('active_days'), + func.distinct(func.date(func.timezone("UTC+8", func.timezone("UTC", Message.created_at)))) + ).label("active_days"), ) .filter(Message.app_id == app_model.id) .group_by(Message.from_end_user_id) @@ -44,9 +44,9 @@ class EndUserService: subq = ( db.session.query( Conversation.from_end_user_id, - func.max(Conversation.created_at).label('last_chat_at'), - func.min(Conversation.created_at).label('first_chat_at'), - func.count(Message.id).label('total_messages'), + func.max(Conversation.created_at).label("last_chat_at"), + func.min(Conversation.created_at).label("first_chat_at"), + func.count(Message.id).label("total_messages"), ) .filter(Conversation.app_id == app_model.id) .join(Message, Message.conversation_id == Conversation.id) @@ -75,14 +75,14 @@ class EndUserService: # Apply filters filter_conditions = [] - if 'health_status' in filters: - filter_conditions.append(EndUser.health_status == filters['health_status']) + if "health_status" in filters: + filter_conditions.append(EndUser.health_status == filters["health_status"]) - if 'last_chat_at__gte' in filters: - filter_conditions.append(subq.c.last_chat_at >= filters['last_chat_at__gte']) + if "last_chat_at__gte" in filters: + filter_conditions.append(subq.c.last_chat_at >= filters["last_chat_at__gte"]) - if 'last_chat_at__lte' in filters: - filter_conditions.append(subq.c.last_chat_at <= filters['last_chat_at__lte']) + if "last_chat_at__lte" in filters: + filter_conditions.append(subq.c.last_chat_at <= filters["last_chat_at__lte"]) # Apply all filter conditions if filter_conditions: @@ -109,17 +109,17 @@ class EndUserService: # Convert to dictionary for JSON serialization end_user_dict = { - 'id': end_user.external_user_id, - 'email': end_user.email, - 'first_chat_at': end_user.first_chat_at, - 'last_chat_at': end_user.last_chat_at, - 'total_messages': end_user.total_messages, - 'active_days': end_user.active_days, - 'health_status': end_user.health_status, - 'topics': end_user.topics, - 'summary': end_user.summary, - 'major': end_user.major, - 'organization_id': end_user.organization_id, + "id": end_user.external_user_id, + "email": end_user.email, + "first_chat_at": end_user.first_chat_at, + "last_chat_at": end_user.last_chat_at, + "total_messages": end_user.total_messages, + "active_days": end_user.active_days, + "health_status": end_user.health_status, + "topics": end_user.topics, + "summary": end_user.summary, + "major": end_user.major, + "organization_id": end_user.organization_id, } users.append(end_user_dict) @@ -172,18 +172,18 @@ class EndUserService: """ try: # Update username if provided - if 'username' in profile_data: - end_user.name = profile_data['username'] + if "username" in profile_data: + end_user.name = profile_data["username"] # Update gender if provided - if 'gender' in profile_data: - gender_str = profile_data['gender'] + if "gender" in profile_data: + gender_str = profile_data["gender"] gender_map = {"unknown": 0, "male": 1, "female": 2} end_user.gender = gender_map[gender_str] # Update major if provided - if 'major' in profile_data: - major = profile_data['major'] + if "major" in profile_data: + major = profile_data["major"] # Create a new dictionary if extra_profile is None if end_user.extra_profile is None: @@ -191,7 +191,7 @@ class EndUserService: # Make a copy of the existing dictionary to ensure changes are detected extra_profile = dict(end_user.extra_profile) - extra_profile['major'] = major + extra_profile["major"] = major end_user.extra_profile = extra_profile # Force the change to be detected diff --git a/api/services/image_generation_service.py b/api/services/image_generation_service.py index 047328c8dd..4540d13ede 100644 --- a/api/services/image_generation_service.py +++ b/api/services/image_generation_service.py @@ -19,7 +19,6 @@ DEFAULT_IMAGE_EXTENSION = ".png" class ImageGenerationService: - generate_image_rate_limiter = RateLimiter( prefix="generate_image_rate_limit", max_attempts=dify_config.IMAGE_GENERATION_DAILY_LIMIT, time_window=86400 * 1 ) @@ -62,7 +61,6 @@ class ImageGenerationService: @staticmethod def pagination_image_list(end_user: EndUser, limit: int, offset: int) -> MultiPagePagination: - query = ( db.session.query(UserGeneratedImage) .filter(UserGeneratedImage.app_id == end_user.app_id, UserGeneratedImage.end_user_id == end_user.id) diff --git a/api/services/organization_service.py b/api/services/organization_service.py index eeb9759a11..e589cfbbc6 100644 --- a/api/services/organization_service.py +++ b/api/services/organization_service.py @@ -21,16 +21,16 @@ class OrganizationService: Returns: Organization or None if no match found """ - if not email or '@' not in email: + if not email or "@" not in email: return None # Get email domain - email_domain = email.split('@')[-1].lower() + email_domain = email.split("@")[-1].lower() # Get active organizations for this tenant organizations = ( db.session.query(Organization) - .filter(Organization.tenant_id == tenant_id, Organization.status == 'active') + .filter(Organization.tenant_id == tenant_id, Organization.status == "active") .all() ) @@ -186,7 +186,7 @@ class OrganizationService: """ return ( db.session.query(Organization) - .filter(Organization.tenant_id == tenant_id, Organization.status == 'active') + .filter(Organization.tenant_id == tenant_id, Organization.status == "active") .all() ) diff --git a/api/services/stats_service.py b/api/services/stats_service.py index bdba3a1c4c..627732c682 100644 --- a/api/services/stats_service.py +++ b/api/services/stats_service.py @@ -102,13 +102,13 @@ class StatsService: date_range = [] current_date = start_date while current_date <= end_date: - date_range.append(current_date.strftime('%Y-%m-%d')) + date_range.append(current_date.strftime("%Y-%m-%d")) current_date += timedelta(days=1) daily_stats = [] for date_str in date_range: - date = datetime.strptime(date_str, '%Y-%m-%d') + date = datetime.strptime(date_str, "%Y-%m-%d") next_date = date + timedelta(days=1) # Count active users (users who had a conversation on this date) @@ -154,8 +154,8 @@ class StatsService: active_user_ids_query = active_user_ids_query.filter(Message.organization_id == organization_id) # Get the intersection to find active new users - new_user_ids = [user_id for user_id, in new_user_ids_query.all()] - active_user_ids = [user_id for user_id, in active_user_ids_query.all()] + new_user_ids = [user_id for (user_id,) in new_user_ids_query.all()] + active_user_ids = [user_id for (user_id,) in active_user_ids_query.all()] # Count users who appear in both lists (created today AND active today) active_new_users = len(set(new_user_ids).intersection(set(active_user_ids))) @@ -184,13 +184,13 @@ class StatsService: date_range = [] current_date = start_date while current_date <= end_date: - date_range.append(current_date.strftime('%Y-%m-%d')) + date_range.append(current_date.strftime("%Y-%m-%d")) current_date += timedelta(days=1) daily_stats = [] for date_str in date_range: - date = datetime.strptime(date_str, '%Y-%m-%d') + date = datetime.strptime(date_str, "%Y-%m-%d") next_date = date + timedelta(days=1) # Count total conversations for this date diff --git a/api/tasks/image_generation_task.py b/api/tasks/image_generation_task.py index c315c29867..26528cb350 100644 --- a/api/tasks/image_generation_task.py +++ b/api/tasks/image_generation_task.py @@ -48,11 +48,7 @@ def generate_image_task( raise Exception(f"End user {end_user_id} not found") # Get the existing UserGeneratedImage entity - user_generated_image = ( - db.session.query(UserGeneratedImage) - .filter(UserGeneratedImage.id == image_id) - .first() - ) + user_generated_image = db.session.query(UserGeneratedImage).filter(UserGeneratedImage.id == image_id).first() if not user_generated_image: raise Exception(f"UserGeneratedImage {image_id} not found") @@ -67,16 +63,10 @@ def generate_image_task( db.session.commit() raise Exception("Image generation app id is not set") - image_generation_app_model = ( - db.session.query(App) - .filter(App.id == dify_config.IMAGE_GENERATION_APP_ID) - .first() - ) + image_generation_app_model = db.session.query(App).filter(App.id == dify_config.IMAGE_GENERATION_APP_ID).first() if image_generation_app_model is None: user_generated_image.status = "failed" - user_generated_image.error_message = ( - "Image generation app model is not found" - ) + user_generated_image.error_message = "Image generation app model is not found" db.session.commit() raise Exception("Image generation app model is not found") @@ -93,10 +83,7 @@ def generate_image_task( .all() ) - recent_messages = [ - f"user: {message.query}\n\nassistant: {message.answer}" - for message in recent_messages - ] + recent_messages = [f"user: {message.query}\n\nassistant: {message.answer}" for message in recent_messages] # Prepare arguments for generation args = { @@ -167,9 +154,7 @@ def generate_image_task( # Update status to failed if we have the entity try: user_generated_image = ( - db.session.query(UserGeneratedImage) - .filter(UserGeneratedImage.id == image_id) - .first() + db.session.query(UserGeneratedImage).filter(UserGeneratedImage.id == image_id).first() ) if user_generated_image: user_generated_image.status = "failed" diff --git a/api/tests/manual/test_registration_api.py b/api/tests/manual/test_registration_api.py index ba8d6cc3e9..718ea35fe6 100644 --- a/api/tests/manual/test_registration_api.py +++ b/api/tests/manual/test_registration_api.py @@ -12,69 +12,69 @@ import requests class RegistrationTester: """Test class for registration API.""" - + def __init__(self, base_url: str = "http://localhost:5001"): self.base_url = base_url self.session = requests.Session() - + def test_send_verification_code(self, email: str) -> dict[str, Any]: """Test sending verification code.""" print(f"🔵 Testing verification code send for: {email}") - + response = self.session.post( f"{self.base_url}/service/auth/email-code-login", json={"email": email}, - headers={"Content-Type": "application/json"} + headers={"Content-Type": "application/json"}, ) - + print(f" Status: {response.status_code}") print(f" Response: {response.text}") - + return { "status_code": response.status_code, "response": ( - response.json() - if response.headers.get('content-type', '').startswith('application/json') + response.json() + if response.headers.get("content-type", "").startswith("application/json") else response.text - ) + ), } - + def test_registration_with_code(self, email: str, code: str, token: str) -> dict[str, Any]: """Test registration with verification code.""" print(f"🔵 Testing registration for: {email} with code: {code}") - + response = self.session.post( f"{self.base_url}/service/auth/email-code-login/validity", json={"email": email, "code": code, "token": token}, - headers={"Content-Type": "application/json"} + headers={"Content-Type": "application/json"}, ) - + print(f" Status: {response.status_code}") print(f" Response: {response.text}") - + return { "status_code": response.status_code, "response": ( - response.json() - if response.headers.get('content-type', '').startswith('application/json') + response.json() + if response.headers.get("content-type", "").startswith("application/json") else response.text - ) + ), } - + def test_verification_code_sending(self, email: str) -> dict[str, Any]: """Test verification code sending (first step only).""" print(f"\n🚀 Testing verification code sending for: {email}") print("=" * 50) - + # Step 1: Send verification code send_result = self.test_send_verification_code(email) - + if send_result["status_code"] != 200: print(f"❌ Failed to send verification code for {email}") return send_result - + print(f"✅ Verification code sent successfully for {email}") - + # Extract token from send result token = None if isinstance(send_result["response"], dict) and "data" in send_result["response"]: @@ -85,23 +85,23 @@ class RegistrationTester: else: print("❌ No token received from verification code send") return {"status_code": 400, "response": "No token received"} - + return {"status_code": 200, "response": "Verification code sent successfully", "token": token} - + def test_registration_flow_interactive(self, email: str) -> dict[str, Any]: """Test full registration flow with user input for verification code.""" print(f"\n🚀 Testing INTERACTIVE registration flow for: {email}") print("=" * 50) - + # Step 1: Send verification code send_result = self.test_send_verification_code(email) - + if send_result["status_code"] != 200: print(f"❌ Failed to send verification code for {email}") return send_result - + print(f"✅ Verification code sent successfully for {email}") - + # Extract token from send result token = None if isinstance(send_result["response"], dict) and "data" in send_result["response"]: @@ -110,94 +110,91 @@ class RegistrationTester: else: print("❌ No token received from verification code send") return {"status_code": 400, "response": "No token received"} - + # Step 2: Get verification code from user print(f"📧 A verification code has been sent to {email}") verification_code = input("🔢 Please enter the verification code from your email: ").strip() - + if not verification_code: print("❌ No verification code entered") return {"status_code": 400, "response": "No verification code entered"} - + register_result = self.test_registration_with_code(email, verification_code, token) - + if register_result["status_code"] == 200: print(f"✅ Registration successful for {email}") else: print(f"❌ Registration failed for {email}") - + return register_result - + def run_comprehensive_tests(self): """Run comprehensive tests for different email scenarios.""" print("🧪 Running Comprehensive Email Verification Code Tests") print("=" * 60) print("ℹ️ This tests verification code sending (step 1 of registration)") print("ℹ️ For full registration testing, use interactive mode") - + test_cases = [ { "email": "student@university.edu", "description": "University student (.edu domain)", - "expected_org": "Should be assigned to organization if domain match exists" + "expected_org": "Should be assigned to organization if domain match exists", }, { "email": "user@gmail.com", "description": "Personal Gmail account", - "expected_org": "Should register without organization assignment" + "expected_org": "Should register without organization assignment", }, { "email": "employee@company.com", "description": "Company email without pre-configured organization", - "expected_org": "Should register without organization assignment" + "expected_org": "Should register without organization assignment", }, { "email": "admin@startup.io", "description": "Startup email", - "expected_org": "Should register without organization assignment" + "expected_org": "Should register without organization assignment", }, { "email": "test@protonmail.com", "description": "ProtonMail account", - "expected_org": "Should register without organization assignment" - } + "expected_org": "Should register without organization assignment", + }, ] - + results = [] - + for i, test_case in enumerate(test_cases, 1): print(f"\n📋 Test Case {i}: {test_case['description']}") print(f" Email: {test_case['email']}") print(f" Expected: {test_case['expected_org']}") - + result = self.test_verification_code_sending(test_case["email"]) - results.append({ - "test_case": test_case, - "result": result - }) - + results.append({"test_case": test_case, "result": result}) + # Small delay between tests time.sleep(0.5) - + # Summary print("\n📊 Test Results Summary") print("=" * 30) - + for i, test_result in enumerate(results, 1): test_case = test_result["test_case"] result = test_result["result"] - + status = "✅ PASSED" if result["status_code"] == 200 else "❌ FAILED" print(f"{i}. {test_case['description']}: {status}") - + if result["status_code"] != 200: print(f" Error: {result['response']}") - + # Overall summary passed = sum(1 for r in results if r["result"]["status_code"] == 200) total = len(results) print(f"\n🎯 Overall: {passed}/{total} verification code tests passed") - + return results @@ -205,9 +202,9 @@ def main(): """Main function to run tests.""" print("🔍 Email Registration API Test Suite") print("=" * 40) - + tester = RegistrationTester() - + # Check if server is running try: response = requests.get("http://localhost:5001/health", timeout=5) @@ -216,14 +213,14 @@ def main(): print("❌ Server is not running on port 5001") print(" Please start the server with: uv run flask run --host 0.0.0.0 --port=5001 --debug") return - + # Run comprehensive tests results = tester.run_comprehensive_tests() - + # Additional specific tests print("\n🔬 Additional Tests") print("=" * 20) - + # Test invalid email print("\n🔵 Testing invalid email format") invalid_result = tester.test_send_verification_code("invalid.email") @@ -231,7 +228,7 @@ def main(): print("✅ Invalid email correctly rejected") else: print("❌ Invalid email was accepted (should be rejected)") - + # Test registration with invalid code print("\n🔵 Testing registration with invalid code") # First get a valid token @@ -245,11 +242,11 @@ def main(): print("❌ Invalid verification code was accepted (should be rejected)") else: print("⚠️ Could not test invalid code - failed to get token") - + print("\n🎉 All tests completed!") print("\nℹ️ Note: This test uses mock verification codes.") print(" In production, users would receive actual codes via email.") if __name__ == "__main__": - main() \ No newline at end of file + main()