diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index a9580a3ba3..d684fe9144 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -8,13 +8,15 @@ body: label: Self Checks description: "To make sure we get to you in time, please check the following :)" options: + - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). + required: true - label: This is only for bug report, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). required: true - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true - - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). + - label: I confirm that I am using English to submit this report, otherwise it will be closed. required: true - - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + - label: 【中文用户 & Non English User】请使用英语提交,否则会被关闭 :) required: true - label: "Please do not modify this template :) and fill in all the required fields." required: true @@ -42,20 +44,22 @@ body: attributes: label: Steps to reproduce description: We highly suggest including screenshots and a bug report log. Please use the right markdown syntax for code blocks. - placeholder: Having detailed steps helps us reproduce the bug. + placeholder: Having detailed steps helps us reproduce the bug. If you have logs, please use fenced code blocks (triple backticks ```) to format them. validations: required: true - type: textarea attributes: label: ✔️ Expected Behavior - placeholder: What were you expecting? + description: Describe what you expected to happen. + placeholder: What were you expecting? Please do not copy and paste the steps to reproduce here. validations: - required: false + required: true - type: textarea attributes: label: ❌ Actual Behavior - placeholder: What happened instead? + description: Describe what actually happened. + placeholder: What happened instead? Please do not copy and paste the steps to reproduce here. validations: required: false diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 6877c382c4..c1666d24cf 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,5 +1,11 @@ blank_issues_enabled: false contact_links: + - name: "\U0001F4A1 Model Providers & Plugins" + url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose" + about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details. + - name: "\U0001F4AC Documentation Issues" + url: "https://github.com/langgenius/dify-docs/issues/new" + about: Report issues with the documentation, such as typos, outdated information, or missing content. Please provide the specific section and details of the issue. - name: "\U0001F4E7 Discussions" url: https://github.com/langgenius/dify/discussions/categories/general - about: General discussions and request help from the community + about: General discussions and seek help from the community diff --git a/.github/ISSUE_TEMPLATE/document_issue.yml b/.github/ISSUE_TEMPLATE/document_issue.yml deleted file mode 100644 index 8fdbc0fb9a..0000000000 --- a/.github/ISSUE_TEMPLATE/document_issue.yml +++ /dev/null @@ -1,24 +0,0 @@ -name: "📚 Documentation Issue" -description: Report issues in our documentation -labels: - - documentation -body: - - type: checkboxes - attributes: - label: Self Checks - description: "To make sure we get to you in time, please check the following :)" - options: - - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. - required: true - - label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). - required: true - - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" - required: true - - label: "Please do not modify this template :) and fill in all the required fields." - required: true - - type: textarea - attributes: - label: Provide a description of requested docs changes - placeholder: Briefly describe which document needs to be corrected and why. - validations: - required: true diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index b1952c63a9..bd293e2442 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -8,11 +8,11 @@ body: label: Self Checks description: "To make sure we get to you in time, please check the following :)" options: - - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. + - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). required: true - - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). + - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true - - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + - label: I confirm that I am using English to submit this report, otherwise it will be closed. required: true - label: "Please do not modify this template :) and fill in all the required fields." required: true diff --git a/.github/ISSUE_TEMPLATE/translation_issue.yml b/.github/ISSUE_TEMPLATE/translation_issue.yml deleted file mode 100644 index f9c2dfb7d2..0000000000 --- a/.github/ISSUE_TEMPLATE/translation_issue.yml +++ /dev/null @@ -1,55 +0,0 @@ -name: "🌐 Localization/Translation issue" -description: Report incorrect translations. [please use English :)] -labels: - - translation -body: - - type: checkboxes - attributes: - label: Self Checks - description: "To make sure we get to you in time, please check the following :)" - options: - - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. - required: true - - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). - required: true - - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" - required: true - - label: "Please do not modify this template :) and fill in all the required fields." - required: true - - type: input - attributes: - label: Dify version - description: Hover over system tray icon or look at Settings - validations: - required: true - - type: input - attributes: - label: Utility with translation issue - placeholder: Some area - description: Please input here the utility with the translation issue - validations: - required: true - - type: input - attributes: - label: 🌐 Language affected - placeholder: "German" - validations: - required: true - - type: textarea - attributes: - label: ❌ Actual phrase(s) - placeholder: What is there? Please include a screenshot as that is extremely helpful. - validations: - required: true - - type: textarea - attributes: - label: ✔️ Expected phrase(s) - placeholder: What was expected? - validations: - required: true - - type: textarea - attributes: - label: ℹ Why is the current translation wrong - placeholder: Why do you feel this is incorrect? - validations: - required: true diff --git a/README.md b/README.md index e8e3654b98..2909e0e6cf 100644 --- a/README.md +++ b/README.md @@ -205,6 +205,7 @@ If you'd like to configure a highly-available setup, there are community-contrib - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Using Terraform for Deployment diff --git a/README_AR.md b/README_AR.md index d93bca8646..e959ca0f78 100644 --- a/README_AR.md +++ b/README_AR.md @@ -188,6 +188,7 @@ docker compose up -d - [رسم بياني Helm من قبل @magicsong](https://github.com/magicsong/ai-charts) - [ملف YAML من قبل @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [ملف YAML من قبل @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 جديد! ملفات YAML (تدعم Dify v1.6.0) بواسطة @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### استخدام Terraform للتوزيع diff --git a/README_BN.md b/README_BN.md index 3efee3684d..29d7374ea5 100644 --- a/README_BN.md +++ b/README_BN.md @@ -204,6 +204,8 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 নতুন! YAML ফাইলসমূহ (Dify v1.6.0 সমর্থিত) তৈরি করেছেন @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) + #### টেরাফর্ম ব্যবহার করে ডিপ্লয় diff --git a/README_CN.md b/README_CN.md index 21e27429ec..486a368c09 100644 --- a/README_CN.md +++ b/README_CN.md @@ -194,9 +194,9 @@ docker compose up -d 如果您需要自定义配置,请参考 [.env.example](docker/.env.example) 文件中的注释,并更新 `.env` 文件中对应的值。此外,您可能需要根据您的具体部署环境和需求对 `docker-compose.yaml` 文件本身进行调整,例如更改镜像版本、端口映射或卷挂载。完成任何更改后,请重新运行 `docker-compose up -d`。您可以在[此处](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用环境变量的完整列表。 -#### 使用 Helm Chart 部署 +#### 使用 Helm Chart 或 Kubernetes 资源清单(YAML)部署 -使用 [Helm Chart](https://helm.sh/) 版本或者 YAML 文件,可以在 Kubernetes 上部署 Dify。 +使用 [Helm Chart](https://helm.sh/) 版本或者 Kubernetes 资源清单(YAML),可以在 Kubernetes 上部署 Dify。 - [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify) - [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) @@ -204,6 +204,10 @@ docker compose up -d - [YAML 文件 by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML 文件 (支持 Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) + + + #### 使用 Terraform 部署 使用 [terraform](https://www.terraform.io/) 一键将 Dify 部署到云平台 diff --git a/README_DE.md b/README_DE.md index 20c313035e..fce52c34c2 100644 --- a/README_DE.md +++ b/README_DE.md @@ -203,6 +203,7 @@ Falls Sie eine hochverfügbare Konfiguration einrichten möchten, gibt es von de - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Terraform für die Bereitstellung verwenden diff --git a/README_ES.md b/README_ES.md index e4b7df6686..6fd6dfcee8 100644 --- a/README_ES.md +++ b/README_ES.md @@ -203,6 +203,7 @@ Si desea configurar una configuración de alta disponibilidad, la comunidad prop - [Gráfico Helm por @magicsong](https://github.com/magicsong/ai-charts) - [Ficheros YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [Ficheros YAML por @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 ¡NUEVO! Archivos YAML (compatible con Dify v1.6.0) por @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Uso de Terraform para el despliegue diff --git a/README_FR.md b/README_FR.md index 8fd17fb7c3..b2209fb495 100644 --- a/README_FR.md +++ b/README_FR.md @@ -201,6 +201,7 @@ Si vous souhaitez configurer une configuration haute disponibilité, la communau - [Helm Chart par @magicsong](https://github.com/magicsong/ai-charts) - [Fichier YAML par @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [Fichier YAML par @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NOUVEAU ! Fichiers YAML (compatible avec Dify v1.6.0) par @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Utilisation de Terraform pour le déploiement diff --git a/README_JA.md b/README_JA.md index a3ee81e1f2..c658225f90 100644 --- a/README_JA.md +++ b/README_JA.md @@ -202,6 +202,7 @@ docker compose up -d - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 新着!YAML ファイル(Dify v1.6.0 対応)by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Terraformを使用したデプロイ diff --git a/README_KL.md b/README_KL.md index 3e5ab1a74f..bfafcc7407 100644 --- a/README_KL.md +++ b/README_KL.md @@ -201,6 +201,7 @@ If you'd like to configure a highly-available setup, there are community-contrib - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Terraform atorlugu pilersitsineq diff --git a/README_KR.md b/README_KR.md index 3c504900e1..282117e776 100644 --- a/README_KR.md +++ b/README_KR.md @@ -195,6 +195,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 - [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Terraform을 사용한 배포 diff --git a/README_PT.md b/README_PT.md index fb5f3662ae..576f6b48f7 100644 --- a/README_PT.md +++ b/README_PT.md @@ -200,6 +200,7 @@ Se deseja configurar uma instalação de alta disponibilidade, há [Helm Charts] - [Helm Chart de @magicsong](https://github.com/magicsong/ai-charts) - [Arquivo YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [Arquivo YAML por @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NOVO! Arquivos YAML (Compatível com Dify v1.6.0) por @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Usando o Terraform para Implantação diff --git a/README_SI.md b/README_SI.md index 647069a220..7ded001d86 100644 --- a/README_SI.md +++ b/README_SI.md @@ -201,6 +201,7 @@ Star Dify on GitHub and be instantly notified of new releases. - [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) - [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Uporaba Terraform za uvajanje diff --git a/README_TR.md b/README_TR.md index f52335646a..6e94e54fa0 100644 --- a/README_TR.md +++ b/README_TR.md @@ -194,6 +194,7 @@ Yüksek kullanılabilirliğe sahip bir kurulum yapılandırmak isterseniz, Dify' - [@BorisPolonsky tarafından Helm Chart](https://github.com/BorisPolonsky/dify-helm) - [@Winson-030 tarafından YAML dosyası](https://github.com/Winson-030/dify-kubernetes) - [@wyy-holding tarafından YAML dosyası](https://github.com/wyy-holding/dify-k8s) +- [🚀 YENİ! YAML dosyaları (Dify v1.6.0 destekli) @Zhoneym tarafından](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Dağıtım için Terraform Kullanımı diff --git a/README_TW.md b/README_TW.md index 71082ff893..6e3e22b5c1 100644 --- a/README_TW.md +++ b/README_TW.md @@ -197,12 +197,13 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify 如果您需要自定義配置,請參考我們的 [.env.example](docker/.env.example) 文件中的註釋,並在您的 `.env` 文件中更新相應的值。此外,根據您特定的部署環境和需求,您可能需要調整 `docker-compose.yaml` 文件本身,例如更改映像版本、端口映射或卷掛載。進行任何更改後,請重新運行 `docker-compose up -d`。您可以在[這裡](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用環境變數的完整列表。 -如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 YAML 文件允許在 Kubernetes 上部署 Dify。 +如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 Kubernetes 資源清單(YAML)允許在 Kubernetes 上部署 Dify。 - [由 @LeoQuote 提供的 Helm Chart](https://github.com/douban/charts/tree/master/charts/dify) - [由 @BorisPolonsky 提供的 Helm Chart](https://github.com/BorisPolonsky/dify-helm) - [由 @Winson-030 提供的 YAML 文件](https://github.com/Winson-030/dify-kubernetes) - [由 @wyy-holding 提供的 YAML 文件](https://github.com/wyy-holding/dify-k8s) +- [🚀 NEW! YAML 檔案(支援 Dify v1.6.0)by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) ### 使用 Terraform 進行部署 diff --git a/README_VI.md b/README_VI.md index 58d8434fff..51314e6de5 100644 --- a/README_VI.md +++ b/README_VI.md @@ -196,6 +196,7 @@ Nếu bạn muốn cấu hình một cài đặt có độ sẵn sàng cao, có - [Helm Chart bởi @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm) - [Tệp YAML bởi @Winson-030](https://github.com/Winson-030/dify-kubernetes) - [Tệp YAML bởi @wyy-holding](https://github.com/wyy-holding/dify-k8s) +- [🚀 MỚI! Tệp YAML (Hỗ trợ Dify v1.6.0) bởi @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes) #### Sử dụng Terraform để Triển khai diff --git a/api/.env.example b/api/.env.example index eab017a624..c09c6c230e 100644 --- a/api/.env.example +++ b/api/.env.example @@ -505,6 +505,8 @@ LOGIN_LOCKOUT_DURATION=86400 # Enable OpenTelemetry ENABLE_OTEL=false +OTLP_TRACE_ENDPOINT= +OTLP_METRIC_ENDPOINT= OTLP_BASE_ENDPOINT=http://localhost:4318 OTLP_API_KEY= OTEL_EXPORTER_OTLP_PROTOCOL= diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 427602676f..0c0c06dd46 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -162,6 +162,11 @@ class DatabaseConfig(BaseSettings): default=3600, ) + SQLALCHEMY_POOL_USE_LIFO: bool = Field( + description="If True, SQLAlchemy will use last-in-first-out way to retrieve connections from pool.", + default=False, + ) + SQLALCHEMY_POOL_PRE_PING: bool = Field( description="If True, enables connection pool pre-ping feature to check connections.", default=False, @@ -199,6 +204,7 @@ class DatabaseConfig(BaseSettings): "pool_recycle": self.SQLALCHEMY_POOL_RECYCLE, "pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING, "connect_args": connect_args, + "pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO, } diff --git a/api/configs/observability/otel/otel_config.py b/api/configs/observability/otel/otel_config.py index 1b88ddcfe6..7572a696ce 100644 --- a/api/configs/observability/otel/otel_config.py +++ b/api/configs/observability/otel/otel_config.py @@ -12,6 +12,16 @@ class OTelConfig(BaseSettings): default=False, ) + OTLP_TRACE_ENDPOINT: str = Field( + description="OTLP trace endpoint", + default="", + ) + + OTLP_METRIC_ENDPOINT: str = Field( + description="OTLP metric endpoint", + default="", + ) + OTLP_BASE_ENDPOINT: str = Field( description="OTLP base endpoint", default="http://localhost:4318", diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 860166a61a..9fe32dde6d 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -151,6 +151,7 @@ class AppApi(Resource): parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") parser.add_argument("use_icon_as_answer_icon", type=bool, location="json") + parser.add_argument("max_active_requests", type=int, location="json") args = parser.parse_args() app_service = AppService() diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 0f53860f56..503393f264 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -35,16 +35,20 @@ class AppMCPServerController(Resource): @get_app_model @marshal_with(app_server_fields) def post(self, app_model): - # The role of the current user in the ta table must be editor, admin, or owner if not current_user.is_editor: raise NotFound() parser = reqparse.RequestParser() - parser.add_argument("description", type=str, required=True, location="json") + parser.add_argument("description", type=str, required=False, location="json") parser.add_argument("parameters", type=dict, required=True, location="json") args = parser.parse_args() + + description = args.get("description") + if not description: + description = app_model.description or "" + server = AppMCPServer( name=app_model.name, - description=args["description"], + description=description, parameters=json.dumps(args["parameters"], ensure_ascii=False), status=AppMCPServerStatus.ACTIVE, app_id=app_model.id, @@ -65,14 +69,22 @@ class AppMCPServerController(Resource): raise NotFound() parser = reqparse.RequestParser() parser.add_argument("id", type=str, required=True, location="json") - parser.add_argument("description", type=str, required=True, location="json") + parser.add_argument("description", type=str, required=False, location="json") parser.add_argument("parameters", type=dict, required=True, location="json") parser.add_argument("status", type=str, required=False, location="json") args = parser.parse_args() server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first() if not server: raise NotFound() - server.description = args["description"] + + description = args.get("description") + if description is None: + pass + elif not description: + server.description = app_model.description or "" + else: + server.description = description + server.parameters = json.dumps(args["parameters"], ensure_ascii=False) if args["status"]: if args["status"] not in [status.value for status in AppMCPServerStatus]: diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py index 86aed77412..32b64d10c5 100644 --- a/api/controllers/console/app/statistic.py +++ b/api/controllers/console/app/statistic.py @@ -2,6 +2,7 @@ from datetime import datetime from decimal import Decimal import pytz +import sqlalchemy as sa from flask import jsonify from flask_login import current_user from flask_restful import Resource, reqparse @@ -9,10 +10,11 @@ from flask_restful import Resource, reqparse from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required +from core.app.entities.app_invoke_entities import InvokeFrom from extensions.ext_database import db from libs.helper import DatetimeString from libs.login import login_required -from models.model import AppMode +from models import AppMode, Message class DailyMessageStatistic(Resource): @@ -85,46 +87,41 @@ class DailyConversationStatistic(Resource): parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args") args = parser.parse_args() - sql_query = """SELECT - DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date, - COUNT(DISTINCT messages.conversation_id) AS conversation_count -FROM - messages -WHERE - app_id = :app_id""" - arg_dict = {"tz": account.timezone, "app_id": app_model.id} - timezone = pytz.timezone(account.timezone) utc_timezone = pytz.utc + stmt = ( + sa.select( + sa.func.date( + sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz")) + ).label("date"), + sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"), + ) + .select_from(Message) + .where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value) + ) + if args["start"]: start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M") start_datetime = start_datetime.replace(second=0) - start_datetime_timezone = timezone.localize(start_datetime) start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone) - - sql_query += " AND created_at >= :start" - arg_dict["start"] = start_datetime_utc + stmt = stmt.where(Message.created_at >= start_datetime_utc) if args["end"]: end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M") end_datetime = end_datetime.replace(second=0) - end_datetime_timezone = timezone.localize(end_datetime) end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone) + stmt = stmt.where(Message.created_at < end_datetime_utc) - sql_query += " AND created_at < :end" - arg_dict["end"] = end_datetime_utc - - sql_query += " GROUP BY date ORDER BY date" + stmt = stmt.group_by("date").order_by("date") response_data = [] - with db.engine.begin() as conn: - rs = conn.execute(db.text(sql_query), arg_dict) - for i in rs: - response_data.append({"date": str(i.date), "conversation_count": i.conversation_count}) + rs = conn.execute(stmt, {"tz": account.timezone}) + for row in rs: + response_data.append({"date": str(row.date), "conversation_count": row.conversation_count}) return jsonify({"data": response_data}) diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 00d6fa3cbf..ba93f82756 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -68,13 +68,18 @@ def _create_pagination_parser(): return parser +def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: + value_type = workflow_draft_var.value_type + return value_type.exposed_type().value + + _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { "id": fields.String, "type": fields.String(attribute=lambda model: model.get_variable_type()), "name": fields.String, "description": fields.String, "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), - "value_type": fields.String, + "value_type": fields.String(attribute=_serialize_variable_type), "edited": fields.Boolean(attribute=lambda model: model.edited), "visible": fields.Boolean, } @@ -90,7 +95,7 @@ _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = { "name": fields.String, "description": fields.String, "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), - "value_type": fields.String, + "value_type": fields.String(attribute=_serialize_variable_type), "edited": fields.Boolean(attribute=lambda model: model.edited), "visible": fields.Boolean, } @@ -396,7 +401,7 @@ class EnvironmentVariableCollectionApi(Resource): "name": v.name, "description": v.description, "selector": v.selector, - "value_type": v.value_type.value, + "value_type": v.value_type.exposed_type().value, "value": v.value, # Do not track edited for env vars. "edited": False, diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index 3b48288710..a3438fc2c7 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -41,6 +41,7 @@ class AgentStrategyParameter(PluginParameter): APP_SELECTOR = CommonParameterType.APP_SELECTOR.value MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value + ANY = CommonParameterType.ANY.value # deprecated, should not use. SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 840a3c9d3b..af15324f46 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -16,9 +16,10 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, ) from core.moderation.base import ModerationError +from core.variables.variables import VariableUnion from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db @@ -64,7 +65,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): if not workflow: raise ValueError("Workflow not initialized") - user_id = None + user_id: str | None = None if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: @@ -136,23 +137,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): session.commit() # Create a variable pool. - system_inputs = { - SystemVariableKey.QUERY: query, - SystemVariableKey.FILES: files, - SystemVariableKey.CONVERSATION_ID: self.conversation.id, - SystemVariableKey.USER_ID: user_id, - SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count, - SystemVariableKey.APP_ID: app_config.app_id, - SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id, - } + system_inputs = SystemVariable( + query=query, + files=files, + conversation_id=self.conversation.id, + user_id=user_id, + dialogue_count=self._dialogue_count, + app_id=app_config.app_id, + workflow_id=app_config.workflow_id, + workflow_execution_id=self.application_generate_entity.workflow_run_id, + ) # init variable pool variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + conversation_variables=cast(list[VariableUnion], conversation_variables), ) # init graph diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 4c52fc3e83..1dc9796d5b 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -61,12 +61,12 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes import NodeType from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from events.message_event import message_was_created from extensions.ext_database import db @@ -116,16 +116,16 @@ class AdvancedChatAppGenerateTaskPipeline: self._workflow_cycle_manager = WorkflowCycleManager( application_generate_entity=application_generate_entity, - workflow_system_variables={ - SystemVariableKey.QUERY: message.query, - SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.CONVERSATION_ID: conversation.id, - SystemVariableKey.USER_ID: user_session_id, - SystemVariableKey.DIALOGUE_COUNT: dialogue_count, - SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id, - }, + workflow_system_variables=SystemVariable( + query=message.query, + files=application_generate_entity.files, + conversation_id=conversation.id, + user_id=user_session_id, + dialogue_count=dialogue_count, + app_id=application_generate_entity.app_config.app_id, + workflow_id=workflow.id, + workflow_execution_id=application_generate_entity.workflow_run_id, + ), workflow_info=CycleManagerWorkflowInfo( workflow_id=workflow.id, workflow_type=WorkflowType(workflow.type), diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 07aeb57fa3..3a66ffa578 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -11,7 +11,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db @@ -95,13 +95,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): files = self.application_generate_entity.files # Create a variable pool. - system_inputs = { - SystemVariableKey.FILES: files, - SystemVariableKey.USER_ID: user_id, - SystemVariableKey.APP_ID: app_config.app_id, - SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id, - } + + system_inputs = SystemVariable( + files=files, + user_id=user_id, + app_id=app_config.app_id, + workflow_id=app_config.workflow_id, + workflow_execution_id=self.application_generate_entity.workflow_execution_id, + ) variable_pool = VariablePool( system_variables=system_inputs, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index c6b326d8a4..7adc03e9c3 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -54,10 +54,10 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType -from core.workflow.enums import SystemVariableKey from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from extensions.ext_database import db from models.account import Account @@ -107,13 +107,13 @@ class WorkflowAppGenerateTaskPipeline: self._workflow_cycle_manager = WorkflowCycleManager( application_generate_entity=application_generate_entity, - workflow_system_variables={ - SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: user_session_id, - SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id, - }, + workflow_system_variables=SystemVariable( + files=application_generate_entity.files, + user_id=user_session_id, + app_id=application_generate_entity.app_config.app_id, + workflow_id=workflow.id, + workflow_execution_id=application_generate_entity.workflow_execution_id, + ), workflow_info=CycleManagerWorkflowInfo( workflow_id=workflow.id, workflow_type=WorkflowType(workflow.type), diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 17b9ac5827..2f4d234ecd 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -62,6 +62,7 @@ from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes import NodeType from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db @@ -166,7 +167,7 @@ class WorkflowBasedAppRunner(AppRunner): # init variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, environment_variables=workflow.environment_variables, ) @@ -263,7 +264,7 @@ class WorkflowBasedAppRunner(AppRunner): # init variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, environment_variables=workflow.environment_variables, ) diff --git a/api/core/entities/parameter_entities.py b/api/core/entities/parameter_entities.py index 2fa347c204..fbd62437e6 100644 --- a/api/core/entities/parameter_entities.py +++ b/api/core/entities/parameter_entities.py @@ -14,6 +14,7 @@ class CommonParameterType(StrEnum): APP_SELECTOR = "app-selector" MODEL_SELECTOR = "model-selector" TOOLS_SELECTOR = "array[tools]" + ANY = "any" # Dynamic select parameter # Once you are not sure about the available options until authorization is done diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index a8e9f41a84..b416e48ce4 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -45,17 +45,13 @@ class TemplateTransformer(ABC): result_str = cls.extract_result_str_from_response(response) result = json.loads(result_str) except json.JSONDecodeError as e: - raise ValueError(f"Failed to parse JSON response: {str(e)}. Response content: {result_str[:200]}...") + raise ValueError(f"Failed to parse JSON response: {str(e)}.") except ValueError as e: # Re-raise ValueError from extract_result_str_from_response raise e except Exception as e: raise ValueError(f"Unexpected error during response transformation: {str(e)}") - # Check if the result contains an error - if isinstance(result, dict) and "error" in result: - raise ValueError(f"JavaScript execution error: {result['error']}") - if not isinstance(result, dict): raise ValueError(f"Result must be a dict, got {type(result).__name__}") if not all(isinstance(k, str) for k in result): diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index b63478e822..bcb31a816f 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -240,7 +240,7 @@ def refresh_authorization( response = requests.post(token_url, data=params) if not response.ok: raise ValueError(f"Token refresh failed: HTTP {response.status_code}") - return OAuthTokens.parse_obj(response.json()) + return OAuthTokens.model_validate(response.json()) def register_client( diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 1c0f582501..7734b8fdd9 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -1,7 +1,7 @@ import logging import queue from collections.abc import Callable -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError from contextlib import ExitStack from datetime import timedelta from types import TracebackType @@ -171,23 +171,41 @@ class BaseSession( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._exit_stack = ExitStack() + # Initialize executor and future to None for proper cleanup checks + self._executor: ThreadPoolExecutor | None = None + self._receiver_future: Future | None = None def __enter__(self) -> Self: - self._executor = ThreadPoolExecutor() + # The thread pool is dedicated to running `_receive_loop`. Setting `max_workers` to 1 + # ensures no unnecessary threads are created. + self._executor = ThreadPoolExecutor(max_workers=1) self._receiver_future = self._executor.submit(self._receive_loop) return self def check_receiver_status(self) -> None: - if self._receiver_future.done(): + """`check_receiver_status` ensures that any exceptions raised during the + execution of `_receive_loop` are retrieved and propagated.""" + if self._receiver_future and self._receiver_future.done(): self._receiver_future.result() def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: - self._exit_stack.close() self._read_stream.put(None) self._write_stream.put(None) + # Wait for the receiver loop to finish + if self._receiver_future: + try: + self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds + except TimeoutError: + # If the receiver loop is still running after timeout, we'll force shutdown + pass + + # Shutdown the executor + if self._executor: + self._executor.shutdown(wait=True) + def send_request( self, request: SendRequestT, diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index b18a6905fe..db8fec4ee9 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -284,7 +284,8 @@ class AliyunDataTrace(BaseTraceInstance): else: node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution) return node_span - except Exception: + except Exception as e: + logging.debug(f"Error occurred in build_workflow_node_span: {e}", exc_info=True) return None def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status: @@ -306,7 +307,7 @@ class AliyunDataTrace(BaseTraceInstance): start_time=convert_datetime_to_nanoseconds(node_execution.created_at), end_time=convert_datetime_to_nanoseconds(node_execution.finished_at), attributes={ - GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""), + GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value, GEN_AI_FRAMEWORK: "dify", INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False), @@ -381,7 +382,7 @@ class AliyunDataTrace(BaseTraceInstance): start_time=convert_datetime_to_nanoseconds(node_execution.created_at), end_time=convert_datetime_to_nanoseconds(node_execution.finished_at), attributes={ - GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""), + GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value, GEN_AI_FRAMEWORK: "dify", GEN_AI_MODEL_NAME: process_data.get("model_name", ""), @@ -415,7 +416,7 @@ class AliyunDataTrace(BaseTraceInstance): start_time=convert_datetime_to_nanoseconds(trace_info.start_time), end_time=convert_datetime_to_nanoseconds(trace_info.end_time), attributes={ - GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""), + GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", GEN_AI_USER_ID: str(user_id), GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value, GEN_AI_FRAMEWORK: "dify", diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index fcbbc70fc3..be4997a5bf 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -241,7 +241,7 @@ class OpikDataTrace(BaseTraceInstance): "trace_id": opik_trace_id, "id": prepare_opik_uuid(created_at, node_execution_id), "parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id), - "name": node_type, + "name": node_name, "type": run_type, "start_time": created_at, "end_time": finished_at, diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index 2be65d67a0..47290ee613 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, field_validator from core.entities.parameter_entities import CommonParameterType from core.tools.entities.common_entities import I18nObject +from core.workflow.nodes.base.entities import NumberType class PluginParameterOption(BaseModel): @@ -38,6 +39,7 @@ class PluginParameterType(enum.StrEnum): APP_SELECTOR = CommonParameterType.APP_SELECTOR.value MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value + ANY = CommonParameterType.ANY.value DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value # deprecated, should not use. @@ -151,6 +153,10 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): if value and not isinstance(value, list): raise ValueError("The tools selector must be a list.") return value + case PluginParameterType.ANY: + if value and not isinstance(value, str | dict | list | NumberType): + raise ValueError("The var selector must be a string, dictionary, list or number.") + return value case PluginParameterType.ARRAY: if not isinstance(value, list): # Try to parse JSON string for arrays diff --git a/api/core/plugin/impl/plugin.py b/api/core/plugin/impl/plugin.py index b7f7b31655..04ac8c9649 100644 --- a/api/core/plugin/impl/plugin.py +++ b/api/core/plugin/impl/plugin.py @@ -36,7 +36,7 @@ class PluginInstaller(BasePluginClient): "GET", f"plugin/{tenant_id}/management/list", PluginListResponse, - params={"page": 1, "page_size": 256}, + params={"page": 1, "page_size": 256, "response_type": "paged"}, ) return result.list @@ -45,7 +45,7 @@ class PluginInstaller(BasePluginClient): "GET", f"plugin/{tenant_id}/management/list", PluginListResponse, - params={"page": page, "page_size": page_size}, + params={"page": page, "page_size": page_size, "response_type": "paged"}, ) def upload_pkg( diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 25964ae063..0f0fe65f27 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -158,7 +158,7 @@ class AdvancedPromptTransform(PromptTransform): if prompt_item.edition_type == "basic" or not prompt_item.edition_type: if self.with_variable_tmpl: - vp = VariablePool() + vp = VariablePool.empty() for k, v in inputs.items(): if k.startswith("#"): vp.add(k[1:-1].split("."), v) diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py index a124faa503..552068c99e 100644 --- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py +++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py @@ -4,6 +4,7 @@ from typing import Any, Optional import tablestore # type: ignore from pydantic import BaseModel, model_validator +from tablestore import BatchGetRowRequest, TableInBatchGetRowItem from configs import dify_config from core.rag.datasource.vdb.field import Field @@ -50,6 +51,29 @@ class TableStoreVector(BaseVector): self._index_name = f"{collection_name}_idx" self._tags_field = f"{Field.METADATA_KEY.value}_tags" + def create_collection(self, embeddings: list[list[float]], **kwargs): + dimension = len(embeddings[0]) + self._create_collection(dimension) + + def get_by_ids(self, ids: list[str]) -> list[Document]: + docs = [] + request = BatchGetRowRequest() + columns_to_get = [Field.METADATA_KEY.value, Field.CONTENT_KEY.value] + rows_to_get = [[("id", _id)] for _id in ids] + request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1)) + + result = self._tablestore_client.batch_get_row(request) + table_result = result.get_result_by_table(self._table_name) + for item in table_result: + if item.is_ok and item.row: + kv = {k: v for k, v, t in item.row.attribute_columns} + docs.append( + Document( + page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value]) + ) + ) + return docs + def get_type(self) -> str: return VectorType.TABLESTORE diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index b5148e245f..64568a8eda 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -16,6 +16,7 @@ from core.plugin.entities.parameters import ( cast_parameter_value, init_frontend_parameter, ) +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.tools.entities.common_entities import I18nObject from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY @@ -179,6 +180,10 @@ class ToolInvokeMessage(BaseModel): data: Mapping[str, Any] = Field(..., description="Detailed log data") metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log") + class RetrieverResourceMessage(BaseModel): + retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources") + context: str = Field(..., description="context") + class MessageType(Enum): TEXT = "text" IMAGE = "image" @@ -191,13 +196,22 @@ class ToolInvokeMessage(BaseModel): FILE = "file" LOG = "log" BLOB_CHUNK = "blob_chunk" + RETRIEVER_RESOURCES = "retriever_resources" type: MessageType = MessageType.TEXT """ plain text, image url or link url """ message: ( - JsonMessage | TextMessage | BlobChunkMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage + JsonMessage + | TextMessage + | BlobChunkMessage + | BlobMessage + | LogMessage + | FileMessage + | None + | VariableMessage + | RetrieverResourceMessage ) meta: dict[str, Any] | None = None @@ -243,6 +257,7 @@ class ToolParameter(PluginParameter): FILES = PluginParameterType.FILES.value APP_SELECTOR = PluginParameterType.APP_SELECTOR.value MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value + ANY = PluginParameterType.ANY.value DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value # MCP object and array type parameters diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index 3f844e8234..a3c84615ca 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -1,5 +1,4 @@ import re -import uuid from json import dumps as json_dumps from json import loads as json_loads from json.decoder import JSONDecodeError @@ -154,7 +153,7 @@ class ApiBasedToolSchemaParser: # remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$ path = re.sub(r"[^a-zA-Z0-9_-]", "", path) if not path: - path = str(uuid.uuid4()) + path = "" interface["operation"]["operationId"] = f"{path}_{interface['method']}" diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 6cf09e0372..13274f4e0e 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -1,9 +1,9 @@ import json import sys from collections.abc import Mapping, Sequence -from typing import Any +from typing import Annotated, Any, TypeAlias -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator from core.file import File @@ -11,6 +11,11 @@ from .types import SegmentType class Segment(BaseModel): + """Segment is runtime type used during the execution of workflow. + + Note: this class is abstract, you should use subclasses of this class instead. + """ + model_config = ConfigDict(frozen=True) value_type: SegmentType @@ -73,7 +78,7 @@ class StringSegment(Segment): class FloatSegment(Segment): - value_type: SegmentType = SegmentType.NUMBER + value_type: SegmentType = SegmentType.FLOAT value: float # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. # The following tests cannot pass. @@ -92,7 +97,7 @@ class FloatSegment(Segment): class IntegerSegment(Segment): - value_type: SegmentType = SegmentType.NUMBER + value_type: SegmentType = SegmentType.INTEGER value: int @@ -181,3 +186,46 @@ class ArrayFileSegment(ArraySegment): @property def text(self) -> str: return "" + + +def get_segment_discriminator(v: Any) -> SegmentType | None: + if isinstance(v, Segment): + return v.value_type + elif isinstance(v, dict): + value_type = v.get("value_type") + if value_type is None: + return None + try: + seg_type = SegmentType(value_type) + except ValueError: + return None + return seg_type + else: + # return None if the discriminator value isn't found + return None + + +# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic. +# Use `Segment` for type hinting when serialization is not required. +# +# Note: +# - All variants in `SegmentUnion` must inherit from the `Segment` class. +# - The union must include all non-abstract subclasses of `Segment`, except: +# - `SegmentGroup`, which is not added to the variable pool. +# - `Variable` and its subclasses, which are handled by `VariableUnion`. +SegmentUnion: TypeAlias = Annotated[ + ( + Annotated[NoneSegment, Tag(SegmentType.NONE)] + | Annotated[StringSegment, Tag(SegmentType.STRING)] + | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] + | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] + | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] + | Annotated[FileSegment, Tag(SegmentType.FILE)] + | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] + | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] + | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] + | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] + | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] + ), + Discriminator(get_segment_discriminator), +] diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 68d3d82883..e39237dba5 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -1,8 +1,27 @@ +from collections.abc import Mapping from enum import StrEnum +from typing import Any, Optional + +from core.file.models import File + + +class ArrayValidation(StrEnum): + """Strategy for validating array elements""" + + # Skip element validation (only check array container) + NONE = "none" + + # Validate the first element (if array is non-empty) + FIRST = "first" + + # Validate all elements in the array. + ALL = "all" class SegmentType(StrEnum): NUMBER = "number" + INTEGER = "integer" + FLOAT = "float" STRING = "string" OBJECT = "object" SECRET = "secret" @@ -19,16 +38,141 @@ class SegmentType(StrEnum): GROUP = "group" - def is_array_type(self): + def is_array_type(self) -> bool: return self in _ARRAY_TYPES + @classmethod + def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]: + """ + Attempt to infer the `SegmentType` based on the Python type of the `value` parameter. + + Returns `None` if no appropriate `SegmentType` can be determined for the given `value`. + For example, this may occur if the input is a generic Python object of type `object`. + """ + + if isinstance(value, list): + elem_types: set[SegmentType] = set() + for i in value: + segment_type = cls.infer_segment_type(i) + if segment_type is None: + return None + + elem_types.add(segment_type) + + if len(elem_types) != 1: + if elem_types.issubset(_NUMERICAL_TYPES): + return SegmentType.ARRAY_NUMBER + return SegmentType.ARRAY_ANY + elif all(i.is_array_type() for i in elem_types): + return SegmentType.ARRAY_ANY + match elem_types.pop(): + case SegmentType.STRING: + return SegmentType.ARRAY_STRING + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: + return SegmentType.ARRAY_NUMBER + case SegmentType.OBJECT: + return SegmentType.ARRAY_OBJECT + case SegmentType.FILE: + return SegmentType.ARRAY_FILE + case SegmentType.NONE: + return SegmentType.ARRAY_ANY + case _: + # This should be unreachable. + raise ValueError(f"not supported value {value}") + if value is None: + return SegmentType.NONE + elif isinstance(value, int) and not isinstance(value, bool): + return SegmentType.INTEGER + elif isinstance(value, float): + return SegmentType.FLOAT + elif isinstance(value, str): + return SegmentType.STRING + elif isinstance(value, dict): + return SegmentType.OBJECT + elif isinstance(value, File): + return SegmentType.FILE + elif isinstance(value, str): + return SegmentType.STRING + else: + return None + + def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool: + if not isinstance(value, list): + return False + # Skip element validation if array is empty + if len(value) == 0: + return True + if self == SegmentType.ARRAY_ANY: + return True + element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self] + + if array_validation == ArrayValidation.NONE: + return True + elif array_validation == ArrayValidation.FIRST: + return element_type.is_valid(value[0]) + else: + return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value) + + def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool: + """ + Check if a value matches the segment type. + Users of `SegmentType` should call this method, instead of using + `isinstance` manually. + + Args: + value: The value to validate + array_validation: Validation strategy for array types (ignored for non-array types) + + Returns: + True if the value matches the type under the given validation strategy + """ + if self.is_array_type(): + return self._validate_array(value, array_validation) + elif self == SegmentType.NUMBER: + return isinstance(value, (int, float)) + elif self == SegmentType.STRING: + return isinstance(value, str) + elif self == SegmentType.OBJECT: + return isinstance(value, dict) + elif self == SegmentType.SECRET: + return isinstance(value, str) + elif self == SegmentType.FILE: + return isinstance(value, File) + elif self == SegmentType.NONE: + return value is None + else: + raise AssertionError("this statement should be unreachable.") + + def exposed_type(self) -> "SegmentType": + """Returns the type exposed to the frontend. + + The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here. + """ + if self in (SegmentType.INTEGER, SegmentType.FLOAT): + return SegmentType.NUMBER + return self + + +_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = { + # ARRAY_ANY does not have correpond element type. + SegmentType.ARRAY_STRING: SegmentType.STRING, + SegmentType.ARRAY_NUMBER: SegmentType.NUMBER, + SegmentType.ARRAY_OBJECT: SegmentType.OBJECT, + SegmentType.ARRAY_FILE: SegmentType.FILE, +} _ARRAY_TYPES = frozenset( - [ + list(_ARRAY_ELEMENT_TYPES_MAPPING.keys()) + + [ SegmentType.ARRAY_ANY, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_FILE, + ] +) + + +_NUMERICAL_TYPES = frozenset( + [ + SegmentType.NUMBER, + SegmentType.INTEGER, + SegmentType.FLOAT, ] ) diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index b650b1682e..a31ebc848e 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -1,8 +1,8 @@ from collections.abc import Sequence -from typing import cast +from typing import Annotated, TypeAlias, cast from uuid import uuid4 -from pydantic import Field +from pydantic import Discriminator, Field, Tag from core.helper import encrypter @@ -20,6 +20,7 @@ from .segments import ( ObjectSegment, Segment, StringSegment, + get_segment_discriminator, ) from .types import SegmentType @@ -27,6 +28,10 @@ from .types import SegmentType class Variable(Segment): """ A variable is a segment that has a name. + + It is mainly used to store segments and their selector in VariablePool. + + Note: this class is abstract, you should use subclasses of this class instead. """ id: str = Field( @@ -93,3 +98,28 @@ class FileVariable(FileSegment, Variable): class ArrayFileVariable(ArrayFileSegment, ArrayVariable): pass + + +# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic. +# Use `Variable` for type hinting when serialization is not required. +# +# Note: +# - All variants in `VariableUnion` must inherit from the `Variable` class. +# - The union must include all non-abstract subclasses of `Segment`, except: +VariableUnion: TypeAlias = Annotated[ + ( + Annotated[NoneVariable, Tag(SegmentType.NONE)] + | Annotated[StringVariable, Tag(SegmentType.STRING)] + | Annotated[FloatVariable, Tag(SegmentType.FLOAT)] + | Annotated[IntegerVariable, Tag(SegmentType.INTEGER)] + | Annotated[ObjectVariable, Tag(SegmentType.OBJECT)] + | Annotated[FileVariable, Tag(SegmentType.FILE)] + | Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)] + | Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)] + | Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)] + | Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)] + | Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)] + | Annotated[SecretVariable, Tag(SegmentType.SECRET)] + ), + Discriminator(get_segment_discriminator), +] diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 80dda2632d..646a9d3402 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -1,7 +1,7 @@ import re from collections import defaultdict from collections.abc import Mapping, Sequence -from typing import Any, Union +from typing import Annotated, Any, Union, cast from pydantic import BaseModel, Field @@ -9,8 +9,9 @@ from core.file import File, FileAttribute, file_manager from core.variables import Segment, SegmentGroup, Variable from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.segments import FileSegment, NoneSegment +from core.variables.variables import VariableUnion from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from factories import variable_factory VariableValue = Union[str, int, float, dict, list, File] @@ -23,31 +24,31 @@ class VariablePool(BaseModel): # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. - variable_dictionary: dict[str, dict[int, Segment]] = Field( + variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field( description="Variables mapping", default=defaultdict(dict), ) - # TODO: This user inputs is not used for pool. + + # The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere. user_inputs: Mapping[str, Any] = Field( description="User inputs", default_factory=dict, ) - system_variables: Mapping[SystemVariableKey, Any] = Field( + system_variables: SystemVariable = Field( description="System variables", - default_factory=dict, ) - environment_variables: Sequence[Variable] = Field( + environment_variables: Sequence[VariableUnion] = Field( description="Environment variables.", default_factory=list, ) - conversation_variables: Sequence[Variable] = Field( + conversation_variables: Sequence[VariableUnion] = Field( description="Conversation variables.", default_factory=list, ) def model_post_init(self, context: Any, /) -> None: - for key, value in self.system_variables.items(): - self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) + # Create a mapping from field names to SystemVariableKey enum values + self._add_system_variables(self.system_variables) # Add environment variables to the variable pool for var in self.environment_variables: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) @@ -83,8 +84,22 @@ class VariablePool(BaseModel): segment = variable_factory.build_segment(value) variable = variable_factory.segment_to_variable(segment=segment, selector=selector) - hash_key = hash(tuple(selector[1:])) - self.variable_dictionary[selector[0]][hash_key] = variable + key, hash_key = self._selector_to_keys(selector) + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable) + + @classmethod + def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]: + return selector[0], hash(tuple(selector[1:])) + + def _has(self, selector: Sequence[str]) -> bool: + key, hash_key = self._selector_to_keys(selector) + if key not in self.variable_dictionary: + return False + if hash_key not in self.variable_dictionary[key]: + return False + return True def get(self, selector: Sequence[str], /) -> Segment | None: """ @@ -102,8 +117,8 @@ class VariablePool(BaseModel): if len(selector) < MIN_SELECTORS_LENGTH: return None - hash_key = hash(tuple(selector[1:])) - value = self.variable_dictionary[selector[0]].get(hash_key) + key, hash_key = self._selector_to_keys(selector) + value: Segment | None = self.variable_dictionary[key].get(hash_key) if value is None: selector, attr = selector[:-1], selector[-1] @@ -136,8 +151,9 @@ class VariablePool(BaseModel): if len(selector) == 1: self.variable_dictionary[selector[0]] = {} return + key, hash_key = self._selector_to_keys(selector) hash_key = hash(tuple(selector[1:])) - self.variable_dictionary[selector[0]].pop(hash_key, None) + self.variable_dictionary[key].pop(hash_key, None) def convert_template(self, template: str, /): parts = VARIABLE_PATTERN.split(template) @@ -154,3 +170,20 @@ class VariablePool(BaseModel): if isinstance(segment, FileSegment): return segment return None + + def _add_system_variables(self, system_variable: SystemVariable): + sys_var_mapping = system_variable.to_dict() + for key, value in sys_var_mapping.items(): + if value is None: + continue + selector = (SYSTEM_VARIABLE_NODE_ID, key) + # If the system variable already exists, do not add it again. + # This ensures that we can keep the id of the system variables intact. + if self._has(selector): + continue + self.add(selector, value) # type: ignore + + @classmethod + def empty(cls) -> "VariablePool": + """Create an empty variable pool.""" + return cls(system_variables=SystemVariable.empty()) diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py index afc09bfac5..a62ffe46c9 100644 --- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -17,8 +17,12 @@ class GraphRuntimeState(BaseModel): """total tokens""" llm_usage: LLMUsage = LLMUsage.empty_usage() """llm usage info""" + + # The `outputs` field stores the final output values generated by executing workflows or chatflows. + # + # Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent + # after a serialization and deserialization round trip. outputs: dict[str, Any] = {} - """outputs""" node_run_steps: int = 0 """node run steps""" diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index c447f433aa..8b566c83cd 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -521,18 +521,52 @@ class IterationNode(BaseNode[IterationNodeData]): ) return elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: - yield IterationRunFailedEvent( - iteration_id=self.id, - iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, - start_at=start_at, - inputs=inputs, - outputs={"output": None}, - steps=len(iterator_list_value), - metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, - error=event.error, + yield NodeInIterationFailedEvent( + **metadata_event.model_dump(), ) + outputs[current_index] = None + + # clean nodes resources + for node_id in iteration_graph.node_ids: + variable_pool.remove([node_id]) + + # iteration run failed + if self.node_data.is_parallel: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + parallel_mode_run_id=parallel_mode_run_id, + start_at=start_at, + inputs=inputs, + outputs={"output": outputs}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + else: + yield IterationRunFailedEvent( + iteration_id=self.id, + iteration_node_id=self.node_id, + iteration_node_type=self.node_type, + iteration_node_data=self.node_data, + start_at=start_at, + inputs=inputs, + outputs={"output": outputs}, + steps=len(iterator_list_value), + metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens}, + error=event.error, + ) + + # stop the iterator + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + error=event.error, + ) + ) + return yield metadata_event current_output_segment = variable_pool.get(self.node_data.output_selector) diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 3f4a5edab9..d04e0bfae1 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -1,11 +1,29 @@ from collections.abc import Mapping -from typing import Any, Literal, Optional +from typing import Annotated, Any, Literal, Optional -from pydantic import BaseModel, Field +from pydantic import AfterValidator, BaseModel, Field +from core.variables.types import SegmentType from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData from core.workflow.utils.condition.entities import Condition +_VALID_VAR_TYPE = frozenset( + [ + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.OBJECT, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + ] +) + + +def _is_valid_var_type(seg_type: SegmentType) -> SegmentType: + if seg_type not in _VALID_VAR_TYPE: + raise ValueError(...) + return seg_type + class LoopVariableData(BaseModel): """ @@ -13,7 +31,7 @@ class LoopVariableData(BaseModel): """ label: str - var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] + var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)] value_type: Literal["variable", "constant"] value: Optional[Any | list[str]] = None diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 11fd7b6c2d..20501d0317 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -7,14 +7,9 @@ from typing import TYPE_CHECKING, Any, Literal, cast from configs import dify_config from core.variables import ( - ArrayNumberSegment, - ArrayObjectSegment, - ArrayStringSegment, IntegerSegment, - ObjectSegment, Segment, SegmentType, - StringSegment, ) from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -39,6 +34,7 @@ from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.loop.entities import LoopNodeData from core.workflow.utils.condition.processor import ConditionProcessor +from factories.variable_factory import TypeMismatchError, build_segment_with_type if TYPE_CHECKING: from core.workflow.entities.variable_pool import VariablePool @@ -505,23 +501,21 @@ class LoopNode(BaseNode[LoopNodeData]): return variable_mapping @staticmethod - def _get_segment_for_constant(var_type: str, value: Any) -> Segment: + def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment: """Get the appropriate segment type for a constant value.""" - segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = { - "string": (StringSegment, SegmentType.STRING), - "number": (IntegerSegment, SegmentType.NUMBER), - "object": (ObjectSegment, SegmentType.OBJECT), - "array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING), - "array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER), - "array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT), - } if var_type in ["array[string]", "array[number]", "array[object]"]: - if value: + if value and isinstance(value, str): value = json.loads(value) else: value = [] - segment_info = segment_mapping.get(var_type) - if not segment_info: - raise ValueError(f"Invalid variable type: {var_type}") - segment_class, value_type = segment_info - return segment_class(value=value, value_type=value_type) + try: + return build_segment_with_type(var_type, value) + except TypeMismatchError as type_exc: + # Attempt to parse the value as a JSON-encoded string, if applicable. + if not isinstance(value, str): + raise + try: + value = json.loads(value) + except ValueError: + raise type_exc + return build_segment_with_type(var_type, value) diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 5ee9bc331f..e215591888 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -16,7 +16,7 @@ class StartNode(BaseNode[StartNodeData]): def _run(self) -> NodeRunResult: node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables + system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() # TODO: System variables should be directly accessible, no need for special handling # Set system variables as node outputs. diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 48627a229d..3853a5d920 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -22,7 +22,7 @@ from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import AgentLogEvent from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent +from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from factories import file_factory @@ -373,6 +373,12 @@ class ToolNode(BaseNode[ToolNodeData]): agent_logs.append(agent_log) yield agent_log + elif message.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES: + assert isinstance(message.message, ToolInvokeMessage.RetrieverResourceMessage) + yield RunRetrieverResourceEvent( + retriever_resources=message.message.retriever_resources, + context=message.message.context, + ) # Add agent_logs to outputs['json'] to ensure frontend can access thinking process json_output: list[dict[str, Any]] = [] diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index be5083c9c1..1864b13784 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -130,6 +130,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): def get_zero_value(t: SegmentType): + # TODO(QuantumGhost): this should be a method of `SegmentType`. match t: case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: return variable_factory.build_segment([]) @@ -137,6 +138,10 @@ def get_zero_value(t: SegmentType): return variable_factory.build_segment({}) case SegmentType.STRING: return variable_factory.build_segment("") + case SegmentType.INTEGER: + return variable_factory.build_segment(0) + case SegmentType.FLOAT: + return variable_factory.build_segment(0.0) case SegmentType.NUMBER: return variable_factory.build_segment(0) case _: diff --git a/api/core/workflow/nodes/variable_assigner/v2/constants.py b/api/core/workflow/nodes/variable_assigner/v2/constants.py index 3797bfa77a..7f760e5baa 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/constants.py +++ b/api/core/workflow/nodes/variable_assigner/v2/constants.py @@ -1,5 +1,6 @@ from core.variables import SegmentType +# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy. EMPTY_VALUE_MAPPING = { SegmentType.STRING: "", SegmentType.NUMBER: 0, diff --git a/api/core/workflow/nodes/variable_assigner/v2/helpers.py b/api/core/workflow/nodes/variable_assigner/v2/helpers.py index 8fb2a27388..7a20975b15 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/v2/helpers.py @@ -10,10 +10,16 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation): case Operation.OVER_WRITE | Operation.CLEAR: return True case Operation.SET: - return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER} + return variable_type in { + SegmentType.OBJECT, + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.INTEGER, + SegmentType.FLOAT, + } case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE: # Only number variable can be added, subtracted, multiplied or divided - return variable_type == SegmentType.NUMBER + return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT} case Operation.APPEND | Operation.EXTEND: # Only array variable can be appended or extended return variable_type in { @@ -46,7 +52,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat match variable_type: case SegmentType.STRING | SegmentType.OBJECT: return operation in {Operation.OVER_WRITE, Operation.SET} - case SegmentType.NUMBER: + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: return operation in { Operation.OVER_WRITE, Operation.SET, @@ -66,7 +72,7 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va case SegmentType.STRING: return isinstance(value, str) - case SegmentType.NUMBER: + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: if not isinstance(value, int | float): return False if operation == Operation.DIVIDE and value == 0: diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py new file mode 100644 index 0000000000..df90c16596 --- /dev/null +++ b/api/core/workflow/system_variable.py @@ -0,0 +1,89 @@ +from collections.abc import Sequence +from typing import Any + +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator + +from core.file.models import File +from core.workflow.enums import SystemVariableKey + + +class SystemVariable(BaseModel): + """A model for managing system variables. + + Fields with a value of `None` are treated as absent and will not be included + in the variable pool. + """ + + model_config = ConfigDict( + extra="forbid", + serialize_by_alias=True, + validate_by_alias=True, + ) + + user_id: str | None = None + + # Ideally, `app_id` and `workflow_id` should be required and not `None`. + # However, there are scenarios in the codebase where these fields are not set. + # To maintain compatibility, they are marked as optional here. + app_id: str | None = None + workflow_id: str | None = None + + files: Sequence[File] = Field(default_factory=list) + + # NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`. + # To maintain compatibility with existing workflows, it must be serialized + # as `workflow_run_id` in dictionaries or JSON objects, and also referenced + # as `workflow_run_id` in the variable pool. + workflow_execution_id: str | None = Field( + validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"), + serialization_alias="workflow_run_id", + default=None, + ) + # Chatflow related fields. + query: str | None = None + conversation_id: str | None = None + dialogue_count: int | None = None + + @model_validator(mode="before") + @classmethod + def validate_json_fields(cls, data): + if isinstance(data, dict): + # For JSON validation, only allow workflow_run_id + if "workflow_execution_id" in data and "workflow_run_id" not in data: + # This is likely from direct instantiation, allow it + return data + elif "workflow_execution_id" in data and "workflow_run_id" in data: + # Both present, remove workflow_execution_id + data = data.copy() + data.pop("workflow_execution_id") + return data + return data + + @classmethod + def empty(cls) -> "SystemVariable": + return cls() + + def to_dict(self) -> dict[SystemVariableKey, Any]: + # NOTE: This method is provided for compatibility with legacy code. + # New code should use the `SystemVariable` object directly instead of converting + # it to a dictionary, as this conversion results in the loss of type information + # for each key, making static analysis more difficult. + + d: dict[SystemVariableKey, Any] = { + SystemVariableKey.FILES: self.files, + } + if self.user_id is not None: + d[SystemVariableKey.USER_ID] = self.user_id + if self.app_id is not None: + d[SystemVariableKey.APP_ID] = self.app_id + if self.workflow_id is not None: + d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id + if self.workflow_execution_id is not None: + d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id + if self.query is not None: + d[SystemVariableKey.QUERY] = self.query + if self.conversation_id is not None: + d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id + if self.dialogue_count is not None: + d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count + return d diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 0aab2426af..50ff733979 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -26,6 +26,7 @@ from core.workflow.entities.workflow_node_execution import ( from core.workflow.enums import SystemVariableKey from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from libs.datetime_utils import naive_utc_now @@ -43,7 +44,7 @@ class WorkflowCycleManager: self, *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], - workflow_system_variables: dict[SystemVariableKey, Any], + workflow_system_variables: SystemVariable, workflow_info: CycleManagerWorkflowInfo, workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, @@ -56,17 +57,22 @@ class WorkflowCycleManager: def handle_workflow_run_start(self) -> WorkflowExecution: inputs = {**self._application_generate_entity.inputs} - for key, value in (self._workflow_system_variables or {}).items(): - if key.value == "conversation": - continue - inputs[f"sys.{key.value}"] = value + + # Iterate over SystemVariable fields using Pydantic's model_fields + if self._workflow_system_variables: + for field_name, value in self._workflow_system_variables.to_dict().items(): + if field_name == SystemVariableKey.CONVERSATION_ID: + continue + inputs[f"sys.{field_name}"] = value # handle special values inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) # init workflow run # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this - execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_EXECUTION_ID) or uuid4()) + execution_id = str( + self._workflow_system_variables.workflow_execution_id if self._workflow_system_variables else None + ) or str(uuid4()) execution = WorkflowExecution.new( id_=execution_id, workflow_id=self._workflow_info.workflow_id, diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 2868dcb7de..1399efcdb1 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.base import BaseNode from core.workflow.nodes.event import NodeEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from factories import file_factory from models.enums import UserFrom @@ -254,7 +255,7 @@ class WorkflowEntry: # init variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, environment_variables=[], ) diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index b62b0b60d6..0771104fb1 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -193,13 +193,22 @@ def init_app(app: DifyApp): insecure=True, ) else: + headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None + + trace_endpoint = dify_config.OTLP_TRACE_ENDPOINT + if not trace_endpoint: + trace_endpoint = dify_config.OTLP_BASE_ENDPOINT + "/v1/traces" exporter = HTTPSpanExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/traces", - headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"}, + endpoint=trace_endpoint, + headers=headers, ) + + metric_endpoint = dify_config.OTLP_METRIC_ENDPOINT + if not metric_endpoint: + metric_endpoint = dify_config.OTLP_BASE_ENDPOINT + "/v1/traces" metric_exporter = HTTPMetricExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics", - headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"}, + endpoint=metric_endpoint, + headers=headers, ) else: exporter = ConsoleSpanExporter() diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 250ee4695e..39ebd009d5 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -91,9 +91,13 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen result = StringVariable.model_validate(mapping) case SegmentType.SECRET: result = SecretVariable.model_validate(mapping) - case SegmentType.NUMBER if isinstance(value, int): + case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int): + mapping = dict(mapping) + mapping["value_type"] = SegmentType.INTEGER result = IntegerVariable.model_validate(mapping) - case SegmentType.NUMBER if isinstance(value, float): + case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float): + mapping = dict(mapping) + mapping["value_type"] = SegmentType.FLOAT result = FloatVariable.model_validate(mapping) case SegmentType.NUMBER if not isinstance(value, float | int): raise VariableError(f"invalid number value {value}") @@ -119,6 +123,8 @@ def infer_segment_type_from_value(value: Any, /) -> SegmentType: def build_segment(value: Any, /) -> Segment: + # NOTE: If you have runtime type information available, consider using the `build_segment_with_type` + # below if value is None: return NoneSegment() if isinstance(value, str): @@ -134,12 +140,17 @@ def build_segment(value: Any, /) -> Segment: if isinstance(value, list): items = [build_segment(item) for item in value] types = {item.value_type for item in items} - if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items): + if all(isinstance(item, ArraySegment) for item in items): return ArrayAnySegment(value=value) + elif len(types) != 1: + if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): + return ArrayNumberSegment(value=value) + return ArrayAnySegment(value=value) + match types.pop(): case SegmentType.STRING: return ArrayStringSegment(value=value) - case SegmentType.NUMBER: + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: return ArrayNumberSegment(value=value) case SegmentType.OBJECT: return ArrayObjectSegment(value=value) @@ -153,6 +164,22 @@ def build_segment(value: Any, /) -> Segment: raise ValueError(f"not supported value {value}") +_segment_factory: Mapping[SegmentType, type[Segment]] = { + SegmentType.NONE: NoneSegment, + SegmentType.STRING: StringSegment, + SegmentType.INTEGER: IntegerSegment, + SegmentType.FLOAT: FloatSegment, + SegmentType.FILE: FileSegment, + SegmentType.OBJECT: ObjectSegment, + # Array types + SegmentType.ARRAY_ANY: ArrayAnySegment, + SegmentType.ARRAY_STRING: ArrayStringSegment, + SegmentType.ARRAY_NUMBER: ArrayNumberSegment, + SegmentType.ARRAY_OBJECT: ArrayObjectSegment, + SegmentType.ARRAY_FILE: ArrayFileSegment, +} + + def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: """ Build a segment with explicit type checking. @@ -190,7 +217,7 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: if segment_type == SegmentType.NONE: return NoneSegment() else: - raise TypeMismatchError(f"Expected {segment_type}, but got None") + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") # Handle empty list special case for array types if isinstance(value, list) and len(value) == 0: @@ -205,21 +232,25 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: elif segment_type == SegmentType.ARRAY_FILE: return ArrayFileSegment(value=value) else: - raise TypeMismatchError(f"Expected {segment_type}, but got empty list") - - # Build segment using existing logic to infer actual type - inferred_segment = build_segment(value) - inferred_type = inferred_segment.value_type + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") + inferred_type = SegmentType.infer_segment_type(value) # Type compatibility checking + if inferred_type is None: + raise TypeMismatchError( + f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" + ) if inferred_type == segment_type: - return inferred_segment - - # Type mismatch - raise error with descriptive message - raise TypeMismatchError( - f"Type mismatch: expected {segment_type}, but value '{value}' " - f"(type: {type(value).__name__}) corresponds to {inferred_type}" - ) + segment_class = _segment_factory[segment_type] + return segment_class(value_type=segment_type, value=value) + elif segment_type == SegmentType.NUMBER and inferred_type in ( + SegmentType.INTEGER, + SegmentType.FLOAT, + ): + segment_class = _segment_factory[inferred_type] + return segment_class(value_type=inferred_type, value=value) + else: + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") def segment_to_variable( @@ -247,6 +278,6 @@ def segment_to_variable( name=name, description=description, value=segment.value, - selector=selector, + selector=list(selector), ), ) diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py new file mode 100644 index 0000000000..8288bd54a3 --- /dev/null +++ b/api/fields/_value_type_serializer.py @@ -0,0 +1,15 @@ +from typing import TypedDict + +from core.variables.segments import Segment +from core.variables.types import SegmentType + + +class _VarTypedDict(TypedDict, total=False): + value_type: SegmentType + + +def serialize_value_type(v: _VarTypedDict | Segment) -> str: + if isinstance(v, Segment): + return v.value_type.exposed_type().value + else: + return v["value_type"].exposed_type().value diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 73c224542a..b6d85e0e24 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -188,6 +188,7 @@ app_detail_fields_with_site = { "site": fields.Nested(site_fields), "api_base_url": fields.String, "use_icon_as_answer_icon": fields.Boolean, + "max_active_requests": fields.Integer, "created_by": fields.String, "created_at": TimestampField, "updated_by": fields.String, diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index 71785e7d67..c5a0c9a49d 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -2,10 +2,12 @@ from flask_restful import fields from libs.helper import TimestampField +from ._value_type_serializer import serialize_value_type + conversation_variable_fields = { "id": fields.String, "name": fields.String, - "value_type": fields.String(attribute="value_type.value"), + "value_type": fields.String(attribute=serialize_value_type), "value": fields.String, "description": fields.String, "created_at": TimestampField, diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index f00ea71c54..930e59cc1c 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -5,6 +5,8 @@ from core.variables import SecretVariable, SegmentType, Variable from fields.member_fields import simple_account_fields from libs.helper import TimestampField +from ._value_type_serializer import serialize_value_type + ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET) @@ -24,11 +26,16 @@ class EnvironmentVariableField(fields.Raw): "id": value.id, "name": value.name, "value": value.value, - "value_type": value.value_type.value, + "value_type": value.value_type.exposed_type().value, "description": value.description, } if isinstance(value, dict): - value_type = value.get("value_type") + value_type_str = value.get("value_type") + if not isinstance(value_type_str, str): + raise TypeError( + f"unexpected type for value_type field, value={value_type_str}, type={type(value_type_str)}" + ) + value_type = SegmentType(value_type_str).exposed_type() if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES: raise ValueError(f"Unsupported environment variable value type: {value_type}") return value @@ -37,7 +44,7 @@ class EnvironmentVariableField(fields.Raw): conversation_variable_fields = { "id": fields.String, "name": fields.String, - "value_type": fields.String(attribute="value_type.value"), + "value_type": fields.String(attribute=serialize_value_type), "value": fields.Raw, "description": fields.String, } diff --git a/api/libs/uuid_utils.py b/api/libs/uuid_utils.py new file mode 100644 index 0000000000..a8190011ed --- /dev/null +++ b/api/libs/uuid_utils.py @@ -0,0 +1,164 @@ +import secrets +import struct +import time +import uuid + +# Reference for UUIDv7 specification: +# RFC 9562, Section 5.7 - https://www.rfc-editor.org/rfc/rfc9562.html#section-5.7 + +# Define the format for packing the timestamp as an unsigned 64-bit integer (big-endian). +# +# For details on the `struct.pack` format, refer to: +# https://docs.python.org/3/library/struct.html#byte-order-size-and-alignment +_PACK_TIMESTAMP = ">Q" + +# Define the format for packing the 12-bit random data A (as specified in RFC 9562 Section 5.7) +# into an unsigned 16-bit integer (big-endian). +_PACK_RAND_A = ">H" + + +def _create_uuidv7_bytes(timestamp_ms: int, random_bytes: bytes) -> bytes: + """Create UUIDv7 byte structure with given timestamp and random bytes. + + This is a private helper function that handles the common logic for creating + UUIDv7 byte structure according to RFC 9562 specification. + + UUIDv7 Structure: + - 48 bits: timestamp (milliseconds since Unix epoch) + - 12 bits: random data A (with version bits) + - 62 bits: random data B (with variant bits) + + The function performs the following operations: + 1. Creates a 128-bit (16-byte) UUID structure + 2. Packs the timestamp into the first 48 bits (6 bytes) + 3. Sets the version bits to 7 (0111) in the correct position + 4. Sets the variant bits to 10 (binary) in the correct position + 5. Fills the remaining bits with the provided random bytes + + Args: + timestamp_ms: The timestamp in milliseconds since Unix epoch (48 bits). + random_bytes: Random bytes to use for the random portions (must be 10 bytes). + First 2 bytes are used for random data A (12 bits after version). + Last 8 bytes are used for random data B (62 bits after variant). + + Returns: + A 16-byte bytes object representing the complete UUIDv7 structure. + + Note: + This function assumes the random_bytes parameter is exactly 10 bytes. + The caller is responsible for providing appropriate random data. + """ + # Create the 128-bit UUID structure + uuid_bytes = bytearray(16) + + # Pack timestamp (48 bits) into first 6 bytes + uuid_bytes[0:6] = struct.pack(_PACK_TIMESTAMP, timestamp_ms)[2:8] # Take last 6 bytes of 8-byte big-endian + + # Next 16 bits: random data A (12 bits) + version (4 bits) + # Take first 2 random bytes and set version to 7 + rand_a = struct.unpack(_PACK_RAND_A, random_bytes[0:2])[0] + # Clear the highest 4 bits to make room for the version field + # by performing a bitwise AND with 0x0FFF (binary: 0b0000_1111_1111_1111). + rand_a = rand_a & 0x0FFF + # Set the version field to 7 (binary: 0111) by performing a bitwise OR with 0x7000 (binary: 0b0111_0000_0000_0000). + rand_a = rand_a | 0x7000 + uuid_bytes[6:8] = struct.pack(_PACK_RAND_A, rand_a) + + # Last 64 bits: random data B (62 bits) + variant (2 bits) + # Use remaining 8 random bytes and set variant to 10 (binary) + uuid_bytes[8:16] = random_bytes[2:10] + + # Set variant bits (first 2 bits of byte 8 should be '10') + uuid_bytes[8] = (uuid_bytes[8] & 0x3F) | 0x80 # Set variant to 10xxxxxx + + return bytes(uuid_bytes) + + +def uuidv7(timestamp_ms: int | None = None) -> uuid.UUID: + """Generate a UUID version 7 according to RFC 9562 specification. + + UUIDv7 features a time-ordered value field derived from the widely + implemented and well known Unix Epoch timestamp source, the number of + milliseconds since midnight 1 Jan 1970 UTC, leap seconds excluded. + + Structure: + - 48 bits: timestamp (milliseconds since Unix epoch) + - 12 bits: random data A (with version bits) + - 62 bits: random data B (with variant bits) + + Args: + timestamp_ms: The timestamp used when generating UUID, use the current time if unspecified. + Should be an integer representing milliseconds since Unix epoch. + + Returns: + A UUID object representing a UUIDv7. + + Example: + >>> import time + >>> # Generate UUIDv7 with current time + >>> uuid_current = uuidv7() + >>> # Generate UUIDv7 with specific timestamp + >>> uuid_specific = uuidv7(int(time.time() * 1000)) + """ + if timestamp_ms is None: + timestamp_ms = int(time.time() * 1000) + + # Generate 10 random bytes for the random portions + random_bytes = secrets.token_bytes(10) + + # Create UUIDv7 bytes using the helper function + uuid_bytes = _create_uuidv7_bytes(timestamp_ms, random_bytes) + + return uuid.UUID(bytes=uuid_bytes) + + +def uuidv7_timestamp(id_: uuid.UUID) -> int: + """Extract the timestamp from a UUIDv7. + + UUIDv7 contains a 48-bit timestamp field representing milliseconds since + the Unix epoch (1970-01-01 00:00:00 UTC). This function extracts and + returns that timestamp as an integer representing milliseconds since the epoch. + + Args: + id_: A UUID object that should be a UUIDv7 (version 7). + + Returns: + The timestamp as an integer representing milliseconds since Unix epoch. + + Raises: + ValueError: If the provided UUID is not version 7. + + Example: + >>> uuid_v7 = uuidv7() + >>> timestamp = uuidv7_timestamp(uuid_v7) + >>> print(f"UUID was created at: {timestamp} ms") + """ + # Verify this is a UUIDv7 + if id_.version != 7: + raise ValueError(f"Expected UUIDv7 (version 7), got version {id_.version}") + + # Extract the UUID bytes + uuid_bytes = id_.bytes + + # Extract the first 48 bits (6 bytes) as the timestamp in milliseconds + # Pad with 2 zero bytes at the beginning to make it 8 bytes for unpacking as Q (unsigned long long) + timestamp_bytes = b"\x00\x00" + uuid_bytes[0:6] + ts_in_ms = struct.unpack(_PACK_TIMESTAMP, timestamp_bytes)[0] + + # Return timestamp directly in milliseconds as integer + assert isinstance(ts_in_ms, int) + return ts_in_ms + + +def uuidv7_boundary(timestamp_ms: int) -> uuid.UUID: + """Generate a non-random uuidv7 with the given timestamp (first 48 bits) and + all random bits to 0. As the smallest possible uuidv7 for that timestamp, + it may be used as a boundary for partitions. + """ + # Use zero bytes for all random portions + zero_random_bytes = b"\x00" * 10 + + # Create UUIDv7 bytes using the helper function + uuid_bytes = _create_uuidv7_bytes(timestamp_ms, zero_random_bytes) + + return uuid.UUID(bytes=uuid_bytes) diff --git a/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py b/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py new file mode 100644 index 0000000000..2bbbb3d28e --- /dev/null +++ b/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py @@ -0,0 +1,86 @@ +"""add uuidv7 function in SQL + +Revision ID: 1c9ba48be8e4 +Revises: 58eb7bdb93fe +Create Date: 2025-07-02 23:32:38.484499 + +""" + +""" +The functions in this files comes from https://github.com/dverite/postgres-uuidv7-sql/, with minor modifications. + +LICENSE: + +# Copyright and License + +Copyright (c) 2024, Daniel Vérité + +Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement is hereby granted, provided that the above copyright notice and this paragraph and the following two paragraphs appear in all copies. + +In no event shall Daniel Vérité be liable to any party for direct, indirect, special, incidental, or consequential damages, including lost profits, arising out of the use of this software and its documentation, even if Daniel Vérité has been advised of the possibility of such damage. + +Daniel Vérité specifically disclaims any warranties, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose. The software provided hereunder is on an "AS IS" basis, and Daniel Vérité has no obligations to provide maintenance, support, updates, enhancements, or modifications. +""" + +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '1c9ba48be8e4' +down_revision = '58eb7bdb93fe' +branch_labels: None = None +depends_on: None = None + + +def upgrade(): + # This implementation differs slightly from the original uuidv7 function in + # https://github.com/dverite/postgres-uuidv7-sql/. + # The ability to specify source timestamp has been removed because its type signature is incompatible with + # PostgreSQL 18's `uuidv7` function. This capability is rarely needed in practice, as IDs can be + # generated and controlled within the application layer. + op.execute(sa.text(r""" +/* Main function to generate a uuidv7 value with millisecond precision */ +CREATE FUNCTION uuidv7() RETURNS uuid +AS +$$ + -- Replace the first 48 bits of a uuidv4 with the current + -- number of milliseconds since 1970-01-01 UTC + -- and set the "ver" field to 7 by setting additional bits +SELECT encode( + set_bit( + set_bit( + overlay(uuid_send(gen_random_uuid()) placing + substring(int8send((extract(epoch from clock_timestamp()) * 1000)::bigint) from + 3) + from 1 for 6), + 52, 1), + 53, 1), 'hex')::uuid; +$$ LANGUAGE SQL VOLATILE PARALLEL SAFE; + +COMMENT ON FUNCTION uuidv7 IS + 'Generate a uuid-v7 value with a 48-bit timestamp (millisecond precision) and 74 bits of randomness'; +""")) + + op.execute(sa.text(r""" +CREATE FUNCTION uuidv7_boundary(timestamptz) RETURNS uuid +AS +$$ + /* uuid fields: version=0b0111, variant=0b10 */ +SELECT encode( + overlay('\x00000000000070008000000000000000'::bytea + placing substring(int8send(floor(extract(epoch from $1) * 1000)::bigint) from 3) + from 1 for 6), + 'hex')::uuid; +$$ LANGUAGE SQL STABLE STRICT PARALLEL SAFE; + +COMMENT ON FUNCTION uuidv7_boundary(timestamptz) IS + 'Generate a non-random uuidv7 with the given timestamp (first 48 bits) and all random bits to 0. As the smallest possible uuidv7 for that timestamp, it may be used as a boundary for partitions.'; +""" +)) + + +def downgrade(): + op.execute(sa.text("DROP FUNCTION uuidv7")) + op.execute(sa.text("DROP FUNCTION uuidv7_boundary")) diff --git a/api/models/workflow.py b/api/models/workflow.py index 77d48bec4f..9930859201 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -12,6 +12,7 @@ from sqlalchemy import orm from core.file.constants import maybe_file_object from core.file.models import File from core.variables import utils as variable_utils +from core.variables.variables import FloatVariable, IntegerVariable, StringVariable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes.enums import NodeType from factories.variable_factory import TypeMismatchError, build_segment_with_type @@ -347,7 +348,7 @@ class Workflow(Base): ) @property - def environment_variables(self) -> Sequence[Variable]: + def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: # TODO: find some way to init `self._environment_variables` when instance created. if self._environment_variables is None: self._environment_variables = "{}" @@ -367,11 +368,15 @@ class Workflow(Base): def decrypt_func(var): if isinstance(var, SecretVariable): return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) - else: + elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)): return var + else: + raise AssertionError("this statement should be unreachable.") - results = list(map(decrypt_func, results)) - return results + decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list( + map(decrypt_func, results) + ) + return decrypted_results @environment_variables.setter def environment_variables(self, value: Sequence[Variable]): diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index bb66bb3a9d..ebd1d74b20 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -29,11 +29,12 @@ from sqlalchemy.orm import Session, sessionmaker from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.workflow import WorkflowRun +from repositories.api_workflow_run_repository import APIWorkflowRunRepository logger = logging.getLogger(__name__) -class DifyAPISQLAlchemyWorkflowRunRepository: +class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): """ SQLAlchemy implementation of APIWorkflowRunRepository. diff --git a/api/services/app_service.py b/api/services/app_service.py index db0f8cd414..0a08f345df 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -233,6 +233,7 @@ class AppService: app.icon = args.get("icon") app.icon_background = args.get("icon_background") app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) + app.max_active_requests = args.get("max_active_requests") app.updated_by = current_user.id app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0149d50346..677bc74237 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -3,7 +3,7 @@ import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import Any, Optional +from typing import Any, Optional, cast from uuid import uuid4 from sqlalchemy import select @@ -15,10 +15,10 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File from core.repositories import DifyCoreRepositoryFactory from core.variables import Variable +from core.variables.variables import VariableUnion from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes import NodeType @@ -28,6 +28,7 @@ from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.event.types import NodeEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db @@ -369,7 +370,7 @@ class WorkflowService: else: variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs=user_inputs, environment_variables=draft_workflow.environment_variables, conversation_variables=[], @@ -685,36 +686,30 @@ def _setup_variable_pool( ): # Only inject system variables for START node type. if node_type == NodeType.START: - # Create a variable pool. - system_inputs: dict[SystemVariableKey, Any] = { - # From inputs: - SystemVariableKey.FILES: files, - SystemVariableKey.USER_ID: user_id, - # From workflow model - SystemVariableKey.APP_ID: workflow.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - # Randomly generated. - SystemVariableKey.WORKFLOW_EXECUTION_ID: str(uuid.uuid4()), - } + system_variable = SystemVariable( + user_id=user_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + files=files or [], + workflow_execution_id=str(uuid.uuid4()), + ) # Only add chatflow-specific variables for non-workflow types if workflow.type != WorkflowType.WORKFLOW.value: - system_inputs.update( - { - SystemVariableKey.QUERY: query, - SystemVariableKey.CONVERSATION_ID: conversation_id, - SystemVariableKey.DIALOGUE_COUNT: 0, - } - ) + system_variable.query = query + system_variable.conversation_id = conversation_id + system_variable.dialogue_count = 0 else: - system_inputs = {} + system_variable = SystemVariable.empty() # init variable pool variable_pool = VariablePool( - system_variables=system_inputs, + system_variables=system_variable, user_inputs=user_inputs, environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + conversation_variables=cast(list[VariableUnion], conversation_variables), # ) return variable_pool diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index 13d78c2d83..90bb04f649 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -9,12 +9,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.entities import CodeNodeData +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @@ -50,7 +50,7 @@ def init_code_node(code_config: dict): # construct variable pool variable_pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 1ab0cc2451..50e726febf 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -6,11 +6,11 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.http_request.node import HttpRequestNode +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock @@ -44,7 +44,7 @@ def init_http_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index 8acaa54b9c..ff119b7482 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -13,12 +13,12 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import AssistantPromptMessage from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.llm.node import LLMNode +from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from models.enums import UserFrom from models.workflow import WorkflowType @@ -62,12 +62,14 @@ def init_llm_node(config: dict) -> LLMNode: # construct variable pool variable_pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather today?", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, + system_variables=SystemVariable( + user_id="aaa", + app_id=app_id, + workflow_id=workflow_id, + files=[], + query="what's the weather today?", + conversation_id="abababa", + ), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 0df8e8b146..dd8466afa6 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -8,11 +8,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.model_runtime.entities import AssistantPromptMessage from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from models.enums import UserFrom from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_config @@ -64,12 +64,9 @@ def init_parameter_extractor_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, + system_variables=SystemVariable( + user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa" + ), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index a5f2677a59..1f617fc92d 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -6,11 +6,11 @@ import pytest from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock @@ -61,7 +61,7 @@ def test_execute_code(setup_code_executor_mock): # construct variable pool variable_pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 039beedafe..6907e0163e 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -6,12 +6,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.event.event import RunCompletedEvent from core.workflow.nodes.tool.tool_node import ToolNode +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType @@ -44,7 +44,7 @@ def init_tool_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index b70c8830ed..e9d4ee1935 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -88,6 +88,7 @@ def test_flask_configs(monkeypatch): "pool_pre_ping": False, "pool_recycle": 3600, "pool_size": 30, + "pool_use_lifo": False, } assert config["CONSOLE_WEB_URL"] == "https://example.com" diff --git a/api/tests/unit_tests/controllers/console/test_wraps.py b/api/tests/unit_tests/controllers/console/test_wraps.py new file mode 100644 index 0000000000..9742368f04 --- /dev/null +++ b/api/tests/unit_tests/controllers/console/test_wraps.py @@ -0,0 +1,380 @@ +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from flask_login import LoginManager, UserMixin + +from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout +from controllers.console.workspace.error import AccountNotInitializedError +from controllers.console.wraps import ( + account_initialization_required, + cloud_edition_billing_rate_limit_check, + cloud_edition_billing_resource_check, + enterprise_license_required, + only_edition_cloud, + only_edition_enterprise, + only_edition_self_hosted, + setup_required, +) +from models.account import AccountStatus +from services.feature_service import LicenseStatus + + +class MockUser(UserMixin): + """Simple User class for testing.""" + + def __init__(self, user_id: str): + self.id = user_id + self.current_tenant_id = "tenant123" + + def get_id(self) -> str: + return self.id + + +def create_app_with_login(): + """Create a Flask app with LoginManager configured.""" + app = Flask(__name__) + app.config["SECRET_KEY"] = "test-secret-key" + + login_manager = LoginManager() + login_manager.init_app(app) + + @login_manager.user_loader + def load_user(user_id: str): + return MockUser(user_id) + + return app + + +class TestAccountInitialization: + """Test account initialization decorator""" + + def test_should_allow_initialized_account(self): + """Test that initialized accounts can access protected views""" + # Arrange + mock_user = MagicMock() + mock_user.status = AccountStatus.ACTIVE + + @account_initialization_required + def protected_view(): + return "success" + + # Act + with patch("controllers.console.wraps.current_user", mock_user): + result = protected_view() + + # Assert + assert result == "success" + + def test_should_reject_uninitialized_account(self): + """Test that uninitialized accounts raise AccountNotInitializedError""" + # Arrange + mock_user = MagicMock() + mock_user.status = AccountStatus.UNINITIALIZED + + @account_initialization_required + def protected_view(): + return "success" + + # Act & Assert + with patch("controllers.console.wraps.current_user", mock_user): + with pytest.raises(AccountNotInitializedError): + protected_view() + + +class TestEditionChecks: + """Test edition-specific decorators""" + + def test_only_edition_cloud_allows_cloud_edition(self): + """Test cloud edition decorator allows CLOUD edition""" + + # Arrange + @only_edition_cloud + def cloud_view(): + return "cloud_success" + + # Act + with patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"): + result = cloud_view() + + # Assert + assert result == "cloud_success" + + def test_only_edition_cloud_rejects_other_editions(self): + """Test cloud edition decorator rejects non-CLOUD editions""" + # Arrange + app = Flask(__name__) + + @only_edition_cloud + def cloud_view(): + return "cloud_success" + + # Act & Assert + with app.test_request_context(): + with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"): + with pytest.raises(Exception) as exc_info: + cloud_view() + assert exc_info.value.code == 404 + + def test_only_edition_enterprise_allows_when_enabled(self): + """Test enterprise edition decorator allows when ENTERPRISE_ENABLED is True""" + + # Arrange + @only_edition_enterprise + def enterprise_view(): + return "enterprise_success" + + # Act + with patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True): + result = enterprise_view() + + # Assert + assert result == "enterprise_success" + + def test_only_edition_self_hosted_allows_self_hosted(self): + """Test self-hosted edition decorator allows SELF_HOSTED edition""" + + # Arrange + @only_edition_self_hosted + def self_hosted_view(): + return "self_hosted_success" + + # Act + with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"): + result = self_hosted_view() + + # Assert + assert result == "self_hosted_success" + + +class TestBillingResourceLimits: + """Test billing resource limit decorators""" + + def test_should_allow_when_under_resource_limit(self): + """Test that requests are allowed when under resource limits""" + # Arrange + mock_features = MagicMock() + mock_features.billing.enabled = True + mock_features.members.limit = 10 + mock_features.members.size = 5 + + @cloud_edition_billing_resource_check("members") + def add_member(): + return "member_added" + + # Act + with patch("controllers.console.wraps.current_user"): + with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): + result = add_member() + + # Assert + assert result == "member_added" + + def test_should_reject_when_over_resource_limit(self): + """Test that requests are rejected when over resource limits""" + # Arrange + app = create_app_with_login() + mock_features = MagicMock() + mock_features.billing.enabled = True + mock_features.members.limit = 10 + mock_features.members.size = 10 + + @cloud_edition_billing_resource_check("members") + def add_member(): + return "member_added" + + # Act & Assert + with app.test_request_context(): + with patch("controllers.console.wraps.current_user", MockUser("test_user")): + with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): + with pytest.raises(Exception) as exc_info: + add_member() + assert exc_info.value.code == 403 + assert "members has reached the limit" in str(exc_info.value.description) + + def test_should_check_source_for_documents_limit(self): + """Test document limit checks request source""" + # Arrange + app = create_app_with_login() + mock_features = MagicMock() + mock_features.billing.enabled = True + mock_features.documents_upload_quota.limit = 100 + mock_features.documents_upload_quota.size = 100 + + @cloud_edition_billing_resource_check("documents") + def upload_document(): + return "document_uploaded" + + # Test 1: Should reject when source is datasets + with app.test_request_context("/?source=datasets"): + with patch("controllers.console.wraps.current_user", MockUser("test_user")): + with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): + with pytest.raises(Exception) as exc_info: + upload_document() + assert exc_info.value.code == 403 + + # Test 2: Should allow when source is not datasets + with app.test_request_context("/?source=other"): + with patch("controllers.console.wraps.current_user", MockUser("test_user")): + with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features): + result = upload_document() + assert result == "document_uploaded" + + +class TestRateLimiting: + """Test rate limiting decorator""" + + @patch("controllers.console.wraps.redis_client") + @patch("controllers.console.wraps.db") + def test_should_allow_requests_within_rate_limit(self, mock_db, mock_redis): + """Test that requests within rate limit are allowed""" + # Arrange + mock_rate_limit = MagicMock() + mock_rate_limit.enabled = True + mock_rate_limit.limit = 10 + mock_redis.zcard.return_value = 5 # 5 requests in window + + @cloud_edition_billing_rate_limit_check("knowledge") + def knowledge_request(): + return "knowledge_success" + + # Act + with patch("controllers.console.wraps.current_user"): + with patch( + "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit + ): + result = knowledge_request() + + # Assert + assert result == "knowledge_success" + mock_redis.zadd.assert_called_once() + mock_redis.zremrangebyscore.assert_called_once() + + @patch("controllers.console.wraps.redis_client") + @patch("controllers.console.wraps.db") + def test_should_reject_requests_over_rate_limit(self, mock_db, mock_redis): + """Test that requests over rate limit are rejected and logged""" + # Arrange + app = create_app_with_login() + mock_rate_limit = MagicMock() + mock_rate_limit.enabled = True + mock_rate_limit.limit = 10 + mock_rate_limit.subscription_plan = "pro" + mock_redis.zcard.return_value = 11 # Over limit + + mock_session = MagicMock() + mock_db.session = mock_session + + @cloud_edition_billing_rate_limit_check("knowledge") + def knowledge_request(): + return "knowledge_success" + + # Act & Assert + with app.test_request_context(): + with patch("controllers.console.wraps.current_user", MockUser("test_user")): + with patch( + "controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit + ): + with pytest.raises(Exception) as exc_info: + knowledge_request() + + # Verify error + assert exc_info.value.code == 403 + assert "rate limit" in str(exc_info.value.description) + + # Verify rate limit log was created + mock_session.add.assert_called_once() + mock_session.commit.assert_called_once() + + +class TestSystemSetup: + """Test system setup decorator""" + + @patch("controllers.console.wraps.db") + def test_should_allow_when_setup_complete(self, mock_db): + """Test that requests are allowed when setup is complete""" + # Arrange + mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists + + @setup_required + def admin_view(): + return "admin_success" + + # Act + with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"): + result = admin_view() + + # Assert + assert result == "admin_success" + + @patch("controllers.console.wraps.db") + @patch("controllers.console.wraps.os.environ.get") + def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db): + """Test NotInitValidateError when INIT_PASSWORD is set but setup not complete""" + # Arrange + mock_db.session.query.return_value.first.return_value = None # No setup + mock_environ_get.return_value = "some_password" + + @setup_required + def admin_view(): + return "admin_success" + + # Act & Assert + with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"): + with pytest.raises(NotInitValidateError): + admin_view() + + @patch("controllers.console.wraps.db") + @patch("controllers.console.wraps.os.environ.get") + def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db): + """Test NotSetupError when no INIT_PASSWORD and setup not complete""" + # Arrange + mock_db.session.query.return_value.first.return_value = None # No setup + mock_environ_get.return_value = None # No INIT_PASSWORD + + @setup_required + def admin_view(): + return "admin_success" + + # Act & Assert + with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"): + with pytest.raises(NotSetupError): + admin_view() + + +class TestEnterpriseLicense: + """Test enterprise license decorator""" + + def test_should_allow_with_valid_license(self): + """Test that valid licenses allow access""" + # Arrange + mock_settings = MagicMock() + mock_settings.license.status = LicenseStatus.ACTIVE + + @enterprise_license_required + def enterprise_feature(): + return "enterprise_success" + + # Act + with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings): + result = enterprise_feature() + + # Assert + assert result == "enterprise_success" + + @pytest.mark.parametrize("invalid_status", [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST]) + def test_should_reject_with_invalid_license(self, invalid_status): + """Test that invalid licenses raise UnauthorizedAndForceLogout""" + # Arrange + mock_settings = MagicMock() + mock_settings.license.status = invalid_status + + @enterprise_license_required + def enterprise_feature(): + return "enterprise_success" + + # Act & Assert + with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings): + with pytest.raises(UnauthorizedAndForceLogout) as exc_info: + enterprise_feature() + assert "license is invalid" in str(exc_info.value) diff --git a/api/tests/unit_tests/core/tools/utils/__init__.py b/api/tests/unit_tests/core/tools/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/tools/utils/test_parser.py b/api/tests/unit_tests/core/tools/utils/test_parser.py new file mode 100644 index 0000000000..8e07293ce0 --- /dev/null +++ b/api/tests/unit_tests/core/tools/utils/test_parser.py @@ -0,0 +1,56 @@ +import pytest +from flask import Flask + +from core.tools.utils.parser import ApiBasedToolSchemaParser + + +@pytest.fixture +def app(): + app = Flask(__name__) + return app + + +def test_parse_openapi_to_tool_bundle_operation_id(app): + openapi = { + "openapi": "3.0.0", + "info": {"title": "Simple API", "version": "1.0.0"}, + "servers": [{"url": "http://localhost:3000"}], + "paths": { + "/": { + "get": { + "summary": "Root endpoint", + "responses": { + "200": { + "description": "Successful response", + } + }, + } + }, + "/api/resources": { + "get": { + "summary": "Non-root endpoint without an operationId", + "responses": { + "200": { + "description": "Successful response", + } + }, + }, + "post": { + "summary": "Non-root endpoint with an operationId", + "operationId": "createResource", + "responses": { + "201": { + "description": "Resource created", + } + }, + }, + }, + }, + } + with app.test_request_context(): + tool_bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi) + + assert len(tool_bundles) == 3 + assert tool_bundles[0].operation_id == "_get" + assert tool_bundles[1].operation_id == "apiresources_get" + assert tool_bundles[2].operation_id == "createResource" diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 1b035d01a7..cdc261fd42 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -1,14 +1,49 @@ +import dataclasses + +from pydantic import BaseModel + +from core.file import File, FileTransferMethod, FileType from core.helper import encrypter -from core.variables import SecretVariable, StringVariable +from core.variables.segments import ( + ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, + ArrayStringSegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + SegmentUnion, + StringSegment, + get_segment_discriminator, +) +from core.variables.types import SegmentType +from core.variables.variables import ( + ArrayAnyVariable, + ArrayFileVariable, + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FileVariable, + FloatVariable, + IntegerVariable, + NoneVariable, + ObjectVariable, + SecretVariable, + StringVariable, + Variable, + VariableUnion, +) from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable def test_segment_group_to_text(): variable_pool = VariablePool( - system_variables={ - SystemVariableKey("user_id"): "fake-user-id", - }, + system_variables=SystemVariable(user_id="fake-user-id"), user_inputs={}, environment_variables=[ SecretVariable(name="secret_key", value="fake-secret-key"), @@ -30,7 +65,7 @@ def test_segment_group_to_text(): def test_convert_constant_to_segment_group(): variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -43,9 +78,7 @@ def test_convert_constant_to_segment_group(): def test_convert_variable_to_segment_group(): variable_pool = VariablePool( - system_variables={ - SystemVariableKey("user_id"): "fake-user-id", - }, + system_variables=SystemVariable(user_id="fake-user-id"), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -56,3 +89,297 @@ def test_convert_variable_to_segment_group(): assert segments_group.log == "fake-user-id" assert isinstance(segments_group.value[0], StringVariable) assert segments_group.value[0].value == "fake-user-id" + + +class _Segments(BaseModel): + segments: list[SegmentUnion] + + +class _Variables(BaseModel): + variables: list[VariableUnion] + + +def create_test_file( + file_type: FileType = FileType.DOCUMENT, + transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE, + filename: str = "test.txt", + extension: str = ".txt", + mime_type: str = "text/plain", + size: int = 1024, +) -> File: + """Factory function to create File objects for testing""" + return File( + tenant_id="test-tenant", + type=file_type, + transfer_method=transfer_method, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None, + remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None, + storage_key="test-storage-key", + ) + + +class TestSegmentDumpAndLoad: + """Test suite for segment and variable serialization/deserialization""" + + def test_segments(self): + """Test basic segment serialization compatibility""" + model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")]) + json = model.model_dump_json() + print("Json: ", json) + loaded = _Segments.model_validate_json(json) + assert loaded == model + + def test_segment_number(self): + """Test number segment serialization compatibility""" + model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)]) + json = model.model_dump_json() + print("Json: ", json) + loaded = _Segments.model_validate_json(json) + assert loaded == model + + def test_variables(self): + """Test variable serialization compatibility""" + model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")]) + json = model.model_dump_json() + print("Json: ", json) + restored = _Variables.model_validate_json(json) + assert restored == model + + def test_all_segments_serialization(self): + """Test serialization/deserialization of all segment types""" + # Create one instance of each segment type + test_file = create_test_file() + + all_segments: list[SegmentUnion] = [ + NoneSegment(), + StringSegment(value="test string"), + IntegerSegment(value=42), + FloatSegment(value=3.14), + ObjectSegment(value={"key": "value", "number": 123}), + FileSegment(value=test_file), + ArrayAnySegment(value=[1, "string", 3.14, {"key": "value"}]), + ArrayStringSegment(value=["hello", "world"]), + ArrayNumberSegment(value=[1, 2.5, 3]), + ArrayObjectSegment(value=[{"id": 1}, {"id": 2}]), + ArrayFileSegment(value=[]), # Empty array to avoid file complexity + ] + + # Test serialization and deserialization + model = _Segments(segments=all_segments) + json_str = model.model_dump_json() + loaded = _Segments.model_validate_json(json_str) + + # Verify all segments are preserved + assert len(loaded.segments) == len(all_segments) + + for original, loaded_segment in zip(all_segments, loaded.segments): + assert type(loaded_segment) == type(original) + assert loaded_segment.value_type == original.value_type + + # For file segments, compare key properties instead of exact equality + if isinstance(original, FileSegment) and isinstance(loaded_segment, FileSegment): + orig_file = original.value + loaded_file = loaded_segment.value + assert isinstance(orig_file, File) + assert isinstance(loaded_file, File) + assert loaded_file.tenant_id == orig_file.tenant_id + assert loaded_file.type == orig_file.type + assert loaded_file.filename == orig_file.filename + else: + assert loaded_segment.value == original.value + + def test_all_variables_serialization(self): + """Test serialization/deserialization of all variable types""" + # Create one instance of each variable type + test_file = create_test_file() + + all_variables: list[VariableUnion] = [ + NoneVariable(name="none_var"), + StringVariable(value="test string", name="string_var"), + IntegerVariable(value=42, name="int_var"), + FloatVariable(value=3.14, name="float_var"), + ObjectVariable(value={"key": "value", "number": 123}, name="object_var"), + FileVariable(value=test_file, name="file_var"), + ArrayAnyVariable(value=[1, "string", 3.14, {"key": "value"}], name="array_any_var"), + ArrayStringVariable(value=["hello", "world"], name="array_string_var"), + ArrayNumberVariable(value=[1, 2.5, 3], name="array_number_var"), + ArrayObjectVariable(value=[{"id": 1}, {"id": 2}], name="array_object_var"), + ArrayFileVariable(value=[], name="array_file_var"), # Empty array to avoid file complexity + ] + + # Test serialization and deserialization + model = _Variables(variables=all_variables) + json_str = model.model_dump_json() + loaded = _Variables.model_validate_json(json_str) + + # Verify all variables are preserved + assert len(loaded.variables) == len(all_variables) + + for original, loaded_variable in zip(all_variables, loaded.variables): + assert type(loaded_variable) == type(original) + assert loaded_variable.value_type == original.value_type + assert loaded_variable.name == original.name + + # For file variables, compare key properties instead of exact equality + if isinstance(original, FileVariable) and isinstance(loaded_variable, FileVariable): + orig_file = original.value + loaded_file = loaded_variable.value + assert isinstance(orig_file, File) + assert isinstance(loaded_file, File) + assert loaded_file.tenant_id == orig_file.tenant_id + assert loaded_file.type == orig_file.type + assert loaded_file.filename == orig_file.filename + else: + assert loaded_variable.value == original.value + + def test_segment_discriminator_function_for_segment_types(self): + """Test the segment discriminator function""" + + @dataclasses.dataclass + class TestCase: + segment: Segment + expected_segment_type: SegmentType + + file1 = create_test_file() + file2 = create_test_file(filename="test2.txt") + + cases = [ + TestCase( + NoneSegment(), + SegmentType.NONE, + ), + TestCase( + StringSegment(value=""), + SegmentType.STRING, + ), + TestCase( + FloatSegment(value=0.0), + SegmentType.FLOAT, + ), + TestCase( + IntegerSegment(value=0), + SegmentType.INTEGER, + ), + TestCase( + ObjectSegment(value={}), + SegmentType.OBJECT, + ), + TestCase( + FileSegment(value=file1), + SegmentType.FILE, + ), + TestCase( + ArrayAnySegment(value=[0, 0.0, ""]), + SegmentType.ARRAY_ANY, + ), + TestCase( + ArrayStringSegment(value=[""]), + SegmentType.ARRAY_STRING, + ), + TestCase( + ArrayNumberSegment(value=[0, 0.0]), + SegmentType.ARRAY_NUMBER, + ), + TestCase( + ArrayObjectSegment(value=[{}]), + SegmentType.ARRAY_OBJECT, + ), + TestCase( + ArrayFileSegment(value=[file1, file2]), + SegmentType.ARRAY_FILE, + ), + ] + + for test_case in cases: + segment = test_case.segment + assert get_segment_discriminator(segment) == test_case.expected_segment_type, ( + f"get_segment_discriminator failed for type {type(segment)}" + ) + model_dict = segment.model_dump(mode="json") + assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, ( + f"get_segment_discriminator failed for serialized form of type {type(segment)}" + ) + + def test_variable_discriminator_function_for_variable_types(self): + """Test the variable discriminator function""" + + @dataclasses.dataclass + class TestCase: + variable: Variable + expected_segment_type: SegmentType + + file1 = create_test_file() + file2 = create_test_file(filename="test2.txt") + + cases = [ + TestCase( + NoneVariable(name="none_var"), + SegmentType.NONE, + ), + TestCase( + StringVariable(value="test", name="string_var"), + SegmentType.STRING, + ), + TestCase( + FloatVariable(value=0.0, name="float_var"), + SegmentType.FLOAT, + ), + TestCase( + IntegerVariable(value=0, name="int_var"), + SegmentType.INTEGER, + ), + TestCase( + ObjectVariable(value={}, name="object_var"), + SegmentType.OBJECT, + ), + TestCase( + FileVariable(value=file1, name="file_var"), + SegmentType.FILE, + ), + TestCase( + SecretVariable(value="secret", name="secret_var"), + SegmentType.SECRET, + ), + TestCase( + ArrayAnyVariable(value=[0, 0.0, ""], name="array_any_var"), + SegmentType.ARRAY_ANY, + ), + TestCase( + ArrayStringVariable(value=[""], name="array_string_var"), + SegmentType.ARRAY_STRING, + ), + TestCase( + ArrayNumberVariable(value=[0, 0.0], name="array_number_var"), + SegmentType.ARRAY_NUMBER, + ), + TestCase( + ArrayObjectVariable(value=[{}], name="array_object_var"), + SegmentType.ARRAY_OBJECT, + ), + TestCase( + ArrayFileVariable(value=[file1, file2], name="array_file_var"), + SegmentType.ARRAY_FILE, + ), + ] + + for test_case in cases: + variable = test_case.variable + assert get_segment_discriminator(variable) == test_case.expected_segment_type, ( + f"get_segment_discriminator failed for type {type(variable)}" + ) + model_dict = variable.model_dump(mode="json") + assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, ( + f"get_segment_discriminator failed for serialized form of type {type(variable)}" + ) + + def test_invlaid_value_for_discriminator(self): + # Test invalid cases + assert get_segment_discriminator({"value_type": "invalid"}) is None + assert get_segment_discriminator({}) is None + assert get_segment_discriminator("not_a_dict") is None + assert get_segment_discriminator(42) is None + assert get_segment_discriminator(object) is None diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py new file mode 100644 index 0000000000..64d0d8c7e7 --- /dev/null +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -0,0 +1,60 @@ +from core.variables.types import SegmentType + + +class TestSegmentTypeIsArrayType: + """ + Test class for SegmentType.is_array_type method. + + Provides comprehensive coverage of all SegmentType values to ensure + correct identification of array and non-array types. + """ + + def test_is_array_type(self): + """ + Test that all SegmentType enum values are covered in our test cases. + + Ensures comprehensive coverage by verifying that every SegmentType + value is tested for the is_array_type method. + """ + # Arrange + all_segment_types = set(SegmentType) + expected_array_types = [ + SegmentType.ARRAY_ANY, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_FILE, + ] + expected_non_array_types = [ + SegmentType.INTEGER, + SegmentType.FLOAT, + SegmentType.NUMBER, + SegmentType.STRING, + SegmentType.OBJECT, + SegmentType.SECRET, + SegmentType.FILE, + SegmentType.NONE, + SegmentType.GROUP, + ] + + for seg_type in expected_array_types: + assert seg_type.is_array_type() + + for seg_type in expected_non_array_types: + assert not seg_type.is_array_type() + + # Act & Assert + covered_types = set(expected_array_types) | set(expected_non_array_types) + assert covered_types == set(SegmentType), "All SegmentType values should be covered in tests" + + def test_all_enum_values_are_supported(self): + """ + Test that all enum values are supported and return boolean values. + + Validates that every SegmentType enum value can be processed by + is_array_type method and returns a boolean value. + """ + enum_values: list[SegmentType] = list(SegmentType) + for seg_type in enum_values: + is_array = seg_type.is_array_type() + assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}" diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index 426557c716..925142892c 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -11,6 +11,7 @@ from core.variables import ( SegmentType, StringVariable, ) +from core.variables.variables import Variable def test_frozen_variables(): @@ -75,7 +76,7 @@ def test_object_variable_to_object(): def test_variable_to_object(): - var = StringVariable(name="text", value="text") + var: Variable = StringVariable(name="text", value="text") assert var.to_object() == "text" var = IntegerVariable(name="integer", value=42) assert var.to_object() == 42 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py new file mode 100644 index 0000000000..cf7cee8710 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py @@ -0,0 +1,146 @@ +import time +from decimal import Decimal + +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState +from core.workflow.system_variable import SystemVariable + + +def create_test_graph_runtime_state() -> GraphRuntimeState: + """Factory function to create a GraphRuntimeState with non-empty values for testing.""" + # Create a variable pool with system variables + system_vars = SystemVariable( + user_id="test_user_123", + app_id="test_app_456", + workflow_id="test_workflow_789", + workflow_execution_id="test_execution_001", + query="test query", + conversation_id="test_conv_123", + dialogue_count=5, + ) + variable_pool = VariablePool(system_variables=system_vars) + + # Add some variables to the variable pool + variable_pool.add(["test_node", "test_var"], "test_value") + variable_pool.add(["another_node", "another_var"], 42) + + # Create LLM usage with realistic values + llm_usage = LLMUsage( + prompt_tokens=150, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal(1000), + prompt_price=Decimal("0.15"), + completion_tokens=75, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal(1000), + completion_price=Decimal("0.15"), + total_tokens=225, + total_price=Decimal("0.30"), + currency="USD", + latency=1.25, + ) + + # Create runtime route state with some node states + node_run_state = RuntimeRouteState() + node_state = node_run_state.create_node_state("test_node_1") + node_run_state.add_route(node_state.id, "target_node_id") + + return GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + total_tokens=100, + llm_usage=llm_usage, + outputs={ + "string_output": "test result", + "int_output": 42, + "float_output": 3.14, + "list_output": ["item1", "item2", "item3"], + "dict_output": {"key1": "value1", "key2": 123}, + "nested_dict": {"level1": {"level2": ["nested", "list", 456]}}, + }, + node_run_steps=5, + node_run_state=node_run_state, + ) + + +def test_basic_round_trip_serialization(): + """Test basic round-trip serialization ensures GraphRuntimeState values remain unchanged.""" + # Create a state with non-empty values + original_state = create_test_graph_runtime_state() + + # Serialize to JSON and deserialize back + json_data = original_state.model_dump_json() + deserialized_state = GraphRuntimeState.model_validate_json(json_data) + + # Core test: ensure the round-trip preserves all values + assert deserialized_state == original_state + + # Serialize to JSON and deserialize back + dict_data = original_state.model_dump(mode="python") + deserialized_state = GraphRuntimeState.model_validate(dict_data) + assert deserialized_state == original_state + + # Serialize to JSON and deserialize back + dict_data = original_state.model_dump(mode="json") + deserialized_state = GraphRuntimeState.model_validate(dict_data) + assert deserialized_state == original_state + + +def test_outputs_field_round_trip(): + """Test the problematic outputs field maintains values through round-trip serialization.""" + original_state = create_test_graph_runtime_state() + + # Serialize and deserialize + json_data = original_state.model_dump_json() + deserialized_state = GraphRuntimeState.model_validate_json(json_data) + + # Verify the outputs field specifically maintains its values + assert deserialized_state.outputs == original_state.outputs + assert deserialized_state == original_state + + +def test_empty_outputs_round_trip(): + """Test round-trip serialization with empty outputs field.""" + variable_pool = VariablePool.empty() + original_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + outputs={}, # Empty outputs + ) + + json_data = original_state.model_dump_json() + deserialized_state = GraphRuntimeState.model_validate_json(json_data) + + assert deserialized_state == original_state + + +def test_llm_usage_round_trip(): + # Create LLM usage with specific decimal values + llm_usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.0015"), + prompt_price_unit=Decimal(1000), + prompt_price=Decimal("0.15"), + completion_tokens=50, + completion_unit_price=Decimal("0.003"), + completion_price_unit=Decimal(1000), + completion_price=Decimal("0.15"), + total_tokens=150, + total_price=Decimal("0.30"), + currency="USD", + latency=2.5, + ) + + json_data = llm_usage.model_dump_json() + deserialized = LLMUsage.model_validate_json(json_data) + assert deserialized == llm_usage + + dict_data = llm_usage.model_dump(mode="python") + deserialized = LLMUsage.model_validate(dict_data) + assert deserialized == llm_usage + + dict_data = llm_usage.model_dump(mode="json") + deserialized = LLMUsage.model_validate(dict_data) + assert deserialized == llm_usage diff --git a/api/tests/unit_tests/core/workflow/graph_engine/entities/test_node_run_state.py b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_node_run_state.py new file mode 100644 index 0000000000..f3de42479a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_node_run_state.py @@ -0,0 +1,401 @@ +import json +import uuid +from datetime import UTC, datetime + +import pytest +from pydantic import ValidationError + +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState + +_TEST_DATETIME = datetime(2024, 1, 15, 10, 30, 45) + + +class TestRouteNodeStateSerialization: + """Test cases for RouteNodeState Pydantic serialization/deserialization.""" + + def _test_route_node_state(self): + """Test comprehensive RouteNodeState serialization with all core fields validation.""" + + node_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"input_key": "input_value"}, + outputs={"output_key": "output_value"}, + ) + + node_state = RouteNodeState( + node_id="comprehensive_test_node", + start_at=_TEST_DATETIME, + finished_at=_TEST_DATETIME, + status=RouteNodeState.Status.SUCCESS, + node_run_result=node_run_result, + index=5, + paused_at=_TEST_DATETIME, + paused_by="user_123", + failed_reason="test_reason", + ) + return node_state + + def test_route_node_state_comprehensive_field_validation(self): + """Test comprehensive RouteNodeState serialization with all core fields validation.""" + node_state = self._test_route_node_state() + serialized = node_state.model_dump() + + # Comprehensive validation of all RouteNodeState fields + assert serialized["node_id"] == "comprehensive_test_node" + assert serialized["status"] == RouteNodeState.Status.SUCCESS + assert serialized["start_at"] == _TEST_DATETIME + assert serialized["finished_at"] == _TEST_DATETIME + assert serialized["paused_at"] == _TEST_DATETIME + assert serialized["paused_by"] == "user_123" + assert serialized["failed_reason"] == "test_reason" + assert serialized["index"] == 5 + assert "id" in serialized + assert isinstance(serialized["id"], str) + uuid.UUID(serialized["id"]) # Validate UUID format + + # Validate nested NodeRunResult structure + assert serialized["node_run_result"] is not None + assert serialized["node_run_result"]["status"] == WorkflowNodeExecutionStatus.SUCCEEDED + assert serialized["node_run_result"]["inputs"] == {"input_key": "input_value"} + assert serialized["node_run_result"]["outputs"] == {"output_key": "output_value"} + + def test_route_node_state_minimal_required_fields(self): + """Test RouteNodeState with only required fields, focusing on defaults.""" + node_state = RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME) + + serialized = node_state.model_dump() + + # Focus on required fields and default values (not re-testing all fields) + assert serialized["node_id"] == "minimal_node" + assert serialized["start_at"] == _TEST_DATETIME + assert serialized["status"] == RouteNodeState.Status.RUNNING # Default status + assert serialized["index"] == 1 # Default index + assert serialized["node_run_result"] is None # Default None + json = node_state.model_dump_json() + deserialized = RouteNodeState.model_validate_json(json) + assert deserialized == node_state + + def test_route_node_state_deserialization_from_dict(self): + """Test RouteNodeState deserialization from dictionary data.""" + test_datetime = datetime(2024, 1, 15, 10, 30, 45) + test_id = str(uuid.uuid4()) + + dict_data = { + "id": test_id, + "node_id": "deserialized_node", + "start_at": test_datetime, + "status": "success", + "finished_at": test_datetime, + "index": 3, + } + + node_state = RouteNodeState.model_validate(dict_data) + + # Focus on deserialization accuracy + assert node_state.id == test_id + assert node_state.node_id == "deserialized_node" + assert node_state.start_at == test_datetime + assert node_state.status == RouteNodeState.Status.SUCCESS + assert node_state.finished_at == test_datetime + assert node_state.index == 3 + + def test_route_node_state_round_trip_consistency(self): + node_states = ( + self._test_route_node_state(), + RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME), + ) + for node_state in node_states: + json = node_state.model_dump_json() + deserialized = RouteNodeState.model_validate_json(json) + assert deserialized == node_state + + dict_ = node_state.model_dump(mode="python") + deserialized = RouteNodeState.model_validate(dict_) + assert deserialized == node_state + + dict_ = node_state.model_dump(mode="json") + deserialized = RouteNodeState.model_validate(dict_) + assert deserialized == node_state + + +class TestRouteNodeStateEnumSerialization: + """Dedicated tests for RouteNodeState Status enum serialization behavior.""" + + def test_status_enum_model_dump_behavior(self): + """Test Status enum serialization in model_dump() returns enum objects.""" + + for status_enum in RouteNodeState.Status: + node_state = RouteNodeState(node_id="enum_test", start_at=_TEST_DATETIME, status=status_enum) + serialized = node_state.model_dump(mode="python") + assert serialized["status"] == status_enum + serialized = node_state.model_dump(mode="json") + assert serialized["status"] == status_enum.value + + def test_status_enum_json_serialization_behavior(self): + """Test Status enum serialization in JSON returns string values.""" + test_datetime = datetime(2024, 1, 15, 10, 30, 45) + + enum_to_string_mapping = { + RouteNodeState.Status.RUNNING: "running", + RouteNodeState.Status.SUCCESS: "success", + RouteNodeState.Status.FAILED: "failed", + RouteNodeState.Status.PAUSED: "paused", + RouteNodeState.Status.EXCEPTION: "exception", + } + + for status_enum, expected_string in enum_to_string_mapping.items(): + node_state = RouteNodeState(node_id="json_enum_test", start_at=test_datetime, status=status_enum) + + json_data = json.loads(node_state.model_dump_json()) + assert json_data["status"] == expected_string + + def test_status_enum_deserialization_from_string(self): + """Test Status enum deserialization from string values.""" + test_datetime = datetime(2024, 1, 15, 10, 30, 45) + + string_to_enum_mapping = { + "running": RouteNodeState.Status.RUNNING, + "success": RouteNodeState.Status.SUCCESS, + "failed": RouteNodeState.Status.FAILED, + "paused": RouteNodeState.Status.PAUSED, + "exception": RouteNodeState.Status.EXCEPTION, + } + + for status_string, expected_enum in string_to_enum_mapping.items(): + dict_data = { + "node_id": "enum_deserialize_test", + "start_at": test_datetime, + "status": status_string, + } + + node_state = RouteNodeState.model_validate(dict_data) + assert node_state.status == expected_enum + + +class TestRuntimeRouteStateSerialization: + """Test cases for RuntimeRouteState Pydantic serialization/deserialization.""" + + _NODE1_ID = "node_1" + _ROUTE_STATE1_ID = str(uuid.uuid4()) + _NODE2_ID = "node_2" + _ROUTE_STATE2_ID = str(uuid.uuid4()) + _NODE3_ID = "node_3" + _ROUTE_STATE3_ID = str(uuid.uuid4()) + + def _get_runtime_route_state(self): + # Create node states with different configurations + node_state_1 = RouteNodeState( + id=self._ROUTE_STATE1_ID, + node_id=self._NODE1_ID, + start_at=_TEST_DATETIME, + index=1, + ) + node_state_2 = RouteNodeState( + id=self._ROUTE_STATE2_ID, + node_id=self._NODE2_ID, + start_at=_TEST_DATETIME, + status=RouteNodeState.Status.SUCCESS, + finished_at=_TEST_DATETIME, + index=2, + ) + node_state_3 = RouteNodeState( + id=self._ROUTE_STATE3_ID, + node_id=self._NODE3_ID, + start_at=_TEST_DATETIME, + status=RouteNodeState.Status.FAILED, + failed_reason="Test failure", + index=3, + ) + + runtime_state = RuntimeRouteState( + routes={node_state_1.id: [node_state_2.id, node_state_3.id], node_state_2.id: [node_state_3.id]}, + node_state_mapping={ + node_state_1.id: node_state_1, + node_state_2.id: node_state_2, + node_state_3.id: node_state_3, + }, + ) + + return runtime_state + + def test_runtime_route_state_comprehensive_structure_validation(self): + """Test comprehensive RuntimeRouteState serialization with full structure validation.""" + + runtime_state = self._get_runtime_route_state() + serialized = runtime_state.model_dump() + + # Comprehensive validation of RuntimeRouteState structure + assert "routes" in serialized + assert "node_state_mapping" in serialized + assert isinstance(serialized["routes"], dict) + assert isinstance(serialized["node_state_mapping"], dict) + + # Validate routes dictionary structure and content + assert len(serialized["routes"]) == 2 + assert self._ROUTE_STATE1_ID in serialized["routes"] + assert self._ROUTE_STATE2_ID in serialized["routes"] + assert serialized["routes"][self._ROUTE_STATE1_ID] == [self._ROUTE_STATE2_ID, self._ROUTE_STATE3_ID] + assert serialized["routes"][self._ROUTE_STATE2_ID] == [self._ROUTE_STATE3_ID] + + # Validate node_state_mapping dictionary structure and content + assert len(serialized["node_state_mapping"]) == 3 + for state_id in [ + self._ROUTE_STATE1_ID, + self._ROUTE_STATE2_ID, + self._ROUTE_STATE3_ID, + ]: + assert state_id in serialized["node_state_mapping"] + node_data = serialized["node_state_mapping"][state_id] + node_state = runtime_state.node_state_mapping[state_id] + assert node_data["node_id"] == node_state.node_id + assert node_data["status"] == node_state.status + assert node_data["index"] == node_state.index + + def test_runtime_route_state_empty_collections(self): + """Test RuntimeRouteState with empty collections, focusing on default behavior.""" + runtime_state = RuntimeRouteState() + serialized = runtime_state.model_dump() + + # Focus on default empty collection behavior + assert serialized["routes"] == {} + assert serialized["node_state_mapping"] == {} + assert isinstance(serialized["routes"], dict) + assert isinstance(serialized["node_state_mapping"], dict) + + def test_runtime_route_state_json_serialization_structure(self): + """Test RuntimeRouteState JSON serialization structure.""" + node_state = RouteNodeState(node_id="json_node", start_at=_TEST_DATETIME) + + runtime_state = RuntimeRouteState( + routes={"source": ["target1", "target2"]}, node_state_mapping={node_state.id: node_state} + ) + + json_str = runtime_state.model_dump_json() + json_data = json.loads(json_str) + + # Focus on JSON structure validation + assert isinstance(json_str, str) + assert isinstance(json_data, dict) + assert "routes" in json_data + assert "node_state_mapping" in json_data + assert json_data["routes"]["source"] == ["target1", "target2"] + assert node_state.id in json_data["node_state_mapping"] + + def test_runtime_route_state_deserialization_from_dict(self): + """Test RuntimeRouteState deserialization from dictionary data.""" + node_id = str(uuid.uuid4()) + + dict_data = { + "routes": {"source_node": ["target_node_1", "target_node_2"]}, + "node_state_mapping": { + node_id: { + "id": node_id, + "node_id": "test_node", + "start_at": _TEST_DATETIME, + "status": "running", + "index": 1, + } + }, + } + + runtime_state = RuntimeRouteState.model_validate(dict_data) + + # Focus on deserialization accuracy + assert runtime_state.routes == {"source_node": ["target_node_1", "target_node_2"]} + assert len(runtime_state.node_state_mapping) == 1 + assert node_id in runtime_state.node_state_mapping + + deserialized_node = runtime_state.node_state_mapping[node_id] + assert deserialized_node.node_id == "test_node" + assert deserialized_node.status == RouteNodeState.Status.RUNNING + assert deserialized_node.index == 1 + + def test_runtime_route_state_round_trip_consistency(self): + """Test RuntimeRouteState round-trip serialization consistency.""" + original = self._get_runtime_route_state() + + # Dictionary round trip + dict_data = original.model_dump(mode="python") + reconstructed = RuntimeRouteState.model_validate(dict_data) + assert reconstructed == original + + dict_data = original.model_dump(mode="json") + reconstructed = RuntimeRouteState.model_validate(dict_data) + assert reconstructed == original + + # JSON round trip + json_str = original.model_dump_json() + json_reconstructed = RuntimeRouteState.model_validate_json(json_str) + assert json_reconstructed == original + + +class TestSerializationEdgeCases: + """Test edge cases and error conditions for serialization/deserialization.""" + + def test_invalid_status_deserialization(self): + """Test deserialization with invalid status values.""" + test_datetime = _TEST_DATETIME + invalid_data = { + "node_id": "invalid_test", + "start_at": test_datetime, + "status": "invalid_status", + } + + with pytest.raises(ValidationError) as exc_info: + RouteNodeState.model_validate(invalid_data) + assert "status" in str(exc_info.value) + + def test_missing_required_fields_deserialization(self): + """Test deserialization with missing required fields.""" + incomplete_data = {"id": str(uuid.uuid4())} + + with pytest.raises(ValidationError) as exc_info: + RouteNodeState.model_validate(incomplete_data) + error_str = str(exc_info.value) + assert "node_id" in error_str or "start_at" in error_str + + def test_invalid_datetime_deserialization(self): + """Test deserialization with invalid datetime values.""" + invalid_data = { + "node_id": "datetime_test", + "start_at": "invalid_datetime", + "status": "running", + } + + with pytest.raises(ValidationError) as exc_info: + RouteNodeState.model_validate(invalid_data) + assert "start_at" in str(exc_info.value) + + def test_invalid_routes_structure_deserialization(self): + """Test RuntimeRouteState deserialization with invalid routes structure.""" + invalid_data = { + "routes": "invalid_routes_structure", # Should be dict + "node_state_mapping": {}, + } + + with pytest.raises(ValidationError) as exc_info: + RuntimeRouteState.model_validate(invalid_data) + assert "routes" in str(exc_info.value) + + def test_timezone_handling_in_datetime_fields(self): + """Test timezone handling in datetime field serialization.""" + utc_datetime = datetime.now(UTC) + naive_datetime = utc_datetime.replace(tzinfo=None) + + node_state = RouteNodeState(node_id="timezone_test", start_at=naive_datetime) + dict_ = node_state.model_dump() + + assert dict_["start_at"] == naive_datetime + + # Test round trip + reconstructed = RouteNodeState.model_validate(dict_) + assert reconstructed.start_at == naive_datetime + assert reconstructed.start_at.tzinfo is None + + json = node_state.model_dump_json() + + reconstructed = RouteNodeState.model_validate_json(json) + assert reconstructed.start_at == naive_datetime + assert reconstructed.start_at.tzinfo is None diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py index c288a5fa13..ed4e42425e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py @@ -8,7 +8,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( BaseNodeEvent, GraphRunFailedEvent, @@ -27,6 +26,7 @@ from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.nodes.llm.node import LLMNode from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType @@ -171,7 +171,8 @@ def test_run_parallel_in_workflow(mock_close, mock_remove): graph = Graph.init(graph_config=graph_config) variable_pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} + system_variables=SystemVariable(user_id="aaa", app_id="1", workflow_id="1", files=[]), + user_inputs={"query": "hi"}, ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) @@ -293,12 +294,12 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove): graph = Graph.init(graph_config=graph_config) variable_pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, + system_variables=SystemVariable( + user_id="aaa", + files=[], + query="what's the weather in SF", + conversation_id="abababa", + ), user_inputs={}, ) @@ -474,12 +475,12 @@ def test_run_branch(mock_close, mock_remove): graph = Graph.init(graph_config=graph_config) variable_pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "hi", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, + system_variables=SystemVariable( + user_id="aaa", + files=[], + query="hi", + conversation_id="abababa", + ), user_inputs={"uid": "takato"}, ) @@ -804,18 +805,22 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app): # construct variable pool pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "dify", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "1", - }, + system_variables=SystemVariable( + user_id="1", + files=[], + query="dify", + conversation_id="abababa", + ), user_inputs={}, environment_variables=[], ) pool.add(["pe", "list_output"], ["dify-1", "dify-2"]) variable_pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"} + system_variables=SystemVariable( + user_id="aaa", + files=[], + ), + user_inputs={"query": "hi"}, ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index b7f78d91fa..85ff4f9c05 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -5,11 +5,11 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from models.enums import UserFrom from models.workflow import WorkflowType @@ -51,7 +51,7 @@ def test_execute_answer(): # construct variable pool pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], ) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py index c3a3818655..137e8b889d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -3,7 +3,6 @@ from collections.abc import Generator from datetime import UTC, datetime from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( GraphEngineEvent, NodeRunStartedEvent, @@ -15,6 +14,7 @@ from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeSta from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.enums import NodeType from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.system_variable import SystemVariable def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: @@ -180,12 +180,12 @@ def test_process(): graph = Graph.init(graph_config=graph_config) variable_pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "what's the weather in SF", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, + system_variables=SystemVariable( + user_id="aaa", + files=[], + query="what's the weather in SF", + conversation_id="abababa", + ), user_inputs={}, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index d066fc1e33..bb6d72f51e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -7,12 +7,13 @@ from core.workflow.nodes.http_request import ( ) from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout from core.workflow.nodes.http_request.executor import Executor +from core.workflow.system_variable import SystemVariable def test_executor_with_json_body_and_number_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) variable_pool.add(["pre_node_id", "number"], 42) @@ -65,7 +66,7 @@ def test_executor_with_json_body_and_number_variable(): def test_executor_with_json_body_and_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -120,7 +121,7 @@ def test_executor_with_json_body_and_object_variable(): def test_executor_with_json_body_and_nested_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -174,7 +175,7 @@ def test_executor_with_json_body_and_nested_object_variable(): def test_extract_selectors_from_template_with_newline(): - variable_pool = VariablePool() + variable_pool = VariablePool(system_variables=SystemVariable.empty()) variable_pool.add(("node_id", "custom_query"), "line1\nline2") node_data = HttpRequestNodeData( title="Test JSON Body with Nested Object Variable", @@ -201,7 +202,7 @@ def test_extract_selectors_from_template_with_newline(): def test_executor_with_form_data(): # Prepare the variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) variable_pool.add(["pre_node_id", "text_field"], "Hello, World!") @@ -280,7 +281,11 @@ def test_init_headers(): authorization=HttpRequestNodeAuthorization(type="no-auth"), ) timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) - return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool()) + return Executor( + node_data=node_data, + timeout=timeout, + variable_pool=VariablePool(system_variables=SystemVariable.empty()), + ) executor = create_executor("aa\n cc:") executor._init_headers() @@ -310,7 +315,11 @@ def test_init_params(): authorization=HttpRequestNodeAuthorization(type="no-auth"), ) timeout = HttpRequestNodeTimeout(connect=10, read=30, write=30) - return Executor(node_data=node_data, timeout=timeout, variable_pool=VariablePool()) + return Executor( + node_data=node_data, + timeout=timeout, + variable_pool=VariablePool(system_variables=SystemVariable.empty()), + ) # Test basic key-value pairs executor = create_executor("key1:value1\nkey2:value2") diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 7fd32a4826..33f9251a72 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -15,6 +15,7 @@ from core.workflow.nodes.http_request import ( HttpRequestNodeBody, HttpRequestNodeData, ) +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType @@ -40,7 +41,7 @@ def test_http_request_node_binary_file(monkeypatch): ), ) variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) variable_pool.add( @@ -128,7 +129,7 @@ def test_http_request_node_form_with_file(monkeypatch): ), ) variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) variable_pool.add( @@ -223,7 +224,7 @@ def test_http_request_node_form_with_multiple_files(monkeypatch): ) variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py index 362072a3db..17c23b7735 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration.py @@ -7,7 +7,6 @@ from core.variables.segments import ArrayAnySegment, ArrayStringSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState @@ -15,6 +14,7 @@ from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.iteration.entities import ErrorHandleMode from core.workflow.nodes.iteration.iteration_node import IterationNode from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType @@ -151,12 +151,12 @@ def test_run(): # construct variable pool pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "dify", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "1", - }, + system_variables=SystemVariable( + user_id="1", + files=[], + query="dify", + conversation_id="abababa", + ), user_inputs={}, environment_variables=[], ) @@ -368,12 +368,12 @@ def test_run_parallel(): # construct variable pool pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "dify", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "1", - }, + system_variables=SystemVariable( + user_id="1", + files=[], + query="dify", + conversation_id="abababa", + ), user_inputs={}, environment_variables=[], ) @@ -584,12 +584,12 @@ def test_iteration_run_in_parallel_mode(): # construct variable pool pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "dify", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "1", - }, + system_variables=SystemVariable( + user_id="1", + files=[], + query="dify", + conversation_id="abababa", + ), user_inputs={}, environment_variables=[], ) @@ -808,12 +808,12 @@ def test_iteration_run_error_handle(): # construct variable pool pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "dify", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "1", - }, + system_variables=SystemVariable( + user_id="1", + files=[], + query="dify", + conversation_id="abababa", + ), user_inputs={}, environment_variables=[], ) diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index 336c2befcc..fefad0ec95 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -36,6 +36,7 @@ from core.workflow.nodes.llm.entities import ( ) from core.workflow.nodes.llm.file_saver import LLMFileSaver from core.workflow.nodes.llm.node import LLMNode +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.provider import ProviderType from models.workflow import WorkflowType @@ -104,7 +105,7 @@ def graph() -> Graph: @pytest.fixture def graph_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) return GraphRuntimeState( @@ -181,7 +182,7 @@ def test_fetch_files_with_file_segment(): related_id="1", storage_key="", ) - variable_pool = VariablePool() + variable_pool = VariablePool.empty() variable_pool.add(["sys", "files"], file) result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) @@ -209,7 +210,7 @@ def test_fetch_files_with_array_file_segment(): storage_key="", ), ] - variable_pool = VariablePool() + variable_pool = VariablePool.empty() variable_pool.add(["sys", "files"], ArrayFileSegment(value=files)) result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) @@ -217,7 +218,7 @@ def test_fetch_files_with_array_file_segment(): def test_fetch_files_with_none_segment(): - variable_pool = VariablePool() + variable_pool = VariablePool.empty() variable_pool.add(["sys", "files"], NoneSegment()) result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) @@ -225,7 +226,7 @@ def test_fetch_files_with_none_segment(): def test_fetch_files_with_array_any_segment(): - variable_pool = VariablePool() + variable_pool = VariablePool.empty() variable_pool.add(["sys", "files"], ArrayAnySegment(value=[])) result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) @@ -233,7 +234,7 @@ def test_fetch_files_with_array_any_segment(): def test_fetch_files_with_non_existent_variable(): - variable_pool = VariablePool() + variable_pool = VariablePool.empty() result = llm_utils.fetch_files(variable_pool=variable_pool, selector=["sys", "files"]) assert result == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/test_answer.py index abc822e98b..44c31b212e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_answer.py @@ -5,11 +5,11 @@ from unittest.mock import MagicMock from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.answer.answer_node import AnswerNode +from core.workflow.system_variable import SystemVariable from extensions.ext_database import db from models.enums import UserFrom from models.workflow import WorkflowType @@ -53,7 +53,7 @@ def test_execute_answer(): # construct variable pool variable_pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index a6c553faf0..3f83428834 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -5,7 +5,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( GraphRunPartialSucceededEvent, NodeRunExceptionEvent, @@ -17,6 +16,7 @@ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntime from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.nodes.llm.node import LLMNode +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType @@ -167,12 +167,12 @@ class ContinueOnErrorTestHelper: """Helper method to create a graph engine instance for testing""" graph = Graph.init(graph_config=graph_config) variable_pool = VariablePool( - system_variables={ - SystemVariableKey.QUERY: "clear", - SystemVariableKey.FILES: [], - SystemVariableKey.CONVERSATION_ID: "abababa", - SystemVariableKey.USER_ID: "aaa", - }, + system_variables=SystemVariable( + user_id="aaa", + files=[], + query="clear", + conversation_id="abababa", + ), user_inputs=user_inputs or {"uid": "takato"}, ) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index c4e411f9d6..167a92484d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -7,12 +7,12 @@ from core.file import File, FileTransferMethod, FileType from core.variables import ArrayFileSegment from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.if_else.entities import IfElseNodeData from core.workflow.nodes.if_else.if_else_node import IfElseNode +from core.workflow.system_variable import SystemVariable from core.workflow.utils.condition.entities import Condition, SubCondition, SubVariableCondition from extensions.ext_database import db from models.enums import UserFrom @@ -37,9 +37,7 @@ def test_execute_if_else_result_true(): ) # construct variable pool - pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={} - ) + pool = VariablePool(system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}) pool.add(["start", "array_contains"], ["ab", "def"]) pool.add(["start", "array_not_contains"], ["ac", "def"]) pool.add(["start", "contains"], "cabcde") @@ -157,7 +155,7 @@ def test_execute_if_else_result_false(): # construct variable pool pool = VariablePool( - system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, + system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], ) diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index e121f6338c..2776e57777 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -15,6 +15,7 @@ from core.workflow.nodes.enums import ErrorStrategy from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.tool import ToolNode from core.workflow.nodes.tool.entities import ToolNodeData +from core.workflow.system_variable import SystemVariable from models import UserFrom, WorkflowType @@ -34,7 +35,7 @@ def _create_tool_node(): version="1", ) variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, ) node = ToolNode( diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index deb3e29b86..62e3e37104 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -7,12 +7,12 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import ArrayStringVariable, StringVariable from core.workflow.conversation_variable_updater import ConversationVariableUpdater from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType @@ -68,7 +68,7 @@ def test_overwrite_string_variable(): # construct variable pool variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id}, + system_variables=SystemVariable(conversation_id=conversation_id), user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -165,7 +165,7 @@ def test_append_variable_to_array(): conversation_id = str(uuid.uuid4()) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id}, + system_variables=SystemVariable(conversation_id=conversation_id), user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -256,7 +256,7 @@ def test_clear_array(): conversation_id = str(uuid.uuid4()) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: conversation_id}, + system_variables=SystemVariable(conversation_id=conversation_id), user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py index 7c5597dd89..a3a90b0599 100644 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py @@ -5,12 +5,12 @@ from uuid import uuid4 from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import ArrayStringVariable from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation +from core.workflow.system_variable import SystemVariable from models.enums import UserFrom from models.workflow import WorkflowType @@ -109,7 +109,7 @@ def test_remove_first_from_array(): ) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables=SystemVariable(conversation_id="conversation_id"), user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -196,7 +196,7 @@ def test_remove_last_from_array(): ) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables=SystemVariable(conversation_id="conversation_id"), user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -275,7 +275,7 @@ def test_remove_first_from_empty_array(): ) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables=SystemVariable(conversation_id="conversation_id"), user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], @@ -354,7 +354,7 @@ def test_remove_last_from_empty_array(): ) variable_pool = VariablePool( - system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"}, + system_variables=SystemVariable(conversation_id="conversation_id"), user_inputs={}, environment_variables=[], conversation_variables=[conversation_variable], diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py new file mode 100644 index 0000000000..11d788ed79 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_system_variable.py @@ -0,0 +1,251 @@ +import json +from typing import Any + +import pytest +from pydantic import ValidationError + +from core.file.enums import FileTransferMethod, FileType +from core.file.models import File +from core.workflow.system_variable import SystemVariable + +# Test data constants for SystemVariable serialization tests +VALID_BASE_DATA: dict[str, Any] = { + "user_id": "a20f06b1-8703-45ab-937c-860a60072113", + "app_id": "661bed75-458d-49c9-b487-fda0762677b9", + "workflow_id": "d31f2136-b292-4ae0-96d4-1e77894a4f43", +} + +COMPLETE_VALID_DATA: dict[str, Any] = { + **VALID_BASE_DATA, + "query": "test query", + "files": [], + "conversation_id": "91f1eb7d-69f4-4d7b-b82f-4003d51744b9", + "dialogue_count": 5, + "workflow_run_id": "eb4704b5-2274-47f2-bfcd-0452daa82cb5", +} + + +def create_test_file() -> File: + """Create a test File object for serialization tests.""" + return File( + tenant_id="test-tenant-id", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="test-file-id", + filename="test.txt", + extension=".txt", + mime_type="text/plain", + size=1024, + storage_key="test-storage-key", + ) + + +class TestSystemVariableSerialization: + """Focused tests for SystemVariable serialization/deserialization logic.""" + + def test_basic_deserialization(self): + """Test successful deserialization from JSON structure with all fields correctly mapped.""" + # Test with complete data + system_var = SystemVariable(**COMPLETE_VALID_DATA) + + # Verify all fields are correctly mapped + assert system_var.user_id == COMPLETE_VALID_DATA["user_id"] + assert system_var.app_id == COMPLETE_VALID_DATA["app_id"] + assert system_var.workflow_id == COMPLETE_VALID_DATA["workflow_id"] + assert system_var.query == COMPLETE_VALID_DATA["query"] + assert system_var.conversation_id == COMPLETE_VALID_DATA["conversation_id"] + assert system_var.dialogue_count == COMPLETE_VALID_DATA["dialogue_count"] + assert system_var.workflow_execution_id == COMPLETE_VALID_DATA["workflow_run_id"] + assert system_var.files == [] + + # Test with minimal data (only required fields) + minimal_var = SystemVariable(**VALID_BASE_DATA) + assert minimal_var.user_id == VALID_BASE_DATA["user_id"] + assert minimal_var.app_id == VALID_BASE_DATA["app_id"] + assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"] + assert minimal_var.query is None + assert minimal_var.conversation_id is None + assert minimal_var.dialogue_count is None + assert minimal_var.workflow_execution_id is None + assert minimal_var.files == [] + + def test_alias_handling(self): + """Test workflow_execution_id vs workflow_run_id alias resolution - core deserialization logic.""" + workflow_id = "eb4704b5-2274-47f2-bfcd-0452daa82cb5" + + # Test workflow_run_id only (preferred alias) + data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id} + system_var1 = SystemVariable(**data_run_id) + assert system_var1.workflow_execution_id == workflow_id + + # Test workflow_execution_id only (direct field name) + data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id} + system_var2 = SystemVariable(**data_execution_id) + assert system_var2.workflow_execution_id == workflow_id + + # Test both present - workflow_run_id should take precedence + data_both = { + **VALID_BASE_DATA, + "workflow_execution_id": "should-be-ignored", + "workflow_run_id": workflow_id, + } + system_var3 = SystemVariable(**data_both) + assert system_var3.workflow_execution_id == workflow_id + + # Test neither present - should be None + system_var4 = SystemVariable(**VALID_BASE_DATA) + assert system_var4.workflow_execution_id is None + + def test_serialization_round_trip(self): + """Test that serialize → deserialize produces the same result with alias handling.""" + # Create original SystemVariable + original = SystemVariable(**COMPLETE_VALID_DATA) + + # Serialize to dict + serialized = original.model_dump(mode="json") + + # Verify alias is used in serialization (workflow_run_id, not workflow_execution_id) + assert "workflow_run_id" in serialized + assert "workflow_execution_id" not in serialized + assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"] + + # Deserialize back + deserialized = SystemVariable(**serialized) + + # Verify all fields match after round-trip + assert deserialized.user_id == original.user_id + assert deserialized.app_id == original.app_id + assert deserialized.workflow_id == original.workflow_id + assert deserialized.query == original.query + assert deserialized.conversation_id == original.conversation_id + assert deserialized.dialogue_count == original.dialogue_count + assert deserialized.workflow_execution_id == original.workflow_execution_id + assert list(deserialized.files) == list(original.files) + + def test_json_round_trip(self): + """Test JSON serialization/deserialization consistency with proper structure.""" + # Create original SystemVariable + original = SystemVariable(**COMPLETE_VALID_DATA) + + # Serialize to JSON string + json_str = original.model_dump_json() + + # Parse JSON and verify structure + json_data = json.loads(json_str) + assert "workflow_run_id" in json_data + assert "workflow_execution_id" not in json_data + assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"] + + # Deserialize from JSON data + deserialized = SystemVariable(**json_data) + + # Verify key fields match after JSON round-trip + assert deserialized.workflow_execution_id == original.workflow_execution_id + assert deserialized.user_id == original.user_id + assert deserialized.app_id == original.app_id + assert deserialized.workflow_id == original.workflow_id + + def test_files_field_deserialization(self): + """Test deserialization with File objects in the files field - SystemVariable specific logic.""" + # Test with empty files list + data_empty = {**VALID_BASE_DATA, "files": []} + system_var_empty = SystemVariable(**data_empty) + assert system_var_empty.files == [] + + # Test with single File object + test_file = create_test_file() + data_single = {**VALID_BASE_DATA, "files": [test_file]} + system_var_single = SystemVariable(**data_single) + assert len(system_var_single.files) == 1 + assert system_var_single.files[0].filename == "test.txt" + assert system_var_single.files[0].tenant_id == "test-tenant-id" + + # Test with multiple File objects + file1 = File( + tenant_id="tenant1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="file1", + filename="doc1.txt", + storage_key="key1", + ) + file2 = File( + tenant_id="tenant2", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.jpg", + filename="image.jpg", + storage_key="key2", + ) + + data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]} + system_var_multiple = SystemVariable(**data_multiple) + assert len(system_var_multiple.files) == 2 + assert system_var_multiple.files[0].filename == "doc1.txt" + assert system_var_multiple.files[1].filename == "image.jpg" + + # Verify files field serialization/deserialization + serialized = system_var_multiple.model_dump(mode="json") + deserialized = SystemVariable(**serialized) + assert len(deserialized.files) == 2 + assert deserialized.files[0].filename == "doc1.txt" + assert deserialized.files[1].filename == "image.jpg" + + def test_alias_serialization_consistency(self): + """Test that alias handling works consistently in both serialization directions.""" + workflow_id = "test-workflow-id" + + # Create with workflow_run_id (alias) + data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id} + system_var = SystemVariable(**data_with_alias) + + # Serialize and verify alias is used + serialized = system_var.model_dump() + assert serialized["workflow_run_id"] == workflow_id + assert "workflow_execution_id" not in serialized + + # Deserialize and verify field mapping + deserialized = SystemVariable(**serialized) + assert deserialized.workflow_execution_id == workflow_id + + # Test JSON serialization path + json_serialized = json.loads(system_var.model_dump_json()) + assert json_serialized["workflow_run_id"] == workflow_id + assert "workflow_execution_id" not in json_serialized + + json_deserialized = SystemVariable(**json_serialized) + assert json_deserialized.workflow_execution_id == workflow_id + + def test_model_validator_serialization_logic(self): + """Test the custom model validator behavior for serialization scenarios.""" + workflow_id = "test-workflow-execution-id" + + # Test direct instantiation with workflow_execution_id (should work) + data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id} + system_var1 = SystemVariable(**data1) + assert system_var1.workflow_execution_id == workflow_id + + # Test serialization of the above (should use alias) + serialized1 = system_var1.model_dump() + assert "workflow_run_id" in serialized1 + assert serialized1["workflow_run_id"] == workflow_id + + # Test both present - workflow_run_id takes precedence (validator logic) + data2 = { + **VALID_BASE_DATA, + "workflow_execution_id": "should-be-removed", + "workflow_run_id": workflow_id, + } + system_var2 = SystemVariable(**data2) + assert system_var2.workflow_execution_id == workflow_id + + # Verify serialization consistency + serialized2 = system_var2.model_dump() + assert serialized2["workflow_run_id"] == workflow_id + + +def test_constructor_with_extra_key(): + # Test that SystemVariable should forbid extra keys + with pytest.raises(ValidationError): + # This should fail because there is an unexpected key. + SystemVariable(invalid_key=1) # type: ignore diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index bb8d34fad5..c65b60cb4d 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -1,17 +1,43 @@ +import uuid +from collections import defaultdict + import pytest -from pydantic import ValidationError from core.file import File, FileTransferMethod, FileType from core.variables import FileSegment, StringSegment -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID +from core.variables.segments import ( + ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, + ArrayStringSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, +) +from core.variables.variables import ( + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FloatVariable, + IntegerVariable, + ObjectVariable, + StringVariable, + VariableUnion, +) +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from factories.variable_factory import build_segment, segment_to_variable @pytest.fixture def pool(): - return VariablePool(system_variables={}, user_inputs={}) + return VariablePool( + system_variables=SystemVariable(user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"), + user_inputs={}, + ) @pytest.fixture @@ -52,18 +78,28 @@ def test_use_long_selector(pool): class TestVariablePool: def test_constructor(self): - pool = VariablePool() + # Test with minimal required SystemVariable + minimal_system_vars = SystemVariable( + user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id" + ) + pool = VariablePool(system_variables=minimal_system_vars) + + # Test with all parameters pool = VariablePool( variable_dictionary={}, user_inputs={}, - system_variables={}, + system_variables=minimal_system_vars, environment_variables=[], conversation_variables=[], ) + # Test with more complex SystemVariable + complex_system_vars = SystemVariable( + user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id" + ) pool = VariablePool( user_inputs={"key": "value"}, - system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"}, + system_variables=complex_system_vars, environment_variables=[ segment_to_variable( segment=build_segment(1), @@ -80,6 +116,323 @@ class TestVariablePool: ], ) - def test_constructor_with_invalid_system_variable_key(self): - with pytest.raises(ValidationError): - VariablePool(system_variables={"invalid_key": "value"}) # type: ignore + def test_get_system_variables(self): + sys_var = SystemVariable( + user_id="test_user_id", + app_id="test_app_id", + workflow_id="test_workflow_id", + workflow_execution_id="test_execution_123", + query="test query", + conversation_id="test_conv_id", + dialogue_count=5, + ) + pool = VariablePool(system_variables=sys_var) + + kv = [ + ("user_id", sys_var.user_id), + ("app_id", sys_var.app_id), + ("workflow_id", sys_var.workflow_id), + ("workflow_run_id", sys_var.workflow_execution_id), + ("query", sys_var.query), + ("conversation_id", sys_var.conversation_id), + ("dialogue_count", sys_var.dialogue_count), + ] + for key, expected_value in kv: + segment = pool.get([SYSTEM_VARIABLE_NODE_ID, key]) + assert segment is not None + assert segment.value == expected_value + + +class TestVariablePoolSerialization: + """Test cases for VariablePool serialization and deserialization using Pydantic's built-in methods. + + These tests focus exclusively on serialization/deserialization logic to ensure that + VariablePool data can be properly serialized to dictionaries/JSON and reconstructed + while preserving all data integrity. + """ + + _NODE1_ID = "node_1" + _NODE2_ID = "node_2" + _NODE3_ID = "node_3" + + def _create_pool_without_file(self): + # Create comprehensive system variables + system_vars = SystemVariable( + user_id="test_user_id", + app_id="test_app_id", + workflow_id="test_workflow_id", + workflow_execution_id="test_execution_123", + query="test query", + conversation_id="test_conv_id", + dialogue_count=5, + ) + + # Create environment variables with all types including ArrayFileVariable + env_vars: list[VariableUnion] = [ + StringVariable( + id="env_string_id", + name="env_string", + value="env_string_value", + selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_string"], + ), + IntegerVariable( + id="env_integer_id", + name="env_integer", + value=1, + selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_integer"], + ), + FloatVariable( + id="env_float_id", + name="env_float", + value=1.0, + selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_float"], + ), + ] + + # Create conversation variables with complex data + conv_vars: list[VariableUnion] = [ + StringVariable( + id="conv_string_id", + name="conv_string", + value="conv_string_value", + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_string"], + ), + IntegerVariable( + id="conv_integer_id", + name="conv_integer", + value=1, + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_integer"], + ), + FloatVariable( + id="conv_float_id", + name="conv_float", + value=1.0, + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_float"], + ), + ObjectVariable( + id="conv_object_id", + name="conv_object", + value={"key": "value", "nested": {"data": 123}}, + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_object"], + ), + ArrayStringVariable( + id="conv_array_string_id", + name="conv_array_string", + value=["conv_array_string_value"], + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_string"], + ), + ArrayNumberVariable( + id="conv_array_number_id", + name="conv_array_number", + value=[1, 1.0], + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_number"], + ), + ArrayObjectVariable( + id="conv_array_object_id", + name="conv_array_object", + value=[{"a": 1}, {"b": "2"}], + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_object"], + ), + ] + + # Create comprehensive user inputs + user_inputs = { + "string_input": "test_value", + "number_input": 42, + "object_input": {"nested": {"key": "value"}}, + "array_input": ["item1", "item2", "item3"], + } + + # Create VariablePool + pool = VariablePool( + system_variables=system_vars, + user_inputs=user_inputs, + environment_variables=env_vars, + conversation_variables=conv_vars, + ) + return pool + + def _add_node_data_to_pool(self, pool: VariablePool, with_file=False): + test_file = File( + tenant_id="test_tenant_id", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="test_related_id", + remote_url="test_url", + filename="test_file.txt", + storage_key="test_storage_key", + ) + + # Add various segment types to variable dictionary + pool.add((self._NODE1_ID, "string_var"), StringSegment(value="test_string")) + pool.add((self._NODE1_ID, "int_var"), IntegerSegment(value=123)) + pool.add((self._NODE1_ID, "float_var"), FloatSegment(value=45.67)) + pool.add((self._NODE1_ID, "object_var"), ObjectSegment(value={"test": "data"})) + if with_file: + pool.add((self._NODE1_ID, "file_var"), FileSegment(value=test_file)) + pool.add((self._NODE1_ID, "none_var"), NoneSegment()) + + # Add array segments including ArrayFileVariable + pool.add((self._NODE2_ID, "array_string"), ArrayStringSegment(value=["a", "b", "c"])) + pool.add((self._NODE2_ID, "array_number"), ArrayNumberSegment(value=[1, 2, 3])) + pool.add((self._NODE2_ID, "array_object"), ArrayObjectSegment(value=[{"a": 1}, {"b": 2}])) + if with_file: + pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file])) + pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}])) + + # Add nested variables + pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value")) + + def test_system_variables(self): + sys_vars = SystemVariable( + user_id="test_user_id", + app_id="test_app_id", + workflow_id="test_workflow_id", + workflow_execution_id="test_execution_123", + query="test query", + conversation_id="test_conv_id", + dialogue_count=5, + ) + pool = VariablePool(system_variables=sys_vars) + json = pool.model_dump_json() + pool2 = VariablePool.model_validate_json(json) + assert pool2.system_variables == sys_vars + + for mode in ["json", "python"]: + dict_ = pool.model_dump(mode=mode) + pool2 = VariablePool.model_validate(dict_) + assert pool2.system_variables == sys_vars + + def test_pool_without_file_vars(self): + pool = self._create_pool_without_file() + json = pool.model_dump_json() + pool2 = pool.model_validate_json(json) + assert pool2.system_variables == pool.system_variables + assert pool2.conversation_variables == pool.conversation_variables + assert pool2.environment_variables == pool.environment_variables + assert pool2.user_inputs == pool.user_inputs + assert pool2.variable_dictionary == pool.variable_dictionary + assert pool2 == pool + + def test_basic_dictionary_round_trip(self): + """Test basic round-trip serialization: model_dump() → model_validate()""" + # Create a comprehensive VariablePool with all data types + original_pool = self._create_pool_without_file() + self._add_node_data_to_pool(original_pool) + + # Serialize to dictionary using Pydantic's model_dump() + serialized_data = original_pool.model_dump() + + # Verify serialized data structure + assert isinstance(serialized_data, dict) + assert "system_variables" in serialized_data + assert "user_inputs" in serialized_data + assert "environment_variables" in serialized_data + assert "conversation_variables" in serialized_data + assert "variable_dictionary" in serialized_data + + # Deserialize back using Pydantic's model_validate() + reconstructed_pool = VariablePool.model_validate(serialized_data) + + # Verify data integrity is preserved + self._assert_pools_equal(original_pool, reconstructed_pool) + + def test_json_round_trip(self): + """Test JSON round-trip serialization: model_dump_json() → model_validate_json()""" + # Create a comprehensive VariablePool with all data types + original_pool = self._create_pool_without_file() + self._add_node_data_to_pool(original_pool) + + # Serialize to JSON string using Pydantic's model_dump_json() + json_data = original_pool.model_dump_json() + + # Verify JSON is valid string + assert isinstance(json_data, str) + assert len(json_data) > 0 + + # Deserialize back using Pydantic's model_validate_json() + reconstructed_pool = VariablePool.model_validate_json(json_data) + + # Verify data integrity is preserved + self._assert_pools_equal(original_pool, reconstructed_pool) + + def test_complex_data_serialization(self): + """Test serialization of complex data structures including ArrayFileVariable""" + original_pool = self._create_pool_without_file() + self._add_node_data_to_pool(original_pool, with_file=True) + + # Test dictionary round-trip + dict_data = original_pool.model_dump() + reconstructed_dict = VariablePool.model_validate(dict_data) + + # Test JSON round-trip + json_data = original_pool.model_dump_json() + reconstructed_json = VariablePool.model_validate_json(json_data) + + # Verify both reconstructed pools are equivalent + self._assert_pools_equal(reconstructed_dict, reconstructed_json) + # TODO: assert the data for file object... + + def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool) -> None: + """Assert that two VariablePools contain equivalent data""" + + # Compare system variables + assert pool1.system_variables == pool2.system_variables + + # Compare user inputs + assert dict(pool1.user_inputs) == dict(pool2.user_inputs) + + # Compare environment variables count + assert pool1.environment_variables == pool2.environment_variables + + # Compare conversation variables count + assert pool1.conversation_variables == pool2.conversation_variables + + # Test key variable retrievals to ensure functionality is preserved + test_selectors = [ + (SYSTEM_VARIABLE_NODE_ID, "user_id"), + (SYSTEM_VARIABLE_NODE_ID, "app_id"), + (ENVIRONMENT_VARIABLE_NODE_ID, "env_string"), + (ENVIRONMENT_VARIABLE_NODE_ID, "env_number"), + (CONVERSATION_VARIABLE_NODE_ID, "conv_string"), + (self._NODE1_ID, "string_var"), + (self._NODE1_ID, "int_var"), + (self._NODE1_ID, "float_var"), + (self._NODE2_ID, "array_string"), + (self._NODE2_ID, "array_number"), + (self._NODE3_ID, "nested", "deep", "var"), + ] + + for selector in test_selectors: + val1 = pool1.get(selector) + val2 = pool2.get(selector) + + # Both should exist or both should be None + assert (val1 is None) == (val2 is None) + + if val1 is not None and val2 is not None: + # Values should be equal + assert val1.value == val2.value + # Value types should be the same (more important than exact class type) + assert val1.value_type == val2.value_type + + def test_variable_pool_deserialization_default_dict(self): + variable_pool = VariablePool( + user_inputs={"a": 1, "b": "2"}, + system_variables=SystemVariable(workflow_id=str(uuid.uuid4())), + environment_variables=[ + StringVariable(name="str_var", value="a"), + ], + conversation_variables=[IntegerVariable(name="int_var", value=1)], + ) + assert isinstance(variable_pool.variable_dictionary, defaultdict) + json = variable_pool.model_dump_json() + loaded = VariablePool.model_validate_json(json) + assert isinstance(loaded.variable_dictionary, defaultdict) + + loaded.add(["non_exist_node", "a"], 1) + + pool_dict = variable_pool.model_dump() + loaded = VariablePool.model_validate(pool_dict) + assert isinstance(loaded.variable_dictionary, defaultdict) + loaded.add(["non_exist_node", "a"], 1) diff --git a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py index 646de8bf3a..642bc810ba 100644 --- a/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py +++ b/api/tests/unit_tests/core/workflow/test_workflow_cycle_manager.py @@ -18,10 +18,10 @@ from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from models.enums import CreatorUserRole from models.model import AppMode @@ -67,14 +67,14 @@ def real_app_generate_entity(): @pytest.fixture def real_workflow_system_variables(): - return { - SystemVariableKey.QUERY: "test query", - SystemVariableKey.CONVERSATION_ID: "test-conversation-id", - SystemVariableKey.USER_ID: "test-user-id", - SystemVariableKey.APP_ID: "test-app-id", - SystemVariableKey.WORKFLOW_ID: "test-workflow-id", - SystemVariableKey.WORKFLOW_EXECUTION_ID: "test-workflow-run-id", - } + return SystemVariable( + query="test query", + conversation_id="test-conversation-id", + user_id="test-user-id", + app_id="test-app-id", + workflow_id="test-workflow-id", + workflow_execution_id="test-workflow-run-id", + ) @pytest.fixture diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py index f1cb937bb3..54bf6558bf 100644 --- a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py @@ -10,7 +10,7 @@ class TestAppendVariablesRecursively: def test_append_simple_dict_value(self): """Test appending a simple dictionary value""" - pool = VariablePool() + pool = VariablePool.empty() node_id = "test_node" variable_key_list = ["output"] variable_value = {"name": "John", "age": 30} @@ -33,7 +33,7 @@ class TestAppendVariablesRecursively: def test_append_object_segment_value(self): """Test appending an ObjectSegment value""" - pool = VariablePool() + pool = VariablePool.empty() node_id = "test_node" variable_key_list = ["result"] @@ -60,7 +60,7 @@ class TestAppendVariablesRecursively: def test_append_nested_dict_value(self): """Test appending a nested dictionary value""" - pool = VariablePool() + pool = VariablePool.empty() node_id = "test_node" variable_key_list = ["data"] @@ -97,7 +97,7 @@ class TestAppendVariablesRecursively: def test_append_non_dict_value(self): """Test appending a non-dictionary value (should not recurse)""" - pool = VariablePool() + pool = VariablePool.empty() node_id = "test_node" variable_key_list = ["simple"] variable_value = "simple_string" @@ -114,7 +114,7 @@ class TestAppendVariablesRecursively: def test_append_segment_non_object_value(self): """Test appending a Segment that is not ObjectSegment (should not recurse)""" - pool = VariablePool() + pool = VariablePool.empty() node_id = "test_node" variable_key_list = ["text"] variable_value = StringSegment(value="Hello World") @@ -132,7 +132,7 @@ class TestAppendVariablesRecursively: def test_append_empty_dict_value(self): """Test appending an empty dictionary value""" - pool = VariablePool() + pool = VariablePool.empty() node_id = "test_node" variable_key_list = ["empty"] variable_value: dict[str, Any] = {} diff --git a/api/tests/unit_tests/factories/test_variable_factory.py b/api/tests/unit_tests/factories/test_variable_factory.py index edd4c5e93e..4f2542a323 100644 --- a/api/tests/unit_tests/factories/test_variable_factory.py +++ b/api/tests/unit_tests/factories/test_variable_factory.py @@ -505,8 +505,8 @@ def test_build_segment_type_for_scalar(): size=1000, ) cases = [ - TestCase(0, SegmentType.NUMBER), - TestCase(0.0, SegmentType.NUMBER), + TestCase(0, SegmentType.INTEGER), + TestCase(0.0, SegmentType.FLOAT), TestCase("", SegmentType.STRING), TestCase(file, SegmentType.FILE), ] @@ -531,14 +531,14 @@ class TestBuildSegmentWithType: result = build_segment_with_type(SegmentType.NUMBER, 42) assert isinstance(result, IntegerSegment) assert result.value == 42 - assert result.value_type == SegmentType.NUMBER + assert result.value_type == SegmentType.INTEGER def test_number_type_float(self): """Test building a number segment with float value.""" result = build_segment_with_type(SegmentType.NUMBER, 3.14) assert isinstance(result, FloatSegment) assert result.value == 3.14 - assert result.value_type == SegmentType.NUMBER + assert result.value_type == SegmentType.FLOAT def test_object_type(self): """Test building an object segment with correct type.""" @@ -652,14 +652,14 @@ class TestBuildSegmentWithType: with pytest.raises(TypeMismatchError) as exc_info: build_segment_with_type(SegmentType.STRING, None) - assert "Expected string, but got None" in str(exc_info.value) + assert "expected string, but got None" in str(exc_info.value) def test_type_mismatch_empty_list_to_non_array(self): """Test type mismatch when expecting non-array type but getting empty list.""" with pytest.raises(TypeMismatchError) as exc_info: build_segment_with_type(SegmentType.STRING, []) - assert "Expected string, but got empty list" in str(exc_info.value) + assert "expected string, but got empty list" in str(exc_info.value) def test_type_mismatch_object_to_array(self): """Test type mismatch when expecting array but getting object.""" @@ -674,19 +674,19 @@ class TestBuildSegmentWithType: # Integer should work result_int = build_segment_with_type(SegmentType.NUMBER, 42) assert isinstance(result_int, IntegerSegment) - assert result_int.value_type == SegmentType.NUMBER + assert result_int.value_type == SegmentType.INTEGER # Float should work result_float = build_segment_with_type(SegmentType.NUMBER, 3.14) assert isinstance(result_float, FloatSegment) - assert result_float.value_type == SegmentType.NUMBER + assert result_float.value_type == SegmentType.FLOAT @pytest.mark.parametrize( ("segment_type", "value", "expected_class"), [ (SegmentType.STRING, "test", StringSegment), - (SegmentType.NUMBER, 42, IntegerSegment), - (SegmentType.NUMBER, 3.14, FloatSegment), + (SegmentType.INTEGER, 42, IntegerSegment), + (SegmentType.FLOAT, 3.14, FloatSegment), (SegmentType.OBJECT, {}, ObjectSegment), (SegmentType.NONE, None, NoneSegment), (SegmentType.ARRAY_STRING, [], ArrayStringSegment), @@ -857,5 +857,5 @@ class TestBuildSegmentValueErrors: # Verify they are processed as integers, not as errors assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1" assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0" - assert true_segment.value_type == SegmentType.NUMBER - assert false_segment.value_type == SegmentType.NUMBER + assert true_segment.value_type == SegmentType.INTEGER + assert false_segment.value_type == SegmentType.INTEGER diff --git a/api/tests/unit_tests/libs/test_uuid_utils.py b/api/tests/unit_tests/libs/test_uuid_utils.py new file mode 100644 index 0000000000..7dbda95f45 --- /dev/null +++ b/api/tests/unit_tests/libs/test_uuid_utils.py @@ -0,0 +1,351 @@ +import struct +import time +import uuid +from unittest import mock + +import pytest +from hypothesis import given +from hypothesis import strategies as st + +from libs.uuid_utils import _create_uuidv7_bytes, uuidv7, uuidv7_boundary, uuidv7_timestamp + + +# Tests for private helper function _create_uuidv7_bytes +def test_create_uuidv7_bytes_basic_structure(): + """Test basic byte structure creation.""" + timestamp_ms = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds + random_bytes = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x11\x22" + + result = _create_uuidv7_bytes(timestamp_ms, random_bytes) + + # Should be exactly 16 bytes + assert len(result) == 16 + assert isinstance(result, bytes) + + # Create UUID from bytes to verify it's valid + uuid_obj = uuid.UUID(bytes=result) + assert uuid_obj.version == 7 + + +def test_create_uuidv7_bytes_timestamp_encoding(): + """Test timestamp is correctly encoded in first 48 bits.""" + timestamp_ms = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds + random_bytes = b"\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + + result = _create_uuidv7_bytes(timestamp_ms, random_bytes) + + # Extract timestamp from first 6 bytes + timestamp_bytes = b"\x00\x00" + result[0:6] + extracted_timestamp = struct.unpack(">Q", timestamp_bytes)[0] + + assert extracted_timestamp == timestamp_ms + + +def test_create_uuidv7_bytes_version_bits(): + """Test version bits are set to 7.""" + timestamp_ms = 1609459200000 + random_bytes = b"\xff\xff\x00\x00\x00\x00\x00\x00\x00\x00" # Set first 2 bytes to all 1s + + result = _create_uuidv7_bytes(timestamp_ms, random_bytes) + + # Extract version from bytes 6-7 + version_and_rand_a = struct.unpack(">H", result[6:8])[0] + version = (version_and_rand_a >> 12) & 0x0F + + assert version == 7 + + +def test_create_uuidv7_bytes_variant_bits(): + """Test variant bits are set correctly.""" + timestamp_ms = 1609459200000 + random_bytes = b"\x00\x00\xff\x00\x00\x00\x00\x00\x00\x00" # Set byte 8 to all 1s + + result = _create_uuidv7_bytes(timestamp_ms, random_bytes) + + # Check variant bits in byte 8 (should be 10xxxxxx) + variant_byte = result[8] + variant_bits = (variant_byte >> 6) & 0b11 + + assert variant_bits == 0b10 # Should be binary 10 + + +def test_create_uuidv7_bytes_random_data(): + """Test random bytes are placed correctly.""" + timestamp_ms = 1609459200000 + random_bytes = b"\x12\x34\x56\x78\x9a\xbc\xde\xf0\x11\x22" + + result = _create_uuidv7_bytes(timestamp_ms, random_bytes) + + # Check random data A (12 bits from bytes 6-7, excluding version) + version_and_rand_a = struct.unpack(">H", result[6:8])[0] + rand_a = version_and_rand_a & 0x0FFF + expected_rand_a = struct.unpack(">H", random_bytes[0:2])[0] & 0x0FFF + assert rand_a == expected_rand_a + + # Check random data B (bytes 8-15, with variant bits preserved) + # Byte 8 should have variant bits set but preserve lower 6 bits + expected_byte_8 = (random_bytes[2] & 0x3F) | 0x80 + assert result[8] == expected_byte_8 + + # Bytes 9-15 should match random_bytes[3:10] + assert result[9:16] == random_bytes[3:10] + + +def test_create_uuidv7_bytes_zero_random(): + """Test with zero random bytes (boundary case).""" + timestamp_ms = 1609459200000 + zero_random_bytes = b"\x00" * 10 + + result = _create_uuidv7_bytes(timestamp_ms, zero_random_bytes) + + # Should still be valid UUIDv7 + uuid_obj = uuid.UUID(bytes=result) + assert uuid_obj.version == 7 + + # Version bits should be 0x7000 + version_and_rand_a = struct.unpack(">H", result[6:8])[0] + assert version_and_rand_a == 0x7000 + + # Variant byte should be 0x80 (variant bits + zero random bits) + assert result[8] == 0x80 + + # Remaining bytes should be zero + assert result[9:16] == b"\x00" * 7 + + +def test_uuidv7_basic_generation(): + """Test basic UUID generation produces valid UUIDv7.""" + result = uuidv7() + + # Should be a UUID object + assert isinstance(result, uuid.UUID) + + # Should be version 7 + assert result.version == 7 + + # Should have correct variant (RFC 4122 variant) + # Variant bits should be 10xxxxxx (0x80-0xBF range) + variant_byte = result.bytes[8] + assert (variant_byte >> 6) == 0b10 + + +def test_uuidv7_with_custom_timestamp(): + """Test UUID generation with custom timestamp.""" + custom_timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds + result = uuidv7(custom_timestamp) + + assert isinstance(result, uuid.UUID) + assert result.version == 7 + + # Extract and verify timestamp + extracted_timestamp = uuidv7_timestamp(result) + assert isinstance(extracted_timestamp, int) + assert extracted_timestamp == custom_timestamp # Exact match for integer milliseconds + + +def test_uuidv7_with_none_timestamp(monkeypatch): + """Test UUID generation with None timestamp uses current time.""" + mock_time = 1609459200 + mock_time_func = mock.Mock(return_value=mock_time) + monkeypatch.setattr("time.time", mock_time_func) + result = uuidv7(None) + + assert isinstance(result, uuid.UUID) + assert result.version == 7 + + # Should use the mocked current time (converted to milliseconds) + assert mock_time_func.called + extracted_timestamp = uuidv7_timestamp(result) + assert extracted_timestamp == mock_time * 1000 # 1609459200.0 * 1000 + + +def test_uuidv7_time_ordering(): + """Test that sequential UUIDs have increasing timestamps.""" + # Generate UUIDs with incrementing timestamps (in milliseconds) + timestamp1 = 1609459200000 # 2021-01-01 00:00:00 UTC + timestamp2 = 1609459201000 # 2021-01-01 00:00:01 UTC + timestamp3 = 1609459202000 # 2021-01-01 00:00:02 UTC + + uuid1 = uuidv7(timestamp1) + uuid2 = uuidv7(timestamp2) + uuid3 = uuidv7(timestamp3) + + # Extract timestamps + ts1 = uuidv7_timestamp(uuid1) + ts2 = uuidv7_timestamp(uuid2) + ts3 = uuidv7_timestamp(uuid3) + + # Should be in ascending order + assert ts1 < ts2 < ts3 + + # UUIDs should be lexicographically ordered by their string representation + # due to time-ordering property of UUIDv7 + uuid_strings = [str(uuid1), str(uuid2), str(uuid3)] + assert uuid_strings == sorted(uuid_strings) + + +def test_uuidv7_uniqueness(): + """Test that multiple calls generate different UUIDs.""" + # Generate multiple UUIDs with the same timestamp (in milliseconds) + timestamp = 1609459200000 + uuids = [uuidv7(timestamp) for _ in range(100)] + + # All should be unique despite same timestamp (due to random bits) + assert len(set(uuids)) == 100 + + # All should have the same extracted timestamp + for uuid_obj in uuids: + extracted_ts = uuidv7_timestamp(uuid_obj) + assert extracted_ts == timestamp + + +def test_uuidv7_timestamp_error_handling_wrong_version(): + """Test error handling for non-UUIDv7 inputs.""" + + uuid_v4 = uuid.uuid4() + with pytest.raises(ValueError) as exc_ctx: + uuidv7_timestamp(uuid_v4) + assert "Expected UUIDv7 (version 7)" in str(exc_ctx.value) + assert f"got version {uuid_v4.version}" in str(exc_ctx.value) + + +@given(st.integers(max_value=2**48 - 1, min_value=0)) +def test_uuidv7_timestamp_round_trip(timestamp_ms): + # Generate UUID with timestamp + uuid_obj = uuidv7(timestamp_ms) + + # Extract timestamp back + extracted_timestamp = uuidv7_timestamp(uuid_obj) + + # Should match exactly for integer millisecond timestamps + assert extracted_timestamp == timestamp_ms + + +def test_uuidv7_timestamp_edge_cases(): + """Test timestamp extraction with edge case values.""" + # Test with very small timestamp + small_timestamp = 1 # 1ms after epoch + uuid_small = uuidv7(small_timestamp) + extracted_small = uuidv7_timestamp(uuid_small) + assert extracted_small == small_timestamp + + # Test with large timestamp (year 2038+) + large_timestamp = 2147483647000 # 2038-01-19 03:14:07 UTC in milliseconds + uuid_large = uuidv7(large_timestamp) + extracted_large = uuidv7_timestamp(uuid_large) + assert extracted_large == large_timestamp + + +def test_uuidv7_boundary_basic_generation(): + """Test basic boundary UUID generation with a known timestamp.""" + timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds + result = uuidv7_boundary(timestamp) + + # Should be a UUID object + assert isinstance(result, uuid.UUID) + + # Should be version 7 + assert result.version == 7 + + # Should have correct variant (RFC 4122 variant) + # Variant bits should be 10xxxxxx (0x80-0xBF range) + variant_byte = result.bytes[8] + assert (variant_byte >> 6) == 0b10 + + +def test_uuidv7_boundary_timestamp_extraction(): + """Test that boundary UUID timestamp can be extracted correctly.""" + timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds + boundary_uuid = uuidv7_boundary(timestamp) + + # Extract timestamp using existing function + extracted_timestamp = uuidv7_timestamp(boundary_uuid) + + # Should match exactly + assert extracted_timestamp == timestamp + + +def test_uuidv7_boundary_deterministic(): + """Test that boundary UUIDs are deterministic for same timestamp.""" + timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds + + # Generate multiple boundary UUIDs with same timestamp + uuid1 = uuidv7_boundary(timestamp) + uuid2 = uuidv7_boundary(timestamp) + uuid3 = uuidv7_boundary(timestamp) + + # Should all be identical + assert uuid1 == uuid2 == uuid3 + assert str(uuid1) == str(uuid2) == str(uuid3) + + +def test_uuidv7_boundary_is_minimum(): + """Test that boundary UUID is lexicographically smaller than regular UUIDs.""" + timestamp = 1609459200000 # 2021-01-01 00:00:00 UTC in milliseconds + + # Generate boundary UUID + boundary_uuid = uuidv7_boundary(timestamp) + + # Generate multiple regular UUIDs with same timestamp + regular_uuids = [uuidv7(timestamp) for _ in range(50)] + + # Boundary UUID should be lexicographically smaller than all regular UUIDs + boundary_str = str(boundary_uuid) + for regular_uuid in regular_uuids: + regular_str = str(regular_uuid) + assert boundary_str < regular_str, f"Boundary {boundary_str} should be < regular {regular_str}" + + # Also test with bytes comparison + boundary_bytes = boundary_uuid.bytes + for regular_uuid in regular_uuids: + regular_bytes = regular_uuid.bytes + assert boundary_bytes < regular_bytes + + +def test_uuidv7_boundary_different_timestamps(): + """Test that boundary UUIDs with different timestamps are ordered correctly.""" + timestamp1 = 1609459200000 # 2021-01-01 00:00:00 UTC + timestamp2 = 1609459201000 # 2021-01-01 00:00:01 UTC + timestamp3 = 1609459202000 # 2021-01-01 00:00:02 UTC + + uuid1 = uuidv7_boundary(timestamp1) + uuid2 = uuidv7_boundary(timestamp2) + uuid3 = uuidv7_boundary(timestamp3) + + # Extract timestamps to verify + ts1 = uuidv7_timestamp(uuid1) + ts2 = uuidv7_timestamp(uuid2) + ts3 = uuidv7_timestamp(uuid3) + + # Should be in ascending order + assert ts1 < ts2 < ts3 + + # UUIDs should be lexicographically ordered + uuid_strings = [str(uuid1), str(uuid2), str(uuid3)] + assert uuid_strings == sorted(uuid_strings) + + # Bytes should also be ordered + assert uuid1.bytes < uuid2.bytes < uuid3.bytes + + +def test_uuidv7_boundary_edge_cases(): + """Test boundary UUID generation with edge case timestamp values.""" + # Test with timestamp 0 (Unix epoch) + epoch_uuid = uuidv7_boundary(0) + assert isinstance(epoch_uuid, uuid.UUID) + assert epoch_uuid.version == 7 + assert uuidv7_timestamp(epoch_uuid) == 0 + + # Test with very large timestamp values + large_timestamp = 2147483647000 # 2038-01-19 03:14:07 UTC in milliseconds + large_uuid = uuidv7_boundary(large_timestamp) + assert isinstance(large_uuid, uuid.UUID) + assert large_uuid.version == 7 + assert uuidv7_timestamp(large_uuid) == large_timestamp + + # Test with current time + current_time = int(time.time() * 1000) + current_uuid = uuidv7_boundary(current_time) + assert isinstance(current_uuid, uuid.UUID) + assert current_uuid.version == 7 + assert uuidv7_timestamp(current_uuid) == current_time diff --git a/docker/.env.example b/docker/.env.example index dabd66f285..94f3766b2e 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -214,6 +214,10 @@ SQLALCHEMY_POOL_SIZE=30 SQLALCHEMY_POOL_RECYCLE=3600 # Whether to print SQL, default is false. SQLALCHEMY_ECHO=false +# If True, will test connections for liveness upon each checkout +SQLALCHEMY_POOL_PRE_PING=false +# Whether to enable the Last in first out option or use default FIFO queue if is false +SQLALCHEMY_POOL_USE_LIFO=false # Maximum number of connections to the database # Default is 100 @@ -1135,6 +1139,8 @@ PLUGIN_VOLCENGINE_TOS_REGION= # OTLP Collector Configuration # ------------------------------ ENABLE_OTEL=false +OTLP_TRACE_ENDPOINT= +OTLP_METRIC_ENDPOINT= OTLP_BASE_ENDPOINT=http://localhost:4318 OTLP_API_KEY= OTEL_EXPORTER_OTLP_PROTOCOL= diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 61362ed9fd..5f0d2b1f87 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -56,6 +56,8 @@ x-shared-env: &shared-api-worker-env SQLALCHEMY_POOL_SIZE: ${SQLALCHEMY_POOL_SIZE:-30} SQLALCHEMY_POOL_RECYCLE: ${SQLALCHEMY_POOL_RECYCLE:-3600} SQLALCHEMY_ECHO: ${SQLALCHEMY_ECHO:-false} + SQLALCHEMY_POOL_PRE_PING: ${SQLALCHEMY_POOL_PRE_PING:-false} + SQLALCHEMY_POOL_USE_LIFO: ${SQLALCHEMY_POOL_USE_LIFO:-false} POSTGRES_MAX_CONNECTIONS: ${POSTGRES_MAX_CONNECTIONS:-100} POSTGRES_SHARED_BUFFERS: ${POSTGRES_SHARED_BUFFERS:-128MB} POSTGRES_WORK_MEM: ${POSTGRES_WORK_MEM:-4MB} @@ -504,6 +506,8 @@ x-shared-env: &shared-api-worker-env PLUGIN_VOLCENGINE_TOS_SECRET_KEY: ${PLUGIN_VOLCENGINE_TOS_SECRET_KEY:-} PLUGIN_VOLCENGINE_TOS_REGION: ${PLUGIN_VOLCENGINE_TOS_REGION:-} ENABLE_OTEL: ${ENABLE_OTEL:-false} + OTLP_TRACE_ENDPOINT: ${OTLP_TRACE_ENDPOINT:-} + OTLP_METRIC_ENDPOINT: ${OTLP_METRIC_ENDPOINT:-} OTLP_BASE_ENDPOINT: ${OTLP_BASE_ENDPOINT:-http://localhost:4318} OTLP_API_KEY: ${OTLP_API_KEY:-} OTEL_EXPORTER_OTLP_PROTOCOL: ${OTEL_EXPORTER_OTLP_PROTOCOL:-} diff --git a/web/app/(commonLayout)/apps/AppCard.tsx b/web/app/(commonLayout)/apps/AppCard.tsx index f50cc10520..e04c3fdea6 100644 --- a/web/app/(commonLayout)/apps/AppCard.tsx +++ b/web/app/(commonLayout)/apps/AppCard.tsx @@ -88,6 +88,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { icon_background, description, use_icon_as_answer_icon, + max_active_requests, }) => { try { await updateAppInfo({ @@ -98,6 +99,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { icon_background, description, use_icon_as_answer_icon, + max_active_requests, }) setShowEditModal(false) notify({ @@ -432,6 +434,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => { appDescription={app.description} appMode={app.mode} appUseIconAsAnswerIcon={app.use_icon_as_answer_icon} + max_active_requests={app.max_active_requests ?? null} show={showEditModal} onConfirm={onEdit} onHide={() => setShowEditModal(false)} diff --git a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx index acaae3f720..426778c835 100644 --- a/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx +++ b/web/app/(commonLayout)/datasets/(datasetDetailLayout)/[datasetId]/layout-main.tsx @@ -62,7 +62,6 @@ const ExtraInfo = ({ isMobile, relatedApps, expand }: IExtraInfoProps) => { {
diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index 3817ebf5a4..d5a04ec420 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -71,6 +71,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx icon_background, description, use_icon_as_answer_icon, + max_active_requests, }) => { if (!appDetail) return @@ -83,6 +84,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx icon_background, description, use_icon_as_answer_icon, + max_active_requests, }) setShowEditModal(false) notify({ @@ -350,6 +352,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx appDescription={appDetail.description} appMode={appDetail.mode} appUseIconAsAnswerIcon={appDetail.use_icon_as_answer_icon} + max_active_requests={appDetail.max_active_requests ?? null} show={showEditModal} onConfirm={onEdit} onHide={() => setShowEditModal(false)} diff --git a/web/app/components/app/configuration/config-var/config-modal/index.tsx b/web/app/components/app/configuration/config-var/config-modal/index.tsx index 29cbc55b90..8fcc0f4c08 100644 --- a/web/app/components/app/configuration/config-var/config-modal/index.tsx +++ b/web/app/components/app/configuration/config-var/config-modal/index.tsx @@ -1,5 +1,5 @@ 'use client' -import type { FC } from 'react' +import type { ChangeEvent, FC } from 'react' import React, { useCallback, useEffect, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' @@ -11,7 +11,7 @@ import SelectTypeItem from '../select-type-item' import Field from './field' import Input from '@/app/components/base/input' import Toast from '@/app/components/base/toast' -import { checkKeys, getNewVarInWorkflow } from '@/utils/var' +import { checkKeys, getNewVarInWorkflow, replaceSpaceWithUnderscreInVarNameInput } from '@/utils/var' import ConfigContext from '@/context/debug-configuration' import type { InputVar, MoreInfo, UploadFileSetting } from '@/app/components/workflow/types' import Modal from '@/app/components/base/modal' @@ -109,6 +109,20 @@ const ConfigModal: FC = ({ }) }, [checkVariableName, tempPayload.label]) + const handleVarNameChange = useCallback((e: ChangeEvent) => { + replaceSpaceWithUnderscreInVarNameInput(e.target) + const value = e.target.value + const { isValid, errorKey, errorMessageKey } = checkKeys([value], true) + if (!isValid) { + Toast.notify({ + type: 'error', + message: t(`appDebug.varKeyError.${errorMessageKey}`, { key: errorKey }), + }) + return + } + handlePayloadChange('variable')(e.target.value) + }, [handlePayloadChange, t]) + const handleConfirm = () => { const moreInfo = tempPayload.variable === payload?.variable ? undefined @@ -200,7 +214,7 @@ const ConfigModal: FC = ({ handlePayloadChange('variable')(e.target.value)} + onChange={handleVarNameChange} onBlur={handleVarKeyBlur} placeholder={t('appDebug.variableConfig.inputPlaceholder')!} /> diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index 66fe85a170..a1b82ab2fe 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -209,7 +209,6 @@ const AgentTools: FC = () => { {item.tool_label} {!item.isDeleted && (
{item.tool_name}
@@ -232,7 +231,6 @@ const AgentTools: FC = () => {
@@ -259,7 +257,6 @@ const AgentTools: FC = () => { {!item.notAuthor && (
{ setCurrentTool(item) diff --git a/web/app/components/app/configuration/prompt-value-panel/index.tsx b/web/app/components/app/configuration/prompt-value-panel/index.tsx index e509ee50e4..b36bf8848a 100644 --- a/web/app/components/app/configuration/prompt-value-panel/index.tsx +++ b/web/app/components/app/configuration/prompt-value-panel/index.tsx @@ -177,7 +177,7 @@ const PromptValuePanel: FC = ({
{canNotRun && ( - +