diff --git a/README.md b/README.md index 1dc7e2dd98..2909e0e6cf 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ README in বাংলা

-Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features, and more, allowing you to quickly move from prototype to production. +Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production. ## Quick start @@ -65,7 +65,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
-The easiest way to start the Dify server is through [docker compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: +The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine: ```bash cd dify @@ -205,6 +205,7 @@ If you'd like to configure a highly-available setup, there are community-contrib - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Using Terraform for Deployment @@ -261,8 +262,8 @@ At the same time, please consider supporting Dify by sharing it on social media ## Security disclosure -To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer. +To protect your privacy, please avoid posting security issues on GitHub. Instead, report issues to security@dify.ai, and our team will respond with detailed answer. ## License -This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions. +This repository is licensed under the [Dify Open Source License](LICENSE), based on Apache 2.0 with additional conditions. diff --git a/README_AR.md b/README_AR.md index d93bca8646..e959ca0f78 100644 --- a/README_AR.md +++ b/README_AR.md @@ -188,6 +188,7 @@ docker compose up -d - [رسم بياني Helm من قبل @magicsong](https://github.com/magicsong/ai-charts) - [ملف YAML من قبل @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [ملف YAML من قبل @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 جديد! ملفات YAML (تدعم Dify v1.6.0) بواسطة @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### استخدام Terraform للتوزيع diff --git a/README_BN.md b/README_BN.md index 3efee3684d..29d7374ea5 100644 --- a/README_BN.md +++ b/README_BN.md @@ -204,6 +204,8 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 নতুন! YAML ফাইলসমূহ (Dify v1.6.0 সমর্থিত) তৈরি করেছেন @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) + #### টেরাফর্ম ব্যবহার করে ডিপ্লয় diff --git a/README_CN.md b/README_CN.md index 21e27429ec..486a368c09 100644 --- a/README_CN.md +++ b/README_CN.md @@ -194,9 +194,9 @@ docker compose up -d 如果您需要自定义配置,请参考 [.env.example](docker/.env.example) 文件中的注释,并更新 `.env` 文件中对应的值。此外,您可能需要根据您的具体部署环境和需求对 `docker-compose.yaml` 文件本身进行调整,例如更改镜像版本、端口映射或卷挂载。完成任何更改后,请重新运行 `docker-compose up -d`。您可以在[此处](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用环境变量的完整列表。 -#### 使用 Helm Chart 部署 +#### 使用 Helm Chart 或 Kubernetes 资源清单(YAML)部署 -使用 [Helm Chart](https://helm.sh/) 版本或者 YAML 文件,可以在 Kubernetes 上部署 Dify。 +使用 [Helm Chart](https://helm.sh/) 版本或者 Kubernetes 资源清单(YAML),可以在 Kubernetes 上部署 Dify。 - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) - [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) @@ -204,6 +204,10 @@ docker compose up -d - [YAML 文件 by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML 文件 (支持 Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) + + + #### 使用 Terraform 部署 使用 [terraform](https://www.terraform.io/) 一键将 Dify 部署到云平台 diff --git a/README_DE.md b/README_DE.md index 20c313035e..fce52c34c2 100644 --- a/README_DE.md +++ b/README_DE.md @@ -203,6 +203,7 @@ Falls Sie eine hochverfügbare Konfiguration einrichten möchten, gibt es von de - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Terraform für die Bereitstellung verwenden diff --git a/README_ES.md b/README_ES.md index e4b7df6686..6fd6dfcee8 100644 --- a/README_ES.md +++ b/README_ES.md @@ -203,6 +203,7 @@ Si desea configurar una configuración de alta disponibilidad, la comunidad prop - [Gráfico Helm por @magicsong](https://github.com/magicsong/ai-charts) - [Ficheros YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [Ficheros YAML por @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 ¡NUEVO! Archivos YAML (compatible con Dify v1.6.0) por @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Uso de Terraform para el despliegue diff --git a/README_FR.md b/README_FR.md index 8fd17fb7c3..b2209fb495 100644 --- a/README_FR.md +++ b/README_FR.md @@ -201,6 +201,7 @@ Si vous souhaitez configurer une configuration haute disponibilité, la communau - [Helm Chart par @magicsong](https://github.com/magicsong/ai-charts) - [Fichier YAML par @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [Fichier YAML par @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NOUVEAU ! Fichiers YAML (compatible avec Dify v1.6.0) par @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Utilisation de Terraform pour le déploiement diff --git a/README_JA.md b/README_JA.md index a3ee81e1f2..c658225f90 100644 --- a/README_JA.md +++ b/README_JA.md @@ -202,6 +202,7 @@ docker compose up -d - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 新着!YAML ファイル(Dify v1.6.0 対応)by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Terraformを使用したデプロイ diff --git a/README_KL.md b/README_KL.md index 3e5ab1a74f..bfafcc7407 100644 --- a/README_KL.md +++ b/README_KL.md @@ -201,6 +201,7 @@ If you'd like to configure a highly-available setup, there are community-contrib - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Terraform atorlugu pilersitsineq diff --git a/README_KR.md b/README_KR.md index 3c504900e1..282117e776 100644 --- a/README_KR.md +++ b/README_KR.md @@ -195,6 +195,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Terraform을 사용한 배포 diff --git a/README_PT.md b/README_PT.md index fb5f3662ae..576f6b48f7 100644 --- a/README_PT.md +++ b/README_PT.md @@ -200,6 +200,7 @@ Se deseja configurar uma instalação de alta disponibilidade, há [Helm Charts] - [Helm Chart de @magicsong](https://github.com/magicsong/ai-charts) - [Arquivo YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [Arquivo YAML por @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NOVO! Arquivos YAML (Compatível com Dify v1.6.0) por @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Usando o Terraform para Implantação diff --git a/README_SI.md b/README_SI.md index 647069a220..7ded001d86 100644 --- a/README_SI.md +++ b/README_SI.md @@ -201,6 +201,7 @@ Star Dify on GitHub and be instantly notified of new releases. - [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Uporaba Terraform za uvajanje diff --git a/README_TR.md b/README_TR.md index f52335646a..6e94e54fa0 100644 --- a/README_TR.md +++ b/README_TR.md @@ -194,6 +194,7 @@ Yüksek kullanılabilirliğe sahip bir kurulum yapılandırmak isterseniz, Dify' - [@BorisPolonsky tarafından Helm Chart](https://github.com/BorisPolonsky/dify-helm) - [@Winson-030 tarafından YAML dosyası](https://github.com/Winson-030/dify-kubernetes) - [@wyy-holding tarafından YAML dosyası](https://github.com/wyy-holding/dify-k8s) +- [🚀 YENİ! YAML dosyaları (Dify v1.6.0 destekli) @Zhoneym tarafından](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Dağıtım için Terraform Kullanımı diff --git a/README_TW.md b/README_TW.md index 71082ff893..6e3e22b5c1 100644 --- a/README_TW.md +++ b/README_TW.md @@ -197,12 +197,13 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify 如果您需要自定義配置,請參考我們的 [.env.example](docker/.env.example) 文件中的註釋,並在您的 `.env` 文件中更新相應的值。此外,根據您特定的部署環境和需求,您可能需要調整 `docker-compose.yaml` 文件本身,例如更改映像版本、端口映射或卷掛載。進行任何更改後,請重新運行 `docker-compose up -d`。您可以在[這裡](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用環境變數的完整列表。 -如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 YAML 文件允許在 Kubernetes 上部署 Dify。 +如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 Kubernetes 資源清單(YAML)允許在 Kubernetes 上部署 Dify。 - [由 @LeoQuote 提供的 Helm Chart](https://github.com/douban/charts/tree/master/charts/dify) - [由 @BorisPolonsky 提供的 Helm Chart](https://github.com/BorisPolonsky/dify-helm) - [由 @Winson-030 提供的 YAML 文件](https://github.com/Winson-030/dify-kubernetes) - [由 @wyy-holding 提供的 YAML 文件](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML 檔案(支援 Dify v1.6.0)by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) ### 使用 Terraform 進行部署 diff --git a/README_VI.md b/README_VI.md index 58d8434fff..51314e6de5 100644 --- a/README_VI.md +++ b/README_VI.md @@ -196,6 +196,7 @@ Nếu bạn muốn cấu hình một cài đặt có độ sẵn sàng cao, có - [Helm Chart bởi @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) - [Tệp YAML bởi @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [Tệp YAML bởi @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 MỚI! Tệp YAML (Hỗ trợ Dify v1.6.0) bởi @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Sử dụng Terraform để Triển khai diff --git a/api/.env.example b/api/.env.example index 7b08c032ed..c2d8c78a7d 100644 --- a/api/.env.example +++ b/api/.env.example @@ -449,6 +449,19 @@ MAX_VARIABLE_SIZE=204800 # hybrid: Save new data to object storage, read from both object storage and RDBMS WORKFLOW_NODE_EXECUTION_STORAGE=rdbms +# Repository configuration +# Core workflow execution repository implementation +CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository + +# Core workflow node execution repository implementation +CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository + +# API workflow node execution repository implementation +API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository + +# API workflow run repository implementation +API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository + # App configuration APP_MAX_EXECUTION_TIME=1200 APP_MAX_ACTIVE_REQUESTS=0 diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 2fd9f94e06..f1d529355d 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -546,6 +546,33 @@ class WorkflowNodeExecutionConfig(BaseSettings): ) +class RepositoryConfig(BaseSettings): + """ + Configuration for repository implementations + """ + + CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field( + description="Repository implementation for WorkflowExecution. Specify as a module path", + default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository", + ) + + CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( + description="Repository implementation for WorkflowNodeExecution. Specify as a module path", + default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository", + ) + + API_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( + description="Service-layer repository implementation for WorkflowNodeExecutionModel operations. " + "Specify as a module path", + default="repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository", + ) + + API_WORKFLOW_RUN_REPOSITORY: str = Field( + description="Service-layer repository implementation for WorkflowRun operations. Specify as a module path", + default="repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository", + ) + + class AuthConfig(BaseSettings): """ Configuration for authentication and OAuth @@ -922,6 +949,7 @@ class FeatureConfig( MultiModalTransferConfig, PositionConfig, RagEtlConfig, + RepositoryConfig, SecurityConfig, ToolConfig, UpdateConfig, diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index ccda97d80c..0f53860f56 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -92,7 +92,8 @@ class AppMCPServerRefreshController(Resource): raise NotFound() server = ( db.session.query(AppMCPServer) - .filter(AppMCPServer.id == server_id and AppMCPServer.tenant_id == current_user.current_tenant_id) + .filter(AppMCPServer.id == server_id) + .filter(AppMCPServer.tenant_id == current_user.current_tenant_id) .first() ) if not server: diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 86aed77412..32b64d10c5 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -2,6 +2,7 @@ from datetime import datetime from decimal import Decimal import pytz +import sqlalchemy as sa from flask import jsonify from flask_login import current_user from flask_restful import Resource, reqparse @@ -9,10 +10,11 @@ from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required +from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.helper import DatetimeString from libs.login import login_required -from models.model import AppMode +from models import AppMode, Message class DailyMessageStatistic(Resource): @@ -85,46 +87,41 @@ class DailyConversationStatistic(Resource): parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT - DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - COUNT(DISTINCT messages.conversation_id) AS conversation_count -FROM - messages -WHERE - app_id = :app_id""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id} - timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc + stmt = ( + sa.select( + sa.func.date( + sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz")) + ).label("date"), + sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"), + ) + .select_from(Message) + .where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value) + ) + if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - - sql_query += " AND created_at >= :start" - arg_dict["start"] = start_datetime_utc + stmt = stmt.where(Message.created_at >= start_datetime_utc) if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) + stmt = stmt.where(Message.created_at < end_datetime_utc) - sql_query += " AND created_at < :end" - arg_dict["end"] = end_datetime_utc - - sql_query += " GROUP BY date ORDER BY date" + stmt = stmt.group_by("date").order_by("date") response_data = [] - with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) - for i in rs: - response_data.append({"date": str(i.date), "conversation_count": i.conversation_count}) + rs = conn.execute(stmt, {"tz": account.timezone}) + for row in rs: + response_data.append({"date": str(row.date), "conversation_count": row.conversation_count}) return jsonify({"data": response_data}) diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py index 03b60610aa..3322350e25 100644 --- a/api/controllers/console/app/wraps.py +++ b/api/controllers/console/app/wraps.py @@ -35,8 +35,6 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[ raise AppNotFoundError() app_mode = AppMode.value_of(app_model.mode) - if app_mode == AppMode.CHANNEL: - raise AppNotFoundError() if mode is not None: if isinstance(mode, list): diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index efb4acc5fb..ac2ebf2b09 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -3,7 +3,7 @@ import logging from dateutil.parser import isoparse from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful.inputs import int_range -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import InternalServerError from controllers.service_api import api @@ -30,7 +30,7 @@ from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from libs import helper from libs.helper import TimestampField from models.model import App, AppMode, EndUser -from models.workflow import WorkflowRun +from repositories.factory import DifyAPIRepositoryFactory from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError from services.workflow_app_service import WorkflowAppService @@ -63,7 +63,15 @@ class WorkflowRunDetailApi(Resource): if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]: raise NotWorkflowAppError() - workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() + # Use repository to get workflow run + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + + workflow_run = workflow_run_repo.get_workflow_run_by_id( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + run_id=workflow_run_id, + ) return workflow_run diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 0d304de97a..28bf4a9a23 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -3,6 +3,8 @@ import logging import uuid from typing import Optional, Union, cast +from sqlalchemy import select + from core.agent.entities import AgentEntity, AgentToolEntity from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig @@ -417,12 +419,15 @@ class BaseAgentRunner(AppRunner): if isinstance(prompt_message, SystemPromptMessage): result.append(prompt_message) - messages: list[Message] = ( - db.session.query(Message) - .filter( - Message.conversation_id == self.message.conversation_id, + messages = ( + ( + db.session.execute( + select(Message) + .where(Message.conversation_id == self.message.conversation_id) + .order_by(Message.created_at.desc()) + ) ) - .order_by(Message.created_at.desc()) + .scalars() .all() ) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 7877408cef..4b8f5ebe27 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -25,8 +25,7 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.repositories.draft_variable_repository import ( DraftVariableSaverFactory, ) @@ -183,14 +182,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING else: workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=workflow_triggered_from, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -260,14 +259,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -343,14 +342,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 40a1e272a7..2f9632e97d 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -23,8 +23,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository @@ -156,14 +155,14 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING else: workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=workflow_triggered_from, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -306,16 +305,14 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -390,16 +387,14 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 2a85cd5e3d..c6b326d8a4 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -3,7 +3,6 @@ import time from collections.abc import Generator from typing import Optional, Union -from sqlalchemy import select from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME @@ -68,7 +67,6 @@ from models.workflow import ( Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, - WorkflowRun, ) logger = logging.getLogger(__name__) @@ -562,8 +560,6 @@ class WorkflowAppGenerateTaskPipeline: tts_publisher.publish(None) def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: - workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_)) - assert workflow_run is not None invoke_from = self._application_generate_entity.invoke_from if invoke_from == InvokeFrom.SERVICE_API: created_from = WorkflowAppLogCreatedFrom.SERVICE_API @@ -576,10 +572,10 @@ class WorkflowAppGenerateTaskPipeline: return workflow_app_log = WorkflowAppLog() - workflow_app_log.tenant_id = workflow_run.tenant_id - workflow_app_log.app_id = workflow_run.app_id - workflow_app_log.workflow_id = workflow_run.workflow_id - workflow_app_log.workflow_run_id = workflow_run.id + workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id + workflow_app_log.app_id = self._application_generate_entity.app_config.app_id + workflow_app_log.workflow_id = workflow_execution.workflow_id + workflow_app_log.workflow_run_id = workflow_execution.id_ workflow_app_log.created_from = created_from.value workflow_app_log.created_by_role = self._created_by_role workflow_app_log.created_by = self._user_id diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index 84f212a9c1..b416e48ce4 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,6 +5,8 @@ from base64 import b64encode from collections.abc import Mapping from typing import Any +from core.variables.utils import SegmentJSONEncoder + class TemplateTransformer(ABC): _code_placeholder: str = "{{code}}" @@ -43,17 +45,13 @@ class TemplateTransformer(ABC): result_str = cls.extract_result_str_from_response(response) result = json.loads(result_str) except json.JSONDecodeError as e: - raise ValueError(f"Failed to parse JSON response: {str(e)}. Response content: {result_str[:200]}...") + raise ValueError(f"Failed to parse JSON response: {str(e)}.") except ValueError as e: # Re-raise ValueError from extract_result_str_from_response raise e except Exception as e: raise ValueError(f"Unexpected error during response transformation: {str(e)}") - # Check if the result contains an error - if isinstance(result, dict) and "error" in result: - raise ValueError(f"JavaScript execution error: {result['error']}") - if not isinstance(result, dict): raise ValueError(f"Result must be a dict, got {type(result).__name__}") if not all(isinstance(k, str) for k in result): @@ -95,7 +93,7 @@ class TemplateTransformer(ABC): @classmethod def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str: - inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode() + inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode() input_base64_encoded = b64encode(inputs_json_str).decode("utf-8") return input_base64_encoded diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 2254b3d4d5..a9f0a92e5d 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,6 +1,8 @@ from collections.abc import Sequence from typing import Optional +from sqlalchemy import select + from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.file import file_manager from core.model_manager import ModelInstance @@ -17,11 +19,15 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile -from models.workflow import WorkflowRun +from models.workflow import Workflow, WorkflowRun class TokenBufferMemory: - def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None: + def __init__( + self, + conversation: Conversation, + model_instance: ModelInstance, + ) -> None: self.conversation = conversation self.model_instance = model_instance @@ -36,20 +42,8 @@ class TokenBufferMemory: app_record = self.conversation.app # fetch limited messages, and return reversed - query = ( - db.session.query( - Message.id, - Message.query, - Message.answer, - Message.created_at, - Message.workflow_run_id, - Message.parent_message_id, - Message.answer_tokens, - ) - .filter( - Message.conversation_id == self.conversation.id, - ) - .order_by(Message.created_at.desc()) + stmt = ( + select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc()) ) if message_limit and message_limit > 0: @@ -57,7 +51,9 @@ class TokenBufferMemory: else: message_limit = 500 - messages = query.limit(message_limit).all() + stmt = stmt.limit(message_limit) + + messages = db.session.scalars(stmt).all() # instead of all messages from the conversation, we only need to extract messages # that belong to the thread of last message @@ -74,18 +70,20 @@ class TokenBufferMemory: files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all() if files: file_extra_config = None - if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}: file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config) + elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}: + workflow_run = db.session.scalar( + select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id) + ) + if not workflow_run: + raise ValueError(f"Workflow run not found: {message.workflow_run_id}") + workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id)) + if not workflow: + raise ValueError(f"Workflow not found: {workflow_run.workflow_id}") + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) else: - if message.workflow_run_id: - workflow_run = ( - db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first() - ) - - if workflow_run and workflow_run.workflow: - file_extra_config = FileUploadConfigManager.convert( - workflow_run.workflow.features_dict, is_vision=False - ) + raise AssertionError(f"Invalid app mode: {self.conversation.mode}") detail = ImagePromptMessageContent.DETAIL.LOW if file_extra_config and app_record: diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index a3dbce0e59..4a7e66d27c 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.utils import filter_none_values -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models import EndUser, WorkflowNodeExecutionTriggeredFrom @@ -123,10 +123,10 @@ class LangFuseDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index f94e5e49d7..8a559c4929 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -27,7 +27,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.utils import filter_none_values, generate_dotted_order -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -145,10 +145,10 @@ class LangSmithDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 8bedea20fb..be4997a5bf 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -160,10 +160,10 @@ class OpikDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) @@ -241,7 +241,7 @@ class OpikDataTrace(BaseTraceInstance): "trace_id": opik_trace_id, "id": prepare_opik_uuid(created_at, node_execution_id), "parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id), - "name": node_type, + "name": node_name, "type": run_type, "start_time": created_at, "end_time": finished_at, diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 3917348a91..445c6a8741 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -144,10 +144,10 @@ class WeaveDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) diff --git a/api/core/plugin/impl/plugin.py b/api/core/plugin/impl/plugin.py index b7f7b31655..04ac8c9649 100644 --- a/api/core/plugin/impl/plugin.py +++ b/api/core/plugin/impl/plugin.py @@ -36,7 +36,7 @@ class PluginInstaller(BasePluginClient): "GET", f"plugin/{tenant_id}/management/list", PluginListResponse, - params={"page": 1, "page_size": 256}, + params={"page": 1, "page_size": 256, "response_type": "paged"}, ) return result.list @@ -45,7 +45,7 @@ class PluginInstaller(BasePluginClient): "GET", f"plugin/{tenant_id}/management/list", PluginListResponse, - params={"page": page, "page_size": page_size}, + params={"page": page, "page_size": page_size, "response_type": "paged"}, ) def upload_pkg( diff --git a/api/core/prompt/utils/extract_thread_messages.py b/api/core/prompt/utils/extract_thread_messages.py index f7aef76c87..4b883622a7 100644 --- a/api/core/prompt/utils/extract_thread_messages.py +++ b/api/core/prompt/utils/extract_thread_messages.py @@ -1,10 +1,11 @@ -from typing import Any +from collections.abc import Sequence from constants import UUID_NIL +from models import Message -def extract_thread_messages(messages: list[Any]): - thread_messages = [] +def extract_thread_messages(messages: Sequence[Message]): + thread_messages: list[Message] = [] next_message = None for message in messages: diff --git a/api/core/prompt/utils/get_thread_messages_length.py b/api/core/prompt/utils/get_thread_messages_length.py index f49466db6d..de64c27a73 100644 --- a/api/core/prompt/utils/get_thread_messages_length.py +++ b/api/core/prompt/utils/get_thread_messages_length.py @@ -1,3 +1,5 @@ +from sqlalchemy import select + from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db from models.model import Message @@ -8,19 +10,9 @@ def get_thread_messages_length(conversation_id: str) -> int: Get the number of thread messages based on the parent message id. """ # Fetch all messages related to the conversation - query = ( - db.session.query( - Message.id, - Message.parent_message_id, - Message.answer, - ) - .filter( - Message.conversation_id == conversation_id, - ) - .order_by(Message.created_at.desc()) - ) + stmt = select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at.desc()) - messages = query.all() + messages = db.session.scalars(stmt).all() # Extract thread messages thread_messages = extract_thread_messages(messages) diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 2c5178241c..5a6903d3d5 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -3,7 +3,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Optional from flask import Flask, current_app -from sqlalchemy.orm import load_only +from sqlalchemy.orm import Session, load_only from configs import dify_config from core.rag.data_post_processor.data_post_processor import DataPostProcessor @@ -144,7 +144,8 @@ class RetrievalService: @classmethod def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]: - return db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + with Session(db.engine) as session: + return session.query(Dataset).filter(Dataset.id == dataset_id).first() @classmethod def keyword_search( diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index a124faa503..552068c99e 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -4,6 +4,7 @@ from typing import Any, Optional import tablestore # type: ignore from pydantic import BaseModel, model_validator +from tablestore import BatchGetRowRequest, TableInBatchGetRowItem from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -50,6 +51,29 @@ class TableStoreVector(BaseVector): self._index_name = f"{collection_name}_idx" self._tags_field = f"{Field.METADATA_KEY.value}_tags" + def create_collection(self, embeddings: list[list[float]], **kwargs): + dimension = len(embeddings[0]) + self._create_collection(dimension) + + def get_by_ids(self, ids: list[str]) -> list[Document]: + docs = [] + request = BatchGetRowRequest() + columns_to_get = [Field.METADATA_KEY.value, Field.CONTENT_KEY.value] + rows_to_get = [[("id", _id)] for _id in ids] + request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1)) + + result = self._tablestore_client.batch_get_row(request) + table_result = result.get_result_by_table(self._table_name) + for item in table_result: + if item.is_ok and item.row: + kv = {k: v for k, v, t in item.row.attribute_columns} + docs.append( + Document( + page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value]) + ) + ) + return docs + def get_type(self) -> str: return VectorType.TABLESTORE diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 3fca48be22..5c0360b064 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -9,6 +9,7 @@ from typing import Any, Optional, Union, cast from flask import Flask, current_app from sqlalchemy import Float, and_, or_, text from sqlalchemy import cast as sqlalchemy_cast +from sqlalchemy.orm import Session from core.app.app_config.entities import ( DatasetEntity, @@ -598,7 +599,8 @@ class DatasetRetrieval: metadata_condition: Optional[MetadataCondition] = None, ): with flask_app.app_context(): - dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() + with Session(db.engine) as session: + dataset = session.query(Dataset).filter(Dataset.id == dataset_id).first() if not dataset: return [] diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index 6452317120..052ba1c2cb 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -5,8 +5,11 @@ This package contains concrete implementations of the repository interfaces defined in the core.workflow.repository package. """ +from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository __all__ = [ + "DifyCoreRepositoryFactory", + "RepositoryImportError", "SQLAlchemyWorkflowNodeExecutionRepository", ] diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py new file mode 100644 index 0000000000..4118aa61c7 --- /dev/null +++ b/api/core/repositories/factory.py @@ -0,0 +1,224 @@ +""" +Repository factory for dynamically creating repository instances based on configuration. + +This module provides a Django-like settings system for repository implementations, +allowing users to configure different repository backends through string paths. +""" + +import importlib +import inspect +import logging +from typing import Protocol, Union + +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from models import Account, EndUser +from models.enums import WorkflowRunTriggeredFrom +from models.workflow import WorkflowNodeExecutionTriggeredFrom + +logger = logging.getLogger(__name__) + + +class RepositoryImportError(Exception): + """Raised when a repository implementation cannot be imported or instantiated.""" + + pass + + +class DifyCoreRepositoryFactory: + """ + Factory for creating repository instances based on configuration. + + This factory supports Django-like settings where repository implementations + are specified as module paths (e.g., 'module.submodule.ClassName'). + """ + + @staticmethod + def _import_class(class_path: str) -> type: + """ + Import a class from a module path string. + + Args: + class_path: Full module path to the class (e.g., 'module.submodule.ClassName') + + Returns: + The imported class + + Raises: + RepositoryImportError: If the class cannot be imported + """ + try: + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + repo_class = getattr(module, class_name) + assert isinstance(repo_class, type) + return repo_class + except (ValueError, ImportError, AttributeError) as e: + raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e + + @staticmethod + def _validate_repository_interface(repository_class: type, expected_interface: type[Protocol]) -> None: # type: ignore + """ + Validate that a class implements the expected repository interface. + + Args: + repository_class: The class to validate + expected_interface: The expected interface/protocol + + Raises: + RepositoryImportError: If the class doesn't implement the interface + """ + # Check if the class has all required methods from the protocol + required_methods = [ + method + for method in dir(expected_interface) + if not method.startswith("_") and callable(getattr(expected_interface, method, None)) + ] + + missing_methods = [] + for method_name in required_methods: + if not hasattr(repository_class, method_name): + missing_methods.append(method_name) + + if missing_methods: + raise RepositoryImportError( + f"Repository class '{repository_class.__name__}' does not implement required methods " + f"{missing_methods} from interface '{expected_interface.__name__}'" + ) + + @staticmethod + def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None: + """ + Validate that a repository class constructor accepts required parameters. + + Args: + repository_class: The class to validate + required_params: List of required parameter names + + Raises: + RepositoryImportError: If the constructor doesn't accept required parameters + """ + + try: + # MyPy may flag the line below with the following error: + # + # > Accessing "__init__" on an instance is unsound, since + # > instance.__init__ could be from an incompatible subclass. + # + # Despite this, we need to ensure that the constructor of `repository_class` + # has a compatible signature. + signature = inspect.signature(repository_class.__init__) # type: ignore[misc] + param_names = list(signature.parameters.keys()) + + # Remove 'self' parameter + if "self" in param_names: + param_names.remove("self") + + missing_params = [param for param in required_params if param not in param_names] + if missing_params: + raise RepositoryImportError( + f"Repository class '{repository_class.__name__}' constructor does not accept required parameters: " + f"{missing_params}. Expected parameters: {required_params}" + ) + except Exception as e: + raise RepositoryImportError( + f"Failed to validate constructor signature for '{repository_class.__name__}': {e}" + ) from e + + @classmethod + def create_workflow_execution_repository( + cls, + session_factory: Union[sessionmaker, Engine], + user: Union[Account, EndUser], + app_id: str, + triggered_from: WorkflowRunTriggeredFrom, + ) -> WorkflowExecutionRepository: + """ + Create a WorkflowExecutionRepository instance based on configuration. + + Args: + session_factory: SQLAlchemy sessionmaker or engine + user: Account or EndUser object + app_id: Application ID + triggered_from: Source of the execution trigger + + Returns: + Configured WorkflowExecutionRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be created + """ + class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY + logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, WorkflowExecutionRepository) + cls._validate_constructor_signature( + repository_class, ["session_factory", "user", "app_id", "triggered_from"] + ) + + return repository_class( # type: ignore[no-any-return] + session_factory=session_factory, + user=user, + app_id=app_id, + triggered_from=triggered_from, + ) + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create WorkflowExecutionRepository") + raise RepositoryImportError(f"Failed to create WorkflowExecutionRepository from '{class_path}': {e}") from e + + @classmethod + def create_workflow_node_execution_repository( + cls, + session_factory: Union[sessionmaker, Engine], + user: Union[Account, EndUser], + app_id: str, + triggered_from: WorkflowNodeExecutionTriggeredFrom, + ) -> WorkflowNodeExecutionRepository: + """ + Create a WorkflowNodeExecutionRepository instance based on configuration. + + Args: + session_factory: SQLAlchemy sessionmaker or engine + user: Account or EndUser object + app_id: Application ID + triggered_from: Source of the execution trigger + + Returns: + Configured WorkflowNodeExecutionRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be created + """ + class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY + logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository) + cls._validate_constructor_signature( + repository_class, ["session_factory", "user", "app_id", "triggered_from"] + ) + + return repository_class( # type: ignore[no-any-return] + session_factory=session_factory, + user=user, + app_id=app_id, + triggered_from=triggered_from, + ) + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create WorkflowNodeExecutionRepository") + raise RepositoryImportError( + f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}" + ) from e diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 3f844e8234..a3c84615ca 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -1,5 +1,4 @@ import re -import uuid from json import dumps as json_dumps from json import loads as json_loads from json.decoder import JSONDecodeError @@ -154,7 +153,7 @@ class ApiBasedToolSchemaParser: # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ path = re.sub(r"[^a-zA-Z0-9_-]", "", path) if not path: - path = str(uuid.uuid4()) + path = "" interface["operation"]["operationId"] = f"{path}_{interface['method']}" diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index c447f433aa..8b566c83cd 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -521,18 +521,52 @@ class IterationNode(BaseNode[IterationNodeData]): ) return elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: - yield IterationRunFailedEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - start_at=start_at, - inputs=inputs, - outputs={"output": None}, - steps=len(iterator_list_value), - metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, - error=event.error, + yield NodeInIterationFailedEvent( + **metadata_event.model_dump(), ) + outputs[current_index] = None + + # clean nodes resources + for node_id in iteration_graph.node_ids: + variable_pool.remove([node_id]) + + # iteration run failed + if self.node_data.is_parallel: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + parallel_mode_run_id=parallel_mode_run_id, + start_at=start_at, + inputs=inputs, + outputs={"output": outputs}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + else: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": outputs}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + + # stop the iterator + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + ) + ) + return yield metadata_event current_output_segment = variable_pool.get(self.node_data.output_selector) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index b34d62d669..f05d93d83e 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -144,6 +144,8 @@ class KnowledgeRetrievalNode(LLMNode): error=str(e), error_type=type(e).__name__, ) + finally: + db.session.close() def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]: available_datasets = [] @@ -171,6 +173,9 @@ class KnowledgeRetrievalNode(LLMNode): .all() ) + # avoid blocking at retrieval + db.session.close() + for dataset in results: # pass if dataset is not available if not dataset: diff --git a/api/libs/passport.py b/api/libs/passport.py index 8df4f529bc..fe8fc33b5f 100644 --- a/api/libs/passport.py +++ b/api/libs/passport.py @@ -14,9 +14,11 @@ class PassportService: def verify(self, token): try: return jwt.decode(token, self.sk, algorithms=["HS256"]) + except jwt.exceptions.ExpiredSignatureError: + raise Unauthorized("Token has expired.") except jwt.exceptions.InvalidSignatureError: raise Unauthorized("Invalid token signature.") except jwt.exceptions.DecodeError: raise Unauthorized("Invalid token.") - except jwt.exceptions.ExpiredSignatureError: - raise Unauthorized("Token has expired.") + except jwt.exceptions.PyJWTError: # Catch-all for other JWT errors + raise Unauthorized("Invalid token.") diff --git a/api/models/model.py b/api/models/model.py index b1007c4a79..7e9e91727d 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -50,7 +50,6 @@ class AppMode(StrEnum): CHAT = "chat" ADVANCED_CHAT = "advanced-chat" AGENT_CHAT = "agent-chat" - CHANNEL = "channel" @classmethod def value_of(cls, value: str) -> "AppMode": @@ -934,7 +933,7 @@ class Message(Base): created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - workflow_run_id = db.Column(StringUUID) + workflow_run_id: Mapped[str] = db.Column(StringUUID) @property def inputs(self): diff --git a/api/pyproject.toml b/api/pyproject.toml index 420bc771b6..7f1efa671f 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -108,7 +108,7 @@ dev = [ "faker~=32.1.0", "lxml-stubs~=0.5.1", "mypy~=1.16.0", - "ruff~=0.11.5", + "ruff~=0.12.3", "pytest~=8.3.2", "pytest-benchmark~=4.0.0", "pytest-cov~=4.1.0", diff --git a/api/repositories/__init__.py b/api/repositories/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py new file mode 100644 index 0000000000..00a2d1f87d --- /dev/null +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -0,0 +1,197 @@ +""" +Service-layer repository protocol for WorkflowNodeExecutionModel operations. + +This module provides a protocol interface for service-layer operations on WorkflowNodeExecutionModel +that abstracts database queries currently done directly in service classes. This repository is +specifically designed for service-layer needs and is separate from the core domain repository. + +The service repository handles operations that require access to database-specific fields like +tenant_id, app_id, triggered_from, etc., which are not part of the core domain model. +""" + +from collections.abc import Sequence +from datetime import datetime +from typing import Optional, Protocol + +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from models.workflow import WorkflowNodeExecutionModel + + +class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol): + """ + Protocol for service-layer operations on WorkflowNodeExecutionModel. + + This repository provides database access patterns specifically needed by service classes, + handling queries that involve database-specific fields and multi-tenancy concerns. + + Key responsibilities: + - Manages database operations for workflow node executions + - Handles multi-tenant data isolation + - Provides batch processing capabilities + - Supports execution lifecycle management + + Implementation notes: + - Returns database models directly (WorkflowNodeExecutionModel) + - Handles tenant/app filtering automatically + - Provides service-specific query patterns + - Focuses on database operations without domain logic + - Supports cleanup and maintenance operations + """ + + def get_node_last_execution( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + node_id: str, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get the most recent execution for a specific node. + + This method finds the latest execution of a specific node within a workflow, + ordered by creation time. Used primarily for debugging and inspection purposes. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_id: The workflow identifier + node_id: The node identifier + + Returns: + The most recent WorkflowNodeExecutionModel for the node, or None if not found + """ + ... + + def get_executions_by_workflow_run( + self, + tenant_id: str, + app_id: str, + workflow_run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get all node executions for a specific workflow run. + + This method retrieves all node executions that belong to a specific workflow run, + ordered by index in descending order for proper trace visualization. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_run_id: The workflow run identifier + + Returns: + A sequence of WorkflowNodeExecutionModel instances ordered by index (desc) + """ + ... + + def get_execution_by_id( + self, + execution_id: str, + tenant_id: Optional[str] = None, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get a workflow node execution by its ID. + + This method retrieves a specific execution by its unique identifier. + Tenant filtering is optional for cases where the execution ID is globally unique. + + When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants. + If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should + set `tenant_id` to prevent horizontal privilege escalation. + + Args: + execution_id: The execution identifier + tenant_id: Optional tenant identifier for additional filtering + + Returns: + The WorkflowNodeExecutionModel if found, or None if not found + """ + ... + + def delete_expired_executions( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> int: + """ + Delete workflow node executions that are older than the specified date. + + This method is used for cleanup operations to remove expired executions + in batches to avoid overwhelming the database. + + Args: + tenant_id: The tenant identifier + before_date: Delete executions created before this date + batch_size: Maximum number of executions to delete in one batch + + Returns: + The number of executions deleted + """ + ... + + def delete_executions_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow node executions for a specific app. + + This method is used when removing an app and all its related data. + Executions are deleted in batches to avoid overwhelming the database. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + batch_size: Maximum number of executions to delete in one batch + + Returns: + The total number of executions deleted + """ + ... + + def get_expired_executions_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get a batch of expired workflow node executions for backup purposes. + + This method retrieves expired executions without deleting them, + allowing the caller to backup the data before deletion. + + Args: + tenant_id: The tenant identifier + before_date: Get executions created before this date + batch_size: Maximum number of executions to retrieve + + Returns: + A sequence of WorkflowNodeExecutionModel instances + """ + ... + + def delete_executions_by_ids( + self, + execution_ids: Sequence[str], + ) -> int: + """ + Delete workflow node executions by their IDs. + + This method deletes specific executions by their IDs, + typically used after backing up the data. + + This method does not perform tenant isolation checks. The caller is responsible for ensuring proper + data isolation between tenants. When execution IDs come from untrusted sources (e.g., API requests), + additional tenant validation should be implemented to prevent unauthorized access. + + Args: + execution_ids: List of execution IDs to delete + + Returns: + The number of executions deleted + """ + ... diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py new file mode 100644 index 0000000000..59e7baeb79 --- /dev/null +++ b/api/repositories/api_workflow_run_repository.py @@ -0,0 +1,181 @@ +""" +API WorkflowRun Repository Protocol + +This module defines the protocol for service-layer WorkflowRun operations. +The repository provides an abstraction layer for WorkflowRun database operations +used by service classes, separating service-layer concerns from core domain logic. + +Key Features: +- Paginated workflow run queries with filtering +- Bulk deletion operations with OSS backup support +- Multi-tenant data isolation +- Expired record cleanup with data retention +- Service-layer specific query patterns + +Usage: + This protocol should be used by service classes that need to perform + WorkflowRun database operations. It provides a clean interface that + hides implementation details and supports dependency injection. + +Example: + ```python + from repositories.dify_api_repository_factory import DifyAPIRepositoryFactory + + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + + # Get paginated workflow runs + runs = repo.get_paginated_workflow_runs( + tenant_id="tenant-123", + app_id="app-456", + triggered_from="debugging", + limit=20 + ) + ``` +""" + +from collections.abc import Sequence +from datetime import datetime +from typing import Optional, Protocol + +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.workflow import WorkflowRun + + +class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol): + """ + Protocol for service-layer WorkflowRun repository operations. + + This protocol defines the interface for WorkflowRun database operations + that are specific to service-layer needs, including pagination, filtering, + and bulk operations with data backup support. + """ + + def get_paginated_workflow_runs( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + limit: int = 20, + last_id: Optional[str] = None, + ) -> InfiniteScrollPagination: + """ + Get paginated workflow runs with filtering. + + Retrieves workflow runs for a specific app and trigger source with + cursor-based pagination support. Used primarily for debugging and + workflow run listing in the UI. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + app_id: Application identifier + triggered_from: Filter by trigger source (e.g., "debugging", "app-run") + limit: Maximum number of records to return (default: 20) + last_id: Cursor for pagination - ID of the last record from previous page + + Returns: + InfiniteScrollPagination object containing: + - data: List of WorkflowRun objects + - limit: Applied limit + - has_more: Boolean indicating if more records exist + + Raises: + ValueError: If last_id is provided but the corresponding record doesn't exist + """ + ... + + def get_workflow_run_by_id( + self, + tenant_id: str, + app_id: str, + run_id: str, + ) -> Optional[WorkflowRun]: + """ + Get a specific workflow run by ID. + + Retrieves a single workflow run with tenant and app isolation. + Used for workflow run detail views and execution tracking. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + app_id: Application identifier + run_id: Workflow run identifier + + Returns: + WorkflowRun object if found, None otherwise + """ + ... + + def get_expired_runs_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowRun]: + """ + Get a batch of expired workflow runs for cleanup. + + Retrieves workflow runs created before the specified date for + cleanup operations. Used by scheduled tasks to remove old data + while maintaining data retention policies. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + before_date: Only return runs created before this date + batch_size: Maximum number of records to return + + Returns: + Sequence of WorkflowRun objects to be processed for cleanup + """ + ... + + def delete_runs_by_ids( + self, + run_ids: Sequence[str], + ) -> int: + """ + Delete workflow runs by their IDs. + + Performs bulk deletion of workflow runs by ID. This method should + be used after backing up the data to OSS storage for retention. + + Args: + run_ids: Sequence of workflow run IDs to delete + + Returns: + Number of records actually deleted + + Note: + This method performs hard deletion. Ensure data is backed up + to OSS storage before calling this method for compliance with + data retention policies. + """ + ... + + def delete_runs_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow runs for a specific app. + + Performs bulk deletion of all workflow runs associated with an app. + Used during app cleanup operations. Processes records in batches + to avoid memory issues and long-running transactions. + + Args: + tenant_id: Tenant identifier for multi-tenant isolation + app_id: Application identifier + batch_size: Number of records to process in each batch + + Returns: + Total number of records deleted across all batches + + Note: + This method performs hard deletion without backup. Use with caution + and ensure proper data retention policies are followed. + """ + ... diff --git a/api/repositories/factory.py b/api/repositories/factory.py new file mode 100644 index 0000000000..0a0adbf2c2 --- /dev/null +++ b/api/repositories/factory.py @@ -0,0 +1,103 @@ +""" +DifyAPI Repository Factory for creating repository instances. + +This factory is specifically designed for DifyAPI repositories that handle +service-layer operations with dependency injection patterns. +""" + +import logging + +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository +from repositories.api_workflow_run_repository import APIWorkflowRunRepository + +logger = logging.getLogger(__name__) + + +class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory): + """ + Factory for creating DifyAPI repository instances based on configuration. + + This factory handles the creation of repositories that are specifically designed + for service-layer operations and use dependency injection with sessionmaker + for better testability and separation of concerns. + """ + + @classmethod + def create_api_workflow_node_execution_repository( + cls, session_maker: sessionmaker + ) -> DifyAPIWorkflowNodeExecutionRepository: + """ + Create a DifyAPIWorkflowNodeExecutionRepository instance based on configuration. + + This repository is designed for service-layer operations and uses dependency injection + with a sessionmaker for better testability and separation of concerns. It provides + database access patterns specifically needed by service classes, handling queries + that involve database-specific fields and multi-tenancy concerns. + + Args: + session_maker: SQLAlchemy sessionmaker to inject for database session management. + + Returns: + Configured DifyAPIWorkflowNodeExecutionRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be imported or instantiated + """ + class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY + logger.debug(f"Creating DifyAPIWorkflowNodeExecutionRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, DifyAPIWorkflowNodeExecutionRepository) + # Service repository requires session_maker parameter + cls._validate_constructor_signature(repository_class, ["session_maker"]) + + return repository_class(session_maker=session_maker) # type: ignore[no-any-return] + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create DifyAPIWorkflowNodeExecutionRepository") + raise RepositoryImportError( + f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}" + ) from e + + @classmethod + def create_api_workflow_run_repository(cls, session_maker: sessionmaker) -> APIWorkflowRunRepository: + """ + Create an APIWorkflowRunRepository instance based on configuration. + + This repository is designed for service-layer WorkflowRun operations and uses dependency + injection with a sessionmaker for better testability and separation of concerns. It provides + database access patterns specifically needed by service classes for workflow run management, + including pagination, filtering, and bulk operations. + + Args: + session_maker: SQLAlchemy sessionmaker to inject for database session management. + + Returns: + Configured APIWorkflowRunRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be imported or instantiated + """ + class_path = dify_config.API_WORKFLOW_RUN_REPOSITORY + logger.debug(f"Creating APIWorkflowRunRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, APIWorkflowRunRepository) + # Service repository requires session_maker parameter + cls._validate_constructor_signature(repository_class, ["session_maker"]) + + return repository_class(session_maker=session_maker) # type: ignore[no-any-return] + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create APIWorkflowRunRepository") + raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py new file mode 100644 index 0000000000..e6a23ddf9f --- /dev/null +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -0,0 +1,290 @@ +""" +SQLAlchemy implementation of WorkflowNodeExecutionServiceRepository. + +This module provides a concrete implementation of the service repository protocol +using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations. +""" + +from collections.abc import Sequence +from datetime import datetime +from typing import Optional + +from sqlalchemy import delete, desc, select +from sqlalchemy.orm import Session, sessionmaker + +from models.workflow import WorkflowNodeExecutionModel +from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository + + +class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository): + """ + SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository. + + This repository provides service-layer database operations for WorkflowNodeExecutionModel + using SQLAlchemy 2.0 style queries. It implements the DifyAPIWorkflowNodeExecutionRepository + protocol with the following features: + + - Multi-tenancy data isolation through tenant_id filtering + - Direct database model operations without domain conversion + - Batch processing for efficient large-scale operations + - Optimized query patterns for common access patterns + - Dependency injection for better testability and maintainability + - Session management and transaction handling with proper cleanup + - Maintenance operations for data lifecycle management + - Thread-safe database operations using session-per-request pattern + """ + + def __init__(self, session_maker: sessionmaker[Session]): + """ + Initialize the repository with a sessionmaker. + + Args: + session_maker: SQLAlchemy sessionmaker for creating database sessions + """ + self._session_maker = session_maker + + def get_node_last_execution( + self, + tenant_id: str, + app_id: str, + workflow_id: str, + node_id: str, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get the most recent execution for a specific node. + + This method replicates the query pattern from WorkflowService.get_node_last_run() + using SQLAlchemy 2.0 style syntax. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_id: The workflow identifier + node_id: The node identifier + + Returns: + The most recent WorkflowNodeExecutionModel for the node, or None if not found + """ + stmt = ( + select(WorkflowNodeExecutionModel) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + WorkflowNodeExecutionModel.workflow_id == workflow_id, + WorkflowNodeExecutionModel.node_id == node_id, + ) + .order_by(desc(WorkflowNodeExecutionModel.created_at)) + .limit(1) + ) + + with self._session_maker() as session: + return session.scalar(stmt) + + def get_executions_by_workflow_run( + self, + tenant_id: str, + app_id: str, + workflow_run_id: str, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get all node executions for a specific workflow run. + + This method replicates the query pattern from WorkflowRunService.get_workflow_run_node_executions() + using SQLAlchemy 2.0 style syntax. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + workflow_run_id: The workflow run identifier + + Returns: + A sequence of WorkflowNodeExecutionModel instances ordered by index (desc) + """ + stmt = ( + select(WorkflowNodeExecutionModel) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + ) + .order_by(desc(WorkflowNodeExecutionModel.index)) + ) + + with self._session_maker() as session: + return session.execute(stmt).scalars().all() + + def get_execution_by_id( + self, + execution_id: str, + tenant_id: Optional[str] = None, + ) -> Optional[WorkflowNodeExecutionModel]: + """ + Get a workflow node execution by its ID. + + This method replicates the query pattern from WorkflowDraftVariableService + and WorkflowService.single_step_run_workflow_node() using SQLAlchemy 2.0 style syntax. + + When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants. + If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should + set `tenant_id` to prevent horizontal privilege escalation. + + Args: + execution_id: The execution identifier + tenant_id: Optional tenant identifier for additional filtering + + Returns: + The WorkflowNodeExecutionModel if found, or None if not found + """ + stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution_id) + + # Add tenant filtering if provided + if tenant_id is not None: + stmt = stmt.where(WorkflowNodeExecutionModel.tenant_id == tenant_id) + + with self._session_maker() as session: + return session.scalar(stmt) + + def delete_expired_executions( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> int: + """ + Delete workflow node executions that are older than the specified date. + + Args: + tenant_id: The tenant identifier + before_date: Delete executions created before this date + batch_size: Maximum number of executions to delete in one batch + + Returns: + The number of executions deleted + """ + total_deleted = 0 + + while True: + with self._session_maker() as session: + # Find executions to delete in batches + stmt = ( + select(WorkflowNodeExecutionModel.id) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.created_at < before_date, + ) + .limit(batch_size) + ) + + execution_ids = session.execute(stmt).scalars().all() + if not execution_ids: + break + + # Delete the batch + delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) + result = session.execute(delete_stmt) + session.commit() + total_deleted += result.rowcount + + # If we deleted fewer than the batch size, we're done + if len(execution_ids) < batch_size: + break + + return total_deleted + + def delete_executions_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow node executions for a specific app. + + Args: + tenant_id: The tenant identifier + app_id: The application identifier + batch_size: Maximum number of executions to delete in one batch + + Returns: + The total number of executions deleted + """ + total_deleted = 0 + + while True: + with self._session_maker() as session: + # Find executions to delete in batches + stmt = ( + select(WorkflowNodeExecutionModel.id) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_id, + ) + .limit(batch_size) + ) + + execution_ids = session.execute(stmt).scalars().all() + if not execution_ids: + break + + # Delete the batch + delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) + result = session.execute(delete_stmt) + session.commit() + total_deleted += result.rowcount + + # If we deleted fewer than the batch size, we're done + if len(execution_ids) < batch_size: + break + + return total_deleted + + def get_expired_executions_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowNodeExecutionModel]: + """ + Get a batch of expired workflow node executions for backup purposes. + + Args: + tenant_id: The tenant identifier + before_date: Get executions created before this date + batch_size: Maximum number of executions to retrieve + + Returns: + A sequence of WorkflowNodeExecutionModel instances + """ + stmt = ( + select(WorkflowNodeExecutionModel) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.created_at < before_date, + ) + .limit(batch_size) + ) + + with self._session_maker() as session: + return session.execute(stmt).scalars().all() + + def delete_executions_by_ids( + self, + execution_ids: Sequence[str], + ) -> int: + """ + Delete workflow node executions by their IDs. + + Args: + execution_ids: List of execution IDs to delete + + Returns: + The number of executions deleted + """ + if not execution_ids: + return 0 + + with self._session_maker() as session: + stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids)) + result = session.execute(stmt) + session.commit() + return result.rowcount diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py new file mode 100644 index 0000000000..ebd1d74b20 --- /dev/null +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -0,0 +1,203 @@ +""" +SQLAlchemy API WorkflowRun Repository Implementation + +This module provides the SQLAlchemy-based implementation of the APIWorkflowRunRepository +protocol. It handles service-layer WorkflowRun database operations using SQLAlchemy 2.0 +style queries with proper session management and multi-tenant data isolation. + +Key Features: +- SQLAlchemy 2.0 style queries for modern database operations +- Cursor-based pagination for efficient large dataset handling +- Bulk operations with batch processing for performance +- Multi-tenant data isolation and security +- Proper session management with dependency injection + +Implementation Notes: +- Uses sessionmaker for consistent session management +- Implements cursor-based pagination using created_at timestamps +- Provides efficient bulk deletion with batch processing +- Maintains data consistency with proper transaction handling +""" + +import logging +from collections.abc import Sequence +from datetime import datetime +from typing import Optional, cast + +from sqlalchemy import delete, select +from sqlalchemy.orm import Session, sessionmaker + +from libs.infinite_scroll_pagination import InfiniteScrollPagination +from models.workflow import WorkflowRun +from repositories.api_workflow_run_repository import APIWorkflowRunRepository + +logger = logging.getLogger(__name__) + + +class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): + """ + SQLAlchemy implementation of APIWorkflowRunRepository. + + Provides service-layer WorkflowRun database operations using SQLAlchemy 2.0 + style queries. Supports dependency injection through sessionmaker and + maintains proper multi-tenant data isolation. + + Args: + session_maker: SQLAlchemy sessionmaker instance for database connections + """ + + def __init__(self, session_maker: sessionmaker[Session]) -> None: + """ + Initialize the repository with a sessionmaker. + + Args: + session_maker: SQLAlchemy sessionmaker for database connections + """ + self._session_maker = session_maker + + def get_paginated_workflow_runs( + self, + tenant_id: str, + app_id: str, + triggered_from: str, + limit: int = 20, + last_id: Optional[str] = None, + ) -> InfiniteScrollPagination: + """ + Get paginated workflow runs with filtering. + + Implements cursor-based pagination using created_at timestamps for + efficient handling of large datasets. Filters by tenant, app, and + trigger source for proper data isolation. + """ + with self._session_maker() as session: + # Build base query with filters + base_stmt = select(WorkflowRun).where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.app_id == app_id, + WorkflowRun.triggered_from == triggered_from, + ) + + if last_id: + # Get the last workflow run for cursor-based pagination + last_run_stmt = base_stmt.where(WorkflowRun.id == last_id) + last_workflow_run = session.scalar(last_run_stmt) + + if not last_workflow_run: + raise ValueError("Last workflow run not exists") + + # Get records created before the last run's timestamp + base_stmt = base_stmt.where( + WorkflowRun.created_at < last_workflow_run.created_at, + WorkflowRun.id != last_workflow_run.id, + ) + + # First page - get most recent records + workflow_runs = session.scalars(base_stmt.order_by(WorkflowRun.created_at.desc()).limit(limit + 1)).all() + + # Check if there are more records for pagination + has_more = len(workflow_runs) > limit + if has_more: + workflow_runs = workflow_runs[:-1] + + return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) + + def get_workflow_run_by_id( + self, + tenant_id: str, + app_id: str, + run_id: str, + ) -> Optional[WorkflowRun]: + """ + Get a specific workflow run by ID with tenant and app isolation. + """ + with self._session_maker() as session: + stmt = select(WorkflowRun).where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.app_id == app_id, + WorkflowRun.id == run_id, + ) + return cast(Optional[WorkflowRun], session.scalar(stmt)) + + def get_expired_runs_batch( + self, + tenant_id: str, + before_date: datetime, + batch_size: int = 1000, + ) -> Sequence[WorkflowRun]: + """ + Get a batch of expired workflow runs for cleanup operations. + """ + with self._session_maker() as session: + stmt = ( + select(WorkflowRun) + .where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.created_at < before_date, + ) + .limit(batch_size) + ) + return cast(Sequence[WorkflowRun], session.scalars(stmt).all()) + + def delete_runs_by_ids( + self, + run_ids: Sequence[str], + ) -> int: + """ + Delete workflow runs by their IDs using bulk deletion. + """ + if not run_ids: + return 0 + + with self._session_maker() as session: + stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)) + result = session.execute(stmt) + session.commit() + + deleted_count = cast(int, result.rowcount) + logger.info(f"Deleted {deleted_count} workflow runs by IDs") + return deleted_count + + def delete_runs_by_app( + self, + tenant_id: str, + app_id: str, + batch_size: int = 1000, + ) -> int: + """ + Delete all workflow runs for a specific app in batches. + """ + total_deleted = 0 + + while True: + with self._session_maker() as session: + # Get a batch of run IDs to delete + stmt = ( + select(WorkflowRun.id) + .where( + WorkflowRun.tenant_id == tenant_id, + WorkflowRun.app_id == app_id, + ) + .limit(batch_size) + ) + run_ids = session.scalars(stmt).all() + + if not run_ids: + break + + # Delete the batch + delete_stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids)) + result = session.execute(delete_stmt) + session.commit() + + batch_deleted = result.rowcount + total_deleted += batch_deleted + + logger.info(f"Deleted batch of {batch_deleted} workflow runs for app {app_id}") + + # If we deleted fewer records than the batch size, we're done + if batch_deleted < batch_size: + break + + logger.info(f"Total deleted {total_deleted} workflow runs for app {app_id}") + return total_deleted diff --git a/api/services/app_service.py b/api/services/app_service.py index d08462d001..db0f8cd414 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -47,8 +47,6 @@ class AppService: filters.append(App.mode == AppMode.ADVANCED_CHAT.value) elif args["mode"] == "agent-chat": filters.append(App.mode == AppMode.AGENT_CHAT.value) - elif args["mode"] == "channel": - filters.append(App.mode == AppMode.CHANNEL.value) if args.get("is_created_by_me", False): filters.append(App.created_by == user_id) diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 1fd560d581..ddd16b2e0c 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.model_runtime.utils.encoders import jsonable_encoder @@ -14,7 +14,7 @@ from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Tenant from models.model import App, Conversation, Message -from models.workflow import WorkflowNodeExecutionModel, WorkflowRun +from repositories.factory import DifyAPIRepositoryFactory from services.billing_service import BillingService logger = logging.getLogger(__name__) @@ -105,84 +105,99 @@ class ClearFreePlanTenantExpiredLogs: ) ) - while True: - with Session(db.engine).no_autoflush as session: - workflow_node_executions = ( - session.query(WorkflowNodeExecutionModel) - .filter( - WorkflowNodeExecutionModel.tenant_id == tenant_id, - WorkflowNodeExecutionModel.created_at - < datetime.datetime.now() - datetime.timedelta(days=days), - ) - .limit(batch) - .all() - ) + # Process expired workflow node executions with backup + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) + before_date = datetime.datetime.now() - datetime.timedelta(days=days) + total_deleted = 0 - if len(workflow_node_executions) == 0: - break + while True: + # Get a batch of expired executions for backup + workflow_node_executions = node_execution_repo.get_expired_executions_batch( + tenant_id=tenant_id, + before_date=before_date, + batch_size=batch, + ) - # save workflow node executions - storage.save( - f"free_plan_tenant_expired_logs/" - f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}" - f"-{time.time()}.json", - json.dumps( - jsonable_encoder(workflow_node_executions), - ).encode("utf-8"), - ) + if len(workflow_node_executions) == 0: + break + + # Save workflow node executions to storage + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder(workflow_node_executions), + ).encode("utf-8"), + ) - workflow_node_execution_ids = [ - workflow_node_execution.id for workflow_node_execution in workflow_node_executions - ] + # Extract IDs for deletion + workflow_node_execution_ids = [ + workflow_node_execution.id for workflow_node_execution in workflow_node_executions + ] - # delete workflow node executions - session.query(WorkflowNodeExecutionModel).filter( - WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids), - ).delete(synchronize_session=False) - session.commit() + # Delete the backed up executions + deleted_count = node_execution_repo.delete_executions_by_ids(workflow_node_execution_ids) + total_deleted += deleted_count - click.echo( - click.style( - f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}" - f" workflow node executions for tenant {tenant_id}" - ) + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}" + f" workflow node executions for tenant {tenant_id}" ) + ) + + # If we got fewer than the batch size, we're done + if len(workflow_node_executions) < batch: + break + + # Process expired workflow runs with backup + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + before_date = datetime.datetime.now() - datetime.timedelta(days=days) + total_deleted = 0 while True: - with Session(db.engine).no_autoflush as session: - workflow_runs = ( - session.query(WorkflowRun) - .filter( - WorkflowRun.tenant_id == tenant_id, - WorkflowRun.created_at < datetime.datetime.now() - datetime.timedelta(days=days), - ) - .limit(batch) - .all() - ) + # Get a batch of expired workflow runs for backup + workflow_runs = workflow_run_repo.get_expired_runs_batch( + tenant_id=tenant_id, + before_date=before_date, + batch_size=batch, + ) - if len(workflow_runs) == 0: - break + if len(workflow_runs) == 0: + break + + # Save workflow runs to storage + storage.save( + f"free_plan_tenant_expired_logs/" + f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}" + f"-{time.time()}.json", + json.dumps( + jsonable_encoder( + [workflow_run.to_dict() for workflow_run in workflow_runs], + ), + ).encode("utf-8"), + ) - # save workflow runs + # Extract IDs for deletion + workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs] - storage.save( - f"free_plan_tenant_expired_logs/" - f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}" - f"-{time.time()}.json", - json.dumps( - jsonable_encoder( - [workflow_run.to_dict() for workflow_run in workflow_runs], - ), - ).encode("utf-8"), - ) + # Delete the backed up workflow runs + deleted_count = workflow_run_repo.delete_runs_by_ids(workflow_run_ids) + total_deleted += deleted_count - workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs] + click.echo( + click.style( + f"[{datetime.datetime.now()}] Processed {len(workflow_run_ids)}" + f" workflow runs for tenant {tenant_id}" + ) + ) - # delete workflow runs - session.query(WorkflowRun).filter( - WorkflowRun.id.in_(workflow_run_ids), - ).delete(synchronize_session=False) - session.commit() + # If we got fewer than the batch size, we're done + if len(workflow_runs) < batch: + break @classmethod def process(cls, days: int, batch: int, tenant_ids: list[str]): diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index 8c06ee9386..54d45f45ea 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -29,7 +29,7 @@ class EnterpriseService: raise ValueError("No data found.") try: # parse the UTC timestamp from the response - return datetime.fromisoformat(data.replace("Z", "+00:00")) + return datetime.fromisoformat(data) except ValueError as e: raise ValueError(f"Invalid date format: {data}") from e @@ -40,7 +40,7 @@ class EnterpriseService: raise ValueError("No data found.") try: # parse the UTC timestamp from the response - return datetime.fromisoformat(data.replace("Z", "+00:00")) + return datetime.fromisoformat(data) except ValueError as e: raise ValueError(f"Invalid date format: {data}") from e diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 603064ca07..88d4224e97 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -95,7 +95,7 @@ class WeightKeywordSetting(BaseModel): class WeightModel(BaseModel): - weight_type: Optional[str] = None + weight_type: Optional[Literal["semantic_first", "keyword_first", "customized"]] = None vector_setting: Optional[WeightVectorSetting] = None keyword_setting: Optional[WeightKeywordSetting] = None diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index d7fb4a7c1b..0f22afd8dd 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -427,6 +427,9 @@ class PluginService: manager = PluginInstaller() + # collect actual plugin_unique_identifiers + actual_plugin_unique_identifiers = [] + metas = [] features = FeatureService.get_system_features() # check if already downloaded @@ -437,6 +440,8 @@ class PluginService: # check if the plugin is available to install PluginService._check_plugin_installation_scope(plugin_decode_response.verification) # already downloaded, skip + actual_plugin_unique_identifiers.append(plugin_unique_identifier) + metas.append({"plugin_unique_identifier": plugin_unique_identifier}) except Exception: # plugin not installed, download and upload pkg pkg = download_plugin_pkg(plugin_unique_identifier) @@ -447,17 +452,15 @@ class PluginService: ) # check if the plugin is available to install PluginService._check_plugin_installation_scope(response.verification) + # use response plugin_unique_identifier + actual_plugin_unique_identifiers.append(response.unique_identifier) + metas.append({"plugin_unique_identifier": response.unique_identifier}) return manager.install_from_identifiers( tenant_id, - plugin_unique_identifiers, + actual_plugin_unique_identifiers, PluginInstallationSource.Marketplace, - [ - { - "plugin_unique_identifier": plugin_unique_identifier, - } - for plugin_unique_identifier in plugin_unique_identifiers - ], + metas, ) @staticmethod diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 44fd72b5e4..f306e1f062 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -5,9 +5,9 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import Any, ClassVar -from sqlalchemy import Engine, orm, select +from sqlalchemy import Engine, orm from sqlalchemy.dialects.postgresql import insert -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.sql.expression import and_, or_ from core.app.entities.app_invoke_entities import InvokeFrom @@ -25,7 +25,8 @@ from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable from models import App, Conversation from models.enums import DraftVariableType -from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable +from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable +from repositories.factory import DifyAPIRepositoryFactory _logger = logging.getLogger(__name__) @@ -117,7 +118,24 @@ class WorkflowDraftVariableService: _session: Session def __init__(self, session: Session) -> None: + """ + Initialize the WorkflowDraftVariableService with a SQLAlchemy session. + + Args: + session (Session): The SQLAlchemy session used to execute database queries. + The provided session must be bound to an `Engine` object, not a specific `Connection`. + + Raises: + AssertionError: If the provided session is not bound to an `Engine` object. + """ self._session = session + engine = session.get_bind() + # Ensure the session is bound to a engine. + assert isinstance(engine, Engine) + session_maker = sessionmaker(bind=engine, expire_on_commit=False) + self._api_node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker + ) def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None: return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first() @@ -248,8 +266,7 @@ class WorkflowDraftVariableService: _logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name) return None - query = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == variable.node_execution_id) - node_exec = self._session.scalars(query).first() + node_exec = self._api_node_execution_repo.get_execution_by_id(variable.node_execution_id) if node_exec is None: _logger.warning( "Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s", @@ -298,6 +315,8 @@ class WorkflowDraftVariableService: def reset_variable(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: variable_type = variable.get_variable_type() + if variable_type == DraftVariableType.SYS and not is_system_variable_editable(variable.name): + raise VariableResetError(f"cannot reset system variable, variable_id={variable.id}") if variable_type == DraftVariableType.CONVERSATION: return self._reset_conv_var(workflow, variable) else: diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 483c0d3086..e43999a8c9 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -2,9 +2,9 @@ import threading from collections.abc import Sequence from typing import Optional +from sqlalchemy.orm import sessionmaker + import contexts -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import ( @@ -15,10 +15,18 @@ from models import ( WorkflowRun, WorkflowRunTriggeredFrom, ) -from models.workflow import WorkflowNodeExecutionTriggeredFrom +from repositories.factory import DifyAPIRepositoryFactory class WorkflowRunService: + def __init__(self): + """Initialize WorkflowRunService with repository dependencies.""" + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker + ) + self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination: """ Get advanced chat app workflow run list @@ -62,45 +70,16 @@ class WorkflowRunService: :param args: request args """ limit = int(args.get("limit", 20)) + last_id = args.get("last_id") - base_query = db.session.query(WorkflowRun).filter( - WorkflowRun.tenant_id == app_model.tenant_id, - WorkflowRun.app_id == app_model.id, - WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value, + return self._workflow_run_repo.get_paginated_workflow_runs( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value, + limit=limit, + last_id=last_id, ) - if args.get("last_id"): - last_workflow_run = base_query.filter( - WorkflowRun.id == args.get("last_id"), - ).first() - - if not last_workflow_run: - raise ValueError("Last workflow run not exists") - - workflow_runs = ( - base_query.filter( - WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id - ) - .order_by(WorkflowRun.created_at.desc()) - .limit(limit) - .all() - ) - else: - workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all() - - has_more = False - if len(workflow_runs) == limit: - current_page_first_workflow_run = workflow_runs[-1] - rest_count = base_query.filter( - WorkflowRun.created_at < current_page_first_workflow_run.created_at, - WorkflowRun.id != current_page_first_workflow_run.id, - ).count() - - if rest_count > 0: - has_more = True - - return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more) - def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]: """ Get workflow run detail @@ -108,18 +87,12 @@ class WorkflowRunService: :param app_model: app model :param run_id: workflow run id """ - workflow_run = ( - db.session.query(WorkflowRun) - .filter( - WorkflowRun.tenant_id == app_model.tenant_id, - WorkflowRun.app_id == app_model.id, - WorkflowRun.id == run_id, - ) - .first() + return self._workflow_run_repo.get_workflow_run_by_id( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + run_id=run_id, ) - return workflow_run - def get_workflow_run_node_executions( self, app_model: App, @@ -137,17 +110,13 @@ class WorkflowRunService: if not workflow_run: return [] - repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, - user=user, - app_id=app_model.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) + # Get tenant_id from user + tenant_id = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + if tenant_id is None: + raise ValueError("User tenant_id cannot be None") - # Use the repository to get the database models directly - order_config = OrderConfig(order_by=["index"], order_direction="desc") - workflow_node_executions = repository.get_db_models_by_workflow_run( - workflow_run_id=run_id, order_config=order_config + return self._node_execution_service_repo.get_executions_by_workflow_run( + tenant_id=tenant_id, + app_id=app_model.id, + workflow_run_id=run_id, ) - - return workflow_node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 2be57fd51c..0149d50346 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -7,13 +7,13 @@ from typing import Any, Optional from uuid import uuid4 from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.orm import Session, sessionmaker from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import DifyCoreRepositoryFactory from core.variables import Variable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool @@ -41,6 +41,7 @@ from models.workflow import ( WorkflowNodeExecutionTriggeredFrom, WorkflowType, ) +from repositories.factory import DifyAPIRepositoryFactory from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError from services.workflow.workflow_converter import WorkflowConverter @@ -57,21 +58,32 @@ class WorkflowService: Workflow Service """ - def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None: - # TODO(QuantumGhost): This query is not fully covered by index. - criteria = ( - WorkflowNodeExecutionModel.tenant_id == app_model.tenant_id, - WorkflowNodeExecutionModel.app_id == app_model.id, - WorkflowNodeExecutionModel.workflow_id == workflow.id, - WorkflowNodeExecutionModel.node_id == node_id, + def __init__(self, session_maker: sessionmaker | None = None): + """Initialize WorkflowService with repository dependencies.""" + if session_maker is None: + session_maker = sessionmaker(bind=db.engine, expire_on_commit=False) + self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository( + session_maker ) - node_exec = ( - db.session.query(WorkflowNodeExecutionModel) - .filter(*criteria) - .order_by(WorkflowNodeExecutionModel.created_at.desc()) - .first() + + def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None: + """ + Get the most recent execution for a specific node. + + Args: + app_model: The application model + workflow: The workflow model + node_id: The node identifier + + Returns: + The most recent WorkflowNodeExecutionModel for the node, or None if not found + """ + return self._node_execution_service_repo.get_node_last_execution( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + workflow_id=workflow.id, + node_id=node_id, ) - return node_exec def is_workflow_exist(self, app_model: App) -> bool: return ( @@ -396,7 +408,7 @@ class WorkflowService: node_execution.workflow_id = draft_workflow.id # Create repository and save the node execution - repository = SQLAlchemyWorkflowNodeExecutionRepository( + repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( session_factory=db.engine, user=account, app_id=app_model.id, @@ -404,8 +416,9 @@ class WorkflowService: ) repository.save(node_execution) - # Convert node_execution to WorkflowNodeExecution after save - workflow_node_execution = repository.to_db_model(node_execution) + workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id) + if workflow_node_execution is None: + raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving") with Session(bind=db.engine) as session, session.begin(): draft_var_saver = DraftVariableSaver( @@ -418,6 +431,7 @@ class WorkflowService: ) draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs) session.commit() + return workflow_node_execution def run_free_workflow_node( @@ -429,7 +443,7 @@ class WorkflowService: # run draft workflow node start_at = time.perf_counter() - workflow_node_execution = self._handle_node_run_result( + node_execution = self._handle_node_run_result( invoke_node_fn=lambda: WorkflowEntry.run_free_node( node_id=node_id, node_data=node_data, @@ -441,7 +455,7 @@ class WorkflowService: node_id=node_id, ) - return workflow_node_execution + return node_execution def _handle_node_run_result( self, diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index 4a62cb74b4..179adcbd6e 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -6,6 +6,7 @@ import click from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import sessionmaker from extensions.ext_database import db from models import ( @@ -31,7 +32,8 @@ from models import ( ) from models.tools import WorkflowToolProvider from models.web import PinnedConversation, SavedMessage -from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun +from models.workflow import ConversationVariable, Workflow, WorkflowAppLog +from repositories.factory import DifyAPIRepositoryFactory @shared_task(queue="app_deletion", bind=True, max_retries=3) @@ -189,30 +191,32 @@ def _delete_app_workflows(tenant_id: str, app_id: str): def _delete_app_workflow_runs(tenant_id: str, app_id: str): - def del_workflow_run(workflow_run_id: str): - db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).delete(synchronize_session=False) - - _delete_records( - """select id from workflow_runs where tenant_id=:tenant_id and app_id=:app_id limit 1000""", - {"tenant_id": tenant_id, "app_id": app_id}, - del_workflow_run, - "workflow run", + """Delete all workflow runs for an app using the service repository.""" + session_maker = sessionmaker(bind=db.engine) + workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) + + deleted_count = workflow_run_repo.delete_runs_by_app( + tenant_id=tenant_id, + app_id=app_id, + batch_size=1000, ) + logging.info(f"Deleted {deleted_count} workflow runs for app {app_id}") -def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): - def del_workflow_node_execution(workflow_node_execution_id: str): - db.session.query(WorkflowNodeExecutionModel).filter( - WorkflowNodeExecutionModel.id == workflow_node_execution_id - ).delete(synchronize_session=False) - _delete_records( - """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""", - {"tenant_id": tenant_id, "app_id": app_id}, - del_workflow_node_execution, - "workflow node execution", +def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): + """Delete all workflow node executions for an app using the service repository.""" + session_maker = sessionmaker(bind=db.engine) + node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker) + + deleted_count = node_execution_repo.delete_executions_by_app( + tenant_id=tenant_id, + app_id=app_id, + batch_size=1000, ) + logging.info(f"Deleted {deleted_count} workflow node executions for app {app_id}") + def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): def del_workflow_app_log(workflow_app_log_id: str): diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 638323f850..8acaa54b9c 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -116,11 +116,11 @@ def test_execute_llm(flask_req_ctx): mock_usage = LLMUsage( prompt_tokens=30, prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal("1000"), + prompt_price_unit=Decimal(1000), prompt_price=Decimal("0.00003"), completion_tokens=20, completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal("1000"), + completion_price_unit=Decimal(1000), completion_price=Decimal("0.00004"), total_tokens=50, total_price=Decimal("0.00007"), @@ -219,11 +219,11 @@ def test_execute_llm_with_jinja2(flask_req_ctx, setup_code_executor_mock): mock_usage = LLMUsage( prompt_tokens=30, prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal("1000"), + prompt_price_unit=Decimal(1000), prompt_price=Decimal("0.00003"), completion_tokens=20, completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal("1000"), + completion_price_unit=Decimal(1000), completion_price=Decimal("0.00004"), total_tokens=50, total_price=Decimal("0.00007"), diff --git a/api/tests/unit_tests/core/repositories/__init__.py b/api/tests/unit_tests/core/repositories/__init__.py new file mode 100644 index 0000000000..c65d7da61d --- /dev/null +++ b/api/tests/unit_tests/core/repositories/__init__.py @@ -0,0 +1 @@ +# Unit tests for core repositories module diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py new file mode 100644 index 0000000000..fce4a6fb6b --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -0,0 +1,455 @@ +""" +Unit tests for the RepositoryFactory. + +This module tests the factory pattern implementation for creating repository instances +based on configuration, including error handling and validation. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from pytest_mock import MockerFixture +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from models import Account, EndUser +from models.enums import WorkflowRunTriggeredFrom +from models.workflow import WorkflowNodeExecutionTriggeredFrom + + +class TestRepositoryFactory: + """Test cases for RepositoryFactory.""" + + def test_import_class_success(self): + """Test successful class import.""" + # Test importing a real class + class_path = "unittest.mock.MagicMock" + result = DifyCoreRepositoryFactory._import_class(class_path) + assert result is MagicMock + + def test_import_class_invalid_path(self): + """Test import with invalid module path.""" + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory._import_class("invalid.module.path") + assert "Cannot import repository class" in str(exc_info.value) + + def test_import_class_invalid_class_name(self): + """Test import with invalid class name.""" + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass") + assert "Cannot import repository class" in str(exc_info.value) + + def test_import_class_malformed_path(self): + """Test import with malformed path (no dots).""" + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory._import_class("invalidpath") + assert "Cannot import repository class" in str(exc_info.value) + + def test_validate_repository_interface_success(self): + """Test successful interface validation.""" + + # Create a mock class that implements the required methods + class MockRepository: + def save(self): + pass + + def get_by_id(self): + pass + + # Create a mock interface with the same methods + class MockInterface: + def save(self): + pass + + def get_by_id(self): + pass + + # Should not raise an exception + DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) + + def test_validate_repository_interface_missing_methods(self): + """Test interface validation with missing methods.""" + + # Create a mock class that doesn't implement all required methods + class IncompleteRepository: + def save(self): + pass + + # Missing get_by_id method + + # Create a mock interface with required methods + class MockInterface: + def save(self): + pass + + def get_by_id(self): + pass + + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface) + assert "does not implement required methods" in str(exc_info.value) + assert "get_by_id" in str(exc_info.value) + + def test_validate_constructor_signature_success(self): + """Test successful constructor signature validation.""" + + class MockRepository: + def __init__(self, session_factory, user, app_id, triggered_from): + pass + + # Should not raise an exception + DifyCoreRepositoryFactory._validate_constructor_signature( + MockRepository, ["session_factory", "user", "app_id", "triggered_from"] + ) + + def test_validate_constructor_signature_missing_params(self): + """Test constructor validation with missing parameters.""" + + class IncompleteRepository: + def __init__(self, session_factory, user): + # Missing app_id and triggered_from parameters + pass + + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory._validate_constructor_signature( + IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"] + ) + assert "does not accept required parameters" in str(exc_info.value) + assert "app_id" in str(exc_info.value) + assert "triggered_from" in str(exc_info.value) + + def test_validate_constructor_signature_inspection_error(self, mocker: MockerFixture): + """Test constructor validation when inspection fails.""" + # Mock inspect.signature to raise an exception + mocker.patch("inspect.signature", side_effect=Exception("Inspection failed")) + + class MockRepository: + def __init__(self, session_factory): + pass + + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"]) + assert "Failed to validate constructor signature" in str(exc_info.value) + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_execution_repository_success(self, mock_config, mocker: MockerFixture): + """Test successful creation of WorkflowExecutionRepository.""" + # Setup mock configuration + mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + # Create mock dependencies + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=Account) + app_id = "test-app-id" + triggered_from = WorkflowRunTriggeredFrom.APP_RUN + + # Mock the imported class to be a valid repository + mock_repository_class = MagicMock() + mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository) + mock_repository_class.return_value = mock_repository_instance + + # Mock the validation methods + with ( + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), + ): + result = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id=app_id, + triggered_from=triggered_from, + ) + + # Verify the repository was created with correct parameters + mock_repository_class.assert_called_once_with( + session_factory=mock_session_factory, + user=mock_user, + app_id=app_id, + triggered_from=triggered_from, + ) + assert result is mock_repository_instance + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_execution_repository_import_error(self, mock_config): + """Test WorkflowExecutionRepository creation with import error.""" + # Setup mock configuration with invalid class path + mock_config.WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=Account) + + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + assert "Cannot import repository class" in str(exc_info.value) + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture): + """Test WorkflowExecutionRepository creation with validation error.""" + # Setup mock configuration + mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=Account) + + # Mock import to succeed but validation to fail + mock_repository_class = MagicMock() + with ( + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object( + DifyCoreRepositoryFactory, + "_validate_repository_interface", + side_effect=RepositoryImportError("Interface validation failed"), + ), + ): + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + assert "Interface validation failed" in str(exc_info.value) + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_execution_repository_instantiation_error(self, mock_config, mocker: MockerFixture): + """Test WorkflowExecutionRepository creation with instantiation error.""" + # Setup mock configuration + mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=Account) + + # Mock import and validation to succeed but instantiation to fail + mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed")) + with ( + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), + ): + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value) + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_node_execution_repository_success(self, mock_config, mocker: MockerFixture): + """Test successful creation of WorkflowNodeExecutionRepository.""" + # Setup mock configuration + mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + # Create mock dependencies + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=EndUser) + app_id = "test-app-id" + triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + + # Mock the imported class to be a valid repository + mock_repository_class = MagicMock() + mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository) + mock_repository_class.return_value = mock_repository_instance + + # Mock the validation methods + with ( + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), + ): + result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id=app_id, + triggered_from=triggered_from, + ) + + # Verify the repository was created with correct parameters + mock_repository_class.assert_called_once_with( + session_factory=mock_session_factory, + user=mock_user, + app_id=app_id, + triggered_from=triggered_from, + ) + assert result is mock_repository_instance + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_node_execution_repository_import_error(self, mock_config): + """Test WorkflowNodeExecutionRepository creation with import error.""" + # Setup mock configuration with invalid class path + mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=EndUser) + + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + assert "Cannot import repository class" in str(exc_info.value) + + def test_repository_import_error_exception(self): + """Test RepositoryImportError exception.""" + error_message = "Test error message" + exception = RepositoryImportError(error_message) + assert str(exception) == error_message + assert isinstance(exception, Exception) + + @patch("core.repositories.factory.dify_config") + def test_create_with_engine_instead_of_sessionmaker(self, mock_config, mocker: MockerFixture): + """Test repository creation with Engine instead of sessionmaker.""" + # Setup mock configuration + mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + # Create mock dependencies with Engine instead of sessionmaker + mock_engine = MagicMock(spec=Engine) + mock_user = MagicMock(spec=Account) + + # Mock the imported class to be a valid repository + mock_repository_class = MagicMock() + mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository) + mock_repository_class.return_value = mock_repository_instance + + # Mock the validation methods + with ( + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), + ): + result = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=mock_engine, # Using Engine instead of sessionmaker + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + # Verify the repository was created with the Engine + mock_repository_class.assert_called_once_with( + session_factory=mock_engine, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + assert result is mock_repository_instance + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_node_execution_repository_validation_error(self, mock_config): + """Test WorkflowNodeExecutionRepository creation with validation error.""" + # Setup mock configuration + mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=EndUser) + + # Mock import to succeed but validation to fail + mock_repository_class = MagicMock() + with ( + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object( + DifyCoreRepositoryFactory, + "_validate_repository_interface", + side_effect=RepositoryImportError("Interface validation failed"), + ), + ): + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + assert "Interface validation failed" in str(exc_info.value) + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config): + """Test WorkflowNodeExecutionRepository creation with instantiation error.""" + # Setup mock configuration + mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=EndUser) + + # Mock import and validation to succeed but instantiation to fail + mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed")) + with ( + patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), + patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), + ): + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value) + + def test_validate_repository_interface_with_private_methods(self): + """Test interface validation ignores private methods.""" + + # Create a mock class with private methods + class MockRepository: + def save(self): + pass + + def get_by_id(self): + pass + + def _private_method(self): + pass + + # Create a mock interface with private methods + class MockInterface: + def save(self): + pass + + def get_by_id(self): + pass + + def _private_method(self): + pass + + # Should not raise an exception (private methods are ignored) + DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) + + def test_validate_constructor_signature_with_extra_params(self): + """Test constructor validation with extra parameters (should pass).""" + + class MockRepository: + def __init__(self, session_factory, user, app_id, triggered_from, extra_param=None): + pass + + # Should not raise an exception (extra parameters are allowed) + DifyCoreRepositoryFactory._validate_constructor_signature( + MockRepository, ["session_factory", "user", "app_id", "triggered_from"] + ) + + def test_validate_constructor_signature_with_kwargs(self): + """Test constructor validation with **kwargs (current implementation doesn't support this).""" + + class MockRepository: + def __init__(self, session_factory, user, **kwargs): + pass + + # Current implementation doesn't handle **kwargs, so this should raise an exception + with pytest.raises(RepositoryImportError) as exc_info: + DifyCoreRepositoryFactory._validate_constructor_signature( + MockRepository, ["session_factory", "user", "app_id", "triggered_from"] + ) + assert "does not accept required parameters" in str(exc_info.value) + assert "app_id" in str(exc_info.value) + assert "triggered_from" in str(exc_info.value) diff --git a/api/tests/unit_tests/core/tools/utils/__init__.py b/api/tests/unit_tests/core/tools/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py new file mode 100644 index 0000000000..8e07293ce0 --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -0,0 +1,56 @@ +import pytest +from flask import Flask + +from core.tools.utils.parser import ApiBasedToolSchemaParser + + +@pytest.fixture +def app(): + app = Flask(__name__) + return app + + +def test_parse_openapi_to_tool_bundle_operation_id(app): + openapi = { + "openapi": "3.0.0", + "info": {"title": "Simple API", "version": "1.0.0"}, + "servers": [{"url": "http://localhost:3000"}], + "paths": { + "/": { + "get": { + "summary": "Root endpoint", + "responses": { + "200": { + "description": "Successful response", + } + }, + } + }, + "/api/resources": { + "get": { + "summary": "Non-root endpoint without an operationId", + "responses": { + "200": { + "description": "Successful response", + } + }, + }, + "post": { + "summary": "Non-root endpoint with an operationId", + "operationId": "createResource", + "responses": { + "201": { + "description": "Resource created", + } + }, + }, + }, + }, + } + with app.test_request_context(): + tool_bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi) + + assert len(tool_bundles) == 3 + assert tool_bundles[0].operation_id == "_get" + assert tool_bundles[1].operation_id == "apiresources_get" + assert tool_bundles[2].operation_id == "createResource" diff --git a/api/tests/unit_tests/libs/test_login.py b/api/tests/unit_tests/libs/test_login.py new file mode 100644 index 0000000000..39671077d4 --- /dev/null +++ b/api/tests/unit_tests/libs/test_login.py @@ -0,0 +1,232 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask, g +from flask_login import LoginManager, UserMixin + +from libs.login import _get_user, current_user, login_required + + +class MockUser(UserMixin): + """Mock user class for testing.""" + + def __init__(self, id: str, is_authenticated: bool = True): + self.id = id + self._is_authenticated = is_authenticated + + @property + def is_authenticated(self): + return self._is_authenticated + + +class TestLoginRequired: + """Test cases for login_required decorator.""" + + @pytest.fixture + def setup_app(self, app: Flask): + """Set up Flask app with login manager.""" + # Initialize login manager + login_manager = LoginManager() + login_manager.init_app(app) + + # Mock unauthorized handler + login_manager.unauthorized = MagicMock(return_value="Unauthorized") + + # Add a dummy user loader to prevent exceptions + @login_manager.user_loader + def load_user(user_id): + return None + + return app + + def test_authenticated_user_can_access_protected_view(self, setup_app: Flask): + """Test that authenticated users can access protected views.""" + + @login_required + def protected_view(): + return "Protected content" + + with setup_app.test_request_context(): + # Mock authenticated user + mock_user = MockUser("test_user", is_authenticated=True) + with patch("libs.login._get_user", return_value=mock_user): + result = protected_view() + assert result == "Protected content" + + def test_unauthenticated_user_cannot_access_protected_view(self, setup_app: Flask): + """Test that unauthenticated users are redirected.""" + + @login_required + def protected_view(): + return "Protected content" + + with setup_app.test_request_context(): + # Mock unauthenticated user + mock_user = MockUser("test_user", is_authenticated=False) + with patch("libs.login._get_user", return_value=mock_user): + result = protected_view() + assert result == "Unauthorized" + setup_app.login_manager.unauthorized.assert_called_once() + + def test_login_disabled_allows_unauthenticated_access(self, setup_app: Flask): + """Test that LOGIN_DISABLED config bypasses authentication.""" + + @login_required + def protected_view(): + return "Protected content" + + with setup_app.test_request_context(): + # Mock unauthenticated user and LOGIN_DISABLED + mock_user = MockUser("test_user", is_authenticated=False) + with patch("libs.login._get_user", return_value=mock_user): + with patch("libs.login.dify_config") as mock_config: + mock_config.LOGIN_DISABLED = True + + result = protected_view() + assert result == "Protected content" + # Ensure unauthorized was not called + setup_app.login_manager.unauthorized.assert_not_called() + + def test_options_request_bypasses_authentication(self, setup_app: Flask): + """Test that OPTIONS requests are exempt from authentication.""" + + @login_required + def protected_view(): + return "Protected content" + + with setup_app.test_request_context(method="OPTIONS"): + # Mock unauthenticated user + mock_user = MockUser("test_user", is_authenticated=False) + with patch("libs.login._get_user", return_value=mock_user): + result = protected_view() + assert result == "Protected content" + # Ensure unauthorized was not called + setup_app.login_manager.unauthorized.assert_not_called() + + def test_flask_2_compatibility(self, setup_app: Flask): + """Test Flask 2.x compatibility with ensure_sync.""" + + @login_required + def protected_view(): + return "Protected content" + + # Mock Flask 2.x ensure_sync + setup_app.ensure_sync = MagicMock(return_value=lambda: "Synced content") + + with setup_app.test_request_context(): + mock_user = MockUser("test_user", is_authenticated=True) + with patch("libs.login._get_user", return_value=mock_user): + result = protected_view() + assert result == "Synced content" + setup_app.ensure_sync.assert_called_once() + + def test_flask_1_compatibility(self, setup_app: Flask): + """Test Flask 1.x compatibility without ensure_sync.""" + + @login_required + def protected_view(): + return "Protected content" + + # Remove ensure_sync to simulate Flask 1.x + if hasattr(setup_app, "ensure_sync"): + delattr(setup_app, "ensure_sync") + + with setup_app.test_request_context(): + mock_user = MockUser("test_user", is_authenticated=True) + with patch("libs.login._get_user", return_value=mock_user): + result = protected_view() + assert result == "Protected content" + + +class TestGetUser: + """Test cases for _get_user function.""" + + def test_get_user_returns_user_from_g(self, app: Flask): + """Test that _get_user returns user from g._login_user.""" + mock_user = MockUser("test_user") + + with app.test_request_context(): + g._login_user = mock_user + user = _get_user() + assert user == mock_user + assert user.id == "test_user" + + def test_get_user_loads_user_if_not_in_g(self, app: Flask): + """Test that _get_user loads user if not already in g.""" + mock_user = MockUser("test_user") + + # Mock login manager + login_manager = MagicMock() + login_manager._load_user = MagicMock() + app.login_manager = login_manager + + with app.test_request_context(): + # Simulate _load_user setting g._login_user + def side_effect(): + g._login_user = mock_user + + login_manager._load_user.side_effect = side_effect + + user = _get_user() + assert user == mock_user + login_manager._load_user.assert_called_once() + + def test_get_user_returns_none_without_request_context(self, app: Flask): + """Test that _get_user returns None outside request context.""" + # Outside of request context + user = _get_user() + assert user is None + + +class TestCurrentUser: + """Test cases for current_user proxy.""" + + def test_current_user_proxy_returns_authenticated_user(self, app: Flask): + """Test that current_user proxy returns authenticated user.""" + mock_user = MockUser("test_user", is_authenticated=True) + + with app.test_request_context(): + with patch("libs.login._get_user", return_value=mock_user): + assert current_user.id == "test_user" + assert current_user.is_authenticated is True + + def test_current_user_proxy_returns_none_when_no_user(self, app: Flask): + """Test that current_user proxy handles None user.""" + with app.test_request_context(): + with patch("libs.login._get_user", return_value=None): + # When _get_user returns None, accessing attributes should fail + # or current_user should evaluate to falsy + try: + # Try to access an attribute that would exist on a real user + _ = current_user.id + pytest.fail("Should have raised AttributeError") + except AttributeError: + # This is expected when current_user is None + pass + + def test_current_user_proxy_thread_safety(self, app: Flask): + """Test that current_user proxy is thread-safe.""" + import threading + + results = {} + + def check_user_in_thread(user_id: str, index: int): + with app.test_request_context(): + mock_user = MockUser(user_id) + with patch("libs.login._get_user", return_value=mock_user): + results[index] = current_user.id + + # Create multiple threads with different users + threads = [] + for i in range(5): + thread = threading.Thread(target=check_user_in_thread, args=(f"user_{i}", i)) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join() + + # Verify each thread got its own user + for i in range(5): + assert results[i] == f"user_{i}" diff --git a/api/tests/unit_tests/libs/test_passport.py b/api/tests/unit_tests/libs/test_passport.py new file mode 100644 index 0000000000..f33484c18d --- /dev/null +++ b/api/tests/unit_tests/libs/test_passport.py @@ -0,0 +1,205 @@ +from datetime import UTC, datetime, timedelta +from unittest.mock import patch + +import jwt +import pytest +from werkzeug.exceptions import Unauthorized + +from libs.passport import PassportService + + +class TestPassportService: + """Test PassportService JWT operations""" + + @pytest.fixture + def passport_service(self): + """Create PassportService instance with test secret key""" + with patch("libs.passport.dify_config") as mock_config: + mock_config.SECRET_KEY = "test-secret-key-for-testing" + return PassportService() + + @pytest.fixture + def another_passport_service(self): + """Create another PassportService instance with different secret key""" + with patch("libs.passport.dify_config") as mock_config: + mock_config.SECRET_KEY = "another-secret-key-for-testing" + return PassportService() + + # Core functionality tests + def test_should_issue_and_verify_token(self, passport_service): + """Test complete JWT lifecycle: issue and verify""" + payload = {"user_id": "123", "app_code": "test-app"} + token = passport_service.issue(payload) + + # Verify token format + assert isinstance(token, str) + assert len(token.split(".")) == 3 # JWT format: header.payload.signature + + # Verify token content + decoded = passport_service.verify(token) + assert decoded == payload + + def test_should_handle_different_payload_types(self, passport_service): + """Test issuing and verifying tokens with different payload types""" + test_cases = [ + {"string": "value"}, + {"number": 42}, + {"float": 3.14}, + {"boolean": True}, + {"null": None}, + {"array": [1, 2, 3]}, + {"nested": {"key": "value"}}, + {"unicode": "中文测试"}, + {"emoji": "🔐"}, + {}, # Empty payload + ] + + for payload in test_cases: + token = passport_service.issue(payload) + decoded = passport_service.verify(token) + assert decoded == payload + + # Security tests + def test_should_reject_modified_token(self, passport_service): + """Test that any modification to token invalidates it""" + token = passport_service.issue({"user": "test"}) + + # Test multiple modification points + test_positions = [0, len(token) // 3, len(token) // 2, len(token) - 1] + + for pos in test_positions: + if pos < len(token) and token[pos] != ".": + # Change one character + tampered = token[:pos] + ("X" if token[pos] != "X" else "Y") + token[pos + 1 :] + with pytest.raises(Unauthorized): + passport_service.verify(tampered) + + def test_should_reject_token_with_different_secret_key(self, passport_service, another_passport_service): + """Test key isolation - token from one service should not work with another""" + payload = {"user_id": "123", "app_code": "test-app"} + token = passport_service.issue(payload) + + with pytest.raises(Unauthorized) as exc_info: + another_passport_service.verify(token) + assert str(exc_info.value) == "401 Unauthorized: Invalid token signature." + + def test_should_use_hs256_algorithm(self, passport_service): + """Test that HS256 algorithm is used for signing""" + payload = {"test": "data"} + token = passport_service.issue(payload) + + # Decode header without relying on JWT internals + # Use jwt.get_unverified_header which is a public API + header = jwt.get_unverified_header(token) + assert header["alg"] == "HS256" + + def test_should_reject_token_with_wrong_algorithm(self, passport_service): + """Test rejection of token signed with different algorithm""" + payload = {"user_id": "123"} + + # Create token with different algorithm + with patch("libs.passport.dify_config") as mock_config: + mock_config.SECRET_KEY = "test-secret-key-for-testing" + # Create token with HS512 instead of HS256 + wrong_alg_token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS512") + + # Should fail because service expects HS256 + # InvalidAlgorithmError is now caught by PyJWTError handler + with pytest.raises(Unauthorized) as exc_info: + passport_service.verify(wrong_alg_token) + assert str(exc_info.value) == "401 Unauthorized: Invalid token." + + # Exception handling tests + def test_should_handle_invalid_tokens(self, passport_service): + """Test handling of various invalid token formats""" + invalid_tokens = [ + ("not.a.token", "Invalid token."), + ("invalid-jwt-format", "Invalid token."), + ("xxx.yyy.zzz", "Invalid token."), + ("a.b", "Invalid token."), # Missing signature + ("", "Invalid token."), # Empty string + (" ", "Invalid token."), # Whitespace + (None, "Invalid token."), # None value + # Malformed base64 + ("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.INVALID_BASE64!@#$.signature", "Invalid token."), + ] + + for invalid_token, expected_message in invalid_tokens: + with pytest.raises(Unauthorized) as exc_info: + passport_service.verify(invalid_token) + assert expected_message in str(exc_info.value) + + def test_should_reject_expired_token(self, passport_service): + """Test rejection of expired token""" + past_time = datetime.now(UTC) - timedelta(hours=1) + payload = {"user_id": "123", "exp": past_time.timestamp()} + + with patch("libs.passport.dify_config") as mock_config: + mock_config.SECRET_KEY = "test-secret-key-for-testing" + token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS256") + + with pytest.raises(Unauthorized) as exc_info: + passport_service.verify(token) + assert str(exc_info.value) == "401 Unauthorized: Token has expired." + + # Configuration tests + def test_should_handle_empty_secret_key(self): + """Test behavior when SECRET_KEY is empty""" + with patch("libs.passport.dify_config") as mock_config: + mock_config.SECRET_KEY = "" + service = PassportService() + + # Empty secret key should still work but is insecure + payload = {"test": "data"} + token = service.issue(payload) + decoded = service.verify(token) + assert decoded == payload + + def test_should_handle_none_secret_key(self): + """Test behavior when SECRET_KEY is None""" + with patch("libs.passport.dify_config") as mock_config: + mock_config.SECRET_KEY = None + service = PassportService() + + payload = {"test": "data"} + # JWT library will raise TypeError when secret is None + with pytest.raises((TypeError, jwt.exceptions.InvalidKeyError)): + service.issue(payload) + + # Boundary condition tests + def test_should_handle_large_payload(self, passport_service): + """Test handling of large payload""" + # Test with 100KB instead of 1MB for faster tests + large_data = "x" * (100 * 1024) + payload = {"data": large_data} + + token = passport_service.issue(payload) + decoded = passport_service.verify(token) + + assert decoded["data"] == large_data + + def test_should_handle_special_characters_in_payload(self, passport_service): + """Test handling of special characters in payload""" + special_payloads = [ + {"special": "!@#$%^&*()"}, + {"quotes": 'He said "Hello"'}, + {"backslash": "path\\to\\file"}, + {"newline": "line1\nline2"}, + {"unicode": "🔐🔑🛡️"}, + {"mixed": "Test123!@#中文🔐"}, + ] + + for payload in special_payloads: + token = passport_service.issue(payload) + decoded = passport_service.verify(token) + assert decoded == payload + + def test_should_catch_generic_pyjwt_errors(self, passport_service): + """Test that generic PyJWTError exceptions are caught and converted to Unauthorized""" + # Mock jwt.decode to raise a generic PyJWTError + with patch("libs.passport.jwt.decode") as mock_decode: + mock_decode.side_effect = jwt.exceptions.PyJWTError("Generic JWT error") + + with pytest.raises(Unauthorized) as exc_info: + passport_service.verify("some-token") + assert str(exc_info.value) == "401 Unauthorized: Invalid token." diff --git a/api/tests/unit_tests/services/services_test_help.py b/api/tests/unit_tests/services/services_test_help.py new file mode 100644 index 0000000000..c6b962f7fc --- /dev/null +++ b/api/tests/unit_tests/services/services_test_help.py @@ -0,0 +1,59 @@ +from unittest.mock import MagicMock + + +class ServiceDbTestHelper: + """ + Helper class for service database query tests. + """ + + @staticmethod + def setup_db_query_filter_by_mock(mock_db, query_results): + """ + Smart database query mock that responds based on model type and query parameters. + + Args: + mock_db: Mock database session + query_results: Dict mapping (model_name, filter_key, filter_value) to return value + Example: {('Account', 'email', 'test@example.com'): mock_account} + """ + + def query_side_effect(model): + mock_query = MagicMock() + + def filter_by_side_effect(**kwargs): + mock_filter_result = MagicMock() + + def first_side_effect(): + # Find matching result based on model and filter parameters + for (model_name, filter_key, filter_value), result in query_results.items(): + if model.__name__ == model_name and filter_key in kwargs and kwargs[filter_key] == filter_value: + return result + return None + + mock_filter_result.first.side_effect = first_side_effect + + # Handle order_by calls for complex queries + def order_by_side_effect(*args, **kwargs): + mock_order_result = MagicMock() + + def order_first_side_effect(): + # Look for order_by results in the same query_results dict + for (model_name, filter_key, filter_value), result in query_results.items(): + if ( + model.__name__ == model_name + and filter_key == "order_by" + and filter_value == "first_available" + ): + return result + return None + + mock_order_result.first.side_effect = order_first_side_effect + return mock_order_result + + mock_filter_result.order_by.side_effect = order_by_side_effect + return mock_filter_result + + mock_query.filter_by.side_effect = filter_by_side_effect + return mock_query + + mock_db.session.query.side_effect = query_side_effect diff --git a/api/tests/unit_tests/services/test_account_service.py b/api/tests/unit_tests/services/test_account_service.py new file mode 100644 index 0000000000..13900ab6d1 --- /dev/null +++ b/api/tests/unit_tests/services/test_account_service.py @@ -0,0 +1,1545 @@ +import json +from datetime import datetime, timedelta +from unittest.mock import MagicMock, patch + +import pytest + +from configs import dify_config +from models.account import Account +from services.account_service import AccountService, RegisterService, TenantService +from services.errors.account import ( + AccountAlreadyInTenantError, + AccountLoginError, + AccountNotFoundError, + AccountPasswordError, + AccountRegisterError, + CurrentPasswordIncorrectError, +) +from tests.unit_tests.services.services_test_help import ServiceDbTestHelper + + +class TestAccountAssociatedDataFactory: + """Factory class for creating test data and mock objects for account service tests.""" + + @staticmethod + def create_account_mock( + account_id: str = "user-123", + email: str = "test@example.com", + name: str = "Test User", + status: str = "active", + password: str = "hashed_password", + password_salt: str = "salt", + interface_language: str = "en-US", + interface_theme: str = "light", + timezone: str = "UTC", + **kwargs, + ) -> MagicMock: + """Create a mock account with specified attributes.""" + account = MagicMock(spec=Account) + account.id = account_id + account.email = email + account.name = name + account.status = status + account.password = password + account.password_salt = password_salt + account.interface_language = interface_language + account.interface_theme = interface_theme + account.timezone = timezone + # Set last_active_at to a datetime object that's older than 10 minutes + account.last_active_at = datetime.now() - timedelta(minutes=15) + account.initialized_at = None + for key, value in kwargs.items(): + setattr(account, key, value) + return account + + @staticmethod + def create_tenant_join_mock( + tenant_id: str = "tenant-456", + account_id: str = "user-123", + current: bool = True, + role: str = "normal", + **kwargs, + ) -> MagicMock: + """Create a mock tenant account join record.""" + tenant_join = MagicMock() + tenant_join.tenant_id = tenant_id + tenant_join.account_id = account_id + tenant_join.current = current + tenant_join.role = role + for key, value in kwargs.items(): + setattr(tenant_join, key, value) + return tenant_join + + @staticmethod + def create_feature_service_mock(allow_register: bool = True): + """Create a mock feature service.""" + mock_service = MagicMock() + mock_service.get_system_features.return_value.is_allow_register = allow_register + return mock_service + + @staticmethod + def create_billing_service_mock(email_frozen: bool = False): + """Create a mock billing service.""" + mock_service = MagicMock() + mock_service.is_email_in_freeze.return_value = email_frozen + return mock_service + + +class TestAccountService: + """ + Comprehensive unit tests for AccountService methods. + + This test suite covers all account-related operations including: + - Authentication and login + - Account creation and registration + - Password management + - JWT token generation + - User loading and tenant management + - Error conditions and edge cases + """ + + @pytest.fixture + def mock_db_dependencies(self): + """Common mock setup for database dependencies.""" + with patch("services.account_service.db") as mock_db: + mock_db.session.add = MagicMock() + mock_db.session.commit = MagicMock() + yield { + "db": mock_db, + } + + @pytest.fixture + def mock_password_dependencies(self): + """Mock setup for password-related functions.""" + with ( + patch("services.account_service.compare_password") as mock_compare_password, + patch("services.account_service.hash_password") as mock_hash_password, + patch("services.account_service.valid_password") as mock_valid_password, + ): + yield { + "compare_password": mock_compare_password, + "hash_password": mock_hash_password, + "valid_password": mock_valid_password, + } + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_feature_service, + patch("services.account_service.BillingService") as mock_billing_service, + patch("services.account_service.PassportService") as mock_passport_service, + ): + yield { + "feature_service": mock_feature_service, + "billing_service": mock_billing_service, + "passport_service": mock_passport_service, + } + + @pytest.fixture + def mock_db_with_autospec(self): + """ + Mock database with autospec for more realistic behavior. + This approach preserves the actual method signatures and behavior. + """ + with patch("services.account_service.db", autospec=True) as mock_db: + # Create a more realistic session mock + mock_session = MagicMock() + mock_db.session = mock_session + + # Setup basic session methods + mock_session.add = MagicMock() + mock_session.commit = MagicMock() + mock_session.query = MagicMock() + + yield mock_db + + def _assert_database_operations_called(self, mock_db): + """Helper method to verify database operations were called.""" + mock_db.session.commit.assert_called() + + def _assert_database_operations_not_called(self, mock_db): + """Helper method to verify database operations were not called.""" + mock_db.session.commit.assert_not_called() + + def _assert_exception_raised(self, exception_type, callable_func, *args, **kwargs): + """Helper method to verify that specific exception is raised.""" + with pytest.raises(exception_type): + callable_func(*args, **kwargs) + + # ==================== Authentication Tests ==================== + + def test_authenticate_success(self, mock_db_dependencies, mock_password_dependencies): + """Test successful authentication with correct email and password.""" + # Setup test data + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + + # Setup smart database query mock + query_results = {("Account", "email", "test@example.com"): mock_account} + ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + + mock_password_dependencies["compare_password"].return_value = True + + # Execute test + result = AccountService.authenticate("test@example.com", "password") + + # Verify results + assert result == mock_account + self._assert_database_operations_called(mock_db_dependencies["db"]) + + def test_authenticate_account_not_found(self, mock_db_dependencies): + """Test authentication when account does not exist.""" + # Setup smart database query mock - no matching results + query_results = {("Account", "email", "notfound@example.com"): None} + ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + + # Execute test and verify exception + self._assert_exception_raised( + AccountNotFoundError, AccountService.authenticate, "notfound@example.com", "password" + ) + + def test_authenticate_account_banned(self, mock_db_dependencies): + """Test authentication when account is banned.""" + # Setup test data + mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned") + + # Setup smart database query mock + query_results = {("Account", "email", "banned@example.com"): mock_account} + ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + + # Execute test and verify exception + self._assert_exception_raised(AccountLoginError, AccountService.authenticate, "banned@example.com", "password") + + def test_authenticate_password_error(self, mock_db_dependencies, mock_password_dependencies): + """Test authentication with wrong password.""" + # Setup test data + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + + # Setup smart database query mock + query_results = {("Account", "email", "test@example.com"): mock_account} + ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + + mock_password_dependencies["compare_password"].return_value = False + + # Execute test and verify exception + self._assert_exception_raised( + AccountPasswordError, AccountService.authenticate, "test@example.com", "wrongpassword" + ) + + def test_authenticate_pending_account_activates(self, mock_db_dependencies, mock_password_dependencies): + """Test authentication for a pending account, which should activate on login.""" + # Setup test data + mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="pending") + + # Setup smart database query mock + query_results = {("Account", "email", "pending@example.com"): mock_account} + ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + + mock_password_dependencies["compare_password"].return_value = True + + # Execute test + result = AccountService.authenticate("pending@example.com", "password") + + # Verify results + assert result == mock_account + assert mock_account.status == "active" + self._assert_database_operations_called(mock_db_dependencies["db"]) + + # ==================== Account Creation Tests ==================== + + def test_create_account_success( + self, mock_db_dependencies, mock_password_dependencies, mock_external_service_dependencies + ): + """Test successful account creation with all required parameters.""" + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + mock_password_dependencies["hash_password"].return_value = b"hashed_password" + + # Execute test + result = AccountService.create_account( + email="test@example.com", + name="Test User", + interface_language="en-US", + password="password123", + interface_theme="light", + ) + + # Verify results + assert result.email == "test@example.com" + assert result.name == "Test User" + assert result.interface_language == "en-US" + assert result.interface_theme == "light" + assert result.password is not None + assert result.password_salt is not None + assert result.timezone is not None + + # Verify database operations + mock_db_dependencies["db"].session.add.assert_called_once() + added_account = mock_db_dependencies["db"].session.add.call_args[0][0] + assert added_account.email == "test@example.com" + assert added_account.name == "Test User" + assert added_account.interface_language == "en-US" + assert added_account.interface_theme == "light" + assert added_account.password is not None + assert added_account.password_salt is not None + assert added_account.timezone is not None + self._assert_database_operations_called(mock_db_dependencies["db"]) + + def test_create_account_registration_disabled(self, mock_external_service_dependencies): + """Test account creation when registration is disabled.""" + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = False + + # Execute test and verify exception + self._assert_exception_raised( + Exception, # AccountNotFound + AccountService.create_account, + email="test@example.com", + name="Test User", + interface_language="en-US", + ) + + def test_create_account_email_frozen(self, mock_db_dependencies, mock_external_service_dependencies): + """Test account creation with frozen email address.""" + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = True + dify_config.BILLING_ENABLED = True + + # Execute test and verify exception + self._assert_exception_raised( + AccountRegisterError, + AccountService.create_account, + email="frozen@example.com", + name="Test User", + interface_language="en-US", + ) + dify_config.BILLING_ENABLED = False + + def test_create_account_without_password(self, mock_db_dependencies, mock_external_service_dependencies): + """Test account creation without password (for invite-based registration).""" + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Execute test + result = AccountService.create_account( + email="test@example.com", + name="Test User", + interface_language="zh-CN", + password=None, + interface_theme="dark", + ) + + # Verify results + assert result.email == "test@example.com" + assert result.name == "Test User" + assert result.interface_language == "zh-CN" + assert result.interface_theme == "dark" + assert result.password is None + assert result.password_salt is None + assert result.timezone is not None + + # Verify database operations + mock_db_dependencies["db"].session.add.assert_called_once() + added_account = mock_db_dependencies["db"].session.add.call_args[0][0] + assert added_account.email == "test@example.com" + assert added_account.name == "Test User" + assert added_account.interface_language == "zh-CN" + assert added_account.interface_theme == "dark" + assert added_account.password is None + assert added_account.password_salt is None + assert added_account.timezone is not None + self._assert_database_operations_called(mock_db_dependencies["db"]) + + # ==================== Password Management Tests ==================== + + def test_update_account_password_success(self, mock_db_dependencies, mock_password_dependencies): + """Test successful password update with correct current password and valid new password.""" + # Setup test data + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + mock_password_dependencies["compare_password"].return_value = True + mock_password_dependencies["valid_password"].return_value = None + mock_password_dependencies["hash_password"].return_value = b"new_hashed_password" + + # Execute test + result = AccountService.update_account_password(mock_account, "old_password", "new_password123") + + # Verify results + assert result == mock_account + assert mock_account.password is not None + assert mock_account.password_salt is not None + + # Verify password validation was called + mock_password_dependencies["compare_password"].assert_called_once_with( + "old_password", "hashed_password", "salt" + ) + mock_password_dependencies["valid_password"].assert_called_once_with("new_password123") + + # Verify database operations + self._assert_database_operations_called(mock_db_dependencies["db"]) + + def test_update_account_password_current_password_incorrect(self, mock_password_dependencies): + """Test password update with incorrect current password.""" + # Setup test data + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + mock_password_dependencies["compare_password"].return_value = False + + # Execute test and verify exception + self._assert_exception_raised( + CurrentPasswordIncorrectError, + AccountService.update_account_password, + mock_account, + "wrong_password", + "new_password123", + ) + + # Verify password comparison was called + mock_password_dependencies["compare_password"].assert_called_once_with( + "wrong_password", "hashed_password", "salt" + ) + + def test_update_account_password_invalid_new_password(self, mock_password_dependencies): + """Test password update with invalid new password.""" + # Setup test data + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + mock_password_dependencies["compare_password"].return_value = True + mock_password_dependencies["valid_password"].side_effect = ValueError("Password too short") + + # Execute test and verify exception + self._assert_exception_raised( + ValueError, AccountService.update_account_password, mock_account, "old_password", "short" + ) + + # Verify password validation was called + mock_password_dependencies["valid_password"].assert_called_once_with("short") + + # ==================== User Loading Tests ==================== + + def test_load_user_success(self, mock_db_dependencies): + """Test successful user loading with current tenant.""" + # Setup test data + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock() + + # Setup smart database query mock + query_results = { + ("Account", "id", "user-123"): mock_account, + ("TenantAccountJoin", "account_id", "user-123"): mock_tenant_join, + } + ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + + # Mock datetime + with patch("services.account_service.datetime") as mock_datetime: + mock_now = datetime.now() + mock_datetime.now.return_value = mock_now + mock_datetime.UTC = "UTC" + + # Execute test + result = AccountService.load_user("user-123") + + # Verify results + assert result == mock_account + assert mock_account.set_tenant_id.called + + def test_load_user_not_found(self, mock_db_dependencies): + """Test user loading when user does not exist.""" + # Setup smart database query mock - no matching results + query_results = {("Account", "id", "non-existent-user"): None} + ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + + # Execute test + result = AccountService.load_user("non-existent-user") + + # Verify results + assert result is None + + def test_load_user_banned(self, mock_db_dependencies): + """Test user loading when user is banned.""" + # Setup test data + mock_account = TestAccountAssociatedDataFactory.create_account_mock(status="banned") + + # Setup smart database query mock + query_results = {("Account", "id", "user-123"): mock_account} + ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + + # Execute test and verify exception + self._assert_exception_raised( + Exception, # Unauthorized + AccountService.load_user, + "user-123", + ) + + def test_load_user_no_current_tenant(self, mock_db_dependencies): + """Test user loading when user has no current tenant but has available tenants.""" + # Setup test data + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + mock_available_tenant = TestAccountAssociatedDataFactory.create_tenant_join_mock(current=False) + + # Setup smart database query mock for complex scenario + query_results = { + ("Account", "id", "user-123"): mock_account, + ("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant + ("TenantAccountJoin", "order_by", "first_available"): mock_available_tenant, # First available tenant + } + ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + + # Mock datetime + with patch("services.account_service.datetime") as mock_datetime: + mock_now = datetime.now() + mock_datetime.now.return_value = mock_now + mock_datetime.UTC = "UTC" + + # Execute test + result = AccountService.load_user("user-123") + + # Verify results + assert result == mock_account + assert mock_available_tenant.current is True + self._assert_database_operations_called(mock_db_dependencies["db"]) + + def test_load_user_no_tenants(self, mock_db_dependencies): + """Test user loading when user has no tenants at all.""" + # Setup test data + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + + # Setup smart database query mock for no tenants scenario + query_results = { + ("Account", "id", "user-123"): mock_account, + ("TenantAccountJoin", "account_id", "user-123"): None, # No current tenant + ("TenantAccountJoin", "order_by", "first_available"): None, # No available tenants + } + ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + + # Mock datetime + with patch("services.account_service.datetime") as mock_datetime: + mock_now = datetime.now() + mock_datetime.now.return_value = mock_now + mock_datetime.UTC = "UTC" + + # Execute test + result = AccountService.load_user("user-123") + + # Verify results + assert result is None + + +class TestTenantService: + """ + Comprehensive unit tests for TenantService methods. + + This test suite covers all tenant-related operations including: + - Tenant creation and management + - Member management and permissions + - Tenant switching + - Role updates and permission checks + - Error conditions and edge cases + """ + + @pytest.fixture + def mock_db_dependencies(self): + """Common mock setup for database dependencies.""" + with patch("services.account_service.db") as mock_db: + mock_db.session.add = MagicMock() + mock_db.session.commit = MagicMock() + yield { + "db": mock_db, + } + + @pytest.fixture + def mock_rsa_dependencies(self): + """Mock setup for RSA-related functions.""" + with patch("services.account_service.generate_key_pair") as mock_generate_key_pair: + yield mock_generate_key_pair + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_feature_service, + patch("services.account_service.BillingService") as mock_billing_service, + ): + yield { + "feature_service": mock_feature_service, + "billing_service": mock_billing_service, + } + + def _assert_database_operations_called(self, mock_db): + """Helper method to verify database operations were called.""" + mock_db.session.commit.assert_called() + + def _assert_exception_raised(self, exception_type, callable_func, *args, **kwargs): + """Helper method to verify that specific exception is raised.""" + with pytest.raises(exception_type): + callable_func(*args, **kwargs) + + # ==================== Tenant Creation Tests ==================== + + def test_create_owner_tenant_if_not_exist_new_user( + self, mock_db_dependencies, mock_rsa_dependencies, mock_external_service_dependencies + ): + """Test creating owner tenant for new user without existing tenants.""" + # Setup test data + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + + # Setup smart database query mock - no existing tenant joins + query_results = { + ("TenantAccountJoin", "account_id", "user-123"): None, + ("TenantAccountJoin", "tenant_id", "tenant-456"): None, # For has_roles check + } + ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + + # Setup external service mocks + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + + # Mock tenant creation + mock_tenant = MagicMock() + mock_tenant.id = "tenant-456" + mock_tenant.name = "Test User's Workspace" + + # Mock database operations + mock_db_dependencies["db"].session.add = MagicMock() + + # Mock RSA key generation + mock_rsa_dependencies.return_value = "mock_public_key" + + # Mock has_roles method to return False (no existing owner) + with patch("services.account_service.TenantService.has_roles") as mock_has_roles: + mock_has_roles.return_value = False + + # Mock Tenant creation to set proper ID + with patch("services.account_service.Tenant") as mock_tenant_class: + mock_tenant_instance = MagicMock() + mock_tenant_instance.id = "tenant-456" + mock_tenant_instance.name = "Test User's Workspace" + mock_tenant_class.return_value = mock_tenant_instance + + # Execute test + TenantService.create_owner_tenant_if_not_exist(mock_account) + + # Verify tenant was created with correct parameters + mock_db_dependencies["db"].session.add.assert_called() + + # Get all calls to session.add + add_calls = mock_db_dependencies["db"].session.add.call_args_list + + # Should have at least 2 calls: one for Tenant, one for TenantAccountJoin + assert len(add_calls) >= 2 + + # Verify Tenant was added with correct name + tenant_added = False + tenant_account_join_added = False + + for call in add_calls: + added_object = call[0][0] # First argument of the call + + # Check if it's a Tenant object + if hasattr(added_object, "name") and hasattr(added_object, "id"): + # This should be a Tenant object + assert added_object.name == "Test User's Workspace" + tenant_added = True + + # Check if it's a TenantAccountJoin object + elif ( + hasattr(added_object, "tenant_id") + and hasattr(added_object, "account_id") + and hasattr(added_object, "role") + ): + # This should be a TenantAccountJoin object + assert added_object.tenant_id is not None + assert added_object.account_id == "user-123" + assert added_object.role == "owner" + tenant_account_join_added = True + + assert tenant_added, "Tenant object was not added to database" + assert tenant_account_join_added, "TenantAccountJoin object was not added to database" + + self._assert_database_operations_called(mock_db_dependencies["db"]) + assert mock_rsa_dependencies.called, "RSA key generation was not called" + + # ==================== Member Management Tests ==================== + + def test_create_tenant_member_success(self, mock_db_dependencies): + """Test successful tenant member creation.""" + # Setup test data + mock_tenant = MagicMock() + mock_tenant.id = "tenant-456" + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + + # Setup smart database query mock - no existing member + query_results = {("TenantAccountJoin", "tenant_id", "tenant-456"): None} + ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + + # Mock database operations + mock_db_dependencies["db"].session.add = MagicMock() + + # Execute test + result = TenantService.create_tenant_member(mock_tenant, mock_account, "normal") + + # Verify member was created with correct parameters + assert result is not None + mock_db_dependencies["db"].session.add.assert_called_once() + + # Verify the TenantAccountJoin object was added with correct parameters + added_tenant_account_join = mock_db_dependencies["db"].session.add.call_args[0][0] + assert added_tenant_account_join.tenant_id == "tenant-456" + assert added_tenant_account_join.account_id == "user-123" + assert added_tenant_account_join.role == "normal" + + self._assert_database_operations_called(mock_db_dependencies["db"]) + + # ==================== Tenant Switching Tests ==================== + + def test_switch_tenant_success(self): + """Test successful tenant switching.""" + # Setup test data + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + mock_tenant_join = TestAccountAssociatedDataFactory.create_tenant_join_mock( + tenant_id="tenant-456", account_id="user-123", current=False + ) + + # Mock the complex query in switch_tenant method + with patch("services.account_service.db") as mock_db: + # Mock the join query that returns the tenant_account_join + mock_query = MagicMock() + mock_filter = MagicMock() + mock_filter.first.return_value = mock_tenant_join + mock_query.filter.return_value = mock_filter + mock_query.join.return_value = mock_query + mock_db.session.query.return_value = mock_query + + # Execute test + TenantService.switch_tenant(mock_account, "tenant-456") + + # Verify tenant was switched + assert mock_tenant_join.current is True + self._assert_database_operations_called(mock_db) + + def test_switch_tenant_no_tenant_id(self): + """Test tenant switching without providing tenant ID.""" + # Setup test data + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + + # Execute test and verify exception + self._assert_exception_raised(ValueError, TenantService.switch_tenant, mock_account, None) + + # ==================== Role Management Tests ==================== + + def test_update_member_role_success(self): + """Test successful member role update.""" + # Setup test data + mock_tenant = MagicMock() + mock_tenant.id = "tenant-456" + mock_member = TestAccountAssociatedDataFactory.create_account_mock(account_id="member-789") + mock_operator = TestAccountAssociatedDataFactory.create_account_mock(account_id="operator-123") + mock_target_join = TestAccountAssociatedDataFactory.create_tenant_join_mock( + tenant_id="tenant-456", account_id="member-789", role="normal" + ) + mock_operator_join = TestAccountAssociatedDataFactory.create_tenant_join_mock( + tenant_id="tenant-456", account_id="operator-123", role="owner" + ) + + # Mock the database queries in update_member_role method + with patch("services.account_service.db") as mock_db: + # Mock the first query for operator permission check + mock_query1 = MagicMock() + mock_filter1 = MagicMock() + mock_filter1.first.return_value = mock_operator_join + mock_query1.filter_by.return_value = mock_filter1 + + # Mock the second query for target member + mock_query2 = MagicMock() + mock_filter2 = MagicMock() + mock_filter2.first.return_value = mock_target_join + mock_query2.filter_by.return_value = mock_filter2 + + # Make the query method return different mocks for different calls + mock_db.session.query.side_effect = [mock_query1, mock_query2] + + # Execute test + TenantService.update_member_role(mock_tenant, mock_member, "admin", mock_operator) + + # Verify role was updated + assert mock_target_join.role == "admin" + self._assert_database_operations_called(mock_db) + + # ==================== Permission Check Tests ==================== + + def test_check_member_permission_success(self, mock_db_dependencies): + """Test successful member permission check.""" + # Setup test data + mock_tenant = MagicMock() + mock_tenant.id = "tenant-456" + mock_operator = TestAccountAssociatedDataFactory.create_account_mock(account_id="operator-123") + mock_member = TestAccountAssociatedDataFactory.create_account_mock(account_id="member-789") + mock_operator_join = TestAccountAssociatedDataFactory.create_tenant_join_mock( + tenant_id="tenant-456", account_id="operator-123", role="owner" + ) + + # Setup smart database query mock + query_results = {("TenantAccountJoin", "tenant_id", "tenant-456"): mock_operator_join} + ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + + # Execute test - should not raise exception + TenantService.check_member_permission(mock_tenant, mock_operator, mock_member, "add") + + def test_check_member_permission_operate_self(self): + """Test member permission check when operator tries to operate self.""" + # Setup test data + mock_tenant = MagicMock() + mock_tenant.id = "tenant-456" + mock_operator = TestAccountAssociatedDataFactory.create_account_mock(account_id="operator-123") + + # Execute test and verify exception + from services.errors.account import CannotOperateSelfError + + self._assert_exception_raised( + CannotOperateSelfError, + TenantService.check_member_permission, + mock_tenant, + mock_operator, + mock_operator, # Same as operator + "add", + ) + + +class TestRegisterService: + """ + Comprehensive unit tests for RegisterService methods. + + This test suite covers all registration-related operations including: + - System setup + - Account registration + - Member invitation + - Token management + - Invitation validation + - Error conditions and edge cases + """ + + @pytest.fixture + def mock_db_dependencies(self): + """Common mock setup for database dependencies.""" + with patch("services.account_service.db") as mock_db: + mock_db.session.add = MagicMock() + mock_db.session.commit = MagicMock() + mock_db.session.begin_nested = MagicMock() + mock_db.session.rollback = MagicMock() + yield { + "db": mock_db, + } + + @pytest.fixture + def mock_redis_dependencies(self): + """Mock setup for Redis-related functions.""" + with patch("services.account_service.redis_client") as mock_redis: + yield mock_redis + + @pytest.fixture + def mock_external_service_dependencies(self): + """Mock setup for external service dependencies.""" + with ( + patch("services.account_service.FeatureService") as mock_feature_service, + patch("services.account_service.BillingService") as mock_billing_service, + patch("services.account_service.PassportService") as mock_passport_service, + ): + yield { + "feature_service": mock_feature_service, + "billing_service": mock_billing_service, + "passport_service": mock_passport_service, + } + + @pytest.fixture + def mock_task_dependencies(self): + """Mock setup for task dependencies.""" + with patch("services.account_service.send_invite_member_mail_task") as mock_send_mail: + yield mock_send_mail + + def _assert_database_operations_called(self, mock_db): + """Helper method to verify database operations were called.""" + mock_db.session.commit.assert_called() + + def _assert_database_operations_not_called(self, mock_db): + """Helper method to verify database operations were not called.""" + mock_db.session.commit.assert_not_called() + + def _assert_exception_raised(self, exception_type, callable_func, *args, **kwargs): + """Helper method to verify that specific exception is raised.""" + with pytest.raises(exception_type): + callable_func(*args, **kwargs) + + # ==================== Setup Tests ==================== + + def test_setup_success(self, mock_db_dependencies, mock_external_service_dependencies): + """Test successful system setup.""" + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Mock AccountService.create_account + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + with patch("services.account_service.AccountService.create_account") as mock_create_account: + mock_create_account.return_value = mock_account + + # Mock TenantService.create_owner_tenant_if_not_exist + with patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_tenant: + # Mock DifySetup + with patch("services.account_service.DifySetup") as mock_dify_setup: + mock_dify_setup_instance = MagicMock() + mock_dify_setup.return_value = mock_dify_setup_instance + + # Execute test + RegisterService.setup("admin@example.com", "Admin User", "password123", "192.168.1.1") + + # Verify results + mock_create_account.assert_called_once_with( + email="admin@example.com", + name="Admin User", + interface_language="en-US", + password="password123", + is_setup=True, + ) + mock_create_tenant.assert_called_once_with(account=mock_account, is_setup=True) + mock_dify_setup.assert_called_once() + self._assert_database_operations_called(mock_db_dependencies["db"]) + + def test_setup_failure_rollback(self, mock_db_dependencies, mock_external_service_dependencies): + """Test setup failure with proper rollback.""" + # Setup mocks to simulate failure + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Mock AccountService.create_account to raise exception + with patch("services.account_service.AccountService.create_account") as mock_create_account: + mock_create_account.side_effect = Exception("Database error") + + # Execute test and verify exception + self._assert_exception_raised( + ValueError, + RegisterService.setup, + "admin@example.com", + "Admin User", + "password123", + "192.168.1.1", + ) + + # Verify rollback operations were called + mock_db_dependencies["db"].session.query.assert_called() + + # ==================== Registration Tests ==================== + + def test_register_success(self, mock_db_dependencies, mock_external_service_dependencies): + """Test successful account registration.""" + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Mock AccountService.create_account + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + with patch("services.account_service.AccountService.create_account") as mock_create_account: + mock_create_account.return_value = mock_account + + # Mock TenantService.create_tenant and create_tenant_member + with ( + patch("services.account_service.TenantService.create_tenant") as mock_create_tenant, + patch("services.account_service.TenantService.create_tenant_member") as mock_create_member, + patch("services.account_service.tenant_was_created") as mock_event, + ): + mock_tenant = MagicMock() + mock_tenant.id = "tenant-456" + mock_create_tenant.return_value = mock_tenant + + # Execute test + result = RegisterService.register( + email="test@example.com", + name="Test User", + password="password123", + language="en-US", + ) + + # Verify results + assert result == mock_account + assert result.status == "active" + assert result.initialized_at is not None + mock_create_account.assert_called_once_with( + email="test@example.com", + name="Test User", + interface_language="en-US", + password="password123", + is_setup=False, + ) + mock_create_tenant.assert_called_once_with("Test User's Workspace") + mock_create_member.assert_called_once_with(mock_tenant, mock_account, role="owner") + mock_event.send.assert_called_once_with(mock_tenant) + self._assert_database_operations_called(mock_db_dependencies["db"]) + + def test_register_with_oauth(self, mock_db_dependencies, mock_external_service_dependencies): + """Test account registration with OAuth integration.""" + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Mock AccountService.create_account and link_account_integrate + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + with ( + patch("services.account_service.AccountService.create_account") as mock_create_account, + patch("services.account_service.AccountService.link_account_integrate") as mock_link_account, + ): + mock_create_account.return_value = mock_account + + # Mock TenantService methods + with ( + patch("services.account_service.TenantService.create_tenant") as mock_create_tenant, + patch("services.account_service.TenantService.create_tenant_member") as mock_create_member, + patch("services.account_service.tenant_was_created") as mock_event, + ): + mock_tenant = MagicMock() + mock_create_tenant.return_value = mock_tenant + + # Execute test + result = RegisterService.register( + email="test@example.com", + name="Test User", + password=None, + open_id="oauth123", + provider="google", + language="en-US", + ) + + # Verify results + assert result == mock_account + mock_link_account.assert_called_once_with("google", "oauth123", mock_account) + self._assert_database_operations_called(mock_db_dependencies["db"]) + + def test_register_with_pending_status(self, mock_db_dependencies, mock_external_service_dependencies): + """Test account registration with pending status.""" + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Mock AccountService.create_account + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + with patch("services.account_service.AccountService.create_account") as mock_create_account: + mock_create_account.return_value = mock_account + + # Mock TenantService methods + with ( + patch("services.account_service.TenantService.create_tenant") as mock_create_tenant, + patch("services.account_service.TenantService.create_tenant_member") as mock_create_member, + patch("services.account_service.tenant_was_created") as mock_event, + ): + mock_tenant = MagicMock() + mock_create_tenant.return_value = mock_tenant + + # Execute test with pending status + from models.account import AccountStatus + + result = RegisterService.register( + email="test@example.com", + name="Test User", + password="password123", + language="en-US", + status=AccountStatus.PENDING, + ) + + # Verify results + assert result == mock_account + assert result.status == "pending" + self._assert_database_operations_called(mock_db_dependencies["db"]) + + def test_register_workspace_not_allowed(self, mock_db_dependencies, mock_external_service_dependencies): + """Test registration when workspace creation is not allowed.""" + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.is_allow_create_workspace = True + mock_external_service_dependencies[ + "feature_service" + ].get_system_features.return_value.license.workspaces.is_available.return_value = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Mock AccountService.create_account + mock_account = TestAccountAssociatedDataFactory.create_account_mock() + with patch("services.account_service.AccountService.create_account") as mock_create_account: + mock_create_account.return_value = mock_account + + # Execute test and verify exception + from services.errors.workspace import WorkSpaceNotAllowedCreateError + + with patch("services.account_service.TenantService.create_tenant") as mock_create_tenant: + mock_create_tenant.side_effect = WorkSpaceNotAllowedCreateError() + + self._assert_exception_raised( + AccountRegisterError, + RegisterService.register, + email="test@example.com", + name="Test User", + password="password123", + language="en-US", + ) + + # Verify rollback was called + mock_db_dependencies["db"].session.rollback.assert_called() + + def test_register_general_exception(self, mock_db_dependencies, mock_external_service_dependencies): + """Test registration with general exception handling.""" + # Setup mocks + mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True + mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False + + # Mock AccountService.create_account to raise exception + with patch("services.account_service.AccountService.create_account") as mock_create_account: + mock_create_account.side_effect = Exception("Unexpected error") + + # Execute test and verify exception + self._assert_exception_raised( + AccountRegisterError, + RegisterService.register, + email="test@example.com", + name="Test User", + password="password123", + language="en-US", + ) + + # Verify rollback was called + mock_db_dependencies["db"].session.rollback.assert_called() + + # ==================== Member Invitation Tests ==================== + + def test_invite_new_member_new_account(self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies): + """Test inviting a new member who doesn't have an account.""" + # Setup test data + mock_tenant = MagicMock() + mock_tenant.id = "tenant-456" + mock_tenant.name = "Test Workspace" + mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") + + # Mock database queries - need to mock the Session query + mock_session = MagicMock() + mock_session.query.return_value.filter_by.return_value.first.return_value = None # No existing account + + with patch("services.account_service.Session") as mock_session_class: + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_class.return_value.__exit__.return_value = None + + # Mock RegisterService.register + mock_new_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="new-user-456", email="newuser@example.com", name="newuser", status="pending" + ) + with patch("services.account_service.RegisterService.register") as mock_register: + mock_register.return_value = mock_new_account + + # Mock TenantService methods + with ( + patch("services.account_service.TenantService.check_member_permission") as mock_check_permission, + patch("services.account_service.TenantService.create_tenant_member") as mock_create_member, + patch("services.account_service.TenantService.switch_tenant") as mock_switch_tenant, + patch("services.account_service.RegisterService.generate_invite_token") as mock_generate_token, + ): + mock_generate_token.return_value = "invite-token-123" + + # Execute test + result = RegisterService.invite_new_member( + tenant=mock_tenant, + email="newuser@example.com", + language="en-US", + role="normal", + inviter=mock_inviter, + ) + + # Verify results + assert result == "invite-token-123" + mock_register.assert_called_once_with( + email="newuser@example.com", + name="newuser", + language="en-US", + status="pending", + is_setup=True, + ) + mock_create_member.assert_called_once_with(mock_tenant, mock_new_account, "normal") + mock_switch_tenant.assert_called_once_with(mock_new_account, mock_tenant.id) + mock_generate_token.assert_called_once_with(mock_tenant, mock_new_account) + mock_task_dependencies.delay.assert_called_once() + + def test_invite_new_member_existing_account( + self, mock_db_dependencies, mock_redis_dependencies, mock_task_dependencies + ): + """Test inviting a new member who already has an account.""" + # Setup test data + mock_tenant = MagicMock() + mock_tenant.id = "tenant-456" + mock_tenant.name = "Test Workspace" + mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") + mock_existing_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="existing-user-456", email="existing@example.com", status="pending" + ) + + # Mock database queries - need to mock the Session query + mock_session = MagicMock() + mock_session.query.return_value.filter_by.return_value.first.return_value = mock_existing_account + + with patch("services.account_service.Session") as mock_session_class: + mock_session_class.return_value.__enter__.return_value = mock_session + mock_session_class.return_value.__exit__.return_value = None + + # Mock the db.session.query for TenantAccountJoin + mock_db_query = MagicMock() + mock_db_query.filter_by.return_value.first.return_value = None # No existing member + mock_db_dependencies["db"].session.query.return_value = mock_db_query + + # Mock TenantService methods + with ( + patch("services.account_service.TenantService.check_member_permission") as mock_check_permission, + patch("services.account_service.TenantService.create_tenant_member") as mock_create_member, + patch("services.account_service.RegisterService.generate_invite_token") as mock_generate_token, + ): + mock_generate_token.return_value = "invite-token-123" + + # Execute test + result = RegisterService.invite_new_member( + tenant=mock_tenant, + email="existing@example.com", + language="en-US", + role="normal", + inviter=mock_inviter, + ) + + # Verify results + assert result == "invite-token-123" + mock_create_member.assert_called_once_with(mock_tenant, mock_existing_account, "normal") + mock_generate_token.assert_called_once_with(mock_tenant, mock_existing_account) + mock_task_dependencies.delay.assert_called_once() + + def test_invite_new_member_already_in_tenant(self, mock_db_dependencies, mock_redis_dependencies): + """Test inviting a member who is already in the tenant.""" + # Setup test data + mock_tenant = MagicMock() + mock_tenant.id = "tenant-456" + mock_inviter = TestAccountAssociatedDataFactory.create_account_mock(account_id="inviter-123", name="Inviter") + mock_existing_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="existing-user-456", email="existing@example.com", status="active" + ) + + # Mock database queries + query_results = { + ("Account", "email", "existing@example.com"): mock_existing_account, + ( + "TenantAccountJoin", + "tenant_id", + "tenant-456", + ): TestAccountAssociatedDataFactory.create_tenant_join_mock(), + } + ServiceDbTestHelper.setup_db_query_filter_by_mock(mock_db_dependencies["db"], query_results) + + # Mock TenantService methods + with patch("services.account_service.TenantService.check_member_permission") as mock_check_permission: + # Execute test and verify exception + self._assert_exception_raised( + AccountAlreadyInTenantError, + RegisterService.invite_new_member, + tenant=mock_tenant, + email="existing@example.com", + language="en-US", + role="normal", + inviter=mock_inviter, + ) + + def test_invite_new_member_no_inviter(self): + """Test inviting a member without providing an inviter.""" + # Setup test data + mock_tenant = MagicMock() + + # Execute test and verify exception + self._assert_exception_raised( + ValueError, + RegisterService.invite_new_member, + tenant=mock_tenant, + email="test@example.com", + language="en-US", + role="normal", + inviter=None, + ) + + # ==================== Token Management Tests ==================== + + def test_generate_invite_token_success(self, mock_redis_dependencies): + """Test successful invite token generation.""" + # Setup test data + mock_tenant = MagicMock() + mock_tenant.id = "tenant-456" + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="user-123", email="test@example.com" + ) + + # Mock uuid generation + with patch("services.account_service.uuid.uuid4") as mock_uuid: + mock_uuid.return_value = "test-uuid-123" + + # Execute test + result = RegisterService.generate_invite_token(mock_tenant, mock_account) + + # Verify results + assert result == "test-uuid-123" + mock_redis_dependencies.setex.assert_called_once() + + # Verify the stored data + call_args = mock_redis_dependencies.setex.call_args + assert call_args[0][0] == "member_invite:token:test-uuid-123" + stored_data = json.loads(call_args[0][2]) + assert stored_data["account_id"] == "user-123" + assert stored_data["email"] == "test@example.com" + assert stored_data["workspace_id"] == "tenant-456" + + def test_is_valid_invite_token_valid(self, mock_redis_dependencies): + """Test checking valid invite token.""" + # Setup mock + mock_redis_dependencies.get.return_value = b'{"test": "data"}' + + # Execute test + result = RegisterService.is_valid_invite_token("valid-token") + + # Verify results + assert result is True + mock_redis_dependencies.get.assert_called_once_with("member_invite:token:valid-token") + + def test_is_valid_invite_token_invalid(self, mock_redis_dependencies): + """Test checking invalid invite token.""" + # Setup mock + mock_redis_dependencies.get.return_value = None + + # Execute test + result = RegisterService.is_valid_invite_token("invalid-token") + + # Verify results + assert result is False + mock_redis_dependencies.get.assert_called_once_with("member_invite:token:invalid-token") + + def test_revoke_token_with_workspace_and_email(self, mock_redis_dependencies): + """Test revoking token with workspace ID and email.""" + # Execute test + RegisterService.revoke_token("workspace-123", "test@example.com", "token-123") + + # Verify results + mock_redis_dependencies.delete.assert_called_once() + call_args = mock_redis_dependencies.delete.call_args + assert "workspace-123" in call_args[0][0] + # The email is hashed, so we check for the hash pattern instead + assert "member_invite_token:" in call_args[0][0] + + def test_revoke_token_without_workspace_and_email(self, mock_redis_dependencies): + """Test revoking token without workspace ID and email.""" + # Execute test + RegisterService.revoke_token("", "", "token-123") + + # Verify results + mock_redis_dependencies.delete.assert_called_once_with("member_invite:token:token-123") + + # ==================== Invitation Validation Tests ==================== + + def test_get_invitation_if_token_valid_success(self, mock_db_dependencies, mock_redis_dependencies): + """Test successful invitation validation.""" + # Setup test data + mock_tenant = MagicMock() + mock_tenant.id = "tenant-456" + mock_tenant.status = "normal" + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="user-123", email="test@example.com" + ) + + with patch("services.account_service.RegisterService._get_invitation_by_token") as mock_get_invitation_by_token: + # Mock the invitation data returned by _get_invitation_by_token + invitation_data = { + "account_id": "user-123", + "email": "test@example.com", + "workspace_id": "tenant-456", + } + mock_get_invitation_by_token.return_value = invitation_data + + # Mock database queries - complex query mocking + mock_query1 = MagicMock() + mock_query1.filter.return_value.first.return_value = mock_tenant + + mock_query2 = MagicMock() + mock_query2.join.return_value.filter.return_value.first.return_value = (mock_account, "normal") + + mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2] + + # Execute test + result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") + + # Verify results + assert result is not None + assert result["account"] == mock_account + assert result["tenant"] == mock_tenant + assert result["data"] == invitation_data + + def test_get_invitation_if_token_valid_no_token_data(self, mock_redis_dependencies): + """Test invitation validation with no token data.""" + # Setup mock + mock_redis_dependencies.get.return_value = None + + # Execute test + result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") + + # Verify results + assert result is None + + def test_get_invitation_if_token_valid_tenant_not_found(self, mock_db_dependencies, mock_redis_dependencies): + """Test invitation validation when tenant is not found.""" + # Setup mock Redis data + invitation_data = { + "account_id": "user-123", + "email": "test@example.com", + "workspace_id": "tenant-456", + } + mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode() + + # Mock database queries - no tenant found + mock_query = MagicMock() + mock_query.filter.return_value.first.return_value = None + mock_db_dependencies["db"].session.query.return_value = mock_query + + # Execute test + result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") + + # Verify results + assert result is None + + def test_get_invitation_if_token_valid_account_not_found(self, mock_db_dependencies, mock_redis_dependencies): + """Test invitation validation when account is not found.""" + # Setup test data + mock_tenant = MagicMock() + mock_tenant.id = "tenant-456" + mock_tenant.status = "normal" + + # Mock Redis data + invitation_data = { + "account_id": "user-123", + "email": "test@example.com", + "workspace_id": "tenant-456", + } + mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode() + + # Mock database queries + mock_query1 = MagicMock() + mock_query1.filter.return_value.first.return_value = mock_tenant + + mock_query2 = MagicMock() + mock_query2.join.return_value.filter.return_value.first.return_value = None # No account found + + mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2] + + # Execute test + result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") + + # Verify results + assert result is None + + def test_get_invitation_if_token_valid_account_id_mismatch(self, mock_db_dependencies, mock_redis_dependencies): + """Test invitation validation when account ID doesn't match.""" + # Setup test data + mock_tenant = MagicMock() + mock_tenant.id = "tenant-456" + mock_tenant.status = "normal" + mock_account = TestAccountAssociatedDataFactory.create_account_mock( + account_id="different-user-456", email="test@example.com" + ) + + # Mock Redis data with different account ID + invitation_data = { + "account_id": "user-123", + "email": "test@example.com", + "workspace_id": "tenant-456", + } + mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode() + + # Mock database queries + mock_query1 = MagicMock() + mock_query1.filter.return_value.first.return_value = mock_tenant + + mock_query2 = MagicMock() + mock_query2.join.return_value.filter.return_value.first.return_value = (mock_account, "normal") + + mock_db_dependencies["db"].session.query.side_effect = [mock_query1, mock_query2] + + # Execute test + result = RegisterService.get_invitation_if_token_valid("tenant-456", "test@example.com", "token-123") + + # Verify results + assert result is None + + # ==================== Helper Method Tests ==================== + + def test_get_invitation_token_key(self): + """Test the _get_invitation_token_key helper method.""" + # Execute test + result = RegisterService._get_invitation_token_key("test-token") + + # Verify results + assert result == "member_invite:token:test-token" + + def test_get_invitation_by_token_with_workspace_and_email(self, mock_redis_dependencies): + """Test _get_invitation_by_token with workspace ID and email.""" + # Setup mock + mock_redis_dependencies.get.return_value = b"user-123" + + # Execute test + result = RegisterService._get_invitation_by_token("token-123", "workspace-456", "test@example.com") + + # Verify results + assert result is not None + assert result["account_id"] == "user-123" + assert result["email"] == "test@example.com" + assert result["workspace_id"] == "workspace-456" + + def test_get_invitation_by_token_without_workspace_and_email(self, mock_redis_dependencies): + """Test _get_invitation_by_token without workspace ID and email.""" + # Setup mock + invitation_data = { + "account_id": "user-123", + "email": "test@example.com", + "workspace_id": "tenant-456", + } + mock_redis_dependencies.get.return_value = json.dumps(invitation_data).encode() + + # Execute test + result = RegisterService._get_invitation_by_token("token-123") + + # Verify results + assert result is not None + assert result == invitation_data + + def test_get_invitation_by_token_no_data(self, mock_redis_dependencies): + """Test _get_invitation_by_token with no data.""" + # Setup mock + mock_redis_dependencies.get.return_value = None + + # Execute test + result = RegisterService._get_invitation_by_token("token-123") + + # Verify results + assert result is None diff --git a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py index 223020c2c5..2c87eaf805 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_deletion.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_deletion.py @@ -10,7 +10,8 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE @pytest.fixture def workflow_setup(): - workflow_service = WorkflowService() + mock_session_maker = MagicMock() + workflow_service = WorkflowService(mock_session_maker) session = MagicMock(spec=Session) tenant_id = "test-tenant-id" workflow_id = "test-workflow-id" diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index c5c9cf1050..8b1348b75b 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -1,14 +1,14 @@ import dataclasses import secrets -from unittest import mock -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest +from sqlalchemy import Engine from sqlalchemy.orm import Session from core.variables import StringSegment from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.nodes import NodeType +from core.workflow.nodes.enums import NodeType from models.enums import DraftVariableType from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable from services.workflow_draft_variable_service import ( @@ -18,13 +18,25 @@ from services.workflow_draft_variable_service import ( ) +@pytest.fixture +def mock_engine() -> Engine: + return Mock(spec=Engine) + + +@pytest.fixture +def mock_session(mock_engine) -> Session: + mock_session = Mock(spec=Session) + mock_session.get_bind.return_value = mock_engine + return mock_session + + class TestDraftVariableSaver: def _get_test_app_id(self): suffix = secrets.token_hex(6) return f"test_app_id_{suffix}" def test__should_variable_be_visible(self): - mock_session = mock.MagicMock(spec=Session) + mock_session = MagicMock(spec=Session) test_app_id = self._get_test_app_id() saver = DraftVariableSaver( session=mock_session, @@ -70,7 +82,7 @@ class TestDraftVariableSaver: ), ] - mock_session = mock.MagicMock(spec=Session) + mock_session = MagicMock(spec=Session) test_app_id = self._get_test_app_id() saver = DraftVariableSaver( session=mock_session, @@ -105,9 +117,8 @@ class TestWorkflowDraftVariableService: conversation_variables=[], ) - def test_reset_conversation_variable(self): + def test_reset_conversation_variable(self, mock_session): """Test resetting a conversation variable""" - mock_session = Mock(spec=Session) service = WorkflowDraftVariableService(mock_session) test_app_id = self._get_test_app_id() @@ -131,9 +142,8 @@ class TestWorkflowDraftVariableService: mock_reset_conv.assert_called_once_with(workflow, variable) assert result == expected_result - def test_reset_node_variable_with_no_execution_id(self): + def test_reset_node_variable_with_no_execution_id(self, mock_session): """Test resetting a node variable with no execution ID - should delete variable""" - mock_session = Mock(spec=Session) service = WorkflowDraftVariableService(mock_session) test_app_id = self._get_test_app_id() @@ -158,11 +168,26 @@ class TestWorkflowDraftVariableService: mock_session.flush.assert_called_once() assert result is None - def test_reset_node_variable_with_missing_execution_record(self): + def test_reset_node_variable_with_missing_execution_record( + self, + mock_engine, + mock_session, + monkeypatch, + ): """Test resetting a node variable when execution record doesn't exist""" - mock_session = Mock(spec=Session) + mock_repo_session = Mock(spec=Session) + + mock_session_maker = MagicMock() + # Mock the context manager protocol for sessionmaker + mock_session_maker.return_value.__enter__.return_value = mock_repo_session + mock_session_maker.return_value.__exit__.return_value = None + monkeypatch.setattr("services.workflow_draft_variable_service.sessionmaker", mock_session_maker) service = WorkflowDraftVariableService(mock_session) + # Mock the repository to return None (no execution record found) + service._api_node_execution_repo = Mock() + service._api_node_execution_repo.get_execution_by_id.return_value = None + test_app_id = self._get_test_app_id() workflow = self._create_test_workflow(test_app_id) @@ -171,24 +196,41 @@ class TestWorkflowDraftVariableService: variable = WorkflowDraftVariable.new_node_variable( app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id" ) - - # Mock session.scalars to return None (no execution record found) - mock_scalars = Mock() - mock_scalars.first.return_value = None - mock_session.scalars.return_value = mock_scalars + # Variable is editable by default from factory method result = service._reset_node_var_or_sys_var(workflow, variable) + mock_session_maker.assert_called_once_with(bind=mock_engine, expire_on_commit=False) # Should delete the variable and return None mock_session.delete.assert_called_once_with(instance=variable) mock_session.flush.assert_called_once() assert result is None - def test_reset_node_variable_with_valid_execution_record(self): + def test_reset_node_variable_with_valid_execution_record( + self, + mock_session, + monkeypatch, + ): """Test resetting a node variable with valid execution record - should restore from execution""" - mock_session = Mock(spec=Session) + mock_repo_session = Mock(spec=Session) + + mock_session_maker = MagicMock() + # Mock the context manager protocol for sessionmaker + mock_session_maker.return_value.__enter__.return_value = mock_repo_session + mock_session_maker.return_value.__exit__.return_value = None + mock_session_maker = monkeypatch.setattr( + "services.workflow_draft_variable_service.sessionmaker", mock_session_maker + ) service = WorkflowDraftVariableService(mock_session) + # Create mock execution record + mock_execution = Mock(spec=WorkflowNodeExecutionModel) + mock_execution.outputs_dict = {"test_var": "output_value"} + + # Mock the repository to return the execution record + service._api_node_execution_repo = Mock() + service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution + test_app_id = self._get_test_app_id() workflow = self._create_test_workflow(test_app_id) @@ -197,16 +239,7 @@ class TestWorkflowDraftVariableService: variable = WorkflowDraftVariable.new_node_variable( app_id=test_app_id, node_id="test_node_id", name="test_var", value=test_value, node_execution_id="exec-id" ) - - # Create mock execution record - mock_execution = Mock(spec=WorkflowNodeExecutionModel) - mock_execution.process_data_dict = {"test_var": "process_value"} - mock_execution.outputs_dict = {"test_var": "output_value"} - - # Mock session.scalars to return the execution record - mock_scalars = Mock() - mock_scalars.first.return_value = mock_execution - mock_session.scalars.return_value = mock_scalars + # Variable is editable by default from factory method # Mock workflow methods mock_node_config = {"type": "test_node"} @@ -224,9 +257,8 @@ class TestWorkflowDraftVariableService: # Should return the updated variable assert result == variable - def test_reset_non_editable_system_variable_raises_error(self): + def test_reset_non_editable_system_variable_raises_error(self, mock_session): """Test that resetting a non-editable system variable raises an error""" - mock_session = Mock(spec=Session) service = WorkflowDraftVariableService(mock_session) test_app_id = self._get_test_app_id() @@ -242,24 +274,13 @@ class TestWorkflowDraftVariableService: editable=False, # Non-editable system variable ) - # Mock the service to properly check system variable editability - with patch.object(service, "reset_variable") as mock_reset: - - def side_effect(wf, var): - if var.get_variable_type() == DraftVariableType.SYS and not is_system_variable_editable(var.name): - raise VariableResetError(f"cannot reset system variable, variable_id={var.id}") - return var - - mock_reset.side_effect = side_effect - - with pytest.raises(VariableResetError) as exc_info: - service.reset_variable(workflow, variable) - assert "cannot reset system variable" in str(exc_info.value) - assert f"variable_id={variable.id}" in str(exc_info.value) + with pytest.raises(VariableResetError) as exc_info: + service.reset_variable(workflow, variable) + assert "cannot reset system variable" in str(exc_info.value) + assert f"variable_id={variable.id}" in str(exc_info.value) - def test_reset_editable_system_variable_succeeds(self): + def test_reset_editable_system_variable_succeeds(self, mock_session): """Test that resetting an editable system variable succeeds""" - mock_session = Mock(spec=Session) service = WorkflowDraftVariableService(mock_session) test_app_id = self._get_test_app_id() @@ -279,10 +300,9 @@ class TestWorkflowDraftVariableService: mock_execution = Mock(spec=WorkflowNodeExecutionModel) mock_execution.outputs_dict = {"sys.files": "[]"} - # Mock session.scalars to return the execution record - mock_scalars = Mock() - mock_scalars.first.return_value = mock_execution - mock_session.scalars.return_value = mock_scalars + # Mock the repository to return the execution record + service._api_node_execution_repo = Mock() + service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution result = service._reset_node_var_or_sys_var(workflow, variable) @@ -291,9 +311,8 @@ class TestWorkflowDraftVariableService: assert variable.last_edited_at is None mock_session.flush.assert_called() - def test_reset_query_system_variable_succeeds(self): + def test_reset_query_system_variable_succeeds(self, mock_session): """Test that resetting query system variable (another editable one) succeeds""" - mock_session = Mock(spec=Session) service = WorkflowDraftVariableService(mock_session) test_app_id = self._get_test_app_id() @@ -313,10 +332,9 @@ class TestWorkflowDraftVariableService: mock_execution = Mock(spec=WorkflowNodeExecutionModel) mock_execution.outputs_dict = {"sys.query": "reset query"} - # Mock session.scalars to return the execution record - mock_scalars = Mock() - mock_scalars.first.return_value = mock_execution - mock_session.scalars.return_value = mock_scalars + # Mock the repository to return the execution record + service._api_node_execution_repo = Mock() + service._api_node_execution_repo.get_execution_by_id.return_value = mock_execution result = service._reset_node_var_or_sys_var(workflow, variable) diff --git a/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py new file mode 100644 index 0000000000..32d2f8b7e0 --- /dev/null +++ b/api/tests/unit_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -0,0 +1,288 @@ +from datetime import datetime +from unittest.mock import MagicMock +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from models.workflow import WorkflowNodeExecutionModel +from repositories.sqlalchemy_api_workflow_node_execution_repository import ( + DifyAPISQLAlchemyWorkflowNodeExecutionRepository, +) + + +class TestSQLAlchemyWorkflowNodeExecutionServiceRepository: + @pytest.fixture + def repository(self): + mock_session_maker = MagicMock() + return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker) + + @pytest.fixture + def mock_execution(self): + execution = MagicMock(spec=WorkflowNodeExecutionModel) + execution.id = str(uuid4()) + execution.tenant_id = "tenant-123" + execution.app_id = "app-456" + execution.workflow_id = "workflow-789" + execution.workflow_run_id = "run-101" + execution.node_id = "node-202" + execution.index = 1 + execution.created_at = "2023-01-01T00:00:00Z" + return execution + + def test_get_node_last_execution_found(self, repository, mock_execution): + """Test getting the last execution for a node when it exists.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = mock_execution + + # Act + result = repository.get_node_last_execution( + tenant_id="tenant-123", + app_id="app-456", + workflow_id="workflow-789", + node_id="node-202", + ) + + # Assert + assert result == mock_execution + mock_session.scalar.assert_called_once() + # Verify the query was constructed correctly + call_args = mock_session.scalar.call_args[0][0] + assert hasattr(call_args, "compile") # It's a SQLAlchemy statement + + def test_get_node_last_execution_not_found(self, repository): + """Test getting the last execution for a node when it doesn't exist.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + # Act + result = repository.get_node_last_execution( + tenant_id="tenant-123", + app_id="app-456", + workflow_id="workflow-789", + node_id="node-202", + ) + + # Assert + assert result is None + mock_session.scalar.assert_called_once() + + def test_get_executions_by_workflow_run(self, repository, mock_execution): + """Test getting all executions for a workflow run.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + executions = [mock_execution] + mock_session.execute.return_value.scalars.return_value.all.return_value = executions + + # Act + result = repository.get_executions_by_workflow_run( + tenant_id="tenant-123", + app_id="app-456", + workflow_run_id="run-101", + ) + + # Assert + assert result == executions + mock_session.execute.assert_called_once() + # Verify the query was constructed correctly + call_args = mock_session.execute.call_args[0][0] + assert hasattr(call_args, "compile") # It's a SQLAlchemy statement + + def test_get_executions_by_workflow_run_empty(self, repository): + """Test getting executions for a workflow run when none exist.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + mock_session.execute.return_value.scalars.return_value.all.return_value = [] + + # Act + result = repository.get_executions_by_workflow_run( + tenant_id="tenant-123", + app_id="app-456", + workflow_run_id="run-101", + ) + + # Assert + assert result == [] + mock_session.execute.assert_called_once() + + def test_get_execution_by_id_found(self, repository, mock_execution): + """Test getting execution by ID when it exists.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = mock_execution + + # Act + result = repository.get_execution_by_id(mock_execution.id) + + # Assert + assert result == mock_execution + mock_session.scalar.assert_called_once() + + def test_get_execution_by_id_not_found(self, repository): + """Test getting execution by ID when it doesn't exist.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + mock_session.scalar.return_value = None + + # Act + result = repository.get_execution_by_id("non-existent-id") + + # Assert + assert result is None + mock_session.scalar.assert_called_once() + + def test_repository_implements_protocol(self, repository): + """Test that the repository implements the required protocol methods.""" + # Verify all protocol methods are implemented + assert hasattr(repository, "get_node_last_execution") + assert hasattr(repository, "get_executions_by_workflow_run") + assert hasattr(repository, "get_execution_by_id") + + # Verify methods are callable + assert callable(repository.get_node_last_execution) + assert callable(repository.get_executions_by_workflow_run) + assert callable(repository.get_execution_by_id) + assert callable(repository.delete_expired_executions) + assert callable(repository.delete_executions_by_app) + assert callable(repository.get_expired_executions_batch) + assert callable(repository.delete_executions_by_ids) + + def test_delete_expired_executions(self, repository): + """Test deleting expired executions.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + + # Mock the select query to return some IDs first time, then empty to stop loop + execution_ids = ["id1", "id2"] # Less than batch_size to trigger break + + # Mock execute method to handle both select and delete statements + def mock_execute(stmt): + mock_result = MagicMock() + # For select statements, return execution IDs + if hasattr(stmt, "limit"): # This is our select statement + mock_result.scalars.return_value.all.return_value = execution_ids + else: # This is our delete statement + mock_result.rowcount = 2 + return mock_result + + mock_session.execute.side_effect = mock_execute + + before_date = datetime(2023, 1, 1) + + # Act + result = repository.delete_expired_executions( + tenant_id="tenant-123", + before_date=before_date, + batch_size=1000, + ) + + # Assert + assert result == 2 + assert mock_session.execute.call_count == 2 # One select call, one delete call + mock_session.commit.assert_called_once() + + def test_delete_executions_by_app(self, repository): + """Test deleting executions by app.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + + # Mock the select query to return some IDs first time, then empty to stop loop + execution_ids = ["id1", "id2"] + + # Mock execute method to handle both select and delete statements + def mock_execute(stmt): + mock_result = MagicMock() + # For select statements, return execution IDs + if hasattr(stmt, "limit"): # This is our select statement + mock_result.scalars.return_value.all.return_value = execution_ids + else: # This is our delete statement + mock_result.rowcount = 2 + return mock_result + + mock_session.execute.side_effect = mock_execute + + # Act + result = repository.delete_executions_by_app( + tenant_id="tenant-123", + app_id="app-456", + batch_size=1000, + ) + + # Assert + assert result == 2 + assert mock_session.execute.call_count == 2 # One select call, one delete call + mock_session.commit.assert_called_once() + + def test_get_expired_executions_batch(self, repository): + """Test getting expired executions batch for backup.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + + # Create mock execution objects + mock_execution1 = MagicMock() + mock_execution1.id = "exec-1" + mock_execution2 = MagicMock() + mock_execution2.id = "exec-2" + + mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2] + + before_date = datetime(2023, 1, 1) + + # Act + result = repository.get_expired_executions_batch( + tenant_id="tenant-123", + before_date=before_date, + batch_size=1000, + ) + + # Assert + assert len(result) == 2 + assert result[0].id == "exec-1" + assert result[1].id == "exec-2" + mock_session.execute.assert_called_once() + + def test_delete_executions_by_ids(self, repository): + """Test deleting executions by IDs.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + + # Mock the delete query result + mock_result = MagicMock() + mock_result.rowcount = 3 + mock_session.execute.return_value = mock_result + + execution_ids = ["id1", "id2", "id3"] + + # Act + result = repository.delete_executions_by_ids(execution_ids) + + # Assert + assert result == 3 + mock_session.execute.assert_called_once() + mock_session.commit.assert_called_once() + + def test_delete_executions_by_ids_empty_list(self, repository): + """Test deleting executions with empty ID list.""" + # Arrange + mock_session = MagicMock(spec=Session) + repository._session_maker.return_value.__enter__.return_value = mock_session + + # Act + result = repository.delete_executions_by_ids([]) + + # Assert + assert result == 0 + mock_session.query.assert_not_called() + mock_session.commit.assert_not_called() diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py index 13393668ea..9700cbaf0e 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_service.py @@ -10,7 +10,8 @@ from services.workflow_service import WorkflowService class TestWorkflowService: @pytest.fixture def workflow_service(self): - return WorkflowService() + mock_session_maker = MagicMock() + return WorkflowService(mock_session_maker) @pytest.fixture def mock_app(self): diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index 728c58fc5b..93284eed4b 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -27,11 +27,11 @@ def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LL return LLMUsage( prompt_tokens=prompt_tokens, prompt_unit_price=Decimal("0.001"), - prompt_price_unit=Decimal("1"), + prompt_price_unit=Decimal(1), prompt_price=Decimal(str(prompt_tokens)) * Decimal("0.001"), completion_tokens=completion_tokens, completion_unit_price=Decimal("0.002"), - completion_price_unit=Decimal("1"), + completion_price_unit=Decimal(1), completion_price=Decimal(str(completion_tokens)) * Decimal("0.002"), total_tokens=prompt_tokens + completion_tokens, total_price=Decimal(str(prompt_tokens)) * Decimal("0.001") + Decimal(str(completion_tokens)) * Decimal("0.002"), diff --git a/api/uv.lock b/api/uv.lock index e108e0c445..21b6b20f53 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1498,7 +1498,7 @@ dev = [ { name = "pytest-cov", specifier = "~=4.1.0" }, { name = "pytest-env", specifier = "~=1.1.3" }, { name = "pytest-mock", specifier = "~=3.14.0" }, - { name = "ruff", specifier = "~=0.11.5" }, + { name = "ruff", specifier = "~=0.12.3" }, { name = "scipy-stubs", specifier = ">=1.15.3.0" }, { name = "types-aiofiles", specifier = "~=24.1.0" }, { name = "types-beautifulsoup4", specifier = "~=4.12.0" }, @@ -5088,27 +5088,27 @@ wheels = [ [[package]] name = "ruff" -version = "0.11.13" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ed/da/9c6f995903b4d9474b39da91d2d626659af3ff1eeb43e9ae7c119349dba6/ruff-0.11.13.tar.gz", hash = "sha256:26fa247dc68d1d4e72c179e08889a25ac0c7ba4d78aecfc835d49cbfd60bf514", size = 4282054, upload-time = "2025-06-05T21:00:15.721Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7d/ce/a11d381192966e0b4290842cc8d4fac7dc9214ddf627c11c1afff87da29b/ruff-0.11.13-py3-none-linux_armv6l.whl", hash = "sha256:4bdfbf1240533f40042ec00c9e09a3aade6f8c10b6414cf11b519488d2635d46", size = 10292516, upload-time = "2025-06-05T20:59:32.944Z" }, - { url = "https://files.pythonhosted.org/packages/78/db/87c3b59b0d4e753e40b6a3b4a2642dfd1dcaefbff121ddc64d6c8b47ba00/ruff-0.11.13-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:aef9c9ed1b5ca28bb15c7eac83b8670cf3b20b478195bd49c8d756ba0a36cf48", size = 11106083, upload-time = "2025-06-05T20:59:37.03Z" }, - { url = "https://files.pythonhosted.org/packages/77/79/d8cec175856ff810a19825d09ce700265f905c643c69f45d2b737e4a470a/ruff-0.11.13-py3-none-macosx_11_0_arm64.whl", hash = "sha256:53b15a9dfdce029c842e9a5aebc3855e9ab7771395979ff85b7c1dedb53ddc2b", size = 10436024, upload-time = "2025-06-05T20:59:39.741Z" }, - { url = "https://files.pythonhosted.org/packages/8b/5b/f6d94f2980fa1ee854b41568368a2e1252681b9238ab2895e133d303538f/ruff-0.11.13-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab153241400789138d13f362c43f7edecc0edfffce2afa6a68434000ecd8f69a", size = 10646324, upload-time = "2025-06-05T20:59:42.185Z" }, - { url = "https://files.pythonhosted.org/packages/6c/9c/b4c2acf24ea4426016d511dfdc787f4ce1ceb835f3c5fbdbcb32b1c63bda/ruff-0.11.13-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6c51f93029d54a910d3d24f7dd0bb909e31b6cd989a5e4ac513f4eb41629f0dc", size = 10174416, upload-time = "2025-06-05T20:59:44.319Z" }, - { url = "https://files.pythonhosted.org/packages/f3/10/e2e62f77c65ede8cd032c2ca39c41f48feabedb6e282bfd6073d81bb671d/ruff-0.11.13-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1808b3ed53e1a777c2ef733aca9051dc9bf7c99b26ece15cb59a0320fbdbd629", size = 11724197, upload-time = "2025-06-05T20:59:46.935Z" }, - { url = "https://files.pythonhosted.org/packages/bb/f0/466fe8469b85c561e081d798c45f8a1d21e0b4a5ef795a1d7f1a9a9ec182/ruff-0.11.13-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d28ce58b5ecf0f43c1b71edffabe6ed7f245d5336b17805803312ec9bc665933", size = 12511615, upload-time = "2025-06-05T20:59:49.534Z" }, - { url = "https://files.pythonhosted.org/packages/17/0e/cefe778b46dbd0cbcb03a839946c8f80a06f7968eb298aa4d1a4293f3448/ruff-0.11.13-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:55e4bc3a77842da33c16d55b32c6cac1ec5fb0fbec9c8c513bdce76c4f922165", size = 12117080, upload-time = "2025-06-05T20:59:51.654Z" }, - { url = "https://files.pythonhosted.org/packages/5d/2c/caaeda564cbe103bed145ea557cb86795b18651b0f6b3ff6a10e84e5a33f/ruff-0.11.13-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:633bf2c6f35678c56ec73189ba6fa19ff1c5e4807a78bf60ef487b9dd272cc71", size = 11326315, upload-time = "2025-06-05T20:59:54.469Z" }, - { url = "https://files.pythonhosted.org/packages/75/f0/782e7d681d660eda8c536962920c41309e6dd4ebcea9a2714ed5127d44bd/ruff-0.11.13-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4ffbc82d70424b275b089166310448051afdc6e914fdab90e08df66c43bb5ca9", size = 11555640, upload-time = "2025-06-05T20:59:56.986Z" }, - { url = "https://files.pythonhosted.org/packages/5d/d4/3d580c616316c7f07fb3c99dbecfe01fbaea7b6fd9a82b801e72e5de742a/ruff-0.11.13-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4a9ddd3ec62a9a89578c85842b836e4ac832d4a2e0bfaad3b02243f930ceafcc", size = 10507364, upload-time = "2025-06-05T20:59:59.154Z" }, - { url = "https://files.pythonhosted.org/packages/5a/dc/195e6f17d7b3ea6b12dc4f3e9de575db7983db187c378d44606e5d503319/ruff-0.11.13-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d237a496e0778d719efb05058c64d28b757c77824e04ffe8796c7436e26712b7", size = 10141462, upload-time = "2025-06-05T21:00:01.481Z" }, - { url = "https://files.pythonhosted.org/packages/f4/8e/39a094af6967faa57ecdeacb91bedfb232474ff8c3d20f16a5514e6b3534/ruff-0.11.13-py3-none-musllinux_1_2_i686.whl", hash = "sha256:26816a218ca6ef02142343fd24c70f7cd8c5aa6c203bca284407adf675984432", size = 11121028, upload-time = "2025-06-05T21:00:04.06Z" }, - { url = "https://files.pythonhosted.org/packages/5a/c0/b0b508193b0e8a1654ec683ebab18d309861f8bd64e3a2f9648b80d392cb/ruff-0.11.13-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:51c3f95abd9331dc5b87c47ac7f376db5616041173826dfd556cfe3d4977f492", size = 11602992, upload-time = "2025-06-05T21:00:06.249Z" }, - { url = "https://files.pythonhosted.org/packages/7c/91/263e33ab93ab09ca06ce4f8f8547a858cc198072f873ebc9be7466790bae/ruff-0.11.13-py3-none-win32.whl", hash = "sha256:96c27935418e4e8e77a26bb05962817f28b8ef3843a6c6cc49d8783b5507f250", size = 10474944, upload-time = "2025-06-05T21:00:08.459Z" }, - { url = "https://files.pythonhosted.org/packages/46/f4/7c27734ac2073aae8efb0119cae6931b6fb48017adf048fdf85c19337afc/ruff-0.11.13-py3-none-win_amd64.whl", hash = "sha256:29c3189895a8a6a657b7af4e97d330c8a3afd2c9c8f46c81e2fc5a31866517e3", size = 11548669, upload-time = "2025-06-05T21:00:11.147Z" }, - { url = "https://files.pythonhosted.org/packages/ec/bf/b273dd11673fed8a6bd46032c0ea2a04b2ac9bfa9c628756a5856ba113b0/ruff-0.11.13-py3-none-win_arm64.whl", hash = "sha256:b4385285e9179d608ff1d2fb9922062663c658605819a6876d8beef0c30b7f3b", size = 10683928, upload-time = "2025-06-05T21:00:13.758Z" }, +version = "0.12.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c3/2a/43955b530c49684d3c38fcda18c43caf91e99204c2a065552528e0552d4f/ruff-0.12.3.tar.gz", hash = "sha256:f1b5a4b6668fd7b7ea3697d8d98857390b40c1320a63a178eee6be0899ea2d77", size = 4459341, upload-time = "2025-07-11T13:21:16.086Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/fd/b44c5115539de0d598d75232a1cc7201430b6891808df111b8b0506aae43/ruff-0.12.3-py3-none-linux_armv6l.whl", hash = "sha256:47552138f7206454eaf0c4fe827e546e9ddac62c2a3d2585ca54d29a890137a2", size = 10430499, upload-time = "2025-07-11T13:20:26.321Z" }, + { url = "https://files.pythonhosted.org/packages/43/c5/9eba4f337970d7f639a37077be067e4ec80a2ad359e4cc6c5b56805cbc66/ruff-0.12.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:0a9153b000c6fe169bb307f5bd1b691221c4286c133407b8827c406a55282041", size = 11213413, upload-time = "2025-07-11T13:20:30.017Z" }, + { url = "https://files.pythonhosted.org/packages/e2/2c/fac3016236cf1fe0bdc8e5de4f24c76ce53c6dd9b5f350d902549b7719b2/ruff-0.12.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fa6b24600cf3b750e48ddb6057e901dd5b9aa426e316addb2a1af185a7509882", size = 10586941, upload-time = "2025-07-11T13:20:33.046Z" }, + { url = "https://files.pythonhosted.org/packages/c5/0f/41fec224e9dfa49a139f0b402ad6f5d53696ba1800e0f77b279d55210ca9/ruff-0.12.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2506961bf6ead54887ba3562604d69cb430f59b42133d36976421bc8bd45901", size = 10783001, upload-time = "2025-07-11T13:20:35.534Z" }, + { url = "https://files.pythonhosted.org/packages/0d/ca/dd64a9ce56d9ed6cad109606ac014860b1c217c883e93bf61536400ba107/ruff-0.12.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c4faaff1f90cea9d3033cbbcdf1acf5d7fb11d8180758feb31337391691f3df0", size = 10269641, upload-time = "2025-07-11T13:20:38.459Z" }, + { url = "https://files.pythonhosted.org/packages/63/5c/2be545034c6bd5ce5bb740ced3e7014d7916f4c445974be11d2a406d5088/ruff-0.12.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40dced4a79d7c264389de1c59467d5d5cefd79e7e06d1dfa2c75497b5269a5a6", size = 11875059, upload-time = "2025-07-11T13:20:41.517Z" }, + { url = "https://files.pythonhosted.org/packages/8e/d4/a74ef1e801ceb5855e9527dae105eaff136afcb9cc4d2056d44feb0e4792/ruff-0.12.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:0262d50ba2767ed0fe212aa7e62112a1dcbfd46b858c5bf7bbd11f326998bafc", size = 12658890, upload-time = "2025-07-11T13:20:44.442Z" }, + { url = "https://files.pythonhosted.org/packages/13/c8/1057916416de02e6d7c9bcd550868a49b72df94e3cca0aeb77457dcd9644/ruff-0.12.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:12371aec33e1a3758597c5c631bae9a5286f3c963bdfb4d17acdd2d395406687", size = 12232008, upload-time = "2025-07-11T13:20:47.374Z" }, + { url = "https://files.pythonhosted.org/packages/f5/59/4f7c130cc25220392051fadfe15f63ed70001487eca21d1796db46cbcc04/ruff-0.12.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:560f13b6baa49785665276c963edc363f8ad4b4fc910a883e2625bdb14a83a9e", size = 11499096, upload-time = "2025-07-11T13:20:50.348Z" }, + { url = "https://files.pythonhosted.org/packages/d4/01/a0ad24a5d2ed6be03a312e30d32d4e3904bfdbc1cdbe63c47be9d0e82c79/ruff-0.12.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:023040a3499f6f974ae9091bcdd0385dd9e9eb4942f231c23c57708147b06311", size = 11688307, upload-time = "2025-07-11T13:20:52.945Z" }, + { url = "https://files.pythonhosted.org/packages/93/72/08f9e826085b1f57c9a0226e48acb27643ff19b61516a34c6cab9d6ff3fa/ruff-0.12.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:883d844967bffff5ab28bba1a4d246c1a1b2933f48cb9840f3fdc5111c603b07", size = 10661020, upload-time = "2025-07-11T13:20:55.799Z" }, + { url = "https://files.pythonhosted.org/packages/80/a0/68da1250d12893466c78e54b4a0ff381370a33d848804bb51279367fc688/ruff-0.12.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2120d3aa855ff385e0e562fdee14d564c9675edbe41625c87eeab744a7830d12", size = 10246300, upload-time = "2025-07-11T13:20:58.222Z" }, + { url = "https://files.pythonhosted.org/packages/6a/22/5f0093d556403e04b6fd0984fc0fb32fbb6f6ce116828fd54306a946f444/ruff-0.12.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6b16647cbb470eaf4750d27dddc6ebf7758b918887b56d39e9c22cce2049082b", size = 11263119, upload-time = "2025-07-11T13:21:01.503Z" }, + { url = "https://files.pythonhosted.org/packages/92/c9/f4c0b69bdaffb9968ba40dd5fa7df354ae0c73d01f988601d8fac0c639b1/ruff-0.12.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:e1417051edb436230023575b149e8ff843a324557fe0a265863b7602df86722f", size = 11746990, upload-time = "2025-07-11T13:21:04.524Z" }, + { url = "https://files.pythonhosted.org/packages/fe/84/7cc7bd73924ee6be4724be0db5414a4a2ed82d06b30827342315a1be9e9c/ruff-0.12.3-py3-none-win32.whl", hash = "sha256:dfd45e6e926deb6409d0616078a666ebce93e55e07f0fb0228d4b2608b2c248d", size = 10589263, upload-time = "2025-07-11T13:21:07.148Z" }, + { url = "https://files.pythonhosted.org/packages/07/87/c070f5f027bd81f3efee7d14cb4d84067ecf67a3a8efb43aadfc72aa79a6/ruff-0.12.3-py3-none-win_amd64.whl", hash = "sha256:a946cf1e7ba3209bdef039eb97647f1c77f6f540e5845ec9c114d3af8df873e7", size = 11695072, upload-time = "2025-07-11T13:21:11.004Z" }, + { url = "https://files.pythonhosted.org/packages/e0/30/f3eaf6563c637b6e66238ed6535f6775480db973c836336e4122161986fc/ruff-0.12.3-py3-none-win_arm64.whl", hash = "sha256:5f9c7c9c8f84c2d7f27e93674d27136fbf489720251544c4da7fb3d742e011b1", size = 10805855, upload-time = "2025-07-11T13:21:13.547Z" }, ] [[package]] diff --git a/docker/.env.example b/docker/.env.example index e2d7436067..38b1a39815 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -801,6 +801,19 @@ WORKFLOW_FILE_UPLOAD_LIMIT=10 # hybrid: Save new data to object storage, read from both object storage and RDBMS WORKFLOW_NODE_EXECUTION_STORAGE=rdbms +# Repository configuration +# Core workflow execution repository implementation +CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository + +# Core workflow node execution repository implementation +CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository + +# API workflow node execution repository implementation +API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository + +# API workflow run repository implementation +API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository + # HTTP request node in workflow configuration HTTP_REQUEST_NODE_MAX_BINARY_SIZE=10485760 HTTP_REQUEST_NODE_MAX_TEXT_SIZE=1048576 diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 3803c26a33..c54cd6621a 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -356,6 +356,10 @@ x-shared-env: &shared-api-worker-env WORKFLOW_PARALLEL_DEPTH_LIMIT: ${WORKFLOW_PARALLEL_DEPTH_LIMIT:-3} WORKFLOW_FILE_UPLOAD_LIMIT: ${WORKFLOW_FILE_UPLOAD_LIMIT:-10} WORKFLOW_NODE_EXECUTION_STORAGE: ${WORKFLOW_NODE_EXECUTION_STORAGE:-rdbms} + CORE_WORKFLOW_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository} + CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY:-core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository} + API_WORKFLOW_NODE_EXECUTION_REPOSITORY: ${API_WORKFLOW_NODE_EXECUTION_REPOSITORY:-repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository} + API_WORKFLOW_RUN_REPOSITORY: ${API_WORKFLOW_RUN_REPOSITORY:-repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository} HTTP_REQUEST_NODE_MAX_BINARY_SIZE: ${HTTP_REQUEST_NODE_MAX_BINARY_SIZE:-10485760} HTTP_REQUEST_NODE_MAX_TEXT_SIZE: ${HTTP_REQUEST_NODE_MAX_TEXT_SIZE:-1048576} HTTP_REQUEST_NODE_SSL_VERIFY: ${HTTP_REQUEST_NODE_SSL_VERIFY:-True} diff --git a/web/app/(commonLayout)/explore/installed/[appId]/page.tsx b/web/app/(commonLayout)/explore/installed/[appId]/page.tsx index 938a03992b..e288c62b5d 100644 --- a/web/app/(commonLayout)/explore/installed/[appId]/page.tsx +++ b/web/app/(commonLayout)/explore/installed/[appId]/page.tsx @@ -1,16 +1,18 @@ -import type { FC } from 'react' import React from 'react' import Main from '@/app/components/explore/installed-app' export type IInstalledAppProps = { - params: Promise<{ + params: { appId: string - }> + } } -const InstalledApp: FC = async ({ params }) => { +// Using Next.js page convention for async server components +async function InstalledApp({ params }: IInstalledAppProps) { + const appId = (await params).appId return ( -
+
) } -export default React.memo(InstalledApp) + +export default InstalledApp diff --git a/web/app/(shareLayout)/chat/[token]/page.tsx b/web/app/(shareLayout)/chat/[token]/page.tsx index 640c40378f..8ce67585f0 100644 --- a/web/app/(shareLayout)/chat/[token]/page.tsx +++ b/web/app/(shareLayout)/chat/[token]/page.tsx @@ -1,10 +1,13 @@ 'use client' import React from 'react' import ChatWithHistoryWrap from '@/app/components/base/chat/chat-with-history' +import AuthenticatedLayout from '../../components/authenticated-layout' const Chat = () => { return ( - + + + ) } diff --git a/web/app/(shareLayout)/chatbot/[token]/page.tsx b/web/app/(shareLayout)/chatbot/[token]/page.tsx index 6196afecc4..5323d0dacc 100644 --- a/web/app/(shareLayout)/chatbot/[token]/page.tsx +++ b/web/app/(shareLayout)/chatbot/[token]/page.tsx @@ -1,10 +1,13 @@ 'use client' import React from 'react' import EmbeddedChatbot from '@/app/components/base/chat/embedded-chatbot' +import AuthenticatedLayout from '../../components/authenticated-layout' const Chatbot = () => { return ( - + + + ) } diff --git a/web/app/(shareLayout)/completion/[token]/page.tsx b/web/app/(shareLayout)/completion/[token]/page.tsx index e8bc9d79f5..ae91338b9a 100644 --- a/web/app/(shareLayout)/completion/[token]/page.tsx +++ b/web/app/(shareLayout)/completion/[token]/page.tsx @@ -1,9 +1,12 @@ import React from 'react' import Main from '@/app/components/share/text-generation' +import AuthenticatedLayout from '../../components/authenticated-layout' const Completion = () => { return ( -
+ +
+ ) } diff --git a/web/app/(shareLayout)/components/authenticated-layout.tsx b/web/app/(shareLayout)/components/authenticated-layout.tsx new file mode 100644 index 0000000000..e3cfc8e6a8 --- /dev/null +++ b/web/app/(shareLayout)/components/authenticated-layout.tsx @@ -0,0 +1,84 @@ +'use client' + +import AppUnavailable from '@/app/components/base/app-unavailable' +import Loading from '@/app/components/base/loading' +import { removeAccessToken } from '@/app/components/share/utils' +import { useWebAppStore } from '@/context/web-app-context' +import { useGetUserCanAccessApp } from '@/service/access-control' +import { useGetWebAppInfo, useGetWebAppMeta, useGetWebAppParams } from '@/service/use-share' +import { usePathname, useRouter, useSearchParams } from 'next/navigation' +import React, { useCallback, useEffect } from 'react' +import { useTranslation } from 'react-i18next' + +const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => { + const { t } = useTranslation() + const updateAppInfo = useWebAppStore(s => s.updateAppInfo) + const updateAppParams = useWebAppStore(s => s.updateAppParams) + const updateWebAppMeta = useWebAppStore(s => s.updateWebAppMeta) + const updateUserCanAccessApp = useWebAppStore(s => s.updateUserCanAccessApp) + const { isFetching: isFetchingAppParams, data: appParams, error: appParamsError } = useGetWebAppParams() + const { isFetching: isFetchingAppInfo, data: appInfo, error: appInfoError } = useGetWebAppInfo() + const { isFetching: isFetchingAppMeta, data: appMeta, error: appMetaError } = useGetWebAppMeta() + const { data: userCanAccessApp, error: useCanAccessAppError } = useGetUserCanAccessApp({ appId: appInfo?.app_id, isInstalledApp: false }) + + useEffect(() => { + if (appInfo) + updateAppInfo(appInfo) + if (appParams) + updateAppParams(appParams) + if (appMeta) + updateWebAppMeta(appMeta) + updateUserCanAccessApp(Boolean(userCanAccessApp && userCanAccessApp?.result)) + }, [appInfo, appMeta, appParams, updateAppInfo, updateAppParams, updateUserCanAccessApp, updateWebAppMeta, userCanAccessApp]) + + const router = useRouter() + const pathname = usePathname() + const searchParams = useSearchParams() + const getSigninUrl = useCallback(() => { + const params = new URLSearchParams(searchParams) + params.delete('message') + params.set('redirect_url', pathname) + return `/webapp-signin?${params.toString()}` + }, [searchParams, pathname]) + + const backToHome = useCallback(() => { + removeAccessToken() + const url = getSigninUrl() + router.replace(url) + }, [getSigninUrl, router]) + + if (appInfoError) { + return
+ +
+ } + if (appParamsError) { + return
+ +
+ } + if (appMetaError) { + return
+ +
+ } + if (useCanAccessAppError) { + return
+ +
+ } + if (userCanAccessApp && !userCanAccessApp.result) { + return
+ + {t('common.userProfile.logout')} +
+ } + if (isFetchingAppInfo || isFetchingAppParams || isFetchingAppMeta) { + return
+ +
+ } + return <>{children} +} + +export default React.memo(AuthenticatedLayout) diff --git a/web/app/(shareLayout)/components/splash.tsx b/web/app/(shareLayout)/components/splash.tsx new file mode 100644 index 0000000000..4fe9efe4dd --- /dev/null +++ b/web/app/(shareLayout)/components/splash.tsx @@ -0,0 +1,80 @@ +'use client' +import type { FC, PropsWithChildren } from 'react' +import { useEffect } from 'react' +import { useCallback } from 'react' +import { useWebAppStore } from '@/context/web-app-context' +import { useRouter, useSearchParams } from 'next/navigation' +import AppUnavailable from '@/app/components/base/app-unavailable' +import { checkOrSetAccessToken, removeAccessToken, setAccessToken } from '@/app/components/share/utils' +import { useTranslation } from 'react-i18next' +import { fetchAccessToken } from '@/service/share' +import Loading from '@/app/components/base/loading' +import { AccessMode } from '@/models/access-control' + +const Splash: FC = ({ children }) => { + const { t } = useTranslation() + const shareCode = useWebAppStore(s => s.shareCode) + const webAppAccessMode = useWebAppStore(s => s.webAppAccessMode) + const searchParams = useSearchParams() + const router = useRouter() + const redirectUrl = searchParams.get('redirect_url') + const tokenFromUrl = searchParams.get('web_sso_token') + const message = searchParams.get('message') + const code = searchParams.get('code') + const getSigninUrl = useCallback(() => { + const params = new URLSearchParams(searchParams) + params.delete('message') + params.delete('code') + return `/webapp-signin?${params.toString()}` + }, [searchParams]) + + const backToHome = useCallback(() => { + removeAccessToken() + const url = getSigninUrl() + router.replace(url) + }, [getSigninUrl, router]) + + useEffect(() => { + (async () => { + if (message) + return + if (shareCode && tokenFromUrl && redirectUrl) { + localStorage.setItem('webapp_access_token', tokenFromUrl) + const tokenResp = await fetchAccessToken({ appCode: shareCode, webAppAccessToken: tokenFromUrl }) + await setAccessToken(shareCode, tokenResp.access_token) + router.replace(decodeURIComponent(redirectUrl)) + return + } + if (shareCode && redirectUrl && localStorage.getItem('webapp_access_token')) { + const tokenResp = await fetchAccessToken({ appCode: shareCode, webAppAccessToken: localStorage.getItem('webapp_access_token') }) + await setAccessToken(shareCode, tokenResp.access_token) + router.replace(decodeURIComponent(redirectUrl)) + return + } + if (webAppAccessMode === AccessMode.PUBLIC && redirectUrl) { + await checkOrSetAccessToken(shareCode) + router.replace(decodeURIComponent(redirectUrl)) + } + })() + }, [shareCode, redirectUrl, router, tokenFromUrl, message, webAppAccessMode]) + + if (message) { + return
+ + {code === '403' ? t('common.userProfile.logout') : t('share.login.backToHome')} +
+ } + if (tokenFromUrl) { + return
+ +
+ } + if (webAppAccessMode === AccessMode.PUBLIC && redirectUrl) { + return
+ +
+ } + return <>{children} +} + +export default Splash diff --git a/web/app/(shareLayout)/layout.tsx b/web/app/(shareLayout)/layout.tsx index d057ba7599..5af913cac9 100644 --- a/web/app/(shareLayout)/layout.tsx +++ b/web/app/(shareLayout)/layout.tsx @@ -1,54 +1,15 @@ -'use client' -import React, { useEffect, useState } from 'react' -import type { FC } from 'react' -import { usePathname, useSearchParams } from 'next/navigation' -import Loading from '../components/base/loading' -import { useGlobalPublicStore } from '@/context/global-public-context' -import { AccessMode } from '@/models/access-control' -import { getAppAccessModeByAppCode } from '@/service/share' +import type { FC, PropsWithChildren } from 'react' +import WebAppStoreProvider from '@/context/web-app-context' +import Splash from './components/splash' -const Layout: FC<{ - children: React.ReactNode -}> = ({ children }) => { - const isGlobalPending = useGlobalPublicStore(s => s.isGlobalPending) - const setWebAppAccessMode = useGlobalPublicStore(s => s.setWebAppAccessMode) - const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) - const pathname = usePathname() - const searchParams = useSearchParams() - const redirectUrl = searchParams.get('redirect_url') - const [isLoading, setIsLoading] = useState(true) - useEffect(() => { - (async () => { - if (!isGlobalPending && !systemFeatures.webapp_auth.enabled) { - setIsLoading(false) - return - } - - let appCode: string | null = null - if (redirectUrl) { - const url = new URL(`${window.location.origin}${decodeURIComponent(redirectUrl)}`) - appCode = url.pathname.split('/').pop() || null - } - else { - appCode = pathname.split('/').pop() || null - } - - if (!appCode) - return - setIsLoading(true) - const ret = await getAppAccessModeByAppCode(appCode) - setWebAppAccessMode(ret?.accessMode || AccessMode.PUBLIC) - setIsLoading(false) - })() - }, [pathname, redirectUrl, setWebAppAccessMode, isGlobalPending, systemFeatures.webapp_auth.enabled]) - if (isLoading || isGlobalPending) { - return
- -
- } +const Layout: FC = ({ children }) => { return (
- {children} + + + {children} + +
) } diff --git a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx index 9f9a8ad4e3..5e3f6fff1d 100644 --- a/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx +++ b/web/app/(shareLayout)/webapp-reset-password/set-password/page.tsx @@ -9,8 +9,7 @@ import Button from '@/app/components/base/button' import { changeWebAppPasswordWithToken } from '@/service/common' import Toast from '@/app/components/base/toast' import Input from '@/app/components/base/input' - -const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/ +import { validPassword } from '@/config' const ChangePasswordForm = () => { const { t } = useTranslation() diff --git a/web/app/(shareLayout)/webapp-signin/layout.tsx b/web/app/(shareLayout)/webapp-signin/layout.tsx index a03364d326..7649982072 100644 --- a/web/app/(shareLayout)/webapp-signin/layout.tsx +++ b/web/app/(shareLayout)/webapp-signin/layout.tsx @@ -3,10 +3,13 @@ import cn from '@/utils/classnames' import { useGlobalPublicStore } from '@/context/global-public-context' import useDocumentTitle from '@/hooks/use-document-title' +import type { PropsWithChildren } from 'react' +import { useTranslation } from 'react-i18next' -export default function SignInLayout({ children }: any) { - const { systemFeatures } = useGlobalPublicStore() - useDocumentTitle('') +export default function SignInLayout({ children }: PropsWithChildren) { + const { t } = useTranslation() + const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) + useDocumentTitle(t('login.webapp.login')) return <>
diff --git a/web/app/(shareLayout)/webapp-signin/normalForm.tsx b/web/app/(shareLayout)/webapp-signin/normalForm.tsx index d6bdf607ba..44006a9f1e 100644 --- a/web/app/(shareLayout)/webapp-signin/normalForm.tsx +++ b/web/app/(shareLayout)/webapp-signin/normalForm.tsx @@ -1,3 +1,4 @@ +'use client' import React, { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' import Link from 'next/link' diff --git a/web/app/(shareLayout)/webapp-signin/page.tsx b/web/app/(shareLayout)/webapp-signin/page.tsx index 967516c416..1c6209b902 100644 --- a/web/app/(shareLayout)/webapp-signin/page.tsx +++ b/web/app/(shareLayout)/webapp-signin/page.tsx @@ -1,36 +1,30 @@ 'use client' import { useRouter, useSearchParams } from 'next/navigation' import type { FC } from 'react' -import React, { useCallback, useEffect } from 'react' +import React, { useCallback } from 'react' import { useTranslation } from 'react-i18next' -import Toast from '@/app/components/base/toast' -import { removeAccessToken, setAccessToken } from '@/app/components/share/utils' +import { removeAccessToken } from '@/app/components/share/utils' import { useGlobalPublicStore } from '@/context/global-public-context' -import Loading from '@/app/components/base/loading' import AppUnavailable from '@/app/components/base/app-unavailable' import NormalForm from './normalForm' import { AccessMode } from '@/models/access-control' import ExternalMemberSsoAuth from './components/external-member-sso-auth' -import { fetchAccessToken } from '@/service/share' +import { useWebAppStore } from '@/context/web-app-context' const WebSSOForm: FC = () => { const { t } = useTranslation() const systemFeatures = useGlobalPublicStore(s => s.systemFeatures) - const webAppAccessMode = useGlobalPublicStore(s => s.webAppAccessMode) + const webAppAccessMode = useWebAppStore(s => s.webAppAccessMode) const searchParams = useSearchParams() const router = useRouter() const redirectUrl = searchParams.get('redirect_url') - const tokenFromUrl = searchParams.get('web_sso_token') - const message = searchParams.get('message') - const code = searchParams.get('code') const getSigninUrl = useCallback(() => { - const params = new URLSearchParams(searchParams) - params.delete('message') - params.delete('code') + const params = new URLSearchParams() + params.append('redirect_url', redirectUrl || '') return `/webapp-signin?${params.toString()}` - }, [searchParams]) + }, [redirectUrl]) const backToHome = useCallback(() => { removeAccessToken() @@ -38,73 +32,12 @@ const WebSSOForm: FC = () => { router.replace(url) }, [getSigninUrl, router]) - const showErrorToast = (msg: string) => { - Toast.notify({ - type: 'error', - message: msg, - }) - } - - const getAppCodeFromRedirectUrl = useCallback(() => { - if (!redirectUrl) - return null - const url = new URL(`${window.location.origin}${decodeURIComponent(redirectUrl)}`) - const appCode = url.pathname.split('/').pop() - if (!appCode) - return null - - return appCode - }, [redirectUrl]) - - useEffect(() => { - (async () => { - if (message) - return - - const appCode = getAppCodeFromRedirectUrl() - if (appCode && tokenFromUrl && redirectUrl) { - localStorage.setItem('webapp_access_token', tokenFromUrl) - const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: tokenFromUrl }) - await setAccessToken(appCode, tokenResp.access_token) - router.replace(decodeURIComponent(redirectUrl)) - return - } - if (appCode && redirectUrl && localStorage.getItem('webapp_access_token')) { - const tokenResp = await fetchAccessToken({ appCode, webAppAccessToken: localStorage.getItem('webapp_access_token') }) - await setAccessToken(appCode, tokenResp.access_token) - router.replace(decodeURIComponent(redirectUrl)) - } - })() - }, [getAppCodeFromRedirectUrl, redirectUrl, router, tokenFromUrl, message]) - - useEffect(() => { - if (webAppAccessMode && webAppAccessMode === AccessMode.PUBLIC && redirectUrl) - router.replace(decodeURIComponent(redirectUrl)) - }, [webAppAccessMode, router, redirectUrl]) - - if (tokenFromUrl) { - return
- -
- } - - if (message) { - return
- - {code === '403' ? t('common.userProfile.logout') : t('share.login.backToHome')} -
- } if (!redirectUrl) { - showErrorToast('redirect url is invalid.') return
} - if (webAppAccessMode && webAppAccessMode === AccessMode.PUBLIC) { - return
- -
- } + if (!systemFeatures.webapp_auth.enabled) { return

{t('login.webapp.disabled')}

diff --git a/web/app/(shareLayout)/workflow/[token]/page.tsx b/web/app/(shareLayout)/workflow/[token]/page.tsx index e93bc8c1af..4f5923e91f 100644 --- a/web/app/(shareLayout)/workflow/[token]/page.tsx +++ b/web/app/(shareLayout)/workflow/[token]/page.tsx @@ -1,10 +1,13 @@ import React from 'react' import Main from '@/app/components/share/text-generation' +import AuthenticatedLayout from '../../components/authenticated-layout' const Workflow = () => { return ( -
+ +
+ ) } diff --git a/web/app/account/account-page/index.tsx b/web/app/account/account-page/index.tsx index a469286900..55fa2983dd 100644 --- a/web/app/account/account-page/index.tsx +++ b/web/app/account/account-page/index.tsx @@ -21,6 +21,7 @@ import Input from '@/app/components/base/input' import PremiumBadge from '@/app/components/base/premium-badge' import { useGlobalPublicStore } from '@/context/global-public-context' import EmailChangeModal from './email-change-modal' +import { validPassword } from '@/config' const titleClassName = ` system-sm-semibold text-text-secondary @@ -29,8 +30,6 @@ const descriptionClassName = ` mt-1 body-xs-regular text-text-tertiary ` -const validPassword = /^(?=.*[a-zA-Z])(?=.*\d).{8,}$/ - export default function AccountPage() { const { t } = useTranslation() const { systemFeatures } = useGlobalPublicStore() diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index c28cc20df5..3817ebf5a4 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -308,13 +308,11 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx operations={operations} />
-
- -
+