Apply code formatting with ruff formatter

- Fixed formatting issues across 23 files
- Ensured all linting checks pass with uv run ruff check
- Code now follows consistent formatting standards
- No functional changes, only formatting improvements

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
pull/21891/head
ytqh 11 months ago
parent de684cdd21
commit 816ea24571

@ -52,19 +52,13 @@ def reset_password(email, new_password, password_confirm):
account = db.session.query(Account).filter(Account.email == email).one_or_none() account = db.session.query(Account).filter(Account.email == email).one_or_none()
if not account: if not account:
click.echo( click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
click.style("Account not found for email: {}".format(email), fg="red")
)
return return
try: try:
valid_password(new_password) valid_password(new_password)
except: except:
click.echo( click.echo(click.style("Invalid password. Must match {}".format(password_pattern), fg="red"))
click.style(
"Invalid password. Must match {}".format(password_pattern), fg="red"
)
)
return return
# generate password salt # generate password salt
@ -96,9 +90,7 @@ def reset_email(email, new_email, email_confirm):
account = db.session.query(Account).filter(Account.email == email).one_or_none() account = db.session.query(Account).filter(Account.email == email).one_or_none()
if not account: if not account:
click.echo( click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
click.style("Account not found for email: {}".format(email), fg="red")
)
return return
try: try:
@ -132,34 +124,24 @@ def reset_encrypt_key_pair():
Only support SELF_HOSTED mode. Only support SELF_HOSTED mode.
""" """
if dify_config.EDITION != "SELF_HOSTED": if dify_config.EDITION != "SELF_HOSTED":
click.echo( click.echo(click.style("This command is only for SELF_HOSTED installations.", fg="red"))
click.style("This command is only for SELF_HOSTED installations.", fg="red")
)
return return
tenants = db.session.query(Tenant).all() tenants = db.session.query(Tenant).all()
for tenant in tenants: for tenant in tenants:
if not tenant: if not tenant:
click.echo( click.echo(click.style("No workspaces found. Run /install first.", fg="red"))
click.style("No workspaces found. Run /install first.", fg="red")
)
return return
tenant.encrypt_public_key = generate_key_pair(tenant.id) tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.query(Provider).filter( db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
Provider.provider_type == "custom", Provider.tenant_id == tenant.id db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
).delete()
db.session.query(ProviderModel).filter(
ProviderModel.tenant_id == tenant.id
).delete()
db.session.commit() db.session.commit()
click.echo( click.echo(
click.style( click.style(
"Congratulations! The asymmetric key pair of workspace {} has been reset.".format( "Congratulations! The asymmetric key pair of workspace {} has been reset.".format(tenant.id),
tenant.id
),
fg="green", fg="green",
) )
) )
@ -209,15 +191,12 @@ def migrate_annotation_vector_database():
for app in apps: for app in apps:
total_count = total_count + 1 total_count = total_count + 1
click.echo( click.echo(
f"Processing the {total_count} app {app.id}. " f"Processing the {total_count} app {app.id}. " + f"{create_count} created, {skipped_count} skipped."
+ f"{create_count} created, {skipped_count} skipped."
) )
try: try:
click.echo("Creating app annotation index: {}".format(app.id)) click.echo("Creating app annotation index: {}".format(app.id))
app_annotation_setting = ( app_annotation_setting = (
db.session.query(AppAnnotationSetting) db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
.filter(AppAnnotationSetting.app_id == app.id)
.first()
) )
if not app_annotation_setting: if not app_annotation_setting:
@ -227,22 +206,13 @@ def migrate_annotation_vector_database():
# get dataset_collection_binding info # get dataset_collection_binding info
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.filter( .filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
DatasetCollectionBinding.id
== app_annotation_setting.collection_binding_id
)
.first() .first()
) )
if not dataset_collection_binding: if not dataset_collection_binding:
click.echo( click.echo("App annotation collection binding not found: {}".format(app.id))
"App annotation collection binding not found: {}".format(app.id)
)
continue continue
annotations = ( annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
db.session.query(MessageAnnotation)
.filter(MessageAnnotation.app_id == app.id)
.all()
)
dataset = Dataset( dataset = Dataset(
id=app.id, id=app.id,
tenant_id=app.tenant_id, tenant_id=app.tenant_id,
@ -264,24 +234,14 @@ def migrate_annotation_vector_database():
) )
documents.append(document) documents.append(document)
vector = Vector( vector = Vector(dataset, attributes=["doc_id", "annotation_id", "app_id"])
dataset, attributes=["doc_id", "annotation_id", "app_id"]
)
click.echo(f"Migrating annotations for app: {app.id}.") click.echo(f"Migrating annotations for app: {app.id}.")
try: try:
vector.delete() vector.delete()
click.echo( click.echo(click.style(f"Deleted vector index for app {app.id}.", fg="green"))
click.style(
f"Deleted vector index for app {app.id}.", fg="green"
)
)
except Exception as e: except Exception as e:
click.echo( click.echo(click.style(f"Failed to delete vector index for app {app.id}.", fg="red"))
click.style(
f"Failed to delete vector index for app {app.id}.", fg="red"
)
)
raise e raise e
if documents: if documents:
try: try:
@ -292,11 +252,7 @@ def migrate_annotation_vector_database():
) )
) )
vector.create(documents) vector.create(documents)
click.echo( click.echo(click.style(f"Created vector index for app {app.id}.", fg="green"))
click.style(
f"Created vector index for app {app.id}.", fg="green"
)
)
except Exception as e: except Exception as e:
click.echo( click.echo(
click.style( click.style(
@ -310,9 +266,7 @@ def migrate_annotation_vector_database():
except Exception as e: except Exception as e:
click.echo( click.echo(
click.style( click.style(
"Error creating app annotation index: {} {}".format( "Error creating app annotation index: {} {}".format(e.__class__.__name__, str(e)),
e.__class__.__name__, str(e)
),
fg="red", fg="red",
) )
) )
@ -378,9 +332,7 @@ def migrate_knowledge_vector_database():
f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped." f"Processing the {total_count} dataset {dataset.id}. {create_count} created, {skipped_count} skipped."
) )
try: try:
click.echo( click.echo("Creating dataset vector database index: {}".format(dataset.id))
"Creating dataset vector database index: {}".format(dataset.id)
)
if dataset.index_struct_dict: if dataset.index_struct_dict:
if dataset.index_struct_dict["type"] == vector_type: if dataset.index_struct_dict["type"] == vector_type:
skipped_count = skipped_count + 1 skipped_count = skipped_count + 1
@ -393,10 +345,7 @@ def migrate_knowledge_vector_database():
if dataset.collection_binding_id: if dataset.collection_binding_id:
dataset_collection_binding = ( dataset_collection_binding = (
db.session.query(DatasetCollectionBinding) db.session.query(DatasetCollectionBinding)
.filter( .filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
DatasetCollectionBinding.id
== dataset.collection_binding_id
)
.one_or_none() .one_or_none()
) )
if dataset_collection_binding: if dataset_collection_binding:
@ -407,9 +356,7 @@ def migrate_knowledge_vector_database():
collection_name = Dataset.gen_collection_name_by_id(dataset_id) collection_name = Dataset.gen_collection_name_by_id(dataset_id)
elif vector_type in lower_collection_vector_types: elif vector_type in lower_collection_vector_types:
collection_name = Dataset.gen_collection_name_by_id( collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
dataset_id
).lower()
else: else:
raise ValueError(f"Vector store {vector_type} is not supported.") raise ValueError(f"Vector store {vector_type} is not supported.")
@ -508,9 +455,7 @@ def migrate_knowledge_vector_database():
db.session.rollback() db.session.rollback()
click.echo( click.echo(
click.style( click.style(
"Error creating dataset index: {} {}".format( "Error creating dataset index: {} {}".format(e.__class__.__name__, str(e)),
e.__class__.__name__, str(e)
),
fg="red", fg="red",
) )
) )
@ -572,9 +517,9 @@ def convert_to_agent_apps():
db.session.commit() db.session.commit()
# update conversation mode to agent # update conversation mode to agent
db.session.query(Conversation).filter( db.session.query(Conversation).filter(Conversation.app_id == app.id).update(
Conversation.app_id == app.id {Conversation.mode: AppMode.AGENT_CHAT.value}
).update({Conversation.mode: AppMode.AGENT_CHAT.value}) )
db.session.commit() db.session.commit()
click.echo(click.style("Converted app: {}".format(app.id), fg="green")) click.echo(click.style("Converted app: {}".format(app.id), fg="green"))
@ -588,9 +533,7 @@ def convert_to_agent_apps():
click.echo( click.echo(
click.style( click.style(
"Conversion complete. Converted {} agent apps.".format( "Conversion complete. Converted {} agent apps.".format(len(proceeded_app_ids)),
len(proceeded_app_ids)
),
fg="green", fg="green",
) )
) )
@ -723,15 +666,11 @@ def old_metadata_migration():
) )
db.session.add(dataset_metadata_binding) db.session.add(dataset_metadata_binding)
else: else:
dataset_metadata_binding = ( dataset_metadata_binding = DatasetMetadataBinding.query.filter(
DatasetMetadataBinding.query.filter( DatasetMetadataBinding.dataset_id == document.dataset_id,
DatasetMetadataBinding.dataset_id DatasetMetadataBinding.document_id == document.id,
== document.dataset_id, DatasetMetadataBinding.metadata_id == dataset_metadata.id,
DatasetMetadataBinding.document_id == document.id, ).first()
DatasetMetadataBinding.metadata_id
== dataset_metadata.id,
).first()
)
if not dataset_metadata_binding: if not dataset_metadata_binding:
dataset_metadata_binding = DatasetMetadataBinding( dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=document.tenant_id, tenant_id=document.tenant_id,
@ -750,9 +689,7 @@ def old_metadata_migration():
@click.option("--email", prompt=True, help="Tenant account email.") @click.option("--email", prompt=True, help="Tenant account email.")
@click.option("--name", prompt=True, help="Workspace name.") @click.option("--name", prompt=True, help="Workspace name.")
@click.option("--language", prompt=True, help="Account language, default: en-US.") @click.option("--language", prompt=True, help="Account language, default: en-US.")
def create_tenant( def create_tenant(email: str, language: Optional[str] = None, name: Optional[str] = None):
email: str, language: Optional[str] = None, name: Optional[str] = None
):
""" """
Create tenant account Create tenant account
""" """
@ -790,9 +727,7 @@ def create_tenant(
click.echo( click.echo(
click.style( click.style(
"Account and tenant created.\nAccount: {}\nPassword: {}".format( "Account and tenant created.\nAccount: {}\nPassword: {}".format(email, new_password),
email, new_password
),
fg="green", fg="green",
) )
) )
@ -867,19 +802,13 @@ where sites.id is null limit 1000"""
fg="red", fg="red",
) )
) )
logging.exception( logging.exception(f"Failed to fix app related site missing issue, app_id: {app_id}")
f"Failed to fix app related site missing issue, app_id: {app_id}"
)
continue continue
if not processed_count: if not processed_count:
break break
click.echo( click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green"))
click.style(
"Fix for missing app-related sites completed successfully!", fg="green"
)
)
@click.command( @click.command(
@ -895,9 +824,7 @@ where sites.id is null limit 1000"""
help="Type of login ID (phone or email)", help="Type of login ID (phone or email)",
) )
@click.option("--organization-id", required=True, help="Organization ID") @click.option("--organization-id", required=True, help="Organization ID")
def create_admin_account( def create_admin_account(name: str, login_id: str, login_id_type: str, organization_id: str):
name: str, login_id: str, login_id_type: str, organization_id: str
):
""" """
Create or update an admin account with a phone number or email for a specific organization. Create or update an admin account with a phone number or email for a specific organization.
This command will create a new account if the login ID doesn't exist, This command will create a new account if the login ID doesn't exist,
@ -907,46 +834,26 @@ def create_admin_account(
# Check if organization exists # Check if organization exists
from models.organization import Organization, OrganizationMember, OrganizationRole from models.organization import Organization, OrganizationMember, OrganizationRole
organization = ( organization = db.session.query(Organization).filter(Organization.id == organization_id).first()
db.session.query(Organization)
.filter(Organization.id == organization_id)
.first()
)
if not organization: if not organization:
click.echo( click.echo(click.style(f"Organization with ID {organization_id} not found.", fg="red"))
click.style(
f"Organization with ID {organization_id} not found.", fg="red"
)
)
return return
# Get tenant from organization # Get tenant from organization
tenant = ( tenant = db.session.query(Tenant).filter(Tenant.id == organization.tenant_id).first()
db.session.query(Tenant).filter(Tenant.id == organization.tenant_id).first()
)
if not tenant: if not tenant:
click.echo( click.echo(click.style(f"Tenant for organization {organization_id} not found.", fg="red"))
click.style(
f"Tenant for organization {organization_id} not found.", fg="red"
)
)
return return
# Check if account exists with this login ID # Check if account exists with this login ID
account = None account = None
if login_id_type == "phone": if login_id_type == "phone":
account = ( account = db.session.query(Account).filter(Account.phone == login_id).first()
db.session.query(Account).filter(Account.phone == login_id).first()
)
else: # email else: # email
account = ( account = db.session.query(Account).filter(Account.email == login_id).first()
db.session.query(Account).filter(Account.email == login_id).first()
)
if account: if account:
click.echo( click.echo(f"Account with {login_id_type} {login_id} already exists. Updating account...")
f"Account with {login_id_type} {login_id} already exists. Updating account..."
)
# Update account # Update account
account.name = name account.name = name
@ -1010,9 +917,7 @@ def create_admin_account(
if org_member: if org_member:
# Update role to admin # Update role to admin
org_member.role = OrganizationRole.ADMIN org_member.role = OrganizationRole.ADMIN
click.echo( click.echo(f"Updated account role to {OrganizationRole.ADMIN} in organization {organization.name}")
f"Updated account role to {OrganizationRole.ADMIN} in organization {organization.name}"
)
else: else:
# Add account to organization with admin role # Add account to organization with admin role
org_member = OrganizationMember( org_member = OrganizationMember(
@ -1023,9 +928,7 @@ def create_admin_account(
created_by=account.id, created_by=account.id,
) )
db.session.add(org_member) db.session.add(org_member)
click.echo( click.echo(f"Added account to organization {organization.name} with role {OrganizationRole.ADMIN}")
f"Added account to organization {organization.name} with role {OrganizationRole.ADMIN}"
)
db.session.commit() db.session.commit()
@ -1044,12 +947,8 @@ def create_admin_account(
click.echo(click.style(f"Error: {str(e)}", fg="red")) click.echo(click.style(f"Error: {str(e)}", fg="red"))
@click.command( @click.command("create-organization", help="Create a new organization for multi-school support.")
"create-organization", help="Create a new organization for multi-school support." @click.option("--tenant-id", required=True, help="ID of the tenant that owns this organization")
)
@click.option(
"--tenant-id", required=True, help="ID of the tenant that owns this organization"
)
@click.option("--name", required=True, help="Name of the organization") @click.option("--name", required=True, help="Name of the organization")
@click.option("--code", required=True, help="Unique code for the organization") @click.option("--code", required=True, help="Unique code for the organization")
@click.option( @click.option(
@ -1060,21 +959,15 @@ def create_admin_account(
help="Type of organization", help="Type of organization",
) )
@click.option("--description", default="", help="Description of the organization") @click.option("--description", default="", help="Description of the organization")
@click.option( @click.option("--email-domains", default="", help="Comma-separated list of allowed email domains")
"--email-domains", default="", help="Comma-separated list of allowed email domains"
)
@click.option("--created-by", required=True, help="Account ID of the creator") @click.option("--created-by", required=True, help="Account ID of the creator")
def create_organization_cmd( def create_organization_cmd(tenant_id, name, code, org_type, description, email_domains, created_by):
tenant_id, name, code, org_type, description, email_domains, created_by
):
"""Create a new organization under a tenant for multi-school support""" """Create a new organization under a tenant for multi-school support"""
try: try:
# Check if code already exists # Check if code already exists
from models.organization import Organization from models.organization import Organization
existing = ( existing = db.session.query(Organization).filter(Organization.code == code).first()
db.session.query(Organization).filter(Organization.code == code).first()
)
if existing: if existing:
click.echo(f"Error: Organization with code '{code}' already exists") click.echo(f"Error: Organization with code '{code}' already exists")
return return
@ -1106,9 +999,7 @@ def create_organization_cmd(
db.session.add(organization) db.session.add(organization)
db.session.commit() db.session.commit()
click.echo( click.echo(f"Organization '{name}' (ID: {organization.id}) created successfully")
f"Organization '{name}' (ID: {organization.id}) created successfully"
)
except Exception as e: except Exception as e:
db.session.rollback() db.session.rollback()
@ -1120,17 +1011,13 @@ def create_organization_cmd(
@click.option("--name", help="New name for the organization") @click.option("--name", help="New name for the organization")
@click.option("--description", help="New description") @click.option("--description", help="New description")
@click.option("--email-domains", help="Comma-separated list of allowed email domains") @click.option("--email-domains", help="Comma-separated list of allowed email domains")
@click.option( @click.option("--status", type=click.Choice(["active", "inactive"]), help="Organization status")
"--status", type=click.Choice(["active", "inactive"]), help="Organization status"
)
def update_organization_cmd(org_id, name, description, email_domains, status): def update_organization_cmd(org_id, name, description, email_domains, status):
"""Update an existing organization's configuration""" """Update an existing organization's configuration"""
try: try:
from models.organization import Organization from models.organization import Organization
organization = ( organization = db.session.query(Organization).filter(Organization.id == org_id).first()
db.session.query(Organization).filter(Organization.id == org_id).first()
)
if not organization: if not organization:
click.echo(f"Error: Organization with ID '{org_id}' not found") click.echo(f"Error: Organization with ID '{org_id}' not found")
return return
@ -1225,9 +1112,7 @@ def show_organization_cmd(org_id):
try: try:
from models.organization import Organization from models.organization import Organization
organization = ( organization = db.session.query(Organization).filter(Organization.id == org_id).first()
db.session.query(Organization).filter(Organization.id == org_id).first()
)
if not organization: if not organization:
click.echo(f"Error: Organization with ID '{org_id}' not found") click.echo(f"Error: Organization with ID '{org_id}' not found")
@ -1257,27 +1142,19 @@ def show_organization_cmd(org_id):
@click.option( @click.option(
"--role", "--role",
required=True, required=True,
type=click.Choice( type=click.Choice(["admin", "teacher", "student", "staff", "manager", "employee", "guest"]),
["admin", "teacher", "student", "staff", "manager", "employee", "guest"]
),
help="Role in the organization", help="Role in the organization",
) )
@click.option("--department", help="Department within the organization") @click.option("--department", help="Department within the organization")
@click.option("--title", help="Job title or position") @click.option("--title", help="Job title or position")
@click.option( @click.option("--is-default", is_flag=True, help="Set as the account's default organization")
"--is-default", is_flag=True, help="Set as the account's default organization" def add_account_to_organization_cmd(org_id, account_id, role, department, title, is_default):
)
def add_account_to_organization_cmd(
org_id, account_id, role, department, title, is_default
):
"""Add an account to an organization with appropriate role and metadata""" """Add an account to an organization with appropriate role and metadata"""
try: try:
from models.organization import Organization, OrganizationMember from models.organization import Organization, OrganizationMember
# Check if organization exists # Check if organization exists
organization = ( organization = db.session.query(Organization).filter(Organization.id == org_id).first()
db.session.query(Organization).filter(Organization.id == org_id).first()
)
if not organization: if not organization:
click.echo(f"Error: Organization with ID '{org_id}' not found") click.echo(f"Error: Organization with ID '{org_id}' not found")
return return
@ -1299,9 +1176,7 @@ def add_account_to_organization_cmd(
) )
if existing: if existing:
click.echo( click.echo("Account is already a member of this organization. Updating role and metadata.")
"Account is already a member of this organization. Updating role and metadata."
)
existing.role = role existing.role = role
existing.department = department existing.department = department
existing.title = title existing.title = title
@ -1374,9 +1249,7 @@ def upload_private_key_file_cloud_storage(tenant_id: Optional[str] = None):
) )
file_key = f"privkeys/{tenant_id}/private.pem" file_key = f"privkeys/{tenant_id}/private.pem"
file_content = Path( file_content = Path(f"{os.environ.get('STORAGE_LOCAL_PATH', 'storage')}/{file_key}").read_bytes()
f"{os.environ.get('STORAGE_LOCAL_PATH', 'storage')}/{file_key}"
).read_bytes()
storage.save(filename=file_key, data=file_content) storage.save(filename=file_key, data=file_content)
click.echo( click.echo(
click.style( click.style(
@ -1386,9 +1259,7 @@ def upload_private_key_file_cloud_storage(tenant_id: Optional[str] = None):
) )
@click.command( @click.command("upload-local-files-to-cloud-storage", help="upload local files to cloud storage")
"upload-local-files-to-cloud-storage", help="upload local files to cloud storage"
)
def upload_local_files_to_cloud_storage(): def upload_local_files_to_cloud_storage():
""" """
upload local files to cloud storage upload local files to cloud storage
@ -1406,14 +1277,10 @@ def upload_local_files_to_cloud_storage():
batch_size = 100 batch_size = 100
processed_count = 0 processed_count = 0
while processed_count < total_count: while processed_count < total_count:
files: list[UploadFile] = ( files: list[UploadFile] = UploadFile.query.filter_by(storage_type="local").limit(batch_size).all()
UploadFile.query.filter_by(storage_type="local").limit(batch_size).all()
)
for file in files: for file in files:
target_filepath = ( target_filepath = f"{os.environ.get('STORAGE_LOCAL_PATH', 'storage')}/{file.key}"
f"{os.environ.get('STORAGE_LOCAL_PATH', 'storage')}/{file.key}"
)
# if the file exists # if the file exists
if not os.path.exists(target_filepath): if not os.path.exists(target_filepath):
@ -1459,11 +1326,7 @@ def upload_local_files_to_cloud_storage():
processed_count += 1 processed_count += 1
if processed_count % 10 == 0 or processed_count == total_count: if processed_count % 10 == 0 or processed_count == total_count:
click.echo( click.echo(click.style(f"Processed {processed_count}/{total_count} files\n", fg="blue"))
click.style(
f"Processed {processed_count}/{total_count} files\n", fg="blue"
)
)
time.sleep(3) time.sleep(3)
click.echo( click.echo(
@ -1564,9 +1427,7 @@ def install_plugins(input_file: str, output_file: str, workers: int):
click.echo(click.style("Install plugins completed.", fg="green")) click.echo(click.style("Install plugins completed.", fg="green"))
@click.command( @click.command("clear-free-plan-tenant-expired-logs", help="Clear free plan tenant expired logs.")
"clear-free-plan-tenant-expired-logs", help="Clear free plan tenant expired logs."
)
@click.option( @click.option(
"--days", "--days",
prompt=True, prompt=True,
@ -1593,9 +1454,7 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
ClearFreePlanTenantExpiredLogs.process(days, batch, tenant_ids) ClearFreePlanTenantExpiredLogs.process(days, batch, tenant_ids)
click.echo( click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green"))
click.style("Clear free plan tenant expired logs completed.", fg="green")
)
@click.option( @click.option(
@ -1651,9 +1510,7 @@ def clear_orphaned_file_records(force: bool):
) )
) )
for ids_table in ids_tables: for ids_table in ids_tables:
click.echo( click.echo(click.style(f"- {ids_table['table']} ({ids_table['column']})", fg="yellow"))
click.style(f"- {ids_table['table']} ({ids_table['column']})", fg="yellow")
)
click.echo("") click.echo("")
click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red")) click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red"))
@ -1704,9 +1561,7 @@ def clear_orphaned_file_records(force: bool):
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(db.text(query)) rs = conn.execute(db.text(query))
for i in rs: for i in rs:
orphaned_message_files.append( orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])})
{"id": str(i[0]), "message_id": str(i[1])}
)
if orphaned_message_files: if orphaned_message_files:
click.echo( click.echo(
@ -1732,9 +1587,7 @@ def clear_orphaned_file_records(force: bool):
abort=True, abort=True,
) )
click.echo( click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
click.style("- Deleting orphaned message_files records", fg="white")
)
query = "DELETE FROM message_files WHERE id IN :ids" query = "DELETE FROM message_files WHERE id IN :ids"
with db.engine.begin() as conn: with db.engine.begin() as conn:
conn.execute( conn.execute(
@ -1755,11 +1608,7 @@ def clear_orphaned_file_records(force: bool):
) )
) )
except Exception as e: except Exception as e:
click.echo( click.echo(click.style(f"Error deleting orphaned message_files records: {str(e)}", fg="red"))
click.style(
f"Error deleting orphaned message_files records: {str(e)}", fg="red"
)
)
# clean up the orphaned records in the rest of the *_files tables # clean up the orphaned records in the rest of the *_files tables
try: try:
@ -1776,14 +1625,8 @@ def clear_orphaned_file_records(force: bool):
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(db.text(query)) rs = conn.execute(db.text(query))
for i in rs: for i in rs:
all_files_in_tables.append( all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]})
{"table": files_table["table"], "id": str(i[0]), "key": i[1]} click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
)
click.echo(
click.style(
f"Found {len(all_files_in_tables)} files in tables.", fg="white"
)
)
# fetch referred table and columns # fetch referred table and columns
guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
@ -1798,15 +1641,12 @@ def clear_orphaned_file_records(force: bool):
) )
) )
query = ( query = (
f"SELECT {ids_table['column']} FROM {ids_table['table']} " f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
f"WHERE {ids_table['column']} IS NOT NULL"
) )
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(db.text(query)) rs = conn.execute(db.text(query))
for i in rs: for i in rs:
all_ids_in_tables.append( all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
{"table": ids_table["table"], "id": str(i[0])}
)
elif ids_table["type"] == "text": elif ids_table["type"] == "text":
click.echo( click.echo(
click.style( click.style(
@ -1842,11 +1682,7 @@ def clear_orphaned_file_records(force: bool):
for i in rs: for i in rs:
for j in i[0]: for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j}) all_ids_in_tables.append({"table": ids_table["table"], "id": j})
click.echo( click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white"))
click.style(
f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white"
)
)
except Exception as e: except Exception as e:
click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
@ -1864,9 +1700,7 @@ def clear_orphaned_file_records(force: bool):
) )
) )
return return
click.echo( click.echo(click.style(f"Found {len(orphaned_files)} orphaned file records.", fg="white"))
click.style(f"Found {len(orphaned_files)} orphaned file records.", fg="white")
)
for file in orphaned_files: for file in orphaned_files:
click.echo(click.style(f"- orphaned file id: {file}", fg="black")) click.echo(click.style(f"- orphaned file id: {file}", fg="black"))
if not force: if not force:
@ -1888,13 +1722,9 @@ def clear_orphaned_file_records(force: bool):
with db.engine.begin() as conn: with db.engine.begin() as conn:
conn.execute(db.text(query), {"ids": tuple(orphaned_files)}) conn.execute(db.text(query), {"ids": tuple(orphaned_files)})
except Exception as e: except Exception as e:
click.echo( click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red"))
click.style(f"Error deleting orphaned file records: {str(e)}", fg="red")
)
return return
click.echo( click.echo(click.style(f"Removed {len(orphaned_files)} orphaned file records.", fg="green"))
click.style(f"Removed {len(orphaned_files)} orphaned file records.", fg="green")
)
@click.option( @click.option(
@ -1903,9 +1733,7 @@ def clear_orphaned_file_records(force: bool):
is_flag=True, is_flag=True,
help="Skip user confirmation and force the command to execute.", help="Skip user confirmation and force the command to execute.",
) )
@click.command( @click.command("remove-orphaned-files-on-storage", help="Remove orphaned files on the storage.")
"remove-orphaned-files-on-storage", help="Remove orphaned files on the storage."
)
def remove_orphaned_files_on_storage(force: bool): def remove_orphaned_files_on_storage(force: bool):
""" """
Remove orphaned files on the storage. Remove orphaned files on the storage.
@ -1981,32 +1809,20 @@ def remove_orphaned_files_on_storage(force: bool):
all_files_in_tables = [] all_files_in_tables = []
try: try:
for files_table in files_tables: for files_table in files_tables:
click.echo( click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white"))
click.style(
f"- Listing files from table {files_table['table']}", fg="white"
)
)
query = f"SELECT {files_table['key_column']} FROM {files_table['table']}" query = f"SELECT {files_table['key_column']} FROM {files_table['table']}"
with db.engine.begin() as conn: with db.engine.begin() as conn:
rs = conn.execute(db.text(query)) rs = conn.execute(db.text(query))
for i in rs: for i in rs:
all_files_in_tables.append(str(i[0])) all_files_in_tables.append(str(i[0]))
click.echo( click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
click.style(
f"Found {len(all_files_in_tables)} files in tables.", fg="white"
)
)
except Exception as e: except Exception as e:
click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
all_files_on_storage = [] all_files_on_storage = []
for storage_path in storage_paths: for storage_path in storage_paths:
try: try:
click.echo( click.echo(click.style(f"- Scanning files on storage path {storage_path}", fg="white"))
click.style(
f"- Scanning files on storage path {storage_path}", fg="white"
)
)
files = storage.scan(path=storage_path, files=True, directories=False) files = storage.scan(path=storage_path, files=True, directories=False)
all_files_on_storage.extend(files) all_files_on_storage.extend(files)
except FileNotFoundError as e: except FileNotFoundError as e:
@ -2025,18 +1841,12 @@ def remove_orphaned_files_on_storage(force: bool):
) )
) )
continue continue
click.echo( click.echo(click.style(f"Found {len(all_files_on_storage)} files on storage.", fg="white"))
click.style(f"Found {len(all_files_on_storage)} files on storage.", fg="white")
)
# find orphaned files # find orphaned files
orphaned_files = list(set(all_files_on_storage) - set(all_files_in_tables)) orphaned_files = list(set(all_files_on_storage) - set(all_files_in_tables))
if not orphaned_files: if not orphaned_files:
click.echo( click.echo(click.style("No orphaned files found. There is nothing to remove.", fg="green"))
click.style(
"No orphaned files found. There is nothing to remove.", fg="green"
)
)
return return
click.echo(click.style(f"Found {len(orphaned_files)} orphaned files.", fg="white")) click.echo(click.style(f"Found {len(orphaned_files)} orphaned files.", fg="white"))
for file in orphaned_files: for file in orphaned_files:
@ -2057,18 +1867,10 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"- Removing orphaned file: {file}", fg="white")) click.echo(click.style(f"- Removing orphaned file: {file}", fg="white"))
except Exception as e: except Exception as e:
error_files += 1 error_files += 1
click.echo( click.echo(click.style(f"- Error deleting orphaned file {file}: {str(e)}", fg="red"))
click.style(
f"- Error deleting orphaned file {file}: {str(e)}", fg="red"
)
)
continue continue
if error_files == 0: if error_files == 0:
click.echo( click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green"))
click.style(
f"Removed {removed_files} orphaned files without errors.", fg="green"
)
)
else: else:
click.echo( click.echo(
click.style( click.style(

@ -238,9 +238,7 @@ class LoginApi(Resource):
AccountService.reset_login_error_rate_limit(login_id) AccountService.reset_login_error_rate_limit(login_id)
# Generate token for the authenticated admin # Generate token for the authenticated admin
token_pair = AccountService.login( token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
account, ip_address=extract_remote_ip(request)
)
response_data = token_pair.model_dump() response_data = token_pair.model_dump()

@ -181,6 +181,6 @@ class OperationLogs(Resource):
pass pass
api.add_resource(WatermarkSettings, '/settings/watermark') api.add_resource(WatermarkSettings, "/settings/watermark")
api.add_resource(SystemInfo, '/settings/info') api.add_resource(SystemInfo, "/settings/info")
api.add_resource(OperationLogs, '/settings/logs') api.add_resource(OperationLogs, "/settings/logs")

@ -109,15 +109,15 @@ class UserStats(Resource):
""" """
try: try:
# Parse date parameters # Parse date parameters
start_date_str = request.args.get('start_date') start_date_str = request.args.get("start_date")
end_date_str = request.args.get('end_date') end_date_str = request.args.get("end_date")
if not start_date_str or not end_date_str: if not start_date_str or not end_date_str:
raise BadRequest("start_date and end_date are required") raise BadRequest("start_date and end_date are required")
try: try:
start_date = datetime.strptime(start_date_str, '%Y-%m-%d') start_date = datetime.strptime(start_date_str, "%Y-%m-%d")
end_date = datetime.strptime(end_date_str, '%Y-%m-%d') end_date = datetime.strptime(end_date_str, "%Y-%m-%d")
end_date = end_date.replace(hour=23, minute=59, second=59) end_date = end_date.replace(hour=23, minute=59, second=59)
except ValueError: except ValueError:
raise BadRequest("Invalid date format. Use YYYY-MM-DD") raise BadRequest("Invalid date format. Use YYYY-MM-DD")
@ -187,15 +187,15 @@ class ConversationStats(Resource):
""" """
try: try:
# Parse date parameters # Parse date parameters
start_date_str = request.args.get('start_date') start_date_str = request.args.get("start_date")
end_date_str = request.args.get('end_date') end_date_str = request.args.get("end_date")
if not start_date_str or not end_date_str: if not start_date_str or not end_date_str:
raise BadRequest("start_date and end_date are required") raise BadRequest("start_date and end_date are required")
try: try:
start_date = datetime.strptime(start_date_str, '%Y-%m-%d') start_date = datetime.strptime(start_date_str, "%Y-%m-%d")
end_date = datetime.strptime(end_date_str, '%Y-%m-%d') end_date = datetime.strptime(end_date_str, "%Y-%m-%d")
end_date = end_date.replace(hour=23, minute=59, second=59) end_date = end_date.replace(hour=23, minute=59, second=59)
except ValueError: except ValueError:
raise BadRequest("Invalid date format. Use YYYY-MM-DD") raise BadRequest("Invalid date format. Use YYYY-MM-DD")
@ -215,6 +215,6 @@ class ConversationStats(Resource):
return {"error": "An error occurred while processing the request"}, 500 return {"error": "An error occurred while processing the request"}, 500
api.add_resource(RiskStats, '/stats/risk') api.add_resource(RiskStats, "/stats/risk")
api.add_resource(UserStats, '/stats/user') api.add_resource(UserStats, "/stats/user")
api.add_resource(ConversationStats, '/stats/conversation') api.add_resource(ConversationStats, "/stats/conversation")

@ -120,4 +120,4 @@ class StudentConversation(Resource):
raise NotFound("Last Conversation Not Exists.") raise NotFound("Last Conversation Not Exists.")
api.add_resource(StudentConversation, '/students/<string:student_id>/conversation') api.add_resource(StudentConversation, "/students/<string:student_id>/conversation")

@ -100,11 +100,11 @@ class StudentList(Resource):
from flask import request from flask import request
# Get query parameters with defaults # Get query parameters with defaults
health_status = request.args.get('health_status') health_status = request.args.get("health_status")
begin_date = request.args.get('begin_date') begin_date = request.args.get("begin_date")
end_date = request.args.get('end_date') end_date = request.args.get("end_date")
page = int(request.args.get('page', 1)) page = int(request.args.get("page", 1))
limit = int(request.args.get('limit', 20)) limit = int(request.args.get("limit", 20))
# Validate parameters # Validate parameters
if begin_date: if begin_date:
@ -122,13 +122,13 @@ class StudentList(Resource):
# Build query filters # Build query filters
filters = {} filters = {}
if health_status: if health_status:
filters['health_status'] = health_status filters["health_status"] = health_status
if begin_date: if begin_date:
filters['last_chat_at__gte'] = begin_date filters["last_chat_at__gte"] = begin_date
if end_date: if end_date:
filters['last_chat_at__lte'] = end_date filters["last_chat_at__lte"] = end_date
# Get students with pagination # Get students with pagination
offset = (page - 1) * limit offset = (page - 1) * limit
@ -142,4 +142,4 @@ class StudentList(Resource):
) )
api.add_resource(StudentList, '/students') api.add_resource(StudentList, "/students")

@ -44,16 +44,20 @@ def validate_admin_token_and_extract_info(view: Optional[Callable] = None):
raise Unauthorized("Invalid token: user not found") raise Unauthorized("Invalid token: user not found")
if account.status != AccountStatus.ACTIVE: if account.status != AccountStatus.ACTIVE:
raise Unauthorized("Invalid token: account is not active") raise Unauthorized("Invalid token: account is not active")
# Check if user has admin role in their current organization # Check if user has admin role in their current organization
org_member = db.session.query(OrganizationMember).filter( org_member = (
OrganizationMember.account_id == user_id, db.session.query(OrganizationMember)
OrganizationMember.organization_id == account.current_organization_id .filter(
).first() OrganizationMember.account_id == user_id,
OrganizationMember.organization_id == account.current_organization_id,
)
.first()
)
if not org_member: if not org_member:
raise Unauthorized("Invalid token: user is not a member of any organization") raise Unauthorized("Invalid token: user is not a member of any organization")
# Check if the user has admin role # Check if the user has admin role
if org_member.role != OrganizationRole.ADMIN: if org_member.role != OrganizationRole.ADMIN:
raise Unauthorized("Invalid token: account does not have admin privileges") raise Unauthorized("Invalid token: account does not have admin privileges")

@ -54,25 +54,17 @@ class AnswersSummaryAnalysisApi(Resource):
return {"error": "exam_answers file_id is required"}, 400 return {"error": "exam_answers file_id is required"}, 400
# Read the exam answers file to get categories and correct answers # Read the exam answers file to get categories and correct answers
exam_answers_file_content, _ = self._read_file_with_encoding_detection( exam_answers_file_content, _ = self._read_file_with_encoding_detection(exam_answers_file_id)
exam_answers_file_id
)
if not exam_answers_file_content: if not exam_answers_file_content:
return {"error": "Failed to read exam answers file or file not found"}, 404 return {"error": "Failed to read exam answers file or file not found"}, 404
# Parse the exam answers file # Parse the exam answers file
exam_answers, categories, correct_answer = self._parse_exam_answers( exam_answers, categories, correct_answer = self._parse_exam_answers(exam_answers_file_content)
exam_answers_file_content
)
if not categories or not correct_answer: if not categories or not correct_answer:
return { return {"error": "Failed to parse categories and correct answers from exam file"}, 400
"error": "Failed to parse categories and correct answers from exam file"
}, 400
# Read the user answers file content with encoding detection # Read the user answers file content with encoding detection
user_answers_file_content, _ = self._read_file_with_encoding_detection( user_answers_file_content, _ = self._read_file_with_encoding_detection(user_answers_file_id)
user_answers_file_id
)
if not user_answers_file_content: if not user_answers_file_content:
return {"error": "Failed to read user answers file or file not found"}, 404 return {"error": "Failed to read user answers file or file not found"}, 404
@ -82,9 +74,7 @@ class AnswersSummaryAnalysisApi(Resource):
return {"error": "Failed to parse user answers from file"}, 400 return {"error": "Failed to parse user answers from file"}, 400
# Calculate category statistics # Calculate category statistics
summary_analysis = self._calculate_category_statistics( summary_analysis = self._calculate_category_statistics(user_answers, correct_answer, categories)
user_answers, correct_answer, categories
)
# Return the response # Return the response
return jsonify( return jsonify(
@ -95,16 +85,12 @@ class AnswersSummaryAnalysisApi(Resource):
} }
) )
def _read_file_with_encoding_detection( def _read_file_with_encoding_detection(self, file_id: str) -> tuple[Optional[str], Optional[str]]:
self, file_id: str
) -> tuple[Optional[str], Optional[str]]:
"""Read file content with automatic encoding detection. """Read file content with automatic encoding detection.
Supports both CSV and XLSX files, converting XLSX to CSV text format. Supports both CSV and XLSX files, converting XLSX to CSV text format.
""" """
try: try:
upload_file = ( upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
)
if not upload_file: if not upload_file:
return None, None return None, None
@ -112,15 +98,12 @@ class AnswersSummaryAnalysisApi(Resource):
file_content = storage.load_once(upload_file.key) file_content = storage.load_once(upload_file.key)
# Check if the file is Excel (.xlsx) based on filename or mime type # Check if the file is Excel (.xlsx) based on filename or mime type
file_extension = ( file_extension = upload_file.name.split(".")[-1].lower() if upload_file.name else ""
upload_file.name.split(".")[-1].lower() if upload_file.name else ""
)
mime_type = upload_file.mime_type if upload_file.mime_type else "" mime_type = upload_file.mime_type if upload_file.mime_type else ""
is_excel = ( is_excel = (
file_extension == "xlsx" file_extension == "xlsx"
or mime_type or mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
) )
if is_excel: if is_excel:
@ -174,9 +157,7 @@ class AnswersSummaryAnalysisApi(Resource):
print(f"Error reading file: {str(e)}") print(f"Error reading file: {str(e)}")
return None, None return None, None
def _parse_exam_answers( def _parse_exam_answers(self, file_content: str) -> tuple[list[dict[str, Any]], list[dict[str, Any]], list[str]]:
self, file_content: str
) -> tuple[list[dict[str, Any]], list[dict[str, Any]], list[str]]:
"""Parse exam answers from the file content. """Parse exam answers from the file content.
Expected format is CSV with columns: Expected format is CSV with columns:
@ -207,9 +188,7 @@ class AnswersSummaryAnalysisApi(Resource):
exam_answers = [] exam_answers = []
category_map = defaultdict(list) category_map = defaultdict(list)
correct_answers = [ correct_answers = [""] * 1000 # Initialize with empty strings, we'll trim later
""
] * 1000 # Initialize with empty strings, we'll trim later
max_question_num = 0 max_question_num = 0
for row in csv_reader: for row in csv_reader:
@ -247,9 +226,7 @@ class AnswersSummaryAnalysisApi(Resource):
correct_answers = correct_answers[:max_question_num] correct_answers = correct_answers[:max_question_num]
# Convert category_map to the expected categories format # Convert category_map to the expected categories format
categories = [ categories = [{"name": cat, "items": items} for cat, items in category_map.items()]
{"name": cat, "items": items} for cat, items in category_map.items()
]
return exam_answers, categories, correct_answers return exam_answers, categories, correct_answers
except Exception as e: except Exception as e:
@ -281,9 +258,7 @@ class AnswersSummaryAnalysisApi(Resource):
result = [] result = []
for row in csv_reader: for row in csv_reader:
if ( if not row or len(row) < 4: # Skip empty rows or rows with insufficient data
not row or len(row) < 4
): # Skip empty rows or rows with insufficient data
continue continue
# Extract student ID and name # Extract student ID and name
@ -293,9 +268,7 @@ class AnswersSummaryAnalysisApi(Resource):
# Extract answers (skip ID, name, and score columns) # Extract answers (skip ID, name, and score columns)
answers = [ans.strip() for ans in row[3:]] answers = [ans.strip() for ans in row[3:]]
result.append( result.append({"user_name": name, "code": student_id, "answers": answers})
{"user_name": name, "code": student_id, "answers": answers}
)
return result return result
except Exception as e: except Exception as e:
@ -399,9 +372,7 @@ class GenerateAnalysisReportApi(Resource):
data = request.get_json() data = request.get_json()
summary_analysis = data.get("summary_analysis") summary_analysis = data.get("summary_analysis")
school_name = data.get( school_name = data.get("school_name", "山东单县一中") # Default value if not provided
"school_name", "山东单县一中"
) # Default value if not provided
html_template = data.get("html_template") html_template = data.get("html_template")
if not summary_analysis: if not summary_analysis:
@ -507,9 +478,7 @@ class GenerateAnalysisReportApi(Resource):
# Create the HTML with the template # Create the HTML with the template
template = Template(html_template) template = Template(html_template)
html_content = template.render( html_content = template.render(school_name=school_name, summary_analysis=summary_analysis)
school_name=school_name, summary_analysis=summary_analysis
)
# Generate PDF # Generate PDF
html = HTML(string=html_content) html = HTML(string=html_content)

@ -198,7 +198,6 @@ class EmailCodeLoginApi(Resource):
is_new_user = account is None is_new_user = account is None
if account is None: if account is None:
# Create new account # Create new account
account = AccountService.create_account_in_tenant( account = AccountService.create_account_in_tenant(
tenant=tenant, tenant=tenant,
@ -212,9 +211,11 @@ class EmailCodeLoginApi(Resource):
OrganizationService.assign_account_to_organization(account, organization.id) OrganizationService.assign_account_to_organization(account, organization.id)
else: else:
if (
if (organization is not None and account.current_organization_id is not None organization is not None
and account.current_organization_id != organization.id): and account.current_organization_id is not None
and account.current_organization_id != organization.id
):
raise OrganizationMismatchError() raise OrganizationMismatchError()
connected_tenant = TenantService.get_join_tenants(account) connected_tenant = TenantService.get_join_tenants(account)

@ -89,33 +89,33 @@ class UserProfile(Resource):
validated_data = {} validated_data = {}
# Validate username if provided # Validate username if provided
if 'username' in data: if "username" in data:
username = data['username'] username = data["username"]
# Validate username (Chinese or English only, max 10 chars) # Validate username (Chinese or English only, max 10 chars)
if not re.match(r'^[a-zA-Z\u4e00-\u9fa5]+$', username) or len(username) > 10: if not re.match(r"^[a-zA-Z\u4e00-\u9fa5]+$", username) or len(username) > 10:
return {"success": False, "message": "Invalid username format"}, 400 return {"success": False, "message": "Invalid username format"}, 400
validated_data['username'] = username validated_data["username"] = username
# Validate gender if provided # Validate gender if provided
if 'gender' in data: if "gender" in data:
gender_str = data['gender'] gender_str = data["gender"]
if gender_str not in ["unknown", "male", "female"]: if gender_str not in ["unknown", "male", "female"]:
return {"success": False, "message": "Invalid gender value"}, 400 return {"success": False, "message": "Invalid gender value"}, 400
validated_data['gender'] = gender_str validated_data["gender"] = gender_str
# Validate major if provided # Validate major if provided
if 'major' in data: if "major" in data:
major = data['major'] major = data["major"]
# Allow None as a valid value (to clear the field) # Allow None as a valid value (to clear the field)
if major is None: if major is None:
validated_data['major'] = None validated_data["major"] = None
elif not isinstance(major, str): elif not isinstance(major, str):
return {"success": False, "message": "Major must be a string value or null"}, 400 return {"success": False, "message": "Major must be a string value or null"}, 400
elif len(major) > 50: elif len(major) > 50:
return {"success": False, "message": "Major exceeds maximum length of 50"}, 400 return {"success": False, "message": "Major exceeds maximum length of 50"}, 400
else: else:
validated_data['major'] = major validated_data["major"] = major
# Use the service to update user profile # Use the service to update user profile
success, error = EndUserService.update_user_profile(end_user, validated_data) success, error = EndUserService.update_user_profile(end_user, validated_data)
@ -126,4 +126,4 @@ class UserProfile(Resource):
return {"success": True} return {"success": True}
api.add_resource(UserProfile, '/user/profile') api.add_resource(UserProfile, "/user/profile")

@ -27,15 +27,11 @@ def load_user_from_request(request_from_flask_login):
raise Unauthorized("Invalid Authorization token.") raise Unauthorized("Invalid Authorization token.")
else: else:
if " " not in auth_header: if " " not in auth_header:
raise Unauthorized( raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
"Invalid Authorization header format. Expected 'Bearer <api-key>' format."
)
auth_scheme, auth_token = auth_header.split(None, 1) auth_scheme, auth_token = auth_header.split(None, 1)
auth_scheme = auth_scheme.lower() auth_scheme = auth_scheme.lower()
if auth_scheme != "bearer": if auth_scheme != "bearer":
raise Unauthorized( raise Unauthorized("Invalid Authorization header format. Expected 'Bearer <api-key>' format.")
"Invalid Authorization header format. Expected 'Bearer <api-key>' format."
)
decoded = PassportService().verify(auth_token) decoded = PassportService().verify(auth_token)
user_id = decoded.get("user_id") user_id = decoded.get("user_id")

@ -111,9 +111,7 @@ def init_app(app: DifyApp):
) as span: ) as span:
span.set_status(StatusCode.ERROR) span.set_status(StatusCode.ERROR)
span.record_exception(record.exc_info[1]) span.record_exception(record.exc_info[1])
span.set_attribute( span.set_attribute("exception.type", record.exc_info[0].__name__)
"exception.type", record.exc_info[0].__name__
)
span.set_attribute("exception.message", str(record.exc_info[1])) span.set_attribute("exception.message", str(record.exc_info[1]))
except Exception: except Exception:
pass pass
@ -198,9 +196,7 @@ def init_app(app: DifyApp):
set_meter_provider(MeterProvider(resource=resource, metric_readers=[reader])) set_meter_provider(MeterProvider(resource=resource, metric_readers=[reader]))
if not is_celery_worker(): if not is_celery_worker():
init_flask_instrumentor(app) init_flask_instrumentor(app)
CeleryInstrumentor( CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument()
tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()
).instrument()
instrument_exception_logging() instrument_exception_logging()
init_sqlalchemy_instrumentor(app) init_sqlalchemy_instrumentor(app)
atexit.register(shutdown_tracer) atexit.register(shutdown_tracer)
@ -221,6 +217,4 @@ def init_celery_worker(*args, **kwargs):
metric_provider = get_meter_provider() metric_provider = get_meter_provider()
if dify_config.DEBUG: if dify_config.DEBUG:
logging.info("Initializing OpenTelemetry for Celery worker") logging.info("Initializing OpenTelemetry for Celery worker")
CeleryInstrumentor( CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument()
tracer_provider=tracer_provider, meter_provider=metric_provider
).instrument()

@ -47,11 +47,10 @@ class PhoneSms:
access_key_secret=secret, access_key_secret=secret,
) )
# Endpoint 请参考 https://api.aliyun.com/product/Dysmsapi # Endpoint 请参考 https://api.aliyun.com/product/Dysmsapi
config.endpoint = 'dysmsapi.aliyuncs.com' config.endpoint = "dysmsapi.aliyuncs.com"
return Dysmsapi20170525Client(config) return Dysmsapi20170525Client(config)
def send_sms(self, phone_numbers: str, code: str) -> None: def send_sms(self, phone_numbers: str, code: str) -> None:
if not self._client: if not self._client:
raise ValueError("PhoneSms client is not initialized") raise ValueError("PhoneSms client is not initialized")
@ -63,7 +62,7 @@ class PhoneSms:
) )
response = self._client.send_sms_with_options(send_sms_request, util_models.RuntimeOptions()) response = self._client.send_sms_with_options(send_sms_request, util_models.RuntimeOptions())
if response.body.code != 'OK': if response.body.code != "OK":
raise Exception(response.body.message) raise Exception(response.body.message)

@ -2,23 +2,22 @@ from dify_app import DifyApp
def init_app(app: DifyApp): def init_app(app: DifyApp):
from flasgger import Swagger from flasgger import Swagger
app.config['SWAGGER'] = { app.config["SWAGGER"] = {
'title': 'API Docs', "title": "API Docs",
'uiversion': 3, "uiversion": 3,
'url_prefix': '/openapi', "url_prefix": "/openapi",
'specs_route': '/', "specs_route": "/",
'static_url_path': '/flasgger_static', "static_url_path": "/flasgger_static",
'securityDefinitions': { "securityDefinitions": {
'ApiKeyAuth': { "ApiKeyAuth": {
'type': 'apiKey', "type": "apiKey",
'name': 'Authorization', "name": "Authorization",
'in': 'header', "in": "header",
'description': 'API Key Authorization header using Bearer scheme. Example: "Bearer {token}"' "description": 'API Key Authorization header using Bearer scheme. Example: "Bearer {token}"',
} }
} },
} }
Swagger(app) Swagger(app)

@ -59,13 +59,13 @@ class Organization(db.Model): # type: ignore[name-defined]
def allowed_email_domains(self) -> list[str]: def allowed_email_domains(self) -> list[str]:
"""Get list of allowed email domains for this organization""" """Get list of allowed email domains for this organization"""
settings = self.settings_dict settings = self.settings_dict
return settings.get('allowed_email_domains', []) return settings.get("allowed_email_domains", [])
@allowed_email_domains.setter @allowed_email_domains.setter
def allowed_email_domains(self, domains: list[str]): def allowed_email_domains(self, domains: list[str]):
"""Set allowed email domains for this organization""" """Set allowed email domains for this organization"""
settings = self.settings_dict settings = self.settings_dict
settings['allowed_email_domains'] = domains settings["allowed_email_domains"] = domains
self.settings_dict = settings self.settings_dict = settings
@property @property
@ -78,7 +78,7 @@ class Organization(db.Model): # type: ignore[name-defined]
if not self.is_email_restricted: if not self.is_email_restricted:
return True return True
email_domain = email.split('@')[-1].lower() email_domain = email.split("@")[-1].lower()
return email_domain in self.allowed_email_domains return email_domain in self.allowed_email_domains
@property @property

@ -43,9 +43,7 @@ def user_profile_generate_task():
logger.info(f"No users to update. for app_id {app_id}") logger.info(f"No users to update. for app_id {app_id}")
continue continue
logger.info( logger.info(f"Found {len(users_to_update)} users profile and memory updates. in app_id {app_id}")
f"Found {len(users_to_update)} users profile and memory updates. in app_id {app_id}"
)
update_user_profile_for_appid(users_to_update) update_user_profile_for_appid(users_to_update)
end_at = time.perf_counter() end_at = time.perf_counter()
@ -64,9 +62,7 @@ def update_user_profile_for_appid(users_to_update: list[EndUser]):
batch = users_to_update[i : i + batch_size] batch = users_to_update[i : i + batch_size]
try: try:
for user in batch: for user in batch:
new_messages, latest_messages_created_at = fetch_new_messages_for_user( new_messages, latest_messages_created_at = fetch_new_messages_for_user(user)
user
)
if len(new_messages) > 0: if len(new_messages) > 0:
process_user_memory(user, new_messages) process_user_memory(user, new_messages)
@ -91,9 +87,7 @@ def fetch_users_to_update(app_id: str) -> list[EndUser]:
) )
latest_message_query = latest_message_query.filter(Message.app_id == app_id) latest_message_query = latest_message_query.filter(Message.app_id == app_id)
latest_message_subquery = latest_message_query.group_by( latest_message_subquery = latest_message_query.group_by(Message.from_end_user_id).subquery()
Message.from_end_user_id
).subquery()
# Then join with EndUser to find users who need memory updates # Then join with EndUser to find users who need memory updates
users_query = ( users_query = (
@ -106,8 +100,7 @@ def fetch_users_to_update(app_id: str) -> list[EndUser]:
EndUser.app_id == app_id, EndUser.app_id == app_id,
or_( or_(
EndUser.profile_updated_at.is_(None), EndUser.profile_updated_at.is_(None),
EndUser.profile_updated_at EndUser.profile_updated_at < latest_message_subquery.c.latest_message_time,
< latest_message_subquery.c.latest_message_time,
), ),
) )
) )
@ -122,14 +115,10 @@ def fetch_users_to_update(app_id: str) -> list[EndUser]:
def fetch_new_messages_for_user(user: EndUser) -> tuple[str, datetime]: def fetch_new_messages_for_user(user: EndUser) -> tuple[str, datetime]:
"""Fetch new messages for a user.""" """Fetch new messages for a user."""
message_query = db.session.query(Message).filter( message_query = db.session.query(Message).filter(Message.from_end_user_id == user.id)
Message.from_end_user_id == user.id
)
message_query = message_query.filter(Message.app_id == user.app_id) message_query = message_query.filter(Message.app_id == user.app_id)
if user.profile_updated_at: if user.profile_updated_at:
message_query = message_query.filter( message_query = message_query.filter(Message.created_at > user.profile_updated_at)
Message.created_at > user.profile_updated_at
)
new_messages = message_query.order_by(asc(Message.created_at)).all() new_messages = message_query.order_by(asc(Message.created_at)).all()
if len(new_messages) == 0: if len(new_messages) == 0:
@ -150,9 +139,7 @@ def process_user_memory(user: EndUser, new_messages: str):
memory_app_id = dify_config.USER_MEMORY_GENERATION_APP_ID memory_app_id = dify_config.USER_MEMORY_GENERATION_APP_ID
if memory_app_id == "": if memory_app_id == "":
logger.warning( logger.warning("No memory generation app_id provided, skipping memory generation.")
"No memory generation app_id provided, skipping memory generation."
)
return return
memory_app_model = db.session.query(App).filter(App.id == memory_app_id).first() memory_app_model = db.session.query(App).filter(App.id == memory_app_id).first()
@ -195,18 +182,12 @@ def process_user_health_summary(user: EndUser, new_messages: str):
health_summary_app_id = dify_config.USER_HEALTH_SUMMARY_GENERATION_APP_ID health_summary_app_id = dify_config.USER_HEALTH_SUMMARY_GENERATION_APP_ID
if health_summary_app_id == "": if health_summary_app_id == "":
logger.warning( logger.warning("No health summary app_id provided, skipping health summary generation.")
"No health summary app_id provided, skipping health summary generation."
)
return return
health_summary_app_model = ( health_summary_app_model = db.session.query(App).filter(App.id == health_summary_app_id).first()
db.session.query(App).filter(App.id == health_summary_app_id).first()
)
if health_summary_app_model is None: if health_summary_app_model is None:
logger.error( logger.error(f"App not found for health summary generation app_id {health_summary_app_id}")
f"App not found for health summary generation app_id {health_summary_app_id}"
)
return return
args = { args = {
@ -237,9 +218,7 @@ def process_user_health_summary(user: EndUser, new_messages: str):
result = response["data"]["outputs"]["result"] result = response["data"]["outputs"]["result"]
if result is None: if result is None:
logger.warning( logger.warning(f"Health summary generation failed with None result for user {user.id}")
f"Health summary generation failed with None result for user {user.id}"
)
return return
# preprocess result in case of ```json xxxx``` # preprocess result in case of ```json xxxx```

@ -70,9 +70,7 @@ REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService: class AccountService:
reset_password_rate_limiter = RateLimiter( reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1
)
email_code_login_rate_limiter = RateLimiter( email_code_login_rate_limiter = RateLimiter(
prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1 prefix="email_code_login_rate_limit", max_attempts=1, time_window=60 * 1
) )
@ -122,16 +120,12 @@ class AccountService:
if account.status == AccountStatus.BANNED.value: if account.status == AccountStatus.BANNED.value:
raise Unauthorized("Account is banned.") raise Unauthorized("Account is banned.")
current_tenant = TenantAccountJoin.query.filter_by( current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first()
account_id=account.id, current=True
).first()
if current_tenant: if current_tenant:
account.current_tenant_id = current_tenant.tenant_id account.current_tenant_id = current_tenant.tenant_id
else: else:
available_ta = ( available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id) TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
.order_by(TenantAccountJoin.id.asc())
.first()
) )
if not available_ta: if not available_ta:
return None return None
@ -140,9 +134,7 @@ class AccountService:
available_ta.current = True available_ta.current = True
db.session.commit() db.session.commit()
if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta( if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10):
minutes=10
):
account.last_active_at = datetime.now(UTC).replace(tzinfo=None) account.last_active_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit() db.session.commit()
@ -150,9 +142,7 @@ class AccountService:
@staticmethod @staticmethod
def get_account_jwt_token(account: Account) -> str: def get_account_jwt_token(account: Account) -> str:
exp_dt = datetime.now(UTC) + timedelta( exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES)
minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES
)
exp = int(exp_dt.timestamp()) exp = int(exp_dt.timestamp())
payload = { payload = {
"user_id": account.id, "user_id": account.id,
@ -165,9 +155,7 @@ class AccountService:
return token return token
@staticmethod @staticmethod
def authenticate( def authenticate(email: str, password: str, invite_token: Optional[str] = None) -> Account:
email: str, password: str, invite_token: Optional[str] = None
) -> Account:
"""authenticate account with email and password""" """authenticate account with email and password"""
account = db.session.query(Account).filter_by(email=email).first() account = db.session.query(Account).filter_by(email=email).first()
@ -186,9 +174,7 @@ class AccountService:
account.password = base64_password_hashed account.password = base64_password_hashed
account.password_salt = base64_salt account.password_salt = base64_salt
if account.password is None or not compare_password( if account.password is None or not compare_password(password, account.password, account.password_salt):
password, account.password, account.password_salt
):
raise AccountPasswordError("Invalid email or password.") raise AccountPasswordError("Invalid email or password.")
if account.status == AccountStatus.PENDING.value: if account.status == AccountStatus.PENDING.value:
@ -202,9 +188,7 @@ class AccountService:
@staticmethod @staticmethod
def update_account_password(account, password, new_password): def update_account_password(account, password, new_password):
"""update account password""" """update account password"""
if account.password and not compare_password( if account.password and not compare_password(password, account.password, account.password_salt):
password, account.password, account.password_salt
):
raise CurrentPasswordIncorrectError("Current password is incorrect.") raise CurrentPasswordIncorrectError("Current password is incorrect.")
# may be raised # may be raised
@ -352,11 +336,9 @@ class AccountService:
"""Link account integrate""" """Link account integrate"""
try: try:
# Query whether there is an existing binding record for the same provider # Query whether there is an existing binding record for the same provider
account_integrate: Optional[AccountIntegrate] = ( account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by(
AccountIntegrate.query.filter_by( account_id=account.id, provider=provider
account_id=account.id, provider=provider ).first()
).first()
)
if account_integrate: if account_integrate:
# If it exists, update the record # If it exists, update the record
@ -376,9 +358,7 @@ class AccountService:
db.session.commit() db.session.commit()
logging.info(f"Account {account.id} linked {provider} account {open_id}.") logging.info(f"Account {account.id} linked {provider} account {open_id}.")
except Exception as e: except Exception as e:
logging.exception( logging.exception(f"Failed to link {provider} account {open_id} to Account {account.id}")
f"Failed to link {provider} account {open_id} to Account {account.id}"
)
raise LinkAccountIntegrateError("Failed to link account.") from e raise LinkAccountIntegrateError("Failed to link account.") from e
@staticmethod @staticmethod
@ -425,20 +405,14 @@ class AccountService:
@staticmethod @staticmethod
def logout(*, account: Account) -> None: def logout(*, account: Account) -> None:
refresh_token = redis_client.get( refresh_token = redis_client.get(AccountService._get_account_refresh_token_key(account.id))
AccountService._get_account_refresh_token_key(account.id)
)
if refresh_token: if refresh_token:
AccountService._delete_refresh_token( AccountService._delete_refresh_token(refresh_token.decode("utf-8"), account.id)
refresh_token.decode("utf-8"), account.id
)
@staticmethod @staticmethod
def refresh_token(refresh_token: str) -> TokenPair: def refresh_token(refresh_token: str) -> TokenPair:
# Verify the refresh token # Verify the refresh token
account_id = redis_client.get( account_id = redis_client.get(AccountService._get_refresh_token_key(refresh_token))
AccountService._get_refresh_token_key(refresh_token)
)
if not account_id: if not account_id:
raise ValueError("Invalid refresh token") raise ValueError("Invalid refresh token")
@ -525,9 +499,7 @@ class AccountService:
if email is None: if email is None:
raise ValueError("Email must be provided.") raise ValueError("Email must be provided.")
if dify_config.DEBUG_ORG_EMAIL_DOMAIN and email.endswith( if dify_config.DEBUG_ORG_EMAIL_DOMAIN and email.endswith(dify_config.DEBUG_ORG_EMAIL_DOMAIN):
dify_config.DEBUG_ORG_EMAIL_DOMAIN
):
code = dify_config.DEBUG_CODE_FOR_LOGIN code = dify_config.DEBUG_CODE_FOR_LOGIN
elif cls.email_code_login_rate_limiter.is_rate_limited(email): elif cls.email_code_login_rate_limiter.is_rate_limited(email):
from controllers.console.auth.error import ( from controllers.console.auth.error import (
@ -659,9 +631,7 @@ class AccountService:
redis_client.setex(freeze_key, 60 * 60, 1) redis_client.setex(freeze_key, 60 * 60, 1)
return True return True
else: else:
redis_client.setex( redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes
hour_limit_key, 60 * 10, hour_limit_count + 1
) # first time limit 10 minutes
# add hour limit count # add hour limit count
redis_client.incr(hour_limit_key) redis_client.incr(hour_limit_key)
@ -697,9 +667,7 @@ class AccountService:
organization_id = admin_account.current_organization_id organization_id = admin_account.current_organization_id
if not organization_id: if not organization_id:
logging.warning( logging.warning(f"Account {admin_account.id} is not a member of any organization.")
f"Account {admin_account.id} is not a member of any organization."
)
return None return None
# If organization_id is provided, check if account is an admin member of that organization # If organization_id is provided, check if account is an admin member of that organization
@ -716,9 +684,7 @@ class AccountService:
) )
if not org_member: if not org_member:
logging.warning( logging.warning(f"Account {admin_account.id} is not a member of any organization.")
f"Account {admin_account.id} is not a member of any organization."
)
return None return None
return admin_account return admin_account
@ -744,9 +710,7 @@ class AccountService:
current_minute_count = int(current_minute_count) current_minute_count = int(current_minute_count)
# check current hour count # check current hour count
if ( if current_minute_count > dify_config.EMAIL_SEND_IP_LIMIT_PER_MINUTE: # Use same limit as email
current_minute_count > dify_config.EMAIL_SEND_IP_LIMIT_PER_MINUTE
): # Use same limit as email
hour_limit_count = redis_client.get(hour_limit_key) hour_limit_count = redis_client.get(hour_limit_key)
if hour_limit_count is None: if hour_limit_count is None:
hour_limit_count = 0 hour_limit_count = 0
@ -756,9 +720,7 @@ class AccountService:
redis_client.setex(freeze_key, 60 * 60, 1) redis_client.setex(freeze_key, 60 * 60, 1)
return True return True
else: else:
redis_client.setex( redis_client.setex(hour_limit_key, 60 * 10, hour_limit_count + 1) # first time limit 10 minutes
hour_limit_key, 60 * 10, hour_limit_count + 1
) # first time limit 10 minutes
# add hour limit count # add hour limit count
redis_client.incr(hour_limit_key) redis_client.incr(hour_limit_key)
@ -823,11 +785,7 @@ class AccountService:
Returns None if no admin account with this ID exists. Returns None if no admin account with this ID exists.
Raises Unauthorized if account is banned. Raises Unauthorized if account is banned.
""" """
account = ( account = db.session.query(Account).filter((Account.email == login_id) | (Account.phone == login_id)).first()
db.session.query(Account)
.filter((Account.email == login_id) | (Account.phone == login_id))
.first()
)
if not account: if not account:
return None return None
@ -842,7 +800,6 @@ class AccountService:
class TenantService: class TenantService:
@staticmethod @staticmethod
def get_tenant_by_id(tenant_id: str) -> Tenant: def get_tenant_by_id(tenant_id: str) -> Tenant:
return Tenant.query.filter_by(id=tenant_id).first() return Tenant.query.filter_by(id=tenant_id).first()
@ -877,53 +834,38 @@ class TenantService:
): ):
"""Check if user have a workspace or not""" """Check if user have a workspace or not"""
available_ta = ( available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id) TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first()
.order_by(TenantAccountJoin.id.asc())
.first()
) )
if available_ta: if available_ta:
return return
"""Create owner tenant if not exist""" """Create owner tenant if not exist"""
if ( if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup:
not FeatureService.get_system_features().is_allow_create_workspace
and not is_setup
):
raise WorkSpaceNotAllowedCreateError() raise WorkSpaceNotAllowedCreateError()
if name: if name:
tenant = TenantService.create_tenant(name=name, is_setup=is_setup) tenant = TenantService.create_tenant(name=name, is_setup=is_setup)
else: else:
tenant = TenantService.create_tenant( tenant = TenantService.create_tenant(name=f"{account.name}'s Workspace", is_setup=is_setup)
name=f"{account.name}'s Workspace", is_setup=is_setup
)
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant account.current_tenant = tenant
db.session.commit() db.session.commit()
tenant_was_created.send(tenant) tenant_was_created.send(tenant)
@staticmethod @staticmethod
def create_tenant_member( def create_tenant_member(tenant: Tenant, account: Account, role: str = "normal") -> TenantAccountJoin:
tenant: Tenant, account: Account, role: str = "normal"
) -> TenantAccountJoin:
"""Create tenant member""" """Create tenant member"""
if role == TenantAccountRole.OWNER.value: if role == TenantAccountRole.OWNER.value:
if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]): if TenantService.has_roles(tenant, [TenantAccountRole.OWNER]):
logging.error(f"Tenant {tenant.id} has already an owner.") logging.error(f"Tenant {tenant.id} has already an owner.")
raise Exception("Tenant already has an owner.") raise Exception("Tenant already has an owner.")
ta = ( ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
db.session.query(TenantAccountJoin)
.filter_by(tenant_id=tenant.id, account_id=account.id)
.first()
)
if ta: if ta:
ta.role = role ta.role = role
else: else:
ta = TenantAccountJoin( ta = TenantAccountJoin(tenant_id=tenant.id, account_id=account.id, role=role)
tenant_id=tenant.id, account_id=account.id, role=role
)
db.session.add(ta) db.session.add(ta)
db.session.commit() db.session.commit()
@ -949,9 +891,7 @@ class TenantService:
if not tenant: if not tenant:
raise TenantNotFoundError("Tenant not found.") raise TenantNotFoundError("Tenant not found.")
ta = TenantAccountJoin.query.filter_by( ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
tenant_id=tenant.id, account_id=account.id
).first()
if ta: if ta:
tenant.role = ta.role tenant.role = ta.role
else: else:
@ -978,9 +918,7 @@ class TenantService:
) )
if not tenant_account_join: if not tenant_account_join:
raise AccountNotLinkTenantError( raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
"Tenant not found or account is not a member of the tenant."
)
else: else:
TenantAccountJoin.query.filter( TenantAccountJoin.query.filter(
TenantAccountJoin.account_id == account.id, TenantAccountJoin.account_id == account.id,
@ -1065,9 +1003,7 @@ class TenantService:
return cast(int, db.session.query(func.count(Tenant.id)).scalar()) return cast(int, db.session.query(func.count(Tenant.id)).scalar())
@staticmethod @staticmethod
def check_member_permission( def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None:
tenant: Tenant, operator: Account, member: Account | None, action: str
) -> None:
"""Check member permission""" """Check member permission"""
perms = { perms = {
"add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], "add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN],
@ -1081,26 +1017,20 @@ class TenantService:
if operator.id == member.id: if operator.id == member.id:
raise CannotOperateSelfError("Cannot operate self.") raise CannotOperateSelfError("Cannot operate self.")
ta_operator = TenantAccountJoin.query.filter_by( ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first()
tenant_id=tenant.id, account_id=operator.id
).first()
if not ta_operator or ta_operator.role not in perms[action]: if not ta_operator or ta_operator.role not in perms[action]:
raise NoPermissionError(f"No permission to {action} member.") raise NoPermissionError(f"No permission to {action} member.")
@staticmethod @staticmethod
def remove_member_from_tenant( def remove_member_from_tenant(tenant: Tenant, account: Account, operator: Account) -> None:
tenant: Tenant, account: Account, operator: Account
) -> None:
"""Remove member from tenant""" """Remove member from tenant"""
if operator.id == account.id: if operator.id == account.id:
raise CannotOperateSelfError("Cannot operate self.") raise CannotOperateSelfError("Cannot operate self.")
TenantService.check_member_permission(tenant, operator, account, "remove") TenantService.check_member_permission(tenant, operator, account, "remove")
ta = TenantAccountJoin.query.filter_by( ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
tenant_id=tenant.id, account_id=account.id
).first()
if not ta: if not ta:
raise MemberNotInTenantError("Member not in tenant.") raise MemberNotInTenantError("Member not in tenant.")
@ -1108,26 +1038,18 @@ class TenantService:
db.session.commit() db.session.commit()
@staticmethod @staticmethod
def update_member_role( def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account) -> None:
tenant: Tenant, member: Account, new_role: str, operator: Account
) -> None:
"""Update member role""" """Update member role"""
TenantService.check_member_permission(tenant, operator, member, "update") TenantService.check_member_permission(tenant, operator, member, "update")
target_member_join = TenantAccountJoin.query.filter_by( target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first()
tenant_id=tenant.id, account_id=member.id
).first()
if target_member_join.role == new_role: if target_member_join.role == new_role:
raise RoleAlreadyAssignedError( raise RoleAlreadyAssignedError("The provided role is already assigned to the member.")
"The provided role is already assigned to the member."
)
if new_role == "owner": if new_role == "owner":
# Find the current owner and change their role to 'admin' # Find the current owner and change their role to 'admin'
current_owner_join = TenantAccountJoin.query.filter_by( current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
tenant_id=tenant.id, role="owner"
).first()
current_owner_join.role = "admin" current_owner_join.role = "admin"
# Update the role of the target member # Update the role of the target member
@ -1137,9 +1059,7 @@ class TenantService:
@staticmethod @staticmethod
def dissolve_tenant(tenant: Tenant, operator: Account) -> None: def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
"""Dissolve tenant""" """Dissolve tenant"""
if not TenantService.check_member_permission( if not TenantService.check_member_permission(tenant, operator, operator, "remove"):
tenant, operator, operator, "remove"
):
raise NoPermissionError("No permission to dissolve tenant.") raise NoPermissionError("No permission to dissolve tenant.")
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
db.session.delete(tenant) db.session.delete(tenant)
@ -1224,10 +1144,7 @@ class RegisterService:
if open_id is not None and provider is not None: if open_id is not None and provider is not None:
AccountService.link_account_integrate(provider, open_id, account) AccountService.link_account_integrate(provider, open_id, account)
if ( if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required:
FeatureService.get_system_features().is_allow_create_workspace
and create_workspace_required
):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace") tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant account.current_tenant = tenant
@ -1281,9 +1198,7 @@ class RegisterService:
TenantService.switch_tenant(account, tenant.id) TenantService.switch_tenant(account, tenant.id)
else: else:
TenantService.check_member_permission(tenant, inviter, account, "add") TenantService.check_member_permission(tenant, inviter, account, "add")
ta = TenantAccountJoin.query.filter_by( ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first()
tenant_id=tenant.id, account_id=account.id
).first()
if not ta: if not ta:
TenantService.create_tenant_member(tenant, account, role) TenantService.create_tenant_member(tenant, account, role)
@ -1330,9 +1245,7 @@ class RegisterService:
def revoke_token(cls, workspace_id: str, email: str, token: str): def revoke_token(cls, workspace_id: str, email: str, token: str):
if workspace_id and email: if workspace_id and email:
email_hash = sha256(email.encode()).hexdigest() email_hash = sha256(email.encode()).hexdigest()
cache_key = "member_invite_token:{}, {}:{}".format( cache_key = "member_invite_token:{}, {}:{}".format(workspace_id, email_hash, token)
workspace_id, email_hash, token
)
redis_client.delete(cache_key) redis_client.delete(cache_key)
else: else:
redis_client.delete(cls._get_invitation_token_key(token)) redis_client.delete(cls._get_invitation_token_key(token))
@ -1347,9 +1260,7 @@ class RegisterService:
tenant = ( tenant = (
db.session.query(Tenant) db.session.query(Tenant)
.filter( .filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal"
)
.first() .first()
) )

@ -32,8 +32,8 @@ class EndUserService:
db.session.query( db.session.query(
Message.from_end_user_id, Message.from_end_user_id,
func.count( func.count(
func.distinct(func.date(func.timezone('UTC+8', func.timezone('UTC', Message.created_at)))) func.distinct(func.date(func.timezone("UTC+8", func.timezone("UTC", Message.created_at))))
).label('active_days'), ).label("active_days"),
) )
.filter(Message.app_id == app_model.id) .filter(Message.app_id == app_model.id)
.group_by(Message.from_end_user_id) .group_by(Message.from_end_user_id)
@ -44,9 +44,9 @@ class EndUserService:
subq = ( subq = (
db.session.query( db.session.query(
Conversation.from_end_user_id, Conversation.from_end_user_id,
func.max(Conversation.created_at).label('last_chat_at'), func.max(Conversation.created_at).label("last_chat_at"),
func.min(Conversation.created_at).label('first_chat_at'), func.min(Conversation.created_at).label("first_chat_at"),
func.count(Message.id).label('total_messages'), func.count(Message.id).label("total_messages"),
) )
.filter(Conversation.app_id == app_model.id) .filter(Conversation.app_id == app_model.id)
.join(Message, Message.conversation_id == Conversation.id) .join(Message, Message.conversation_id == Conversation.id)
@ -75,14 +75,14 @@ class EndUserService:
# Apply filters # Apply filters
filter_conditions = [] filter_conditions = []
if 'health_status' in filters: if "health_status" in filters:
filter_conditions.append(EndUser.health_status == filters['health_status']) filter_conditions.append(EndUser.health_status == filters["health_status"])
if 'last_chat_at__gte' in filters: if "last_chat_at__gte" in filters:
filter_conditions.append(subq.c.last_chat_at >= filters['last_chat_at__gte']) filter_conditions.append(subq.c.last_chat_at >= filters["last_chat_at__gte"])
if 'last_chat_at__lte' in filters: if "last_chat_at__lte" in filters:
filter_conditions.append(subq.c.last_chat_at <= filters['last_chat_at__lte']) filter_conditions.append(subq.c.last_chat_at <= filters["last_chat_at__lte"])
# Apply all filter conditions # Apply all filter conditions
if filter_conditions: if filter_conditions:
@ -109,17 +109,17 @@ class EndUserService:
# Convert to dictionary for JSON serialization # Convert to dictionary for JSON serialization
end_user_dict = { end_user_dict = {
'id': end_user.external_user_id, "id": end_user.external_user_id,
'email': end_user.email, "email": end_user.email,
'first_chat_at': end_user.first_chat_at, "first_chat_at": end_user.first_chat_at,
'last_chat_at': end_user.last_chat_at, "last_chat_at": end_user.last_chat_at,
'total_messages': end_user.total_messages, "total_messages": end_user.total_messages,
'active_days': end_user.active_days, "active_days": end_user.active_days,
'health_status': end_user.health_status, "health_status": end_user.health_status,
'topics': end_user.topics, "topics": end_user.topics,
'summary': end_user.summary, "summary": end_user.summary,
'major': end_user.major, "major": end_user.major,
'organization_id': end_user.organization_id, "organization_id": end_user.organization_id,
} }
users.append(end_user_dict) users.append(end_user_dict)
@ -172,18 +172,18 @@ class EndUserService:
""" """
try: try:
# Update username if provided # Update username if provided
if 'username' in profile_data: if "username" in profile_data:
end_user.name = profile_data['username'] end_user.name = profile_data["username"]
# Update gender if provided # Update gender if provided
if 'gender' in profile_data: if "gender" in profile_data:
gender_str = profile_data['gender'] gender_str = profile_data["gender"]
gender_map = {"unknown": 0, "male": 1, "female": 2} gender_map = {"unknown": 0, "male": 1, "female": 2}
end_user.gender = gender_map[gender_str] end_user.gender = gender_map[gender_str]
# Update major if provided # Update major if provided
if 'major' in profile_data: if "major" in profile_data:
major = profile_data['major'] major = profile_data["major"]
# Create a new dictionary if extra_profile is None # Create a new dictionary if extra_profile is None
if end_user.extra_profile is None: if end_user.extra_profile is None:
@ -191,7 +191,7 @@ class EndUserService:
# Make a copy of the existing dictionary to ensure changes are detected # Make a copy of the existing dictionary to ensure changes are detected
extra_profile = dict(end_user.extra_profile) extra_profile = dict(end_user.extra_profile)
extra_profile['major'] = major extra_profile["major"] = major
end_user.extra_profile = extra_profile end_user.extra_profile = extra_profile
# Force the change to be detected # Force the change to be detected

@ -19,7 +19,6 @@ DEFAULT_IMAGE_EXTENSION = ".png"
class ImageGenerationService: class ImageGenerationService:
generate_image_rate_limiter = RateLimiter( generate_image_rate_limiter = RateLimiter(
prefix="generate_image_rate_limit", max_attempts=dify_config.IMAGE_GENERATION_DAILY_LIMIT, time_window=86400 * 1 prefix="generate_image_rate_limit", max_attempts=dify_config.IMAGE_GENERATION_DAILY_LIMIT, time_window=86400 * 1
) )
@ -62,7 +61,6 @@ class ImageGenerationService:
@staticmethod @staticmethod
def pagination_image_list(end_user: EndUser, limit: int, offset: int) -> MultiPagePagination: def pagination_image_list(end_user: EndUser, limit: int, offset: int) -> MultiPagePagination:
query = ( query = (
db.session.query(UserGeneratedImage) db.session.query(UserGeneratedImage)
.filter(UserGeneratedImage.app_id == end_user.app_id, UserGeneratedImage.end_user_id == end_user.id) .filter(UserGeneratedImage.app_id == end_user.app_id, UserGeneratedImage.end_user_id == end_user.id)

@ -21,16 +21,16 @@ class OrganizationService:
Returns: Returns:
Organization or None if no match found Organization or None if no match found
""" """
if not email or '@' not in email: if not email or "@" not in email:
return None return None
# Get email domain # Get email domain
email_domain = email.split('@')[-1].lower() email_domain = email.split("@")[-1].lower()
# Get active organizations for this tenant # Get active organizations for this tenant
organizations = ( organizations = (
db.session.query(Organization) db.session.query(Organization)
.filter(Organization.tenant_id == tenant_id, Organization.status == 'active') .filter(Organization.tenant_id == tenant_id, Organization.status == "active")
.all() .all()
) )
@ -186,7 +186,7 @@ class OrganizationService:
""" """
return ( return (
db.session.query(Organization) db.session.query(Organization)
.filter(Organization.tenant_id == tenant_id, Organization.status == 'active') .filter(Organization.tenant_id == tenant_id, Organization.status == "active")
.all() .all()
) )

@ -102,13 +102,13 @@ class StatsService:
date_range = [] date_range = []
current_date = start_date current_date = start_date
while current_date <= end_date: while current_date <= end_date:
date_range.append(current_date.strftime('%Y-%m-%d')) date_range.append(current_date.strftime("%Y-%m-%d"))
current_date += timedelta(days=1) current_date += timedelta(days=1)
daily_stats = [] daily_stats = []
for date_str in date_range: for date_str in date_range:
date = datetime.strptime(date_str, '%Y-%m-%d') date = datetime.strptime(date_str, "%Y-%m-%d")
next_date = date + timedelta(days=1) next_date = date + timedelta(days=1)
# Count active users (users who had a conversation on this date) # Count active users (users who had a conversation on this date)
@ -154,8 +154,8 @@ class StatsService:
active_user_ids_query = active_user_ids_query.filter(Message.organization_id == organization_id) active_user_ids_query = active_user_ids_query.filter(Message.organization_id == organization_id)
# Get the intersection to find active new users # Get the intersection to find active new users
new_user_ids = [user_id for user_id, in new_user_ids_query.all()] new_user_ids = [user_id for (user_id,) in new_user_ids_query.all()]
active_user_ids = [user_id for user_id, in active_user_ids_query.all()] active_user_ids = [user_id for (user_id,) in active_user_ids_query.all()]
# Count users who appear in both lists (created today AND active today) # Count users who appear in both lists (created today AND active today)
active_new_users = len(set(new_user_ids).intersection(set(active_user_ids))) active_new_users = len(set(new_user_ids).intersection(set(active_user_ids)))
@ -184,13 +184,13 @@ class StatsService:
date_range = [] date_range = []
current_date = start_date current_date = start_date
while current_date <= end_date: while current_date <= end_date:
date_range.append(current_date.strftime('%Y-%m-%d')) date_range.append(current_date.strftime("%Y-%m-%d"))
current_date += timedelta(days=1) current_date += timedelta(days=1)
daily_stats = [] daily_stats = []
for date_str in date_range: for date_str in date_range:
date = datetime.strptime(date_str, '%Y-%m-%d') date = datetime.strptime(date_str, "%Y-%m-%d")
next_date = date + timedelta(days=1) next_date = date + timedelta(days=1)
# Count total conversations for this date # Count total conversations for this date

@ -48,11 +48,7 @@ def generate_image_task(
raise Exception(f"End user {end_user_id} not found") raise Exception(f"End user {end_user_id} not found")
# Get the existing UserGeneratedImage entity # Get the existing UserGeneratedImage entity
user_generated_image = ( user_generated_image = db.session.query(UserGeneratedImage).filter(UserGeneratedImage.id == image_id).first()
db.session.query(UserGeneratedImage)
.filter(UserGeneratedImage.id == image_id)
.first()
)
if not user_generated_image: if not user_generated_image:
raise Exception(f"UserGeneratedImage {image_id} not found") raise Exception(f"UserGeneratedImage {image_id} not found")
@ -67,16 +63,10 @@ def generate_image_task(
db.session.commit() db.session.commit()
raise Exception("Image generation app id is not set") raise Exception("Image generation app id is not set")
image_generation_app_model = ( image_generation_app_model = db.session.query(App).filter(App.id == dify_config.IMAGE_GENERATION_APP_ID).first()
db.session.query(App)
.filter(App.id == dify_config.IMAGE_GENERATION_APP_ID)
.first()
)
if image_generation_app_model is None: if image_generation_app_model is None:
user_generated_image.status = "failed" user_generated_image.status = "failed"
user_generated_image.error_message = ( user_generated_image.error_message = "Image generation app model is not found"
"Image generation app model is not found"
)
db.session.commit() db.session.commit()
raise Exception("Image generation app model is not found") raise Exception("Image generation app model is not found")
@ -93,10 +83,7 @@ def generate_image_task(
.all() .all()
) )
recent_messages = [ recent_messages = [f"user: {message.query}\n\nassistant: {message.answer}" for message in recent_messages]
f"user: {message.query}\n\nassistant: {message.answer}"
for message in recent_messages
]
# Prepare arguments for generation # Prepare arguments for generation
args = { args = {
@ -167,9 +154,7 @@ def generate_image_task(
# Update status to failed if we have the entity # Update status to failed if we have the entity
try: try:
user_generated_image = ( user_generated_image = (
db.session.query(UserGeneratedImage) db.session.query(UserGeneratedImage).filter(UserGeneratedImage.id == image_id).first()
.filter(UserGeneratedImage.id == image_id)
.first()
) )
if user_generated_image: if user_generated_image:
user_generated_image.status = "failed" user_generated_image.status = "failed"

@ -12,69 +12,69 @@ import requests
class RegistrationTester: class RegistrationTester:
"""Test class for registration API.""" """Test class for registration API."""
def __init__(self, base_url: str = "http://localhost:5001"): def __init__(self, base_url: str = "http://localhost:5001"):
self.base_url = base_url self.base_url = base_url
self.session = requests.Session() self.session = requests.Session()
def test_send_verification_code(self, email: str) -> dict[str, Any]: def test_send_verification_code(self, email: str) -> dict[str, Any]:
"""Test sending verification code.""" """Test sending verification code."""
print(f"🔵 Testing verification code send for: {email}") print(f"🔵 Testing verification code send for: {email}")
response = self.session.post( response = self.session.post(
f"{self.base_url}/service/auth/email-code-login", f"{self.base_url}/service/auth/email-code-login",
json={"email": email}, json={"email": email},
headers={"Content-Type": "application/json"} headers={"Content-Type": "application/json"},
) )
print(f" Status: {response.status_code}") print(f" Status: {response.status_code}")
print(f" Response: {response.text}") print(f" Response: {response.text}")
return { return {
"status_code": response.status_code, "status_code": response.status_code,
"response": ( "response": (
response.json() response.json()
if response.headers.get('content-type', '').startswith('application/json') if response.headers.get("content-type", "").startswith("application/json")
else response.text else response.text
) ),
} }
def test_registration_with_code(self, email: str, code: str, token: str) -> dict[str, Any]: def test_registration_with_code(self, email: str, code: str, token: str) -> dict[str, Any]:
"""Test registration with verification code.""" """Test registration with verification code."""
print(f"🔵 Testing registration for: {email} with code: {code}") print(f"🔵 Testing registration for: {email} with code: {code}")
response = self.session.post( response = self.session.post(
f"{self.base_url}/service/auth/email-code-login/validity", f"{self.base_url}/service/auth/email-code-login/validity",
json={"email": email, "code": code, "token": token}, json={"email": email, "code": code, "token": token},
headers={"Content-Type": "application/json"} headers={"Content-Type": "application/json"},
) )
print(f" Status: {response.status_code}") print(f" Status: {response.status_code}")
print(f" Response: {response.text}") print(f" Response: {response.text}")
return { return {
"status_code": response.status_code, "status_code": response.status_code,
"response": ( "response": (
response.json() response.json()
if response.headers.get('content-type', '').startswith('application/json') if response.headers.get("content-type", "").startswith("application/json")
else response.text else response.text
) ),
} }
def test_verification_code_sending(self, email: str) -> dict[str, Any]: def test_verification_code_sending(self, email: str) -> dict[str, Any]:
"""Test verification code sending (first step only).""" """Test verification code sending (first step only)."""
print(f"\n🚀 Testing verification code sending for: {email}") print(f"\n🚀 Testing verification code sending for: {email}")
print("=" * 50) print("=" * 50)
# Step 1: Send verification code # Step 1: Send verification code
send_result = self.test_send_verification_code(email) send_result = self.test_send_verification_code(email)
if send_result["status_code"] != 200: if send_result["status_code"] != 200:
print(f"❌ Failed to send verification code for {email}") print(f"❌ Failed to send verification code for {email}")
return send_result return send_result
print(f"✅ Verification code sent successfully for {email}") print(f"✅ Verification code sent successfully for {email}")
# Extract token from send result # Extract token from send result
token = None token = None
if isinstance(send_result["response"], dict) and "data" in send_result["response"]: if isinstance(send_result["response"], dict) and "data" in send_result["response"]:
@ -85,23 +85,23 @@ class RegistrationTester:
else: else:
print("❌ No token received from verification code send") print("❌ No token received from verification code send")
return {"status_code": 400, "response": "No token received"} return {"status_code": 400, "response": "No token received"}
return {"status_code": 200, "response": "Verification code sent successfully", "token": token} return {"status_code": 200, "response": "Verification code sent successfully", "token": token}
def test_registration_flow_interactive(self, email: str) -> dict[str, Any]: def test_registration_flow_interactive(self, email: str) -> dict[str, Any]:
"""Test full registration flow with user input for verification code.""" """Test full registration flow with user input for verification code."""
print(f"\n🚀 Testing INTERACTIVE registration flow for: {email}") print(f"\n🚀 Testing INTERACTIVE registration flow for: {email}")
print("=" * 50) print("=" * 50)
# Step 1: Send verification code # Step 1: Send verification code
send_result = self.test_send_verification_code(email) send_result = self.test_send_verification_code(email)
if send_result["status_code"] != 200: if send_result["status_code"] != 200:
print(f"❌ Failed to send verification code for {email}") print(f"❌ Failed to send verification code for {email}")
return send_result return send_result
print(f"✅ Verification code sent successfully for {email}") print(f"✅ Verification code sent successfully for {email}")
# Extract token from send result # Extract token from send result
token = None token = None
if isinstance(send_result["response"], dict) and "data" in send_result["response"]: if isinstance(send_result["response"], dict) and "data" in send_result["response"]:
@ -110,94 +110,91 @@ class RegistrationTester:
else: else:
print("❌ No token received from verification code send") print("❌ No token received from verification code send")
return {"status_code": 400, "response": "No token received"} return {"status_code": 400, "response": "No token received"}
# Step 2: Get verification code from user # Step 2: Get verification code from user
print(f"📧 A verification code has been sent to {email}") print(f"📧 A verification code has been sent to {email}")
verification_code = input("🔢 Please enter the verification code from your email: ").strip() verification_code = input("🔢 Please enter the verification code from your email: ").strip()
if not verification_code: if not verification_code:
print("❌ No verification code entered") print("❌ No verification code entered")
return {"status_code": 400, "response": "No verification code entered"} return {"status_code": 400, "response": "No verification code entered"}
register_result = self.test_registration_with_code(email, verification_code, token) register_result = self.test_registration_with_code(email, verification_code, token)
if register_result["status_code"] == 200: if register_result["status_code"] == 200:
print(f"✅ Registration successful for {email}") print(f"✅ Registration successful for {email}")
else: else:
print(f"❌ Registration failed for {email}") print(f"❌ Registration failed for {email}")
return register_result return register_result
def run_comprehensive_tests(self): def run_comprehensive_tests(self):
"""Run comprehensive tests for different email scenarios.""" """Run comprehensive tests for different email scenarios."""
print("🧪 Running Comprehensive Email Verification Code Tests") print("🧪 Running Comprehensive Email Verification Code Tests")
print("=" * 60) print("=" * 60)
print(" This tests verification code sending (step 1 of registration)") print(" This tests verification code sending (step 1 of registration)")
print(" For full registration testing, use interactive mode") print(" For full registration testing, use interactive mode")
test_cases = [ test_cases = [
{ {
"email": "student@university.edu", "email": "student@university.edu",
"description": "University student (.edu domain)", "description": "University student (.edu domain)",
"expected_org": "Should be assigned to organization if domain match exists" "expected_org": "Should be assigned to organization if domain match exists",
}, },
{ {
"email": "user@gmail.com", "email": "user@gmail.com",
"description": "Personal Gmail account", "description": "Personal Gmail account",
"expected_org": "Should register without organization assignment" "expected_org": "Should register without organization assignment",
}, },
{ {
"email": "employee@company.com", "email": "employee@company.com",
"description": "Company email without pre-configured organization", "description": "Company email without pre-configured organization",
"expected_org": "Should register without organization assignment" "expected_org": "Should register without organization assignment",
}, },
{ {
"email": "admin@startup.io", "email": "admin@startup.io",
"description": "Startup email", "description": "Startup email",
"expected_org": "Should register without organization assignment" "expected_org": "Should register without organization assignment",
}, },
{ {
"email": "test@protonmail.com", "email": "test@protonmail.com",
"description": "ProtonMail account", "description": "ProtonMail account",
"expected_org": "Should register without organization assignment" "expected_org": "Should register without organization assignment",
} },
] ]
results = [] results = []
for i, test_case in enumerate(test_cases, 1): for i, test_case in enumerate(test_cases, 1):
print(f"\n📋 Test Case {i}: {test_case['description']}") print(f"\n📋 Test Case {i}: {test_case['description']}")
print(f" Email: {test_case['email']}") print(f" Email: {test_case['email']}")
print(f" Expected: {test_case['expected_org']}") print(f" Expected: {test_case['expected_org']}")
result = self.test_verification_code_sending(test_case["email"]) result = self.test_verification_code_sending(test_case["email"])
results.append({ results.append({"test_case": test_case, "result": result})
"test_case": test_case,
"result": result
})
# Small delay between tests # Small delay between tests
time.sleep(0.5) time.sleep(0.5)
# Summary # Summary
print("\n📊 Test Results Summary") print("\n📊 Test Results Summary")
print("=" * 30) print("=" * 30)
for i, test_result in enumerate(results, 1): for i, test_result in enumerate(results, 1):
test_case = test_result["test_case"] test_case = test_result["test_case"]
result = test_result["result"] result = test_result["result"]
status = "✅ PASSED" if result["status_code"] == 200 else "❌ FAILED" status = "✅ PASSED" if result["status_code"] == 200 else "❌ FAILED"
print(f"{i}. {test_case['description']}: {status}") print(f"{i}. {test_case['description']}: {status}")
if result["status_code"] != 200: if result["status_code"] != 200:
print(f" Error: {result['response']}") print(f" Error: {result['response']}")
# Overall summary # Overall summary
passed = sum(1 for r in results if r["result"]["status_code"] == 200) passed = sum(1 for r in results if r["result"]["status_code"] == 200)
total = len(results) total = len(results)
print(f"\n🎯 Overall: {passed}/{total} verification code tests passed") print(f"\n🎯 Overall: {passed}/{total} verification code tests passed")
return results return results
@ -205,9 +202,9 @@ def main():
"""Main function to run tests.""" """Main function to run tests."""
print("🔍 Email Registration API Test Suite") print("🔍 Email Registration API Test Suite")
print("=" * 40) print("=" * 40)
tester = RegistrationTester() tester = RegistrationTester()
# Check if server is running # Check if server is running
try: try:
response = requests.get("http://localhost:5001/health", timeout=5) response = requests.get("http://localhost:5001/health", timeout=5)
@ -216,14 +213,14 @@ def main():
print("❌ Server is not running on port 5001") print("❌ Server is not running on port 5001")
print(" Please start the server with: uv run flask run --host 0.0.0.0 --port=5001 --debug") print(" Please start the server with: uv run flask run --host 0.0.0.0 --port=5001 --debug")
return return
# Run comprehensive tests # Run comprehensive tests
results = tester.run_comprehensive_tests() results = tester.run_comprehensive_tests()
# Additional specific tests # Additional specific tests
print("\n🔬 Additional Tests") print("\n🔬 Additional Tests")
print("=" * 20) print("=" * 20)
# Test invalid email # Test invalid email
print("\n🔵 Testing invalid email format") print("\n🔵 Testing invalid email format")
invalid_result = tester.test_send_verification_code("invalid.email") invalid_result = tester.test_send_verification_code("invalid.email")
@ -231,7 +228,7 @@ def main():
print("✅ Invalid email correctly rejected") print("✅ Invalid email correctly rejected")
else: else:
print("❌ Invalid email was accepted (should be rejected)") print("❌ Invalid email was accepted (should be rejected)")
# Test registration with invalid code # Test registration with invalid code
print("\n🔵 Testing registration with invalid code") print("\n🔵 Testing registration with invalid code")
# First get a valid token # First get a valid token
@ -245,11 +242,11 @@ def main():
print("❌ Invalid verification code was accepted (should be rejected)") print("❌ Invalid verification code was accepted (should be rejected)")
else: else:
print("⚠️ Could not test invalid code - failed to get token") print("⚠️ Could not test invalid code - failed to get token")
print("\n🎉 All tests completed!") print("\n🎉 All tests completed!")
print("\n Note: This test uses mock verification codes.") print("\n Note: This test uses mock verification codes.")
print(" In production, users would receive actual codes via email.") print(" In production, users would receive actual codes via email.")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

Loading…
Cancel
Save