diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh
index 1853ba06d5..022f71bfb4 100755
--- a/.devcontainer/post_create_command.sh
+++ b/.devcontainer/post_create_command.sh
@@ -1,6 +1,6 @@
#!/bin/bash
-npm add -g pnpm@10.12.1
+npm add -g pnpm@10.13.1
cd web && pnpm install
pipx install uv
@@ -12,3 +12,4 @@ echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f do
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
source /home/vscode/.bashrc
+
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
index a9580a3ba3..d684fe9144 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.yml
+++ b/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -8,13 +8,15 @@ body:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
+ - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
+ required: true
- label: This is only for bug report, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
required: true
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
+ - label: I confirm that I am using English to submit this report, otherwise it will be closed.
required: true
- - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)"
+ - label: 【中文用户 & Non English User】请使用英语提交,否则会被关闭 :)
required: true
- label: "Please do not modify this template :) and fill in all the required fields."
required: true
@@ -42,20 +44,22 @@ body:
attributes:
label: Steps to reproduce
description: We highly suggest including screenshots and a bug report log. Please use the right markdown syntax for code blocks.
- placeholder: Having detailed steps helps us reproduce the bug.
+ placeholder: Having detailed steps helps us reproduce the bug. If you have logs, please use fenced code blocks (triple backticks ```) to format them.
validations:
required: true
- type: textarea
attributes:
label: ✔️ Expected Behavior
- placeholder: What were you expecting?
+ description: Describe what you expected to happen.
+ placeholder: What were you expecting? Please do not copy and paste the steps to reproduce here.
validations:
- required: false
+ required: true
- type: textarea
attributes:
label: ❌ Actual Behavior
- placeholder: What happened instead?
+ description: Describe what actually happened.
+ placeholder: What happened instead? Please do not copy and paste the steps to reproduce here.
validations:
required: false
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
index 6877c382c4..c1666d24cf 100644
--- a/.github/ISSUE_TEMPLATE/config.yml
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -1,5 +1,11 @@
blank_issues_enabled: false
contact_links:
+ - name: "\U0001F4A1 Model Providers & Plugins"
+ url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose"
+ about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details.
+ - name: "\U0001F4AC Documentation Issues"
+ url: "https://github.com/langgenius/dify-docs/issues/new"
+ about: Report issues with the documentation, such as typos, outdated information, or missing content. Please provide the specific section and details of the issue.
- name: "\U0001F4E7 Discussions"
url: https://github.com/langgenius/dify/discussions/categories/general
- about: General discussions and request help from the community
+ about: General discussions and seek help from the community
diff --git a/.github/ISSUE_TEMPLATE/document_issue.yml b/.github/ISSUE_TEMPLATE/document_issue.yml
deleted file mode 100644
index 8fdbc0fb9a..0000000000
--- a/.github/ISSUE_TEMPLATE/document_issue.yml
+++ /dev/null
@@ -1,24 +0,0 @@
-name: "📚 Documentation Issue"
-description: Report issues in our documentation
-labels:
- - documentation
-body:
- - type: checkboxes
- attributes:
- label: Self Checks
- description: "To make sure we get to you in time, please check the following :)"
- options:
- - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
- required: true
- - label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
- required: true
- - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)"
- required: true
- - label: "Please do not modify this template :) and fill in all the required fields."
- required: true
- - type: textarea
- attributes:
- label: Provide a description of requested docs changes
- placeholder: Briefly describe which document needs to be corrected and why.
- validations:
- required: true
diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml
index b1952c63a9..bd293e2442 100644
--- a/.github/ISSUE_TEMPLATE/feature_request.yml
+++ b/.github/ISSUE_TEMPLATE/feature_request.yml
@@ -8,11 +8,11 @@ body:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
+ - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
required: true
- - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
+ - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)"
+ - label: I confirm that I am using English to submit this report, otherwise it will be closed.
required: true
- label: "Please do not modify this template :) and fill in all the required fields."
required: true
diff --git a/.github/ISSUE_TEMPLATE/translation_issue.yml b/.github/ISSUE_TEMPLATE/translation_issue.yml
deleted file mode 100644
index f9c2dfb7d2..0000000000
--- a/.github/ISSUE_TEMPLATE/translation_issue.yml
+++ /dev/null
@@ -1,55 +0,0 @@
-name: "🌐 Localization/Translation issue"
-description: Report incorrect translations. [please use English :)]
-labels:
- - translation
-body:
- - type: checkboxes
- attributes:
- label: Self Checks
- description: "To make sure we get to you in time, please check the following :)"
- options:
- - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
- required: true
- - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
- required: true
- - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)"
- required: true
- - label: "Please do not modify this template :) and fill in all the required fields."
- required: true
- - type: input
- attributes:
- label: Dify version
- description: Hover over system tray icon or look at Settings
- validations:
- required: true
- - type: input
- attributes:
- label: Utility with translation issue
- placeholder: Some area
- description: Please input here the utility with the translation issue
- validations:
- required: true
- - type: input
- attributes:
- label: 🌐 Language affected
- placeholder: "German"
- validations:
- required: true
- - type: textarea
- attributes:
- label: ❌ Actual phrase(s)
- placeholder: What is there? Please include a screenshot as that is extremely helpful.
- validations:
- required: true
- - type: textarea
- attributes:
- label: ✔️ Expected phrase(s)
- placeholder: What was expected?
- validations:
- required: true
- - type: textarea
- attributes:
- label: ℹ Why is the current translation wrong
- placeholder: Why do you feel this is incorrect?
- validations:
- required: true
diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml
index f08befefb8..a5a5071fae 100644
--- a/.github/workflows/api-tests.yml
+++ b/.github/workflows/api-tests.yml
@@ -47,15 +47,17 @@ jobs:
- name: Run Unit tests
run: |
uv run --project api bash dev/pytest/pytest_unit_tests.sh
+
+ - name: Coverage Summary
+ run: |
+ set -x
# Extract coverage percentage and create a summary
TOTAL_COVERAGE=$(python -c 'import json; print(json.load(open("coverage.json"))["totals"]["percent_covered_display"])')
# Create a detailed coverage summary
echo "### Test Coverage Summary :test_tube:" >> $GITHUB_STEP_SUMMARY
echo "Total Coverage: ${TOTAL_COVERAGE}%" >> $GITHUB_STEP_SUMMARY
- echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
- uv run --project api coverage report >> $GITHUB_STEP_SUMMARY
- echo "\`\`\`" >> $GITHUB_STEP_SUMMARY
+ uv run --project api coverage report --format=markdown >> $GITHUB_STEP_SUMMARY
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py
@@ -83,9 +85,15 @@ jobs:
compose-file: |
docker/docker-compose.middleware.yaml
services: |
+ db
+ redis
sandbox
ssrf_proxy
+ - name: setup test config
+ run: |
+ cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
+
- name: Run Workflow
run: uv run --project api bash dev/pytest/pytest_workflow.sh
diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml
index cc735ae67c..b933560a5e 100644
--- a/.github/workflows/build-push.yml
+++ b/.github/workflows/build-push.yml
@@ -6,6 +6,7 @@ on:
- "main"
- "deploy/dev"
- "deploy/enterprise"
+ - "build/**"
tags:
- "*"
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
index b06ab9653e..a283f8d5ca 100644
--- a/.github/workflows/style.yml
+++ b/.github/workflows/style.yml
@@ -28,7 +28,7 @@ jobs:
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v45
+ uses: tj-actions/changed-files@v46
with:
files: |
api/**
@@ -75,7 +75,7 @@ jobs:
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v45
+ uses: tj-actions/changed-files@v46
with:
files: web/**
@@ -113,7 +113,7 @@ jobs:
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v45
+ uses: tj-actions/changed-files@v46
with:
files: |
docker/generate_docker_compose
@@ -144,7 +144,7 @@ jobs:
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v45
+ uses: tj-actions/changed-files@v46
with:
files: |
**.sh
@@ -152,13 +152,15 @@ jobs:
**.yml
**Dockerfile
dev/**
+ .editorconfig
- name: Super-linter
- uses: super-linter/super-linter/slim@v7
+ uses: super-linter/super-linter/slim@v8
if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning
- DEFAULT_BRANCH: main
+ DEFAULT_BRANCH: origin/main
+ EDITORCONFIG_FILE_NAME: editorconfig-checker.json
FILTER_REGEX_INCLUDE: pnpm-lock.yaml
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
IGNORE_GENERATED_FILES: true
@@ -168,16 +170,6 @@ jobs:
# FIXME: temporarily disabled until api-docker.yaml's run script is fixed for shellcheck
# VALIDATE_GITHUB_ACTIONS: true
VALIDATE_DOCKERFILE_HADOLINT: true
+ VALIDATE_EDITORCONFIG: true
VALIDATE_XML: true
VALIDATE_YAML: true
-
- - name: EditorConfig checks
- uses: super-linter/super-linter/slim@v7
- env:
- DEFAULT_BRANCH: main
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- IGNORE_GENERATED_FILES: true
- IGNORE_GITIGNORED_FILES: true
- # EditorConfig validation
- VALIDATE_EDITORCONFIG: true
- EDITORCONFIG_FILE_NAME: editorconfig-checker.json
diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml
index 512d14b2ee..912267094b 100644
--- a/.github/workflows/vdb-tests.yml
+++ b/.github/workflows/vdb-tests.yml
@@ -84,10 +84,14 @@ jobs:
elasticsearch
oceanbase
- - name: Check VDB Ready (TiDB, Oceanbase)
+ - name: setup test config
run: |
- uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- uv run --project api python api/tests/integration_tests/vdb/oceanbase/check_oceanbase_ready.py
+ echo $(pwd)
+ ls -lah .
+ cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
+
+ - name: Check VDB Ready (TiDB)
+ run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py
- name: Test Vector Stores
run: uv run --project api bash dev/pytest/pytest_vdb.sh
diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml
index 37cfdc5c1e..c3f8fdbaf6 100644
--- a/.github/workflows/web-tests.yml
+++ b/.github/workflows/web-tests.yml
@@ -27,7 +27,7 @@ jobs:
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v45
+ uses: tj-actions/changed-files@v46
with:
files: web/**
diff --git a/.gitignore b/.gitignore
index 4c938b7682..dd4673a3d2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -179,6 +179,7 @@ docker/volumes/pgvecto_rs/data/*
docker/volumes/couchbase/*
docker/volumes/oceanbase/*
docker/volumes/plugin_daemon/*
+docker/volumes/matrixone/*
!docker/volumes/oceanbase/init.d
docker/nginx/conf.d/default.conf
@@ -213,3 +214,4 @@ mise.toml
# AI Assistant
.roo/
+api/.env.backup
diff --git a/README.md b/README.md
index ca09adec08..2909e0e6cf 100644
--- a/README.md
+++ b/README.md
@@ -54,7 +54,7 @@
-Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features, and more, allowing you to quickly move from prototype to production.
+Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.
## Quick start
@@ -65,7 +65,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
-The easiest way to start the Dify server is through [docker compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
+The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
```bash
cd dify
@@ -205,6 +205,7 @@ If you'd like to configure a highly-available setup, there are community-contrib
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Using Terraform for Deployment
@@ -226,6 +227,15 @@ Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/)
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
+#### Using Alibaba Cloud Computing Nest
+
+Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
+
+#### Using Alibaba Cloud Data Management
+
+One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
+
+
## Contributing
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
@@ -252,8 +262,8 @@ At the same time, please consider supporting Dify by sharing it on social media
## Security disclosure
-To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer.
+To protect your privacy, please avoid posting security issues on GitHub. Instead, report issues to security@dify.ai, and our team will respond with detailed answer.
## License
-This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions.
+This repository is licensed under the [Dify Open Source License](LICENSE), based on Apache 2.0 with additional conditions.
diff --git a/README_AR.md b/README_AR.md
index df288fd33c..e959ca0f78 100644
--- a/README_AR.md
+++ b/README_AR.md
@@ -188,6 +188,7 @@ docker compose up -d
- [رسم بياني Helm من قبل @magicsong](https://github.com/magicsong/ai-charts)
- [ملف YAML من قبل @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [ملف YAML من قبل @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 جديد! ملفات YAML (تدعم Dify v1.6.0) بواسطة @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### استخدام Terraform للتوزيع
@@ -209,6 +210,14 @@ docker compose up -d
- [AWS CDK بواسطة @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
+#### استخدام Alibaba Cloud للنشر
+ [بسرعة نشر Dify إلى سحابة علي بابا مع عش الحوسبة السحابية علي بابا](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
+
+#### استخدام Alibaba Cloud Data Management للنشر
+
+انشر Dify على علي بابا كلاود بنقرة واحدة باستخدام [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
+
+
## المساهمة
لأولئك الذين يرغبون في المساهمة، انظر إلى [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) لدينا.
diff --git a/README_BN.md b/README_BN.md
index 4a5b5f3928..29d7374ea5 100644
--- a/README_BN.md
+++ b/README_BN.md
@@ -204,6 +204,8 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 নতুন! YAML ফাইলসমূহ (Dify v1.6.0 সমর্থিত) তৈরি করেছেন @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
+
#### টেরাফর্ম ব্যবহার করে ডিপ্লয়
@@ -225,6 +227,15 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
+#### Alibaba Cloud ব্যবহার করে ডিপ্লয়
+
+ [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
+
+#### Alibaba Cloud Data Management ব্যবহার করে ডিপ্লয়
+
+ [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
+
+
## Contributing
যারা কোড অবদান রাখতে চান, তাদের জন্য আমাদের [অবদান নির্দেশিকা] দেখুন (https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)।
diff --git a/README_CN.md b/README_CN.md
index ba7ee0006d..486a368c09 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -194,9 +194,9 @@ docker compose up -d
如果您需要自定义配置,请参考 [.env.example](docker/.env.example) 文件中的注释,并更新 `.env` 文件中对应的值。此外,您可能需要根据您的具体部署环境和需求对 `docker-compose.yaml` 文件本身进行调整,例如更改镜像版本、端口映射或卷挂载。完成任何更改后,请重新运行 `docker-compose up -d`。您可以在[此处](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用环境变量的完整列表。
-#### 使用 Helm Chart 部署
+#### 使用 Helm Chart 或 Kubernetes 资源清单(YAML)部署
-使用 [Helm Chart](https://helm.sh/) 版本或者 YAML 文件,可以在 Kubernetes 上部署 Dify。
+使用 [Helm Chart](https://helm.sh/) 版本或者 Kubernetes 资源清单(YAML),可以在 Kubernetes 上部署 Dify。
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
@@ -204,6 +204,10 @@ docker compose up -d
- [YAML 文件 by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NEW! YAML 文件 (支持 Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
+
+
+
#### 使用 Terraform 部署
使用 [terraform](https://www.terraform.io/) 一键将 Dify 部署到云平台
@@ -221,6 +225,15 @@ docker compose up -d
##### AWS
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
+#### 使用 阿里云计算巢 部署
+
+使用 [阿里云计算巢](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) 将 Dify 一键部署到 阿里云
+
+#### 使用 阿里云数据管理DMS 部署
+
+使用 [阿里云数据管理DMS](https://help.aliyun.com/zh/dms/dify-in-invitational-preview) 将 Dify 一键部署到 阿里云
+
+
## Star History
[](https://star-history.com/#langgenius/dify&Date)
diff --git a/README_DE.md b/README_DE.md
index f6023a3935..fce52c34c2 100644
--- a/README_DE.md
+++ b/README_DE.md
@@ -203,6 +203,7 @@ Falls Sie eine hochverfügbare Konfiguration einrichten möchten, gibt es von de
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Terraform für die Bereitstellung verwenden
@@ -221,6 +222,15 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
+#### Alibaba Cloud
+
+[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
+
+#### Alibaba Cloud Data Management
+
+Ein-Klick-Bereitstellung von Dify in der Alibaba Cloud mit [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
+
+
## Contributing
Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren.
diff --git a/README_ES.md b/README_ES.md
index 12f2ce8c11..6fd6dfcee8 100644
--- a/README_ES.md
+++ b/README_ES.md
@@ -203,6 +203,7 @@ Si desea configurar una configuración de alta disponibilidad, la comunidad prop
- [Gráfico Helm por @magicsong](https://github.com/magicsong/ai-charts)
- [Ficheros YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [Ficheros YAML por @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 ¡NUEVO! Archivos YAML (compatible con Dify v1.6.0) por @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Uso de Terraform para el despliegue
@@ -221,6 +222,15 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
+#### Alibaba Cloud
+
+[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
+
+#### Alibaba Cloud Data Management
+
+Despliega Dify en Alibaba Cloud con un solo clic con [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
+
+
## Contribuir
Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
diff --git a/README_FR.md b/README_FR.md
index b106615b31..b2209fb495 100644
--- a/README_FR.md
+++ b/README_FR.md
@@ -201,6 +201,7 @@ Si vous souhaitez configurer une configuration haute disponibilité, la communau
- [Helm Chart par @magicsong](https://github.com/magicsong/ai-charts)
- [Fichier YAML par @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [Fichier YAML par @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NOUVEAU ! Fichiers YAML (compatible avec Dify v1.6.0) par @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Utilisation de Terraform pour le déploiement
@@ -219,6 +220,15 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK par @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
+#### Alibaba Cloud
+
+[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
+
+#### Alibaba Cloud Data Management
+
+Déployez Dify en un clic sur Alibaba Cloud avec [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
+
+
## Contribuer
Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
diff --git a/README_JA.md b/README_JA.md
index 26703f3958..c658225f90 100644
--- a/README_JA.md
+++ b/README_JA.md
@@ -155,7 +155,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
[こちら](https://dify.ai)のDify Cloudサービスを利用して、セットアップ不要で試すことができます。サンドボックスプランには、200回のGPT-4呼び出しが無料で含まれています。
- **Dify Community Editionのセルフホスティング**
-この[スタートガイド](#quick-start)を使用して、ローカル環境でDifyを簡単に実行できます。
+この[スタートガイド](#クイックスタート)を使用して、ローカル環境でDifyを簡単に実行できます。
詳しくは[ドキュメント](https://docs.dify.ai)をご覧ください。
- **企業/組織向けのDify**
@@ -202,6 +202,7 @@ docker compose up -d
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 新着!YAML ファイル(Dify v1.6.0 対応)by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Terraformを使用したデプロイ
@@ -220,6 +221,13 @@ docker compose up -d
##### AWS
- [@KevinZhaoによるAWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
+#### Alibaba Cloud
+[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
+
+#### Alibaba Cloud Data Management
+[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) を利用して、DifyをAlibaba Cloudへワンクリックでデプロイできます
+
+
## 貢献
コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。
diff --git a/README_KL.md b/README_KL.md
index ea91baa5aa..bfafcc7407 100644
--- a/README_KL.md
+++ b/README_KL.md
@@ -201,6 +201,7 @@ If you'd like to configure a highly-available setup, there are community-contrib
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Terraform atorlugu pilersitsineq
@@ -219,6 +220,15 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo
##### AWS
- [AWS CDK qachlot @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
+#### Alibaba Cloud
+
+[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
+
+#### Alibaba Cloud Data Management
+
+[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
+
+
## Contributing
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
diff --git a/README_KR.md b/README_KR.md
index 89301e8b2c..282117e776 100644
--- a/README_KR.md
+++ b/README_KR.md
@@ -195,6 +195,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Terraform을 사용한 배포
@@ -213,6 +214,15 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
##### AWS
- [KevinZhao의 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
+#### Alibaba Cloud
+
+[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
+
+#### Alibaba Cloud Data Management
+
+[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)를 통해 원클릭으로 Dify를 Alibaba Cloud에 배포할 수 있습니다
+
+
## 기여
코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요.
diff --git a/README_PT.md b/README_PT.md
index 157772d528..576f6b48f7 100644
--- a/README_PT.md
+++ b/README_PT.md
@@ -200,6 +200,7 @@ Se deseja configurar uma instalação de alta disponibilidade, há [Helm Charts]
- [Helm Chart de @magicsong](https://github.com/magicsong/ai-charts)
- [Arquivo YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [Arquivo YAML por @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NOVO! Arquivos YAML (Compatível com Dify v1.6.0) por @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Usando o Terraform para Implantação
@@ -218,6 +219,15 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
+#### Alibaba Cloud
+
+[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
+
+#### Alibaba Cloud Data Management
+
+Implante o Dify na Alibaba Cloud com um clique usando o [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
+
+
## Contribuindo
Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
diff --git a/README_SI.md b/README_SI.md
index 14de1ea792..7ded001d86 100644
--- a/README_SI.md
+++ b/README_SI.md
@@ -201,6 +201,7 @@ Star Dify on GitHub and be instantly notified of new releases.
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Uporaba Terraform za uvajanje
@@ -219,6 +220,15 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
+#### Alibaba Cloud
+
+[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
+
+#### Alibaba Cloud Data Management
+
+Z enim klikom namestite Dify na Alibaba Cloud z [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
+
+
## Prispevam
Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkrati vas prosimo, da podprete Dify tako, da ga delite na družbenih medijih ter na dogodkih in konferencah.
diff --git a/README_TR.md b/README_TR.md
index 563a05af3c..6e94e54fa0 100644
--- a/README_TR.md
+++ b/README_TR.md
@@ -194,6 +194,7 @@ Yüksek kullanılabilirliğe sahip bir kurulum yapılandırmak isterseniz, Dify'
- [@BorisPolonsky tarafından Helm Chart](https://github.com/BorisPolonsky/dify-helm)
- [@Winson-030 tarafından YAML dosyası](https://github.com/Winson-030/dify-kubernetes)
- [@wyy-holding tarafından YAML dosyası](https://github.com/wyy-holding/dify-k8s)
+- [🚀 YENİ! YAML dosyaları (Dify v1.6.0 destekli) @Zhoneym tarafından](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Dağıtım için Terraform Kullanımı
@@ -212,6 +213,15 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter
##### AWS
- [AWS CDK tarafından @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
+#### Alibaba Cloud
+
+[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
+
+#### Alibaba Cloud Data Management
+
+[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) kullanarak Dify'ı tek tıkla Alibaba Cloud'a dağıtın
+
+
## Katkıda Bulunma
Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz.
diff --git a/README_TW.md b/README_TW.md
index f4a76ac109..6e3e22b5c1 100644
--- a/README_TW.md
+++ b/README_TW.md
@@ -197,12 +197,13 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify
如果您需要自定義配置,請參考我們的 [.env.example](docker/.env.example) 文件中的註釋,並在您的 `.env` 文件中更新相應的值。此外,根據您特定的部署環境和需求,您可能需要調整 `docker-compose.yaml` 文件本身,例如更改映像版本、端口映射或卷掛載。進行任何更改後,請重新運行 `docker-compose up -d`。您可以在[這裡](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用環境變數的完整列表。
-如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 YAML 文件允許在 Kubernetes 上部署 Dify。
+如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 Kubernetes 資源清單(YAML)允許在 Kubernetes 上部署 Dify。
- [由 @LeoQuote 提供的 Helm Chart](https://github.com/douban/charts/tree/master/charts/dify)
- [由 @BorisPolonsky 提供的 Helm Chart](https://github.com/BorisPolonsky/dify-helm)
- [由 @Winson-030 提供的 YAML 文件](https://github.com/Winson-030/dify-kubernetes)
- [由 @wyy-holding 提供的 YAML 文件](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NEW! YAML 檔案(支援 Dify v1.6.0)by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
### 使用 Terraform 進行部署
@@ -224,6 +225,15 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify
- [由 @KevinZhao 提供的 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
+#### 使用 阿里云计算巢進行部署
+
+[阿里云](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
+
+#### 使用 阿里雲數據管理DMS 進行部署
+
+透過 [阿里雲數據管理DMS](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/),一鍵將 Dify 部署至阿里雲
+
+
## 貢獻
對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。
diff --git a/README_VI.md b/README_VI.md
index 4e1e05cbf3..51314e6de5 100644
--- a/README_VI.md
+++ b/README_VI.md
@@ -196,6 +196,7 @@ Nếu bạn muốn cấu hình một cài đặt có độ sẵn sàng cao, có
- [Helm Chart bởi @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
- [Tệp YAML bởi @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [Tệp YAML bởi @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 MỚI! Tệp YAML (Hỗ trợ Dify v1.6.0) bởi @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Sử dụng Terraform để Triển khai
@@ -214,6 +215,16 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/)
##### AWS
- [AWS CDK bởi @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
+
+#### Alibaba Cloud
+
+[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
+
+#### Alibaba Cloud Data Management
+
+Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
+
+
## Đóng góp
Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi.
diff --git a/api/.env.example b/api/.env.example
index 7878308588..daa0df535b 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -5,17 +5,22 @@
SECRET_KEY=
# Console API base URL
-CONSOLE_API_URL=http://127.0.0.1:5001
-CONSOLE_WEB_URL=http://127.0.0.1:3000
+CONSOLE_API_URL=http://localhost:5001
+CONSOLE_WEB_URL=http://localhost:3000
# Service API base URL
-SERVICE_API_URL=http://127.0.0.1:5001
+SERVICE_API_URL=http://localhost:5001
# Web APP base URL
-APP_WEB_URL=http://127.0.0.1:3000
+APP_WEB_URL=http://localhost:3000
# Files URL
-FILES_URL=http://127.0.0.1:5001
+FILES_URL=http://localhost:5001
+
+# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network.
+# Set this to the internal Docker service URL for proper plugin file access.
+# Example: INTERNAL_FILES_URL=http://api:5001
+INTERNAL_FILES_URL=http://127.0.0.1:5001
# The time in seconds after the signature is rejected
FILES_ACCESS_TIMEOUT=300
@@ -49,7 +54,7 @@ REDIS_CLUSTERS_PASSWORD=
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
-
+CELERY_BACKEND=redis
# PostgreSQL database configuration
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
@@ -133,12 +138,14 @@ SUPABASE_API_KEY=your-access-key
SUPABASE_URL=your-server-url
# CORS configuration
-WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
-CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
+WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
+CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
# Vector database configuration
-# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore
+# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
VECTOR_STORE=weaviate
+# Prefix used to create collection name in vector database
+VECTOR_INDEX_NAME_PREFIX=Vector_index
# Weaviate configuration
WEAVIATE_ENDPOINT=http://localhost:8080
@@ -294,6 +301,13 @@ VIKINGDB_SCHEMA=http
VIKINGDB_CONNECTION_TIMEOUT=30
VIKINGDB_SOCKET_TIMEOUT=30
+# Matrixone configration
+MATRIXONE_HOST=127.0.0.1
+MATRIXONE_PORT=6001
+MATRIXONE_USER=dump
+MATRIXONE_PASSWORD=111
+MATRIXONE_DATABASE=dify
+
# Lindorm configuration
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
LINDORM_USERNAME=admin
@@ -332,9 +346,11 @@ PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
-# Mail configuration, support: resend, smtp
+# Mail configuration, support: resend, smtp, sendgrid
MAIL_TYPE=
+# If using SendGrid, use the 'from' field for authentication if necessary.
MAIL_DEFAULT_SEND_FROM=no-reply
+# resend configuration
RESEND_API_KEY=
RESEND_API_URL=https://api.resend.com
# smtp configuration
@@ -344,7 +360,8 @@ SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false
-
+# Sendgid configuration
+SENDGRID_API_KEY=
# Sentry configuration
SENTRY_DSN=
@@ -434,6 +451,19 @@ MAX_VARIABLE_SIZE=204800
# hybrid: Save new data to object storage, read from both object storage and RDBMS
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
+# Repository configuration
+# Core workflow execution repository implementation
+CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
+
+# Core workflow node execution repository implementation
+CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
+
+# API workflow node execution repository implementation
+API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
+
+# API workflow run repository implementation
+API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
+
# App configuration
APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0
@@ -467,6 +497,8 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
# Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
+CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5
+OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5
CREATE_TIDB_SERVICE_JOB_ENABLED=false
@@ -477,6 +509,8 @@ LOGIN_LOCKOUT_DURATION=86400
# Enable OpenTelemetry
ENABLE_OTEL=false
+OTLP_TRACE_ENDPOINT=
+OTLP_METRIC_ENDPOINT=
OTLP_BASE_ENDPOINT=http://localhost:4318
OTLP_API_KEY=
OTEL_EXPORTER_OTLP_PROTOCOL=
diff --git a/api/.ruff.toml b/api/.ruff.toml
index facb0d5419..0169613bf8 100644
--- a/api/.ruff.toml
+++ b/api/.ruff.toml
@@ -1,6 +1,4 @@
-exclude = [
- "migrations/*",
-]
+exclude = ["migrations/*"]
line-length = 120
[format]
@@ -9,14 +7,14 @@ quote-style = "double"
[lint]
preview = false
select = [
- "B", # flake8-bugbear rules
- "C4", # flake8-comprehensions
- "E", # pycodestyle E rules
- "F", # pyflakes rules
- "FURB", # refurb rules
- "I", # isort rules
- "N", # pep8-naming
- "PT", # flake8-pytest-style rules
+ "B", # flake8-bugbear rules
+ "C4", # flake8-comprehensions
+ "E", # pycodestyle E rules
+ "F", # pyflakes rules
+ "FURB", # refurb rules
+ "I", # isort rules
+ "N", # pep8-naming
+ "PT", # flake8-pytest-style rules
"PLC0208", # iteration-over-set
"PLC0414", # useless-import-alias
"PLE0604", # invalid-all-object
@@ -24,19 +22,19 @@ select = [
"PLR0402", # manual-from-import
"PLR1711", # useless-return
"PLR1714", # repeated-equality-comparison
- "RUF013", # implicit-optional
- "RUF019", # unnecessary-key-check
- "RUF100", # unused-noqa
- "RUF101", # redirected-noqa
- "RUF200", # invalid-pyproject-toml
- "RUF022", # unsorted-dunder-all
- "S506", # unsafe-yaml-load
- "SIM", # flake8-simplify rules
- "TRY400", # error-instead-of-exception
- "TRY401", # verbose-log-message
- "UP", # pyupgrade rules
- "W191", # tab-indentation
- "W605", # invalid-escape-sequence
+ "RUF013", # implicit-optional
+ "RUF019", # unnecessary-key-check
+ "RUF100", # unused-noqa
+ "RUF101", # redirected-noqa
+ "RUF200", # invalid-pyproject-toml
+ "RUF022", # unsorted-dunder-all
+ "S506", # unsafe-yaml-load
+ "SIM", # flake8-simplify rules
+ "TRY400", # error-instead-of-exception
+ "TRY401", # verbose-log-message
+ "UP", # pyupgrade rules
+ "W191", # tab-indentation
+ "W605", # invalid-escape-sequence
# security related linting rules
# RCE proctection (sort of)
"S102", # exec-builtin, disallow use of `exec`
@@ -47,36 +45,37 @@ select = [
]
ignore = [
- "E402", # module-import-not-at-top-of-file
- "E711", # none-comparison
- "E712", # true-false-comparison
- "E721", # type-comparison
- "E722", # bare-except
- "F821", # undefined-name
- "F841", # unused-variable
+ "E402", # module-import-not-at-top-of-file
+ "E711", # none-comparison
+ "E712", # true-false-comparison
+ "E721", # type-comparison
+ "E722", # bare-except
+ "F821", # undefined-name
+ "F841", # unused-variable
"FURB113", # repeated-append
"FURB152", # math-constant
- "UP007", # non-pep604-annotation
- "UP032", # f-string
- "UP045", # non-pep604-annotation-optional
- "B005", # strip-with-multi-characters
- "B006", # mutable-argument-default
- "B007", # unused-loop-control-variable
- "B026", # star-arg-unpacking-after-keyword-arg
- "B903", # class-as-data-structure
- "B904", # raise-without-from-inside-except
- "B905", # zip-without-explicit-strict
- "N806", # non-lowercase-variable-in-function
- "N815", # mixed-case-variable-in-class-scope
- "PT011", # pytest-raises-too-broad
- "SIM102", # collapsible-if
- "SIM103", # needless-bool
- "SIM105", # suppressible-exception
- "SIM107", # return-in-try-except-finally
- "SIM108", # if-else-block-instead-of-if-exp
- "SIM113", # enumerate-for-loop
- "SIM117", # multiple-with-statements
- "SIM210", # if-expr-with-true-false
+ "UP007", # non-pep604-annotation
+ "UP032", # f-string
+ "UP045", # non-pep604-annotation-optional
+ "B005", # strip-with-multi-characters
+ "B006", # mutable-argument-default
+ "B007", # unused-loop-control-variable
+ "B026", # star-arg-unpacking-after-keyword-arg
+ "B903", # class-as-data-structure
+ "B904", # raise-without-from-inside-except
+ "B905", # zip-without-explicit-strict
+ "N806", # non-lowercase-variable-in-function
+ "N815", # mixed-case-variable-in-class-scope
+ "PT011", # pytest-raises-too-broad
+ "SIM102", # collapsible-if
+ "SIM103", # needless-bool
+ "SIM105", # suppressible-exception
+ "SIM107", # return-in-try-except-finally
+ "SIM108", # if-else-block-instead-of-if-exp
+ "SIM113", # enumerate-for-loop
+ "SIM117", # multiple-with-statements
+ "SIM210", # if-expr-with-true-false
+ "UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
]
[lint.per-file-ignores]
diff --git a/api/Dockerfile b/api/Dockerfile
index 7e4997507f..8c7a1717b9 100644
--- a/api/Dockerfile
+++ b/api/Dockerfile
@@ -47,6 +47,8 @@ RUN \
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
# For Security
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
+ # install fonts to support the use of tools like pypdfium2
+ fonts-noto-cjk \
# install a package to improve the accuracy of guessing mime type and file extension
media-types \
# install libmagic to support the use of python-magic guess MIMETYPE
diff --git a/api/commands.py b/api/commands.py
index 0a6cc61a68..9f933a378c 100644
--- a/api/commands.py
+++ b/api/commands.py
@@ -2,19 +2,22 @@ import base64
import json
import logging
import secrets
-from typing import Optional
+from typing import Any, Optional
import click
from flask import current_app
+from pydantic import TypeAdapter
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from configs import dify_config
from constants.languages import languages
+from core.plugin.entities.plugin import ToolProviderID
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.models.document import Document
+from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from events.app_event import app_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -27,6 +30,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, D
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.provider import Provider, ProviderModel
+from models.tools import ToolOAuthSystemClient
from services.account_service import AccountService, RegisterService, TenantService
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
from services.plugin.data_migration import PluginDataMigration
@@ -281,6 +285,7 @@ def migrate_knowledge_vector_database():
VectorType.ELASTICSEARCH,
VectorType.OPENGAUSS,
VectorType.TABLESTORE,
+ VectorType.MATRIXONE,
}
lower_collection_vector_types = {
VectorType.ANALYTICDB,
@@ -1154,3 +1159,49 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green"))
else:
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
+
+
+@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
+@click.option("--provider", prompt=True, help="Provider name")
+@click.option("--client-params", prompt=True, help="Client Params")
+def setup_system_tool_oauth_client(provider, client_params):
+ """
+ Setup system tool oauth client
+ """
+ provider_id = ToolProviderID(provider)
+ provider_name = provider_id.provider_name
+ plugin_id = provider_id.plugin_id
+
+ try:
+ # json validate
+ click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
+ client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
+ click.echo(click.style("Client params validated successfully.", fg="green"))
+
+ click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
+ click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
+ oauth_client_params = encrypt_system_oauth_params(client_params_dict)
+ click.echo(click.style("Client params encrypted successfully.", fg="green"))
+ except Exception as e:
+ click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
+ return
+
+ deleted_count = (
+ db.session.query(ToolOAuthSystemClient)
+ .filter_by(
+ provider=provider_name,
+ plugin_id=plugin_id,
+ )
+ .delete()
+ )
+ if deleted_count > 0:
+ click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
+
+ oauth_client = ToolOAuthSystemClient(
+ provider=provider_name,
+ plugin_id=plugin_id,
+ encrypted_oauth_params=oauth_client_params,
+ )
+ db.session.add(oauth_client)
+ db.session.commit()
+ click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
diff --git a/api/configs/app_config.py b/api/configs/app_config.py
index 3a3ad35ee7..20f8c40427 100644
--- a/api/configs/app_config.py
+++ b/api/configs/app_config.py
@@ -1,8 +1,11 @@
import logging
+from pathlib import Path
from typing import Any
from pydantic.fields import FieldInfo
-from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
+from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, TomlConfigSettingsSource
+
+from libs.file_utils import search_file_upwards
from .deploy import DeploymentConfig
from .enterprise import EnterpriseFeatureConfig
@@ -99,4 +102,12 @@ class DifyConfig(
RemoteSettingsSourceFactory(settings_cls),
dotenv_settings,
file_secret_settings,
+ TomlConfigSettingsSource(
+ settings_cls=settings_cls,
+ toml_file=search_file_upwards(
+ base_dir_path=Path(__file__).parent,
+ target_file_name="pyproject.toml",
+ max_search_parent_depth=2,
+ ),
+ ),
)
diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py
index a3da5c1b49..f1d529355d 100644
--- a/api/configs/feature/__init__.py
+++ b/api/configs/feature/__init__.py
@@ -31,6 +31,15 @@ class SecurityConfig(BaseSettings):
description="Duration in minutes for which a password reset token remains valid",
default=5,
)
+ CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
+ description="Duration in minutes for which a change email token remains valid",
+ default=5,
+ )
+
+ OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
+ description="Duration in minutes for which a owner transfer token remains valid",
+ default=5,
+ )
LOGIN_DISABLED: bool = Field(
description="Whether to disable login checks",
@@ -237,6 +246,13 @@ class FileAccessConfig(BaseSettings):
default="",
)
+ INTERNAL_FILES_URL: str = Field(
+ description="Internal base URL for file access within Docker network,"
+ " used for plugin daemon and internal service communication."
+ " Falls back to FILES_URL if not specified.",
+ default="",
+ )
+
FILES_ACCESS_TIMEOUT: int = Field(
description="Expiration time in seconds for file access URLs",
default=300,
@@ -530,6 +546,33 @@ class WorkflowNodeExecutionConfig(BaseSettings):
)
+class RepositoryConfig(BaseSettings):
+ """
+ Configuration for repository implementations
+ """
+
+ CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field(
+ description="Repository implementation for WorkflowExecution. Specify as a module path",
+ default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository",
+ )
+
+ CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
+ description="Repository implementation for WorkflowNodeExecution. Specify as a module path",
+ default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository",
+ )
+
+ API_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
+ description="Service-layer repository implementation for WorkflowNodeExecutionModel operations. "
+ "Specify as a module path",
+ default="repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository",
+ )
+
+ API_WORKFLOW_RUN_REPOSITORY: str = Field(
+ description="Service-layer repository implementation for WorkflowRun operations. Specify as a module path",
+ default="repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository",
+ )
+
+
class AuthConfig(BaseSettings):
"""
Configuration for authentication and OAuth
@@ -580,6 +623,16 @@ class AuthConfig(BaseSettings):
default=86400,
)
+ CHANGE_EMAIL_LOCKOUT_DURATION: PositiveInt = Field(
+ description="Time (in seconds) a user must wait before retrying change email after exceeding the rate limit.",
+ default=86400,
+ )
+
+ OWNER_TRANSFER_LOCKOUT_DURATION: PositiveInt = Field(
+ description="Time (in seconds) a user must wait before retrying owner transfer after exceeding the rate limit.",
+ default=86400,
+ )
+
class ModerationConfig(BaseSettings):
"""
@@ -609,7 +662,7 @@ class MailConfig(BaseSettings):
"""
MAIL_TYPE: Optional[str] = Field(
- description="Email service provider type ('smtp' or 'resend'), default to None.",
+ description="Email service provider type ('smtp' or 'resend' or 'sendGrid), default to None.",
default=None,
)
@@ -663,6 +716,11 @@ class MailConfig(BaseSettings):
default=50,
)
+ SENDGRID_API_KEY: Optional[str] = Field(
+ description="API key for SendGrid service",
+ default=None,
+ )
+
class RagEtlConfig(BaseSettings):
"""
@@ -891,6 +949,7 @@ class FeatureConfig(
MultiModalTransferConfig,
PositionConfig,
RagEtlConfig,
+ RepositoryConfig,
SecurityConfig,
ToolConfig,
UpdateConfig,
diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py
index 2dcf1710b0..587ea55ca7 100644
--- a/api/configs/middleware/__init__.py
+++ b/api/configs/middleware/__init__.py
@@ -24,6 +24,7 @@ from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.huawei_cloud_config import HuaweiCloudConfig
from .vdb.lindorm_config import LindormConfig
+from .vdb.matrixone_config import MatrixoneConfig
from .vdb.milvus_config import MilvusConfig
from .vdb.myscale_config import MyScaleConfig
from .vdb.oceanbase_config import OceanBaseVectorConfig
@@ -84,6 +85,11 @@ class VectorStoreConfig(BaseSettings):
default=False,
)
+ VECTOR_INDEX_NAME_PREFIX: Optional[str] = Field(
+ description="Prefix used to create collection name in vector database",
+ default="Vector_index",
+ )
+
class KeywordStoreConfig(BaseSettings):
KEYWORD_STORE: str = Field(
@@ -161,6 +167,11 @@ class DatabaseConfig(BaseSettings):
default=3600,
)
+ SQLALCHEMY_POOL_USE_LIFO: bool = Field(
+ description="If True, SQLAlchemy will use last-in-first-out way to retrieve connections from pool.",
+ default=False,
+ )
+
SQLALCHEMY_POOL_PRE_PING: bool = Field(
description="If True, enables connection pool pre-ping feature to check connections.",
default=False,
@@ -198,13 +209,14 @@ class DatabaseConfig(BaseSettings):
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
"connect_args": connect_args,
+ "pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO,
}
class CeleryConfig(DatabaseConfig):
CELERY_BACKEND: str = Field(
description="Backend for Celery task results. Options: 'database', 'redis'.",
- default="database",
+ default="redis",
)
CELERY_BROKER_URL: Optional[str] = Field(
@@ -222,6 +234,10 @@ class CeleryConfig(DatabaseConfig):
default=None,
)
+ CELERY_SENTINEL_PASSWORD: Optional[str] = Field(
+ description="Password of the Redis Sentinel master.",
+ default=None,
+ )
CELERY_SENTINEL_SOCKET_TIMEOUT: Optional[PositiveFloat] = Field(
description="Timeout for Redis Sentinel socket operations in seconds.",
default=0.1,
@@ -323,5 +339,6 @@ class MiddlewareConfig(
OpenGaussConfig,
TableStoreConfig,
DatasetQueueMonitorConfig,
+ MatrixoneConfig,
):
pass
diff --git a/api/configs/middleware/vdb/matrixone_config.py b/api/configs/middleware/vdb/matrixone_config.py
new file mode 100644
index 0000000000..9400612d8e
--- /dev/null
+++ b/api/configs/middleware/vdb/matrixone_config.py
@@ -0,0 +1,14 @@
+from pydantic import BaseModel, Field
+
+
+class MatrixoneConfig(BaseModel):
+ """Matrixone vector database configuration."""
+
+ MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")
+ MATRIXONE_PORT: int = Field(default=6001, description="Port number of the Matrixone server")
+ MATRIXONE_USER: str = Field(default="dump", description="Username for authenticating with Matrixone")
+ MATRIXONE_PASSWORD: str = Field(default="111", description="Password for authenticating with Matrixone")
+ MATRIXONE_DATABASE: str = Field(default="dify", description="Name of the Matrixone database to connect to")
+ MATRIXONE_METRIC: str = Field(
+ default="l2", description="Distance metric type for vector similarity search (cosine or l2)"
+ )
diff --git a/api/configs/observability/otel/otel_config.py b/api/configs/observability/otel/otel_config.py
index 1b88ddcfe6..7572a696ce 100644
--- a/api/configs/observability/otel/otel_config.py
+++ b/api/configs/observability/otel/otel_config.py
@@ -12,6 +12,16 @@ class OTelConfig(BaseSettings):
default=False,
)
+ OTLP_TRACE_ENDPOINT: str = Field(
+ description="OTLP trace endpoint",
+ default="",
+ )
+
+ OTLP_METRIC_ENDPOINT: str = Field(
+ description="OTLP metric endpoint",
+ default="",
+ )
+
OTLP_BASE_ENDPOINT: str = Field(
description="OTLP base endpoint",
default="http://localhost:4318",
diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py
index 0107df22c5..f511e20e6b 100644
--- a/api/configs/packaging/__init__.py
+++ b/api/configs/packaging/__init__.py
@@ -1,17 +1,13 @@
from pydantic import Field
-from pydantic_settings import BaseSettings
+from configs.packaging.pyproject import PyProjectConfig, PyProjectTomlConfig
-class PackagingInfo(BaseSettings):
+
+class PackagingInfo(PyProjectTomlConfig):
"""
Packaging build information
"""
- CURRENT_VERSION: str = Field(
- description="Dify version",
- default="1.4.3",
- )
-
COMMIT_SHA: str = Field(
description="SHA-1 checksum of the git commit used to build the app",
default="",
diff --git a/api/configs/packaging/pyproject.py b/api/configs/packaging/pyproject.py
new file mode 100644
index 0000000000..90b1ecba06
--- /dev/null
+++ b/api/configs/packaging/pyproject.py
@@ -0,0 +1,17 @@
+from pydantic import BaseModel, Field
+from pydantic_settings import BaseSettings
+
+
+class PyProjectConfig(BaseModel):
+ version: str = Field(description="Dify version", default="")
+
+
+class PyProjectTomlConfig(BaseSettings):
+ """
+ configs in api/pyproject.toml
+ """
+
+ project: PyProjectConfig = Field(
+ description="configs in the project section of pyproject.toml",
+ default=PyProjectConfig(),
+ )
diff --git a/api/constants/__init__.py b/api/constants/__init__.py
index a84de0a451..9e052320ac 100644
--- a/api/constants/__init__.py
+++ b/api/constants/__init__.py
@@ -1,6 +1,7 @@
from configs import dify_config
HIDDEN_VALUE = "[__HIDDEN__]"
+UNKNOWN_VALUE = "[__UNKNOWN__]"
UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3
diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py
index a974c63e35..e25f92399c 100644
--- a/api/controllers/console/__init__.py
+++ b/api/controllers/console/__init__.py
@@ -56,6 +56,7 @@ from .app import (
conversation,
conversation_variables,
generator,
+ mcp_server,
message,
model_config,
ops_trace,
@@ -63,6 +64,7 @@ from .app import (
statistic,
workflow,
workflow_app_log,
+ workflow_draft_variable,
workflow_run,
workflow_statistic,
)
diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py
index 8cb7ad9f5b..f5257fae79 100644
--- a/api/controllers/console/admin.py
+++ b/api/controllers/console/admin.py
@@ -56,8 +56,7 @@ class InsertExploreAppListApi(Resource):
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
- with Session(db.engine) as session:
- app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
+ app = db.session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
if not app:
raise NotFound(f"App '{args['app_id']}' is not found")
@@ -78,38 +77,38 @@ class InsertExploreAppListApi(Resource):
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
).scalar_one_or_none()
- if not recommended_app:
- recommended_app = RecommendedApp(
- app_id=app.id,
- description=desc,
- copyright=copy_right,
- privacy_policy=privacy_policy,
- custom_disclaimer=custom_disclaimer,
- language=args["language"],
- category=args["category"],
- position=args["position"],
- )
-
- db.session.add(recommended_app)
-
- app.is_public = True
- db.session.commit()
-
- return {"result": "success"}, 201
- else:
- recommended_app.description = desc
- recommended_app.copyright = copy_right
- recommended_app.privacy_policy = privacy_policy
- recommended_app.custom_disclaimer = custom_disclaimer
- recommended_app.language = args["language"]
- recommended_app.category = args["category"]
- recommended_app.position = args["position"]
+ if not recommended_app:
+ recommended_app = RecommendedApp(
+ app_id=app.id,
+ description=desc,
+ copyright=copy_right,
+ privacy_policy=privacy_policy,
+ custom_disclaimer=custom_disclaimer,
+ language=args["language"],
+ category=args["category"],
+ position=args["position"],
+ )
+
+ db.session.add(recommended_app)
+
+ app.is_public = True
+ db.session.commit()
+
+ return {"result": "success"}, 201
+ else:
+ recommended_app.description = desc
+ recommended_app.copyright = copy_right
+ recommended_app.privacy_policy = privacy_policy
+ recommended_app.custom_disclaimer = custom_disclaimer
+ recommended_app.language = args["language"]
+ recommended_app.category = args["category"]
+ recommended_app.position = args["position"]
- app.is_public = True
+ app.is_public = True
- db.session.commit()
+ db.session.commit()
- return {"result": "success"}, 200
+ return {"result": "success"}, 200
class InsertExploreAppApi(Resource):
diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py
index 14fd4679a1..2b48afd550 100644
--- a/api/controllers/console/app/annotation.py
+++ b/api/controllers/console/app/annotation.py
@@ -208,7 +208,7 @@ class AnnotationBatchImportApi(Resource):
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
- if not file.filename or not file.filename.endswith(".csv"):
+ if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
return AppAnnotationService.batch_import_app_annotations(app_id, file)
diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py
index 860166a61a..9fe32dde6d 100644
--- a/api/controllers/console/app/app.py
+++ b/api/controllers/console/app/app.py
@@ -151,6 +151,7 @@ class AppApi(Resource):
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
+ parser.add_argument("max_active_requests", type=int, location="json")
args = parser.parse_args()
app_service = AppService()
diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py
index 5dc6515ce0..9ffb94e9f9 100644
--- a/api/controllers/console/app/app_import.py
+++ b/api/controllers/console/app/app_import.py
@@ -17,6 +17,8 @@ from libs.login import login_required
from models import Account
from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus
+from services.enterprise.enterprise_service import EnterpriseService
+from services.feature_service import FeatureService
class AppImportApi(Resource):
@@ -60,7 +62,9 @@ class AppImportApi(Resource):
app_id=args.get("app_id"),
)
session.commit()
-
+ if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
+ # update web app setting as private
+ EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
# Return appropriate status code based on result
status = result.status
if status == ImportStatus.FAILED.value:
diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py
index 5f2def8d8e..665cf1aede 100644
--- a/api/controllers/console/app/audio.py
+++ b/api/controllers/console/app/audio.py
@@ -90,23 +90,11 @@ class ChatMessageTextApi(Resource):
message_id = args.get("message_id", None)
text = args.get("text", None)
- if (
- app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
- and app_model.workflow
- and app_model.workflow.features_dict
- ):
- text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
- if text_to_speech is None:
- raise ValueError("TTS is not enabled")
- voice = args.get("voice") or text_to_speech.get("voice")
- else:
- try:
- if app_model.app_model_config is None:
- raise ValueError("AppModelConfig not found")
- voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
- except Exception:
- voice = None
- response = AudioService.transcript_tts(app_model=app_model, text=text, message_id=message_id, voice=voice)
+ voice = args.get("voice", None)
+
+ response = AudioService.transcript_tts(
+ app_model=app_model, text=text, voice=voice, message_id=message_id, is_draft=True
+ )
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py
index 70d6216497..4eef9fed43 100644
--- a/api/controllers/console/app/conversation.py
+++ b/api/controllers/console/app/conversation.py
@@ -1,4 +1,4 @@
-from datetime import UTC, datetime
+from datetime import datetime
import pytz # pip install pytz
from flask_login import current_user
@@ -19,6 +19,7 @@ from fields.conversation_fields import (
conversation_pagination_fields,
conversation_with_summary_pagination_fields,
)
+from libs.datetime_utils import naive_utc_now
from libs.helper import DatetimeString
from libs.login import login_required
from models import Conversation, EndUser, Message, MessageAnnotation
@@ -315,7 +316,7 @@ def _get_conversation(app_model, conversation_id):
raise NotFound("Conversation Not Exists.")
if not conversation.read_at:
- conversation.read_at = datetime.now(UTC).replace(tzinfo=None)
+ conversation.read_at = naive_utc_now()
conversation.read_account_id = current_user.id
db.session.commit()
diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py
new file mode 100644
index 0000000000..503393f264
--- /dev/null
+++ b/api/controllers/console/app/mcp_server.py
@@ -0,0 +1,119 @@
+import json
+from enum import StrEnum
+
+from flask_login import current_user
+from flask_restful import Resource, marshal_with, reqparse
+from werkzeug.exceptions import NotFound
+
+from controllers.console import api
+from controllers.console.app.wraps import get_app_model
+from controllers.console.wraps import account_initialization_required, setup_required
+from extensions.ext_database import db
+from fields.app_fields import app_server_fields
+from libs.login import login_required
+from models.model import AppMCPServer
+
+
+class AppMCPServerStatus(StrEnum):
+ ACTIVE = "active"
+ INACTIVE = "inactive"
+
+
+class AppMCPServerController(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @get_app_model
+ @marshal_with(app_server_fields)
+ def get(self, app_model):
+ server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first()
+ return server
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @get_app_model
+ @marshal_with(app_server_fields)
+ def post(self, app_model):
+ if not current_user.is_editor:
+ raise NotFound()
+ parser = reqparse.RequestParser()
+ parser.add_argument("description", type=str, required=False, location="json")
+ parser.add_argument("parameters", type=dict, required=True, location="json")
+ args = parser.parse_args()
+
+ description = args.get("description")
+ if not description:
+ description = app_model.description or ""
+
+ server = AppMCPServer(
+ name=app_model.name,
+ description=description,
+ parameters=json.dumps(args["parameters"], ensure_ascii=False),
+ status=AppMCPServerStatus.ACTIVE,
+ app_id=app_model.id,
+ tenant_id=current_user.current_tenant_id,
+ server_code=AppMCPServer.generate_server_code(16),
+ )
+ db.session.add(server)
+ db.session.commit()
+ return server
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @get_app_model
+ @marshal_with(app_server_fields)
+ def put(self, app_model):
+ if not current_user.is_editor:
+ raise NotFound()
+ parser = reqparse.RequestParser()
+ parser.add_argument("id", type=str, required=True, location="json")
+ parser.add_argument("description", type=str, required=False, location="json")
+ parser.add_argument("parameters", type=dict, required=True, location="json")
+ parser.add_argument("status", type=str, required=False, location="json")
+ args = parser.parse_args()
+ server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
+ if not server:
+ raise NotFound()
+
+ description = args.get("description")
+ if description is None:
+ pass
+ elif not description:
+ server.description = app_model.description or ""
+ else:
+ server.description = description
+
+ server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
+ if args["status"]:
+ if args["status"] not in [status.value for status in AppMCPServerStatus]:
+ raise ValueError("Invalid status")
+ server.status = args["status"]
+ db.session.commit()
+ return server
+
+
+class AppMCPServerRefreshController(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @marshal_with(app_server_fields)
+ def get(self, server_id):
+ if not current_user.is_editor:
+ raise NotFound()
+ server = (
+ db.session.query(AppMCPServer)
+ .filter(AppMCPServer.id == server_id)
+ .filter(AppMCPServer.tenant_id == current_user.current_tenant_id)
+ .first()
+ )
+ if not server:
+ raise NotFound()
+ server.server_code = AppMCPServer.generate_server_code(16)
+ db.session.commit()
+ return server
+
+
+api.add_resource(AppMCPServerController, "/apps//server")
+api.add_resource(AppMCPServerRefreshController, "/apps//server/refresh")
diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py
index b7a4c31a15..ea659f9f5b 100644
--- a/api/controllers/console/app/message.py
+++ b/api/controllers/console/app/message.py
@@ -5,6 +5,7 @@ from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
+import services
from controllers.console import api
from controllers.console.app.error import (
CompletionRequestError,
@@ -27,7 +28,7 @@ from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import login_required
-from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
+from models.model import AppMode, Conversation, Message, MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@@ -124,33 +125,16 @@ class MessageFeedbackApi(Resource):
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
args = parser.parse_args()
- message_id = str(args["message_id"])
-
- message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
-
- if not message:
- raise NotFound("Message Not Exists.")
-
- feedback = message.admin_feedback
-
- if not args["rating"] and feedback:
- db.session.delete(feedback)
- elif args["rating"] and feedback:
- feedback.rating = args["rating"]
- elif not args["rating"] and not feedback:
- raise ValueError("rating cannot be None when feedback not exists")
- else:
- feedback = MessageFeedback(
- app_id=app_model.id,
- conversation_id=message.conversation_id,
- message_id=message.id,
- rating=args["rating"],
- from_source="admin",
- from_account_id=current_user.id,
+ try:
+ MessageService.create_feedback(
+ app_model=app_model,
+ message_id=str(args["message_id"]),
+ user=current_user,
+ rating=args.get("rating"),
+ content=None,
)
- db.session.add(feedback)
-
- db.session.commit()
+ except services.errors.message.MessageNotExistsError:
+ raise NotFound("Message Not Exists.")
return {"result": "success"}
diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py
index 3c3a359eeb..358a5e8cdb 100644
--- a/api/controllers/console/app/site.py
+++ b/api/controllers/console/app/site.py
@@ -1,5 +1,3 @@
-from datetime import UTC, datetime
-
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
@@ -10,6 +8,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.app_fields import app_site_fields
+from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import Site
@@ -77,7 +76,7 @@ class AppSite(Resource):
setattr(site, attr_name, value)
site.updated_by = current_user.id
- site.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ site.updated_at = naive_utc_now()
db.session.commit()
return site
@@ -101,7 +100,7 @@ class AppSiteAccessTokenReset(Resource):
site.code = Site.generate_code(16)
site.updated_by = current_user.id
- site.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ site.updated_at = naive_utc_now()
db.session.commit()
return site
diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py
index 86aed77412..32b64d10c5 100644
--- a/api/controllers/console/app/statistic.py
+++ b/api/controllers/console/app/statistic.py
@@ -2,6 +2,7 @@ from datetime import datetime
from decimal import Decimal
import pytz
+import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
from flask_restful import Resource, reqparse
@@ -9,10 +10,11 @@ from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
+from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.login import login_required
-from models.model import AppMode
+from models import AppMode, Message
class DailyMessageStatistic(Resource):
@@ -85,46 +87,41 @@ class DailyConversationStatistic(Resource):
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
- sql_query = """SELECT
- DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
- COUNT(DISTINCT messages.conversation_id) AS conversation_count
-FROM
- messages
-WHERE
- app_id = :app_id"""
- arg_dict = {"tz": account.timezone, "app_id": app_model.id}
-
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
+ stmt = (
+ sa.select(
+ sa.func.date(
+ sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
+ ).label("date"),
+ sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
+ )
+ .select_from(Message)
+ .where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value)
+ )
+
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
-
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
-
- sql_query += " AND created_at >= :start"
- arg_dict["start"] = start_datetime_utc
+ stmt = stmt.where(Message.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
-
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
+ stmt = stmt.where(Message.created_at < end_datetime_utc)
- sql_query += " AND created_at < :end"
- arg_dict["end"] = end_datetime_utc
-
- sql_query += " GROUP BY date ORDER BY date"
+ stmt = stmt.group_by("date").order_by("date")
response_data = []
-
with db.engine.begin() as conn:
- rs = conn.execute(db.text(sql_query), arg_dict)
- for i in rs:
- response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
+ rs = conn.execute(stmt, {"tz": account.timezone})
+ for row in rs:
+ response_data.append({"date": str(row.date), "conversation_count": row.conversation_count})
return jsonify({"data": response_data})
diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py
index cbbdd324ba..a9f088a276 100644
--- a/api/controllers/console/app/workflow.py
+++ b/api/controllers/console/app/workflow.py
@@ -1,5 +1,6 @@
import json
import logging
+from collections.abc import Sequence
from typing import cast
from flask import abort, request
@@ -18,10 +19,12 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
+from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
+from core.file.models import File
from extensions.ext_database import db
-from factories import variable_factory
+from factories import file_factory, variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper
@@ -30,6 +33,7 @@ from libs.login import current_user, login_required
from models import App
from models.account import Account
from models.model import AppMode
+from models.workflow import Workflow
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
@@ -38,6 +42,24 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
logger = logging.getLogger(__name__)
+# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
+# at the controller level rather than in the workflow logic. This would improve separation
+# of concerns and make the code more maintainable.
+def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]:
+ files = files or []
+
+ file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
+ file_objs: Sequence[File] = []
+ if file_extra_config is None:
+ return file_objs
+ file_objs = file_factory.build_from_mappings(
+ mappings=files,
+ tenant_id=workflow.tenant_id,
+ config=file_extra_config,
+ )
+ return file_objs
+
+
class DraftWorkflowApi(Resource):
@setup_required
@login_required
@@ -402,15 +424,30 @@ class DraftWorkflowNodeRunApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
+ parser.add_argument("query", type=str, required=False, location="json", default="")
+ parser.add_argument("files", type=list, location="json", default=[])
args = parser.parse_args()
- inputs = args.get("inputs")
- if inputs == None:
+ user_inputs = args.get("inputs")
+ if user_inputs is None:
raise ValueError("missing inputs")
+ workflow_srv = WorkflowService()
+ # fetch draft workflow by app_model
+ draft_workflow = workflow_srv.get_draft_workflow(app_model=app_model)
+ if not draft_workflow:
+ raise ValueError("Workflow not initialized")
+ files = _parse_file(draft_workflow, args.get("files"))
workflow_service = WorkflowService()
+
workflow_node_execution = workflow_service.run_draft_workflow_node(
- app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user
+ app_model=app_model,
+ draft_workflow=draft_workflow,
+ node_id=node_id,
+ user_inputs=user_inputs,
+ account=current_user,
+ query=args.get("query", ""),
+ files=files,
)
return workflow_node_execution
@@ -731,6 +768,27 @@ class WorkflowByIdApi(Resource):
return None, 204
+class DraftWorkflowNodeLastRunApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
+ @marshal_with(workflow_run_node_execution_fields)
+ def get(self, app_model: App, node_id: str):
+ srv = WorkflowService()
+ workflow = srv.get_draft_workflow(app_model)
+ if not workflow:
+ raise NotFound("Workflow not found")
+ node_exec = srv.get_node_last_run(
+ app_model=app_model,
+ workflow=workflow,
+ node_id=node_id,
+ )
+ if node_exec is None:
+ raise NotFound("last run not found")
+ return node_exec
+
+
api.add_resource(
DraftWorkflowApi,
"/apps//workflows/draft",
@@ -795,3 +853,7 @@ api.add_resource(
WorkflowByIdApi,
"/apps//workflows/",
)
+api.add_resource(
+ DraftWorkflowNodeLastRunApi,
+ "/apps//workflows/draft/nodes//last-run",
+)
diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py
index b9579e2120..310146a5e7 100644
--- a/api/controllers/console/app/workflow_app_log.py
+++ b/api/controllers/console/app/workflow_app_log.py
@@ -34,6 +34,20 @@ class WorkflowAppLogApi(Resource):
parser.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
)
+ parser.add_argument(
+ "created_by_end_user_session_id",
+ type=str,
+ location="args",
+ required=False,
+ default=None,
+ )
+ parser.add_argument(
+ "created_by_account",
+ type=str,
+ location="args",
+ required=False,
+ default=None,
+ )
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
@@ -57,6 +71,8 @@ class WorkflowAppLogApi(Resource):
created_at_after=args.created_at__after,
page=args.page,
limit=args.limit,
+ created_by_end_user_session_id=args.created_by_end_user_session_id,
+ created_by_account=args.created_by_account,
)
return workflow_app_log_pagination
diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py
new file mode 100644
index 0000000000..ba93f82756
--- /dev/null
+++ b/api/controllers/console/app/workflow_draft_variable.py
@@ -0,0 +1,426 @@
+import logging
+from typing import Any, NoReturn
+
+from flask import Response
+from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
+from sqlalchemy.orm import Session
+from werkzeug.exceptions import Forbidden
+
+from controllers.console import api
+from controllers.console.app.error import (
+ DraftWorkflowNotExist,
+)
+from controllers.console.app.wraps import get_app_model
+from controllers.console.wraps import account_initialization_required, setup_required
+from controllers.web.error import InvalidArgumentError, NotFoundError
+from core.variables.segment_group import SegmentGroup
+from core.variables.segments import ArrayFileSegment, FileSegment, Segment
+from core.variables.types import SegmentType
+from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
+from factories.file_factory import build_from_mapping, build_from_mappings
+from factories.variable_factory import build_segment_with_type
+from libs.login import current_user, login_required
+from models import App, AppMode, db
+from models.workflow import WorkflowDraftVariable
+from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
+from services.workflow_service import WorkflowService
+
+logger = logging.getLogger(__name__)
+
+
+def _convert_values_to_json_serializable_object(value: Segment) -> Any:
+ if isinstance(value, FileSegment):
+ return value.value.model_dump()
+ elif isinstance(value, ArrayFileSegment):
+ return [i.model_dump() for i in value.value]
+ elif isinstance(value, SegmentGroup):
+ return [_convert_values_to_json_serializable_object(i) for i in value.value]
+ else:
+ return value.value
+
+
+def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
+ value = variable.get_value()
+ # create a copy of the value to avoid affecting the model cache.
+ value = value.model_copy(deep=True)
+ # Refresh the url signature before returning it to client.
+ if isinstance(value, FileSegment):
+ file = value.value
+ file.remote_url = file.generate_url()
+ elif isinstance(value, ArrayFileSegment):
+ files = value.value
+ for file in files:
+ file.remote_url = file.generate_url()
+ return _convert_values_to_json_serializable_object(value)
+
+
+def _create_pagination_parser():
+ parser = reqparse.RequestParser()
+ parser.add_argument(
+ "page",
+ type=inputs.int_range(1, 100_000),
+ required=False,
+ default=1,
+ location="args",
+ help="the page of data requested",
+ )
+ parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
+ return parser
+
+
+def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
+ value_type = workflow_draft_var.value_type
+ return value_type.exposed_type().value
+
+
+_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
+ "id": fields.String,
+ "type": fields.String(attribute=lambda model: model.get_variable_type()),
+ "name": fields.String,
+ "description": fields.String,
+ "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
+ "value_type": fields.String(attribute=_serialize_variable_type),
+ "edited": fields.Boolean(attribute=lambda model: model.edited),
+ "visible": fields.Boolean,
+}
+
+_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
+ _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
+ value=fields.Raw(attribute=_serialize_var_value),
+)
+
+_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
+ "id": fields.String,
+ "type": fields.String(attribute=lambda _: "env"),
+ "name": fields.String,
+ "description": fields.String,
+ "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
+ "value_type": fields.String(attribute=_serialize_variable_type),
+ "edited": fields.Boolean(attribute=lambda model: model.edited),
+ "visible": fields.Boolean,
+}
+
+_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
+ "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
+}
+
+
+def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
+ return var_list.variables
+
+
+_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
+ "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
+ "total": fields.Raw(),
+}
+
+_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
+ "items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
+}
+
+
+def _api_prerequisite(f):
+ """Common prerequisites for all draft workflow variable APIs.
+
+ It ensures the following conditions are satisfied:
+
+ - Dify has been property setup.
+ - The request user has logged in and initialized.
+ - The requested app is a workflow or a chat flow.
+ - The request user has the edit permission for the app.
+ """
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
+ def wrapper(*args, **kwargs):
+ if not current_user.is_editor:
+ raise Forbidden()
+ return f(*args, **kwargs)
+
+ return wrapper
+
+
+class WorkflowVariableCollectionApi(Resource):
+ @_api_prerequisite
+ @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
+ def get(self, app_model: App):
+ """
+ Get draft workflow
+ """
+ parser = _create_pagination_parser()
+ args = parser.parse_args()
+
+ # fetch draft workflow by app_model
+ workflow_service = WorkflowService()
+ workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
+ if not workflow_exist:
+ raise DraftWorkflowNotExist()
+
+ # fetch draft workflow by app_model
+ with Session(bind=db.engine, expire_on_commit=False) as session:
+ draft_var_srv = WorkflowDraftVariableService(
+ session=session,
+ )
+ workflow_vars = draft_var_srv.list_variables_without_values(
+ app_id=app_model.id,
+ page=args.page,
+ limit=args.limit,
+ )
+
+ return workflow_vars
+
+ @_api_prerequisite
+ def delete(self, app_model: App):
+ draft_var_srv = WorkflowDraftVariableService(
+ session=db.session(),
+ )
+ draft_var_srv.delete_workflow_variables(app_model.id)
+ db.session.commit()
+ return Response("", 204)
+
+
+def validate_node_id(node_id: str) -> NoReturn | None:
+ if node_id in [
+ CONVERSATION_VARIABLE_NODE_ID,
+ SYSTEM_VARIABLE_NODE_ID,
+ ]:
+ # NOTE(QuantumGhost): While we store the system and conversation variables as node variables
+ # with specific `node_id` in database, we still want to make the API separated. By disallowing
+ # accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
+ # we mitigate the risk that user of the API depending on the implementation detail of the API.
+ #
+ # ref: [Hyrum's Law](https://www.hyrumslaw.com/)
+
+ raise InvalidArgumentError(
+ f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
+ )
+ return None
+
+
+class NodeVariableCollectionApi(Resource):
+ @_api_prerequisite
+ @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
+ def get(self, app_model: App, node_id: str):
+ validate_node_id(node_id)
+ with Session(bind=db.engine, expire_on_commit=False) as session:
+ draft_var_srv = WorkflowDraftVariableService(
+ session=session,
+ )
+ node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
+
+ return node_vars
+
+ @_api_prerequisite
+ def delete(self, app_model: App, node_id: str):
+ validate_node_id(node_id)
+ srv = WorkflowDraftVariableService(db.session())
+ srv.delete_node_variables(app_model.id, node_id)
+ db.session.commit()
+ return Response("", 204)
+
+
+class VariableApi(Resource):
+ _PATCH_NAME_FIELD = "name"
+ _PATCH_VALUE_FIELD = "value"
+
+ @_api_prerequisite
+ @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
+ def get(self, app_model: App, variable_id: str):
+ draft_var_srv = WorkflowDraftVariableService(
+ session=db.session(),
+ )
+ variable = draft_var_srv.get_variable(variable_id=variable_id)
+ if variable is None:
+ raise NotFoundError(description=f"variable not found, id={variable_id}")
+ if variable.app_id != app_model.id:
+ raise NotFoundError(description=f"variable not found, id={variable_id}")
+ return variable
+
+ @_api_prerequisite
+ @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
+ def patch(self, app_model: App, variable_id: str):
+ # Request payload for file types:
+ #
+ # Local File:
+ #
+ # {
+ # "type": "image",
+ # "transfer_method": "local_file",
+ # "url": "",
+ # "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
+ # }
+ #
+ # Remote File:
+ #
+ #
+ # {
+ # "type": "image",
+ # "transfer_method": "remote_url",
+ # "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
+ # "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
+ # }
+
+ parser = reqparse.RequestParser()
+ parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
+ # Parse 'value' field as-is to maintain its original data structure
+ parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
+
+ draft_var_srv = WorkflowDraftVariableService(
+ session=db.session(),
+ )
+ args = parser.parse_args(strict=True)
+
+ variable = draft_var_srv.get_variable(variable_id=variable_id)
+ if variable is None:
+ raise NotFoundError(description=f"variable not found, id={variable_id}")
+ if variable.app_id != app_model.id:
+ raise NotFoundError(description=f"variable not found, id={variable_id}")
+
+ new_name = args.get(self._PATCH_NAME_FIELD, None)
+ raw_value = args.get(self._PATCH_VALUE_FIELD, None)
+ if new_name is None and raw_value is None:
+ return variable
+
+ new_value = None
+ if raw_value is not None:
+ if variable.value_type == SegmentType.FILE:
+ if not isinstance(raw_value, dict):
+ raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
+ raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
+ elif variable.value_type == SegmentType.ARRAY_FILE:
+ if not isinstance(raw_value, list):
+ raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
+ if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
+ raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
+ raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
+ new_value = build_segment_with_type(variable.value_type, raw_value)
+ draft_var_srv.update_variable(variable, name=new_name, value=new_value)
+ db.session.commit()
+ return variable
+
+ @_api_prerequisite
+ def delete(self, app_model: App, variable_id: str):
+ draft_var_srv = WorkflowDraftVariableService(
+ session=db.session(),
+ )
+ variable = draft_var_srv.get_variable(variable_id=variable_id)
+ if variable is None:
+ raise NotFoundError(description=f"variable not found, id={variable_id}")
+ if variable.app_id != app_model.id:
+ raise NotFoundError(description=f"variable not found, id={variable_id}")
+ draft_var_srv.delete_variable(variable)
+ db.session.commit()
+ return Response("", 204)
+
+
+class VariableResetApi(Resource):
+ @_api_prerequisite
+ def put(self, app_model: App, variable_id: str):
+ draft_var_srv = WorkflowDraftVariableService(
+ session=db.session(),
+ )
+
+ workflow_srv = WorkflowService()
+ draft_workflow = workflow_srv.get_draft_workflow(app_model)
+ if draft_workflow is None:
+ raise NotFoundError(
+ f"Draft workflow not found, app_id={app_model.id}",
+ )
+ variable = draft_var_srv.get_variable(variable_id=variable_id)
+ if variable is None:
+ raise NotFoundError(description=f"variable not found, id={variable_id}")
+ if variable.app_id != app_model.id:
+ raise NotFoundError(description=f"variable not found, id={variable_id}")
+
+ resetted = draft_var_srv.reset_variable(draft_workflow, variable)
+ db.session.commit()
+ if resetted is None:
+ return Response("", 204)
+ else:
+ return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
+
+
+def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
+ with Session(bind=db.engine, expire_on_commit=False) as session:
+ draft_var_srv = WorkflowDraftVariableService(
+ session=session,
+ )
+ if node_id == CONVERSATION_VARIABLE_NODE_ID:
+ draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
+ elif node_id == SYSTEM_VARIABLE_NODE_ID:
+ draft_vars = draft_var_srv.list_system_variables(app_model.id)
+ else:
+ draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
+ return draft_vars
+
+
+class ConversationVariableCollectionApi(Resource):
+ @_api_prerequisite
+ @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
+ def get(self, app_model: App):
+ # NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
+ # so their IDs can be returned to the caller.
+ workflow_srv = WorkflowService()
+ draft_workflow = workflow_srv.get_draft_workflow(app_model)
+ if draft_workflow is None:
+ raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
+ draft_var_srv = WorkflowDraftVariableService(db.session())
+ draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
+ db.session.commit()
+ return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
+
+
+class SystemVariableCollectionApi(Resource):
+ @_api_prerequisite
+ @marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
+ def get(self, app_model: App):
+ return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
+
+
+class EnvironmentVariableCollectionApi(Resource):
+ @_api_prerequisite
+ def get(self, app_model: App):
+ """
+ Get draft workflow
+ """
+ # fetch draft workflow by app_model
+ workflow_service = WorkflowService()
+ workflow = workflow_service.get_draft_workflow(app_model=app_model)
+ if workflow is None:
+ raise DraftWorkflowNotExist()
+
+ env_vars = workflow.environment_variables
+ env_vars_list = []
+ for v in env_vars:
+ env_vars_list.append(
+ {
+ "id": v.id,
+ "type": "env",
+ "name": v.name,
+ "description": v.description,
+ "selector": v.selector,
+ "value_type": v.value_type.exposed_type().value,
+ "value": v.value,
+ # Do not track edited for env vars.
+ "edited": False,
+ "visible": True,
+ "editable": True,
+ }
+ )
+
+ return {"items": env_vars_list}
+
+
+api.add_resource(
+ WorkflowVariableCollectionApi,
+ "/apps//workflows/draft/variables",
+)
+api.add_resource(NodeVariableCollectionApi, "/apps//workflows/draft/nodes//variables")
+api.add_resource(VariableApi, "/apps//workflows/draft/variables/")
+api.add_resource(VariableResetApi, "/apps//workflows/draft/variables//reset")
+
+api.add_resource(ConversationVariableCollectionApi, "/apps//workflows/draft/conversation-variables")
+api.add_resource(SystemVariableCollectionApi, "/apps//workflows/draft/system-variables")
+api.add_resource(EnvironmentVariableCollectionApi, "/apps//workflows/draft/environment-variables")
diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py
index 9ad8c15847..3322350e25 100644
--- a/api/controllers/console/app/wraps.py
+++ b/api/controllers/console/app/wraps.py
@@ -8,6 +8,15 @@ from libs.login import current_user
from models import App, AppMode
+def _load_app_model(app_id: str) -> Optional[App]:
+ app_model = (
+ db.session.query(App)
+ .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+ .first()
+ )
+ return app_model
+
+
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func):
@wraps(view_func)
@@ -20,18 +29,12 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[
del kwargs["app_id"]
- app_model = (
- db.session.query(App)
- .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
- .first()
- )
+ app_model = _load_app_model(app_id)
if not app_model:
raise AppNotFoundError()
app_mode = AppMode.value_of(app_model.mode)
- if app_mode == AppMode.CHANNEL:
- raise AppNotFoundError()
if mode is not None:
if isinstance(mode, list):
diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py
index 1795563ff7..2562fb5eb8 100644
--- a/api/controllers/console/auth/activate.py
+++ b/api/controllers/console/auth/activate.py
@@ -1,5 +1,3 @@
-import datetime
-
from flask import request
from flask_restful import Resource, reqparse
@@ -7,6 +5,7 @@ from constants.languages import supported_language
from controllers.console import api
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from libs.helper import StrLen, email, extract_remote_ip, timezone
from models.account import AccountStatus
from services.account_service import AccountService, RegisterService
@@ -65,7 +64,7 @@ class ActivateApi(Resource):
account.timezone = args["timezone"]
account.interface_theme = "light"
account.status = AccountStatus.ACTIVE.value
- account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ account.initialized_at = naive_utc_now()
db.session.commit()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py
index 1049f864c3..4c9697cc32 100644
--- a/api/controllers/console/auth/data_source_oauth.py
+++ b/api/controllers/console/auth/data_source_oauth.py
@@ -41,7 +41,7 @@ class OAuthDataSource(Resource):
if not internal_secret:
return ({"error": "Internal secret is not set"},)
oauth_provider.save_internal_access_token(internal_secret)
- return {"data": ""}
+ return {"data": "internal"}
else:
auth_url = oauth_provider.get_authorization_url()
return {"data": auth_url}, 200
diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py
index b40934dbf5..8c5e23de58 100644
--- a/api/controllers/console/auth/error.py
+++ b/api/controllers/console/auth/error.py
@@ -27,7 +27,19 @@ class InvalidTokenError(BaseHTTPException):
class PasswordResetRateLimitExceededError(BaseHTTPException):
error_code = "password_reset_rate_limit_exceeded"
- description = "Too many password reset emails have been sent. Please try again in 1 minutes."
+ description = "Too many password reset emails have been sent. Please try again in 1 minute."
+ code = 429
+
+
+class EmailChangeRateLimitExceededError(BaseHTTPException):
+ error_code = "email_change_rate_limit_exceeded"
+ description = "Too many email change emails have been sent. Please try again in 1 minute."
+ code = 429
+
+
+class OwnerTransferRateLimitExceededError(BaseHTTPException):
+ error_code = "owner_transfer_rate_limit_exceeded"
+ description = "Too many owner transfer emails have been sent. Please try again in 1 minute."
code = 429
@@ -65,3 +77,39 @@ class EmailPasswordResetLimitError(BaseHTTPException):
error_code = "email_password_reset_limit"
description = "Too many failed password reset attempts. Please try again in 24 hours."
code = 429
+
+
+class EmailChangeLimitError(BaseHTTPException):
+ error_code = "email_change_limit"
+ description = "Too many failed email change attempts. Please try again in 24 hours."
+ code = 429
+
+
+class EmailAlreadyInUseError(BaseHTTPException):
+ error_code = "email_already_in_use"
+ description = "A user with this email already exists."
+ code = 400
+
+
+class OwnerTransferLimitError(BaseHTTPException):
+ error_code = "owner_transfer_limit"
+ description = "Too many failed owner transfer attempts. Please try again in 24 hours."
+ code = 429
+
+
+class NotOwnerError(BaseHTTPException):
+ error_code = "not_owner"
+ description = "You are not the owner of the workspace."
+ code = 400
+
+
+class CannotTransferOwnerToSelfError(BaseHTTPException):
+ error_code = "cannot_transfer_owner_to_self"
+ description = "You cannot transfer ownership to yourself."
+ code = 400
+
+
+class MemberNotInTenantError(BaseHTTPException):
+ error_code = "member_not_in_tenant"
+ description = "The member is not in the workspace."
+ code = 400
diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py
index 395367c9e2..d0a4f3ff6d 100644
--- a/api/controllers/console/auth/oauth.py
+++ b/api/controllers/console/auth/oauth.py
@@ -1,5 +1,4 @@
import logging
-from datetime import UTC, datetime
from typing import Optional
import requests
@@ -13,6 +12,7 @@ from configs import dify_config
from constants.languages import languages
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account
@@ -110,7 +110,7 @@ class OAuthCallback(Resource):
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
- account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
+ account.initialized_at = naive_utc_now()
db.session.commit()
try:
diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py
index 7b0d9373cf..b49f8affc8 100644
--- a/api/controllers/console/datasets/data_source.py
+++ b/api/controllers/console/datasets/data_source.py
@@ -1,4 +1,3 @@
-import datetime
import json
from flask import request
@@ -15,6 +14,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
+from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService
@@ -88,7 +88,7 @@ class DataSourceApi(Resource):
if action == "enable":
if data_source_binding.disabled:
data_source_binding.disabled = False
- data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
@@ -97,7 +97,7 @@ class DataSourceApi(Resource):
if action == "disable":
if not data_source_binding.disabled:
data_source_binding.disabled = True
- data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py
index e68273afa6..4f62ac78b4 100644
--- a/api/controllers/console/datasets/datasets.py
+++ b/api/controllers/console/datasets/datasets.py
@@ -211,10 +211,6 @@ class DatasetApi(Resource):
else:
data["embedding_available"] = True
- if data.get("permission") == "partial_members":
- part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
- data.update({"partial_member_list": part_users_list})
-
return data, 200
@setup_required
@@ -686,6 +682,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.TABLESTORE
| VectorType.HUAWEI_CLOUD
| VectorType.TENCENT
+ | VectorType.MATRIXONE
):
return {
"retrieval_method": [
@@ -733,6 +730,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.TABLESTORE
| VectorType.TENCENT
| VectorType.HUAWEI_CLOUD
+ | VectorType.MATRIXONE
):
return {
"retrieval_method": [
diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py
index f7c04102a9..28a2e93049 100644
--- a/api/controllers/console/datasets/datasets_document.py
+++ b/api/controllers/console/datasets/datasets_document.py
@@ -1,11 +1,10 @@
import logging
from argparse import ArgumentTypeError
-from datetime import UTC, datetime
from typing import cast
from flask import request
from flask_login import current_user
-from flask_restful import Resource, fields, marshal, marshal_with, reqparse
+from flask_restful import Resource, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
@@ -43,19 +42,17 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
-from extensions.ext_redis import redis_client
from fields.document_fields import (
dataset_and_document_fields,
document_fields,
document_status_fields,
document_with_segments_fields,
)
+from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
-from tasks.add_document_to_index_task import add_document_to_index_task
-from tasks.remove_document_from_index_task import remove_document_from_index_task
class DocumentResource(Resource):
@@ -242,12 +239,10 @@ class DatasetDocumentListApi(Resource):
return response
- documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String}
-
@setup_required
@login_required
@account_initialization_required
- @marshal_with(documents_and_batch_fields)
+ @marshal_with(dataset_and_document_fields)
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
@@ -293,6 +288,8 @@ class DatasetDocumentListApi(Resource):
try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
+ dataset = DatasetService.get_dataset(dataset_id)
+
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
@@ -300,7 +297,7 @@ class DatasetDocumentListApi(Resource):
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
- return {"documents": documents, "batch": batch}
+ return {"dataset": dataset, "documents": documents, "batch": batch}
@setup_required
@login_required
@@ -753,7 +750,7 @@ class DocumentProcessingApi(DocumentResource):
raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id
- document.paused_at = datetime.now(UTC).replace(tzinfo=None)
+ document.paused_at = naive_utc_now()
document.is_paused = True
db.session.commit()
@@ -833,7 +830,7 @@ class DocumentMetadataApi(DocumentResource):
document.doc_metadata[key] = value
document.doc_type = doc_type
- document.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ document.updated_at = naive_utc_now()
db.session.commit()
return {"result": "success", "message": "Document metadata updated."}, 200
@@ -862,77 +859,16 @@ class DocumentStatusApi(DocumentResource):
DatasetService.check_dataset_permission(dataset, current_user)
document_ids = request.args.getlist("document_id")
- for document_id in document_ids:
- document = self.get_document(dataset_id, document_id)
-
- indexing_cache_key = "document_{}_indexing".format(document.id)
- cache_result = redis_client.get(indexing_cache_key)
- if cache_result is not None:
- raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later")
-
- if action == "enable":
- if document.enabled:
- continue
- document.enabled = True
- document.disabled_at = None
- document.disabled_by = None
- document.updated_at = datetime.now(UTC).replace(tzinfo=None)
- db.session.commit()
-
- # Set cache to prevent indexing the same document multiple times
- redis_client.setex(indexing_cache_key, 600, 1)
-
- add_document_to_index_task.delay(document_id)
-
- elif action == "disable":
- if not document.completed_at or document.indexing_status != "completed":
- raise InvalidActionError(f"Document: {document.name} is not completed.")
- if not document.enabled:
- continue
-
- document.enabled = False
- document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
- document.disabled_by = current_user.id
- document.updated_at = datetime.now(UTC).replace(tzinfo=None)
- db.session.commit()
-
- # Set cache to prevent indexing the same document multiple times
- redis_client.setex(indexing_cache_key, 600, 1)
-
- remove_document_from_index_task.delay(document_id)
-
- elif action == "archive":
- if document.archived:
- continue
-
- document.archived = True
- document.archived_at = datetime.now(UTC).replace(tzinfo=None)
- document.archived_by = current_user.id
- document.updated_at = datetime.now(UTC).replace(tzinfo=None)
- db.session.commit()
-
- if document.enabled:
- # Set cache to prevent indexing the same document multiple times
- redis_client.setex(indexing_cache_key, 600, 1)
-
- remove_document_from_index_task.delay(document_id)
-
- elif action == "un_archive":
- if not document.archived:
- continue
- document.archived = False
- document.archived_at = None
- document.archived_by = None
- document.updated_at = datetime.now(UTC).replace(tzinfo=None)
- db.session.commit()
-
- # Set cache to prevent indexing the same document multiple times
- redis_client.setex(indexing_cache_key, 600, 1)
-
- add_document_to_index_task.delay(document_id)
- else:
- raise InvalidActionError()
+ try:
+ DocumentService.batch_update_document_status(dataset, document_ids, action, current_user)
+ except services.errors.document.DocumentIndexingError as e:
+ raise InvalidActionError(str(e))
+ except ValueError as e:
+ raise InvalidActionError(str(e))
+ except NotFound as e:
+ raise NotFound(str(e))
+
return {"result": "success"}, 200
diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py
index bc37907a30..48142dbe73 100644
--- a/api/controllers/console/datasets/datasets_segments.py
+++ b/api/controllers/console/datasets/datasets_segments.py
@@ -374,7 +374,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if len(request.files) > 1:
raise TooManyFilesError()
# check file type
- if not file.filename or not file.filename.endswith(".csv"):
+ if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
try:
diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py
index 2f00a84de6..cb68bb5e81 100644
--- a/api/controllers/console/datasets/error.py
+++ b/api/controllers/console/datasets/error.py
@@ -25,12 +25,6 @@ class UnsupportedFileTypeError(BaseHTTPException):
code = 415
-class HighQualityDatasetOnlyError(BaseHTTPException):
- error_code = "high_quality_dataset_only"
- description = "Current operation only supports 'high-quality' datasets."
- code = 400
-
-
class DatasetNotInitializedError(BaseHTTPException):
error_code = "dataset_not_initialized"
description = "The dataset is still being initialized or indexing. Please wait a moment."
diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py
index 4200a51709..fcdc91ec67 100644
--- a/api/controllers/console/datasets/website.py
+++ b/api/controllers/console/datasets/website.py
@@ -4,7 +4,7 @@ from controllers.console import api
from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
-from services.website_service import WebsiteService
+from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService
class WebsiteCrawlApi(Resource):
@@ -24,10 +24,16 @@ class WebsiteCrawlApi(Resource):
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
args = parser.parse_args()
- WebsiteService.document_create_args_validate(args)
- # crawl url
+
+ # Create typed request and validate
+ try:
+ api_request = WebsiteCrawlApiRequest.from_args(args)
+ except ValueError as e:
+ raise WebsiteCrawlError(str(e))
+
+ # Crawl URL using typed request
try:
- result = WebsiteService.crawl_url(args)
+ result = WebsiteService.crawl_url(api_request)
except Exception as e:
raise WebsiteCrawlError(str(e))
return result, 200
@@ -43,9 +49,16 @@ class WebsiteCrawlStatusApi(Resource):
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
)
args = parser.parse_args()
- # get crawl status
+
+ # Create typed request and validate
+ try:
+ api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id)
+ except ValueError as e:
+ raise WebsiteCrawlError(str(e))
+
+ # Get crawl status using typed request
try:
- result = WebsiteService.get_crawl_status(job_id, args["provider"])
+ result = WebsiteService.get_crawl_status_typed(api_request)
except Exception as e:
raise WebsiteCrawlError(str(e))
return result, 200
diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py
index 54bc590677..d564a00a76 100644
--- a/api/controllers/console/explore/audio.py
+++ b/api/controllers/console/explore/audio.py
@@ -18,7 +18,6 @@ from controllers.console.app.error import (
from controllers.console.explore.wraps import InstalledAppResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
-from models.model import AppMode
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@@ -79,19 +78,9 @@ class ChatTextApi(InstalledAppResource):
message_id = args.get("message_id", None)
text = args.get("text", None)
- if (
- app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
- and app_model.workflow
- and app_model.workflow.features_dict
- ):
- text_to_speech = app_model.workflow.features_dict.get("text_to_speech")
- voice = args.get("voice") or text_to_speech.get("voice")
- else:
- try:
- voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
- except Exception:
- voice = None
- response = AudioService.transcript_tts(app_model=app_model, message_id=message_id, voice=voice, text=text)
+ voice = args.get("voice", None)
+
+ response = AudioService.transcript_tts(app_model=app_model, text=text, voice=voice, message_id=message_id)
return response
except services.errors.app_model_config.AppModelConfigBrokenError:
logging.exception("App model config broken.")
diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py
index 4367da1162..4842fefc57 100644
--- a/api/controllers/console/explore/completion.py
+++ b/api/controllers/console/explore/completion.py
@@ -1,5 +1,4 @@
import logging
-from datetime import UTC, datetime
from flask_login import current_user
from flask_restful import reqparse
@@ -27,6 +26,7 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs import helper
+from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@@ -51,7 +51,7 @@ class CompletionApi(InstalledAppResource):
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
- installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
+ installed_app.last_used_at = naive_utc_now()
db.session.commit()
try:
@@ -111,7 +111,7 @@ class ChatApi(InstalledAppResource):
args["auto_generate_name"] = False
- installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
+ installed_app.last_used_at = naive_utc_now()
db.session.commit()
try:
diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py
index 9d0c08564e..29111fb865 100644
--- a/api/controllers/console/explore/installed_app.py
+++ b/api/controllers/console/explore/installed_app.py
@@ -1,5 +1,4 @@
import logging
-from datetime import UTC, datetime
from typing import Any
from flask import request
@@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
+from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService
@@ -122,7 +122,7 @@ class InstalledAppsListApi(Resource):
tenant_id=current_tenant_id,
app_owner_tenant_id=app.tenant_id,
is_pinned=False,
- last_used_at=datetime.now(UTC).replace(tzinfo=None),
+ last_used_at=naive_utc_now(),
)
db.session.add(new_installed_app)
db.session.commit()
diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py
index 7dea8e554e..447cc358f8 100644
--- a/api/controllers/console/version.py
+++ b/api/controllers/console/version.py
@@ -18,7 +18,7 @@ class VersionApi(Resource):
check_update_url = dify_config.CHECK_UPDATE_URL
result = {
- "version": dify_config.CURRENT_VERSION,
+ "version": dify_config.project.version,
"release_date": "",
"release_notes": "",
"can_auto_update": False,
diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py
index a9dbf44456..7f7e64a59c 100644
--- a/api/controllers/console/workspace/account.py
+++ b/api/controllers/console/workspace/account.py
@@ -1,13 +1,21 @@
-import datetime
-
import pytz
from flask import request
from flask_login import current_user
from flask_restful import Resource, fields, marshal_with, reqparse
+from sqlalchemy import select
+from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import supported_language
from controllers.console import api
+from controllers.console.auth.error import (
+ EmailAlreadyInUseError,
+ EmailChangeLimitError,
+ EmailCodeError,
+ InvalidEmailError,
+ InvalidTokenError,
+)
+from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from controllers.console.workspace.error import (
AccountAlreadyInitedError,
CurrentPasswordIncorrectError,
@@ -18,15 +26,18 @@ from controllers.console.workspace.error import (
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_enabled,
+ enable_change_email,
enterprise_license_required,
only_edition_cloud,
setup_required,
)
from extensions.ext_database import db
from fields.member_fields import account_fields
-from libs.helper import TimestampField, timezone
+from libs.datetime_utils import naive_utc_now
+from libs.helper import TimestampField, email, extract_remote_ip, timezone
from libs.login import login_required
from models import AccountIntegrate, InvitationCode
+from models.account import Account
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@@ -68,7 +79,7 @@ class AccountInitApi(Resource):
raise InvalidInvitationCodeError()
invitation_code.status = "used"
- invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ invitation_code.used_at = naive_utc_now()
invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id
@@ -76,7 +87,7 @@ class AccountInitApi(Resource):
account.timezone = args["timezone"]
account.interface_theme = "light"
account.status = "active"
- account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ account.initialized_at = naive_utc_now()
db.session.commit()
return {"result": "success"}
@@ -369,6 +380,134 @@ class EducationAutoCompleteApi(Resource):
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
+class ChangeEmailSendEmailApi(Resource):
+ @enable_change_email
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("email", type=email, required=True, location="json")
+ parser.add_argument("language", type=str, required=False, location="json")
+ parser.add_argument("phase", type=str, required=False, location="json")
+ parser.add_argument("token", type=str, required=False, location="json")
+ args = parser.parse_args()
+
+ ip_address = extract_remote_ip(request)
+ if AccountService.is_email_send_ip_limit(ip_address):
+ raise EmailSendIpLimitError()
+
+ if args["language"] is not None and args["language"] == "zh-Hans":
+ language = "zh-Hans"
+ else:
+ language = "en-US"
+ account = None
+ user_email = args["email"]
+ if args["phase"] is not None and args["phase"] == "new_email":
+ if args["token"] is None:
+ raise InvalidTokenError()
+
+ reset_data = AccountService.get_change_email_data(args["token"])
+ if reset_data is None:
+ raise InvalidTokenError()
+ user_email = reset_data.get("email", "")
+
+ if user_email != current_user.email:
+ raise InvalidEmailError()
+ else:
+ with Session(db.engine) as session:
+ account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
+ if account is None:
+ raise AccountNotFound()
+
+ token = AccountService.send_change_email_email(
+ account=account, email=args["email"], old_email=user_email, language=language, phase=args["phase"]
+ )
+ return {"result": "success", "data": token}
+
+
+class ChangeEmailCheckApi(Resource):
+ @enable_change_email
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("email", type=email, required=True, location="json")
+ parser.add_argument("code", type=str, required=True, location="json")
+ parser.add_argument("token", type=str, required=True, nullable=False, location="json")
+ args = parser.parse_args()
+
+ user_email = args["email"]
+
+ is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args["email"])
+ if is_change_email_error_rate_limit:
+ raise EmailChangeLimitError()
+
+ token_data = AccountService.get_change_email_data(args["token"])
+ if token_data is None:
+ raise InvalidTokenError()
+
+ if user_email != token_data.get("email"):
+ raise InvalidEmailError()
+
+ if args["code"] != token_data.get("code"):
+ AccountService.add_change_email_error_rate_limit(args["email"])
+ raise EmailCodeError()
+
+ # Verified, revoke the first token
+ AccountService.revoke_change_email_token(args["token"])
+
+ # Refresh token data by generating a new token
+ _, new_token = AccountService.generate_change_email_token(
+ user_email, code=args["code"], old_email=token_data.get("old_email"), additional_data={}
+ )
+
+ AccountService.reset_change_email_error_rate_limit(args["email"])
+ return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+
+
+class ChangeEmailResetApi(Resource):
+ @enable_change_email
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @marshal_with(account_fields)
+ def post(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("new_email", type=email, required=True, location="json")
+ parser.add_argument("token", type=str, required=True, nullable=False, location="json")
+ args = parser.parse_args()
+
+ reset_data = AccountService.get_change_email_data(args["token"])
+ if not reset_data:
+ raise InvalidTokenError()
+
+ AccountService.revoke_change_email_token(args["token"])
+
+ if not AccountService.check_email_unique(args["new_email"]):
+ raise EmailAlreadyInUseError()
+
+ old_email = reset_data.get("old_email", "")
+ if current_user.email != old_email:
+ raise AccountNotFound()
+
+ updated_account = AccountService.update_account(current_user, email=args["new_email"])
+
+ return updated_account
+
+
+class CheckEmailUnique(Resource):
+ @setup_required
+ def post(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("email", type=email, required=True, location="json")
+ args = parser.parse_args()
+ if not AccountService.check_email_unique(args["email"]):
+ raise EmailAlreadyInUseError()
+ return {"result": "success"}
+
+
# Register API resources
api.add_resource(AccountInitApi, "/account/init")
api.add_resource(AccountProfileApi, "/account/profile")
@@ -385,5 +524,10 @@ api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback")
api.add_resource(EducationVerifyApi, "/account/education/verify")
api.add_resource(EducationApi, "/account/education")
api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete")
+# Change email
+api.add_resource(ChangeEmailSendEmailApi, "/account/change-email")
+api.add_resource(ChangeEmailCheckApi, "/account/change-email/validity")
+api.add_resource(ChangeEmailResetApi, "/account/change-email/reset")
+api.add_resource(CheckEmailUnique, "/account/change-email/check-email-unique")
# api.add_resource(AccountEmailApi, '/account/email')
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')
diff --git a/api/controllers/console/workspace/error.py b/api/controllers/console/workspace/error.py
index 8b70ca62b9..4427d1ff72 100644
--- a/api/controllers/console/workspace/error.py
+++ b/api/controllers/console/workspace/error.py
@@ -13,12 +13,6 @@ class CurrentPasswordIncorrectError(BaseHTTPException):
code = 400
-class ProviderRequestFailedError(BaseHTTPException):
- error_code = "provider_request_failed"
- description = None
- code = 400
-
-
class InvalidInvitationCodeError(BaseHTTPException):
error_code = "invalid_invitation_code"
description = "Invalid invitation code."
diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py
index ba74e2c074..b4eb5e246b 100644
--- a/api/controllers/console/workspace/load_balancing_config.py
+++ b/api/controllers/console/workspace/load_balancing_config.py
@@ -15,7 +15,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str):
- if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
+ if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id
@@ -64,7 +64,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
@login_required
@account_initialization_required
def post(self, provider: str, config_id: str):
- if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role):
+ if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden()
tenant_id = current_user.current_tenant_id
diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py
index db49da7840..b1f79ffdec 100644
--- a/api/controllers/console/workspace/members.py
+++ b/api/controllers/console/workspace/members.py
@@ -1,22 +1,34 @@
from urllib import parse
+from flask import request
from flask_login import current_user
from flask_restful import Resource, abort, marshal_with, reqparse
import services
from configs import dify_config
from controllers.console import api
-from controllers.console.error import WorkspaceMembersLimitExceeded
+from controllers.console.auth.error import (
+ CannotTransferOwnerToSelfError,
+ EmailCodeError,
+ InvalidEmailError,
+ InvalidTokenError,
+ MemberNotInTenantError,
+ NotOwnerError,
+ OwnerTransferLimitError,
+)
+from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
+ is_allow_transfer_owner,
setup_required,
)
from extensions.ext_database import db
from fields.member_fields import account_with_role_list_fields
+from libs.helper import extract_remote_ip
from libs.login import login_required
from models.account import Account, TenantAccountRole
-from services.account_service import RegisterService, TenantService
+from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
@@ -85,6 +97,7 @@ class MemberInviteEmailApi(Resource):
return {
"result": "success",
"invitation_results": invitation_results,
+ "tenant_id": str(current_user.current_tenant.id),
}, 201
@@ -110,7 +123,7 @@ class MemberCancelInviteApi(Resource):
except Exception as e:
raise ValueError(str(e))
- return {"result": "success"}, 204
+ return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200
class MemberUpdateRoleApi(Resource):
@@ -155,8 +168,146 @@ class DatasetOperatorMemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
+class SendOwnerTransferEmailApi(Resource):
+ """Send owner transfer email."""
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @is_allow_transfer_owner
+ def post(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("language", type=str, required=False, location="json")
+ args = parser.parse_args()
+ ip_address = extract_remote_ip(request)
+ if AccountService.is_email_send_ip_limit(ip_address):
+ raise EmailSendIpLimitError()
+
+ # check if the current user is the owner of the workspace
+ if not TenantService.is_owner(current_user, current_user.current_tenant):
+ raise NotOwnerError()
+
+ if args["language"] is not None and args["language"] == "zh-Hans":
+ language = "zh-Hans"
+ else:
+ language = "en-US"
+
+ email = current_user.email
+
+ token = AccountService.send_owner_transfer_email(
+ account=current_user,
+ email=email,
+ language=language,
+ workspace_name=current_user.current_tenant.name,
+ )
+
+ return {"result": "success", "data": token}
+
+
+class OwnerTransferCheckApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @is_allow_transfer_owner
+ def post(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("code", type=str, required=True, location="json")
+ parser.add_argument("token", type=str, required=True, nullable=False, location="json")
+ args = parser.parse_args()
+ # check if the current user is the owner of the workspace
+ if not TenantService.is_owner(current_user, current_user.current_tenant):
+ raise NotOwnerError()
+
+ user_email = current_user.email
+
+ is_owner_transfer_error_rate_limit = AccountService.is_owner_transfer_error_rate_limit(user_email)
+ if is_owner_transfer_error_rate_limit:
+ raise OwnerTransferLimitError()
+
+ token_data = AccountService.get_owner_transfer_data(args["token"])
+ if token_data is None:
+ raise InvalidTokenError()
+
+ if user_email != token_data.get("email"):
+ raise InvalidEmailError()
+
+ if args["code"] != token_data.get("code"):
+ AccountService.add_owner_transfer_error_rate_limit(user_email)
+ raise EmailCodeError()
+
+ # Verified, revoke the first token
+ AccountService.revoke_owner_transfer_token(args["token"])
+
+ # Refresh token data by generating a new token
+ _, new_token = AccountService.generate_owner_transfer_token(user_email, code=args["code"], additional_data={})
+
+ AccountService.reset_owner_transfer_error_rate_limit(user_email)
+ return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+
+
+class OwnerTransfer(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @is_allow_transfer_owner
+ def post(self, member_id):
+ parser = reqparse.RequestParser()
+ parser.add_argument("token", type=str, required=True, nullable=False, location="json")
+ args = parser.parse_args()
+
+ # check if the current user is the owner of the workspace
+ if not TenantService.is_owner(current_user, current_user.current_tenant):
+ raise NotOwnerError()
+
+ if current_user.id == str(member_id):
+ raise CannotTransferOwnerToSelfError()
+
+ transfer_token_data = AccountService.get_owner_transfer_data(args["token"])
+ if not transfer_token_data:
+ raise InvalidTokenError()
+
+ if transfer_token_data.get("email") != current_user.email:
+ raise InvalidEmailError()
+
+ AccountService.revoke_owner_transfer_token(args["token"])
+
+ member = db.session.get(Account, str(member_id))
+ if not member:
+ abort(404)
+ else:
+ member_account = member
+ if not TenantService.is_member(member_account, current_user.current_tenant):
+ raise MemberNotInTenantError()
+
+ try:
+ assert member is not None, "Member not found"
+ TenantService.update_member_role(current_user.current_tenant, member, "owner", current_user)
+
+ AccountService.send_new_owner_transfer_notify_email(
+ account=member,
+ email=member.email,
+ workspace_name=current_user.current_tenant.name,
+ )
+
+ AccountService.send_old_owner_transfer_notify_email(
+ account=current_user,
+ email=current_user.email,
+ workspace_name=current_user.current_tenant.name,
+ new_owner_email=member.email,
+ )
+
+ except Exception as e:
+ raise ValueError(str(e))
+
+ return {"result": "success"}
+
+
api.add_resource(MemberListApi, "/workspaces/current/members")
api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email")
api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/")
api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members//update-role")
api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators")
+# owner transfer
+api.add_resource(SendOwnerTransferEmailApi, "/workspaces/current/members/send-owner-transfer-confirm-email")
+api.add_resource(OwnerTransferCheckApi, "/workspaces/current/members/owner-transfer-check")
+api.add_resource(OwnerTransfer, "/workspaces/current/members//owner-transfer")
diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py
index 9bddbb4b4b..c0a4734828 100644
--- a/api/controllers/console/workspace/plugin.py
+++ b/api/controllers/console/workspace/plugin.py
@@ -13,6 +13,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import login_required
from models.account import TenantPluginPermission
+from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
@@ -497,6 +498,42 @@ class PluginFetchPermissionApi(Resource):
)
+class PluginFetchDynamicSelectOptionsApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self):
+ # check if the user is admin or owner
+ if not current_user.is_admin_or_owner:
+ raise Forbidden()
+
+ tenant_id = current_user.current_tenant_id
+ user_id = current_user.id
+
+ parser = reqparse.RequestParser()
+ parser.add_argument("plugin_id", type=str, required=True, location="args")
+ parser.add_argument("provider", type=str, required=True, location="args")
+ parser.add_argument("action", type=str, required=True, location="args")
+ parser.add_argument("parameter", type=str, required=True, location="args")
+ parser.add_argument("provider_type", type=str, required=True, location="args")
+ args = parser.parse_args()
+
+ try:
+ options = PluginParameterService.get_dynamic_select_options(
+ tenant_id,
+ user_id,
+ args["plugin_id"],
+ args["provider"],
+ args["action"],
+ args["parameter"],
+ args["provider_type"],
+ )
+ except PluginDaemonClientSideError as e:
+ raise ValueError(e)
+
+ return jsonable_encoder({"options": options})
+
+
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
@@ -521,3 +558,5 @@ api.add_resource(PluginFetchMarketplacePkgApi, "/workspaces/current/plugin/marke
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
+
+api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options")
diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py
index 2b1379bfb2..c70bf84d2a 100644
--- a/api/controllers/console/workspace/tool_providers.py
+++ b/api/controllers/console/workspace/tool_providers.py
@@ -1,25 +1,52 @@
import io
+from urllib.parse import urlparse
-from flask import send_file
+from flask import make_response, redirect, request, send_file
from flask_login import current_user
-from flask_restful import Resource, reqparse
-from sqlalchemy.orm import Session
+from flask_restful import (
+ Resource,
+ reqparse,
+)
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
-from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
+from controllers.console.wraps import (
+ account_initialization_required,
+ enterprise_license_required,
+ setup_required,
+)
+from core.mcp.auth.auth_flow import auth, handle_callback
+from core.mcp.auth.auth_provider import OAuthClientProvider
+from core.mcp.error import MCPAuthError, MCPError
+from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
-from extensions.ext_database import db
-from libs.helper import alphanumeric, uuid_value
+from core.plugin.entities.plugin import ToolProviderID
+from core.plugin.impl.oauth import OAuthHandler
+from core.tools.entities.tool_entities import CredentialType
+from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
+from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
+from services.tools.mcp_tools_manage_service import MCPToolManageService
from services.tools.tool_labels_service import ToolLabelsService
from services.tools.tools_manage_service import ToolCommonService
+from services.tools.tools_transform_service import ToolTransformService
from services.tools.workflow_tools_manage_service import WorkflowToolManageService
+def is_valid_url(url: str) -> bool:
+ if not url:
+ return False
+
+ try:
+ parsed = urlparse(url)
+ return all([parsed.scheme, parsed.netloc]) and parsed.scheme in ["http", "https"]
+ except Exception:
+ return False
+
+
class ToolProviderListApi(Resource):
@setup_required
@login_required
@@ -34,7 +61,7 @@ class ToolProviderListApi(Resource):
req.add_argument(
"type",
type=str,
- choices=["builtin", "model", "api", "workflow"],
+ choices=["builtin", "model", "api", "workflow", "mcp"],
required=False,
nullable=True,
location="args",
@@ -71,7 +98,7 @@ class ToolBuiltinProviderInfoApi(Resource):
user_id = user.id
tenant_id = user.current_tenant_id
- return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider))
+ return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
class ToolBuiltinProviderDeleteApi(Resource):
@@ -80,17 +107,47 @@ class ToolBuiltinProviderDeleteApi(Resource):
@account_initialization_required
def post(self, provider):
user = current_user
-
if not user.is_admin_or_owner:
raise Forbidden()
- user_id = user.id
tenant_id = user.current_tenant_id
+ req = reqparse.RequestParser()
+ req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
+ args = req.parse_args()
return BuiltinToolManageService.delete_builtin_tool_provider(
- user_id,
tenant_id,
provider,
+ args["credential_id"],
+ )
+
+
+class ToolBuiltinProviderAddApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self, provider):
+ user = current_user
+
+ user_id = user.id
+ tenant_id = user.current_tenant_id
+
+ parser = reqparse.RequestParser()
+ parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
+ parser.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
+ parser.add_argument("type", type=str, required=True, nullable=False, location="json")
+ args = parser.parse_args()
+
+ if args["type"] not in CredentialType.values():
+ raise ValueError(f"Invalid credential type: {args['type']}")
+
+ return BuiltinToolManageService.add_builtin_tool_provider(
+ user_id=user_id,
+ tenant_id=tenant_id,
+ provider=provider,
+ credentials=args["credentials"],
+ name=args["name"],
+ api_type=CredentialType.of(args["type"]),
)
@@ -108,19 +165,20 @@ class ToolBuiltinProviderUpdateApi(Resource):
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
- parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
+ parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
+ parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
+ parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
- with Session(db.engine) as session:
- result = BuiltinToolManageService.update_builtin_tool_provider(
- session=session,
- user_id=user_id,
- tenant_id=tenant_id,
- provider_name=provider,
- credentials=args["credentials"],
- )
- session.commit()
+ result = BuiltinToolManageService.update_builtin_tool_provider(
+ user_id=user_id,
+ tenant_id=tenant_id,
+ provider=provider,
+ credential_id=args["credential_id"],
+ credentials=args.get("credentials", None),
+ name=args.get("name", ""),
+ )
return result
@@ -131,9 +189,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
def get(self, provider):
tenant_id = current_user.current_tenant_id
- return BuiltinToolManageService.get_builtin_tool_provider_credentials(
- tenant_id=tenant_id,
- provider_name=provider,
+ return jsonable_encoder(
+ BuiltinToolManageService.get_builtin_tool_provider_credentials(
+ tenant_id=tenant_id,
+ provider_name=provider,
+ )
)
@@ -326,12 +386,15 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
- def get(self, provider):
+ def get(self, provider, credential_type):
user = current_user
-
tenant_id = user.current_tenant_id
- return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id)
+ return jsonable_encoder(
+ BuiltinToolManageService.list_builtin_provider_credentials_schema(
+ provider, CredentialType.of(credential_type), tenant_id
+ )
+ )
class ToolApiProviderSchemaApi(Resource):
@@ -568,15 +631,12 @@ class ToolApiListApi(Resource):
@account_initialization_required
def get(self):
user = current_user
-
- user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
provider.to_dict()
for provider in ApiToolManageService.list_api_tools(
- user_id,
tenant_id,
)
]
@@ -613,20 +673,369 @@ class ToolLabelsApi(Resource):
return jsonable_encoder(ToolLabelsService.list_tool_labels())
+class ToolPluginOAuthApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self, provider):
+ tool_provider = ToolProviderID(provider)
+ plugin_id = tool_provider.plugin_id
+ provider_name = tool_provider.provider_name
+
+ # todo check permission
+ user = current_user
+
+ if not user.is_admin_or_owner:
+ raise Forbidden()
+
+ tenant_id = user.current_tenant_id
+ oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
+ if oauth_client_params is None:
+ raise Forbidden("no oauth available client config found for this tool provider")
+
+ oauth_handler = OAuthHandler()
+ context_id = OAuthProxyService.create_proxy_context(
+ user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
+ )
+ redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
+ authorization_url_response = oauth_handler.get_authorization_url(
+ tenant_id=tenant_id,
+ user_id=user.id,
+ plugin_id=plugin_id,
+ provider=provider_name,
+ redirect_uri=redirect_uri,
+ system_credentials=oauth_client_params,
+ )
+ response = make_response(jsonable_encoder(authorization_url_response))
+ response.set_cookie(
+ "context_id",
+ context_id,
+ httponly=True,
+ samesite="Lax",
+ max_age=OAuthProxyService.__MAX_AGE__,
+ )
+ return response
+
+
+class ToolOAuthCallback(Resource):
+ @setup_required
+ def get(self, provider):
+ context_id = request.cookies.get("context_id")
+ if not context_id:
+ raise Forbidden("context_id not found")
+
+ context = OAuthProxyService.use_proxy_context(context_id)
+ if context is None:
+ raise Forbidden("Invalid context_id")
+
+ tool_provider = ToolProviderID(provider)
+ plugin_id = tool_provider.plugin_id
+ provider_name = tool_provider.provider_name
+ user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
+
+ oauth_handler = OAuthHandler()
+ oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
+ if oauth_client_params is None:
+ raise Forbidden("no oauth available client config found for this tool provider")
+
+ redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
+ credentials = oauth_handler.get_credentials(
+ tenant_id=tenant_id,
+ user_id=user_id,
+ plugin_id=plugin_id,
+ provider=provider_name,
+ redirect_uri=redirect_uri,
+ system_credentials=oauth_client_params,
+ request=request,
+ ).credentials
+
+ if not credentials:
+ raise Exception("the plugin credentials failed")
+
+ # add credentials to database
+ BuiltinToolManageService.add_builtin_tool_provider(
+ user_id=user_id,
+ tenant_id=tenant_id,
+ provider=provider,
+ credentials=dict(credentials),
+ api_type=CredentialType.OAUTH2,
+ )
+ return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
+
+
+class ToolBuiltinProviderSetDefaultApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self, provider):
+ parser = reqparse.RequestParser()
+ parser.add_argument("id", type=str, required=True, nullable=False, location="json")
+ args = parser.parse_args()
+ return BuiltinToolManageService.set_default_provider(
+ tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
+ )
+
+
+class ToolOAuthCustomClient(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self, provider):
+ parser = reqparse.RequestParser()
+ parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
+ parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
+ args = parser.parse_args()
+
+ user = current_user
+
+ if not user.is_admin_or_owner:
+ raise Forbidden()
+
+ return BuiltinToolManageService.save_custom_oauth_client_params(
+ tenant_id=user.current_tenant_id,
+ provider=provider,
+ client_params=args.get("client_params", {}),
+ enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
+ )
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self, provider):
+ return jsonable_encoder(
+ BuiltinToolManageService.get_custom_oauth_client_params(
+ tenant_id=current_user.current_tenant_id, provider=provider
+ )
+ )
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def delete(self, provider):
+ return jsonable_encoder(
+ BuiltinToolManageService.delete_custom_oauth_client_params(
+ tenant_id=current_user.current_tenant_id, provider=provider
+ )
+ )
+
+
+class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self, provider):
+ return jsonable_encoder(
+ BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
+ tenant_id=current_user.current_tenant_id, provider_name=provider
+ )
+ )
+
+
+class ToolBuiltinProviderGetCredentialInfoApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self, provider):
+ tenant_id = current_user.current_tenant_id
+
+ return jsonable_encoder(
+ BuiltinToolManageService.get_builtin_tool_provider_credential_info(
+ tenant_id=tenant_id,
+ provider=provider,
+ )
+ )
+
+
+class ToolProviderMCPApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("server_url", type=str, required=True, nullable=False, location="json")
+ parser.add_argument("name", type=str, required=True, nullable=False, location="json")
+ parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
+ parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
+ parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json", default="")
+ parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
+ args = parser.parse_args()
+ user = current_user
+ if not is_valid_url(args["server_url"]):
+ raise ValueError("Server URL is not valid.")
+ return jsonable_encoder(
+ MCPToolManageService.create_mcp_provider(
+ tenant_id=user.current_tenant_id,
+ server_url=args["server_url"],
+ name=args["name"],
+ icon=args["icon"],
+ icon_type=args["icon_type"],
+ icon_background=args["icon_background"],
+ user_id=user.id,
+ server_identifier=args["server_identifier"],
+ )
+ )
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def put(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("server_url", type=str, required=True, nullable=False, location="json")
+ parser.add_argument("name", type=str, required=True, nullable=False, location="json")
+ parser.add_argument("icon", type=str, required=True, nullable=False, location="json")
+ parser.add_argument("icon_type", type=str, required=True, nullable=False, location="json")
+ parser.add_argument("icon_background", type=str, required=False, nullable=True, location="json")
+ parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
+ parser.add_argument("server_identifier", type=str, required=True, nullable=False, location="json")
+ args = parser.parse_args()
+ if not is_valid_url(args["server_url"]):
+ if "[__HIDDEN__]" in args["server_url"]:
+ pass
+ else:
+ raise ValueError("Server URL is not valid.")
+ MCPToolManageService.update_mcp_provider(
+ tenant_id=current_user.current_tenant_id,
+ provider_id=args["provider_id"],
+ server_url=args["server_url"],
+ name=args["name"],
+ icon=args["icon"],
+ icon_type=args["icon_type"],
+ icon_background=args["icon_background"],
+ server_identifier=args["server_identifier"],
+ )
+ return {"result": "success"}
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def delete(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
+ args = parser.parse_args()
+ MCPToolManageService.delete_mcp_tool(tenant_id=current_user.current_tenant_id, provider_id=args["provider_id"])
+ return {"result": "success"}
+
+
+class ToolMCPAuthApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("provider_id", type=str, required=True, nullable=False, location="json")
+ parser.add_argument("authorization_code", type=str, required=False, nullable=True, location="json")
+ args = parser.parse_args()
+ provider_id = args["provider_id"]
+ tenant_id = current_user.current_tenant_id
+ provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
+ if not provider:
+ raise ValueError("provider not found")
+ try:
+ with MCPClient(
+ provider.decrypted_server_url,
+ provider_id,
+ tenant_id,
+ authed=False,
+ authorization_code=args["authorization_code"],
+ for_list=True,
+ ):
+ MCPToolManageService.update_mcp_provider_credentials(
+ mcp_provider=provider,
+ credentials=provider.decrypted_credentials,
+ authed=True,
+ )
+ return {"result": "success"}
+
+ except MCPAuthError:
+ auth_provider = OAuthClientProvider(provider_id, tenant_id, for_list=True)
+ return auth(auth_provider, provider.decrypted_server_url, args["authorization_code"])
+ except MCPError as e:
+ MCPToolManageService.update_mcp_provider_credentials(
+ mcp_provider=provider,
+ credentials={},
+ authed=False,
+ )
+ raise ValueError(f"Failed to connect to MCP server: {e}") from e
+
+
+class ToolMCPDetailApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self, provider_id):
+ user = current_user
+ provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, user.current_tenant_id)
+ return jsonable_encoder(ToolTransformService.mcp_provider_to_user_provider(provider, for_list=True))
+
+
+class ToolMCPListAllApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self):
+ user = current_user
+ tenant_id = user.current_tenant_id
+
+ tools = MCPToolManageService.retrieve_mcp_tools(tenant_id=tenant_id)
+
+ return [tool.to_dict() for tool in tools]
+
+
+class ToolMCPUpdateApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self, provider_id):
+ tenant_id = current_user.current_tenant_id
+ tools = MCPToolManageService.list_mcp_tool_from_remote_server(
+ tenant_id=tenant_id,
+ provider_id=provider_id,
+ )
+ return jsonable_encoder(tools)
+
+
+class ToolMCPCallbackApi(Resource):
+ def get(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("code", type=str, required=True, nullable=False, location="args")
+ parser.add_argument("state", type=str, required=True, nullable=False, location="args")
+ args = parser.parse_args()
+ state_key = args["state"]
+ authorization_code = args["code"]
+ handle_callback(state_key, authorization_code)
+ return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
+
+
# tool provider
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
+# tool oauth
+api.add_resource(ToolPluginOAuthApi, "/oauth/plugin//tool/authorization-url")
+api.add_resource(ToolOAuthCallback, "/oauth/plugin//tool/callback")
+api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin//oauth/custom-client")
+
# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools")
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin//info")
+api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin//add")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update")
+api.add_resource(
+ ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin//default-credential"
+)
+api.add_resource(
+ ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin//credential/info"
+)
api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials"
)
api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi,
- "/workspaces/current/tool-provider/builtin//credentials_schema",
+ "/workspaces/current/tool-provider/builtin//credential/schema/",
+)
+api.add_resource(
+ ToolBuiltinProviderGetOauthClientSchemaApi,
+ "/workspaces/current/tool-provider/builtin//oauth/client-schema",
)
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon")
@@ -647,8 +1056,15 @@ api.add_resource(ToolWorkflowProviderDeleteApi, "/workspaces/current/tool-provid
api.add_resource(ToolWorkflowProviderGetApi, "/workspaces/current/tool-provider/workflow/get")
api.add_resource(ToolWorkflowProviderListToolApi, "/workspaces/current/tool-provider/workflow/tools")
+# mcp tool provider
+api.add_resource(ToolMCPDetailApi, "/workspaces/current/tool-provider/mcp/tools/")
+api.add_resource(ToolProviderMCPApi, "/workspaces/current/tool-provider/mcp")
+api.add_resource(ToolMCPUpdateApi, "/workspaces/current/tool-provider/mcp/update/")
+api.add_resource(ToolMCPAuthApi, "/workspaces/current/tool-provider/mcp/auth")
+api.add_resource(ToolMCPCallbackApi, "/mcp/oauth/callback")
+
api.add_resource(ToolBuiltinListApi, "/workspaces/current/tools/builtin")
api.add_resource(ToolApiListApi, "/workspaces/current/tools/api")
+api.add_resource(ToolMCPListAllApi, "/workspaces/current/tools/mcp")
api.add_resource(ToolWorkflowListApi, "/workspaces/current/tools/workflow")
-
api.add_resource(ToolLabelsApi, "/workspaces/current/tool-labels")
diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py
index ca122772de..d862dac373 100644
--- a/api/controllers/console/wraps.py
+++ b/api/controllers/console/wraps.py
@@ -235,3 +235,29 @@ def email_password_login_enabled(view):
abort(403)
return decorated
+
+
+def enable_change_email(view):
+ @wraps(view)
+ def decorated(*args, **kwargs):
+ features = FeatureService.get_system_features()
+ if features.enable_change_email:
+ return view(*args, **kwargs)
+
+ # otherwise, return 403
+ abort(403)
+
+ return decorated
+
+
+def is_allow_transfer_owner(view):
+ @wraps(view)
+ def decorated(*args, **kwargs):
+ features = FeatureService.get_features(current_user.current_tenant_id)
+ if features.is_allow_transfer_workspace:
+ return view(*args, **kwargs)
+
+ # otherwise, return 403
+ abort(403)
+
+ return decorated
diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py
index f1a15793c7..15f93d2774 100644
--- a/api/controllers/files/upload.py
+++ b/api/controllers/files/upload.py
@@ -87,7 +87,5 @@ class PluginUploadFileApi(Resource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
- return tool_file, 201
-
api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")
diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py
index 41063b35a5..5dfe41eb6b 100644
--- a/api/controllers/inner_api/plugin/plugin.py
+++ b/api/controllers/inner_api/plugin/plugin.py
@@ -17,6 +17,7 @@ from core.plugin.entities.request import (
RequestInvokeApp,
RequestInvokeEncrypt,
RequestInvokeLLM,
+ RequestInvokeLLMWithStructuredOutput,
RequestInvokeModeration,
RequestInvokeParameterExtractorNode,
RequestInvokeQuestionClassifierNode,
@@ -47,6 +48,21 @@ class PluginInvokeLLMApi(Resource):
return length_prefixed_response(0xF, generator())
+class PluginInvokeLLMWithStructuredOutputApi(Resource):
+ @setup_required
+ @plugin_inner_api_only
+ @get_user_tenant
+ @plugin_data(payload_type=RequestInvokeLLMWithStructuredOutput)
+ def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLMWithStructuredOutput):
+ def generator():
+ response = PluginModelBackwardsInvocation.invoke_llm_with_structured_output(
+ user_model.id, tenant_model, payload
+ )
+ return PluginModelBackwardsInvocation.convert_to_event_stream(response)
+
+ return length_prefixed_response(0xF, generator())
+
+
class PluginInvokeTextEmbeddingApi(Resource):
@setup_required
@plugin_inner_api_only
@@ -159,6 +175,7 @@ class PluginInvokeToolApi(Resource):
provider=payload.provider,
tool_name=payload.tool,
tool_parameters=payload.tool_parameters,
+ credential_id=payload.credential_id,
),
)
@@ -291,6 +308,7 @@ class PluginFetchAppInfoApi(Resource):
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
+api.add_resource(PluginInvokeLLMWithStructuredOutputApi, "/invoke/llm/structured-output")
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py
index a2fc2d4675..77568b75f1 100644
--- a/api/controllers/inner_api/workspace/workspace.py
+++ b/api/controllers/inner_api/workspace/workspace.py
@@ -29,7 +29,19 @@ class EnterpriseWorkspace(Resource):
tenant_was_created.send(tenant)
- return {"message": "enterprise workspace created."}
+ resp = {
+ "id": tenant.id,
+ "name": tenant.name,
+ "plan": tenant.plan,
+ "status": tenant.status,
+ "created_at": tenant.created_at.isoformat() + "Z" if tenant.created_at else None,
+ "updated_at": tenant.updated_at.isoformat() + "Z" if tenant.updated_at else None,
+ }
+
+ return {
+ "message": "enterprise workspace created.",
+ "tenant": resp,
+ }
class EnterpriseWorkspaceNoOwnerEmail(Resource):
diff --git a/api/controllers/mcp/__init__.py b/api/controllers/mcp/__init__.py
new file mode 100644
index 0000000000..1b3e0a5621
--- /dev/null
+++ b/api/controllers/mcp/__init__.py
@@ -0,0 +1,8 @@
+from flask import Blueprint
+
+from libs.external_api import ExternalApi
+
+bp = Blueprint("mcp", __name__, url_prefix="/mcp")
+api = ExternalApi(bp)
+
+from . import mcp
diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py
new file mode 100644
index 0000000000..ead728bfb0
--- /dev/null
+++ b/api/controllers/mcp/mcp.py
@@ -0,0 +1,104 @@
+from flask_restful import Resource, reqparse
+from pydantic import ValidationError
+
+from controllers.console.app.mcp_server import AppMCPServerStatus
+from controllers.mcp import api
+from core.app.app_config.entities import VariableEntity
+from core.mcp import types
+from core.mcp.server.streamable_http import MCPServerStreamableHTTPRequestHandler
+from core.mcp.types import ClientNotification, ClientRequest
+from core.mcp.utils import create_mcp_error_response
+from extensions.ext_database import db
+from libs import helper
+from models.model import App, AppMCPServer, AppMode
+
+
+class MCPAppApi(Resource):
+ def post(self, server_code):
+ def int_or_str(value):
+ if isinstance(value, (int, str)):
+ return value
+ else:
+ return None
+
+ parser = reqparse.RequestParser()
+ parser.add_argument("jsonrpc", type=str, required=True, location="json")
+ parser.add_argument("method", type=str, required=True, location="json")
+ parser.add_argument("params", type=dict, required=False, location="json")
+ parser.add_argument("id", type=int_or_str, required=False, location="json")
+ args = parser.parse_args()
+
+ request_id = args.get("id")
+
+ server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
+ if not server:
+ return helper.compact_generate_response(
+ create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
+ )
+
+ if server.status != AppMCPServerStatus.ACTIVE:
+ return helper.compact_generate_response(
+ create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
+ )
+
+ app = db.session.query(App).filter(App.id == server.app_id).first()
+ if not app:
+ return helper.compact_generate_response(
+ create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")
+ )
+
+ if app.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
+ workflow = app.workflow
+ if workflow is None:
+ return helper.compact_generate_response(
+ create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
+ )
+
+ user_input_form = workflow.user_input_form(to_old_structure=True)
+ else:
+ app_model_config = app.app_model_config
+ if app_model_config is None:
+ return helper.compact_generate_response(
+ create_mcp_error_response(request_id, types.INVALID_REQUEST, "App is unavailable")
+ )
+
+ features_dict = app_model_config.to_dict()
+ user_input_form = features_dict.get("user_input_form", [])
+ converted_user_input_form: list[VariableEntity] = []
+ try:
+ for item in user_input_form:
+ variable_type = item.get("type", "") or list(item.keys())[0]
+ variable = item[variable_type]
+ converted_user_input_form.append(
+ VariableEntity(
+ type=variable_type,
+ variable=variable.get("variable"),
+ description=variable.get("description") or "",
+ label=variable.get("label"),
+ required=variable.get("required", False),
+ max_length=variable.get("max_length"),
+ options=variable.get("options") or [],
+ )
+ )
+ except ValidationError as e:
+ return helper.compact_generate_response(
+ create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid user_input_form: {str(e)}")
+ )
+
+ try:
+ request: ClientRequest | ClientNotification = ClientRequest.model_validate(args)
+ except ValidationError as e:
+ try:
+ notification = ClientNotification.model_validate(args)
+ request = notification
+ except ValidationError as e:
+ return helper.compact_generate_response(
+ create_mcp_error_response(request_id, types.INVALID_PARAMS, f"Invalid MCP request: {str(e)}")
+ )
+
+ mcp_server_handler = MCPServerStreamableHTTPRequestHandler(app, request, converted_user_input_form)
+ response = mcp_server_handler.handle()
+ return helper.compact_generate_response(response)
+
+
+api.add_resource(MCPAppApi, "/server//mcp")
diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py
index 2682c2e7f1..848863cf1b 100644
--- a/api/controllers/service_api/app/audio.py
+++ b/api/controllers/service_api/app/audio.py
@@ -20,7 +20,7 @@ from controllers.service_api.app.error import (
from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
-from models.model import App, AppMode, EndUser
+from models.model import App, EndUser
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@@ -78,20 +78,9 @@ class TextApi(Resource):
message_id = args.get("message_id", None)
text = args.get("text", None)
- if (
- app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
- and app_model.workflow
- and app_model.workflow.features_dict
- ):
- text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {})
- voice = args.get("voice") or text_to_speech.get("voice")
- else:
- try:
- voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
- except Exception:
- voice = None
+ voice = args.get("voice", None)
response = AudioService.transcript_tts(
- app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text
+ app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
)
return response
diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py
index df52b49424..ac2ebf2b09 100644
--- a/api/controllers/service_api/app/workflow.py
+++ b/api/controllers/service_api/app/workflow.py
@@ -3,7 +3,7 @@ import logging
from dateutil.parser import isoparse
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
-from sqlalchemy.orm import Session
+from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import InternalServerError
from controllers.service_api import api
@@ -30,7 +30,7 @@ from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs import helper
from libs.helper import TimestampField
from models.model import App, AppMode, EndUser
-from models.workflow import WorkflowRun
+from repositories.factory import DifyAPIRepositoryFactory
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
from services.workflow_app_service import WorkflowAppService
@@ -63,7 +63,15 @@ class WorkflowRunDetailApi(Resource):
if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]:
raise NotWorkflowAppError()
- workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
+ # Use repository to get workflow run
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+
+ workflow_run = workflow_run_repo.get_workflow_run_by_id(
+ tenant_id=app_model.tenant_id,
+ app_id=app_model.id,
+ run_id=workflow_run_id,
+ )
return workflow_run
@@ -135,6 +143,20 @@ class WorkflowAppLogApi(Resource):
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument("created_at__before", type=str, location="args")
parser.add_argument("created_at__after", type=str, location="args")
+ parser.add_argument(
+ "created_by_end_user_session_id",
+ type=str,
+ location="args",
+ required=False,
+ default=None,
+ )
+ parser.add_argument(
+ "created_by_account",
+ type=str,
+ location="args",
+ required=False,
+ default=None,
+ )
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args()
@@ -158,6 +180,8 @@ class WorkflowAppLogApi(Resource):
created_at_after=args.created_at__after,
page=args.page,
limit=args.limit,
+ created_by_end_user_session_id=args.created_by_end_user_session_id,
+ created_by_account=args.created_by_account,
)
return workflow_app_log_pagination
diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py
index 27e8dd3fa6..a499719fc3 100644
--- a/api/controllers/service_api/dataset/dataset.py
+++ b/api/controllers/service_api/dataset/dataset.py
@@ -4,8 +4,12 @@ from werkzeug.exceptions import Forbidden, NotFound
import services.dataset_service
from controllers.service_api import api
-from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError
-from controllers.service_api.wraps import DatasetApiResource, validate_dataset_token
+from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
+from controllers.service_api.wraps import (
+ DatasetApiResource,
+ cloud_edition_billing_rate_limit_check,
+ validate_dataset_token,
+)
from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager
@@ -13,7 +17,7 @@ from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import tag_fields
from libs.login import current_user
from models.dataset import Dataset, DatasetPermissionEnum
-from services.dataset_service import DatasetPermissionService, DatasetService
+from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService
@@ -70,6 +74,7 @@ class DatasetListApi(DatasetApiResource):
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response, 200
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id):
"""Resource for creating datasets."""
parser = reqparse.RequestParser()
@@ -128,6 +133,22 @@ class DatasetListApi(DatasetApiResource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
+
+ if args.get("embedding_model_provider"):
+ DatasetService.check_embedding_model_setting(
+ tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
+ )
+ if (
+ args.get("retrieval_model")
+ and args.get("retrieval_model").get("reranking_model")
+ and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
+ ):
+ DatasetService.check_reranking_model_setting(
+ tenant_id,
+ args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
+ args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
+ )
+
try:
dataset = DatasetService.create_empty_dataset(
tenant_id=tenant_id,
@@ -193,6 +214,7 @@ class DatasetApi(DatasetApiResource):
return data, 200
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, _, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@@ -259,10 +281,20 @@ class DatasetApi(DatasetApiResource):
data = request.get_json()
# check embedding model setting
- if data.get("indexing_technique") == "high_quality":
+ if data.get("indexing_technique") == "high_quality" or data.get("embedding_model_provider"):
DatasetService.check_embedding_model_setting(
dataset.tenant_id, data.get("embedding_model_provider"), data.get("embedding_model")
)
+ if (
+ data.get("retrieval_model")
+ and data.get("retrieval_model").get("reranking_model")
+ and data.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
+ ):
+ DatasetService.check_reranking_model_setting(
+ dataset.tenant_id,
+ data.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
+ data.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
+ )
# The role of the current user in the ta table must be admin, owner, editor, or dataset_operator
DatasetPermissionService.check_permission(
@@ -293,6 +325,7 @@ class DatasetApi(DatasetApiResource):
return result_data, 200
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, _, dataset_id):
"""
Deletes a dataset given its ID.
@@ -322,6 +355,56 @@ class DatasetApi(DatasetApiResource):
raise DatasetInUseError()
+class DocumentStatusApi(DatasetApiResource):
+ """Resource for batch document status operations."""
+
+ def patch(self, tenant_id, dataset_id, action):
+ """
+ Batch update document status.
+
+ Args:
+ tenant_id: tenant id
+ dataset_id: dataset id
+ action: action to perform (enable, disable, archive, un_archive)
+
+ Returns:
+ dict: A dictionary with a key 'result' and a value 'success'
+ int: HTTP status code 200 indicating that the operation was successful.
+
+ Raises:
+ NotFound: If the dataset with the given ID does not exist.
+ Forbidden: If the user does not have permission.
+ InvalidActionError: If the action is invalid or cannot be performed.
+ """
+ dataset_id_str = str(dataset_id)
+ dataset = DatasetService.get_dataset(dataset_id_str)
+
+ if dataset is None:
+ raise NotFound("Dataset not found.")
+
+ # Check user's permission
+ try:
+ DatasetService.check_dataset_permission(dataset, current_user)
+ except services.errors.account.NoPermissionError as e:
+ raise Forbidden(str(e))
+
+ # Check dataset model setting
+ DatasetService.check_dataset_model_setting(dataset)
+
+ # Get document IDs from request body
+ data = request.get_json()
+ document_ids = data.get("document_ids", [])
+
+ try:
+ DocumentService.batch_update_document_status(dataset, document_ids, action, current_user)
+ except services.errors.document.DocumentIndexingError as e:
+ raise InvalidActionError(str(e))
+ except ValueError as e:
+ raise InvalidActionError(str(e))
+
+ return {"result": "success"}, 200
+
+
class DatasetTagsApi(DatasetApiResource):
@validate_dataset_token
@marshal_with(tag_fields)
@@ -450,6 +533,7 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
api.add_resource(DatasetListApi, "/datasets")
api.add_resource(DatasetApi, "/datasets/")
+api.add_resource(DocumentStatusApi, "/datasets//documents/status/")
api.add_resource(DatasetTagsApi, "/datasets/tags")
api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding")
api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding")
diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py
index ab7ab4dcf0..d571b21a0a 100644
--- a/api/controllers/service_api/dataset/document.py
+++ b/api/controllers/service_api/dataset/document.py
@@ -3,7 +3,7 @@ import json
from flask import request
from flask_restful import marshal, reqparse
from sqlalchemy import desc, select
-from werkzeug.exceptions import NotFound
+from werkzeug.exceptions import Forbidden, NotFound
import services
from controllers.common.errors import FilenameNotExistsError
@@ -18,14 +18,19 @@ from controllers.service_api.app.error import (
from controllers.service_api.dataset.error import (
ArchivedDocumentImmutableError,
DocumentIndexingError,
+ InvalidMetadataError,
+)
+from controllers.service_api.wraps import (
+ DatasetApiResource,
+ cloud_edition_billing_rate_limit_check,
+ cloud_edition_billing_resource_check,
)
-from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
from core.errors.error import ProviderTokenNotInitError
from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment
-from services.dataset_service import DocumentService
+from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService
@@ -35,6 +40,7 @@ class DocumentAddByTextApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Create document by text."""
parser = reqparse.RequestParser()
@@ -54,6 +60,7 @@ class DocumentAddByTextApi(DatasetApiResource):
parser.add_argument("embedding_model_provider", type=str, required=False, nullable=True, location="json")
args = parser.parse_args()
+
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
@@ -69,6 +76,21 @@ class DocumentAddByTextApi(DatasetApiResource):
if text is None or name is None:
raise ValueError("Both 'text' and 'name' must be non-null values.")
+ if args.get("embedding_model_provider"):
+ DatasetService.check_embedding_model_setting(
+ tenant_id, args.get("embedding_model_provider"), args.get("embedding_model")
+ )
+ if (
+ args.get("retrieval_model")
+ and args.get("retrieval_model").get("reranking_model")
+ and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
+ ):
+ DatasetService.check_reranking_model_setting(
+ tenant_id,
+ args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
+ args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
+ )
+
upload_file = FileService.upload_text(text=str(text), text_name=str(name))
data_source = {
"type": "upload_file",
@@ -99,6 +121,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents."""
@cloud_edition_billing_resource_check("vector_space", "dataset")
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id):
"""Update document by text."""
parser = reqparse.RequestParser()
@@ -118,6 +141,17 @@ class DocumentUpdateByTextApi(DatasetApiResource):
if not dataset:
raise ValueError("Dataset does not exist.")
+ if (
+ args.get("retrieval_model")
+ and args.get("retrieval_model").get("reranking_model")
+ and args.get("retrieval_model").get("reranking_model").get("reranking_provider_name")
+ ):
+ DatasetService.check_reranking_model_setting(
+ tenant_id,
+ args.get("retrieval_model").get("reranking_model").get("reranking_provider_name"),
+ args.get("retrieval_model").get("reranking_model").get("reranking_model_name"),
+ )
+
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
@@ -158,6 +192,7 @@ class DocumentAddByFileApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset")
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
"""Create document by upload file."""
args = {}
@@ -176,11 +211,29 @@ class DocumentAddByFileApi(DatasetApiResource):
if not dataset:
raise ValueError("Dataset does not exist.")
+ if dataset.provider == "external":
+ raise ValueError("External datasets are not supported.")
+
indexing_technique = args.get("indexing_technique") or dataset.indexing_technique
if not indexing_technique:
raise ValueError("indexing_technique is required.")
args["indexing_technique"] = indexing_technique
+ if "embedding_model_provider" in args:
+ DatasetService.check_embedding_model_setting(
+ tenant_id, args["embedding_model_provider"], args["embedding_model"]
+ )
+ if (
+ "retrieval_model" in args
+ and args["retrieval_model"].get("reranking_model")
+ and args["retrieval_model"].get("reranking_model").get("reranking_provider_name")
+ ):
+ DatasetService.check_reranking_model_setting(
+ tenant_id,
+ args["retrieval_model"].get("reranking_model").get("reranking_provider_name"),
+ args["retrieval_model"].get("reranking_model").get("reranking_model_name"),
+ )
+
# save file info
file = request.files["file"]
# check file
@@ -232,6 +285,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
"""Resource for update documents."""
@cloud_edition_billing_resource_check("vector_space", "dataset")
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id):
"""Update document by upload file."""
args = {}
@@ -250,6 +304,9 @@ class DocumentUpdateByFileApi(DatasetApiResource):
if not dataset:
raise ValueError("Dataset does not exist.")
+ if dataset.provider == "external":
+ raise ValueError("External datasets are not supported.")
+
# indexing_technique is already set in dataset since this is an update
args["indexing_technique"] = dataset.indexing_technique
@@ -302,6 +359,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
class DocumentDeleteApi(DatasetApiResource):
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, document_id):
"""Delete document."""
document_id = str(document_id)
@@ -415,6 +473,101 @@ class DocumentIndexingStatusApi(DatasetApiResource):
return data
+class DocumentDetailApi(DatasetApiResource):
+ METADATA_CHOICES = {"all", "only", "without"}
+
+ def get(self, tenant_id, dataset_id, document_id):
+ dataset_id = str(dataset_id)
+ document_id = str(document_id)
+
+ dataset = self.get_dataset(dataset_id, tenant_id)
+
+ document = DocumentService.get_document(dataset.id, document_id)
+
+ if not document:
+ raise NotFound("Document not found.")
+
+ if document.tenant_id != str(tenant_id):
+ raise Forbidden("No permission.")
+
+ metadata = request.args.get("metadata", "all")
+ if metadata not in self.METADATA_CHOICES:
+ raise InvalidMetadataError(f"Invalid metadata value: {metadata}")
+
+ if metadata == "only":
+ response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata_details}
+ elif metadata == "without":
+ dataset_process_rules = DatasetService.get_process_rules(dataset_id)
+ document_process_rules = document.dataset_process_rule.to_dict()
+ data_source_info = document.data_source_detail_dict
+ response = {
+ "id": document.id,
+ "position": document.position,
+ "data_source_type": document.data_source_type,
+ "data_source_info": data_source_info,
+ "dataset_process_rule_id": document.dataset_process_rule_id,
+ "dataset_process_rule": dataset_process_rules,
+ "document_process_rule": document_process_rules,
+ "name": document.name,
+ "created_from": document.created_from,
+ "created_by": document.created_by,
+ "created_at": document.created_at.timestamp(),
+ "tokens": document.tokens,
+ "indexing_status": document.indexing_status,
+ "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
+ "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None,
+ "indexing_latency": document.indexing_latency,
+ "error": document.error,
+ "enabled": document.enabled,
+ "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None,
+ "disabled_by": document.disabled_by,
+ "archived": document.archived,
+ "segment_count": document.segment_count,
+ "average_segment_length": document.average_segment_length,
+ "hit_count": document.hit_count,
+ "display_status": document.display_status,
+ "doc_form": document.doc_form,
+ "doc_language": document.doc_language,
+ }
+ else:
+ dataset_process_rules = DatasetService.get_process_rules(dataset_id)
+ document_process_rules = document.dataset_process_rule.to_dict()
+ data_source_info = document.data_source_detail_dict
+ response = {
+ "id": document.id,
+ "position": document.position,
+ "data_source_type": document.data_source_type,
+ "data_source_info": data_source_info,
+ "dataset_process_rule_id": document.dataset_process_rule_id,
+ "dataset_process_rule": dataset_process_rules,
+ "document_process_rule": document_process_rules,
+ "name": document.name,
+ "created_from": document.created_from,
+ "created_by": document.created_by,
+ "created_at": document.created_at.timestamp(),
+ "tokens": document.tokens,
+ "indexing_status": document.indexing_status,
+ "completed_at": int(document.completed_at.timestamp()) if document.completed_at else None,
+ "updated_at": int(document.updated_at.timestamp()) if document.updated_at else None,
+ "indexing_latency": document.indexing_latency,
+ "error": document.error,
+ "enabled": document.enabled,
+ "disabled_at": int(document.disabled_at.timestamp()) if document.disabled_at else None,
+ "disabled_by": document.disabled_by,
+ "archived": document.archived,
+ "doc_type": document.doc_type,
+ "doc_metadata": document.doc_metadata_details,
+ "segment_count": document.segment_count,
+ "average_segment_length": document.average_segment_length,
+ "hit_count": document.hit_count,
+ "display_status": document.display_status,
+ "doc_form": document.doc_form,
+ "doc_language": document.doc_language,
+ }
+
+ return response
+
+
api.add_resource(
DocumentAddByTextApi,
"/datasets//document/create_by_text",
@@ -438,3 +591,4 @@ api.add_resource(
api.add_resource(DocumentDeleteApi, "/datasets//documents/")
api.add_resource(DocumentListApi, "/datasets//documents")
api.add_resource(DocumentIndexingStatusApi, "/datasets//documents//indexing-status")
+api.add_resource(DocumentDetailApi, "/datasets//documents/")
diff --git a/api/controllers/service_api/dataset/error.py b/api/controllers/service_api/dataset/error.py
index 5ff5e08c72..ecc47b40a1 100644
--- a/api/controllers/service_api/dataset/error.py
+++ b/api/controllers/service_api/dataset/error.py
@@ -25,12 +25,6 @@ class UnsupportedFileTypeError(BaseHTTPException):
code = 415
-class HighQualityDatasetOnlyError(BaseHTTPException):
- error_code = "high_quality_dataset_only"
- description = "Current operation only supports 'high-quality' datasets."
- code = 400
-
-
class DatasetNotInitializedError(BaseHTTPException):
error_code = "dataset_not_initialized"
description = "The dataset is still being initialized or indexing. Please wait a moment."
diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py
index 465f71bf03..52e9bca5da 100644
--- a/api/controllers/service_api/dataset/hit_testing.py
+++ b/api/controllers/service_api/dataset/hit_testing.py
@@ -1,9 +1,10 @@
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.service_api import api
-from controllers.service_api.wraps import DatasetApiResource
+from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
dataset_id_str = str(dataset_id)
diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py
index 35582feea0..1968696ee5 100644
--- a/api/controllers/service_api/dataset/metadata.py
+++ b/api/controllers/service_api/dataset/metadata.py
@@ -3,7 +3,7 @@ from flask_restful import marshal, reqparse
from werkzeug.exceptions import NotFound
from controllers.service_api import api
-from controllers.service_api.wraps import DatasetApiResource
+from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
from fields.dataset_fields import dataset_metadata_fields
from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import (
@@ -14,6 +14,7 @@ from services.metadata_service import MetadataService
class DatasetMetadataCreateServiceApi(DatasetApiResource):
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, nullable=True, location="json")
@@ -39,6 +40,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
class DatasetMetadataServiceApi(DatasetApiResource):
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id, dataset_id, metadata_id):
parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=True, location="json")
@@ -54,6 +56,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
return marshal(metadata, dataset_metadata_fields), 200
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, metadata_id):
dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id)
@@ -73,6 +76,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, action):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
@@ -88,6 +92,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
class DocumentMetadataEditServiceApi(DatasetApiResource):
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py
index 337752275a..403b7f0a0c 100644
--- a/api/controllers/service_api/dataset/segment.py
+++ b/api/controllers/service_api/dataset/segment.py
@@ -8,6 +8,7 @@ from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_knowledge_limit_check,
+ cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
)
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
@@ -35,6 +36,7 @@ class SegmentApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id):
"""Create single segment."""
# check dataset
@@ -139,6 +141,7 @@ class SegmentApi(DatasetApiResource):
class DatasetSegmentApi(DatasetApiResource):
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
@@ -162,6 +165,7 @@ class DatasetSegmentApi(DatasetApiResource):
return 204
@cloud_edition_billing_resource_check("vector_space", "dataset")
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
@@ -236,6 +240,7 @@ class ChildChunkApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id, segment_id):
"""Create child chunk."""
# check dataset
@@ -332,6 +337,7 @@ class DatasetChildChunkApi(DatasetApiResource):
"""Resource for updating child chunks."""
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id):
"""Delete child chunk."""
# check dataset
@@ -370,6 +376,7 @@ class DatasetChildChunkApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
+ @cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id):
"""Update child chunk."""
# check dataset
diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py
index d24c4597e2..9bb5df4c4e 100644
--- a/api/controllers/service_api/index.py
+++ b/api/controllers/service_api/index.py
@@ -9,7 +9,7 @@ class IndexApi(Resource):
return {
"welcome": "Dify OpenAPI",
"api_version": "v1",
- "server_version": dify_config.CURRENT_VERSION,
+ "server_version": dify_config.project.version,
}
diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py
index d3316a5159..eeed321430 100644
--- a/api/controllers/service_api/wraps.py
+++ b/api/controllers/service_api/wraps.py
@@ -1,6 +1,6 @@
import time
from collections.abc import Callable
-from datetime import UTC, datetime, timedelta
+from datetime import timedelta
from enum import Enum
from functools import wraps
from typing import Optional
@@ -11,13 +11,14 @@ from flask_restful import Resource
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
-from werkzeug.exceptions import Forbidden, Unauthorized
+from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from libs.datetime_utils import naive_utc_now
from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
-from models.dataset import RateLimitLog
+from models.dataset import Dataset, RateLimitLog
from models.model import ApiToken, App, EndUser
from services.feature_service import FeatureService
@@ -256,7 +257,7 @@ def validate_and_get_api_token(scope: str | None = None):
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")
- current_time = datetime.now(UTC).replace(tzinfo=None)
+ current_time = naive_utc_now()
cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
@@ -317,3 +318,11 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
class DatasetApiResource(Resource):
method_decorators = [validate_dataset_token]
+
+ def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset:
+ dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first()
+
+ if not dataset:
+ raise NotFound("Dataset not found.")
+
+ return dataset
diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py
index 06d9ad7564..2919ca9af4 100644
--- a/api/controllers/web/audio.py
+++ b/api/controllers/web/audio.py
@@ -19,7 +19,7 @@ from controllers.web.error import (
from controllers.web.wraps import WebApiResource
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.model_runtime.errors.invoke import InvokeError
-from models.model import App, AppMode
+from models.model import App
from services.audio_service import AudioService
from services.errors.audio import (
AudioTooLargeServiceError,
@@ -77,21 +77,9 @@ class TextApi(WebApiResource):
message_id = args.get("message_id", None)
text = args.get("text", None)
- if (
- app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}
- and app_model.workflow
- and app_model.workflow.features_dict
- ):
- text_to_speech = app_model.workflow.features_dict.get("text_to_speech", {})
- voice = args.get("voice") or text_to_speech.get("voice")
- else:
- try:
- voice = args.get("voice") or app_model.app_model_config.text_to_speech_dict.get("voice")
- except Exception:
- voice = None
-
+ voice = args.get("voice", None)
response = AudioService.transcript_tts(
- app_model=app_model, message_id=message_id, end_user=end_user.external_user_id, voice=voice, text=text
+ app_model=app_model, text=text, voice=voice, end_user=end_user.external_user_id, message_id=message_id
)
return response
diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py
index 4371e679db..036e11d5c5 100644
--- a/api/controllers/web/error.py
+++ b/api/controllers/web/error.py
@@ -139,3 +139,13 @@ class InvokeRateLimitError(BaseHTTPException):
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429
+
+
+class NotFoundError(BaseHTTPException):
+ error_code = "not_found"
+ code = 404
+
+
+class InvalidArgumentError(BaseHTTPException):
+ error_code = "invalid_param"
+ code = 400
diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py
index 6998e4d29a..28bf4a9a23 100644
--- a/api/core/agent/base_agent_runner.py
+++ b/api/core/agent/base_agent_runner.py
@@ -3,6 +3,8 @@ import logging
import uuid
from typing import Optional, Union, cast
+from sqlalchemy import select
+
from core.agent.entities import AgentEntity, AgentToolEntity
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
@@ -161,10 +163,14 @@ class BaseAgentRunner(AppRunner):
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] if parameter.options else []
- message_tool.parameters["properties"][parameter.name] = {
- "type": parameter_type,
- "description": parameter.llm_description or "",
- }
+ message_tool.parameters["properties"][parameter.name] = (
+ {
+ "type": parameter_type,
+ "description": parameter.llm_description or "",
+ }
+ if parameter.input_schema is None
+ else parameter.input_schema
+ )
if len(enum) > 0:
message_tool.parameters["properties"][parameter.name]["enum"] = enum
@@ -254,10 +260,14 @@ class BaseAgentRunner(AppRunner):
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] if parameter.options else []
- prompt_tool.parameters["properties"][parameter.name] = {
- "type": parameter_type,
- "description": parameter.llm_description or "",
- }
+ prompt_tool.parameters["properties"][parameter.name] = (
+ {
+ "type": parameter_type,
+ "description": parameter.llm_description or "",
+ }
+ if parameter.input_schema is None
+ else parameter.input_schema
+ )
if len(enum) > 0:
prompt_tool.parameters["properties"][parameter.name]["enum"] = enum
@@ -409,12 +419,15 @@ class BaseAgentRunner(AppRunner):
if isinstance(prompt_message, SystemPromptMessage):
result.append(prompt_message)
- messages: list[Message] = (
- db.session.query(Message)
- .filter(
- Message.conversation_id == self.message.conversation_id,
+ messages = (
+ (
+ db.session.execute(
+ select(Message)
+ .where(Message.conversation_id == self.message.conversation_id)
+ .order_by(Message.created_at.desc())
+ )
)
- .order_by(Message.created_at.desc())
+ .scalars()
.all()
)
diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py
index 143a3a51aa..a31c1050bd 100644
--- a/api/core/agent/entities.py
+++ b/api/core/agent/entities.py
@@ -16,6 +16,7 @@ class AgentToolEntity(BaseModel):
tool_name: str
tool_parameters: dict[str, Any] = Field(default_factory=dict)
plugin_unique_identifier: str | None = None
+ credential_id: str | None = None
class AgentPromptEntity(BaseModel):
diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py
index 9c722baa23..a3438fc2c7 100644
--- a/api/core/agent/plugin_entities.py
+++ b/api/core/agent/plugin_entities.py
@@ -41,6 +41,7 @@ class AgentStrategyParameter(PluginParameter):
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
+ ANY = CommonParameterType.ANY.value
# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
@@ -85,7 +86,7 @@ class AgentStrategyEntity(BaseModel):
description: I18nObject = Field(..., description="The description of the agent strategy")
output_schema: Optional[dict] = None
features: Optional[list[AgentFeature]] = None
-
+ meta_version: Optional[str] = None
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
diff --git a/api/core/agent/strategy/base.py b/api/core/agent/strategy/base.py
index ead81a7a0e..a52a1dfd7a 100644
--- a/api/core/agent/strategy/base.py
+++ b/api/core/agent/strategy/base.py
@@ -4,6 +4,7 @@ from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyParameter
+from core.plugin.entities.request import InvokeCredentials
class BaseAgentStrategy(ABC):
@@ -18,11 +19,12 @@ class BaseAgentStrategy(ABC):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
+ credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
"""
- yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
+ yield from self._invoke(params, user_id, conversation_id, app_id, message_id, credentials)
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
"""
@@ -38,5 +40,6 @@ class BaseAgentStrategy(ABC):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
+ credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]:
pass
diff --git a/api/core/agent/strategy/plugin.py b/api/core/agent/strategy/plugin.py
index 79b074cf95..04661581a7 100644
--- a/api/core/agent/strategy/plugin.py
+++ b/api/core/agent/strategy/plugin.py
@@ -4,6 +4,7 @@ from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
from core.agent.strategy.base import BaseAgentStrategy
+from core.plugin.entities.request import InvokeCredentials, PluginInvokeContext
from core.plugin.impl.agent import PluginAgentClient
from core.plugin.utils.converter import convert_parameters_to_plugin_format
@@ -15,10 +16,12 @@ class PluginAgentStrategy(BaseAgentStrategy):
tenant_id: str
declaration: AgentStrategyEntity
+ meta_version: str | None = None
- def __init__(self, tenant_id: str, declaration: AgentStrategyEntity):
+ def __init__(self, tenant_id: str, declaration: AgentStrategyEntity, meta_version: str | None):
self.tenant_id = tenant_id
self.declaration = declaration
+ self.meta_version = meta_version
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
return self.declaration.parameters
@@ -38,6 +41,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
+ credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
@@ -56,4 +60,5 @@ class PluginAgentStrategy(BaseAgentStrategy):
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
+ context=PluginInvokeContext(credentials=credentials or InvokeCredentials()),
)
diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py
index 590b944c0d..8887d2500c 100644
--- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py
+++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py
@@ -39,6 +39,7 @@ class AgentConfigManager:
"provider_id": tool["provider_id"],
"tool_name": tool["tool_name"],
"tool_parameters": tool.get("tool_parameters", {}),
+ "credential_id": tool.get("credential_id", None),
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))
diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py
index 20189053f4..a5492d70bd 100644
--- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py
+++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py
@@ -138,14 +138,11 @@ class DatasetConfigManager:
if not config.get("dataset_configs"):
config["dataset_configs"] = {"retrieval_model": "single"}
- if not config["dataset_configs"].get("datasets"):
- config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
-
if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type")
- if not isinstance(config["dataset_configs"], dict):
- raise ValueError("dataset_configs must be of object type")
+ if not config["dataset_configs"].get("datasets"):
+ config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
"datasets", {}
diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py
index 3f31b1c3d5..75bd2f677a 100644
--- a/api/core/app/app_config/entities.py
+++ b/api/core/app/app_config/entities.py
@@ -104,6 +104,7 @@ class VariableEntity(BaseModel):
Variable Entity.
"""
+ # `variable` records the name of the variable in user inputs.
variable: str
label: str
description: str = ""
diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py
index 9e6adc4b08..bd5ad9c51b 100644
--- a/api/core/app/apps/advanced_chat/app_generator.py
+++ b/api/core/app/apps/advanced_chat/app_generator.py
@@ -17,7 +17,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
-from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
+from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
@@ -25,17 +26,23 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
-from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
+from core.workflow.repositories.draft_variable_repository import (
+ DraftVariableSaverFactory,
+)
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
-from services.errors.message import MessageNotExistsError
+from services.workflow_draft_variable_service import (
+ DraftVarLoader,
+ WorkflowDraftVariableService,
+)
logger = logging.getLogger(__name__)
@@ -116,6 +123,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
)
# parse files
+ # TODO(QuantumGhost): Move file parsing logic to the API controller layer
+ # for better separation of concerns.
+ #
+ # For implementation reference, see the `_parse_file` function and
+ # `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config:
@@ -171,14 +183,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
- workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
+ workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
@@ -248,19 +260,26 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
- workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
+ workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
+ var_loader = DraftVarLoader(
+ engine=db.engine,
+ app_id=application_generate_entity.app_config.app_id,
+ tenant_id=application_generate_entity.app_config.tenant_id,
+ )
+ draft_var_srv = WorkflowDraftVariableService(db.session())
+ draft_var_srv.prefill_conversation_variable_default_values(workflow)
return self._generate(
workflow=workflow,
@@ -271,6 +290,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None,
stream=streaming,
+ variable_loader=var_loader,
)
def single_loop_generate(
@@ -323,19 +343,26 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
- workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
+ workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
+ var_loader = DraftVarLoader(
+ engine=db.engine,
+ app_id=application_generate_entity.app_config.app_id,
+ tenant_id=application_generate_entity.app_config.tenant_id,
+ )
+ draft_var_srv = WorkflowDraftVariableService(db.session())
+ draft_var_srv.prefill_conversation_variable_default_values(workflow)
return self._generate(
workflow=workflow,
@@ -346,6 +373,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None,
stream=streaming,
+ variable_loader=var_loader,
)
def _generate(
@@ -359,6 +387,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Optional[Conversation] = None,
stream: bool = True,
+ variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
@@ -367,6 +396,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param user: account or end user
:param invoke_from: invoke from source
:param application_generate_entity: application generate entity
+ :param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution
:param conversation: conversation
:param stream: is stream
@@ -409,6 +439,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"conversation_id": conversation.id,
"message_id": message.id,
"context": context,
+ "variable_loader": variable_loader,
},
)
@@ -425,6 +456,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
+ draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from),
)
return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
@@ -437,6 +469,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation_id: str,
message_id: str,
context: contextvars.Context,
+ variable_loader: VariableLoader,
) -> None:
"""
Generate worker in a new thread.
@@ -453,8 +486,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
- if message is None:
- raise MessageNotExistsError("Message not exists")
# chatbot app
runner = AdvancedChatAppRunner(
@@ -463,6 +494,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
dialogue_count=self._dialogue_count,
+ variable_loader=variable_loader,
)
runner.run()
@@ -496,6 +528,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
+ draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
@@ -522,6 +555,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
+ draft_var_saver_factory=draft_var_saver_factory,
)
try:
diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py
index d9b3833862..af15324f46 100644
--- a/api/core/app/apps/advanced_chat/app_runner.py
+++ b/api/core/app/apps/advanced_chat/app_runner.py
@@ -16,9 +16,11 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
)
from core.moderation.base import ModerationError
+from core.variables.variables import VariableUnion
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
+from core.workflow.system_variable import SystemVariable
+from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.enums import UserFrom
@@ -40,14 +42,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation: Conversation,
message: Message,
dialogue_count: int,
+ variable_loader: VariableLoader,
) -> None:
- super().__init__(queue_manager)
-
+ super().__init__(queue_manager, variable_loader)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.message = message
self._dialogue_count = dialogue_count
+ def _get_app_id(self) -> str:
+ return self.application_generate_entity.app_config.app_id
+
def run(self) -> None:
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
@@ -60,7 +65,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not workflow:
raise ValueError("Workflow not initialized")
- user_id = None
+ user_id: str | None = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
@@ -132,23 +137,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
session.commit()
# Create a variable pool.
- system_inputs = {
- SystemVariableKey.QUERY: query,
- SystemVariableKey.FILES: files,
- SystemVariableKey.CONVERSATION_ID: self.conversation.id,
- SystemVariableKey.USER_ID: user_id,
- SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count,
- SystemVariableKey.APP_ID: app_config.app_id,
- SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
- SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id,
- }
+ system_inputs = SystemVariable(
+ query=query,
+ files=files,
+ conversation_id=self.conversation.id,
+ user_id=user_id,
+ dialogue_count=self._dialogue_count,
+ app_id=app_config.app_id,
+ workflow_id=app_config.workflow_id,
+ workflow_execution_id=self.application_generate_entity.workflow_run_id,
+ )
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
- conversation_variables=conversation_variables,
+ # Based on the definition of `VariableUnion`,
+ # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
+ conversation_variables=cast(list[VariableUnion], conversation_variables),
)
# init graph
diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
index 8c5645bbb7..337b779b50 100644
--- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py
+++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
@@ -1,6 +1,7 @@
import logging
import time
-from collections.abc import Generator, Mapping
+from collections.abc import Callable, Generator, Mapping
+from contextlib import contextmanager
from threading import Thread
from typing import Any, Optional, Union
@@ -15,6 +16,7 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
)
from core.app.entities.queue_entities import (
+ MessageQueueMessage,
QueueAdvancedChatMessageEndEvent,
QueueAgentLogEvent,
QueueAnnotationReplyEvent,
@@ -44,6 +46,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
+ WorkflowQueueMessage,
)
from core.app.entities.task_entities import (
ChatbotAppBlockingResponse,
@@ -52,6 +55,7 @@ from core.app.entities.task_entities import (
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
+ PingStreamResponse,
StreamResponse,
WorkflowTaskState,
)
@@ -61,11 +65,12 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
+from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from events.message_event import message_was_created
from extensions.ext_database import db
@@ -94,6 +99,7 @@ class AdvancedChatAppGenerateTaskPipeline:
dialogue_count: int,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
+ draft_var_saver_factory: DraftVariableSaverFactory,
) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
@@ -114,16 +120,16 @@ class AdvancedChatAppGenerateTaskPipeline:
self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
- workflow_system_variables={
- SystemVariableKey.QUERY: message.query,
- SystemVariableKey.FILES: application_generate_entity.files,
- SystemVariableKey.CONVERSATION_ID: conversation.id,
- SystemVariableKey.USER_ID: user_session_id,
- SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
- SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
- SystemVariableKey.WORKFLOW_ID: workflow.id,
- SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id,
- },
+ workflow_system_variables=SystemVariable(
+ query=message.query,
+ files=application_generate_entity.files,
+ conversation_id=conversation.id,
+ user_id=user_session_id,
+ dialogue_count=dialogue_count,
+ app_id=application_generate_entity.app_config.app_id,
+ workflow_id=workflow.id,
+ workflow_execution_id=application_generate_entity.workflow_run_id,
+ ),
workflow_info=CycleManagerWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
@@ -153,13 +159,13 @@ class AdvancedChatAppGenerateTaskPipeline:
self._conversation_name_generate_thread: Thread | None = None
self._recorded_files: list[Mapping[str, Any]] = []
self._workflow_run_id: str = ""
+ self._draft_var_saver_factory = draft_var_saver_factory
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Process generate task pipeline.
:return:
"""
- # start generate conversation name thread
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query
)
@@ -251,15 +257,12 @@ class AdvancedChatAppGenerateTaskPipeline:
yield response
start_listener_time = time.time()
- # timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not tts_publisher:
break
audio_trunk = tts_publisher.check_and_get_audio()
if audio_trunk is None:
- # release cpu
- # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio_trunk.status == "finish":
@@ -273,400 +276,613 @@ class AdvancedChatAppGenerateTaskPipeline:
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
- def _process_stream_response(
+ @contextmanager
+ def _database_session(self):
+ """Context manager for database sessions."""
+ with Session(db.engine, expire_on_commit=False) as session:
+ try:
+ yield session
+ session.commit()
+ except Exception:
+ session.rollback()
+ raise
+
+ def _ensure_workflow_initialized(self) -> None:
+ """Fluent validation for workflow state."""
+ if not self._workflow_run_id:
+ raise ValueError("workflow run not initialized.")
+
+ def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState:
+ """Fluent validation for graph runtime state."""
+ if not graph_runtime_state:
+ raise ValueError("graph runtime state not initialized.")
+ return graph_runtime_state
+
+ def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
+ """Handle ping events."""
+ yield self._base_task_pipeline._ping_stream_response()
+
+ def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
+ """Handle error events."""
+ with self._database_session() as session:
+ err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
+ yield self._base_task_pipeline._error_to_stream_response(err)
+
+ def _handle_workflow_started_event(
+ self, event: QueueWorkflowStartedEvent, *, graph_runtime_state: Optional[GraphRuntimeState] = None, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle workflow started events."""
+ # Override graph runtime state - this is a side effect but necessary
+ graph_runtime_state = event.graph_runtime_state
+
+ with self._database_session() as session:
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
+ self._workflow_run_id = workflow_execution.id_
+
+ message = self._get_message(session=session)
+ if not message:
+ raise ValueError(f"Message not found: {self._message_id}")
+
+ message.workflow_run_id = workflow_execution.id_
+ workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution=workflow_execution,
+ )
+
+ yield workflow_start_resp
+
+ def _handle_node_retry_event(self, event: QueueNodeRetryEvent, **kwargs) -> Generator[StreamResponse, None, None]:
+ """Handle node retry events."""
+ self._ensure_workflow_initialized()
+
+ with self._database_session() as session:
+ workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
+ workflow_execution_id=self._workflow_run_id, event=event
+ )
+ node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
+
+ if node_retry_resp:
+ yield node_retry_resp
+
+ def _handle_node_started_event(
+ self, event: QueueNodeStartedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle node started events."""
+ self._ensure_workflow_initialized()
+
+ workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
+ workflow_execution_id=self._workflow_run_id, event=event
+ )
+
+ node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
+
+ if node_start_resp:
+ yield node_start_resp
+
+ def _handle_node_succeeded_event(
+ self, event: QueueNodeSucceededEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle node succeeded events."""
+ # Record files if it's an answer node or end node
+ if event.node_type in [NodeType.ANSWER, NodeType.END]:
+ self._recorded_files.extend(
+ self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
+ )
+
+ with self._database_session() as session:
+ workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
+ node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
+
+ self._save_output_for_event(event, workflow_node_execution.id)
+
+ if node_finish_resp:
+ yield node_finish_resp
+
+ def _handle_node_failed_events(
+ self,
+ event: Union[
+ QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
+ ],
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle various node failure events."""
+ workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(event=event)
+
+ node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
+
+ if isinstance(event, QueueNodeExceptionEvent):
+ self._save_output_for_event(event, workflow_node_execution.id)
+
+ if node_finish_resp:
+ yield node_finish_resp
+
+ def _handle_text_chunk_event(
self,
+ event: QueueTextChunkEvent,
+ *,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
- trace_manager: Optional[TraceQueueManager] = None,
+ queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
+ **kwargs,
) -> Generator[StreamResponse, None, None]:
- """
- Process stream response.
- :return:
- """
- # init fake graph runtime state
- graph_runtime_state: Optional[GraphRuntimeState] = None
+ """Handle text chunk events."""
+ delta_text = event.text
+ if delta_text is None:
+ return
+
+ # Handle output moderation chunk
+ should_direct_answer = self._handle_output_moderation_chunk(delta_text)
+ if should_direct_answer:
+ return
+
+ # Only publish tts message at text chunk streaming
+ if tts_publisher and queue_message:
+ tts_publisher.publish(queue_message)
+
+ self._task_state.answer += delta_text
+ yield self._message_cycle_manager.message_to_stream_response(
+ answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
+ )
- for queue_message in self._base_task_pipeline._queue_manager.listen():
- event = queue_message.event
+ def _handle_parallel_branch_started_event(
+ self, event: QueueParallelBranchRunStartedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle parallel branch started events."""
+ self._ensure_workflow_initialized()
- if isinstance(event, QueuePingEvent):
- yield self._base_task_pipeline._ping_stream_response()
- elif isinstance(event, QueueErrorEvent):
- with Session(db.engine, expire_on_commit=False) as session:
- err = self._base_task_pipeline._handle_error(
- event=event, session=session, message_id=self._message_id
- )
- session.commit()
- yield self._base_task_pipeline._error_to_stream_response(err)
- break
- elif isinstance(event, QueueWorkflowStartedEvent):
- # override graph runtime state
- graph_runtime_state = event.graph_runtime_state
-
- with Session(db.engine, expire_on_commit=False) as session:
- # init workflow run
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
- self._workflow_run_id = workflow_execution.id_
- message = self._get_message(session=session)
- if not message:
- raise ValueError(f"Message not found: {self._message_id}")
- message.workflow_run_id = workflow_execution.id_
- workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
- session.commit()
+ parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield parallel_start_resp
- yield workflow_start_resp
- elif isinstance(
- event,
- QueueNodeRetryEvent,
- ):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
-
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
- workflow_execution_id=self._workflow_run_id, event=event
- )
- node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
- session.commit()
+ def _handle_parallel_branch_finished_events(
+ self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle parallel branch finished events."""
+ self._ensure_workflow_initialized()
- if node_retry_resp:
- yield node_retry_resp
- elif isinstance(event, QueueNodeStartedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield parallel_finish_resp
- workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
- workflow_execution_id=self._workflow_run_id, event=event
- )
+ def _handle_iteration_start_event(
+ self, event: QueueIterationStartEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle iteration start events."""
+ self._ensure_workflow_initialized()
- node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response(
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
+ iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield iter_start_resp
- if node_start_resp:
- yield node_start_resp
- elif isinstance(event, QueueNodeSucceededEvent):
- # Record files if it's an answer node or end node
- if event.node_type in [NodeType.ANSWER, NodeType.END]:
- self._recorded_files.extend(
- self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
- )
+ def _handle_iteration_next_event(
+ self, event: QueueIterationNextEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle iteration next events."""
+ self._ensure_workflow_initialized()
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(
- event=event
- )
+ iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield iter_next_resp
- node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
- session.commit()
+ def _handle_iteration_completed_event(
+ self, event: QueueIterationCompletedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle iteration completed events."""
+ self._ensure_workflow_initialized()
- if node_finish_resp:
- yield node_finish_resp
- elif isinstance(
- event,
- QueueNodeFailedEvent
- | QueueNodeInIterationFailedEvent
- | QueueNodeInLoopFailedEvent
- | QueueNodeExceptionEvent,
- ):
- workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
- event=event
- )
+ iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield iter_finish_resp
- node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
+ def _handle_loop_start_event(self, event: QueueLoopStartEvent, **kwargs) -> Generator[StreamResponse, None, None]:
+ """Handle loop start events."""
+ self._ensure_workflow_initialized()
- if node_finish_resp:
- yield node_finish_resp
- elif isinstance(event, QueueParallelBranchRunStartedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
-
- parallel_start_resp = (
- self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
- )
+ loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield loop_start_resp
- yield parallel_start_resp
- elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ def _handle_loop_next_event(self, event: QueueLoopNextEvent, **kwargs) -> Generator[StreamResponse, None, None]:
+ """Handle loop next events."""
+ self._ensure_workflow_initialized()
- parallel_finish_resp = (
- self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
- )
+ loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield loop_next_resp
- yield parallel_finish_resp
- elif isinstance(event, QueueIterationStartEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ def _handle_loop_completed_event(
+ self, event: QueueLoopCompletedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle loop completed events."""
+ self._ensure_workflow_initialized()
- iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield loop_finish_resp
+
+ def _handle_workflow_succeeded_event(
+ self,
+ event: QueueWorkflowSucceededEvent,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle workflow succeeded events."""
+ self._ensure_workflow_initialized()
+ validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
+
+ with self._database_session() as session:
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
+ workflow_run_id=self._workflow_run_id,
+ total_tokens=validated_state.total_tokens,
+ total_steps=validated_state.node_run_steps,
+ outputs=event.outputs,
+ conversation_id=self._conversation_id,
+ trace_manager=trace_manager,
+ )
+ workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
+ session=session,
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution=workflow_execution,
+ )
- yield iter_start_resp
- elif isinstance(event, QueueIterationNextEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ yield workflow_finish_resp
+ self._base_task_pipeline._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
- iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ def _handle_workflow_partial_success_event(
+ self,
+ event: QueueWorkflowPartialSuccessEvent,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle workflow partial success events."""
+ self._ensure_workflow_initialized()
+ validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
+
+ with self._database_session() as session:
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
+ workflow_run_id=self._workflow_run_id,
+ total_tokens=validated_state.total_tokens,
+ total_steps=validated_state.node_run_steps,
+ outputs=event.outputs,
+ exceptions_count=event.exceptions_count,
+ conversation_id=None,
+ trace_manager=trace_manager,
+ )
+ workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
+ session=session,
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution=workflow_execution,
+ )
- yield iter_next_resp
- elif isinstance(event, QueueIterationCompletedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ yield workflow_finish_resp
+ self._base_task_pipeline._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
- iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ def _handle_workflow_failed_event(
+ self,
+ event: QueueWorkflowFailedEvent,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle workflow failed events."""
+ self._ensure_workflow_initialized()
+ validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
+
+ with self._database_session() as session:
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
+ workflow_run_id=self._workflow_run_id,
+ total_tokens=validated_state.total_tokens,
+ total_steps=validated_state.node_run_steps,
+ status=WorkflowExecutionStatus.FAILED,
+ error_message=event.error,
+ conversation_id=self._conversation_id,
+ trace_manager=trace_manager,
+ exceptions_count=event.exceptions_count,
+ )
+ workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
+ session=session,
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution=workflow_execution,
+ )
+ err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
+ err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id)
- yield iter_finish_resp
- elif isinstance(event, QueueLoopStartEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ yield workflow_finish_resp
+ yield self._base_task_pipeline._error_to_stream_response(err)
- loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
+ def _handle_stop_event(
+ self,
+ event: QueueStopEvent,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle stop events."""
+ if self._workflow_run_id and graph_runtime_state:
+ with self._database_session() as session:
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
+ workflow_run_id=self._workflow_run_id,
+ total_tokens=graph_runtime_state.total_tokens,
+ total_steps=graph_runtime_state.node_run_steps,
+ status=WorkflowExecutionStatus.STOPPED,
+ error_message=event.get_stop_reason(),
+ conversation_id=self._conversation_id,
+ trace_manager=trace_manager,
+ )
+ workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
+ session=session,
task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
+ workflow_execution=workflow_execution,
)
+ # Save message
+ self._save_message(session=session, graph_runtime_state=graph_runtime_state)
- yield loop_start_resp
- elif isinstance(event, QueueLoopNextEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ yield workflow_finish_resp
+ elif event.stopped_by in (
+ QueueStopEvent.StopBy.INPUT_MODERATION,
+ QueueStopEvent.StopBy.ANNOTATION_REPLY,
+ ):
+ # When hitting input-moderation or annotation-reply, the workflow will not start
+ with self._database_session() as session:
+ # Save message
+ self._save_message(session=session)
- loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ yield self._message_end_to_stream_response()
- yield loop_next_resp
- elif isinstance(event, QueueLoopCompletedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ def _handle_advanced_chat_message_end_event(
+ self,
+ event: QueueAdvancedChatMessageEndEvent,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle advanced chat message end events."""
+ self._ensure_graph_runtime_initialized(graph_runtime_state)
- loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
+ self._task_state.answer
+ )
+ if output_moderation_answer:
+ self._task_state.answer = output_moderation_answer
+ yield self._message_cycle_manager.message_replace_to_stream_response(
+ answer=output_moderation_answer,
+ reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
+ )
- yield loop_finish_resp
- elif isinstance(event, QueueWorkflowSucceededEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ # Save message
+ with self._database_session() as session:
+ self._save_message(session=session, graph_runtime_state=graph_runtime_state)
- if not graph_runtime_state:
- raise ValueError("workflow run not initialized.")
+ yield self._message_end_to_stream_response()
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
- workflow_run_id=self._workflow_run_id,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- outputs=event.outputs,
- conversation_id=self._conversation_id,
- trace_manager=trace_manager,
- )
+ def _handle_retriever_resources_event(
+ self, event: QueueRetrieverResourcesEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle retriever resources events."""
+ self._message_cycle_manager.handle_retriever_resources(event)
- workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
- session=session,
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
+ with self._database_session() as session:
+ message = self._get_message(session=session)
+ message.message_metadata = self._task_state.metadata.model_dump_json()
+ return
+ yield # Make this a generator
- yield workflow_finish_resp
- self._base_task_pipeline._queue_manager.publish(
- QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
- )
- elif isinstance(event, QueueWorkflowPartialSuccessEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
- if not graph_runtime_state:
- raise ValueError("graph runtime state not initialized.")
-
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
- workflow_run_id=self._workflow_run_id,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- outputs=event.outputs,
- exceptions_count=event.exceptions_count,
- conversation_id=None,
- trace_manager=trace_manager,
- )
- workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
- session=session,
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
+ def _handle_annotation_reply_event(
+ self, event: QueueAnnotationReplyEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle annotation reply events."""
+ self._message_cycle_manager.handle_annotation_reply(event)
- yield workflow_finish_resp
- self._base_task_pipeline._queue_manager.publish(
- QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
- )
- elif isinstance(event, QueueWorkflowFailedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
- if not graph_runtime_state:
- raise ValueError("graph runtime state not initialized.")
-
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
- workflow_run_id=self._workflow_run_id,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- status=WorkflowExecutionStatus.FAILED,
- error_message=event.error,
- conversation_id=self._conversation_id,
- trace_manager=trace_manager,
- exceptions_count=event.exceptions_count,
- )
- workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
- session=session,
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
- err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
- err = self._base_task_pipeline._handle_error(
- event=err_event, session=session, message_id=self._message_id
- )
+ with self._database_session() as session:
+ message = self._get_message(session=session)
+ message.message_metadata = self._task_state.metadata.model_dump_json()
+ return
+ yield # Make this a generator
- yield workflow_finish_resp
- yield self._base_task_pipeline._error_to_stream_response(err)
- break
- elif isinstance(event, QueueStopEvent):
- if self._workflow_run_id and graph_runtime_state:
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
- workflow_run_id=self._workflow_run_id,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- status=WorkflowExecutionStatus.STOPPED,
- error_message=event.get_stop_reason(),
- conversation_id=self._conversation_id,
- trace_manager=trace_manager,
- )
- workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
- session=session,
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
- # Save message
- self._save_message(session=session, graph_runtime_state=graph_runtime_state)
- session.commit()
-
- yield workflow_finish_resp
- elif event.stopped_by in (
- QueueStopEvent.StopBy.INPUT_MODERATION,
- QueueStopEvent.StopBy.ANNOTATION_REPLY,
- ):
- # When hitting input-moderation or annotation-reply, the workflow will not start
- with Session(db.engine, expire_on_commit=False) as session:
- # Save message
- self._save_message(session=session)
- session.commit()
-
- yield self._message_end_to_stream_response()
- break
- elif isinstance(event, QueueRetrieverResourcesEvent):
- self._message_cycle_manager.handle_retriever_resources(event)
-
- with Session(db.engine, expire_on_commit=False) as session:
- message = self._get_message(session=session)
- message.message_metadata = self._task_state.metadata.model_dump_json()
- session.commit()
- elif isinstance(event, QueueAnnotationReplyEvent):
- self._message_cycle_manager.handle_annotation_reply(event)
-
- with Session(db.engine, expire_on_commit=False) as session:
- message = self._get_message(session=session)
- message.message_metadata = self._task_state.metadata.model_dump_json()
- session.commit()
- elif isinstance(event, QueueTextChunkEvent):
- delta_text = event.text
- if delta_text is None:
- continue
+ def _handle_message_replace_event(
+ self, event: QueueMessageReplaceEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle message replace events."""
+ yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason)
- # handle output moderation chunk
- should_direct_answer = self._handle_output_moderation_chunk(delta_text)
- if should_direct_answer:
- continue
+ def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
+ """Handle agent log events."""
+ yield self._workflow_response_converter.handle_agent_log(
+ task_id=self._application_generate_entity.task_id, event=event
+ )
- # only publish tts message at text chunk streaming
- if tts_publisher:
- tts_publisher.publish(queue_message)
+ def _get_event_handlers(self) -> dict[type, Callable]:
+ """Get mapping of event types to their handlers using fluent pattern."""
+ return {
+ # Basic events
+ QueuePingEvent: self._handle_ping_event,
+ QueueErrorEvent: self._handle_error_event,
+ QueueTextChunkEvent: self._handle_text_chunk_event,
+ # Workflow events
+ QueueWorkflowStartedEvent: self._handle_workflow_started_event,
+ QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
+ QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
+ QueueWorkflowFailedEvent: self._handle_workflow_failed_event,
+ # Node events
+ QueueNodeRetryEvent: self._handle_node_retry_event,
+ QueueNodeStartedEvent: self._handle_node_started_event,
+ QueueNodeSucceededEvent: self._handle_node_succeeded_event,
+ # Parallel branch events
+ QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
+ # Iteration events
+ QueueIterationStartEvent: self._handle_iteration_start_event,
+ QueueIterationNextEvent: self._handle_iteration_next_event,
+ QueueIterationCompletedEvent: self._handle_iteration_completed_event,
+ # Loop events
+ QueueLoopStartEvent: self._handle_loop_start_event,
+ QueueLoopNextEvent: self._handle_loop_next_event,
+ QueueLoopCompletedEvent: self._handle_loop_completed_event,
+ # Control events
+ QueueStopEvent: self._handle_stop_event,
+ # Message events
+ QueueRetrieverResourcesEvent: self._handle_retriever_resources_event,
+ QueueAnnotationReplyEvent: self._handle_annotation_reply_event,
+ QueueMessageReplaceEvent: self._handle_message_replace_event,
+ QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event,
+ QueueAgentLogEvent: self._handle_agent_log_event,
+ }
+
+ def _dispatch_event(
+ self,
+ event: Any,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
+ ) -> Generator[StreamResponse, None, None]:
+ """Dispatch events using elegant pattern matching."""
+ handlers = self._get_event_handlers()
+ event_type = type(event)
- self._task_state.answer += delta_text
- yield self._message_cycle_manager.message_to_stream_response(
- answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
- )
- elif isinstance(event, QueueMessageReplaceEvent):
- # published by moderation
- yield self._message_cycle_manager.message_replace_to_stream_response(
- answer=event.text, reason=event.reason
- )
- elif isinstance(event, QueueAdvancedChatMessageEndEvent):
- if not graph_runtime_state:
- raise ValueError("graph runtime state not initialized.")
+ # Direct handler lookup
+ if handler := handlers.get(event_type):
+ yield from handler(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ return
+
+ # Handle node failure events with isinstance check
+ if isinstance(
+ event,
+ (
+ QueueNodeFailedEvent,
+ QueueNodeInIterationFailedEvent,
+ QueueNodeInLoopFailedEvent,
+ QueueNodeExceptionEvent,
+ ),
+ ):
+ yield from self._handle_node_failed_events(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ return
- output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
- self._task_state.answer
- )
- if output_moderation_answer:
- self._task_state.answer = output_moderation_answer
- yield self._message_cycle_manager.message_replace_to_stream_response(
- answer=output_moderation_answer,
- reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
+ # Handle parallel branch finished events with isinstance check
+ if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
+ yield from self._handle_parallel_branch_finished_events(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ return
+
+ # For unhandled events, we continue (original behavior)
+ return
+
+ def _process_stream_response(
+ self,
+ tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ ) -> Generator[StreamResponse, None, None]:
+ """
+ Process stream response using elegant Fluent Python patterns.
+ Maintains exact same functionality as original 57-if-statement version.
+ """
+ # Initialize graph runtime state
+ graph_runtime_state: Optional[GraphRuntimeState] = None
+
+ for queue_message in self._base_task_pipeline._queue_manager.listen():
+ event = queue_message.event
+
+ match event:
+ case QueueWorkflowStartedEvent():
+ graph_runtime_state = event.graph_runtime_state
+ yield from self._handle_workflow_started_event(event)
+
+ case QueueTextChunkEvent():
+ yield from self._handle_text_chunk_event(
+ event, tts_publisher=tts_publisher, queue_message=queue_message
)
- # Save message
- with Session(db.engine, expire_on_commit=False) as session:
- self._save_message(session=session, graph_runtime_state=graph_runtime_state)
- session.commit()
-
- yield self._message_end_to_stream_response()
- elif isinstance(event, QueueAgentLogEvent):
- yield self._workflow_response_converter.handle_agent_log(
- task_id=self._application_generate_entity.task_id, event=event
- )
- else:
- continue
+ case QueueErrorEvent():
+ yield from self._handle_error_event(event)
+ break
+
+ case QueueWorkflowFailedEvent():
+ yield from self._handle_workflow_failed_event(
+ event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager
+ )
+ break
+
+ case QueueStopEvent():
+ yield from self._handle_stop_event(
+ event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager
+ )
+ break
+
+ # Handle all other events through elegant dispatch
+ case _:
+ if responses := list(
+ self._dispatch_event(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ ):
+ yield from responses
- # publish None when task finished
if tts_publisher:
tts_publisher.publish(None)
@@ -738,7 +954,6 @@ class AdvancedChatAppGenerateTaskPipeline:
"""
if self._base_task_pipeline._output_moderation_handler:
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
- # stop subscribe new token when output moderation should direct output
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
self._base_task_pipeline._queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
@@ -759,3 +974,15 @@ class AdvancedChatAppGenerateTaskPipeline:
if not message:
raise ValueError(f"Message not found: {self._message_id}")
return message
+
+ def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str):
+ with Session(db.engine) as session, session.begin():
+ saver = self._draft_var_saver_factory(
+ session=session,
+ app_id=self._application_generate_entity.app_config.app_id,
+ node_id=event.node_id,
+ node_type=event.node_type,
+ node_execution_id=node_execution_id,
+ enclosing_node_id=event.in_loop_id or event.in_iteration_id,
+ )
+ saver.save(event.process_data, event.outputs)
diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py
index a448bf8a94..8665bc9d11 100644
--- a/api/core/app/apps/agent_chat/app_generator.py
+++ b/api/core/app/apps/agent_chat/app_generator.py
@@ -15,7 +15,8 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
-from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
+from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
@@ -26,7 +27,6 @@ from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser
from services.conversation_service import ConversationService
-from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@@ -124,6 +124,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
override_model_config_dict["retriever_resource"] = {"enabled": True}
# parse files
+ # TODO(QuantumGhost): Move file parsing logic to the API controller layer
+ # for better separation of concerns.
+ #
+ # For implementation reference, see the `_parse_file` function and
+ # `DraftWorkflowNodeRunApi` class which handle this properly.
files = args.get("files") or []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
@@ -233,8 +238,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
- if message is None:
- raise MessageNotExistsError("Message not exists")
# chatbot app
runner = AgentChatAppRunner()
diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py
index a83b75cc1a..beece1d77e 100644
--- a/api/core/app/apps/base_app_generator.py
+++ b/api/core/app/apps/base_app_generator.py
@@ -1,10 +1,20 @@
import json
from collections.abc import Generator, Mapping, Sequence
-from typing import TYPE_CHECKING, Any, Optional, Union
+from typing import TYPE_CHECKING, Any, Optional, Union, final
+
+from sqlalchemy.orm import Session
from core.app.app_config.entities import VariableEntityType
+from core.app.entities.app_invoke_entities import InvokeFrom
from core.file import File, FileUploadConfig
+from core.workflow.nodes.enums import NodeType
+from core.workflow.repositories.draft_variable_repository import (
+ DraftVariableSaver,
+ DraftVariableSaverFactory,
+ NoopDraftVariableSaver,
+)
from factories import file_factory
+from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl
if TYPE_CHECKING:
from core.app.app_config.entities import VariableEntity
@@ -159,3 +169,38 @@ class BaseAppGenerator:
yield f"event: {message}\n\n"
return gen()
+
+ @final
+ @staticmethod
+ def _get_draft_var_saver_factory(invoke_from: InvokeFrom) -> DraftVariableSaverFactory:
+ if invoke_from == InvokeFrom.DEBUGGER:
+
+ def draft_var_saver_factory(
+ session: Session,
+ app_id: str,
+ node_id: str,
+ node_type: NodeType,
+ node_execution_id: str,
+ enclosing_node_id: str | None = None,
+ ) -> DraftVariableSaver:
+ return DraftVariableSaverImpl(
+ session=session,
+ app_id=app_id,
+ node_id=node_id,
+ node_type=node_type,
+ node_execution_id=node_execution_id,
+ enclosing_node_id=enclosing_node_id,
+ )
+ else:
+
+ def draft_var_saver_factory(
+ session: Session,
+ app_id: str,
+ node_id: str,
+ node_type: NodeType,
+ node_execution_id: str,
+ enclosing_node_id: str | None = None,
+ ) -> DraftVariableSaver:
+ return NoopDraftVariableSaver()
+
+ return draft_var_saver_factory
diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py
index 0ba33fbe0d..9da0bae56a 100644
--- a/api/core/app/apps/base_app_queue_manager.py
+++ b/api/core/app/apps/base_app_queue_manager.py
@@ -169,7 +169,3 @@ class AppQueueManager:
raise TypeError(
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
)
-
-
-class GenerateTaskStoppedError(Exception):
- pass
diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py
index a3f0cf7f9f..6e8c261a6a 100644
--- a/api/core/app/apps/base_app_runner.py
+++ b/api/core/app/apps/base_app_runner.py
@@ -38,69 +38,6 @@ _logger = logging.getLogger(__name__)
class AppRunner:
- def get_pre_calculate_rest_tokens(
- self,
- app_record: App,
- model_config: ModelConfigWithCredentialsEntity,
- prompt_template_entity: PromptTemplateEntity,
- inputs: Mapping[str, str],
- files: Sequence["File"],
- query: Optional[str] = None,
- ) -> int:
- """
- Get pre calculate rest tokens
- :param app_record: app record
- :param model_config: model config entity
- :param prompt_template_entity: prompt template entity
- :param inputs: inputs
- :param files: files
- :param query: query
- :return:
- """
- # Invoke model
- model_instance = ModelInstance(
- provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
- )
-
- model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
-
- max_tokens = 0
- for parameter_rule in model_config.model_schema.parameter_rules:
- if parameter_rule.name == "max_tokens" or (
- parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
- ):
- max_tokens = (
- model_config.parameters.get(parameter_rule.name)
- or model_config.parameters.get(parameter_rule.use_template or "")
- ) or 0
-
- if model_context_tokens is None:
- return -1
-
- if max_tokens is None:
- max_tokens = 0
-
- # get prompt messages without memory and context
- prompt_messages, stop = self.organize_prompt_messages(
- app_record=app_record,
- model_config=model_config,
- prompt_template_entity=prompt_template_entity,
- inputs=inputs,
- files=files,
- query=query,
- )
-
- prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
-
- rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens
- if rest_tokens < 0:
- raise InvokeBadRequestError(
- "Query or prefix prompt is too long, you can reduce the prefix prompt, "
- "or shrink the max token, or switch to a llm with a larger token limit size."
- )
-
- return rest_tokens
-
def recalc_llm_max_tokens(
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
):
@@ -181,7 +118,7 @@ class AppRunner:
else:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
- model_mode = ModelMode.value_of(model_config.mode)
+ model_mode = ModelMode(model_config.mode)
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py
index a1329cb938..0c76cc39ae 100644
--- a/api/core/app/apps/chat/app_generator.py
+++ b/api/core/app/apps/chat/app_generator.py
@@ -11,10 +11,11 @@ from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
-from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
+from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
@@ -25,7 +26,6 @@ from factories import file_factory
from models.account import Account
from models.model import App, EndUser
from services.conversation_service import ConversationService
-from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@@ -115,6 +115,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
override_model_config_dict["retriever_resource"] = {"enabled": True}
# parse files
+ # TODO(QuantumGhost): Move file parsing logic to the API controller layer
+ # for better separation of concerns.
+ #
+ # For implementation reference, see the `_parse_file` function and
+ # `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
@@ -219,8 +224,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
- if message is None:
- raise MessageNotExistsError("Message not exists")
# chatbot app
runner = ChatAppRunner()
diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py
index 6f524a5872..34a1da2227 100644
--- a/api/core/app/apps/common/workflow_response_converter.py
+++ b/api/core/app/apps/common/workflow_response_converter.py
@@ -44,10 +44,12 @@ from core.app.entities.task_entities import (
)
from core.file import FILE_MODEL_IDENTITY, File
from core.tools.tool_manager import ToolManager
+from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
+from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import (
Account,
CreatorUserRole,
@@ -125,7 +127,7 @@ class WorkflowResponseConverter:
id=workflow_execution.id_,
workflow_id=workflow_execution.workflow_id,
status=workflow_execution.status,
- outputs=workflow_execution.outputs,
+ outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs),
error=workflow_execution.error_message,
elapsed_time=workflow_execution.elapsed_time,
total_tokens=workflow_execution.total_tokens,
@@ -202,6 +204,8 @@ class WorkflowResponseConverter:
if not workflow_node_execution.finished_at:
return None
+ json_converter = WorkflowRuntimeTypeConverter()
+
return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
@@ -214,7 +218,7 @@ class WorkflowResponseConverter:
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data,
- outputs=workflow_node_execution.outputs,
+ outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
@@ -245,6 +249,8 @@ class WorkflowResponseConverter:
if not workflow_node_execution.finished_at:
return None
+ json_converter = WorkflowRuntimeTypeConverter()
+
return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id,
@@ -257,7 +263,7 @@ class WorkflowResponseConverter:
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data,
- outputs=workflow_node_execution.outputs,
+ outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
@@ -376,6 +382,7 @@ class WorkflowResponseConverter:
workflow_execution_id: str,
event: QueueIterationCompletedEvent,
) -> IterationNodeCompletedStreamResponse:
+ json_converter = WorkflowRuntimeTypeConverter()
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_execution_id,
@@ -384,7 +391,7 @@ class WorkflowResponseConverter:
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
- outputs=event.outputs,
+ outputs=json_converter.to_json_encodable(event.outputs),
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
@@ -463,7 +470,7 @@ class WorkflowResponseConverter:
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
- outputs=event.outputs,
+ outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
@@ -500,7 +507,8 @@ class WorkflowResponseConverter:
# Convert to tuple to match Sequence type
return tuple(flattened_files)
- def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
+ @classmethod
+ def _fetch_files_from_variable_value(cls, value: Union[dict, list, Segment]) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from variable value
:param value: variable value
@@ -509,20 +517,30 @@ class WorkflowResponseConverter:
if not value:
return []
- files = []
- if isinstance(value, list):
+ files: list[Mapping[str, Any]] = []
+ if isinstance(value, FileSegment):
+ files.append(value.value.to_dict())
+ elif isinstance(value, ArrayFileSegment):
+ files.extend([i.to_dict() for i in value.value])
+ elif isinstance(value, File):
+ files.append(value.to_dict())
+ elif isinstance(value, list):
for item in value:
- file = self._get_file_var_from_value(item)
+ file = cls._get_file_var_from_value(item)
if file:
files.append(file)
- elif isinstance(value, dict):
- file = self._get_file_var_from_value(value)
+ elif isinstance(
+ value,
+ dict,
+ ):
+ file = cls._get_file_var_from_value(value)
if file:
files.append(file)
return files
- def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
+ @classmethod
+ def _get_file_var_from_value(cls, value: Union[dict, list]) -> Mapping[str, Any] | None:
"""
Get file var from value
:param value: variable value
diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py
index adcbaad3ec..195e7e2e3d 100644
--- a/api/core/app/apps/completion/app_generator.py
+++ b/api/core/app/apps/completion/app_generator.py
@@ -10,10 +10,11 @@ from pydantic import ValidationError
from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
-from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
+from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.app.apps.completion.app_runner import CompletionAppRunner
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
@@ -101,6 +102,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
)
# parse files
+ # TODO(QuantumGhost): Move file parsing logic to the API controller layer
+ # for better separation of concerns.
+ #
+ # For implementation reference, see the `_parse_file` function and
+ # `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config:
@@ -196,8 +202,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
try:
# get message
message = self._get_message(message_id)
- if message is None:
- raise MessageNotExistsError()
# chatbot app
runner = CompletionAppRunner()
diff --git a/api/core/app/apps/exc.py b/api/core/app/apps/exc.py
new file mode 100644
index 0000000000..4187118b9b
--- /dev/null
+++ b/api/core/app/apps/exc.py
@@ -0,0 +1,2 @@
+class GenerateTaskStoppedError(Exception):
+ pass
diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py
index 58b94f4d43..d50cf1c941 100644
--- a/api/core/app/apps/message_based_app_generator.py
+++ b/api/core/app/apps/message_based_app_generator.py
@@ -1,12 +1,12 @@
import json
import logging
from collections.abc import Generator
-from datetime import UTC, datetime
from typing import Optional, Union, cast
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
-from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
+from core.app.apps.base_app_queue_manager import AppQueueManager
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
@@ -24,11 +24,13 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import CreatorUserRole
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError
+from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__)
@@ -182,7 +184,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.commit()
db.session.refresh(conversation)
else:
- conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ conversation.updated_at = naive_utc_now()
db.session.commit()
message = Message(
@@ -251,7 +253,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return introduction or ""
- def _get_conversation(self, conversation_id: str):
+ def _get_conversation(self, conversation_id: str) -> Conversation:
"""
Get conversation by conversation id
:param conversation_id: conversation id
@@ -260,11 +262,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation:
- raise ConversationNotExistsError()
+ raise ConversationNotExistsError("Conversation not exists")
return conversation
- def _get_message(self, message_id: str) -> Optional[Message]:
+ def _get_message(self, message_id: str) -> Message:
"""
Get message by message id
:param message_id: message id
@@ -272,4 +274,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
"""
message = db.session.query(Message).filter(Message.id == message_id).first()
+ if message is None:
+ raise MessageNotExistsError("Message not exists")
+
return message
diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py
index 363c3c82bb..8507f23f17 100644
--- a/api/core/app/apps/message_based_app_queue_manager.py
+++ b/api/core/app/apps/message_based_app_queue_manager.py
@@ -1,4 +1,5 @@
-from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
+from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py
index 7f4770fc97..6f560b3253 100644
--- a/api/core/app/apps/workflow/app_generator.py
+++ b/api/core/app/apps/workflow/app_generator.py
@@ -13,7 +13,8 @@ import contexts
from configs import dify_config
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_generator import BaseAppGenerator
-from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
+from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner
@@ -23,15 +24,17 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
-from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
+from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom
+from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
logger = logging.getLogger(__name__)
@@ -94,6 +97,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files
+ # TODO(QuantumGhost): Move file parsing logic to the API controller layer
+ # for better separation of concerns.
+ #
+ # For implementation reference, see the `_parse_file` function and
+ # `DraftWorkflowNodeRunApi` class which handle this properly.
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
system_files = file_factory.build_from_mappings(
mappings=files,
@@ -148,14 +156,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
- workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
+ workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
@@ -186,6 +194,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
+ variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@@ -195,6 +204,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
+ :param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
@@ -210,6 +220,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
# new thread with request context and contextvars
context = contextvars.copy_context()
+ # release database connection, because the following new thread operations may take a long time
+ db.session.close()
+
worker_thread = threading.Thread(
target=self._generate_worker,
kwargs={
@@ -218,11 +231,16 @@ class WorkflowAppGenerator(BaseAppGenerator):
"queue_manager": queue_manager,
"context": context,
"workflow_thread_pool_id": workflow_thread_pool_id,
+ "variable_loader": variable_loader,
},
)
worker_thread.start()
+ draft_var_saver_factory = self._get_draft_var_saver_factory(
+ invoke_from,
+ )
+
# return response or stream generator
response = self._handle_response(
application_generate_entity=application_generate_entity,
@@ -231,6 +249,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
+ draft_var_saver_factory=draft_var_saver_factory,
stream=streaming,
)
@@ -287,21 +306,26 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
- workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
+ workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
- session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
-
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
+ draft_var_srv = WorkflowDraftVariableService(db.session())
+ draft_var_srv.prefill_conversation_variable_default_values(workflow)
+ var_loader = DraftVarLoader(
+ engine=db.engine,
+ app_id=application_generate_entity.app_config.app_id,
+ tenant_id=application_generate_entity.app_config.tenant_id,
+ )
return self._generate(
app_model=app_model,
@@ -312,6 +336,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
+ variable_loader=var_loader,
)
def single_loop_generate(
@@ -363,22 +388,26 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
- workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
+ workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
- session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
-
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
)
-
+ draft_var_srv = WorkflowDraftVariableService(db.session())
+ draft_var_srv.prefill_conversation_variable_default_values(workflow)
+ var_loader = DraftVarLoader(
+ engine=db.engine,
+ app_id=application_generate_entity.app_config.app_id,
+ tenant_id=application_generate_entity.app_config.tenant_id,
+ )
return self._generate(
app_model=app_model,
workflow=workflow,
@@ -388,6 +417,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
+ variable_loader=var_loader,
)
def _generate_worker(
@@ -396,6 +426,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
context: contextvars.Context,
+ variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
@@ -414,6 +445,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
+ variable_loader=variable_loader,
)
runner.run()
@@ -444,6 +476,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Union[Account, EndUser],
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
+ draft_var_saver_factory: DraftVariableSaverFactory,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@@ -464,6 +497,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user=user,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
+ draft_var_saver_factory=draft_var_saver_factory,
stream=stream,
)
diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py
index 349b8eb51b..40fc03afb7 100644
--- a/api/core/app/apps/workflow/app_queue_manager.py
+++ b/api/core/app/apps/workflow/app_queue_manager.py
@@ -1,4 +1,5 @@
-from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
+from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py
index b59e34e222..3a66ffa578 100644
--- a/api/core/app/apps/workflow/app_runner.py
+++ b/api/core/app/apps/workflow/app_runner.py
@@ -11,7 +11,8 @@ from core.app.entities.app_invoke_entities import (
)
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
+from core.workflow.system_variable import SystemVariable
+from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.enums import UserFrom
@@ -30,6 +31,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
+ variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
) -> None:
"""
@@ -37,10 +39,13 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
+ super().__init__(queue_manager, variable_loader)
self.application_generate_entity = application_generate_entity
- self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id
+ def _get_app_id(self) -> str:
+ return self.application_generate_entity.app_config.app_id
+
def run(self) -> None:
"""
Run application
@@ -90,13 +95,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
files = self.application_generate_entity.files
# Create a variable pool.
- system_inputs = {
- SystemVariableKey.FILES: files,
- SystemVariableKey.USER_ID: user_id,
- SystemVariableKey.APP_ID: app_config.app_id,
- SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
- SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id,
- }
+
+ system_inputs = SystemVariable(
+ files=files,
+ user_id=user_id,
+ app_id=app_config.app_id,
+ workflow_id=app_config.workflow_id,
+ workflow_execution_id=self.application_generate_entity.workflow_execution_id,
+ )
variable_pool = VariablePool(
system_variables=system_inputs,
diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py
index 1734dbb598..9a39b2e01e 100644
--- a/api/core/app/apps/workflow/generate_task_pipeline.py
+++ b/api/core/app/apps/workflow/generate_task_pipeline.py
@@ -1,9 +1,9 @@
import logging
import time
-from collections.abc import Generator
-from typing import Optional, Union
+from collections.abc import Callable, Generator
+from contextlib import contextmanager
+from typing import Any, Optional, Union
-from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@@ -14,6 +14,7 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity,
)
from core.app.entities.queue_entities import (
+ MessageQueueMessage,
QueueAgentLogEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
@@ -39,11 +40,13 @@ from core.app.entities.queue_entities import (
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
+ WorkflowQueueMessage,
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
+ PingStreamResponse,
StreamResponse,
TextChunkStreamResponse,
WorkflowAppBlockingResponse,
@@ -55,9 +58,11 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
-from core.workflow.enums import SystemVariableKey
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from extensions.ext_database import db
from models.account import Account
@@ -67,7 +72,6 @@ from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
- WorkflowRun,
)
logger = logging.getLogger(__name__)
@@ -87,6 +91,7 @@ class WorkflowAppGenerateTaskPipeline:
stream: bool,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
+ draft_var_saver_factory: DraftVariableSaverFactory,
) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
@@ -107,13 +112,13 @@ class WorkflowAppGenerateTaskPipeline:
self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
- workflow_system_variables={
- SystemVariableKey.FILES: application_generate_entity.files,
- SystemVariableKey.USER_ID: user_session_id,
- SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
- SystemVariableKey.WORKFLOW_ID: workflow.id,
- SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id,
- },
+ workflow_system_variables=SystemVariable(
+ files=application_generate_entity.files,
+ user_id=user_session_id,
+ app_id=application_generate_entity.app_config.app_id,
+ workflow_id=workflow.id,
+ workflow_execution_id=application_generate_entity.workflow_execution_id,
+ ),
workflow_info=CycleManagerWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
@@ -131,6 +136,8 @@ class WorkflowAppGenerateTaskPipeline:
self._application_generate_entity = application_generate_entity
self._workflow_features_dict = workflow.features_dict
self._workflow_run_id = ""
+ self._invoke_from = queue_manager._invoke_from
+ self._draft_var_saver_factory = draft_var_saver_factory
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@@ -244,318 +251,497 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
- def _process_stream_response(
+ @contextmanager
+ def _database_session(self):
+ """Context manager for database sessions."""
+ with Session(db.engine, expire_on_commit=False) as session:
+ try:
+ yield session
+ session.commit()
+ except Exception:
+ session.rollback()
+ raise
+
+ def _ensure_workflow_initialized(self) -> None:
+ """Fluent validation for workflow state."""
+ if not self._workflow_run_id:
+ raise ValueError("workflow run not initialized.")
+
+ def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState:
+ """Fluent validation for graph runtime state."""
+ if not graph_runtime_state:
+ raise ValueError("graph runtime state not initialized.")
+ return graph_runtime_state
+
+ def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
+ """Handle ping events."""
+ yield self._base_task_pipeline._ping_stream_response()
+
+ def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
+ """Handle error events."""
+ err = self._base_task_pipeline._handle_error(event=event)
+ yield self._base_task_pipeline._error_to_stream_response(err)
+
+ def _handle_workflow_started_event(
+ self, event: QueueWorkflowStartedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle workflow started events."""
+ # init workflow run
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
+ self._workflow_run_id = workflow_execution.id_
+ start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution=workflow_execution,
+ )
+ yield start_resp
+
+ def _handle_node_retry_event(self, event: QueueNodeRetryEvent, **kwargs) -> Generator[StreamResponse, None, None]:
+ """Handle node retry events."""
+ self._ensure_workflow_initialized()
+
+ with self._database_session() as session:
+ workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
+
+ if response:
+ yield response
+
+ def _handle_node_started_event(
+ self, event: QueueNodeStartedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle node started events."""
+ self._ensure_workflow_initialized()
+
+ workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
+ workflow_execution_id=self._workflow_run_id, event=event
+ )
+ node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
+
+ if node_start_response:
+ yield node_start_response
+
+ def _handle_node_succeeded_event(
+ self, event: QueueNodeSucceededEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle node succeeded events."""
+ workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
+ node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
+
+ self._save_output_for_event(event, workflow_node_execution.id)
+
+ if node_success_response:
+ yield node_success_response
+
+ def _handle_node_failed_events(
self,
- tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
- trace_manager: Optional[TraceQueueManager] = None,
+ event: Union[
+ QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
+ ],
+ **kwargs,
) -> Generator[StreamResponse, None, None]:
- """
- Process stream response.
- :return:
- """
- graph_runtime_state = None
+ """Handle various node failure events."""
+ workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
+ event=event,
+ )
+ node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
- for queue_message in self._base_task_pipeline._queue_manager.listen():
- event = queue_message.event
+ if isinstance(event, QueueNodeExceptionEvent):
+ self._save_output_for_event(event, workflow_node_execution.id)
- if isinstance(event, QueuePingEvent):
- yield self._base_task_pipeline._ping_stream_response()
- elif isinstance(event, QueueErrorEvent):
- err = self._base_task_pipeline._handle_error(event=event)
- yield self._base_task_pipeline._error_to_stream_response(err)
- break
- elif isinstance(event, QueueWorkflowStartedEvent):
- # override graph runtime state
- graph_runtime_state = event.graph_runtime_state
-
- # init workflow run
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
- self._workflow_run_id = workflow_execution.id_
- start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
+ if node_failed_response:
+ yield node_failed_response
- yield start_resp
- elif isinstance(
- event,
- QueueNodeRetryEvent,
- ):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
- response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
- session.commit()
+ def _handle_parallel_branch_started_event(
+ self, event: QueueParallelBranchRunStartedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle parallel branch started events."""
+ self._ensure_workflow_initialized()
- if response:
- yield response
- elif isinstance(event, QueueNodeStartedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield parallel_start_resp
- workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
- workflow_execution_id=self._workflow_run_id, event=event
- )
- node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
+ def _handle_parallel_branch_finished_events(
+ self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle parallel branch finished events."""
+ self._ensure_workflow_initialized()
- if node_start_response:
- yield node_start_response
- elif isinstance(event, QueueNodeSucceededEvent):
- workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(
- event=event
- )
- node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
+ parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield parallel_finish_resp
- if node_success_response:
- yield node_success_response
- elif isinstance(
- event,
- QueueNodeFailedEvent
- | QueueNodeInIterationFailedEvent
- | QueueNodeInLoopFailedEvent
- | QueueNodeExceptionEvent,
- ):
- workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
- event=event,
- )
- node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
+ def _handle_iteration_start_event(
+ self, event: QueueIterationStartEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle iteration start events."""
+ self._ensure_workflow_initialized()
- if node_failed_response:
- yield node_failed_response
+ iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield iter_start_resp
- elif isinstance(event, QueueParallelBranchRunStartedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ def _handle_iteration_next_event(
+ self, event: QueueIterationNextEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle iteration next events."""
+ self._ensure_workflow_initialized()
- parallel_start_resp = (
- self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
- )
+ iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield iter_next_resp
- yield parallel_start_resp
+ def _handle_iteration_completed_event(
+ self, event: QueueIterationCompletedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle iteration completed events."""
+ self._ensure_workflow_initialized()
- elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield iter_finish_resp
- parallel_finish_resp = (
- self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
- )
+ def _handle_loop_start_event(self, event: QueueLoopStartEvent, **kwargs) -> Generator[StreamResponse, None, None]:
+ """Handle loop start events."""
+ self._ensure_workflow_initialized()
- yield parallel_finish_resp
+ loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield loop_start_resp
- elif isinstance(event, QueueIterationStartEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ def _handle_loop_next_event(self, event: QueueLoopNextEvent, **kwargs) -> Generator[StreamResponse, None, None]:
+ """Handle loop next events."""
+ self._ensure_workflow_initialized()
- iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield loop_next_resp
- yield iter_start_resp
+ def _handle_loop_completed_event(
+ self, event: QueueLoopCompletedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle loop completed events."""
+ self._ensure_workflow_initialized()
- elif isinstance(event, QueueIterationNextEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield loop_finish_resp
- iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ def _handle_workflow_succeeded_event(
+ self,
+ event: QueueWorkflowSucceededEvent,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle workflow succeeded events."""
+ self._ensure_workflow_initialized()
+ validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
+
+ with self._database_session() as session:
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
+ workflow_run_id=self._workflow_run_id,
+ total_tokens=validated_state.total_tokens,
+ total_steps=validated_state.node_run_steps,
+ outputs=event.outputs,
+ conversation_id=None,
+ trace_manager=trace_manager,
+ )
- yield iter_next_resp
+ # save workflow app log
+ self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
- elif isinstance(event, QueueIterationCompletedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
+ session=session,
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution=workflow_execution,
+ )
- iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ yield workflow_finish_resp
- yield iter_finish_resp
+ def _handle_workflow_partial_success_event(
+ self,
+ event: QueueWorkflowPartialSuccessEvent,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle workflow partial success events."""
+ self._ensure_workflow_initialized()
+ validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
+
+ with self._database_session() as session:
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
+ workflow_run_id=self._workflow_run_id,
+ total_tokens=validated_state.total_tokens,
+ total_steps=validated_state.node_run_steps,
+ outputs=event.outputs,
+ exceptions_count=event.exceptions_count,
+ conversation_id=None,
+ trace_manager=trace_manager,
+ )
- elif isinstance(event, QueueLoopStartEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ # save workflow app log
+ self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
- loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
+ session=session,
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution=workflow_execution,
+ )
- yield loop_start_resp
+ yield workflow_finish_resp
- elif isinstance(event, QueueLoopNextEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ def _handle_workflow_failed_and_stop_events(
+ self,
+ event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle workflow failed and stop events."""
+ self._ensure_workflow_initialized()
+ validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
+
+ with self._database_session() as session:
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
+ workflow_run_id=self._workflow_run_id,
+ total_tokens=validated_state.total_tokens,
+ total_steps=validated_state.node_run_steps,
+ status=WorkflowExecutionStatus.FAILED
+ if isinstance(event, QueueWorkflowFailedEvent)
+ else WorkflowExecutionStatus.STOPPED,
+ error_message=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
+ conversation_id=None,
+ trace_manager=trace_manager,
+ exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
+ )
- loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ # save workflow app log
+ self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
- yield loop_next_resp
+ workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
+ session=session,
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution=workflow_execution,
+ )
- elif isinstance(event, QueueLoopCompletedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ yield workflow_finish_resp
- loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ def _handle_text_chunk_event(
+ self,
+ event: QueueTextChunkEvent,
+ *,
+ tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
+ queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle text chunk events."""
+ delta_text = event.text
+ if delta_text is None:
+ return
- yield loop_finish_resp
-
- elif isinstance(event, QueueWorkflowSucceededEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
- if not graph_runtime_state:
- raise ValueError("graph runtime state not initialized.")
-
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
- workflow_run_id=self._workflow_run_id,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- outputs=event.outputs,
- conversation_id=None,
- trace_manager=trace_manager,
- )
+ # only publish tts message at text chunk streaming
+ if tts_publisher and queue_message:
+ tts_publisher.publish(queue_message)
- # save workflow app log
- self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
+ yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector)
- workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
- session=session,
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
- session.commit()
-
- yield workflow_finish_resp
- elif isinstance(event, QueueWorkflowPartialSuccessEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
- if not graph_runtime_state:
- raise ValueError("graph runtime state not initialized.")
-
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
- workflow_run_id=self._workflow_run_id,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- outputs=event.outputs,
- exceptions_count=event.exceptions_count,
- conversation_id=None,
- trace_manager=trace_manager,
- )
+ def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
+ """Handle agent log events."""
+ yield self._workflow_response_converter.handle_agent_log(
+ task_id=self._application_generate_entity.task_id, event=event
+ )
- # save workflow app log
- self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
+ def _get_event_handlers(self) -> dict[type, Callable]:
+ """Get mapping of event types to their handlers using fluent pattern."""
+ return {
+ # Basic events
+ QueuePingEvent: self._handle_ping_event,
+ QueueErrorEvent: self._handle_error_event,
+ QueueTextChunkEvent: self._handle_text_chunk_event,
+ # Workflow events
+ QueueWorkflowStartedEvent: self._handle_workflow_started_event,
+ QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
+ QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
+ # Node events
+ QueueNodeRetryEvent: self._handle_node_retry_event,
+ QueueNodeStartedEvent: self._handle_node_started_event,
+ QueueNodeSucceededEvent: self._handle_node_succeeded_event,
+ # Parallel branch events
+ QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
+ # Iteration events
+ QueueIterationStartEvent: self._handle_iteration_start_event,
+ QueueIterationNextEvent: self._handle_iteration_next_event,
+ QueueIterationCompletedEvent: self._handle_iteration_completed_event,
+ # Loop events
+ QueueLoopStartEvent: self._handle_loop_start_event,
+ QueueLoopNextEvent: self._handle_loop_next_event,
+ QueueLoopCompletedEvent: self._handle_loop_completed_event,
+ # Agent events
+ QueueAgentLogEvent: self._handle_agent_log_event,
+ }
+
+ def _dispatch_event(
+ self,
+ event: Any,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
+ ) -> Generator[StreamResponse, None, None]:
+ """Dispatch events using elegant pattern matching."""
+ handlers = self._get_event_handlers()
+ event_type = type(event)
- workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
- session=session,
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
- session.commit()
-
- yield workflow_finish_resp
- elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
- if not graph_runtime_state:
- raise ValueError("graph runtime state not initialized.")
-
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
- workflow_run_id=self._workflow_run_id,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- status=WorkflowExecutionStatus.FAILED
- if isinstance(event, QueueWorkflowFailedEvent)
- else WorkflowExecutionStatus.STOPPED,
- error_message=event.error
- if isinstance(event, QueueWorkflowFailedEvent)
- else event.get_stop_reason(),
- conversation_id=None,
- trace_manager=trace_manager,
- exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
- )
+ # Direct handler lookup
+ if handler := handlers.get(event_type):
+ yield from handler(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ return
- # save workflow app log
- self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
+ # Handle node failure events with isinstance check
+ if isinstance(
+ event,
+ (
+ QueueNodeFailedEvent,
+ QueueNodeInIterationFailedEvent,
+ QueueNodeInLoopFailedEvent,
+ QueueNodeExceptionEvent,
+ ),
+ ):
+ yield from self._handle_node_failed_events(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ return
- workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
- session=session,
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
- session.commit()
+ # Handle parallel branch finished events with isinstance check
+ if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
+ yield from self._handle_parallel_branch_finished_events(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ return
- yield workflow_finish_resp
- elif isinstance(event, QueueTextChunkEvent):
- delta_text = event.text
- if delta_text is None:
- continue
+ # Handle workflow failed and stop events with isinstance check
+ if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
+ yield from self._handle_workflow_failed_and_stop_events(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ return
- # only publish tts message at text chunk streaming
- if tts_publisher:
- tts_publisher.publish(queue_message)
+ # For unhandled events, we continue (original behavior)
+ return
- yield self._text_chunk_to_stream_response(
- delta_text, from_variable_selector=event.from_variable_selector
- )
- elif isinstance(event, QueueAgentLogEvent):
- yield self._workflow_response_converter.handle_agent_log(
- task_id=self._application_generate_entity.task_id, event=event
- )
- else:
- continue
+ def _process_stream_response(
+ self,
+ tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ ) -> Generator[StreamResponse, None, None]:
+ """
+ Process stream response using elegant Fluent Python patterns.
+ Maintains exact same functionality as original 44-if-statement version.
+ """
+ # Initialize graph runtime state
+ graph_runtime_state = None
+
+ for queue_message in self._base_task_pipeline._queue_manager.listen():
+ event = queue_message.event
+
+ match event:
+ case QueueWorkflowStartedEvent():
+ graph_runtime_state = event.graph_runtime_state
+ yield from self._handle_workflow_started_event(event)
+
+ case QueueTextChunkEvent():
+ yield from self._handle_text_chunk_event(
+ event, tts_publisher=tts_publisher, queue_message=queue_message
+ )
+
+ case QueueErrorEvent():
+ yield from self._handle_error_event(event)
+ break
+
+ # Handle all other events through elegant dispatch
+ case _:
+ if responses := list(
+ self._dispatch_event(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ ):
+ yield from responses
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
- workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
- assert workflow_run is not None
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
@@ -568,10 +754,10 @@ class WorkflowAppGenerateTaskPipeline:
return
workflow_app_log = WorkflowAppLog()
- workflow_app_log.tenant_id = workflow_run.tenant_id
- workflow_app_log.app_id = workflow_run.app_id
- workflow_app_log.workflow_id = workflow_run.workflow_id
- workflow_app_log.workflow_run_id = workflow_run.id
+ workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id
+ workflow_app_log.app_id = self._application_generate_entity.app_config.app_id
+ workflow_app_log.workflow_id = workflow_execution.workflow_id
+ workflow_app_log.workflow_run_id = workflow_execution.id_
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = self._created_by_role
workflow_app_log.created_by = self._user_id
@@ -593,3 +779,15 @@ class WorkflowAppGenerateTaskPipeline:
)
return response
+
+ def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str):
+ with Session(db.engine) as session, session.begin():
+ saver = self._draft_var_saver_factory(
+ session=session,
+ app_id=self._application_generate_entity.app_config.app_id,
+ node_id=event.node_id,
+ node_type=event.node_type,
+ node_execution_id=node_execution_id,
+ enclosing_node_id=event.in_loop_id or event.in_iteration_id,
+ )
+ saver.save(event.process_data, event.outputs)
diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py
index facc24b4ca..2f4d234ecd 100644
--- a/api/core/app/apps/workflow_app_runner.py
+++ b/api/core/app/apps/workflow_app_runner.py
@@ -62,6 +62,8 @@ from core.workflow.graph_engine.entities.event import (
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
+from core.workflow.system_variable import SystemVariable
+from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
@@ -69,8 +71,12 @@ from models.workflow import Workflow
class WorkflowBasedAppRunner(AppRunner):
- def __init__(self, queue_manager: AppQueueManager):
+ def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None:
self.queue_manager = queue_manager
+ self._variable_loader = variable_loader
+
+ def _get_app_id(self) -> str:
+ raise NotImplementedError("not implemented")
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
"""
@@ -161,7 +167,7 @@ class WorkflowBasedAppRunner(AppRunner):
# init variable pool
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
@@ -173,6 +179,13 @@ class WorkflowBasedAppRunner(AppRunner):
except NotImplementedError:
variable_mapping = {}
+ load_into_variable_pool(
+ variable_loader=self._variable_loader,
+ variable_pool=variable_pool,
+ variable_mapping=variable_mapping,
+ user_inputs=user_inputs,
+ )
+
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
@@ -251,7 +264,7 @@ class WorkflowBasedAppRunner(AppRunner):
# init variable pool
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
@@ -262,6 +275,12 @@ class WorkflowBasedAppRunner(AppRunner):
)
except NotImplementedError:
variable_mapping = {}
+ load_into_variable_pool(
+ self._variable_loader,
+ variable_pool=variable_pool,
+ variable_mapping=variable_mapping,
+ user_inputs=user_inputs,
+ )
WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
@@ -376,6 +395,7 @@ class WorkflowBasedAppRunner(AppRunner):
in_loop_id=event.in_loop_id,
)
)
+
elif isinstance(event, NodeRunFailedEvent):
self._publish_event(
QueueNodeFailedEvent(
@@ -438,6 +458,7 @@ class WorkflowBasedAppRunner(AppRunner):
in_loop_id=event.in_loop_id,
)
)
+
elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event(
QueueNodeInIterationFailedEvent(
diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py
index c0d99693b0..65ed267959 100644
--- a/api/core/app/entities/app_invoke_entities.py
+++ b/api/core/app/entities/app_invoke_entities.py
@@ -17,9 +17,24 @@ class InvokeFrom(Enum):
Invoke From.
"""
+ # SERVICE_API indicates that this invocation is from an API call to Dify app.
+ #
+ # Description of service api in Dify docs:
+ # https://docs.dify.ai/en/guides/application-publishing/developing-with-apis
SERVICE_API = "service-api"
+
+ # WEB_APP indicates that this invocation is from
+ # the web app of the workflow (or chatflow).
+ #
+ # Description of web app in Dify docs:
+ # https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README
WEB_APP = "web-app"
+
+ # EXPLORE indicates that this invocation is from
+ # the workflow (or chatflow) explore page.
EXPLORE = "explore"
+ # DEBUGGER indicates that this invocation is from
+ # the workflow (or chatflow) edit page.
DEBUGGER = "debugger"
@classmethod
diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py
index 5331c0cc94..3ed0c3352f 100644
--- a/api/core/app/task_pipeline/based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py
@@ -19,6 +19,7 @@ from core.app.entities.task_entities import (
from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.moderation.output_moderation import ModerationRule, OutputModeration
+from models.enums import MessageStatus
from models.model import Message
logger = logging.getLogger(__name__)
@@ -62,7 +63,7 @@ class BasedGenerateTaskPipeline:
return err
err_desc = self._error_to_desc(err)
- message.status = "error"
+ message.status = MessageStatus.ERROR
message.error = err_desc
return err
diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
index d535e1f835..3c8c7bb5a2 100644
--- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
@@ -395,6 +395,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.provider_response_latency = time.perf_counter() - self._start_at
message.total_price = usage.total_price
message.currency = usage.currency
+ self._task_state.llm_result.usage.latency = message.provider_response_latency
message.message_metadata = self._task_state.metadata.model_dump_json()
if trace_manager:
diff --git a/api/core/app/task_pipeline/exc.py b/api/core/app/task_pipeline/exc.py
index e4b4168d08..df62776977 100644
--- a/api/core/app/task_pipeline/exc.py
+++ b/api/core/app/task_pipeline/exc.py
@@ -10,8 +10,3 @@ class RecordNotFoundError(TaskPipilineError):
class WorkflowRunNotFoundError(RecordNotFoundError):
def __init__(self, workflow_run_id: str):
super().__init__("WorkflowRun", workflow_run_id)
-
-
-class WorkflowNodeExecutionNotFoundError(RecordNotFoundError):
- def __init__(self, workflow_node_execution_id: str):
- super().__init__("WorkflowNodeExecution", workflow_node_execution_id)
diff --git a/api/core/entities/parameter_entities.py b/api/core/entities/parameter_entities.py
index 36800bc263..fbd62437e6 100644
--- a/api/core/entities/parameter_entities.py
+++ b/api/core/entities/parameter_entities.py
@@ -14,8 +14,17 @@ class CommonParameterType(StrEnum):
APP_SELECTOR = "app-selector"
MODEL_SELECTOR = "model-selector"
TOOLS_SELECTOR = "array[tools]"
+ ANY = "any"
+
+ # Dynamic select parameter
+ # Once you are not sure about the available options until authorization is done
+ # eg: Select a Slack channel from a Slack workspace
+ DYNAMIC_SELECT = "dynamic-select"
# TOOL_SELECTOR = "tool-selector"
+ # MCP object and array type parameters
+ ARRAY = "array"
+ OBJECT = "object"
class AppSelectorScope(StrEnum):
diff --git a/api/core/file/constants.py b/api/core/file/constants.py
index ce1d238e93..0665ed7e0d 100644
--- a/api/core/file/constants.py
+++ b/api/core/file/constants.py
@@ -1 +1,11 @@
+from typing import Any
+
+# TODO(QuantumGhost): Refactor variable type identification. Instead of directly
+# comparing `dify_model_identity` with constants throughout the codebase, extract
+# this logic into a dedicated function. This would encapsulate the implementation
+# details of how different variable types are identified.
FILE_MODEL_IDENTITY = "__dify__file__"
+
+
+def maybe_file_object(o: Any) -> bool:
+ return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY
diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py
index ada19ef8ce..f8c050c2ac 100644
--- a/api/core/file/file_manager.py
+++ b/api/core/file/file_manager.py
@@ -7,6 +7,7 @@ from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
+ TextPromptMessageContent,
VideoPromptMessageContent,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
@@ -44,11 +45,44 @@ def to_prompt_message_content(
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> PromptMessageContentUnionTypes:
+ """
+ Convert a file to prompt message content.
+
+ This function converts files to their appropriate prompt message content types.
+ For supported file types (IMAGE, AUDIO, VIDEO, DOCUMENT), it creates the
+ corresponding message content with proper encoding/URL.
+
+ For unsupported file types, instead of raising an error, it returns a
+ TextPromptMessageContent with a descriptive message about the file.
+
+ Args:
+ f: The file to convert
+ image_detail_config: Optional detail configuration for image files
+
+ Returns:
+ PromptMessageContentUnionTypes: The appropriate message content type
+
+ Raises:
+ ValueError: If file extension or mime_type is missing
+ """
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
raise ValueError("Missing file mime_type")
+ prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
+ FileType.IMAGE: ImagePromptMessageContent,
+ FileType.AUDIO: AudioPromptMessageContent,
+ FileType.VIDEO: VideoPromptMessageContent,
+ FileType.DOCUMENT: DocumentPromptMessageContent,
+ }
+
+ # Check if file type is supported
+ if f.type not in prompt_class_map:
+ # For unsupported file types, return a text description
+ return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]")
+
+ # Process supported file types
params = {
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
@@ -58,17 +92,7 @@ def to_prompt_message_content(
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
- prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
- FileType.IMAGE: ImagePromptMessageContent,
- FileType.AUDIO: AudioPromptMessageContent,
- FileType.VIDEO: VideoPromptMessageContent,
- FileType.DOCUMENT: DocumentPromptMessageContent,
- }
-
- try:
- return prompt_class_map[f.type].model_validate(params)
- except KeyError:
- raise ValueError(f"file type {f.type} is not supported")
+ return prompt_class_map[f.type].model_validate(params)
def download(f: File, /):
diff --git a/api/core/file/helpers.py b/api/core/file/helpers.py
index 73fabdb11b..335ad2266a 100644
--- a/api/core/file/helpers.py
+++ b/api/core/file/helpers.py
@@ -21,7 +21,9 @@ def get_signed_file_url(upload_file_id: str) -> str:
def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str:
- url = f"{dify_config.FILES_URL}/files/upload/for-plugin"
+ # Plugin access should use internal URL for Docker network communication
+ base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
+ url = f"{base_url}/files/upload/for-plugin"
if user_id is None:
user_id = "DEFAULT-USER"
diff --git a/api/core/file/models.py b/api/core/file/models.py
index aa3b5f629c..f61334e7bc 100644
--- a/api/core/file/models.py
+++ b/api/core/file/models.py
@@ -51,7 +51,7 @@ class File(BaseModel):
# It should be set to `ToolFile.id` when `transfer_method` is `tool_file`.
related_id: Optional[str] = None
filename: Optional[str] = None
- extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
+ extension: Optional[str] = Field(default=None, description="File extension, should contain dot")
mime_type: Optional[str] = None
size: int = -1
diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py
index 656c9d48ed..fac68beb0f 100644
--- a/api/core/file/tool_file_parser.py
+++ b/api/core/file/tool_file_parser.py
@@ -7,13 +7,6 @@ if TYPE_CHECKING:
_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None
-class ToolFileParser:
- @staticmethod
- def get_tool_file_manager() -> "ToolFileManager":
- assert _tool_file_manager_factory is not None
- return _tool_file_manager_factory()
-
-
def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None:
global _tool_file_manager_factory
_tool_file_manager_factory = factory
diff --git a/api/core/file/upload_file_parser.py b/api/core/file/upload_file_parser.py
deleted file mode 100644
index 96b2884811..0000000000
--- a/api/core/file/upload_file_parser.py
+++ /dev/null
@@ -1,67 +0,0 @@
-import base64
-import logging
-import time
-from typing import Optional
-
-from configs import dify_config
-from constants import IMAGE_EXTENSIONS
-from core.helper.url_signer import UrlSigner
-from extensions.ext_storage import storage
-
-
-class UploadFileParser:
- @classmethod
- def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
- if not upload_file:
- return None
-
- if upload_file.extension not in IMAGE_EXTENSIONS:
- return None
-
- if dify_config.MULTIMODAL_SEND_FORMAT == "url" or force_url:
- return cls.get_signed_temp_image_url(upload_file.id)
- else:
- # get image file base64
- try:
- data = storage.load(upload_file.key)
- except FileNotFoundError:
- logging.exception(f"File not found: {upload_file.key}")
- return None
-
- encoded_string = base64.b64encode(data).decode("utf-8")
- return f"data:{upload_file.mime_type};base64,{encoded_string}"
-
- @classmethod
- def get_signed_temp_image_url(cls, upload_file_id) -> str:
- """
- get signed url from upload file
-
- :param upload_file_id: the id of UploadFile object
- :return:
- """
- base_url = dify_config.FILES_URL
- image_preview_url = f"{base_url}/files/{upload_file_id}/image-preview"
-
- return UrlSigner.get_signed_url(url=image_preview_url, sign_key=upload_file_id, prefix="image-preview")
-
- @classmethod
- def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
- """
- verify signature
-
- :param upload_file_id: file id
- :param timestamp: timestamp
- :param nonce: nonce
- :param sign: signature
- :return:
- """
- result = UrlSigner.verify(
- sign_key=upload_file_id, timestamp=timestamp, nonce=nonce, sign=sign, prefix="image-preview"
- )
-
- # verify signature
- if not result:
- return False
-
- current_time = int(time.time())
- return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py
index baa792b5bc..b416e48ce4 100644
--- a/api/core/helper/code_executor/template_transformer.py
+++ b/api/core/helper/code_executor/template_transformer.py
@@ -5,6 +5,8 @@ from base64 import b64encode
from collections.abc import Mapping
from typing import Any
+from core.variables.utils import SegmentJSONEncoder
+
class TemplateTransformer(ABC):
_code_placeholder: str = "{{code}}"
@@ -28,7 +30,7 @@ class TemplateTransformer(ABC):
def extract_result_str_from_response(cls, response: str):
result = re.search(rf"{cls._result_tag}(.*){cls._result_tag}", response, re.DOTALL)
if not result:
- raise ValueError("Failed to parse result")
+ raise ValueError(f"Failed to parse result: no result tag found in response. Response: {response[:200]}...")
return result.group(1)
@classmethod
@@ -38,16 +40,49 @@ class TemplateTransformer(ABC):
:param response: response
:return:
"""
+
try:
- result = json.loads(cls.extract_result_str_from_response(response))
- except json.JSONDecodeError:
- raise ValueError("failed to parse response")
+ result_str = cls.extract_result_str_from_response(response)
+ result = json.loads(result_str)
+ except json.JSONDecodeError as e:
+ raise ValueError(f"Failed to parse JSON response: {str(e)}.")
+ except ValueError as e:
+ # Re-raise ValueError from extract_result_str_from_response
+ raise e
+ except Exception as e:
+ raise ValueError(f"Unexpected error during response transformation: {str(e)}")
+
if not isinstance(result, dict):
- raise ValueError("result must be a dict")
+ raise ValueError(f"Result must be a dict, got {type(result).__name__}")
if not all(isinstance(k, str) for k in result):
- raise ValueError("result keys must be strings")
+ raise ValueError("Result keys must be strings")
+
+ # Post-process the result to convert scientific notation strings back to numbers
+ result = cls._post_process_result(result)
return result
+ @classmethod
+ def _post_process_result(cls, result: dict[Any, Any]) -> dict[Any, Any]:
+ """
+ Post-process the result to convert scientific notation strings back to numbers
+ """
+
+ def convert_scientific_notation(value):
+ if isinstance(value, str):
+ # Check if the string looks like scientific notation
+ if re.match(r"^-?\d+\.?\d*e[+-]\d+$", value, re.IGNORECASE):
+ try:
+ return float(value)
+ except ValueError:
+ pass
+ elif isinstance(value, dict):
+ return {k: convert_scientific_notation(v) for k, v in value.items()}
+ elif isinstance(value, list):
+ return [convert_scientific_notation(v) for v in value]
+ return value
+
+ return convert_scientific_notation(result) # type: ignore[no-any-return]
+
@classmethod
@abstractmethod
def get_runner_script(cls) -> str:
@@ -58,7 +93,7 @@ class TemplateTransformer(ABC):
@classmethod
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
- inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode()
+ inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded
diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py
index 744fce1cf9..1e40997a8b 100644
--- a/api/core/helper/encrypter.py
+++ b/api/core/helper/encrypter.py
@@ -21,7 +21,7 @@ def encrypt_token(tenant_id: str, token: str):
return base64.b64encode(encrypted_token).decode()
-def decrypt_token(tenant_id: str, token: str):
+def decrypt_token(tenant_id: str, token: str) -> str:
return rsa.decrypt(base64.b64decode(token), tenant_id)
diff --git a/api/core/helper/lru_cache.py b/api/core/helper/lru_cache.py
deleted file mode 100644
index 81501d2e4e..0000000000
--- a/api/core/helper/lru_cache.py
+++ /dev/null
@@ -1,22 +0,0 @@
-from collections import OrderedDict
-from typing import Any
-
-
-class LRUCache:
- def __init__(self, capacity: int):
- self.cache: OrderedDict[Any, Any] = OrderedDict()
- self.capacity = capacity
-
- def get(self, key: Any) -> Any:
- if key not in self.cache:
- return None
- else:
- self.cache.move_to_end(key) # move the key to the end of the OrderedDict
- return self.cache[key]
-
- def put(self, key: Any, value: Any) -> None:
- if key in self.cache:
- self.cache.move_to_end(key)
- self.cache[key] = value
- if len(self.cache) > self.capacity:
- self.cache.popitem(last=False) # pop the first item
diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py
new file mode 100644
index 0000000000..48ec3be5c8
--- /dev/null
+++ b/api/core/helper/provider_cache.py
@@ -0,0 +1,84 @@
+import json
+from abc import ABC, abstractmethod
+from json import JSONDecodeError
+from typing import Any, Optional
+
+from extensions.ext_redis import redis_client
+
+
+class ProviderCredentialsCache(ABC):
+ """Base class for provider credentials cache"""
+
+ def __init__(self, **kwargs):
+ self.cache_key = self._generate_cache_key(**kwargs)
+
+ @abstractmethod
+ def _generate_cache_key(self, **kwargs) -> str:
+ """Generate cache key based on subclass implementation"""
+ pass
+
+ def get(self) -> Optional[dict]:
+ """Get cached provider credentials"""
+ cached_credentials = redis_client.get(self.cache_key)
+ if cached_credentials:
+ try:
+ cached_credentials = cached_credentials.decode("utf-8")
+ return dict(json.loads(cached_credentials))
+ except JSONDecodeError:
+ return None
+ return None
+
+ def set(self, config: dict[str, Any]) -> None:
+ """Cache provider credentials"""
+ redis_client.setex(self.cache_key, 86400, json.dumps(config))
+
+ def delete(self) -> None:
+ """Delete cached provider credentials"""
+ redis_client.delete(self.cache_key)
+
+
+class SingletonProviderCredentialsCache(ProviderCredentialsCache):
+ """Cache for tool single provider credentials"""
+
+ def __init__(self, tenant_id: str, provider_type: str, provider_identity: str):
+ super().__init__(
+ tenant_id=tenant_id,
+ provider_type=provider_type,
+ provider_identity=provider_identity,
+ )
+
+ def _generate_cache_key(self, **kwargs) -> str:
+ tenant_id = kwargs["tenant_id"]
+ provider_type = kwargs["provider_type"]
+ identity_name = kwargs["provider_identity"]
+ identity_id = f"{provider_type}.{identity_name}"
+ return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
+
+
+class ToolProviderCredentialsCache(ProviderCredentialsCache):
+ """Cache for tool provider credentials"""
+
+ def __init__(self, tenant_id: str, provider: str, credential_id: str):
+ super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
+
+ def _generate_cache_key(self, **kwargs) -> str:
+ tenant_id = kwargs["tenant_id"]
+ provider = kwargs["provider"]
+ credential_id = kwargs["credential_id"]
+ return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
+
+
+class NoOpProviderCredentialCache:
+ """No-op provider credential cache"""
+
+ def get(self) -> Optional[dict]:
+ """Get cached provider credentials"""
+ return None
+
+ def set(self, config: dict[str, Any]) -> None:
+ """Cache provider credentials"""
+ pass
+
+ def delete(self) -> None:
+ """Delete cached provider credentials"""
+ pass
diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py
deleted file mode 100644
index 2e4a04c579..0000000000
--- a/api/core/helper/tool_provider_cache.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import json
-from enum import Enum
-from json import JSONDecodeError
-from typing import Optional
-
-from extensions.ext_redis import redis_client
-
-
-class ToolProviderCredentialsCacheType(Enum):
- PROVIDER = "tool_provider"
- ENDPOINT = "endpoint"
-
-
-class ToolProviderCredentialsCache:
- def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
- self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
-
- def get(self) -> Optional[dict]:
- """
- Get cached model provider credentials.
-
- :return:
- """
- cached_provider_credentials = redis_client.get(self.cache_key)
- if cached_provider_credentials:
- try:
- cached_provider_credentials = cached_provider_credentials.decode("utf-8")
- cached_provider_credentials = json.loads(cached_provider_credentials)
- except JSONDecodeError:
- return None
-
- return dict(cached_provider_credentials)
- else:
- return None
-
- def set(self, credentials: dict) -> None:
- """
- Cache model provider credentials.
-
- :param credentials: provider credentials
- :return:
- """
- redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
-
- def delete(self) -> None:
- """
- Delete cached model provider credentials.
-
- :return:
- """
- redis_client.delete(self.cache_key)
diff --git a/api/core/helper/url_signer.py b/api/core/helper/url_signer.py
deleted file mode 100644
index dfb143f4c4..0000000000
--- a/api/core/helper/url_signer.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import base64
-import hashlib
-import hmac
-import os
-import time
-
-from pydantic import BaseModel, Field
-
-from configs import dify_config
-
-
-class SignedUrlParams(BaseModel):
- sign_key: str = Field(..., description="The sign key")
- timestamp: str = Field(..., description="Timestamp")
- nonce: str = Field(..., description="Nonce")
- sign: str = Field(..., description="Signature")
-
-
-class UrlSigner:
- @classmethod
- def get_signed_url(cls, url: str, sign_key: str, prefix: str) -> str:
- signed_url_params = cls.get_signed_url_params(sign_key, prefix)
- return (
- f"{url}?timestamp={signed_url_params.timestamp}"
- f"&nonce={signed_url_params.nonce}&sign={signed_url_params.sign}"
- )
-
- @classmethod
- def get_signed_url_params(cls, sign_key: str, prefix: str) -> SignedUrlParams:
- timestamp = str(int(time.time()))
- nonce = os.urandom(16).hex()
- sign = cls._sign(sign_key, timestamp, nonce, prefix)
-
- return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign)
-
- @classmethod
- def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool:
- recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix)
-
- return sign == recalculated_sign
-
- @classmethod
- def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str:
- if not dify_config.SECRET_KEY:
- raise Exception("SECRET_KEY is not set")
-
- data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}"
- secret_key = dify_config.SECRET_KEY.encode()
- sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
- encoded_sign = base64.urlsafe_b64encode(sign).decode()
-
- return encoded_sign
diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py
index 848d897779..305a9190d5 100644
--- a/api/core/indexing_runner.py
+++ b/api/core/indexing_runner.py
@@ -317,9 +317,10 @@ class IndexingRunner:
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
+ if image_file is None:
+ continue
try:
- if image_file:
- storage.delete(image_file.key)
+ storage.delete(image_file.key)
except Exception:
logging.exception(
"Delete image_files failed while indexing_estimate, \
@@ -534,7 +535,7 @@ class IndexingRunner:
# chunk nodes by chunk size
indexing_start_at = time.perf_counter()
tokens = 0
- if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
+ if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
@@ -572,7 +573,7 @@ class IndexingRunner:
for future in futures:
tokens += future.result()
- if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
+ if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
create_keyword_thread.join()
indexing_end_at = time.perf_counter()
diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py
index e01896a491..f7fd93be4a 100644
--- a/api/core/llm_generator/llm_generator.py
+++ b/api/core/llm_generator/llm_generator.py
@@ -148,9 +148,11 @@ class LLMGenerator:
model_manager = ModelManager()
- model_instance = model_manager.get_default_model_instance(
+ model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
+ provider=model_config.get("provider", ""),
+ model=model_config.get("name", ""),
)
try:
diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py
new file mode 100644
index 0000000000..151cef1bc3
--- /dev/null
+++ b/api/core/llm_generator/output_parser/structured_output.py
@@ -0,0 +1,380 @@
+import json
+from collections.abc import Generator, Mapping, Sequence
+from copy import deepcopy
+from enum import StrEnum
+from typing import Any, Literal, Optional, cast, overload
+
+import json_repair
+from pydantic import TypeAdapter, ValidationError
+
+from core.llm_generator.output_parser.errors import OutputParserError
+from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
+from core.model_manager import ModelInstance
+from core.model_runtime.callbacks.base_callback import Callback
+from core.model_runtime.entities.llm_entities import (
+ LLMResult,
+ LLMResultChunk,
+ LLMResultChunkDelta,
+ LLMResultChunkWithStructuredOutput,
+ LLMResultWithStructuredOutput,
+)
+from core.model_runtime.entities.message_entities import (
+ AssistantPromptMessage,
+ PromptMessage,
+ PromptMessageTool,
+ SystemPromptMessage,
+ TextPromptMessageContent,
+)
+from core.model_runtime.entities.model_entities import AIModelEntity, ParameterRule
+
+
+class ResponseFormat(StrEnum):
+ """Constants for model response formats"""
+
+ JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
+ JSON = "JSON" # model's json mode. some model like claude support this mode.
+ JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
+
+
+class SpecialModelType(StrEnum):
+ """Constants for identifying model types"""
+
+ GEMINI = "gemini"
+ OLLAMA = "ollama"
+
+
+@overload
+def invoke_llm_with_structured_output(
+ provider: str,
+ model_schema: AIModelEntity,
+ model_instance: ModelInstance,
+ prompt_messages: Sequence[PromptMessage],
+ json_schema: Mapping[str, Any],
+ model_parameters: Optional[Mapping] = None,
+ tools: Sequence[PromptMessageTool] | None = None,
+ stop: Optional[list[str]] = None,
+ stream: Literal[True] = True,
+ user: Optional[str] = None,
+ callbacks: Optional[list[Callback]] = None,
+) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
+
+
+@overload
+def invoke_llm_with_structured_output(
+ provider: str,
+ model_schema: AIModelEntity,
+ model_instance: ModelInstance,
+ prompt_messages: Sequence[PromptMessage],
+ json_schema: Mapping[str, Any],
+ model_parameters: Optional[Mapping] = None,
+ tools: Sequence[PromptMessageTool] | None = None,
+ stop: Optional[list[str]] = None,
+ stream: Literal[False] = False,
+ user: Optional[str] = None,
+ callbacks: Optional[list[Callback]] = None,
+) -> LLMResultWithStructuredOutput: ...
+
+
+@overload
+def invoke_llm_with_structured_output(
+ provider: str,
+ model_schema: AIModelEntity,
+ model_instance: ModelInstance,
+ prompt_messages: Sequence[PromptMessage],
+ json_schema: Mapping[str, Any],
+ model_parameters: Optional[Mapping] = None,
+ tools: Sequence[PromptMessageTool] | None = None,
+ stop: Optional[list[str]] = None,
+ stream: bool = True,
+ user: Optional[str] = None,
+ callbacks: Optional[list[Callback]] = None,
+) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
+
+
+def invoke_llm_with_structured_output(
+ provider: str,
+ model_schema: AIModelEntity,
+ model_instance: ModelInstance,
+ prompt_messages: Sequence[PromptMessage],
+ json_schema: Mapping[str, Any],
+ model_parameters: Optional[Mapping] = None,
+ tools: Sequence[PromptMessageTool] | None = None,
+ stop: Optional[list[str]] = None,
+ stream: bool = True,
+ user: Optional[str] = None,
+ callbacks: Optional[list[Callback]] = None,
+) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
+ """
+ Invoke large language model with structured output
+ 1. This method invokes model_instance.invoke_llm with json_schema
+ 2. Try to parse the result as structured output
+
+ :param prompt_messages: prompt messages
+ :param json_schema: json schema
+ :param model_parameters: model parameters
+ :param tools: tools for tool calling
+ :param stop: stop words
+ :param stream: is stream response
+ :param user: unique user id
+ :param callbacks: callbacks
+ :return: full response or stream response chunk generator result
+ """
+
+ # handle native json schema
+ model_parameters_with_json_schema: dict[str, Any] = {
+ **(model_parameters or {}),
+ }
+
+ if model_schema.support_structure_output:
+ model_parameters = _handle_native_json_schema(
+ provider, model_schema, json_schema, model_parameters_with_json_schema, model_schema.parameter_rules
+ )
+ else:
+ # Set appropriate response format based on model capabilities
+ _set_response_format(model_parameters_with_json_schema, model_schema.parameter_rules)
+
+ # handle prompt based schema
+ prompt_messages = _handle_prompt_based_schema(
+ prompt_messages=prompt_messages,
+ structured_output_schema=json_schema,
+ )
+
+ llm_result = model_instance.invoke_llm(
+ prompt_messages=list(prompt_messages),
+ model_parameters=model_parameters_with_json_schema,
+ tools=tools,
+ stop=stop,
+ stream=stream,
+ user=user,
+ callbacks=callbacks,
+ )
+
+ if isinstance(llm_result, LLMResult):
+ if not isinstance(llm_result.message.content, str):
+ raise OutputParserError(
+ f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
+ )
+
+ return LLMResultWithStructuredOutput(
+ structured_output=_parse_structured_output(llm_result.message.content),
+ model=llm_result.model,
+ message=llm_result.message,
+ usage=llm_result.usage,
+ system_fingerprint=llm_result.system_fingerprint,
+ prompt_messages=llm_result.prompt_messages,
+ )
+ else:
+
+ def generator() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
+ result_text: str = ""
+ prompt_messages: Sequence[PromptMessage] = []
+ system_fingerprint: Optional[str] = None
+ for event in llm_result:
+ if isinstance(event, LLMResultChunk):
+ prompt_messages = event.prompt_messages
+ system_fingerprint = event.system_fingerprint
+
+ if isinstance(event.delta.message.content, str):
+ result_text += event.delta.message.content
+ elif isinstance(event.delta.message.content, list):
+ for item in event.delta.message.content:
+ if isinstance(item, TextPromptMessageContent):
+ result_text += item.data
+
+ yield LLMResultChunkWithStructuredOutput(
+ model=model_schema.model,
+ prompt_messages=prompt_messages,
+ system_fingerprint=system_fingerprint,
+ delta=event.delta,
+ )
+
+ yield LLMResultChunkWithStructuredOutput(
+ structured_output=_parse_structured_output(result_text),
+ model=model_schema.model,
+ prompt_messages=prompt_messages,
+ system_fingerprint=system_fingerprint,
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=AssistantPromptMessage(content=""),
+ usage=None,
+ finish_reason=None,
+ ),
+ )
+
+ return generator()
+
+
+def _handle_native_json_schema(
+ provider: str,
+ model_schema: AIModelEntity,
+ structured_output_schema: Mapping,
+ model_parameters: dict,
+ rules: list[ParameterRule],
+) -> dict:
+ """
+ Handle structured output for models with native JSON schema support.
+
+ :param model_parameters: Model parameters to update
+ :param rules: Model parameter rules
+ :return: Updated model parameters with JSON schema configuration
+ """
+ # Process schema according to model requirements
+ schema_json = _prepare_schema_for_model(provider, model_schema, structured_output_schema)
+
+ # Set JSON schema in parameters
+ model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
+
+ # Set appropriate response format if required by the model
+ for rule in rules:
+ if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
+ model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
+
+ return model_parameters
+
+
+def _set_response_format(model_parameters: dict, rules: list) -> None:
+ """
+ Set the appropriate response format parameter based on model rules.
+
+ :param model_parameters: Model parameters to update
+ :param rules: Model parameter rules
+ """
+ for rule in rules:
+ if rule.name == "response_format":
+ if ResponseFormat.JSON.value in rule.options:
+ model_parameters["response_format"] = ResponseFormat.JSON.value
+ elif ResponseFormat.JSON_OBJECT.value in rule.options:
+ model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
+
+
+def _handle_prompt_based_schema(
+ prompt_messages: Sequence[PromptMessage], structured_output_schema: Mapping
+) -> list[PromptMessage]:
+ """
+ Handle structured output for models without native JSON schema support.
+ This function modifies the prompt messages to include schema-based output requirements.
+
+ Args:
+ prompt_messages: Original sequence of prompt messages
+
+ Returns:
+ list[PromptMessage]: Updated prompt messages with structured output requirements
+ """
+ # Convert schema to string format
+ schema_str = json.dumps(structured_output_schema, ensure_ascii=False)
+
+ # Find existing system prompt with schema placeholder
+ system_prompt = next(
+ (prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
+ None,
+ )
+ structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
+ # Prepare system prompt content
+ system_prompt_content = (
+ structured_output_prompt + "\n\n" + system_prompt.content
+ if system_prompt and isinstance(system_prompt.content, str)
+ else structured_output_prompt
+ )
+ system_prompt = SystemPromptMessage(content=system_prompt_content)
+
+ # Extract content from the last user message
+
+ filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
+ updated_prompt = [system_prompt] + filtered_prompts
+
+ return updated_prompt
+
+
+def _parse_structured_output(result_text: str) -> Mapping[str, Any]:
+ structured_output: Mapping[str, Any] = {}
+ parsed: Mapping[str, Any] = {}
+ try:
+ parsed = TypeAdapter(Mapping).validate_json(result_text)
+ if not isinstance(parsed, dict):
+ raise OutputParserError(f"Failed to parse structured output: {result_text}")
+ structured_output = parsed
+ except ValidationError:
+ # if the result_text is not a valid json, try to repair it
+ temp_parsed = json_repair.loads(result_text)
+ if not isinstance(temp_parsed, dict):
+ # handle reasoning model like deepseek-r1 got '\n\n\n' prefix
+ if isinstance(temp_parsed, list):
+ temp_parsed = next((item for item in temp_parsed if isinstance(item, dict)), {})
+ else:
+ raise OutputParserError(f"Failed to parse structured output: {result_text}")
+ structured_output = cast(dict, temp_parsed)
+ return structured_output
+
+
+def _prepare_schema_for_model(provider: str, model_schema: AIModelEntity, schema: Mapping) -> dict:
+ """
+ Prepare JSON schema based on model requirements.
+
+ Different models have different requirements for JSON schema formatting.
+ This function handles these differences.
+
+ :param schema: The original JSON schema
+ :return: Processed schema compatible with the current model
+ """
+
+ # Deep copy to avoid modifying the original schema
+ processed_schema = dict(deepcopy(schema))
+
+ # Convert boolean types to string types (common requirement)
+ convert_boolean_to_string(processed_schema)
+
+ # Apply model-specific transformations
+ if SpecialModelType.GEMINI in model_schema.model:
+ remove_additional_properties(processed_schema)
+ return processed_schema
+ elif SpecialModelType.OLLAMA in provider:
+ return processed_schema
+ else:
+ # Default format with name field
+ return {"schema": processed_schema, "name": "llm_response"}
+
+
+def remove_additional_properties(schema: dict) -> None:
+ """
+ Remove additionalProperties fields from JSON schema.
+ Used for models like Gemini that don't support this property.
+
+ :param schema: JSON schema to modify in-place
+ """
+ if not isinstance(schema, dict):
+ return
+
+ # Remove additionalProperties at current level
+ schema.pop("additionalProperties", None)
+
+ # Process nested structures recursively
+ for value in schema.values():
+ if isinstance(value, dict):
+ remove_additional_properties(value)
+ elif isinstance(value, list):
+ for item in value:
+ if isinstance(item, dict):
+ remove_additional_properties(item)
+
+
+def convert_boolean_to_string(schema: dict) -> None:
+ """
+ Convert boolean type specifications to string in JSON schema.
+
+ :param schema: JSON schema to modify in-place
+ """
+ if not isinstance(schema, dict):
+ return
+
+ # Check for boolean type at current level
+ if schema.get("type") == "boolean":
+ schema["type"] = "string"
+
+ # Process nested dictionaries and lists recursively
+ for value in schema.values():
+ if isinstance(value, dict):
+ convert_boolean_to_string(value)
+ elif isinstance(value, list):
+ for item in value:
+ if isinstance(item, dict):
+ convert_boolean_to_string(item)
diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py
index ddfa1e7a66..ef81e38dc5 100644
--- a/api/core/llm_generator/prompts.py
+++ b/api/core/llm_generator/prompts.py
@@ -291,3 +291,21 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc
Now, generate a JSON Schema based on my description
""" # noqa: E501
+
+STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format.
+constraints:
+ - You must output in JSON format.
+ - Do not output boolean value, use string type instead.
+ - Do not output integer or float value, use number type instead.
+eg:
+ Here is the JSON schema:
+ {"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
+
+ Here is the user's question:
+ My name is John Doe and I am 30 years old.
+
+ output:
+ {"name": "John Doe", "age": 30}
+Here is the JSON schema:
+{{schema}}
+""" # noqa: E501
diff --git a/api/core/mcp/__init__.py b/api/core/mcp/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py
new file mode 100644
index 0000000000..bcb31a816f
--- /dev/null
+++ b/api/core/mcp/auth/auth_flow.py
@@ -0,0 +1,342 @@
+import base64
+import hashlib
+import json
+import os
+import secrets
+import urllib.parse
+from typing import Optional
+from urllib.parse import urljoin
+
+import requests
+from pydantic import BaseModel, ValidationError
+
+from core.mcp.auth.auth_provider import OAuthClientProvider
+from core.mcp.types import (
+ OAuthClientInformation,
+ OAuthClientInformationFull,
+ OAuthClientMetadata,
+ OAuthMetadata,
+ OAuthTokens,
+)
+from extensions.ext_redis import redis_client
+
+LATEST_PROTOCOL_VERSION = "1.0"
+OAUTH_STATE_EXPIRY_SECONDS = 5 * 60 # 5 minutes expiry
+OAUTH_STATE_REDIS_KEY_PREFIX = "oauth_state:"
+
+
+class OAuthCallbackState(BaseModel):
+ provider_id: str
+ tenant_id: str
+ server_url: str
+ metadata: OAuthMetadata | None = None
+ client_information: OAuthClientInformation
+ code_verifier: str
+ redirect_uri: str
+
+
+def generate_pkce_challenge() -> tuple[str, str]:
+ """Generate PKCE challenge and verifier."""
+ code_verifier = base64.urlsafe_b64encode(os.urandom(40)).decode("utf-8")
+ code_verifier = code_verifier.replace("=", "").replace("+", "-").replace("/", "_")
+
+ code_challenge_hash = hashlib.sha256(code_verifier.encode("utf-8")).digest()
+ code_challenge = base64.urlsafe_b64encode(code_challenge_hash).decode("utf-8")
+ code_challenge = code_challenge.replace("=", "").replace("+", "-").replace("/", "_")
+
+ return code_verifier, code_challenge
+
+
+def _create_secure_redis_state(state_data: OAuthCallbackState) -> str:
+ """Create a secure state parameter by storing state data in Redis and returning a random state key."""
+ # Generate a secure random state key
+ state_key = secrets.token_urlsafe(32)
+
+ # Store the state data in Redis with expiration
+ redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
+ redis_client.setex(redis_key, OAUTH_STATE_EXPIRY_SECONDS, state_data.model_dump_json())
+
+ return state_key
+
+
+def _retrieve_redis_state(state_key: str) -> OAuthCallbackState:
+ """Retrieve and decode OAuth state data from Redis using the state key, then delete it."""
+ redis_key = f"{OAUTH_STATE_REDIS_KEY_PREFIX}{state_key}"
+
+ # Get state data from Redis
+ state_data = redis_client.get(redis_key)
+
+ if not state_data:
+ raise ValueError("State parameter has expired or does not exist")
+
+ # Delete the state data from Redis immediately after retrieval to prevent reuse
+ redis_client.delete(redis_key)
+
+ try:
+ # Parse and validate the state data
+ oauth_state = OAuthCallbackState.model_validate_json(state_data)
+
+ return oauth_state
+ except ValidationError as e:
+ raise ValueError(f"Invalid state parameter: {str(e)}")
+
+
+def handle_callback(state_key: str, authorization_code: str) -> OAuthCallbackState:
+ """Handle the callback from the OAuth provider."""
+ # Retrieve state data from Redis (state is automatically deleted after retrieval)
+ full_state_data = _retrieve_redis_state(state_key)
+
+ tokens = exchange_authorization(
+ full_state_data.server_url,
+ full_state_data.metadata,
+ full_state_data.client_information,
+ authorization_code,
+ full_state_data.code_verifier,
+ full_state_data.redirect_uri,
+ )
+ provider = OAuthClientProvider(full_state_data.provider_id, full_state_data.tenant_id, for_list=True)
+ provider.save_tokens(tokens)
+ return full_state_data
+
+
+def discover_oauth_metadata(server_url: str, protocol_version: Optional[str] = None) -> Optional[OAuthMetadata]:
+ """Looks up RFC 8414 OAuth 2.0 Authorization Server Metadata."""
+ url = urljoin(server_url, "/.well-known/oauth-authorization-server")
+
+ try:
+ headers = {"MCP-Protocol-Version": protocol_version or LATEST_PROTOCOL_VERSION}
+ response = requests.get(url, headers=headers)
+ if response.status_code == 404:
+ return None
+ if not response.ok:
+ raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
+ return OAuthMetadata.model_validate(response.json())
+ except requests.RequestException as e:
+ if isinstance(e, requests.ConnectionError):
+ response = requests.get(url)
+ if response.status_code == 404:
+ return None
+ if not response.ok:
+ raise ValueError(f"HTTP {response.status_code} trying to load well-known OAuth metadata")
+ return OAuthMetadata.model_validate(response.json())
+ raise
+
+
+def start_authorization(
+ server_url: str,
+ metadata: Optional[OAuthMetadata],
+ client_information: OAuthClientInformation,
+ redirect_url: str,
+ provider_id: str,
+ tenant_id: str,
+) -> tuple[str, str]:
+ """Begins the authorization flow with secure Redis state storage."""
+ response_type = "code"
+ code_challenge_method = "S256"
+
+ if metadata:
+ authorization_url = metadata.authorization_endpoint
+ if response_type not in metadata.response_types_supported:
+ raise ValueError(f"Incompatible auth server: does not support response type {response_type}")
+ if (
+ not metadata.code_challenge_methods_supported
+ or code_challenge_method not in metadata.code_challenge_methods_supported
+ ):
+ raise ValueError(
+ f"Incompatible auth server: does not support code challenge method {code_challenge_method}"
+ )
+ else:
+ authorization_url = urljoin(server_url, "/authorize")
+
+ code_verifier, code_challenge = generate_pkce_challenge()
+
+ # Prepare state data with all necessary information
+ state_data = OAuthCallbackState(
+ provider_id=provider_id,
+ tenant_id=tenant_id,
+ server_url=server_url,
+ metadata=metadata,
+ client_information=client_information,
+ code_verifier=code_verifier,
+ redirect_uri=redirect_url,
+ )
+
+ # Store state data in Redis and generate secure state key
+ state_key = _create_secure_redis_state(state_data)
+
+ params = {
+ "response_type": response_type,
+ "client_id": client_information.client_id,
+ "code_challenge": code_challenge,
+ "code_challenge_method": code_challenge_method,
+ "redirect_uri": redirect_url,
+ "state": state_key,
+ }
+
+ authorization_url = f"{authorization_url}?{urllib.parse.urlencode(params)}"
+ return authorization_url, code_verifier
+
+
+def exchange_authorization(
+ server_url: str,
+ metadata: Optional[OAuthMetadata],
+ client_information: OAuthClientInformation,
+ authorization_code: str,
+ code_verifier: str,
+ redirect_uri: str,
+) -> OAuthTokens:
+ """Exchanges an authorization code for an access token."""
+ grant_type = "authorization_code"
+
+ if metadata:
+ token_url = metadata.token_endpoint
+ if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
+ raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
+ else:
+ token_url = urljoin(server_url, "/token")
+
+ params = {
+ "grant_type": grant_type,
+ "client_id": client_information.client_id,
+ "code": authorization_code,
+ "code_verifier": code_verifier,
+ "redirect_uri": redirect_uri,
+ }
+
+ if client_information.client_secret:
+ params["client_secret"] = client_information.client_secret
+
+ response = requests.post(token_url, data=params)
+ if not response.ok:
+ raise ValueError(f"Token exchange failed: HTTP {response.status_code}")
+ return OAuthTokens.model_validate(response.json())
+
+
+def refresh_authorization(
+ server_url: str,
+ metadata: Optional[OAuthMetadata],
+ client_information: OAuthClientInformation,
+ refresh_token: str,
+) -> OAuthTokens:
+ """Exchange a refresh token for an updated access token."""
+ grant_type = "refresh_token"
+
+ if metadata:
+ token_url = metadata.token_endpoint
+ if metadata.grant_types_supported and grant_type not in metadata.grant_types_supported:
+ raise ValueError(f"Incompatible auth server: does not support grant type {grant_type}")
+ else:
+ token_url = urljoin(server_url, "/token")
+
+ params = {
+ "grant_type": grant_type,
+ "client_id": client_information.client_id,
+ "refresh_token": refresh_token,
+ }
+
+ if client_information.client_secret:
+ params["client_secret"] = client_information.client_secret
+
+ response = requests.post(token_url, data=params)
+ if not response.ok:
+ raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
+ return OAuthTokens.model_validate(response.json())
+
+
+def register_client(
+ server_url: str,
+ metadata: Optional[OAuthMetadata],
+ client_metadata: OAuthClientMetadata,
+) -> OAuthClientInformationFull:
+ """Performs OAuth 2.0 Dynamic Client Registration."""
+ if metadata:
+ if not metadata.registration_endpoint:
+ raise ValueError("Incompatible auth server: does not support dynamic client registration")
+ registration_url = metadata.registration_endpoint
+ else:
+ registration_url = urljoin(server_url, "/register")
+
+ response = requests.post(
+ registration_url,
+ json=client_metadata.model_dump(),
+ headers={"Content-Type": "application/json"},
+ )
+ if not response.ok:
+ response.raise_for_status()
+ return OAuthClientInformationFull.model_validate(response.json())
+
+
+def auth(
+ provider: OAuthClientProvider,
+ server_url: str,
+ authorization_code: Optional[str] = None,
+ state_param: Optional[str] = None,
+ for_list: bool = False,
+) -> dict[str, str]:
+ """Orchestrates the full auth flow with a server using secure Redis state storage."""
+ metadata = discover_oauth_metadata(server_url)
+
+ # Handle client registration if needed
+ client_information = provider.client_information()
+ if not client_information:
+ if authorization_code is not None:
+ raise ValueError("Existing OAuth client information is required when exchanging an authorization code")
+ try:
+ full_information = register_client(server_url, metadata, provider.client_metadata)
+ except requests.RequestException as e:
+ raise ValueError(f"Could not register OAuth client: {e}")
+ provider.save_client_information(full_information)
+ client_information = full_information
+
+ # Exchange authorization code for tokens
+ if authorization_code is not None:
+ if not state_param:
+ raise ValueError("State parameter is required when exchanging authorization code")
+
+ try:
+ # Retrieve state data from Redis using state key
+ full_state_data = _retrieve_redis_state(state_param)
+
+ code_verifier = full_state_data.code_verifier
+ redirect_uri = full_state_data.redirect_uri
+
+ if not code_verifier or not redirect_uri:
+ raise ValueError("Missing code_verifier or redirect_uri in state data")
+
+ except (json.JSONDecodeError, ValueError) as e:
+ raise ValueError(f"Invalid state parameter: {e}")
+
+ tokens = exchange_authorization(
+ server_url,
+ metadata,
+ client_information,
+ authorization_code,
+ code_verifier,
+ redirect_uri,
+ )
+ provider.save_tokens(tokens)
+ return {"result": "success"}
+
+ provider_tokens = provider.tokens()
+
+ # Handle token refresh or new authorization
+ if provider_tokens and provider_tokens.refresh_token:
+ try:
+ new_tokens = refresh_authorization(server_url, metadata, client_information, provider_tokens.refresh_token)
+ provider.save_tokens(new_tokens)
+ return {"result": "success"}
+ except Exception as e:
+ raise ValueError(f"Could not refresh OAuth tokens: {e}")
+
+ # Start new authorization flow
+ authorization_url, code_verifier = start_authorization(
+ server_url,
+ metadata,
+ client_information,
+ provider.redirect_url,
+ provider.mcp_provider.id,
+ provider.mcp_provider.tenant_id,
+ )
+
+ provider.save_code_verifier(code_verifier)
+ return {"authorization_url": authorization_url}
diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py
new file mode 100644
index 0000000000..00d5a25956
--- /dev/null
+++ b/api/core/mcp/auth/auth_provider.py
@@ -0,0 +1,81 @@
+from typing import Optional
+
+from configs import dify_config
+from core.mcp.types import (
+ OAuthClientInformation,
+ OAuthClientInformationFull,
+ OAuthClientMetadata,
+ OAuthTokens,
+)
+from models.tools import MCPToolProvider
+from services.tools.mcp_tools_manage_service import MCPToolManageService
+
+LATEST_PROTOCOL_VERSION = "1.0"
+
+
+class OAuthClientProvider:
+ mcp_provider: MCPToolProvider
+
+ def __init__(self, provider_id: str, tenant_id: str, for_list: bool = False):
+ if for_list:
+ self.mcp_provider = MCPToolManageService.get_mcp_provider_by_provider_id(provider_id, tenant_id)
+ else:
+ self.mcp_provider = MCPToolManageService.get_mcp_provider_by_server_identifier(provider_id, tenant_id)
+
+ @property
+ def redirect_url(self) -> str:
+ """The URL to redirect the user agent to after authorization."""
+ return dify_config.CONSOLE_API_URL + "/console/api/mcp/oauth/callback"
+
+ @property
+ def client_metadata(self) -> OAuthClientMetadata:
+ """Metadata about this OAuth client."""
+ return OAuthClientMetadata(
+ redirect_uris=[self.redirect_url],
+ token_endpoint_auth_method="none",
+ grant_types=["authorization_code", "refresh_token"],
+ response_types=["code"],
+ client_name="Dify",
+ client_uri="https://github.com/langgenius/dify",
+ )
+
+ def client_information(self) -> Optional[OAuthClientInformation]:
+ """Loads information about this OAuth client."""
+ client_information = self.mcp_provider.decrypted_credentials.get("client_information", {})
+ if not client_information:
+ return None
+ return OAuthClientInformation.model_validate(client_information)
+
+ def save_client_information(self, client_information: OAuthClientInformationFull) -> None:
+ """Saves client information after dynamic registration."""
+ MCPToolManageService.update_mcp_provider_credentials(
+ self.mcp_provider,
+ {"client_information": client_information.model_dump()},
+ )
+
+ def tokens(self) -> Optional[OAuthTokens]:
+ """Loads any existing OAuth tokens for the current session."""
+ credentials = self.mcp_provider.decrypted_credentials
+ if not credentials:
+ return None
+ return OAuthTokens(
+ access_token=credentials.get("access_token", ""),
+ token_type=credentials.get("token_type", "Bearer"),
+ expires_in=int(credentials.get("expires_in", "3600") or 3600),
+ refresh_token=credentials.get("refresh_token", ""),
+ )
+
+ def save_tokens(self, tokens: OAuthTokens) -> None:
+ """Stores new OAuth tokens for the current session."""
+ # update mcp provider credentials
+ token_dict = tokens.model_dump()
+ MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, token_dict, authed=True)
+
+ def save_code_verifier(self, code_verifier: str) -> None:
+ """Saves a PKCE code verifier for the current session."""
+ MCPToolManageService.update_mcp_provider_credentials(self.mcp_provider, {"code_verifier": code_verifier})
+
+ def code_verifier(self) -> str:
+ """Loads the PKCE code verifier for the current session."""
+ # get code verifier from mcp provider credentials
+ return str(self.mcp_provider.decrypted_credentials.get("code_verifier", ""))
diff --git a/api/core/mcp/client/sse_client.py b/api/core/mcp/client/sse_client.py
new file mode 100644
index 0000000000..91debcc8f9
--- /dev/null
+++ b/api/core/mcp/client/sse_client.py
@@ -0,0 +1,361 @@
+import logging
+import queue
+from collections.abc import Generator
+from concurrent.futures import ThreadPoolExecutor
+from contextlib import contextmanager
+from typing import Any, TypeAlias, final
+from urllib.parse import urljoin, urlparse
+
+import httpx
+from sseclient import SSEClient
+
+from core.mcp import types
+from core.mcp.error import MCPAuthError, MCPConnectionError
+from core.mcp.types import SessionMessage
+from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
+
+logger = logging.getLogger(__name__)
+
+DEFAULT_QUEUE_READ_TIMEOUT = 3
+
+
+@final
+class _StatusReady:
+ def __init__(self, endpoint_url: str):
+ self._endpoint_url = endpoint_url
+
+
+@final
+class _StatusError:
+ def __init__(self, exc: Exception):
+ self._exc = exc
+
+
+# Type aliases for better readability
+ReadQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
+WriteQueue: TypeAlias = queue.Queue[SessionMessage | Exception | None]
+StatusQueue: TypeAlias = queue.Queue[_StatusReady | _StatusError]
+
+
+def remove_request_params(url: str) -> str:
+ """Remove request parameters from URL, keeping only the path."""
+ return urljoin(url, urlparse(url).path)
+
+
+class SSETransport:
+ """SSE client transport implementation."""
+
+ def __init__(
+ self,
+ url: str,
+ headers: dict[str, Any] | None = None,
+ timeout: float = 5.0,
+ sse_read_timeout: float = 5 * 60,
+ ) -> None:
+ """Initialize the SSE transport.
+
+ Args:
+ url: The SSE endpoint URL.
+ headers: Optional headers to include in requests.
+ timeout: HTTP timeout for regular operations.
+ sse_read_timeout: Timeout for SSE read operations.
+ """
+ self.url = url
+ self.headers = headers or {}
+ self.timeout = timeout
+ self.sse_read_timeout = sse_read_timeout
+ self.endpoint_url: str | None = None
+
+ def _validate_endpoint_url(self, endpoint_url: str) -> bool:
+ """Validate that the endpoint URL matches the connection origin.
+
+ Args:
+ endpoint_url: The endpoint URL to validate.
+
+ Returns:
+ True if valid, False otherwise.
+ """
+ url_parsed = urlparse(self.url)
+ endpoint_parsed = urlparse(endpoint_url)
+
+ return url_parsed.netloc == endpoint_parsed.netloc and url_parsed.scheme == endpoint_parsed.scheme
+
+ def _handle_endpoint_event(self, sse_data: str, status_queue: StatusQueue) -> None:
+ """Handle an 'endpoint' SSE event.
+
+ Args:
+ sse_data: The SSE event data.
+ status_queue: Queue to put status updates.
+ """
+ endpoint_url = urljoin(self.url, sse_data)
+ logger.info(f"Received endpoint URL: {endpoint_url}")
+
+ if not self._validate_endpoint_url(endpoint_url):
+ error_msg = f"Endpoint origin does not match connection origin: {endpoint_url}"
+ logger.error(error_msg)
+ status_queue.put(_StatusError(ValueError(error_msg)))
+ return
+
+ status_queue.put(_StatusReady(endpoint_url))
+
+ def _handle_message_event(self, sse_data: str, read_queue: ReadQueue) -> None:
+ """Handle a 'message' SSE event.
+
+ Args:
+ sse_data: The SSE event data.
+ read_queue: Queue to put parsed messages.
+ """
+ try:
+ message = types.JSONRPCMessage.model_validate_json(sse_data)
+ logger.debug(f"Received server message: {message}")
+ session_message = SessionMessage(message)
+ read_queue.put(session_message)
+ except Exception as exc:
+ logger.exception("Error parsing server message")
+ read_queue.put(exc)
+
+ def _handle_sse_event(self, sse, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
+ """Handle a single SSE event.
+
+ Args:
+ sse: The SSE event object.
+ read_queue: Queue for message events.
+ status_queue: Queue for status events.
+ """
+ match sse.event:
+ case "endpoint":
+ self._handle_endpoint_event(sse.data, status_queue)
+ case "message":
+ self._handle_message_event(sse.data, read_queue)
+ case _:
+ logger.warning(f"Unknown SSE event: {sse.event}")
+
+ def sse_reader(self, event_source, read_queue: ReadQueue, status_queue: StatusQueue) -> None:
+ """Read and process SSE events.
+
+ Args:
+ event_source: The SSE event source.
+ read_queue: Queue to put received messages.
+ status_queue: Queue to put status updates.
+ """
+ try:
+ for sse in event_source.iter_sse():
+ self._handle_sse_event(sse, read_queue, status_queue)
+ except httpx.ReadError as exc:
+ logger.debug(f"SSE reader shutting down normally: {exc}")
+ except Exception as exc:
+ read_queue.put(exc)
+ finally:
+ read_queue.put(None)
+
+ def _send_message(self, client: httpx.Client, endpoint_url: str, message: SessionMessage) -> None:
+ """Send a single message to the server.
+
+ Args:
+ client: HTTP client to use.
+ endpoint_url: The endpoint URL to send to.
+ message: The message to send.
+ """
+ response = client.post(
+ endpoint_url,
+ json=message.message.model_dump(
+ by_alias=True,
+ mode="json",
+ exclude_none=True,
+ ),
+ )
+ response.raise_for_status()
+ logger.debug(f"Client message sent successfully: {response.status_code}")
+
+ def post_writer(self, client: httpx.Client, endpoint_url: str, write_queue: WriteQueue) -> None:
+ """Handle writing messages to the server.
+
+ Args:
+ client: HTTP client to use.
+ endpoint_url: The endpoint URL to send messages to.
+ write_queue: Queue to read messages from.
+ """
+ try:
+ while True:
+ try:
+ message = write_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
+ if message is None:
+ break
+ if isinstance(message, Exception):
+ write_queue.put(message)
+ continue
+
+ self._send_message(client, endpoint_url, message)
+
+ except queue.Empty:
+ continue
+ except httpx.ReadError as exc:
+ logger.debug(f"Post writer shutting down normally: {exc}")
+ except Exception as exc:
+ logger.exception("Error writing messages")
+ write_queue.put(exc)
+ finally:
+ write_queue.put(None)
+
+ def _wait_for_endpoint(self, status_queue: StatusQueue) -> str:
+ """Wait for the endpoint URL from the status queue.
+
+ Args:
+ status_queue: Queue to read status from.
+
+ Returns:
+ The endpoint URL.
+
+ Raises:
+ ValueError: If endpoint URL is not received or there's an error.
+ """
+ try:
+ status = status_queue.get(timeout=1)
+ except queue.Empty:
+ raise ValueError("failed to get endpoint URL")
+
+ if isinstance(status, _StatusReady):
+ return status._endpoint_url
+ elif isinstance(status, _StatusError):
+ raise status._exc
+ else:
+ raise ValueError("failed to get endpoint URL")
+
+ def connect(
+ self,
+ executor: ThreadPoolExecutor,
+ client: httpx.Client,
+ event_source,
+ ) -> tuple[ReadQueue, WriteQueue]:
+ """Establish connection and start worker threads.
+
+ Args:
+ executor: Thread pool executor.
+ client: HTTP client.
+ event_source: SSE event source.
+
+ Returns:
+ Tuple of (read_queue, write_queue).
+ """
+ read_queue: ReadQueue = queue.Queue()
+ write_queue: WriteQueue = queue.Queue()
+ status_queue: StatusQueue = queue.Queue()
+
+ # Start SSE reader thread
+ executor.submit(self.sse_reader, event_source, read_queue, status_queue)
+
+ # Wait for endpoint URL
+ endpoint_url = self._wait_for_endpoint(status_queue)
+ self.endpoint_url = endpoint_url
+
+ # Start post writer thread
+ executor.submit(self.post_writer, client, endpoint_url, write_queue)
+
+ return read_queue, write_queue
+
+
+@contextmanager
+def sse_client(
+ url: str,
+ headers: dict[str, Any] | None = None,
+ timeout: float = 5.0,
+ sse_read_timeout: float = 5 * 60,
+) -> Generator[tuple[ReadQueue, WriteQueue], None, None]:
+ """
+ Client transport for SSE.
+ `sse_read_timeout` determines how long (in seconds) the client will wait for a new
+ event before disconnecting. All other HTTP operations are controlled by `timeout`.
+
+ Args:
+ url: The SSE endpoint URL.
+ headers: Optional headers to include in requests.
+ timeout: HTTP timeout for regular operations.
+ sse_read_timeout: Timeout for SSE read operations.
+
+ Yields:
+ Tuple of (read_queue, write_queue) for message communication.
+ """
+ transport = SSETransport(url, headers, timeout, sse_read_timeout)
+
+ read_queue: ReadQueue | None = None
+ write_queue: WriteQueue | None = None
+
+ with ThreadPoolExecutor() as executor:
+ try:
+ with create_ssrf_proxy_mcp_http_client(headers=transport.headers) as client:
+ with ssrf_proxy_sse_connect(
+ url, timeout=httpx.Timeout(timeout, read=sse_read_timeout), client=client
+ ) as event_source:
+ event_source.response.raise_for_status()
+
+ read_queue, write_queue = transport.connect(executor, client, event_source)
+
+ yield read_queue, write_queue
+
+ except httpx.HTTPStatusError as exc:
+ if exc.response.status_code == 401:
+ raise MCPAuthError()
+ raise MCPConnectionError()
+ except Exception:
+ logger.exception("Error connecting to SSE endpoint")
+ raise
+ finally:
+ # Clean up queues
+ if read_queue:
+ read_queue.put(None)
+ if write_queue:
+ write_queue.put(None)
+
+
+def send_message(http_client: httpx.Client, endpoint_url: str, session_message: SessionMessage) -> None:
+ """
+ Send a message to the server using the provided HTTP client.
+
+ Args:
+ http_client: The HTTP client to use for sending
+ endpoint_url: The endpoint URL to send the message to
+ session_message: The message to send
+ """
+ try:
+ response = http_client.post(
+ endpoint_url,
+ json=session_message.message.model_dump(
+ by_alias=True,
+ mode="json",
+ exclude_none=True,
+ ),
+ )
+ response.raise_for_status()
+ logger.debug(f"Client message sent successfully: {response.status_code}")
+ except Exception as exc:
+ logger.exception("Error sending message")
+ raise
+
+
+def read_messages(
+ sse_client: SSEClient,
+) -> Generator[SessionMessage | Exception, None, None]:
+ """
+ Read messages from the SSE client.
+
+ Args:
+ sse_client: The SSE client to read from
+
+ Yields:
+ SessionMessage or Exception for each event received
+ """
+ try:
+ for sse in sse_client.events():
+ if sse.event == "message":
+ try:
+ message = types.JSONRPCMessage.model_validate_json(sse.data)
+ logger.debug(f"Received server message: {message}")
+ yield SessionMessage(message)
+ except Exception as exc:
+ logger.exception("Error parsing server message")
+ yield exc
+ else:
+ logger.warning(f"Unknown SSE event: {sse.event}")
+ except Exception as exc:
+ logger.exception("Error reading SSE messages")
+ yield exc
diff --git a/api/core/mcp/client/streamable_client.py b/api/core/mcp/client/streamable_client.py
new file mode 100644
index 0000000000..fbd8d05f9e
--- /dev/null
+++ b/api/core/mcp/client/streamable_client.py
@@ -0,0 +1,476 @@
+"""
+StreamableHTTP Client Transport Module
+
+This module implements the StreamableHTTP transport for MCP clients,
+providing support for HTTP POST requests with optional SSE streaming responses
+and session management.
+"""
+
+import logging
+import queue
+from collections.abc import Callable, Generator
+from concurrent.futures import ThreadPoolExecutor
+from contextlib import contextmanager
+from dataclasses import dataclass
+from datetime import timedelta
+from typing import Any, cast
+
+import httpx
+from httpx_sse import EventSource, ServerSentEvent
+
+from core.mcp.types import (
+ ClientMessageMetadata,
+ ErrorData,
+ JSONRPCError,
+ JSONRPCMessage,
+ JSONRPCNotification,
+ JSONRPCRequest,
+ JSONRPCResponse,
+ RequestId,
+ SessionMessage,
+)
+from core.mcp.utils import create_ssrf_proxy_mcp_http_client, ssrf_proxy_sse_connect
+
+logger = logging.getLogger(__name__)
+
+
+SessionMessageOrError = SessionMessage | Exception | None
+# Queue types with clearer names for their roles
+ServerToClientQueue = queue.Queue[SessionMessageOrError] # Server to client messages
+ClientToServerQueue = queue.Queue[SessionMessage | None] # Client to server messages
+GetSessionIdCallback = Callable[[], str | None]
+
+MCP_SESSION_ID = "mcp-session-id"
+LAST_EVENT_ID = "last-event-id"
+CONTENT_TYPE = "content-type"
+ACCEPT = "Accept"
+
+
+JSON = "application/json"
+SSE = "text/event-stream"
+
+DEFAULT_QUEUE_READ_TIMEOUT = 3
+
+
+class StreamableHTTPError(Exception):
+ """Base exception for StreamableHTTP transport errors."""
+
+ pass
+
+
+class ResumptionError(StreamableHTTPError):
+ """Raised when resumption request is invalid."""
+
+ pass
+
+
+@dataclass
+class RequestContext:
+ """Context for a request operation."""
+
+ client: httpx.Client
+ headers: dict[str, str]
+ session_id: str | None
+ session_message: SessionMessage
+ metadata: ClientMessageMetadata | None
+ server_to_client_queue: ServerToClientQueue # Renamed for clarity
+ sse_read_timeout: timedelta
+
+
+class StreamableHTTPTransport:
+ """StreamableHTTP client transport implementation."""
+
+ def __init__(
+ self,
+ url: str,
+ headers: dict[str, Any] | None = None,
+ timeout: timedelta = timedelta(seconds=30),
+ sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
+ ) -> None:
+ """Initialize the StreamableHTTP transport.
+
+ Args:
+ url: The endpoint URL.
+ headers: Optional headers to include in requests.
+ timeout: HTTP timeout for regular operations.
+ sse_read_timeout: Timeout for SSE read operations.
+ """
+ self.url = url
+ self.headers = headers or {}
+ self.timeout = timeout
+ self.sse_read_timeout = sse_read_timeout
+ self.session_id: str | None = None
+ self.request_headers = {
+ ACCEPT: f"{JSON}, {SSE}",
+ CONTENT_TYPE: JSON,
+ **self.headers,
+ }
+
+ def _update_headers_with_session(self, base_headers: dict[str, str]) -> dict[str, str]:
+ """Update headers with session ID if available."""
+ headers = base_headers.copy()
+ if self.session_id:
+ headers[MCP_SESSION_ID] = self.session_id
+ return headers
+
+ def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
+ """Check if the message is an initialization request."""
+ return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
+
+ def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
+ """Check if the message is an initialized notification."""
+ return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
+
+ def _maybe_extract_session_id_from_response(
+ self,
+ response: httpx.Response,
+ ) -> None:
+ """Extract and store session ID from response headers."""
+ new_session_id = response.headers.get(MCP_SESSION_ID)
+ if new_session_id:
+ self.session_id = new_session_id
+ logger.info(f"Received session ID: {self.session_id}")
+
+ def _handle_sse_event(
+ self,
+ sse: ServerSentEvent,
+ server_to_client_queue: ServerToClientQueue,
+ original_request_id: RequestId | None = None,
+ resumption_callback: Callable[[str], None] | None = None,
+ ) -> bool:
+ """Handle an SSE event, returning True if the response is complete."""
+ if sse.event == "message":
+ try:
+ message = JSONRPCMessage.model_validate_json(sse.data)
+ logger.debug(f"SSE message: {message}")
+
+ # If this is a response and we have original_request_id, replace it
+ if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
+ message.root.id = original_request_id
+
+ session_message = SessionMessage(message)
+ # Put message in queue that goes to client
+ server_to_client_queue.put(session_message)
+
+ # Call resumption token callback if we have an ID
+ if sse.id and resumption_callback:
+ resumption_callback(sse.id)
+
+ # If this is a response or error return True indicating completion
+ # Otherwise, return False to continue listening
+ return isinstance(message.root, JSONRPCResponse | JSONRPCError)
+
+ except Exception as exc:
+ # Put exception in queue that goes to client
+ server_to_client_queue.put(exc)
+ return False
+ elif sse.event == "ping":
+ logger.debug("Received ping event")
+ return False
+ else:
+ logger.warning(f"Unknown SSE event: {sse.event}")
+ return False
+
+ def handle_get_stream(
+ self,
+ client: httpx.Client,
+ server_to_client_queue: ServerToClientQueue,
+ ) -> None:
+ """Handle GET stream for server-initiated messages."""
+ try:
+ if not self.session_id:
+ return
+
+ headers = self._update_headers_with_session(self.request_headers)
+
+ with ssrf_proxy_sse_connect(
+ self.url,
+ headers=headers,
+ timeout=httpx.Timeout(self.timeout.seconds, read=self.sse_read_timeout.seconds),
+ client=client,
+ method="GET",
+ ) as event_source:
+ event_source.response.raise_for_status()
+ logger.debug("GET SSE connection established")
+
+ for sse in event_source.iter_sse():
+ self._handle_sse_event(sse, server_to_client_queue)
+
+ except Exception as exc:
+ logger.debug(f"GET stream error (non-fatal): {exc}")
+
+ def _handle_resumption_request(self, ctx: RequestContext) -> None:
+ """Handle a resumption request using GET with SSE."""
+ headers = self._update_headers_with_session(ctx.headers)
+ if ctx.metadata and ctx.metadata.resumption_token:
+ headers[LAST_EVENT_ID] = ctx.metadata.resumption_token
+ else:
+ raise ResumptionError("Resumption request requires a resumption token")
+
+ # Extract original request ID to map responses
+ original_request_id = None
+ if isinstance(ctx.session_message.message.root, JSONRPCRequest):
+ original_request_id = ctx.session_message.message.root.id
+
+ with ssrf_proxy_sse_connect(
+ self.url,
+ headers=headers,
+ timeout=httpx.Timeout(self.timeout.seconds, read=ctx.sse_read_timeout.seconds),
+ client=ctx.client,
+ method="GET",
+ ) as event_source:
+ event_source.response.raise_for_status()
+ logger.debug("Resumption GET SSE connection established")
+
+ for sse in event_source.iter_sse():
+ is_complete = self._handle_sse_event(
+ sse,
+ ctx.server_to_client_queue,
+ original_request_id,
+ ctx.metadata.on_resumption_token_update if ctx.metadata else None,
+ )
+ if is_complete:
+ break
+
+ def _handle_post_request(self, ctx: RequestContext) -> None:
+ """Handle a POST request with response processing."""
+ headers = self._update_headers_with_session(ctx.headers)
+ message = ctx.session_message.message
+ is_initialization = self._is_initialization_request(message)
+
+ with ctx.client.stream(
+ "POST",
+ self.url,
+ json=message.model_dump(by_alias=True, mode="json", exclude_none=True),
+ headers=headers,
+ ) as response:
+ if response.status_code == 202:
+ logger.debug("Received 202 Accepted")
+ return
+
+ if response.status_code == 404:
+ if isinstance(message.root, JSONRPCRequest):
+ self._send_session_terminated_error(
+ ctx.server_to_client_queue,
+ message.root.id,
+ )
+ return
+
+ response.raise_for_status()
+ if is_initialization:
+ self._maybe_extract_session_id_from_response(response)
+
+ content_type = cast(str, response.headers.get(CONTENT_TYPE, "").lower())
+
+ if content_type.startswith(JSON):
+ self._handle_json_response(response, ctx.server_to_client_queue)
+ elif content_type.startswith(SSE):
+ self._handle_sse_response(response, ctx)
+ else:
+ self._handle_unexpected_content_type(
+ content_type,
+ ctx.server_to_client_queue,
+ )
+
+ def _handle_json_response(
+ self,
+ response: httpx.Response,
+ server_to_client_queue: ServerToClientQueue,
+ ) -> None:
+ """Handle JSON response from the server."""
+ try:
+ content = response.read()
+ message = JSONRPCMessage.model_validate_json(content)
+ session_message = SessionMessage(message)
+ server_to_client_queue.put(session_message)
+ except Exception as exc:
+ server_to_client_queue.put(exc)
+
+ def _handle_sse_response(self, response: httpx.Response, ctx: RequestContext) -> None:
+ """Handle SSE response from the server."""
+ try:
+ event_source = EventSource(response)
+ for sse in event_source.iter_sse():
+ is_complete = self._handle_sse_event(
+ sse,
+ ctx.server_to_client_queue,
+ resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
+ )
+ if is_complete:
+ break
+ except Exception as e:
+ ctx.server_to_client_queue.put(e)
+
+ def _handle_unexpected_content_type(
+ self,
+ content_type: str,
+ server_to_client_queue: ServerToClientQueue,
+ ) -> None:
+ """Handle unexpected content type in response."""
+ error_msg = f"Unexpected content type: {content_type}"
+ logger.error(error_msg)
+ server_to_client_queue.put(ValueError(error_msg))
+
+ def _send_session_terminated_error(
+ self,
+ server_to_client_queue: ServerToClientQueue,
+ request_id: RequestId,
+ ) -> None:
+ """Send a session terminated error response."""
+ jsonrpc_error = JSONRPCError(
+ jsonrpc="2.0",
+ id=request_id,
+ error=ErrorData(code=32600, message="Session terminated by server"),
+ )
+ session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
+ server_to_client_queue.put(session_message)
+
+ def post_writer(
+ self,
+ client: httpx.Client,
+ client_to_server_queue: ClientToServerQueue,
+ server_to_client_queue: ServerToClientQueue,
+ start_get_stream: Callable[[], None],
+ ) -> None:
+ """Handle writing requests to the server.
+
+ This method processes messages from the client_to_server_queue and sends them to the server.
+ Responses are written to the server_to_client_queue.
+ """
+ while True:
+ try:
+ # Read message from client queue with timeout to check stop_event periodically
+ session_message = client_to_server_queue.get(timeout=DEFAULT_QUEUE_READ_TIMEOUT)
+ if session_message is None:
+ break
+
+ message = session_message.message
+ metadata = (
+ session_message.metadata if isinstance(session_message.metadata, ClientMessageMetadata) else None
+ )
+
+ # Check if this is a resumption request
+ is_resumption = bool(metadata and metadata.resumption_token)
+
+ logger.debug(f"Sending client message: {message}")
+
+ # Handle initialized notification
+ if self._is_initialized_notification(message):
+ start_get_stream()
+
+ ctx = RequestContext(
+ client=client,
+ headers=self.request_headers,
+ session_id=self.session_id,
+ session_message=session_message,
+ metadata=metadata,
+ server_to_client_queue=server_to_client_queue, # Queue to write responses to client
+ sse_read_timeout=self.sse_read_timeout,
+ )
+
+ if is_resumption:
+ self._handle_resumption_request(ctx)
+ else:
+ self._handle_post_request(ctx)
+ except queue.Empty:
+ continue
+ except Exception as exc:
+ server_to_client_queue.put(exc)
+
+ def terminate_session(self, client: httpx.Client) -> None:
+ """Terminate the session by sending a DELETE request."""
+ if not self.session_id:
+ return
+
+ try:
+ headers = self._update_headers_with_session(self.request_headers)
+ response = client.delete(self.url, headers=headers)
+
+ if response.status_code == 405:
+ logger.debug("Server does not allow session termination")
+ elif response.status_code != 200:
+ logger.warning(f"Session termination failed: {response.status_code}")
+ except Exception as exc:
+ logger.warning(f"Session termination failed: {exc}")
+
+ def get_session_id(self) -> str | None:
+ """Get the current session ID."""
+ return self.session_id
+
+
+@contextmanager
+def streamablehttp_client(
+ url: str,
+ headers: dict[str, Any] | None = None,
+ timeout: timedelta = timedelta(seconds=30),
+ sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
+ terminate_on_close: bool = True,
+) -> Generator[
+ tuple[
+ ServerToClientQueue, # Queue for receiving messages FROM server
+ ClientToServerQueue, # Queue for sending messages TO server
+ GetSessionIdCallback,
+ ],
+ None,
+ None,
+]:
+ """
+ Client transport for StreamableHTTP.
+
+ `sse_read_timeout` determines how long (in seconds) the client will wait for a new
+ event before disconnecting. All other HTTP operations are controlled by `timeout`.
+
+ Yields:
+ Tuple containing:
+ - server_to_client_queue: Queue for reading messages FROM the server
+ - client_to_server_queue: Queue for sending messages TO the server
+ - get_session_id_callback: Function to retrieve the current session ID
+ """
+ transport = StreamableHTTPTransport(url, headers, timeout, sse_read_timeout)
+
+ # Create queues with clear directional meaning
+ server_to_client_queue: ServerToClientQueue = queue.Queue() # For messages FROM server TO client
+ client_to_server_queue: ClientToServerQueue = queue.Queue() # For messages FROM client TO server
+
+ with ThreadPoolExecutor(max_workers=2) as executor:
+ try:
+ with create_ssrf_proxy_mcp_http_client(
+ headers=transport.request_headers,
+ timeout=httpx.Timeout(transport.timeout.seconds, read=transport.sse_read_timeout.seconds),
+ ) as client:
+ # Define callbacks that need access to thread pool
+ def start_get_stream() -> None:
+ """Start a worker thread to handle server-initiated messages."""
+ executor.submit(transport.handle_get_stream, client, server_to_client_queue)
+
+ # Start the post_writer worker thread
+ executor.submit(
+ transport.post_writer,
+ client,
+ client_to_server_queue, # Queue for messages FROM client TO server
+ server_to_client_queue, # Queue for messages FROM server TO client
+ start_get_stream,
+ )
+
+ try:
+ yield (
+ server_to_client_queue, # Queue for receiving messages FROM server
+ client_to_server_queue, # Queue for sending messages TO server
+ transport.get_session_id,
+ )
+ finally:
+ if transport.session_id and terminate_on_close:
+ transport.terminate_session(client)
+
+ # Signal threads to stop
+ client_to_server_queue.put(None)
+ finally:
+ # Clear any remaining items and add None sentinel to unblock any waiting threads
+ try:
+ while not client_to_server_queue.empty():
+ client_to_server_queue.get_nowait()
+ except queue.Empty:
+ pass
+
+ client_to_server_queue.put(None)
+ server_to_client_queue.put(None)
diff --git a/api/core/mcp/entities.py b/api/core/mcp/entities.py
new file mode 100644
index 0000000000..7553c10a2e
--- /dev/null
+++ b/api/core/mcp/entities.py
@@ -0,0 +1,19 @@
+from dataclasses import dataclass
+from typing import Any, Generic, TypeVar
+
+from core.mcp.session.base_session import BaseSession
+from core.mcp.types import LATEST_PROTOCOL_VERSION, RequestId, RequestParams
+
+SUPPORTED_PROTOCOL_VERSIONS: list[str] = ["2024-11-05", LATEST_PROTOCOL_VERSION]
+
+
+SessionT = TypeVar("SessionT", bound=BaseSession[Any, Any, Any, Any, Any])
+LifespanContextT = TypeVar("LifespanContextT")
+
+
+@dataclass
+class RequestContext(Generic[SessionT, LifespanContextT]):
+ request_id: RequestId
+ meta: RequestParams.Meta | None
+ session: SessionT
+ lifespan_context: LifespanContextT
diff --git a/api/core/mcp/error.py b/api/core/mcp/error.py
new file mode 100644
index 0000000000..92ea7bde09
--- /dev/null
+++ b/api/core/mcp/error.py
@@ -0,0 +1,10 @@
+class MCPError(Exception):
+ pass
+
+
+class MCPConnectionError(MCPError):
+ pass
+
+
+class MCPAuthError(MCPConnectionError):
+ pass
diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py
new file mode 100644
index 0000000000..5fe52c008a
--- /dev/null
+++ b/api/core/mcp/mcp_client.py
@@ -0,0 +1,153 @@
+import logging
+from collections.abc import Callable
+from contextlib import AbstractContextManager, ExitStack
+from types import TracebackType
+from typing import Any, Optional, cast
+from urllib.parse import urlparse
+
+from core.mcp.client.sse_client import sse_client
+from core.mcp.client.streamable_client import streamablehttp_client
+from core.mcp.error import MCPAuthError, MCPConnectionError
+from core.mcp.session.client_session import ClientSession
+from core.mcp.types import Tool
+
+logger = logging.getLogger(__name__)
+
+
+class MCPClient:
+ def __init__(
+ self,
+ server_url: str,
+ provider_id: str,
+ tenant_id: str,
+ authed: bool = True,
+ authorization_code: Optional[str] = None,
+ for_list: bool = False,
+ ):
+ # Initialize info
+ self.provider_id = provider_id
+ self.tenant_id = tenant_id
+ self.client_type = "streamable"
+ self.server_url = server_url
+
+ # Authentication info
+ self.authed = authed
+ self.authorization_code = authorization_code
+ if authed:
+ from core.mcp.auth.auth_provider import OAuthClientProvider
+
+ self.provider = OAuthClientProvider(self.provider_id, self.tenant_id, for_list=for_list)
+ self.token = self.provider.tokens()
+
+ # Initialize session and client objects
+ self._session: Optional[ClientSession] = None
+ self._streams_context: Optional[AbstractContextManager[Any]] = None
+ self._session_context: Optional[ClientSession] = None
+ self.exit_stack = ExitStack()
+
+ # Whether the client has been initialized
+ self._initialized = False
+
+ def __enter__(self):
+ self._initialize()
+ self._initialized = True
+ return self
+
+ def __exit__(
+ self, exc_type: Optional[type], exc_value: Optional[BaseException], traceback: Optional[TracebackType]
+ ):
+ self.cleanup()
+
+ def _initialize(
+ self,
+ ):
+ """Initialize the client with fallback to SSE if streamable connection fails"""
+ connection_methods: dict[str, Callable[..., AbstractContextManager[Any]]] = {
+ "mcp": streamablehttp_client,
+ "sse": sse_client,
+ }
+
+ parsed_url = urlparse(self.server_url)
+ path = parsed_url.path or ""
+ method_name = path.rstrip("/").split("/")[-1] if path else ""
+ if method_name in connection_methods:
+ client_factory = connection_methods[method_name]
+ self.connect_server(client_factory, method_name)
+ else:
+ try:
+ logger.debug(f"Not supported method {method_name} found in URL path, trying default 'mcp' method.")
+ self.connect_server(sse_client, "sse")
+ except MCPConnectionError:
+ logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
+ self.connect_server(streamablehttp_client, "mcp")
+
+ def connect_server(
+ self, client_factory: Callable[..., AbstractContextManager[Any]], method_name: str, first_try: bool = True
+ ):
+ from core.mcp.auth.auth_flow import auth
+
+ try:
+ headers = (
+ {"Authorization": f"{self.token.token_type.capitalize()} {self.token.access_token}"}
+ if self.authed and self.token
+ else {}
+ )
+ self._streams_context = client_factory(url=self.server_url, headers=headers)
+ if not self._streams_context:
+ raise MCPConnectionError("Failed to create connection context")
+
+ # Use exit_stack to manage context managers properly
+ if method_name == "mcp":
+ read_stream, write_stream, _ = self.exit_stack.enter_context(self._streams_context)
+ streams = (read_stream, write_stream)
+ else: # sse_client
+ streams = self.exit_stack.enter_context(self._streams_context)
+
+ self._session_context = ClientSession(*streams)
+ self._session = self.exit_stack.enter_context(self._session_context)
+ session = cast(ClientSession, self._session)
+ session.initialize()
+ return
+
+ except MCPAuthError:
+ if not self.authed:
+ raise
+ try:
+ auth(self.provider, self.server_url, self.authorization_code)
+ except Exception as e:
+ raise ValueError(f"Failed to authenticate: {e}")
+ self.token = self.provider.tokens()
+ if first_try:
+ return self.connect_server(client_factory, method_name, first_try=False)
+
+ except MCPConnectionError:
+ raise
+
+ def list_tools(self) -> list[Tool]:
+ """Connect to an MCP server running with SSE transport"""
+ # List available tools to verify connection
+ if not self._initialized or not self._session:
+ raise ValueError("Session not initialized.")
+ response = self._session.list_tools()
+ tools = response.tools
+ return tools
+
+ def invoke_tool(self, tool_name: str, tool_args: dict):
+ """Call a tool"""
+ if not self._initialized or not self._session:
+ raise ValueError("Session not initialized.")
+ return self._session.call_tool(tool_name, tool_args)
+
+ def cleanup(self):
+ """Clean up resources"""
+ try:
+ # ExitStack will handle proper cleanup of all managed context managers
+ self.exit_stack.close()
+ except Exception as e:
+ logging.exception("Error during cleanup")
+ raise ValueError(f"Error during cleanup: {e}")
+ finally:
+ self._session = None
+ self._session_context = None
+ self._streams_context = None
+ self._initialized = False
diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py
new file mode 100644
index 0000000000..20ff7e7524
--- /dev/null
+++ b/api/core/mcp/server/streamable_http.py
@@ -0,0 +1,226 @@
+import json
+import logging
+from collections.abc import Mapping
+from typing import Any, cast
+
+from configs import dify_config
+from controllers.web.passport import generate_session_id
+from core.app.app_config.entities import VariableEntity, VariableEntityType
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
+from core.mcp import types
+from core.mcp.types import INTERNAL_ERROR, INVALID_PARAMS, METHOD_NOT_FOUND
+from core.mcp.utils import create_mcp_error_response
+from core.model_runtime.utils.encoders import jsonable_encoder
+from extensions.ext_database import db
+from models.model import App, AppMCPServer, AppMode, EndUser
+from services.app_generate_service import AppGenerateService
+
+"""
+Apply to MCP HTTP streamable server with stateless http
+"""
+logger = logging.getLogger(__name__)
+
+
+class MCPServerStreamableHTTPRequestHandler:
+ def __init__(
+ self, app: App, request: types.ClientRequest | types.ClientNotification, user_input_form: list[VariableEntity]
+ ):
+ self.app = app
+ self.request = request
+ mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first()
+ if not mcp_server:
+ raise ValueError("MCP server not found")
+ self.mcp_server: AppMCPServer = mcp_server
+ self.end_user = self.retrieve_end_user()
+ self.user_input_form = user_input_form
+
+ @property
+ def request_type(self):
+ return type(self.request.root)
+
+ @property
+ def parameter_schema(self):
+ parameters, required = self._convert_input_form_to_parameters(self.user_input_form)
+ if self.app.mode in {AppMode.COMPLETION.value, AppMode.WORKFLOW.value}:
+ return {
+ "type": "object",
+ "properties": parameters,
+ "required": required,
+ }
+ return {
+ "type": "object",
+ "properties": {
+ "query": {"type": "string", "description": "User Input/Question content"},
+ **parameters,
+ },
+ "required": ["query", *required],
+ }
+
+ @property
+ def capabilities(self):
+ return types.ServerCapabilities(
+ tools=types.ToolsCapability(listChanged=False),
+ )
+
+ def response(self, response: types.Result | str):
+ if isinstance(response, str):
+ sse_content = f"event: ping\ndata: {response}\n\n".encode()
+ yield sse_content
+ return
+ json_response = types.JSONRPCResponse(
+ jsonrpc="2.0",
+ id=(self.request.root.model_extra or {}).get("id", 1),
+ result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
+ )
+ json_data = json.dumps(jsonable_encoder(json_response))
+
+ sse_content = f"event: message\ndata: {json_data}\n\n".encode()
+
+ yield sse_content
+
+ def error_response(self, code: int, message: str, data=None):
+ request_id = (self.request.root.model_extra or {}).get("id", 1) or 1
+ return create_mcp_error_response(request_id, code, message, data)
+
+ def handle(self):
+ handle_map = {
+ types.InitializeRequest: self.initialize,
+ types.ListToolsRequest: self.list_tools,
+ types.CallToolRequest: self.invoke_tool,
+ types.InitializedNotification: self.handle_notification,
+ types.PingRequest: self.handle_ping,
+ }
+ try:
+ if self.request_type in handle_map:
+ return self.response(handle_map[self.request_type]())
+ else:
+ return self.error_response(METHOD_NOT_FOUND, f"Method not found: {self.request_type}")
+ except ValueError as e:
+ logger.exception("Invalid params")
+ return self.error_response(INVALID_PARAMS, str(e))
+ except Exception as e:
+ logger.exception("Internal server error")
+ return self.error_response(INTERNAL_ERROR, f"Internal server error: {str(e)}")
+
+ def handle_notification(self):
+ return "ping"
+
+ def handle_ping(self):
+ return types.EmptyResult()
+
+ def initialize(self):
+ request = cast(types.InitializeRequest, self.request.root)
+ client_info = request.params.clientInfo
+ client_name = f"{client_info.name}@{client_info.version}"
+ if not self.end_user:
+ end_user = EndUser(
+ tenant_id=self.app.tenant_id,
+ app_id=self.app.id,
+ type="mcp",
+ name=client_name,
+ session_id=generate_session_id(),
+ external_user_id=self.mcp_server.id,
+ )
+ db.session.add(end_user)
+ db.session.commit()
+ return types.InitializeResult(
+ protocolVersion=types.SERVER_LATEST_PROTOCOL_VERSION,
+ capabilities=self.capabilities,
+ serverInfo=types.Implementation(name="Dify", version=dify_config.project.version),
+ instructions=self.mcp_server.description,
+ )
+
+ def list_tools(self):
+ if not self.end_user:
+ raise ValueError("User not found")
+ return types.ListToolsResult(
+ tools=[
+ types.Tool(
+ name=self.app.name,
+ description=self.mcp_server.description,
+ inputSchema=self.parameter_schema,
+ )
+ ],
+ )
+
+ def invoke_tool(self):
+ if not self.end_user:
+ raise ValueError("User not found")
+ request = cast(types.CallToolRequest, self.request.root)
+ args = request.params.arguments or {}
+ if self.app.mode in {AppMode.WORKFLOW.value}:
+ args = {"inputs": args}
+ elif self.app.mode in {AppMode.COMPLETION.value}:
+ args = {"query": "", "inputs": args}
+ else:
+ args = {"query": args["query"], "inputs": {k: v for k, v in args.items() if k != "query"}}
+ response = AppGenerateService.generate(
+ self.app,
+ self.end_user,
+ args,
+ InvokeFrom.SERVICE_API,
+ streaming=self.app.mode == AppMode.AGENT_CHAT.value,
+ )
+ answer = ""
+ if isinstance(response, RateLimitGenerator):
+ for item in response.generator:
+ data = item
+ if isinstance(data, str) and data.startswith("data: "):
+ try:
+ json_str = data[6:].strip()
+ parsed_data = json.loads(json_str)
+ if parsed_data.get("event") == "agent_thought":
+ answer += parsed_data.get("thought", "")
+ except json.JSONDecodeError:
+ continue
+ if isinstance(response, Mapping):
+ if self.app.mode in {
+ AppMode.ADVANCED_CHAT.value,
+ AppMode.COMPLETION.value,
+ AppMode.CHAT.value,
+ AppMode.AGENT_CHAT.value,
+ }:
+ answer = response["answer"]
+ elif self.app.mode in {AppMode.WORKFLOW.value}:
+ answer = json.dumps(response["data"]["outputs"], ensure_ascii=False)
+ else:
+ raise ValueError("Invalid app mode")
+ # Not support image yet
+ return types.CallToolResult(content=[types.TextContent(text=answer, type="text")])
+
+ def retrieve_end_user(self):
+ return (
+ db.session.query(EndUser)
+ .filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
+ .first()
+ )
+
+ def _convert_input_form_to_parameters(self, user_input_form: list[VariableEntity]):
+ parameters: dict[str, dict[str, Any]] = {}
+ required = []
+ for item in user_input_form:
+ parameters[item.variable] = {}
+ if item.type in (
+ VariableEntityType.FILE,
+ VariableEntityType.FILE_LIST,
+ VariableEntityType.EXTERNAL_DATA_TOOL,
+ ):
+ continue
+ if item.required:
+ required.append(item.variable)
+ # if the workflow republished, the parameters not changed
+ # we should not raise error here
+ try:
+ description = self.mcp_server.parameters_dict[item.variable]
+ except KeyError:
+ description = ""
+ parameters[item.variable]["description"] = description
+ if item.type in (VariableEntityType.TEXT_INPUT, VariableEntityType.PARAGRAPH):
+ parameters[item.variable]["type"] = "string"
+ elif item.type == VariableEntityType.SELECT:
+ parameters[item.variable]["type"] = "string"
+ parameters[item.variable]["enum"] = item.options
+ elif item.type == VariableEntityType.NUMBER:
+ parameters[item.variable]["type"] = "float"
+ return parameters, required
diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py
new file mode 100644
index 0000000000..7734b8fdd9
--- /dev/null
+++ b/api/core/mcp/session/base_session.py
@@ -0,0 +1,415 @@
+import logging
+import queue
+from collections.abc import Callable
+from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
+from contextlib import ExitStack
+from datetime import timedelta
+from types import TracebackType
+from typing import Any, Generic, Self, TypeVar
+
+from httpx import HTTPStatusError
+from pydantic import BaseModel
+
+from core.mcp.error import MCPAuthError, MCPConnectionError
+from core.mcp.types import (
+ CancelledNotification,
+ ClientNotification,
+ ClientRequest,
+ ClientResult,
+ ErrorData,
+ JSONRPCError,
+ JSONRPCMessage,
+ JSONRPCNotification,
+ JSONRPCRequest,
+ JSONRPCResponse,
+ MessageMetadata,
+ RequestId,
+ RequestParams,
+ ServerMessageMetadata,
+ ServerNotification,
+ ServerRequest,
+ ServerResult,
+ SessionMessage,
+)
+
+SendRequestT = TypeVar("SendRequestT", ClientRequest, ServerRequest)
+SendResultT = TypeVar("SendResultT", ClientResult, ServerResult)
+SendNotificationT = TypeVar("SendNotificationT", ClientNotification, ServerNotification)
+ReceiveRequestT = TypeVar("ReceiveRequestT", ClientRequest, ServerRequest)
+ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel)
+ReceiveNotificationT = TypeVar("ReceiveNotificationT", ClientNotification, ServerNotification)
+DEFAULT_RESPONSE_READ_TIMEOUT = 1.0
+
+
+class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
+ """Handles responding to MCP requests and manages request lifecycle.
+
+ This class MUST be used as a context manager to ensure proper cleanup and
+ cancellation handling:
+
+ Example:
+ with request_responder as resp:
+ resp.respond(result)
+
+ The context manager ensures:
+ 1. Proper cancellation scope setup and cleanup
+ 2. Request completion tracking
+ 3. Cleanup of in-flight requests
+ """
+
+ request: ReceiveRequestT
+ _session: Any
+ _on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any]
+
+ def __init__(
+ self,
+ request_id: RequestId,
+ request_meta: RequestParams.Meta | None,
+ request: ReceiveRequestT,
+ session: """BaseSession[
+ SendRequestT,
+ SendNotificationT,
+ SendResultT,
+ ReceiveRequestT,
+ ReceiveNotificationT
+ ]""",
+ on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
+ ) -> None:
+ self.request_id = request_id
+ self.request_meta = request_meta
+ self.request = request
+ self._session = session
+ self._completed = False
+ self._on_complete = on_complete
+ self._entered = False # Track if we're in a context manager
+
+ def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
+ """Enter the context manager, enabling request cancellation tracking."""
+ self._entered = True
+ return self
+
+ def __exit__(
+ self,
+ exc_type: type[BaseException] | None,
+ exc_val: BaseException | None,
+ exc_tb: TracebackType | None,
+ ) -> None:
+ """Exit the context manager, performing cleanup and notifying completion."""
+ try:
+ if self._completed:
+ self._on_complete(self)
+ finally:
+ self._entered = False
+
+ def respond(self, response: SendResultT | ErrorData) -> None:
+ """Send a response for this request.
+
+ Must be called within a context manager block.
+ Raises:
+ RuntimeError: If not used within a context manager
+ AssertionError: If request was already responded to
+ """
+ if not self._entered:
+ raise RuntimeError("RequestResponder must be used as a context manager")
+ assert not self._completed, "Request already responded to"
+
+ self._completed = True
+
+ self._session._send_response(request_id=self.request_id, response=response)
+
+ def cancel(self) -> None:
+ """Cancel this request and mark it as completed."""
+ if not self._entered:
+ raise RuntimeError("RequestResponder must be used as a context manager")
+
+ self._completed = True # Mark as completed so it's removed from in_flight
+ # Send an error response to indicate cancellation
+ self._session._send_response(
+ request_id=self.request_id,
+ response=ErrorData(code=0, message="Request cancelled", data=None),
+ )
+
+
+class BaseSession(
+ Generic[
+ SendRequestT,
+ SendNotificationT,
+ SendResultT,
+ ReceiveRequestT,
+ ReceiveNotificationT,
+ ],
+):
+ """
+ Implements an MCP "session" on top of read/write streams, including features
+ like request/response linking, notifications, and progress.
+
+ This class is a context manager that automatically starts processing
+ messages when entered.
+ """
+
+ _response_streams: dict[RequestId, queue.Queue[JSONRPCResponse | JSONRPCError]]
+ _request_id: int
+ _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
+ _receive_request_type: type[ReceiveRequestT]
+ _receive_notification_type: type[ReceiveNotificationT]
+
+ def __init__(
+ self,
+ read_stream: queue.Queue,
+ write_stream: queue.Queue,
+ receive_request_type: type[ReceiveRequestT],
+ receive_notification_type: type[ReceiveNotificationT],
+ # If none, reading will never time out
+ read_timeout_seconds: timedelta | None = None,
+ ) -> None:
+ self._read_stream = read_stream
+ self._write_stream = write_stream
+ self._response_streams = {}
+ self._request_id = 0
+ self._receive_request_type = receive_request_type
+ self._receive_notification_type = receive_notification_type
+ self._session_read_timeout_seconds = read_timeout_seconds
+ self._in_flight = {}
+ self._exit_stack = ExitStack()
+ # Initialize executor and future to None for proper cleanup checks
+ self._executor: ThreadPoolExecutor | None = None
+ self._receiver_future: Future | None = None
+
+ def __enter__(self) -> Self:
+ # The thread pool is dedicated to running `_receive_loop`. Setting `max_workers` to 1
+ # ensures no unnecessary threads are created.
+ self._executor = ThreadPoolExecutor(max_workers=1)
+ self._receiver_future = self._executor.submit(self._receive_loop)
+ return self
+
+ def check_receiver_status(self) -> None:
+ """`check_receiver_status` ensures that any exceptions raised during the
+ execution of `_receive_loop` are retrieved and propagated."""
+ if self._receiver_future and self._receiver_future.done():
+ self._receiver_future.result()
+
+ def __exit__(
+ self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
+ ) -> None:
+ self._read_stream.put(None)
+ self._write_stream.put(None)
+
+ # Wait for the receiver loop to finish
+ if self._receiver_future:
+ try:
+ self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
+ except TimeoutError:
+ # If the receiver loop is still running after timeout, we'll force shutdown
+ pass
+
+ # Shutdown the executor
+ if self._executor:
+ self._executor.shutdown(wait=True)
+
+ def send_request(
+ self,
+ request: SendRequestT,
+ result_type: type[ReceiveResultT],
+ request_read_timeout_seconds: timedelta | None = None,
+ metadata: MessageMetadata = None,
+ ) -> ReceiveResultT:
+ """
+ Sends a request and wait for a response. Raises an McpError if the
+ response contains an error. If a request read timeout is provided, it
+ will take precedence over the session read timeout.
+
+ Do not use this method to emit notifications! Use send_notification()
+ instead.
+ """
+ self.check_receiver_status()
+
+ request_id = self._request_id
+ self._request_id = request_id + 1
+
+ response_queue: queue.Queue[JSONRPCResponse | JSONRPCError] = queue.Queue()
+ self._response_streams[request_id] = response_queue
+
+ try:
+ jsonrpc_request = JSONRPCRequest(
+ jsonrpc="2.0",
+ id=request_id,
+ **request.model_dump(by_alias=True, mode="json", exclude_none=True),
+ )
+
+ self._write_stream.put(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata))
+ timeout = DEFAULT_RESPONSE_READ_TIMEOUT
+ if request_read_timeout_seconds is not None:
+ timeout = float(request_read_timeout_seconds.total_seconds())
+ elif self._session_read_timeout_seconds is not None:
+ timeout = float(self._session_read_timeout_seconds.total_seconds())
+ while True:
+ try:
+ response_or_error = response_queue.get(timeout=timeout)
+ break
+ except queue.Empty:
+ self.check_receiver_status()
+ continue
+
+ if response_or_error is None:
+ raise MCPConnectionError(
+ ErrorData(
+ code=500,
+ message="No response received",
+ )
+ )
+ elif isinstance(response_or_error, JSONRPCError):
+ if response_or_error.error.code == 401:
+ raise MCPAuthError(
+ ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
+ )
+ else:
+ raise MCPConnectionError(
+ ErrorData(code=response_or_error.error.code, message=response_or_error.error.message)
+ )
+ else:
+ return result_type.model_validate(response_or_error.result)
+
+ finally:
+ self._response_streams.pop(request_id, None)
+
+ def send_notification(
+ self,
+ notification: SendNotificationT,
+ related_request_id: RequestId | None = None,
+ ) -> None:
+ """
+ Emits a notification, which is a one-way message that does not expect
+ a response.
+ """
+ self.check_receiver_status()
+
+ # Some transport implementations may need to set the related_request_id
+ # to attribute to the notifications to the request that triggered them.
+ jsonrpc_notification = JSONRPCNotification(
+ jsonrpc="2.0",
+ **notification.model_dump(by_alias=True, mode="json", exclude_none=True),
+ )
+ session_message = SessionMessage(
+ message=JSONRPCMessage(jsonrpc_notification),
+ metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None,
+ )
+ self._write_stream.put(session_message)
+
+ def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None:
+ if isinstance(response, ErrorData):
+ jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response)
+ session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
+ self._write_stream.put(session_message)
+ else:
+ jsonrpc_response = JSONRPCResponse(
+ jsonrpc="2.0",
+ id=request_id,
+ result=response.model_dump(by_alias=True, mode="json", exclude_none=True),
+ )
+ session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response))
+ self._write_stream.put(session_message)
+
+ def _receive_loop(self) -> None:
+ """
+ Main message processing loop.
+ In a real synchronous implementation, this would likely run in a separate thread.
+ """
+ while True:
+ try:
+ # Attempt to receive a message (this would be blocking in a synchronous context)
+ message = self._read_stream.get(timeout=DEFAULT_RESPONSE_READ_TIMEOUT)
+ if message is None:
+ break
+ if isinstance(message, HTTPStatusError):
+ response_queue = self._response_streams.get(self._request_id - 1)
+ if response_queue is not None:
+ response_queue.put(
+ JSONRPCError(
+ jsonrpc="2.0",
+ id=self._request_id - 1,
+ error=ErrorData(code=message.response.status_code, message=message.args[0]),
+ )
+ )
+ else:
+ self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}"))
+ elif isinstance(message, Exception):
+ self._handle_incoming(message)
+ elif isinstance(message.message.root, JSONRPCRequest):
+ validated_request = self._receive_request_type.model_validate(
+ message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
+ )
+
+ responder = RequestResponder(
+ request_id=message.message.root.id,
+ request_meta=validated_request.root.params.meta if validated_request.root.params else None,
+ request=validated_request,
+ session=self,
+ on_complete=lambda r: self._in_flight.pop(r.request_id, None),
+ )
+
+ self._in_flight[responder.request_id] = responder
+ self._received_request(responder)
+
+ if not responder._completed:
+ self._handle_incoming(responder)
+
+ elif isinstance(message.message.root, JSONRPCNotification):
+ try:
+ notification = self._receive_notification_type.model_validate(
+ message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True)
+ )
+ # Handle cancellation notifications
+ if isinstance(notification.root, CancelledNotification):
+ cancelled_id = notification.root.params.requestId
+ if cancelled_id in self._in_flight:
+ self._in_flight[cancelled_id].cancel()
+ else:
+ self._received_notification(notification)
+ self._handle_incoming(notification)
+ except Exception as e:
+ # For other validation errors, log and continue
+ logging.warning(f"Failed to validate notification: {e}. Message was: {message.message.root}")
+ else: # Response or error
+ response_queue = self._response_streams.get(message.message.root.id)
+ if response_queue is not None:
+ response_queue.put(message.message.root)
+ else:
+ self._handle_incoming(RuntimeError(f"Server Error: {message}"))
+ except queue.Empty:
+ continue
+ except Exception as e:
+ logging.exception("Error in message processing loop")
+ raise
+
+ def _received_request(self, responder: RequestResponder[ReceiveRequestT, SendResultT]) -> None:
+ """
+ Can be overridden by subclasses to handle a request without needing to
+ listen on the message stream.
+
+ If the request is responded to within this method, it will not be
+ forwarded on to the message stream.
+ """
+ pass
+
+ def _received_notification(self, notification: ReceiveNotificationT) -> None:
+ """
+ Can be overridden by subclasses to handle a notification without needing
+ to listen on the message stream.
+ """
+ pass
+
+ def send_progress_notification(
+ self, progress_token: str | int, progress: float, total: float | None = None
+ ) -> None:
+ """
+ Sends a progress notification for a request that is currently being
+ processed.
+ """
+ pass
+
+ def _handle_incoming(
+ self,
+ req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
+ ) -> None:
+ """A generic handler for incoming messages. Overwritten by subclasses."""
+ pass
diff --git a/api/core/mcp/session/client_session.py b/api/core/mcp/session/client_session.py
new file mode 100644
index 0000000000..ed2ad508ab
--- /dev/null
+++ b/api/core/mcp/session/client_session.py
@@ -0,0 +1,365 @@
+from datetime import timedelta
+from typing import Any, Protocol
+
+from pydantic import AnyUrl, TypeAdapter
+
+from configs import dify_config
+from core.mcp import types
+from core.mcp.entities import SUPPORTED_PROTOCOL_VERSIONS, RequestContext
+from core.mcp.session.base_session import BaseSession, RequestResponder
+
+DEFAULT_CLIENT_INFO = types.Implementation(name="Dify", version=dify_config.project.version)
+
+
+class SamplingFnT(Protocol):
+ def __call__(
+ self,
+ context: RequestContext["ClientSession", Any],
+ params: types.CreateMessageRequestParams,
+ ) -> types.CreateMessageResult | types.ErrorData: ...
+
+
+class ListRootsFnT(Protocol):
+ def __call__(self, context: RequestContext["ClientSession", Any]) -> types.ListRootsResult | types.ErrorData: ...
+
+
+class LoggingFnT(Protocol):
+ def __call__(
+ self,
+ params: types.LoggingMessageNotificationParams,
+ ) -> None: ...
+
+
+class MessageHandlerFnT(Protocol):
+ def __call__(
+ self,
+ message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
+ ) -> None: ...
+
+
+def _default_message_handler(
+ message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
+) -> None:
+ if isinstance(message, Exception):
+ raise ValueError(str(message))
+ elif isinstance(message, (types.ServerNotification | RequestResponder)):
+ pass
+
+
+def _default_sampling_callback(
+ context: RequestContext["ClientSession", Any],
+ params: types.CreateMessageRequestParams,
+) -> types.CreateMessageResult | types.ErrorData:
+ return types.ErrorData(
+ code=types.INVALID_REQUEST,
+ message="Sampling not supported",
+ )
+
+
+def _default_list_roots_callback(
+ context: RequestContext["ClientSession", Any],
+) -> types.ListRootsResult | types.ErrorData:
+ return types.ErrorData(
+ code=types.INVALID_REQUEST,
+ message="List roots not supported",
+ )
+
+
+def _default_logging_callback(
+ params: types.LoggingMessageNotificationParams,
+) -> None:
+ pass
+
+
+ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(types.ClientResult | types.ErrorData)
+
+
+class ClientSession(
+ BaseSession[
+ types.ClientRequest,
+ types.ClientNotification,
+ types.ClientResult,
+ types.ServerRequest,
+ types.ServerNotification,
+ ]
+):
+ def __init__(
+ self,
+ read_stream,
+ write_stream,
+ read_timeout_seconds: timedelta | None = None,
+ sampling_callback: SamplingFnT | None = None,
+ list_roots_callback: ListRootsFnT | None = None,
+ logging_callback: LoggingFnT | None = None,
+ message_handler: MessageHandlerFnT | None = None,
+ client_info: types.Implementation | None = None,
+ ) -> None:
+ super().__init__(
+ read_stream,
+ write_stream,
+ types.ServerRequest,
+ types.ServerNotification,
+ read_timeout_seconds=read_timeout_seconds,
+ )
+ self._client_info = client_info or DEFAULT_CLIENT_INFO
+ self._sampling_callback = sampling_callback or _default_sampling_callback
+ self._list_roots_callback = list_roots_callback or _default_list_roots_callback
+ self._logging_callback = logging_callback or _default_logging_callback
+ self._message_handler = message_handler or _default_message_handler
+
+ def initialize(self) -> types.InitializeResult:
+ sampling = types.SamplingCapability()
+ roots = types.RootsCapability(
+ # TODO: Should this be based on whether we
+ # _will_ send notifications, or only whether
+ # they're supported?
+ listChanged=True,
+ )
+
+ result = self.send_request(
+ types.ClientRequest(
+ types.InitializeRequest(
+ method="initialize",
+ params=types.InitializeRequestParams(
+ protocolVersion=types.LATEST_PROTOCOL_VERSION,
+ capabilities=types.ClientCapabilities(
+ sampling=sampling,
+ experimental=None,
+ roots=roots,
+ ),
+ clientInfo=self._client_info,
+ ),
+ )
+ ),
+ types.InitializeResult,
+ )
+
+ if result.protocolVersion not in SUPPORTED_PROTOCOL_VERSIONS:
+ raise RuntimeError(f"Unsupported protocol version from the server: {result.protocolVersion}")
+
+ self.send_notification(
+ types.ClientNotification(types.InitializedNotification(method="notifications/initialized"))
+ )
+
+ return result
+
+ def send_ping(self) -> types.EmptyResult:
+ """Send a ping request."""
+ return self.send_request(
+ types.ClientRequest(
+ types.PingRequest(
+ method="ping",
+ )
+ ),
+ types.EmptyResult,
+ )
+
+ def send_progress_notification(
+ self, progress_token: str | int, progress: float, total: float | None = None
+ ) -> None:
+ """Send a progress notification."""
+ self.send_notification(
+ types.ClientNotification(
+ types.ProgressNotification(
+ method="notifications/progress",
+ params=types.ProgressNotificationParams(
+ progressToken=progress_token,
+ progress=progress,
+ total=total,
+ ),
+ ),
+ )
+ )
+
+ def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
+ """Send a logging/setLevel request."""
+ return self.send_request(
+ types.ClientRequest(
+ types.SetLevelRequest(
+ method="logging/setLevel",
+ params=types.SetLevelRequestParams(level=level),
+ )
+ ),
+ types.EmptyResult,
+ )
+
+ def list_resources(self) -> types.ListResourcesResult:
+ """Send a resources/list request."""
+ return self.send_request(
+ types.ClientRequest(
+ types.ListResourcesRequest(
+ method="resources/list",
+ )
+ ),
+ types.ListResourcesResult,
+ )
+
+ def list_resource_templates(self) -> types.ListResourceTemplatesResult:
+ """Send a resources/templates/list request."""
+ return self.send_request(
+ types.ClientRequest(
+ types.ListResourceTemplatesRequest(
+ method="resources/templates/list",
+ )
+ ),
+ types.ListResourceTemplatesResult,
+ )
+
+ def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
+ """Send a resources/read request."""
+ return self.send_request(
+ types.ClientRequest(
+ types.ReadResourceRequest(
+ method="resources/read",
+ params=types.ReadResourceRequestParams(uri=uri),
+ )
+ ),
+ types.ReadResourceResult,
+ )
+
+ def subscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
+ """Send a resources/subscribe request."""
+ return self.send_request(
+ types.ClientRequest(
+ types.SubscribeRequest(
+ method="resources/subscribe",
+ params=types.SubscribeRequestParams(uri=uri),
+ )
+ ),
+ types.EmptyResult,
+ )
+
+ def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
+ """Send a resources/unsubscribe request."""
+ return self.send_request(
+ types.ClientRequest(
+ types.UnsubscribeRequest(
+ method="resources/unsubscribe",
+ params=types.UnsubscribeRequestParams(uri=uri),
+ )
+ ),
+ types.EmptyResult,
+ )
+
+ def call_tool(
+ self,
+ name: str,
+ arguments: dict[str, Any] | None = None,
+ read_timeout_seconds: timedelta | None = None,
+ ) -> types.CallToolResult:
+ """Send a tools/call request."""
+
+ return self.send_request(
+ types.ClientRequest(
+ types.CallToolRequest(
+ method="tools/call",
+ params=types.CallToolRequestParams(name=name, arguments=arguments),
+ )
+ ),
+ types.CallToolResult,
+ request_read_timeout_seconds=read_timeout_seconds,
+ )
+
+ def list_prompts(self) -> types.ListPromptsResult:
+ """Send a prompts/list request."""
+ return self.send_request(
+ types.ClientRequest(
+ types.ListPromptsRequest(
+ method="prompts/list",
+ )
+ ),
+ types.ListPromptsResult,
+ )
+
+ def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
+ """Send a prompts/get request."""
+ return self.send_request(
+ types.ClientRequest(
+ types.GetPromptRequest(
+ method="prompts/get",
+ params=types.GetPromptRequestParams(name=name, arguments=arguments),
+ )
+ ),
+ types.GetPromptResult,
+ )
+
+ def complete(
+ self,
+ ref: types.ResourceReference | types.PromptReference,
+ argument: dict[str, str],
+ ) -> types.CompleteResult:
+ """Send a completion/complete request."""
+ return self.send_request(
+ types.ClientRequest(
+ types.CompleteRequest(
+ method="completion/complete",
+ params=types.CompleteRequestParams(
+ ref=ref,
+ argument=types.CompletionArgument(**argument),
+ ),
+ )
+ ),
+ types.CompleteResult,
+ )
+
+ def list_tools(self) -> types.ListToolsResult:
+ """Send a tools/list request."""
+ return self.send_request(
+ types.ClientRequest(
+ types.ListToolsRequest(
+ method="tools/list",
+ )
+ ),
+ types.ListToolsResult,
+ )
+
+ def send_roots_list_changed(self) -> None:
+ """Send a roots/list_changed notification."""
+ self.send_notification(
+ types.ClientNotification(
+ types.RootsListChangedNotification(
+ method="notifications/roots/list_changed",
+ )
+ )
+ )
+
+ def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
+ ctx = RequestContext[ClientSession, Any](
+ request_id=responder.request_id,
+ meta=responder.request_meta,
+ session=self,
+ lifespan_context=None,
+ )
+
+ match responder.request.root:
+ case types.CreateMessageRequest(params=params):
+ with responder:
+ response = self._sampling_callback(ctx, params)
+ client_response = ClientResponse.validate_python(response)
+ responder.respond(client_response)
+
+ case types.ListRootsRequest():
+ with responder:
+ list_roots_response = self._list_roots_callback(ctx)
+ client_response = ClientResponse.validate_python(list_roots_response)
+ responder.respond(client_response)
+
+ case types.PingRequest():
+ with responder:
+ return responder.respond(types.ClientResult(root=types.EmptyResult()))
+
+ def _handle_incoming(
+ self,
+ req: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception,
+ ) -> None:
+ """Handle incoming messages by forwarding to the message handler."""
+ self._message_handler(req)
+
+ def _received_notification(self, notification: types.ServerNotification) -> None:
+ """Handle notifications from the server."""
+ # Process specific notification types
+ match notification.root:
+ case types.LoggingMessageNotification(params=params):
+ self._logging_callback(params)
+ case _:
+ pass
diff --git a/api/core/mcp/types.py b/api/core/mcp/types.py
new file mode 100644
index 0000000000..99d985a781
--- /dev/null
+++ b/api/core/mcp/types.py
@@ -0,0 +1,1217 @@
+from collections.abc import Callable
+from dataclasses import dataclass
+from typing import (
+ Annotated,
+ Any,
+ Generic,
+ Literal,
+ Optional,
+ TypeAlias,
+ TypeVar,
+)
+
+from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel
+from pydantic.networks import AnyUrl, UrlConstraints
+
+"""
+Model Context Protocol bindings for Python
+
+These bindings were generated from https://github.com/modelcontextprotocol/specification,
+using Claude, with a prompt something like the following:
+
+Generate idiomatic Python bindings for this schema for MCP, or the "Model Context
+Protocol." The schema is defined in TypeScript, but there's also a JSON Schema version
+for reference.
+
+* For the bindings, let's use Pydantic V2 models.
+* Each model should allow extra fields everywhere, by specifying `model_config =
+ ConfigDict(extra='allow')`. Do this in every case, instead of a custom base class.
+* Union types should be represented with a Pydantic `RootModel`.
+* Define additional model classes instead of using dictionaries. Do this even if they're
+ not separate types in the schema.
+"""
+# Client support both version, not support 2025-06-18 yet.
+LATEST_PROTOCOL_VERSION = "2025-03-26"
+# Server support 2024-11-05 to allow claude to use.
+SERVER_LATEST_PROTOCOL_VERSION = "2024-11-05"
+ProgressToken = str | int
+Cursor = str
+Role = Literal["user", "assistant"]
+RequestId = Annotated[int | str, Field(union_mode="left_to_right")]
+AnyFunction: TypeAlias = Callable[..., Any]
+
+
+class RequestParams(BaseModel):
+ class Meta(BaseModel):
+ progressToken: ProgressToken | None = None
+ """
+ If specified, the caller requests out-of-band progress notifications for
+ this request (as represented by notifications/progress). The value of this
+ parameter is an opaque token that will be attached to any subsequent
+ notifications. The receiver is not obligated to provide these notifications.
+ """
+
+ model_config = ConfigDict(extra="allow")
+
+ meta: Meta | None = Field(alias="_meta", default=None)
+
+
+class NotificationParams(BaseModel):
+ class Meta(BaseModel):
+ model_config = ConfigDict(extra="allow")
+
+ meta: Meta | None = Field(alias="_meta", default=None)
+ """
+ This parameter name is reserved by MCP to allow clients and servers to attach
+ additional metadata to their notifications.
+ """
+
+
+RequestParamsT = TypeVar("RequestParamsT", bound=RequestParams | dict[str, Any] | None)
+NotificationParamsT = TypeVar("NotificationParamsT", bound=NotificationParams | dict[str, Any] | None)
+MethodT = TypeVar("MethodT", bound=str)
+
+
+class Request(BaseModel, Generic[RequestParamsT, MethodT]):
+ """Base class for JSON-RPC requests."""
+
+ method: MethodT
+ params: RequestParamsT
+ model_config = ConfigDict(extra="allow")
+
+
+class PaginatedRequest(Request[RequestParamsT, MethodT]):
+ cursor: Cursor | None = None
+ """
+ An opaque token representing the current pagination position.
+ If provided, the server should return results starting after this cursor.
+ """
+
+
+class Notification(BaseModel, Generic[NotificationParamsT, MethodT]):
+ """Base class for JSON-RPC notifications."""
+
+ method: MethodT
+ params: NotificationParamsT
+ model_config = ConfigDict(extra="allow")
+
+
+class Result(BaseModel):
+ """Base class for JSON-RPC results."""
+
+ model_config = ConfigDict(extra="allow")
+
+ meta: dict[str, Any] | None = Field(alias="_meta", default=None)
+ """
+ This result property is reserved by the protocol to allow clients and servers to
+ attach additional metadata to their responses.
+ """
+
+
+class PaginatedResult(Result):
+ nextCursor: Cursor | None = None
+ """
+ An opaque token representing the pagination position after the last returned result.
+ If present, there may be more results available.
+ """
+
+
+class JSONRPCRequest(Request[dict[str, Any] | None, str]):
+ """A request that expects a response."""
+
+ jsonrpc: Literal["2.0"]
+ id: RequestId
+ method: str
+ params: dict[str, Any] | None = None
+
+
+class JSONRPCNotification(Notification[dict[str, Any] | None, str]):
+ """A notification which does not expect a response."""
+
+ jsonrpc: Literal["2.0"]
+ params: dict[str, Any] | None = None
+
+
+class JSONRPCResponse(BaseModel):
+ """A successful (non-error) response to a request."""
+
+ jsonrpc: Literal["2.0"]
+ id: RequestId
+ result: dict[str, Any]
+ model_config = ConfigDict(extra="allow")
+
+
+# Standard JSON-RPC error codes
+PARSE_ERROR = -32700
+INVALID_REQUEST = -32600
+METHOD_NOT_FOUND = -32601
+INVALID_PARAMS = -32602
+INTERNAL_ERROR = -32603
+
+
+class ErrorData(BaseModel):
+ """Error information for JSON-RPC error responses."""
+
+ code: int
+ """The error type that occurred."""
+
+ message: str
+ """
+ A short description of the error. The message SHOULD be limited to a concise single
+ sentence.
+ """
+
+ data: Any | None = None
+ """
+ Additional information about the error. The value of this member is defined by the
+ sender (e.g. detailed error information, nested errors etc.).
+ """
+
+ model_config = ConfigDict(extra="allow")
+
+
+class JSONRPCError(BaseModel):
+ """A response to a request that indicates an error occurred."""
+
+ jsonrpc: Literal["2.0"]
+ id: str | int
+ error: ErrorData
+ model_config = ConfigDict(extra="allow")
+
+
+class JSONRPCMessage(RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]):
+ pass
+
+
+class EmptyResult(Result):
+ """A response that indicates success but carries no data."""
+
+
+class Implementation(BaseModel):
+ """Describes the name and version of an MCP implementation."""
+
+ name: str
+ version: str
+ model_config = ConfigDict(extra="allow")
+
+
+class RootsCapability(BaseModel):
+ """Capability for root operations."""
+
+ listChanged: bool | None = None
+ """Whether the client supports notifications for changes to the roots list."""
+ model_config = ConfigDict(extra="allow")
+
+
+class SamplingCapability(BaseModel):
+ """Capability for logging operations."""
+
+ model_config = ConfigDict(extra="allow")
+
+
+class ClientCapabilities(BaseModel):
+ """Capabilities a client may support."""
+
+ experimental: dict[str, dict[str, Any]] | None = None
+ """Experimental, non-standard capabilities that the client supports."""
+ sampling: SamplingCapability | None = None
+ """Present if the client supports sampling from an LLM."""
+ roots: RootsCapability | None = None
+ """Present if the client supports listing roots."""
+ model_config = ConfigDict(extra="allow")
+
+
+class PromptsCapability(BaseModel):
+ """Capability for prompts operations."""
+
+ listChanged: bool | None = None
+ """Whether this server supports notifications for changes to the prompt list."""
+ model_config = ConfigDict(extra="allow")
+
+
+class ResourcesCapability(BaseModel):
+ """Capability for resources operations."""
+
+ subscribe: bool | None = None
+ """Whether this server supports subscribing to resource updates."""
+ listChanged: bool | None = None
+ """Whether this server supports notifications for changes to the resource list."""
+ model_config = ConfigDict(extra="allow")
+
+
+class ToolsCapability(BaseModel):
+ """Capability for tools operations."""
+
+ listChanged: bool | None = None
+ """Whether this server supports notifications for changes to the tool list."""
+ model_config = ConfigDict(extra="allow")
+
+
+class LoggingCapability(BaseModel):
+ """Capability for logging operations."""
+
+ model_config = ConfigDict(extra="allow")
+
+
+class ServerCapabilities(BaseModel):
+ """Capabilities that a server may support."""
+
+ experimental: dict[str, dict[str, Any]] | None = None
+ """Experimental, non-standard capabilities that the server supports."""
+ logging: LoggingCapability | None = None
+ """Present if the server supports sending log messages to the client."""
+ prompts: PromptsCapability | None = None
+ """Present if the server offers any prompt templates."""
+ resources: ResourcesCapability | None = None
+ """Present if the server offers any resources to read."""
+ tools: ToolsCapability | None = None
+ """Present if the server offers any tools to call."""
+ model_config = ConfigDict(extra="allow")
+
+
+class InitializeRequestParams(RequestParams):
+ """Parameters for the initialize request."""
+
+ protocolVersion: str | int
+ """The latest version of the Model Context Protocol that the client supports."""
+ capabilities: ClientCapabilities
+ clientInfo: Implementation
+ model_config = ConfigDict(extra="allow")
+
+
+class InitializeRequest(Request[InitializeRequestParams, Literal["initialize"]]):
+ """
+ This request is sent from the client to the server when it first connects, asking it
+ to begin initialization.
+ """
+
+ method: Literal["initialize"]
+ params: InitializeRequestParams
+
+
+class InitializeResult(Result):
+ """After receiving an initialize request from the client, the server sends this."""
+
+ protocolVersion: str | int
+ """The version of the Model Context Protocol that the server wants to use."""
+ capabilities: ServerCapabilities
+ serverInfo: Implementation
+ instructions: str | None = None
+ """Instructions describing how to use the server and its features."""
+
+
+class InitializedNotification(Notification[NotificationParams | None, Literal["notifications/initialized"]]):
+ """
+ This notification is sent from the client to the server after initialization has
+ finished.
+ """
+
+ method: Literal["notifications/initialized"]
+ params: NotificationParams | None = None
+
+
+class PingRequest(Request[RequestParams | None, Literal["ping"]]):
+ """
+ A ping, issued by either the server or the client, to check that the other party is
+ still alive.
+ """
+
+ method: Literal["ping"]
+ params: RequestParams | None = None
+
+
+class ProgressNotificationParams(NotificationParams):
+ """Parameters for progress notifications."""
+
+ progressToken: ProgressToken
+ """
+ The progress token which was given in the initial request, used to associate this
+ notification with the request that is proceeding.
+ """
+ progress: float
+ """
+ The progress thus far. This should increase every time progress is made, even if the
+ total is unknown.
+ """
+ total: float | None = None
+ """Total number of items to process (or total progress required), if known."""
+ model_config = ConfigDict(extra="allow")
+
+
+class ProgressNotification(Notification[ProgressNotificationParams, Literal["notifications/progress"]]):
+ """
+ An out-of-band notification used to inform the receiver of a progress update for a
+ long-running request.
+ """
+
+ method: Literal["notifications/progress"]
+ params: ProgressNotificationParams
+
+
+class ListResourcesRequest(PaginatedRequest[RequestParams | None, Literal["resources/list"]]):
+ """Sent from the client to request a list of resources the server has."""
+
+ method: Literal["resources/list"]
+ params: RequestParams | None = None
+
+
+class Annotations(BaseModel):
+ audience: list[Role] | None = None
+ priority: Annotated[float, Field(ge=0.0, le=1.0)] | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class Resource(BaseModel):
+ """A known resource that the server is capable of reading."""
+
+ uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
+ """The URI of this resource."""
+ name: str
+ """A human-readable name for this resource."""
+ description: str | None = None
+ """A description of what this resource represents."""
+ mimeType: str | None = None
+ """The MIME type of this resource, if known."""
+ size: int | None = None
+ """
+ The size of the raw resource content, in bytes (i.e., before base64 encoding
+ or any tokenization), if known.
+
+ This can be used by Hosts to display file sizes and estimate context window usage.
+ """
+ annotations: Annotations | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class ResourceTemplate(BaseModel):
+ """A template description for resources available on the server."""
+
+ uriTemplate: str
+ """
+ A URI template (according to RFC 6570) that can be used to construct resource
+ URIs.
+ """
+ name: str
+ """A human-readable name for the type of resource this template refers to."""
+ description: str | None = None
+ """A human-readable description of what this template is for."""
+ mimeType: str | None = None
+ """
+ The MIME type for all resources that match this template. This should only be
+ included if all resources matching this template have the same type.
+ """
+ annotations: Annotations | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class ListResourcesResult(PaginatedResult):
+ """The server's response to a resources/list request from the client."""
+
+ resources: list[Resource]
+
+
+class ListResourceTemplatesRequest(PaginatedRequest[RequestParams | None, Literal["resources/templates/list"]]):
+ """Sent from the client to request a list of resource templates the server has."""
+
+ method: Literal["resources/templates/list"]
+ params: RequestParams | None = None
+
+
+class ListResourceTemplatesResult(PaginatedResult):
+ """The server's response to a resources/templates/list request from the client."""
+
+ resourceTemplates: list[ResourceTemplate]
+
+
+class ReadResourceRequestParams(RequestParams):
+ """Parameters for reading a resource."""
+
+ uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
+ """
+ The URI of the resource to read. The URI can use any protocol; it is up to the
+ server how to interpret it.
+ """
+ model_config = ConfigDict(extra="allow")
+
+
+class ReadResourceRequest(Request[ReadResourceRequestParams, Literal["resources/read"]]):
+ """Sent from the client to the server, to read a specific resource URI."""
+
+ method: Literal["resources/read"]
+ params: ReadResourceRequestParams
+
+
+class ResourceContents(BaseModel):
+ """The contents of a specific resource or sub-resource."""
+
+ uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
+ """The URI of this resource."""
+ mimeType: str | None = None
+ """The MIME type of this resource, if known."""
+ model_config = ConfigDict(extra="allow")
+
+
+class TextResourceContents(ResourceContents):
+ """Text contents of a resource."""
+
+ text: str
+ """
+ The text of the item. This must only be set if the item can actually be represented
+ as text (not binary data).
+ """
+
+
+class BlobResourceContents(ResourceContents):
+ """Binary contents of a resource."""
+
+ blob: str
+ """A base64-encoded string representing the binary data of the item."""
+
+
+class ReadResourceResult(Result):
+ """The server's response to a resources/read request from the client."""
+
+ contents: list[TextResourceContents | BlobResourceContents]
+
+
+class ResourceListChangedNotification(
+ Notification[NotificationParams | None, Literal["notifications/resources/list_changed"]]
+):
+ """
+ An optional notification from the server to the client, informing it that the list
+ of resources it can read from has changed.
+ """
+
+ method: Literal["notifications/resources/list_changed"]
+ params: NotificationParams | None = None
+
+
+class SubscribeRequestParams(RequestParams):
+ """Parameters for subscribing to a resource."""
+
+ uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
+ """
+ The URI of the resource to subscribe to. The URI can use any protocol; it is up to
+ the server how to interpret it.
+ """
+ model_config = ConfigDict(extra="allow")
+
+
+class SubscribeRequest(Request[SubscribeRequestParams, Literal["resources/subscribe"]]):
+ """
+ Sent from the client to request resources/updated notifications from the server
+ whenever a particular resource changes.
+ """
+
+ method: Literal["resources/subscribe"]
+ params: SubscribeRequestParams
+
+
+class UnsubscribeRequestParams(RequestParams):
+ """Parameters for unsubscribing from a resource."""
+
+ uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
+ """The URI of the resource to unsubscribe from."""
+ model_config = ConfigDict(extra="allow")
+
+
+class UnsubscribeRequest(Request[UnsubscribeRequestParams, Literal["resources/unsubscribe"]]):
+ """
+ Sent from the client to request cancellation of resources/updated notifications from
+ the server.
+ """
+
+ method: Literal["resources/unsubscribe"]
+ params: UnsubscribeRequestParams
+
+
+class ResourceUpdatedNotificationParams(NotificationParams):
+ """Parameters for resource update notifications."""
+
+ uri: Annotated[AnyUrl, UrlConstraints(host_required=False)]
+ """
+ The URI of the resource that has been updated. This might be a sub-resource of the
+ one that the client actually subscribed to.
+ """
+ model_config = ConfigDict(extra="allow")
+
+
+class ResourceUpdatedNotification(
+ Notification[ResourceUpdatedNotificationParams, Literal["notifications/resources/updated"]]
+):
+ """
+ A notification from the server to the client, informing it that a resource has
+ changed and may need to be read again.
+ """
+
+ method: Literal["notifications/resources/updated"]
+ params: ResourceUpdatedNotificationParams
+
+
+class ListPromptsRequest(PaginatedRequest[RequestParams | None, Literal["prompts/list"]]):
+ """Sent from the client to request a list of prompts and prompt templates."""
+
+ method: Literal["prompts/list"]
+ params: RequestParams | None = None
+
+
+class PromptArgument(BaseModel):
+ """An argument for a prompt template."""
+
+ name: str
+ """The name of the argument."""
+ description: str | None = None
+ """A human-readable description of the argument."""
+ required: bool | None = None
+ """Whether this argument must be provided."""
+ model_config = ConfigDict(extra="allow")
+
+
+class Prompt(BaseModel):
+ """A prompt or prompt template that the server offers."""
+
+ name: str
+ """The name of the prompt or prompt template."""
+ description: str | None = None
+ """An optional description of what this prompt provides."""
+ arguments: list[PromptArgument] | None = None
+ """A list of arguments to use for templating the prompt."""
+ model_config = ConfigDict(extra="allow")
+
+
+class ListPromptsResult(PaginatedResult):
+ """The server's response to a prompts/list request from the client."""
+
+ prompts: list[Prompt]
+
+
+class GetPromptRequestParams(RequestParams):
+ """Parameters for getting a prompt."""
+
+ name: str
+ """The name of the prompt or prompt template."""
+ arguments: dict[str, str] | None = None
+ """Arguments to use for templating the prompt."""
+ model_config = ConfigDict(extra="allow")
+
+
+class GetPromptRequest(Request[GetPromptRequestParams, Literal["prompts/get"]]):
+ """Used by the client to get a prompt provided by the server."""
+
+ method: Literal["prompts/get"]
+ params: GetPromptRequestParams
+
+
+class TextContent(BaseModel):
+ """Text content for a message."""
+
+ type: Literal["text"]
+ text: str
+ """The text content of the message."""
+ annotations: Annotations | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class ImageContent(BaseModel):
+ """Image content for a message."""
+
+ type: Literal["image"]
+ data: str
+ """The base64-encoded image data."""
+ mimeType: str
+ """
+ The MIME type of the image. Different providers may support different
+ image types.
+ """
+ annotations: Annotations | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class SamplingMessage(BaseModel):
+ """Describes a message issued to or received from an LLM API."""
+
+ role: Role
+ content: TextContent | ImageContent
+ model_config = ConfigDict(extra="allow")
+
+
+class EmbeddedResource(BaseModel):
+ """
+ The contents of a resource, embedded into a prompt or tool call result.
+
+ It is up to the client how best to render embedded resources for the benefit
+ of the LLM and/or the user.
+ """
+
+ type: Literal["resource"]
+ resource: TextResourceContents | BlobResourceContents
+ annotations: Annotations | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class PromptMessage(BaseModel):
+ """Describes a message returned as part of a prompt."""
+
+ role: Role
+ content: TextContent | ImageContent | EmbeddedResource
+ model_config = ConfigDict(extra="allow")
+
+
+class GetPromptResult(Result):
+ """The server's response to a prompts/get request from the client."""
+
+ description: str | None = None
+ """An optional description for the prompt."""
+ messages: list[PromptMessage]
+
+
+class PromptListChangedNotification(
+ Notification[NotificationParams | None, Literal["notifications/prompts/list_changed"]]
+):
+ """
+ An optional notification from the server to the client, informing it that the list
+ of prompts it offers has changed.
+ """
+
+ method: Literal["notifications/prompts/list_changed"]
+ params: NotificationParams | None = None
+
+
+class ListToolsRequest(PaginatedRequest[RequestParams | None, Literal["tools/list"]]):
+ """Sent from the client to request a list of tools the server has."""
+
+ method: Literal["tools/list"]
+ params: RequestParams | None = None
+
+
+class ToolAnnotations(BaseModel):
+ """
+ Additional properties describing a Tool to clients.
+
+ NOTE: all properties in ToolAnnotations are **hints**.
+ They are not guaranteed to provide a faithful description of
+ tool behavior (including descriptive properties like `title`).
+
+ Clients should never make tool use decisions based on ToolAnnotations
+ received from untrusted servers.
+ """
+
+ title: str | None = None
+ """A human-readable title for the tool."""
+
+ readOnlyHint: bool | None = None
+ """
+ If true, the tool does not modify its environment.
+ Default: false
+ """
+
+ destructiveHint: bool | None = None
+ """
+ If true, the tool may perform destructive updates to its environment.
+ If false, the tool performs only additive updates.
+ (This property is meaningful only when `readOnlyHint == false`)
+ Default: true
+ """
+
+ idempotentHint: bool | None = None
+ """
+ If true, calling the tool repeatedly with the same arguments
+ will have no additional effect on the its environment.
+ (This property is meaningful only when `readOnlyHint == false`)
+ Default: false
+ """
+
+ openWorldHint: bool | None = None
+ """
+ If true, this tool may interact with an "open world" of external
+ entities. If false, the tool's domain of interaction is closed.
+ For example, the world of a web search tool is open, whereas that
+ of a memory tool is not.
+ Default: true
+ """
+ model_config = ConfigDict(extra="allow")
+
+
+class Tool(BaseModel):
+ """Definition for a tool the client can call."""
+
+ name: str
+ """The name of the tool."""
+ description: str | None = None
+ """A human-readable description of the tool."""
+ inputSchema: dict[str, Any]
+ """A JSON Schema object defining the expected parameters for the tool."""
+ annotations: ToolAnnotations | None = None
+ """Optional additional tool information."""
+ model_config = ConfigDict(extra="allow")
+
+
+class ListToolsResult(PaginatedResult):
+ """The server's response to a tools/list request from the client."""
+
+ tools: list[Tool]
+
+
+class CallToolRequestParams(RequestParams):
+ """Parameters for calling a tool."""
+
+ name: str
+ arguments: dict[str, Any] | None = None
+ model_config = ConfigDict(extra="allow")
+
+
+class CallToolRequest(Request[CallToolRequestParams, Literal["tools/call"]]):
+ """Used by the client to invoke a tool provided by the server."""
+
+ method: Literal["tools/call"]
+ params: CallToolRequestParams
+
+
+class CallToolResult(Result):
+ """The server's response to a tool call."""
+
+ content: list[TextContent | ImageContent | EmbeddedResource]
+ isError: bool = False
+
+
+class ToolListChangedNotification(Notification[NotificationParams | None, Literal["notifications/tools/list_changed"]]):
+ """
+ An optional notification from the server to the client, informing it that the list
+ of tools it offers has changed.
+ """
+
+ method: Literal["notifications/tools/list_changed"]
+ params: NotificationParams | None = None
+
+
+LoggingLevel = Literal["debug", "info", "notice", "warning", "error", "critical", "alert", "emergency"]
+
+
+class SetLevelRequestParams(RequestParams):
+ """Parameters for setting the logging level."""
+
+ level: LoggingLevel
+ """The level of logging that the client wants to receive from the server."""
+ model_config = ConfigDict(extra="allow")
+
+
+class SetLevelRequest(Request[SetLevelRequestParams, Literal["logging/setLevel"]]):
+ """A request from the client to the server, to enable or adjust logging."""
+
+ method: Literal["logging/setLevel"]
+ params: SetLevelRequestParams
+
+
+class LoggingMessageNotificationParams(NotificationParams):
+ """Parameters for logging message notifications."""
+
+ level: LoggingLevel
+ """The severity of this log message."""
+ logger: str | None = None
+ """An optional name of the logger issuing this message."""
+ data: Any
+ """
+ The data to be logged, such as a string message or an object. Any JSON serializable
+ type is allowed here.
+ """
+ model_config = ConfigDict(extra="allow")
+
+
+class LoggingMessageNotification(Notification[LoggingMessageNotificationParams, Literal["notifications/message"]]):
+ """Notification of a log message passed from server to client."""
+
+ method: Literal["notifications/message"]
+ params: LoggingMessageNotificationParams
+
+
+IncludeContext = Literal["none", "thisServer", "allServers"]
+
+
+class ModelHint(BaseModel):
+ """Hints to use for model selection."""
+
+ name: str | None = None
+ """A hint for a model name."""
+
+ model_config = ConfigDict(extra="allow")
+
+
+class ModelPreferences(BaseModel):
+ """
+ The server's preferences for model selection, requested by the client during
+ sampling.
+
+ Because LLMs can vary along multiple dimensions, choosing the "best" model is
+ rarely straightforward. Different models excel in different areas—some are
+ faster but less capable, others are more capable but more expensive, and so
+ on. This interface allows servers to express their priorities across multiple
+ dimensions to help clients make an appropriate selection for their use case.
+
+ These preferences are always advisory. The client MAY ignore them. It is also
+ up to the client to decide how to interpret these preferences and how to
+ balance them against other considerations.
+ """
+
+ hints: list[ModelHint] | None = None
+ """
+ Optional hints to use for model selection.
+
+ If multiple hints are specified, the client MUST evaluate them in order
+ (such that the first match is taken).
+
+ The client SHOULD prioritize these hints over the numeric priorities, but
+ MAY still use the priorities to select from ambiguous matches.
+ """
+
+ costPriority: float | None = None
+ """
+ How much to prioritize cost when selecting a model. A value of 0 means cost
+ is not important, while a value of 1 means cost is the most important
+ factor.
+ """
+
+ speedPriority: float | None = None
+ """
+ How much to prioritize sampling speed (latency) when selecting a model. A
+ value of 0 means speed is not important, while a value of 1 means speed is
+ the most important factor.
+ """
+
+ intelligencePriority: float | None = None
+ """
+ How much to prioritize intelligence and capabilities when selecting a
+ model. A value of 0 means intelligence is not important, while a value of 1
+ means intelligence is the most important factor.
+ """
+
+ model_config = ConfigDict(extra="allow")
+
+
+class CreateMessageRequestParams(RequestParams):
+ """Parameters for creating a message."""
+
+ messages: list[SamplingMessage]
+ modelPreferences: ModelPreferences | None = None
+ """
+ The server's preferences for which model to select. The client MAY ignore
+ these preferences.
+ """
+ systemPrompt: str | None = None
+ """An optional system prompt the server wants to use for sampling."""
+ includeContext: IncludeContext | None = None
+ """
+ A request to include context from one or more MCP servers (including the caller), to
+ be attached to the prompt.
+ """
+ temperature: float | None = None
+ maxTokens: int
+ """The maximum number of tokens to sample, as requested by the server."""
+ stopSequences: list[str] | None = None
+ metadata: dict[str, Any] | None = None
+ """Optional metadata to pass through to the LLM provider."""
+ model_config = ConfigDict(extra="allow")
+
+
+class CreateMessageRequest(Request[CreateMessageRequestParams, Literal["sampling/createMessage"]]):
+ """A request from the server to sample an LLM via the client."""
+
+ method: Literal["sampling/createMessage"]
+ params: CreateMessageRequestParams
+
+
+StopReason = Literal["endTurn", "stopSequence", "maxTokens"] | str
+
+
+class CreateMessageResult(Result):
+ """The client's response to a sampling/create_message request from the server."""
+
+ role: Role
+ content: TextContent | ImageContent
+ model: str
+ """The name of the model that generated the message."""
+ stopReason: StopReason | None = None
+ """The reason why sampling stopped, if known."""
+
+
+class ResourceReference(BaseModel):
+ """A reference to a resource or resource template definition."""
+
+ type: Literal["ref/resource"]
+ uri: str
+ """The URI or URI template of the resource."""
+ model_config = ConfigDict(extra="allow")
+
+
+class PromptReference(BaseModel):
+ """Identifies a prompt."""
+
+ type: Literal["ref/prompt"]
+ name: str
+ """The name of the prompt or prompt template"""
+ model_config = ConfigDict(extra="allow")
+
+
+class CompletionArgument(BaseModel):
+ """The argument's information for completion requests."""
+
+ name: str
+ """The name of the argument"""
+ value: str
+ """The value of the argument to use for completion matching."""
+ model_config = ConfigDict(extra="allow")
+
+
+class CompleteRequestParams(RequestParams):
+ """Parameters for completion requests."""
+
+ ref: ResourceReference | PromptReference
+ argument: CompletionArgument
+ model_config = ConfigDict(extra="allow")
+
+
+class CompleteRequest(Request[CompleteRequestParams, Literal["completion/complete"]]):
+ """A request from the client to the server, to ask for completion options."""
+
+ method: Literal["completion/complete"]
+ params: CompleteRequestParams
+
+
+class Completion(BaseModel):
+ """Completion information."""
+
+ values: list[str]
+ """An array of completion values. Must not exceed 100 items."""
+ total: int | None = None
+ """
+ The total number of completion options available. This can exceed the number of
+ values actually sent in the response.
+ """
+ hasMore: bool | None = None
+ """
+ Indicates whether there are additional completion options beyond those provided in
+ the current response, even if the exact total is unknown.
+ """
+ model_config = ConfigDict(extra="allow")
+
+
+class CompleteResult(Result):
+ """The server's response to a completion/complete request"""
+
+ completion: Completion
+
+
+class ListRootsRequest(Request[RequestParams | None, Literal["roots/list"]]):
+ """
+ Sent from the server to request a list of root URIs from the client. Roots allow
+ servers to ask for specific directories or files to operate on. A common example
+ for roots is providing a set of repositories or directories a server should operate
+ on.
+
+ This request is typically used when the server needs to understand the file system
+ structure or access specific locations that the client has permission to read from.
+ """
+
+ method: Literal["roots/list"]
+ params: RequestParams | None = None
+
+
+class Root(BaseModel):
+ """Represents a root directory or file that the server can operate on."""
+
+ uri: FileUrl
+ """
+ The URI identifying the root. This *must* start with file:// for now.
+ This restriction may be relaxed in future versions of the protocol to allow
+ other URI schemes.
+ """
+ name: str | None = None
+ """
+ An optional name for the root. This can be used to provide a human-readable
+ identifier for the root, which may be useful for display purposes or for
+ referencing the root in other parts of the application.
+ """
+ model_config = ConfigDict(extra="allow")
+
+
+class ListRootsResult(Result):
+ """
+ The client's response to a roots/list request from the server.
+ This result contains an array of Root objects, each representing a root directory
+ or file that the server can operate on.
+ """
+
+ roots: list[Root]
+
+
+class RootsListChangedNotification(
+ Notification[NotificationParams | None, Literal["notifications/roots/list_changed"]]
+):
+ """
+ A notification from the client to the server, informing it that the list of
+ roots has changed.
+
+ This notification should be sent whenever the client adds, removes, or
+ modifies any root. The server should then request an updated list of roots
+ using the ListRootsRequest.
+ """
+
+ method: Literal["notifications/roots/list_changed"]
+ params: NotificationParams | None = None
+
+
+class CancelledNotificationParams(NotificationParams):
+ """Parameters for cancellation notifications."""
+
+ requestId: RequestId
+ """The ID of the request to cancel."""
+ reason: str | None = None
+ """An optional string describing the reason for the cancellation."""
+ model_config = ConfigDict(extra="allow")
+
+
+class CancelledNotification(Notification[CancelledNotificationParams, Literal["notifications/cancelled"]]):
+ """
+ This notification can be sent by either side to indicate that it is canceling a
+ previously-issued request.
+ """
+
+ method: Literal["notifications/cancelled"]
+ params: CancelledNotificationParams
+
+
+class ClientRequest(
+ RootModel[
+ PingRequest
+ | InitializeRequest
+ | CompleteRequest
+ | SetLevelRequest
+ | GetPromptRequest
+ | ListPromptsRequest
+ | ListResourcesRequest
+ | ListResourceTemplatesRequest
+ | ReadResourceRequest
+ | SubscribeRequest
+ | UnsubscribeRequest
+ | CallToolRequest
+ | ListToolsRequest
+ ]
+):
+ pass
+
+
+class ClientNotification(
+ RootModel[CancelledNotification | ProgressNotification | InitializedNotification | RootsListChangedNotification]
+):
+ pass
+
+
+class ClientResult(RootModel[EmptyResult | CreateMessageResult | ListRootsResult]):
+ pass
+
+
+class ServerRequest(RootModel[PingRequest | CreateMessageRequest | ListRootsRequest]):
+ pass
+
+
+class ServerNotification(
+ RootModel[
+ CancelledNotification
+ | ProgressNotification
+ | LoggingMessageNotification
+ | ResourceUpdatedNotification
+ | ResourceListChangedNotification
+ | ToolListChangedNotification
+ | PromptListChangedNotification
+ ]
+):
+ pass
+
+
+class ServerResult(
+ RootModel[
+ EmptyResult
+ | InitializeResult
+ | CompleteResult
+ | GetPromptResult
+ | ListPromptsResult
+ | ListResourcesResult
+ | ListResourceTemplatesResult
+ | ReadResourceResult
+ | CallToolResult
+ | ListToolsResult
+ ]
+):
+ pass
+
+
+ResumptionToken = str
+
+ResumptionTokenUpdateCallback = Callable[[ResumptionToken], None]
+
+
+@dataclass
+class ClientMessageMetadata:
+ """Metadata specific to client messages."""
+
+ resumption_token: ResumptionToken | None = None
+ on_resumption_token_update: Callable[[ResumptionToken], None] | None = None
+
+
+@dataclass
+class ServerMessageMetadata:
+ """Metadata specific to server messages."""
+
+ related_request_id: RequestId | None = None
+ request_context: object | None = None
+
+
+MessageMetadata = ClientMessageMetadata | ServerMessageMetadata | None
+
+
+@dataclass
+class SessionMessage:
+ """A message with specific metadata for transport-specific features."""
+
+ message: JSONRPCMessage
+ metadata: MessageMetadata = None
+
+
+class OAuthClientMetadata(BaseModel):
+ client_name: str
+ redirect_uris: list[str]
+ grant_types: Optional[list[str]] = None
+ response_types: Optional[list[str]] = None
+ token_endpoint_auth_method: Optional[str] = None
+ client_uri: Optional[str] = None
+ scope: Optional[str] = None
+
+
+class OAuthClientInformation(BaseModel):
+ client_id: str
+ client_secret: Optional[str] = None
+
+
+class OAuthClientInformationFull(OAuthClientInformation):
+ client_name: str | None = None
+ redirect_uris: list[str]
+ scope: Optional[str] = None
+ grant_types: Optional[list[str]] = None
+ response_types: Optional[list[str]] = None
+ token_endpoint_auth_method: Optional[str] = None
+
+
+class OAuthTokens(BaseModel):
+ access_token: str
+ token_type: str
+ expires_in: Optional[int] = None
+ refresh_token: Optional[str] = None
+ scope: Optional[str] = None
+
+
+class OAuthMetadata(BaseModel):
+ authorization_endpoint: str
+ token_endpoint: str
+ registration_endpoint: Optional[str] = None
+ response_types_supported: list[str]
+ grant_types_supported: Optional[list[str]] = None
+ code_challenge_methods_supported: Optional[list[str]] = None
diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py
new file mode 100644
index 0000000000..a54badcd4c
--- /dev/null
+++ b/api/core/mcp/utils.py
@@ -0,0 +1,114 @@
+import json
+
+import httpx
+
+from configs import dify_config
+from core.mcp.types import ErrorData, JSONRPCError
+from core.model_runtime.utils.encoders import jsonable_encoder
+
+HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY
+
+STATUS_FORCELIST = [429, 500, 502, 503, 504]
+
+
+def create_ssrf_proxy_mcp_http_client(
+ headers: dict[str, str] | None = None,
+ timeout: httpx.Timeout | None = None,
+) -> httpx.Client:
+ """Create an HTTPX client with SSRF proxy configuration for MCP connections.
+
+ Args:
+ headers: Optional headers to include in the client
+ timeout: Optional timeout configuration
+
+ Returns:
+ Configured httpx.Client with proxy settings
+ """
+ if dify_config.SSRF_PROXY_ALL_URL:
+ return httpx.Client(
+ verify=HTTP_REQUEST_NODE_SSL_VERIFY,
+ headers=headers or {},
+ timeout=timeout,
+ follow_redirects=True,
+ proxy=dify_config.SSRF_PROXY_ALL_URL,
+ )
+ elif dify_config.SSRF_PROXY_HTTP_URL and dify_config.SSRF_PROXY_HTTPS_URL:
+ proxy_mounts = {
+ "http://": httpx.HTTPTransport(proxy=dify_config.SSRF_PROXY_HTTP_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY),
+ "https://": httpx.HTTPTransport(
+ proxy=dify_config.SSRF_PROXY_HTTPS_URL, verify=HTTP_REQUEST_NODE_SSL_VERIFY
+ ),
+ }
+ return httpx.Client(
+ verify=HTTP_REQUEST_NODE_SSL_VERIFY,
+ headers=headers or {},
+ timeout=timeout,
+ follow_redirects=True,
+ mounts=proxy_mounts,
+ )
+ else:
+ return httpx.Client(
+ verify=HTTP_REQUEST_NODE_SSL_VERIFY,
+ headers=headers or {},
+ timeout=timeout,
+ follow_redirects=True,
+ )
+
+
+def ssrf_proxy_sse_connect(url, **kwargs):
+ """Connect to SSE endpoint with SSRF proxy protection.
+
+ This function creates an SSE connection using the configured proxy settings
+ to prevent SSRF attacks when connecting to external endpoints.
+
+ Args:
+ url: The SSE endpoint URL
+ **kwargs: Additional arguments passed to the SSE connection
+
+ Returns:
+ EventSource object for SSE streaming
+ """
+ from httpx_sse import connect_sse
+
+ # Extract client if provided, otherwise create one
+ client = kwargs.pop("client", None)
+ if client is None:
+ # Create client with SSRF proxy configuration
+ timeout = kwargs.pop(
+ "timeout",
+ httpx.Timeout(
+ timeout=dify_config.SSRF_DEFAULT_TIME_OUT,
+ connect=dify_config.SSRF_DEFAULT_CONNECT_TIME_OUT,
+ read=dify_config.SSRF_DEFAULT_READ_TIME_OUT,
+ write=dify_config.SSRF_DEFAULT_WRITE_TIME_OUT,
+ ),
+ )
+ headers = kwargs.pop("headers", {})
+ client = create_ssrf_proxy_mcp_http_client(headers=headers, timeout=timeout)
+ client_provided = False
+ else:
+ client_provided = True
+
+ # Extract method if provided, default to GET
+ method = kwargs.pop("method", "GET")
+
+ try:
+ return connect_sse(client, method, url, **kwargs)
+ except Exception:
+ # If we created the client, we need to clean it up on error
+ if not client_provided:
+ client.close()
+ raise
+
+
+def create_mcp_error_response(request_id: int | str | None, code: int, message: str, data=None):
+ """Create MCP error response"""
+ error_data = ErrorData(code=code, message=message, data=data)
+ json_response = JSONRPCError(
+ jsonrpc="2.0",
+ id=request_id or 1,
+ error=error_data,
+ )
+ json_data = json.dumps(jsonable_encoder(json_response))
+ sse_content = f"event: message\ndata: {json_data}\n\n".encode()
+ yield sse_content
diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py
index 2254b3d4d5..a9f0a92e5d 100644
--- a/api/core/memory/token_buffer_memory.py
+++ b/api/core/memory/token_buffer_memory.py
@@ -1,6 +1,8 @@
from collections.abc import Sequence
from typing import Optional
+from sqlalchemy import select
+
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file import file_manager
from core.model_manager import ModelInstance
@@ -17,11 +19,15 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from factories import file_factory
from models.model import AppMode, Conversation, Message, MessageFile
-from models.workflow import WorkflowRun
+from models.workflow import Workflow, WorkflowRun
class TokenBufferMemory:
- def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None:
+ def __init__(
+ self,
+ conversation: Conversation,
+ model_instance: ModelInstance,
+ ) -> None:
self.conversation = conversation
self.model_instance = model_instance
@@ -36,20 +42,8 @@ class TokenBufferMemory:
app_record = self.conversation.app
# fetch limited messages, and return reversed
- query = (
- db.session.query(
- Message.id,
- Message.query,
- Message.answer,
- Message.created_at,
- Message.workflow_run_id,
- Message.parent_message_id,
- Message.answer_tokens,
- )
- .filter(
- Message.conversation_id == self.conversation.id,
- )
- .order_by(Message.created_at.desc())
+ stmt = (
+ select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc())
)
if message_limit and message_limit > 0:
@@ -57,7 +51,9 @@ class TokenBufferMemory:
else:
message_limit = 500
- messages = query.limit(message_limit).all()
+ stmt = stmt.limit(message_limit)
+
+ messages = db.session.scalars(stmt).all()
# instead of all messages from the conversation, we only need to extract messages
# that belong to the thread of last message
@@ -74,18 +70,20 @@ class TokenBufferMemory:
files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
if files:
file_extra_config = None
- if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
+ if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
+ elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
+ workflow_run = db.session.scalar(
+ select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id)
+ )
+ if not workflow_run:
+ raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
+ workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
+ if not workflow:
+ raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
+ file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
else:
- if message.workflow_run_id:
- workflow_run = (
- db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
- )
-
- if workflow_run and workflow_run.workflow:
- file_extra_config = FileUploadConfigManager.convert(
- workflow_run.workflow.features_dict, is_vision=False
- )
+ raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
detail = ImagePromptMessageContent.DETAIL.LOW
if file_extra_config and app_record:
diff --git a/api/core/model_manager.py b/api/core/model_manager.py
index 995a30d44c..4886ffe244 100644
--- a/api/core/model_manager.py
+++ b/api/core/model_manager.py
@@ -542,8 +542,6 @@ class LBModelManager:
return config
- return None
-
def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None:
"""
Cooldown model load balancing config
diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py
index de5a748d4f..ace2c1f770 100644
--- a/api/core/model_runtime/entities/llm_entities.py
+++ b/api/core/model_runtime/entities/llm_entities.py
@@ -1,7 +1,7 @@
-from collections.abc import Sequence
+from collections.abc import Mapping, Sequence
from decimal import Decimal
from enum import StrEnum
-from typing import Optional
+from typing import Any, Optional
from pydantic import BaseModel, Field
@@ -53,6 +53,37 @@ class LLMUsage(ModelUsage):
latency=0.0,
)
+ @classmethod
+ def from_metadata(cls, metadata: dict) -> "LLMUsage":
+ """
+ Create LLMUsage instance from metadata dictionary with default values.
+
+ Args:
+ metadata: Dictionary containing usage metadata
+
+ Returns:
+ LLMUsage instance with values from metadata or defaults
+ """
+ total_tokens = metadata.get("total_tokens", 0)
+ completion_tokens = metadata.get("completion_tokens", 0)
+ if total_tokens > 0 and completion_tokens == 0:
+ completion_tokens = total_tokens
+
+ return cls(
+ prompt_tokens=metadata.get("prompt_tokens", 0),
+ completion_tokens=completion_tokens,
+ total_tokens=total_tokens,
+ prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))),
+ completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))),
+ total_price=Decimal(str(metadata.get("total_price", 0))),
+ currency=metadata.get("currency", "USD"),
+ prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))),
+ completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))),
+ prompt_price=Decimal(str(metadata.get("prompt_price", 0))),
+ completion_price=Decimal(str(metadata.get("completion_price", 0))),
+ latency=metadata.get("latency", 0.0),
+ )
+
def plus(self, other: "LLMUsage") -> "LLMUsage":
"""
Add two LLMUsage instances together.
@@ -101,6 +132,20 @@ class LLMResult(BaseModel):
system_fingerprint: Optional[str] = None
+class LLMStructuredOutput(BaseModel):
+ """
+ Model class for llm structured output.
+ """
+
+ structured_output: Optional[Mapping[str, Any]] = None
+
+
+class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput):
+ """
+ Model class for llm result with structured output.
+ """
+
+
class LLMResultChunkDelta(BaseModel):
"""
Model class for llm result chunk delta.
@@ -123,6 +168,12 @@ class LLMResultChunk(BaseModel):
delta: LLMResultChunkDelta
+class LLMResultChunkWithStructuredOutput(LLMResultChunk, LLMStructuredOutput):
+ """
+ Model class for llm result chunk with structured output.
+ """
+
+
class NumTokensResult(PriceInfo):
"""
Model class for number of tokens result.
diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py
index d0f9ee13e5..c9aa8d1474 100644
--- a/api/core/model_runtime/entities/provider_entities.py
+++ b/api/core/model_runtime/entities/provider_entities.py
@@ -123,6 +123,8 @@ class ProviderEntity(BaseModel):
description: Optional[I18nObject] = None
icon_small: Optional[I18nObject] = None
icon_large: Optional[I18nObject] = None
+ icon_small_dark: Optional[I18nObject] = None
+ icon_large_dark: Optional[I18nObject] = None
background: Optional[str] = None
help: Optional[ProviderHelpEntity] = None
supported_model_types: Sequence[ModelType]
diff --git a/api/core/ops/aliyun_trace/__init__.py b/api/core/ops/aliyun_trace/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py
new file mode 100644
index 0000000000..db8fec4ee9
--- /dev/null
+++ b/api/core/ops/aliyun_trace/aliyun_trace.py
@@ -0,0 +1,488 @@
+import json
+import logging
+from collections.abc import Sequence
+from typing import Optional
+from urllib.parse import urljoin
+
+from opentelemetry.trace import Status, StatusCode
+from sqlalchemy.orm import Session, sessionmaker
+
+from core.ops.aliyun_trace.data_exporter.traceclient import (
+ TraceClient,
+ convert_datetime_to_nanoseconds,
+ convert_to_span_id,
+ convert_to_trace_id,
+ generate_span_id,
+)
+from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
+from core.ops.aliyun_trace.entities.semconv import (
+ GEN_AI_COMPLETION,
+ GEN_AI_FRAMEWORK,
+ GEN_AI_MODEL_NAME,
+ GEN_AI_PROMPT,
+ GEN_AI_PROMPT_TEMPLATE_TEMPLATE,
+ GEN_AI_PROMPT_TEMPLATE_VARIABLE,
+ GEN_AI_RESPONSE_FINISH_REASON,
+ GEN_AI_SESSION_ID,
+ GEN_AI_SPAN_KIND,
+ GEN_AI_SYSTEM,
+ GEN_AI_USAGE_INPUT_TOKENS,
+ GEN_AI_USAGE_OUTPUT_TOKENS,
+ GEN_AI_USAGE_TOTAL_TOKENS,
+ GEN_AI_USER_ID,
+ INPUT_VALUE,
+ OUTPUT_VALUE,
+ RETRIEVAL_DOCUMENT,
+ RETRIEVAL_QUERY,
+ TOOL_DESCRIPTION,
+ TOOL_NAME,
+ TOOL_PARAMETERS,
+ GenAISpanKind,
+)
+from core.ops.base_trace_instance import BaseTraceInstance
+from core.ops.entities.config_entity import AliyunConfig
+from core.ops.entities.trace_entity import (
+ BaseTraceInfo,
+ DatasetRetrievalTraceInfo,
+ GenerateNameTraceInfo,
+ MessageTraceInfo,
+ ModerationTraceInfo,
+ SuggestedQuestionTraceInfo,
+ ToolTraceInfo,
+ WorkflowTraceInfo,
+)
+from core.rag.models.document import Document
+from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.workflow.entities.workflow_node_execution import (
+ WorkflowNodeExecution,
+ WorkflowNodeExecutionMetadataKey,
+ WorkflowNodeExecutionStatus,
+)
+from core.workflow.nodes import NodeType
+from models import Account, App, EndUser, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom, db
+
+logger = logging.getLogger(__name__)
+
+
+class AliyunDataTrace(BaseTraceInstance):
+ def __init__(
+ self,
+ aliyun_config: AliyunConfig,
+ ):
+ super().__init__(aliyun_config)
+ base_url = aliyun_config.endpoint.rstrip("/")
+ endpoint = urljoin(base_url, f"adapt_{aliyun_config.license_key}/api/otlp/traces")
+ self.trace_client = TraceClient(service_name=aliyun_config.app_name, endpoint=endpoint)
+
+ def trace(self, trace_info: BaseTraceInfo):
+ if isinstance(trace_info, WorkflowTraceInfo):
+ self.workflow_trace(trace_info)
+ if isinstance(trace_info, MessageTraceInfo):
+ self.message_trace(trace_info)
+ if isinstance(trace_info, ModerationTraceInfo):
+ pass
+ if isinstance(trace_info, SuggestedQuestionTraceInfo):
+ self.suggested_question_trace(trace_info)
+ if isinstance(trace_info, DatasetRetrievalTraceInfo):
+ self.dataset_retrieval_trace(trace_info)
+ if isinstance(trace_info, ToolTraceInfo):
+ self.tool_trace(trace_info)
+ if isinstance(trace_info, GenerateNameTraceInfo):
+ pass
+
+ def api_check(self):
+ return self.trace_client.api_check()
+
+ def get_project_url(self):
+ try:
+ return self.trace_client.get_project_url()
+ except Exception as e:
+ logger.info(f"Aliyun get run url failed: {str(e)}", exc_info=True)
+ raise ValueError(f"Aliyun get run url failed: {str(e)}")
+
+ def workflow_trace(self, trace_info: WorkflowTraceInfo):
+ trace_id = convert_to_trace_id(trace_info.workflow_run_id)
+ workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
+ self.add_workflow_span(trace_id, workflow_span_id, trace_info)
+
+ workflow_node_executions = self.get_workflow_node_executions(trace_info)
+ for node_execution in workflow_node_executions:
+ node_span = self.build_workflow_node_span(node_execution, trace_id, trace_info, workflow_span_id)
+ self.trace_client.add_span(node_span)
+
+ def message_trace(self, trace_info: MessageTraceInfo):
+ message_data = trace_info.message_data
+ if message_data is None:
+ return
+ message_id = trace_info.message_id
+
+ user_id = message_data.from_account_id
+ if message_data.from_end_user_id:
+ end_user_data: Optional[EndUser] = (
+ db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
+ )
+ if end_user_data is not None:
+ user_id = end_user_data.session_id
+
+ status: Status = Status(StatusCode.OK)
+ if trace_info.error:
+ status = Status(StatusCode.ERROR, trace_info.error)
+
+ trace_id = convert_to_trace_id(message_id)
+ message_span_id = convert_to_span_id(message_id, "message")
+ message_span = SpanData(
+ trace_id=trace_id,
+ parent_span_id=None,
+ span_id=message_span_id,
+ name="message",
+ start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
+ end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
+ attributes={
+ GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
+ GEN_AI_USER_ID: str(user_id),
+ GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
+ GEN_AI_FRAMEWORK: "dify",
+ INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
+ OUTPUT_VALUE: str(trace_info.outputs),
+ },
+ status=status,
+ )
+ self.trace_client.add_span(message_span)
+
+ app_model_config = getattr(trace_info.message_data, "app_model_config", {})
+ pre_prompt = getattr(app_model_config, "pre_prompt", "")
+ inputs_data = getattr(trace_info.message_data, "inputs", {})
+ llm_span = SpanData(
+ trace_id=trace_id,
+ parent_span_id=message_span_id,
+ span_id=convert_to_span_id(message_id, "llm"),
+ name="llm",
+ start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
+ end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
+ attributes={
+ GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
+ GEN_AI_USER_ID: str(user_id),
+ GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
+ GEN_AI_FRAMEWORK: "dify",
+ GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""),
+ GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""),
+ GEN_AI_USAGE_INPUT_TOKENS: str(trace_info.message_tokens),
+ GEN_AI_USAGE_OUTPUT_TOKENS: str(trace_info.answer_tokens),
+ GEN_AI_USAGE_TOTAL_TOKENS: str(trace_info.total_tokens),
+ GEN_AI_PROMPT_TEMPLATE_VARIABLE: json.dumps(inputs_data, ensure_ascii=False),
+ GEN_AI_PROMPT_TEMPLATE_TEMPLATE: pre_prompt,
+ GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
+ GEN_AI_COMPLETION: str(trace_info.outputs),
+ INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
+ OUTPUT_VALUE: str(trace_info.outputs),
+ },
+ status=status,
+ )
+ self.trace_client.add_span(llm_span)
+
+ def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
+ if trace_info.message_data is None:
+ return
+ message_id = trace_info.message_id
+
+ documents_data = extract_retrieval_documents(trace_info.documents)
+ dataset_retrieval_span = SpanData(
+ trace_id=convert_to_trace_id(message_id),
+ parent_span_id=convert_to_span_id(message_id, "message"),
+ span_id=generate_span_id(),
+ name="dataset_retrieval",
+ start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
+ end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
+ attributes={
+ GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
+ GEN_AI_FRAMEWORK: "dify",
+ RETRIEVAL_QUERY: str(trace_info.inputs),
+ RETRIEVAL_DOCUMENT: json.dumps(documents_data, ensure_ascii=False),
+ INPUT_VALUE: str(trace_info.inputs),
+ OUTPUT_VALUE: json.dumps(documents_data, ensure_ascii=False),
+ },
+ )
+ self.trace_client.add_span(dataset_retrieval_span)
+
+ def tool_trace(self, trace_info: ToolTraceInfo):
+ if trace_info.message_data is None:
+ return
+ message_id = trace_info.message_id
+
+ status: Status = Status(StatusCode.OK)
+ if trace_info.error:
+ status = Status(StatusCode.ERROR, trace_info.error)
+
+ tool_span = SpanData(
+ trace_id=convert_to_trace_id(message_id),
+ parent_span_id=convert_to_span_id(message_id, "message"),
+ span_id=generate_span_id(),
+ name=trace_info.tool_name,
+ start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
+ end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
+ attributes={
+ GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
+ GEN_AI_FRAMEWORK: "dify",
+ TOOL_NAME: trace_info.tool_name,
+ TOOL_DESCRIPTION: json.dumps(trace_info.tool_config, ensure_ascii=False),
+ TOOL_PARAMETERS: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
+ INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
+ OUTPUT_VALUE: str(trace_info.tool_outputs),
+ },
+ status=status,
+ )
+ self.trace_client.add_span(tool_span)
+
+ def get_workflow_node_executions(self, trace_info: WorkflowTraceInfo) -> Sequence[WorkflowNodeExecution]:
+ # through workflow_run_id get all_nodes_execution using repository
+ session_factory = sessionmaker(bind=db.engine)
+ # Find the app's creator account
+ with Session(db.engine, expire_on_commit=False) as session:
+ # Get the app to find its creator
+ app_id = trace_info.metadata.get("app_id")
+ if not app_id:
+ raise ValueError("No app_id found in trace_info metadata")
+
+ app = session.query(App).filter(App.id == app_id).first()
+ if not app:
+ raise ValueError(f"App with id {app_id} not found")
+
+ if not app.created_by:
+ raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
+
+ service_account = session.query(Account).filter(Account.id == app.created_by).first()
+ if not service_account:
+ raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
+ current_tenant = (
+ session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first()
+ )
+ if not current_tenant:
+ raise ValueError(f"Current tenant not found for account {service_account.id}")
+ service_account.set_tenant_id(current_tenant.tenant_id)
+ workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ session_factory=session_factory,
+ user=service_account,
+ app_id=trace_info.metadata.get("app_id"),
+ triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
+ )
+ # Get all executions for this workflow run
+ workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run(
+ workflow_run_id=trace_info.workflow_run_id
+ )
+ return workflow_node_executions
+
+ def build_workflow_node_span(
+ self, node_execution: WorkflowNodeExecution, trace_id: int, trace_info: WorkflowTraceInfo, workflow_span_id: int
+ ):
+ try:
+ if node_execution.node_type == NodeType.LLM:
+ node_span = self.build_workflow_llm_span(trace_id, workflow_span_id, trace_info, node_execution)
+ elif node_execution.node_type == NodeType.KNOWLEDGE_RETRIEVAL:
+ node_span = self.build_workflow_retrieval_span(trace_id, workflow_span_id, trace_info, node_execution)
+ elif node_execution.node_type == NodeType.TOOL:
+ node_span = self.build_workflow_tool_span(trace_id, workflow_span_id, trace_info, node_execution)
+ else:
+ node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution)
+ return node_span
+ except Exception as e:
+ logging.debug(f"Error occurred in build_workflow_node_span: {e}", exc_info=True)
+ return None
+
+ def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status:
+ span_status: Status = Status(StatusCode.UNSET)
+ if node_execution.status == WorkflowNodeExecutionStatus.SUCCEEDED:
+ span_status = Status(StatusCode.OK)
+ elif node_execution.status in [WorkflowNodeExecutionStatus.FAILED, WorkflowNodeExecutionStatus.EXCEPTION]:
+ span_status = Status(StatusCode.ERROR, str(node_execution.error))
+ return span_status
+
+ def build_workflow_task_span(
+ self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
+ ) -> SpanData:
+ return SpanData(
+ trace_id=trace_id,
+ parent_span_id=workflow_span_id,
+ span_id=convert_to_span_id(node_execution.id, "node"),
+ name=node_execution.title,
+ start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
+ end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
+ attributes={
+ GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
+ GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
+ GEN_AI_FRAMEWORK: "dify",
+ INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
+ OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
+ },
+ status=self.get_workflow_node_status(node_execution),
+ )
+
+ def build_workflow_tool_span(
+ self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
+ ) -> SpanData:
+ tool_des = {}
+ if node_execution.metadata:
+ tool_des = node_execution.metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO, {})
+ return SpanData(
+ trace_id=trace_id,
+ parent_span_id=workflow_span_id,
+ span_id=convert_to_span_id(node_execution.id, "node"),
+ name=node_execution.title,
+ start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
+ end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
+ attributes={
+ GEN_AI_SPAN_KIND: GenAISpanKind.TOOL.value,
+ GEN_AI_FRAMEWORK: "dify",
+ TOOL_NAME: node_execution.title,
+ TOOL_DESCRIPTION: json.dumps(tool_des, ensure_ascii=False),
+ TOOL_PARAMETERS: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
+ INPUT_VALUE: json.dumps(node_execution.inputs if node_execution.inputs else {}, ensure_ascii=False),
+ OUTPUT_VALUE: json.dumps(node_execution.outputs, ensure_ascii=False),
+ },
+ status=self.get_workflow_node_status(node_execution),
+ )
+
+ def build_workflow_retrieval_span(
+ self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
+ ) -> SpanData:
+ input_value = ""
+ if node_execution.inputs:
+ input_value = str(node_execution.inputs.get("query", ""))
+ output_value = ""
+ if node_execution.outputs:
+ output_value = json.dumps(node_execution.outputs.get("result", []), ensure_ascii=False)
+ return SpanData(
+ trace_id=trace_id,
+ parent_span_id=workflow_span_id,
+ span_id=convert_to_span_id(node_execution.id, "node"),
+ name=node_execution.title,
+ start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
+ end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
+ attributes={
+ GEN_AI_SPAN_KIND: GenAISpanKind.RETRIEVER.value,
+ GEN_AI_FRAMEWORK: "dify",
+ RETRIEVAL_QUERY: input_value,
+ RETRIEVAL_DOCUMENT: output_value,
+ INPUT_VALUE: input_value,
+ OUTPUT_VALUE: output_value,
+ },
+ status=self.get_workflow_node_status(node_execution),
+ )
+
+ def build_workflow_llm_span(
+ self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo, node_execution: WorkflowNodeExecution
+ ) -> SpanData:
+ process_data = node_execution.process_data or {}
+ outputs = node_execution.outputs or {}
+ usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
+ return SpanData(
+ trace_id=trace_id,
+ parent_span_id=workflow_span_id,
+ span_id=convert_to_span_id(node_execution.id, "node"),
+ name=node_execution.title,
+ start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
+ end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
+ attributes={
+ GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
+ GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
+ GEN_AI_FRAMEWORK: "dify",
+ GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
+ GEN_AI_SYSTEM: process_data.get("model_provider", ""),
+ GEN_AI_USAGE_INPUT_TOKENS: str(usage_data.get("prompt_tokens", 0)),
+ GEN_AI_USAGE_OUTPUT_TOKENS: str(usage_data.get("completion_tokens", 0)),
+ GEN_AI_USAGE_TOTAL_TOKENS: str(usage_data.get("total_tokens", 0)),
+ GEN_AI_PROMPT: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
+ GEN_AI_COMPLETION: str(outputs.get("text", "")),
+ GEN_AI_RESPONSE_FINISH_REASON: outputs.get("finish_reason", ""),
+ INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
+ OUTPUT_VALUE: str(outputs.get("text", "")),
+ },
+ status=self.get_workflow_node_status(node_execution),
+ )
+
+ def add_workflow_span(self, trace_id: int, workflow_span_id: int, trace_info: WorkflowTraceInfo):
+ message_span_id = None
+ if trace_info.message_id:
+ message_span_id = convert_to_span_id(trace_info.message_id, "message")
+ user_id = trace_info.metadata.get("user_id")
+ status: Status = Status(StatusCode.OK)
+ if trace_info.error:
+ status = Status(StatusCode.ERROR, trace_info.error)
+ if message_span_id: # chatflow
+ message_span = SpanData(
+ trace_id=trace_id,
+ parent_span_id=None,
+ span_id=message_span_id,
+ name="message",
+ start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
+ end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
+ attributes={
+ GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
+ GEN_AI_USER_ID: str(user_id),
+ GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
+ GEN_AI_FRAMEWORK: "dify",
+ INPUT_VALUE: trace_info.workflow_run_inputs.get("sys.query", ""),
+ OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
+ },
+ status=status,
+ )
+ self.trace_client.add_span(message_span)
+
+ workflow_span = SpanData(
+ trace_id=trace_id,
+ parent_span_id=message_span_id,
+ span_id=workflow_span_id,
+ name="workflow",
+ start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
+ end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
+ attributes={
+ GEN_AI_USER_ID: str(user_id),
+ GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
+ GEN_AI_FRAMEWORK: "dify",
+ INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
+ OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
+ },
+ status=status,
+ )
+ self.trace_client.add_span(workflow_span)
+
+ def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
+ message_id = trace_info.message_id
+ status: Status = Status(StatusCode.OK)
+ if trace_info.error:
+ status = Status(StatusCode.ERROR, trace_info.error)
+ suggested_question_span = SpanData(
+ trace_id=convert_to_trace_id(message_id),
+ parent_span_id=convert_to_span_id(message_id, "message"),
+ span_id=convert_to_span_id(message_id, "suggested_question"),
+ name="suggested_question",
+ start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
+ end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
+ attributes={
+ GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
+ GEN_AI_FRAMEWORK: "dify",
+ GEN_AI_MODEL_NAME: trace_info.metadata.get("ls_model_name", ""),
+ GEN_AI_SYSTEM: trace_info.metadata.get("ls_provider", ""),
+ GEN_AI_PROMPT: json.dumps(trace_info.inputs, ensure_ascii=False),
+ GEN_AI_COMPLETION: json.dumps(trace_info.suggested_question, ensure_ascii=False),
+ INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
+ OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
+ },
+ status=status,
+ )
+ self.trace_client.add_span(suggested_question_span)
+
+
+def extract_retrieval_documents(documents: list[Document]):
+ documents_data = []
+ for document in documents:
+ document_data = {
+ "content": document.page_content,
+ "metadata": {
+ "dataset_id": document.metadata.get("dataset_id"),
+ "doc_id": document.metadata.get("doc_id"),
+ "document_id": document.metadata.get("document_id"),
+ },
+ "score": document.metadata.get("score"),
+ }
+ documents_data.append(document_data)
+ return documents_data
diff --git a/api/core/ops/aliyun_trace/data_exporter/__init__.py b/api/core/ops/aliyun_trace/data_exporter/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/ops/aliyun_trace/data_exporter/traceclient.py b/api/core/ops/aliyun_trace/data_exporter/traceclient.py
new file mode 100644
index 0000000000..ba5ac3f420
--- /dev/null
+++ b/api/core/ops/aliyun_trace/data_exporter/traceclient.py
@@ -0,0 +1,200 @@
+import hashlib
+import logging
+import random
+import socket
+import threading
+import uuid
+from collections import deque
+from collections.abc import Sequence
+from datetime import datetime
+from typing import Optional
+
+import requests
+from opentelemetry import trace as trace_api
+from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
+from opentelemetry.sdk.resources import Resource
+from opentelemetry.sdk.trace import ReadableSpan
+from opentelemetry.sdk.util.instrumentation import InstrumentationScope
+from opentelemetry.semconv.resource import ResourceAttributes
+
+from configs import dify_config
+from core.ops.aliyun_trace.entities.aliyun_trace_entity import SpanData
+
+INVALID_SPAN_ID = 0x0000000000000000
+INVALID_TRACE_ID = 0x00000000000000000000000000000000
+
+logger = logging.getLogger(__name__)
+
+
+class TraceClient:
+ def __init__(
+ self,
+ service_name: str,
+ endpoint: str,
+ max_queue_size: int = 1000,
+ schedule_delay_sec: int = 5,
+ max_export_batch_size: int = 50,
+ ):
+ self.endpoint = endpoint
+ self.resource = Resource(
+ attributes={
+ ResourceAttributes.SERVICE_NAME: service_name,
+ ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
+ ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
+ ResourceAttributes.HOST_NAME: socket.gethostname(),
+ }
+ )
+ self.span_builder = SpanBuilder(self.resource)
+ self.exporter = OTLPSpanExporter(endpoint=endpoint)
+
+ self.max_queue_size = max_queue_size
+ self.schedule_delay_sec = schedule_delay_sec
+ self.max_export_batch_size = max_export_batch_size
+
+ self.queue: deque = deque(maxlen=max_queue_size)
+ self.condition = threading.Condition(threading.Lock())
+ self.done = False
+
+ self.worker_thread = threading.Thread(target=self._worker, daemon=True)
+ self.worker_thread.start()
+
+ self._spans_dropped = False
+
+ def export(self, spans: Sequence[ReadableSpan]):
+ self.exporter.export(spans)
+
+ def api_check(self):
+ try:
+ response = requests.head(self.endpoint, timeout=5)
+ if response.status_code == 405:
+ return True
+ else:
+ logger.debug(f"AliyunTrace API check failed: Unexpected status code: {response.status_code}")
+ return False
+ except requests.exceptions.RequestException as e:
+ logger.debug(f"AliyunTrace API check failed: {str(e)}")
+ raise ValueError(f"AliyunTrace API check failed: {str(e)}")
+
+ def get_project_url(self):
+ return "https://arms.console.aliyun.com/#/llm"
+
+ def add_span(self, span_data: SpanData):
+ if span_data is None:
+ return
+ span: ReadableSpan = self.span_builder.build_span(span_data)
+ with self.condition:
+ if len(self.queue) == self.max_queue_size:
+ if not self._spans_dropped:
+ logger.warning("Queue is full, likely spans will be dropped.")
+ self._spans_dropped = True
+
+ self.queue.appendleft(span)
+ if len(self.queue) >= self.max_export_batch_size:
+ self.condition.notify()
+
+ def _worker(self):
+ while not self.done:
+ with self.condition:
+ if len(self.queue) < self.max_export_batch_size and not self.done:
+ self.condition.wait(timeout=self.schedule_delay_sec)
+ self._export_batch()
+
+ def _export_batch(self):
+ spans_to_export: list[ReadableSpan] = []
+ with self.condition:
+ while len(spans_to_export) < self.max_export_batch_size and self.queue:
+ spans_to_export.append(self.queue.pop())
+
+ if spans_to_export:
+ try:
+ self.exporter.export(spans_to_export)
+ except Exception as e:
+ logger.debug(f"Error exporting spans: {e}")
+
+ def shutdown(self):
+ with self.condition:
+ self.done = True
+ self.condition.notify_all()
+ self.worker_thread.join()
+ self._export_batch()
+ self.exporter.shutdown()
+
+
+class SpanBuilder:
+ def __init__(self, resource):
+ self.resource = resource
+ self.instrumentation_scope = InstrumentationScope(
+ __name__,
+ "",
+ None,
+ None,
+ )
+
+ def build_span(self, span_data: SpanData) -> ReadableSpan:
+ span_context = trace_api.SpanContext(
+ trace_id=span_data.trace_id,
+ span_id=span_data.span_id,
+ is_remote=False,
+ trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
+ trace_state=None,
+ )
+
+ parent_span_context = None
+ if span_data.parent_span_id is not None:
+ parent_span_context = trace_api.SpanContext(
+ trace_id=span_data.trace_id,
+ span_id=span_data.parent_span_id,
+ is_remote=False,
+ trace_flags=trace_api.TraceFlags(trace_api.TraceFlags.SAMPLED),
+ trace_state=None,
+ )
+
+ span = ReadableSpan(
+ name=span_data.name,
+ context=span_context,
+ parent=parent_span_context,
+ resource=self.resource,
+ attributes=span_data.attributes,
+ events=span_data.events,
+ links=span_data.links,
+ kind=trace_api.SpanKind.INTERNAL,
+ status=span_data.status,
+ start_time=span_data.start_time,
+ end_time=span_data.end_time,
+ instrumentation_scope=self.instrumentation_scope,
+ )
+ return span
+
+
+def generate_span_id() -> int:
+ span_id = random.getrandbits(64)
+ while span_id == INVALID_SPAN_ID:
+ span_id = random.getrandbits(64)
+ return span_id
+
+
+def convert_to_trace_id(uuid_v4: Optional[str]) -> int:
+ try:
+ uuid_obj = uuid.UUID(uuid_v4)
+ return uuid_obj.int
+ except Exception as e:
+ raise ValueError(f"Invalid UUID input: {e}")
+
+
+def convert_to_span_id(uuid_v4: Optional[str], span_type: str) -> int:
+ try:
+ uuid_obj = uuid.UUID(uuid_v4)
+ except Exception as e:
+ raise ValueError(f"Invalid UUID input: {e}")
+ combined_key = f"{uuid_obj.hex}-{span_type}"
+ hash_bytes = hashlib.sha256(combined_key.encode("utf-8")).digest()
+ span_id = int.from_bytes(hash_bytes[:8], byteorder="big", signed=False)
+ return span_id
+
+
+def convert_datetime_to_nanoseconds(start_time_a: Optional[datetime]) -> Optional[int]:
+ if start_time_a is None:
+ return None
+ timestamp_in_seconds = start_time_a.timestamp()
+ timestamp_in_nanoseconds = int(timestamp_in_seconds * 1e9)
+ return timestamp_in_nanoseconds
diff --git a/api/core/ops/aliyun_trace/entities/__init__.py b/api/core/ops/aliyun_trace/entities/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
new file mode 100644
index 0000000000..1caa822cd0
--- /dev/null
+++ b/api/core/ops/aliyun_trace/entities/aliyun_trace_entity.py
@@ -0,0 +1,21 @@
+from collections.abc import Sequence
+from typing import Optional
+
+from opentelemetry import trace as trace_api
+from opentelemetry.sdk.trace import Event, Status, StatusCode
+from pydantic import BaseModel, Field
+
+
+class SpanData(BaseModel):
+ model_config = {"arbitrary_types_allowed": True}
+
+ trace_id: int = Field(..., description="The unique identifier for the trace.")
+ parent_span_id: Optional[int] = Field(None, description="The ID of the parent span, if any.")
+ span_id: int = Field(..., description="The unique identifier for this span.")
+ name: str = Field(..., description="The name of the span.")
+ attributes: dict[str, str] = Field(default_factory=dict, description="Attributes associated with the span.")
+ events: Sequence[Event] = Field(default_factory=list, description="Events recorded in the span.")
+ links: Sequence[trace_api.Link] = Field(default_factory=list, description="Links to other spans.")
+ status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
+ start_time: Optional[int] = Field(..., description="The start time of the span in nanoseconds.")
+ end_time: Optional[int] = Field(..., description="The end time of the span in nanoseconds.")
diff --git a/api/core/ops/aliyun_trace/entities/semconv.py b/api/core/ops/aliyun_trace/entities/semconv.py
new file mode 100644
index 0000000000..5d70264320
--- /dev/null
+++ b/api/core/ops/aliyun_trace/entities/semconv.py
@@ -0,0 +1,64 @@
+from enum import Enum
+
+# public
+GEN_AI_SESSION_ID = "gen_ai.session.id"
+
+GEN_AI_USER_ID = "gen_ai.user.id"
+
+GEN_AI_USER_NAME = "gen_ai.user.name"
+
+GEN_AI_SPAN_KIND = "gen_ai.span.kind"
+
+GEN_AI_FRAMEWORK = "gen_ai.framework"
+
+
+# Chain
+INPUT_VALUE = "input.value"
+
+OUTPUT_VALUE = "output.value"
+
+
+# Retriever
+RETRIEVAL_QUERY = "retrieval.query"
+
+RETRIEVAL_DOCUMENT = "retrieval.document"
+
+
+# LLM
+GEN_AI_MODEL_NAME = "gen_ai.model_name"
+
+GEN_AI_SYSTEM = "gen_ai.system"
+
+GEN_AI_USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
+
+GEN_AI_USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
+
+GEN_AI_USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
+
+GEN_AI_PROMPT_TEMPLATE_TEMPLATE = "gen_ai.prompt_template.template"
+
+GEN_AI_PROMPT_TEMPLATE_VARIABLE = "gen_ai.prompt_template.variable"
+
+GEN_AI_PROMPT = "gen_ai.prompt"
+
+GEN_AI_COMPLETION = "gen_ai.completion"
+
+GEN_AI_RESPONSE_FINISH_REASON = "gen_ai.response.finish_reason"
+
+# Tool
+TOOL_NAME = "tool.name"
+
+TOOL_DESCRIPTION = "tool.description"
+
+TOOL_PARAMETERS = "tool.parameters"
+
+
+class GenAISpanKind(Enum):
+ CHAIN = "CHAIN"
+ RETRIEVER = "RETRIEVER"
+ RERANKER = "RERANKER"
+ LLM = "LLM"
+ EMBEDDING = "EMBEDDING"
+ TOOL = "TOOL"
+ AGENT = "AGENT"
+ TASK = "TASK"
diff --git a/api/core/ops/arize_phoenix_trace/__init__.py b/api/core/ops/arize_phoenix_trace/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
new file mode 100644
index 0000000000..8b3ce0c448
--- /dev/null
+++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
@@ -0,0 +1,729 @@
+import hashlib
+import json
+import logging
+import os
+from datetime import datetime, timedelta
+from typing import Any, Optional, Union, cast
+
+from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
+from opentelemetry import trace
+from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GrpcOTLPSpanExporter
+from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HttpOTLPSpanExporter
+from opentelemetry.sdk import trace as trace_sdk
+from opentelemetry.sdk.resources import Resource
+from opentelemetry.sdk.trace.export import SimpleSpanProcessor
+from opentelemetry.sdk.trace.id_generator import RandomIdGenerator
+from opentelemetry.trace import SpanContext, TraceFlags, TraceState
+
+from core.ops.base_trace_instance import BaseTraceInstance
+from core.ops.entities.config_entity import ArizeConfig, PhoenixConfig
+from core.ops.entities.trace_entity import (
+ BaseTraceInfo,
+ DatasetRetrievalTraceInfo,
+ GenerateNameTraceInfo,
+ MessageTraceInfo,
+ ModerationTraceInfo,
+ SuggestedQuestionTraceInfo,
+ ToolTraceInfo,
+ TraceTaskName,
+ WorkflowTraceInfo,
+)
+from extensions.ext_database import db
+from models.model import EndUser, MessageFile
+from models.workflow import WorkflowNodeExecutionModel
+
+logger = logging.getLogger(__name__)
+
+
+def setup_tracer(arize_phoenix_config: ArizeConfig | PhoenixConfig) -> tuple[trace_sdk.Tracer, SimpleSpanProcessor]:
+ """Configure OpenTelemetry tracer with OTLP exporter for Arize/Phoenix."""
+ try:
+ # Choose the appropriate exporter based on config type
+ exporter: Union[GrpcOTLPSpanExporter, HttpOTLPSpanExporter]
+ if isinstance(arize_phoenix_config, ArizeConfig):
+ arize_endpoint = f"{arize_phoenix_config.endpoint}/v1"
+ arize_headers = {
+ "api_key": arize_phoenix_config.api_key or "",
+ "space_id": arize_phoenix_config.space_id or "",
+ "authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
+ }
+ exporter = GrpcOTLPSpanExporter(
+ endpoint=arize_endpoint,
+ headers=arize_headers,
+ timeout=30,
+ )
+ else:
+ phoenix_endpoint = f"{arize_phoenix_config.endpoint}/v1/traces"
+ phoenix_headers = {
+ "api_key": arize_phoenix_config.api_key or "",
+ "authorization": f"Bearer {arize_phoenix_config.api_key or ''}",
+ }
+ exporter = HttpOTLPSpanExporter(
+ endpoint=phoenix_endpoint,
+ headers=phoenix_headers,
+ timeout=30,
+ )
+
+ attributes = {
+ "openinference.project.name": arize_phoenix_config.project or "",
+ "model_id": arize_phoenix_config.project or "",
+ }
+ resource = Resource(attributes=attributes)
+ provider = trace_sdk.TracerProvider(resource=resource)
+ processor = SimpleSpanProcessor(
+ exporter,
+ )
+ provider.add_span_processor(processor)
+
+ # Create a named tracer instead of setting the global provider
+ tracer_name = f"arize_phoenix_tracer_{arize_phoenix_config.project}"
+ logger.info(f"[Arize/Phoenix] Created tracer with name: {tracer_name}")
+ return cast(trace_sdk.Tracer, provider.get_tracer(tracer_name)), processor
+ except Exception as e:
+ logger.error(f"[Arize/Phoenix] Failed to setup the tracer: {str(e)}", exc_info=True)
+ raise
+
+
+def datetime_to_nanos(dt: Optional[datetime]) -> int:
+ """Convert datetime to nanoseconds since epoch. If None, use current time."""
+ if dt is None:
+ dt = datetime.now()
+ return int(dt.timestamp() * 1_000_000_000)
+
+
+def uuid_to_trace_id(string: Optional[str]) -> int:
+ """Convert UUID string to a valid trace ID (16-byte integer)."""
+ if string is None:
+ string = ""
+ hash_object = hashlib.sha256(string.encode())
+
+ # Take the first 16 bytes (128 bits) of the hash
+ digest = hash_object.digest()[:16]
+
+ # Convert to integer (128 bits)
+ return int.from_bytes(digest, byteorder="big")
+
+
+class ArizePhoenixDataTrace(BaseTraceInstance):
+ def __init__(
+ self,
+ arize_phoenix_config: ArizeConfig | PhoenixConfig,
+ ):
+ super().__init__(arize_phoenix_config)
+ import logging
+
+ logging.basicConfig()
+ logging.getLogger().setLevel(logging.DEBUG)
+ self.arize_phoenix_config = arize_phoenix_config
+ self.tracer, self.processor = setup_tracer(arize_phoenix_config)
+ self.project = arize_phoenix_config.project
+ self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
+
+ def trace(self, trace_info: BaseTraceInfo):
+ logger.info(f"[Arize/Phoenix] Trace: {trace_info}")
+ try:
+ if isinstance(trace_info, WorkflowTraceInfo):
+ self.workflow_trace(trace_info)
+ if isinstance(trace_info, MessageTraceInfo):
+ self.message_trace(trace_info)
+ if isinstance(trace_info, ModerationTraceInfo):
+ self.moderation_trace(trace_info)
+ if isinstance(trace_info, SuggestedQuestionTraceInfo):
+ self.suggested_question_trace(trace_info)
+ if isinstance(trace_info, DatasetRetrievalTraceInfo):
+ self.dataset_retrieval_trace(trace_info)
+ if isinstance(trace_info, ToolTraceInfo):
+ self.tool_trace(trace_info)
+ if isinstance(trace_info, GenerateNameTraceInfo):
+ self.generate_name_trace(trace_info)
+
+ except Exception as e:
+ logger.error(f"[Arize/Phoenix] Error in the trace: {str(e)}", exc_info=True)
+ raise
+
+ def workflow_trace(self, trace_info: WorkflowTraceInfo):
+ workflow_metadata = {
+ "workflow_run_id": trace_info.workflow_run_id or "",
+ "message_id": trace_info.message_id or "",
+ "workflow_app_log_id": trace_info.workflow_app_log_id or "",
+ "status": trace_info.workflow_run_status or "",
+ "status_message": trace_info.error or "",
+ "level": "ERROR" if trace_info.error else "DEFAULT",
+ "total_tokens": trace_info.total_tokens or 0,
+ }
+ workflow_metadata.update(trace_info.metadata)
+
+ trace_id = uuid_to_trace_id(trace_info.workflow_run_id)
+ span_id = RandomIdGenerator().generate_span_id()
+ context = SpanContext(
+ trace_id=trace_id,
+ span_id=span_id,
+ is_remote=False,
+ trace_flags=TraceFlags(TraceFlags.SAMPLED),
+ trace_state=TraceState(),
+ )
+
+ workflow_span = self.tracer.start_span(
+ name=TraceTaskName.WORKFLOW_TRACE.value,
+ attributes={
+ SpanAttributes.INPUT_VALUE: json.dumps(trace_info.workflow_run_inputs, ensure_ascii=False),
+ SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.workflow_run_outputs, ensure_ascii=False),
+ SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
+ SpanAttributes.METADATA: json.dumps(workflow_metadata, ensure_ascii=False),
+ SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
+ },
+ start_time=datetime_to_nanos(trace_info.start_time),
+ context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
+ )
+
+ try:
+ # Process workflow nodes
+ for node_execution in self._get_workflow_nodes(trace_info.workflow_run_id):
+ created_at = node_execution.created_at or datetime.now()
+ elapsed_time = node_execution.elapsed_time
+ finished_at = created_at + timedelta(seconds=elapsed_time)
+
+ process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
+
+ node_metadata = {
+ "node_id": node_execution.id,
+ "node_type": node_execution.node_type,
+ "node_status": node_execution.status,
+ "tenant_id": node_execution.tenant_id,
+ "app_id": node_execution.app_id,
+ "app_name": node_execution.title,
+ "status": node_execution.status,
+ "level": "ERROR" if node_execution.status != "succeeded" else "DEFAULT",
+ }
+
+ if node_execution.execution_metadata:
+ node_metadata.update(json.loads(node_execution.execution_metadata))
+
+ # Determine the correct span kind based on node type
+ span_kind = OpenInferenceSpanKindValues.CHAIN.value
+ if node_execution.node_type == "llm":
+ span_kind = OpenInferenceSpanKindValues.LLM.value
+ provider = process_data.get("model_provider")
+ model = process_data.get("model_name")
+ if provider:
+ node_metadata["ls_provider"] = provider
+ if model:
+ node_metadata["ls_model_name"] = model
+
+ outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
+ usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
+ if usage_data:
+ node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
+ node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
+ node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
+ elif node_execution.node_type == "dataset_retrieval":
+ span_kind = OpenInferenceSpanKindValues.RETRIEVER.value
+ elif node_execution.node_type == "tool":
+ span_kind = OpenInferenceSpanKindValues.TOOL.value
+ else:
+ span_kind = OpenInferenceSpanKindValues.CHAIN.value
+
+ node_span = self.tracer.start_span(
+ name=node_execution.node_type,
+ attributes={
+ SpanAttributes.INPUT_VALUE: node_execution.inputs or "{}",
+ SpanAttributes.OUTPUT_VALUE: node_execution.outputs or "{}",
+ SpanAttributes.OPENINFERENCE_SPAN_KIND: span_kind,
+ SpanAttributes.METADATA: json.dumps(node_metadata, ensure_ascii=False),
+ SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
+ },
+ start_time=datetime_to_nanos(created_at),
+ context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
+ )
+
+ try:
+ if node_execution.node_type == "llm":
+ llm_attributes: dict[str, Any] = {
+ SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
+ }
+ provider = process_data.get("model_provider")
+ model = process_data.get("model_name")
+ if provider:
+ llm_attributes[SpanAttributes.LLM_PROVIDER] = provider
+ if model:
+ llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model
+ outputs = (
+ json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
+ )
+ usage_data = (
+ process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
+ )
+ if usage_data:
+ llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0)
+ llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = usage_data.get("prompt_tokens", 0)
+ llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = usage_data.get(
+ "completion_tokens", 0
+ )
+ llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
+ node_span.set_attributes(llm_attributes)
+ finally:
+ node_span.end(end_time=datetime_to_nanos(finished_at))
+ finally:
+ workflow_span.end(end_time=datetime_to_nanos(trace_info.end_time))
+
+ def message_trace(self, trace_info: MessageTraceInfo):
+ if trace_info.message_data is None:
+ return
+
+ file_list = cast(list[str], trace_info.file_list) or []
+ message_file_data: Optional[MessageFile] = trace_info.message_file_data
+
+ if message_file_data is not None:
+ file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
+ file_list.append(file_url)
+
+ message_metadata = {
+ "message_id": trace_info.message_id or "",
+ "conversation_mode": str(trace_info.conversation_mode or ""),
+ "user_id": trace_info.message_data.from_account_id or "",
+ "file_list": json.dumps(file_list),
+ "status": trace_info.message_data.status or "",
+ "status_message": trace_info.error or "",
+ "level": "ERROR" if trace_info.error else "DEFAULT",
+ "total_tokens": trace_info.total_tokens or 0,
+ "prompt_tokens": trace_info.message_tokens or 0,
+ "completion_tokens": trace_info.answer_tokens or 0,
+ "ls_provider": trace_info.message_data.model_provider or "",
+ "ls_model_name": trace_info.message_data.model_id or "",
+ }
+ message_metadata.update(trace_info.metadata)
+
+ # Add end user data if available
+ if trace_info.message_data.from_end_user_id:
+ end_user_data: Optional[EndUser] = (
+ db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first()
+ )
+ if end_user_data is not None:
+ message_metadata["end_user_id"] = end_user_data.session_id
+
+ attributes = {
+ SpanAttributes.INPUT_VALUE: trace_info.message_data.query,
+ SpanAttributes.OUTPUT_VALUE: trace_info.message_data.answer,
+ SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
+ SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
+ SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
+ }
+
+ trace_id = uuid_to_trace_id(trace_info.message_id)
+ message_span_id = RandomIdGenerator().generate_span_id()
+ span_context = SpanContext(
+ trace_id=trace_id,
+ span_id=message_span_id,
+ is_remote=False,
+ trace_flags=TraceFlags(TraceFlags.SAMPLED),
+ trace_state=TraceState(),
+ )
+
+ message_span = self.tracer.start_span(
+ name=TraceTaskName.MESSAGE_TRACE.value,
+ attributes=attributes,
+ start_time=datetime_to_nanos(trace_info.start_time),
+ context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
+ )
+
+ try:
+ if trace_info.error:
+ message_span.add_event(
+ "exception",
+ attributes={
+ "exception.message": trace_info.error,
+ "exception.type": "Error",
+ "exception.stacktrace": trace_info.error,
+ },
+ )
+
+ # Convert outputs to string based on type
+ if isinstance(trace_info.outputs, dict | list):
+ outputs_str = json.dumps(trace_info.outputs, ensure_ascii=False)
+ elif isinstance(trace_info.outputs, str):
+ outputs_str = trace_info.outputs
+ else:
+ outputs_str = str(trace_info.outputs)
+
+ llm_attributes = {
+ SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.LLM.value,
+ SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
+ SpanAttributes.OUTPUT_VALUE: outputs_str,
+ SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
+ SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
+ }
+ llm_attributes.update(self._construct_llm_attributes(trace_info.inputs))
+ if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
+ llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens
+ if trace_info.message_tokens is not None and trace_info.message_tokens > 0:
+ llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = trace_info.message_tokens
+ if trace_info.answer_tokens is not None and trace_info.answer_tokens > 0:
+ llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = trace_info.answer_tokens
+
+ if trace_info.message_data.model_id is not None:
+ llm_attributes[SpanAttributes.LLM_MODEL_NAME] = trace_info.message_data.model_id
+ if trace_info.message_data.model_provider is not None:
+ llm_attributes[SpanAttributes.LLM_PROVIDER] = trace_info.message_data.model_provider
+
+ if trace_info.message_data and trace_info.message_data.message_metadata:
+ metadata_dict = json.loads(trace_info.message_data.message_metadata)
+ if model_params := metadata_dict.get("model_parameters"):
+ llm_attributes[SpanAttributes.LLM_INVOCATION_PARAMETERS] = json.dumps(model_params)
+
+ llm_span = self.tracer.start_span(
+ name="llm",
+ attributes=llm_attributes,
+ start_time=datetime_to_nanos(trace_info.start_time),
+ context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
+ )
+
+ try:
+ if trace_info.error:
+ llm_span.add_event(
+ "exception",
+ attributes={
+ "exception.message": trace_info.error,
+ "exception.type": "Error",
+ "exception.stacktrace": trace_info.error,
+ },
+ )
+ finally:
+ llm_span.end(end_time=datetime_to_nanos(trace_info.end_time))
+ finally:
+ message_span.end(end_time=datetime_to_nanos(trace_info.end_time))
+
+ def moderation_trace(self, trace_info: ModerationTraceInfo):
+ if trace_info.message_data is None:
+ return
+
+ metadata = {
+ "message_id": trace_info.message_id,
+ "tool_name": "moderation",
+ "status": trace_info.message_data.status,
+ "status_message": trace_info.message_data.error or "",
+ "level": "ERROR" if trace_info.message_data.error else "DEFAULT",
+ }
+ metadata.update(trace_info.metadata)
+
+ trace_id = uuid_to_trace_id(trace_info.message_id)
+ span_id = RandomIdGenerator().generate_span_id()
+ context = SpanContext(
+ trace_id=trace_id,
+ span_id=span_id,
+ is_remote=False,
+ trace_flags=TraceFlags(TraceFlags.SAMPLED),
+ trace_state=TraceState(),
+ )
+
+ span = self.tracer.start_span(
+ name=TraceTaskName.MODERATION_TRACE.value,
+ attributes={
+ SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
+ SpanAttributes.OUTPUT_VALUE: json.dumps(
+ {
+ "action": trace_info.action,
+ "flagged": trace_info.flagged,
+ "preset_response": trace_info.preset_response,
+ "inputs": trace_info.inputs,
+ },
+ ensure_ascii=False,
+ ),
+ SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
+ SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
+ },
+ start_time=datetime_to_nanos(trace_info.start_time),
+ context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
+ )
+
+ try:
+ if trace_info.message_data.error:
+ span.add_event(
+ "exception",
+ attributes={
+ "exception.message": trace_info.message_data.error,
+ "exception.type": "Error",
+ "exception.stacktrace": trace_info.message_data.error,
+ },
+ )
+ finally:
+ span.end(end_time=datetime_to_nanos(trace_info.end_time))
+
+ def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
+ if trace_info.message_data is None:
+ return
+
+ start_time = trace_info.start_time or trace_info.message_data.created_at
+ end_time = trace_info.end_time or trace_info.message_data.updated_at
+
+ metadata = {
+ "message_id": trace_info.message_id,
+ "tool_name": "suggested_question",
+ "status": trace_info.status,
+ "status_message": trace_info.error or "",
+ "level": "ERROR" if trace_info.error else "DEFAULT",
+ "total_tokens": trace_info.total_tokens,
+ "ls_provider": trace_info.model_provider or "",
+ "ls_model_name": trace_info.model_id or "",
+ }
+ metadata.update(trace_info.metadata)
+
+ trace_id = uuid_to_trace_id(trace_info.message_id)
+ span_id = RandomIdGenerator().generate_span_id()
+ context = SpanContext(
+ trace_id=trace_id,
+ span_id=span_id,
+ is_remote=False,
+ trace_flags=TraceFlags(TraceFlags.SAMPLED),
+ trace_state=TraceState(),
+ )
+
+ span = self.tracer.start_span(
+ name=TraceTaskName.SUGGESTED_QUESTION_TRACE.value,
+ attributes={
+ SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
+ SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.suggested_question, ensure_ascii=False),
+ SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
+ SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
+ },
+ start_time=datetime_to_nanos(start_time),
+ context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
+ )
+
+ try:
+ if trace_info.error:
+ span.add_event(
+ "exception",
+ attributes={
+ "exception.message": trace_info.error,
+ "exception.type": "Error",
+ "exception.stacktrace": trace_info.error,
+ },
+ )
+ finally:
+ span.end(end_time=datetime_to_nanos(end_time))
+
+ def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
+ if trace_info.message_data is None:
+ return
+
+ start_time = trace_info.start_time or trace_info.message_data.created_at
+ end_time = trace_info.end_time or trace_info.message_data.updated_at
+
+ metadata = {
+ "message_id": trace_info.message_id,
+ "tool_name": "dataset_retrieval",
+ "status": trace_info.message_data.status,
+ "status_message": trace_info.message_data.error or "",
+ "level": "ERROR" if trace_info.message_data.error else "DEFAULT",
+ "ls_provider": trace_info.message_data.model_provider or "",
+ "ls_model_name": trace_info.message_data.model_id or "",
+ }
+ metadata.update(trace_info.metadata)
+
+ trace_id = uuid_to_trace_id(trace_info.message_id)
+ span_id = RandomIdGenerator().generate_span_id()
+ context = SpanContext(
+ trace_id=trace_id,
+ span_id=span_id,
+ is_remote=False,
+ trace_flags=TraceFlags(TraceFlags.SAMPLED),
+ trace_state=TraceState(),
+ )
+
+ span = self.tracer.start_span(
+ name=TraceTaskName.DATASET_RETRIEVAL_TRACE.value,
+ attributes={
+ SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
+ SpanAttributes.OUTPUT_VALUE: json.dumps({"documents": trace_info.documents}, ensure_ascii=False),
+ SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.RETRIEVER.value,
+ SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
+ "start_time": start_time.isoformat() if start_time else "",
+ "end_time": end_time.isoformat() if end_time else "",
+ },
+ start_time=datetime_to_nanos(start_time),
+ context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
+ )
+
+ try:
+ if trace_info.message_data.error:
+ span.add_event(
+ "exception",
+ attributes={
+ "exception.message": trace_info.message_data.error,
+ "exception.type": "Error",
+ "exception.stacktrace": trace_info.message_data.error,
+ },
+ )
+ finally:
+ span.end(end_time=datetime_to_nanos(end_time))
+
+ def tool_trace(self, trace_info: ToolTraceInfo):
+ if trace_info.message_data is None:
+ logger.warning("[Arize/Phoenix] Message data is None, skipping tool trace.")
+ return
+
+ metadata = {
+ "message_id": trace_info.message_id,
+ "tool_config": json.dumps(trace_info.tool_config, ensure_ascii=False),
+ }
+
+ trace_id = uuid_to_trace_id(trace_info.message_id)
+ tool_span_id = RandomIdGenerator().generate_span_id()
+ logger.info(f"[Arize/Phoenix] Creating tool trace with trace_id: {trace_id}, span_id: {tool_span_id}")
+
+ # Create span context with the same trace_id as the parent
+ # todo: Create with the appropriate parent span context, so that the tool span is
+ # a child of the appropriate span (e.g. message span)
+ span_context = SpanContext(
+ trace_id=trace_id,
+ span_id=tool_span_id,
+ is_remote=False,
+ trace_flags=TraceFlags(TraceFlags.SAMPLED),
+ trace_state=TraceState(),
+ )
+
+ tool_params_str = (
+ json.dumps(trace_info.tool_parameters, ensure_ascii=False)
+ if isinstance(trace_info.tool_parameters, dict)
+ else str(trace_info.tool_parameters)
+ )
+
+ span = self.tracer.start_span(
+ name=trace_info.tool_name,
+ attributes={
+ SpanAttributes.INPUT_VALUE: json.dumps(trace_info.tool_inputs, ensure_ascii=False),
+ SpanAttributes.OUTPUT_VALUE: trace_info.tool_outputs,
+ SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.TOOL.value,
+ SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
+ SpanAttributes.TOOL_NAME: trace_info.tool_name,
+ SpanAttributes.TOOL_PARAMETERS: tool_params_str,
+ },
+ start_time=datetime_to_nanos(trace_info.start_time),
+ context=trace.set_span_in_context(trace.NonRecordingSpan(span_context)),
+ )
+
+ try:
+ if trace_info.error:
+ span.add_event(
+ "exception",
+ attributes={
+ "exception.message": trace_info.error,
+ "exception.type": "Error",
+ "exception.stacktrace": trace_info.error,
+ },
+ )
+ finally:
+ span.end(end_time=datetime_to_nanos(trace_info.end_time))
+
+ def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
+ if trace_info.message_data is None:
+ return
+
+ metadata = {
+ "project_name": self.project,
+ "message_id": trace_info.message_id,
+ "status": trace_info.message_data.status,
+ "status_message": trace_info.message_data.error or "",
+ "level": "ERROR" if trace_info.message_data.error else "DEFAULT",
+ }
+ metadata.update(trace_info.metadata)
+
+ trace_id = uuid_to_trace_id(trace_info.message_id)
+ span_id = RandomIdGenerator().generate_span_id()
+ context = SpanContext(
+ trace_id=trace_id,
+ span_id=span_id,
+ is_remote=False,
+ trace_flags=TraceFlags(TraceFlags.SAMPLED),
+ trace_state=TraceState(),
+ )
+
+ span = self.tracer.start_span(
+ name=TraceTaskName.GENERATE_NAME_TRACE.value,
+ attributes={
+ SpanAttributes.INPUT_VALUE: json.dumps(trace_info.inputs, ensure_ascii=False),
+ SpanAttributes.OUTPUT_VALUE: json.dumps(trace_info.outputs, ensure_ascii=False),
+ SpanAttributes.OPENINFERENCE_SPAN_KIND: OpenInferenceSpanKindValues.CHAIN.value,
+ SpanAttributes.METADATA: json.dumps(metadata, ensure_ascii=False),
+ SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
+ "start_time": trace_info.start_time.isoformat() if trace_info.start_time else "",
+ "end_time": trace_info.end_time.isoformat() if trace_info.end_time else "",
+ },
+ start_time=datetime_to_nanos(trace_info.start_time),
+ context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
+ )
+
+ try:
+ if trace_info.message_data.error:
+ span.add_event(
+ "exception",
+ attributes={
+ "exception.message": trace_info.message_data.error,
+ "exception.type": "Error",
+ "exception.stacktrace": trace_info.message_data.error,
+ },
+ )
+ finally:
+ span.end(end_time=datetime_to_nanos(trace_info.end_time))
+
+ def api_check(self):
+ try:
+ with self.tracer.start_span("api_check") as span:
+ span.set_attribute("test", "true")
+ return True
+ except Exception as e:
+ logger.info(f"[Arize/Phoenix] API check failed: {str(e)}", exc_info=True)
+ raise ValueError(f"[Arize/Phoenix] API check failed: {str(e)}")
+
+ def get_project_url(self):
+ try:
+ if self.arize_phoenix_config.endpoint == "https://otlp.arize.com":
+ return "https://app.arize.com/"
+ else:
+ return f"{self.arize_phoenix_config.endpoint}/projects/"
+ except Exception as e:
+ logger.info(f"[Arize/Phoenix] Get run url failed: {str(e)}", exc_info=True)
+ raise ValueError(f"[Arize/Phoenix] Get run url failed: {str(e)}")
+
+ def _get_workflow_nodes(self, workflow_run_id: str):
+ """Helper method to get workflow nodes"""
+ workflow_nodes = (
+ db.session.query(
+ WorkflowNodeExecutionModel.id,
+ WorkflowNodeExecutionModel.tenant_id,
+ WorkflowNodeExecutionModel.app_id,
+ WorkflowNodeExecutionModel.title,
+ WorkflowNodeExecutionModel.node_type,
+ WorkflowNodeExecutionModel.status,
+ WorkflowNodeExecutionModel.inputs,
+ WorkflowNodeExecutionModel.outputs,
+ WorkflowNodeExecutionModel.created_at,
+ WorkflowNodeExecutionModel.elapsed_time,
+ WorkflowNodeExecutionModel.process_data,
+ WorkflowNodeExecutionModel.execution_metadata,
+ )
+ .filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
+ .all()
+ )
+ return workflow_nodes
+
+ def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
+ """Helper method to construct LLM attributes with passed prompts."""
+ attributes = {}
+ if isinstance(prompts, list):
+ for i, msg in enumerate(prompts):
+ if isinstance(msg, dict):
+ attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
+ attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user")
+ # todo: handle assistant and tool role messages, as they don't always
+ # have a text field, but may have a tool_calls field instead
+ # e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
+ # 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
+ elif isinstance(prompts, dict):
+ attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts)
+ attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
+ elif isinstance(prompts, str):
+ attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts
+ attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
+
+ return attributes
diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py
index c988bf48d1..89ff0cfded 100644
--- a/api/core/ops/entities/config_entity.py
+++ b/api/core/ops/entities/config_entity.py
@@ -2,20 +2,92 @@ from enum import StrEnum
from pydantic import BaseModel, ValidationInfo, field_validator
+from core.ops.utils import validate_project_name, validate_url, validate_url_with_path
+
class TracingProviderEnum(StrEnum):
+ ARIZE = "arize"
+ PHOENIX = "phoenix"
LANGFUSE = "langfuse"
LANGSMITH = "langsmith"
OPIK = "opik"
WEAVE = "weave"
+ ALIYUN = "aliyun"
class BaseTracingConfig(BaseModel):
"""
- Base model class for tracing
+ Base model class for tracing configurations
"""
- ...
+ @classmethod
+ def validate_endpoint_url(cls, v: str, default_url: str) -> str:
+ """
+ Common endpoint URL validation logic
+
+ Args:
+ v: URL value to validate
+ default_url: Default URL to use if input is None or empty
+
+ Returns:
+ Validated and normalized URL
+ """
+ return validate_url(v, default_url)
+
+ @classmethod
+ def validate_project_field(cls, v: str, default_name: str) -> str:
+ """
+ Common project name validation logic
+
+ Args:
+ v: Project name to validate
+ default_name: Default name to use if input is None or empty
+
+ Returns:
+ Validated project name
+ """
+ return validate_project_name(v, default_name)
+
+
+class ArizeConfig(BaseTracingConfig):
+ """
+ Model class for Arize tracing config.
+ """
+
+ api_key: str | None = None
+ space_id: str | None = None
+ project: str | None = None
+ endpoint: str = "https://otlp.arize.com"
+
+ @field_validator("project")
+ @classmethod
+ def project_validator(cls, v, info: ValidationInfo):
+ return cls.validate_project_field(v, "default")
+
+ @field_validator("endpoint")
+ @classmethod
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ return cls.validate_endpoint_url(v, "https://otlp.arize.com")
+
+
+class PhoenixConfig(BaseTracingConfig):
+ """
+ Model class for Phoenix tracing config.
+ """
+
+ api_key: str | None = None
+ project: str | None = None
+ endpoint: str = "https://app.phoenix.arize.com"
+
+ @field_validator("project")
+ @classmethod
+ def project_validator(cls, v, info: ValidationInfo):
+ return cls.validate_project_field(v, "default")
+
+ @field_validator("endpoint")
+ @classmethod
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ return cls.validate_endpoint_url(v, "https://app.phoenix.arize.com")
class LangfuseConfig(BaseTracingConfig):
@@ -29,13 +101,8 @@ class LangfuseConfig(BaseTracingConfig):
@field_validator("host")
@classmethod
- def set_value(cls, v, info: ValidationInfo):
- if v is None or v == "":
- v = "https://api.langfuse.com"
- if not v.startswith("https://") and not v.startswith("http://"):
- raise ValueError("host must start with https:// or http://")
-
- return v
+ def host_validator(cls, v, info: ValidationInfo):
+ return cls.validate_endpoint_url(v, "https://api.langfuse.com")
class LangSmithConfig(BaseTracingConfig):
@@ -49,13 +116,9 @@ class LangSmithConfig(BaseTracingConfig):
@field_validator("endpoint")
@classmethod
- def set_value(cls, v, info: ValidationInfo):
- if v is None or v == "":
- v = "https://api.smith.langchain.com"
- if not v.startswith("https://"):
- raise ValueError("endpoint must start with https://")
-
- return v
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ # LangSmith only allows HTTPS
+ return validate_url(v, "https://api.smith.langchain.com", allowed_schemes=("https",))
class OpikConfig(BaseTracingConfig):
@@ -71,22 +134,12 @@ class OpikConfig(BaseTracingConfig):
@field_validator("project")
@classmethod
def project_validator(cls, v, info: ValidationInfo):
- if v is None or v == "":
- v = "Default Project"
-
- return v
+ return cls.validate_project_field(v, "Default Project")
@field_validator("url")
@classmethod
def url_validator(cls, v, info: ValidationInfo):
- if v is None or v == "":
- v = "https://www.comet.com/opik/api/"
- if not v.startswith(("https://", "http://")):
- raise ValueError("url must start with https:// or http://")
- if not v.endswith("/api/"):
- raise ValueError("url should ends with /api/")
-
- return v
+ return validate_url_with_path(v, "https://www.comet.com/opik/api/", required_suffix="/api/")
class WeaveConfig(BaseTracingConfig):
@@ -102,22 +155,44 @@ class WeaveConfig(BaseTracingConfig):
@field_validator("endpoint")
@classmethod
- def set_value(cls, v, info: ValidationInfo):
- if v is None or v == "":
- v = "https://trace.wandb.ai"
- if not v.startswith("https://"):
- raise ValueError("endpoint must start with https://")
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ # Weave only allows HTTPS for endpoint
+ return validate_url(v, "https://trace.wandb.ai", allowed_schemes=("https",))
+ @field_validator("host")
+ @classmethod
+ def host_validator(cls, v, info: ValidationInfo):
+ if v is not None and v.strip() != "":
+ return validate_url(v, v, allowed_schemes=("https", "http"))
return v
- @field_validator("host")
+
+class AliyunConfig(BaseTracingConfig):
+ """
+ Model class for Aliyun tracing config.
+ """
+
+ app_name: str = "dify_app"
+ license_key: str
+ endpoint: str
+
+ @field_validator("app_name")
+ @classmethod
+ def app_name_validator(cls, v, info: ValidationInfo):
+ return cls.validate_project_field(v, "dify_app")
+
+ @field_validator("license_key")
@classmethod
- def validate_host(cls, v, info: ValidationInfo):
- if v is not None and v != "":
- if not v.startswith(("https://", "http://")):
- raise ValueError("host must start with https:// or http://")
+ def license_key_validator(cls, v, info: ValidationInfo):
+ if not v or v.strip() == "":
+ raise ValueError("License key cannot be empty")
return v
+ @field_validator("endpoint")
+ @classmethod
+ def endpoint_validator(cls, v, info: ValidationInfo):
+ return cls.validate_endpoint_url(v, "https://tracing-analysis-dc-hz.aliyuncs.com")
+
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py
index 0ea74e9ef0..4a7e66d27c 100644
--- a/api/core/ops/langfuse_trace/langfuse_trace.py
+++ b/api/core/ops/langfuse_trace/langfuse_trace.py
@@ -28,10 +28,11 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
UnitEnum,
)
from core.ops.utils import filter_none_values
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
+from models.enums import MessageStatus
logger = logging.getLogger(__name__)
@@ -83,6 +84,7 @@ class LangFuseDataTrace(BaseTraceInstance):
metadata=metadata,
session_id=trace_info.conversation_id,
tags=["message", "workflow"],
+ version=trace_info.workflow_run_version,
)
self.add_trace(langfuse_trace_data=trace_data)
workflow_span_data = LangfuseSpan(
@@ -108,6 +110,7 @@ class LangFuseDataTrace(BaseTraceInstance):
metadata=metadata,
session_id=trace_info.conversation_id,
tags=["workflow"],
+ version=trace_info.workflow_run_version,
)
self.add_trace(langfuse_trace_data=trace_data)
@@ -120,10 +123,10 @@ class LangFuseDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
- app_id=trace_info.metadata.get("app_id"),
+ app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
@@ -172,48 +175,15 @@ class LangFuseDataTrace(BaseTraceInstance):
}
)
- # add span
- if trace_info.message_id:
- span_data = LangfuseSpan(
- id=node_execution_id,
- name=node_type,
- input=inputs,
- output=outputs,
- trace_id=trace_id,
- start_time=created_at,
- end_time=finished_at,
- metadata=metadata,
- level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
- status_message=trace_info.error or "",
- parent_observation_id=trace_info.workflow_run_id,
- )
- else:
- span_data = LangfuseSpan(
- id=node_execution_id,
- name=node_type,
- input=inputs,
- output=outputs,
- trace_id=trace_id,
- start_time=created_at,
- end_time=finished_at,
- metadata=metadata,
- level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
- status_message=trace_info.error or "",
- )
-
- self.add_span(langfuse_span_data=span_data)
-
+ # add generation span
if process_data and process_data.get("model_mode") == "chat":
total_token = metadata.get("total_tokens", 0)
prompt_tokens = 0
completion_tokens = 0
try:
- if outputs.get("usage"):
- prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
- completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
- else:
- prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
- completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
+ usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
+ prompt_tokens = usage_data.get("prompt_tokens", 0)
+ completion_tokens = usage_data.get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
@@ -226,10 +196,10 @@ class LangFuseDataTrace(BaseTraceInstance):
)
node_generation_data = LangfuseGeneration(
- name="llm",
+ id=node_execution_id,
+ name=node_name,
trace_id=trace_id,
model=process_data.get("model_name"),
- parent_observation_id=node_execution_id,
start_time=created_at,
end_time=finished_at,
input=inputs,
@@ -237,11 +207,30 @@ class LangFuseDataTrace(BaseTraceInstance):
metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error or "",
+ parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
usage=generation_usage,
)
self.add_generation(langfuse_generation_data=node_generation_data)
+ # add normal span
+ else:
+ span_data = LangfuseSpan(
+ id=node_execution_id,
+ name=node_name,
+ input=inputs,
+ output=outputs,
+ trace_id=trace_id,
+ start_time=created_at,
+ end_time=finished_at,
+ metadata=metadata,
+ level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
+ status_message=trace_info.error or "",
+ parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
+ )
+
+ self.add_span(langfuse_span_data=span_data)
+
def message_trace(self, trace_info: MessageTraceInfo, **kwargs):
# get message file data
file_list = trace_info.file_list
@@ -284,7 +273,7 @@ class LangFuseDataTrace(BaseTraceInstance):
)
self.add_trace(langfuse_trace_data=trace_data)
- # start add span
+ # add generation
generation_usage = GenerationUsage(
input=trace_info.message_tokens,
output=trace_info.answer_tokens,
@@ -302,7 +291,7 @@ class LangFuseDataTrace(BaseTraceInstance):
input=trace_info.inputs,
output=message_data.answer,
metadata=metadata,
- level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
+ level=(LevelEnum.DEFAULT if message_data.status != MessageStatus.ERROR else LevelEnum.ERROR),
status_message=message_data.error or "",
usage=generation_usage,
)
@@ -348,7 +337,7 @@ class LangFuseDataTrace(BaseTraceInstance):
start_time=trace_info.start_time,
end_time=trace_info.end_time,
metadata=trace_info.metadata,
- level=(LevelEnum.DEFAULT if message_data.status != "error" else LevelEnum.ERROR),
+ level=(LevelEnum.DEFAULT if message_data.status != MessageStatus.ERROR else LevelEnum.ERROR),
status_message=message_data.error or "",
usage=generation_usage,
)
diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py
index 8a392940db..8a559c4929 100644
--- a/api/core/ops/langsmith_trace/langsmith_trace.py
+++ b/api/core/ops/langsmith_trace/langsmith_trace.py
@@ -27,7 +27,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunUpdateModel,
)
from core.ops.utils import filter_none_values, generate_dotted_order
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
@@ -145,10 +145,10 @@ class LangSmithDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
- app_id=trace_info.metadata.get("app_id"),
+ app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
@@ -206,12 +206,9 @@ class LangSmithDataTrace(BaseTraceInstance):
prompt_tokens = 0
completion_tokens = 0
try:
- if outputs.get("usage"):
- prompt_tokens = outputs.get("usage", {}).get("prompt_tokens", 0)
- completion_tokens = outputs.get("usage", {}).get("completion_tokens", 0)
- else:
- prompt_tokens = process_data.get("usage", {}).get("prompt_tokens", 0)
- completion_tokens = process_data.get("usage", {}).get("completion_tokens", 0)
+ usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
+ prompt_tokens = usage_data.get("prompt_tokens", 0)
+ completion_tokens = usage_data.get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py
index f4d2760ba5..be4997a5bf 100644
--- a/api/core/ops/opik_trace/opik_trace.py
+++ b/api/core/ops/opik_trace/opik_trace.py
@@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
@@ -160,10 +160,10 @@ class OpikDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
- app_id=trace_info.metadata.get("app_id"),
+ app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
@@ -222,10 +222,10 @@ class OpikDataTrace(BaseTraceInstance):
)
try:
- if outputs.get("usage"):
- total_tokens = outputs["usage"].get("total_tokens", 0)
- prompt_tokens = outputs["usage"].get("prompt_tokens", 0)
- completion_tokens = outputs["usage"].get("completion_tokens", 0)
+ usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
+ total_tokens = usage_data.get("total_tokens", 0)
+ prompt_tokens = usage_data.get("prompt_tokens", 0)
+ completion_tokens = usage_data.get("completion_tokens", 0)
except Exception:
logger.error("Failed to extract usage", exc_info=True)
@@ -241,7 +241,7 @@ class OpikDataTrace(BaseTraceInstance):
"trace_id": opik_trace_id,
"id": prepare_opik_uuid(created_at, node_execution_id),
"parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id),
- "name": node_type,
+ "name": node_name,
"type": run_type,
"start_time": created_at,
"end_time": finished_at,
diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py
index e0dfe0c312..5c9b9d27b7 100644
--- a/api/core/ops/ops_trace_manager.py
+++ b/api/core/ops/ops_trace_manager.py
@@ -84,6 +84,36 @@ class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
"other_keys": ["project", "entity", "endpoint", "host"],
"trace_instance": WeaveDataTrace,
}
+ case TracingProviderEnum.ARIZE:
+ from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
+ from core.ops.entities.config_entity import ArizeConfig
+
+ return {
+ "config_class": ArizeConfig,
+ "secret_keys": ["api_key", "space_id"],
+ "other_keys": ["project", "endpoint"],
+ "trace_instance": ArizePhoenixDataTrace,
+ }
+ case TracingProviderEnum.PHOENIX:
+ from core.ops.arize_phoenix_trace.arize_phoenix_trace import ArizePhoenixDataTrace
+ from core.ops.entities.config_entity import PhoenixConfig
+
+ return {
+ "config_class": PhoenixConfig,
+ "secret_keys": ["api_key"],
+ "other_keys": ["project", "endpoint"],
+ "trace_instance": ArizePhoenixDataTrace,
+ }
+ case TracingProviderEnum.ALIYUN:
+ from core.ops.aliyun_trace.aliyun_trace import AliyunDataTrace
+ from core.ops.entities.config_entity import AliyunConfig
+
+ return {
+ "config_class": AliyunConfig,
+ "secret_keys": ["license_key"],
+ "other_keys": ["endpoint", "app_name"],
+ "trace_instance": AliyunDataTrace,
+ }
case _:
raise KeyError(f"Unsupported tracing provider: {provider}")
@@ -251,7 +281,7 @@ class OpsTraceManager:
provider_config_map[tracing_provider]["trace_instance"],
provider_config_map[tracing_provider]["config_class"],
)
- decrypt_trace_config_key = str(decrypt_trace_config)
+ decrypt_trace_config_key = json.dumps(decrypt_trace_config, sort_keys=True)
tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key)
if tracing_instance is None:
# create new tracing_instance and update the cache if it absent
diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py
index 8b06df1930..36d060afd2 100644
--- a/api/core/ops/utils.py
+++ b/api/core/ops/utils.py
@@ -1,6 +1,7 @@
from contextlib import contextmanager
from datetime import datetime
from typing import Optional, Union
+from urllib.parse import urlparse
from extensions.ext_database import db
from models.model import Message
@@ -60,3 +61,83 @@ def generate_dotted_order(
return current_segment
return f"{parent_dotted_order}.{current_segment}"
+
+
+def validate_url(url: str, default_url: str, allowed_schemes: tuple = ("https", "http")) -> str:
+ """
+ Validate and normalize URL with proper error handling
+
+ Args:
+ url: The URL to validate
+ default_url: Default URL to use if input is None or empty
+ allowed_schemes: Tuple of allowed URL schemes (default: https, http)
+
+ Returns:
+ Normalized URL string
+
+ Raises:
+ ValueError: If URL format is invalid or scheme not allowed
+ """
+ if not url or url.strip() == "":
+ return default_url
+
+ # Parse URL to validate format
+ parsed = urlparse(url)
+
+ # Check if scheme is allowed
+ if parsed.scheme not in allowed_schemes:
+ raise ValueError(f"URL scheme must be one of: {', '.join(allowed_schemes)}")
+
+ # Reconstruct URL with only scheme, netloc (removing path, query, fragment)
+ normalized_url = f"{parsed.scheme}://{parsed.netloc}"
+
+ return normalized_url
+
+
+def validate_url_with_path(url: str, default_url: str, required_suffix: str | None = None) -> str:
+ """
+ Validate URL that may include path components
+
+ Args:
+ url: The URL to validate
+ default_url: Default URL to use if input is None or empty
+ required_suffix: Optional suffix that URL must end with
+
+ Returns:
+ Validated URL string
+
+ Raises:
+ ValueError: If URL format is invalid or doesn't match required suffix
+ """
+ if not url or url.strip() == "":
+ return default_url
+
+ # Parse URL to validate format
+ parsed = urlparse(url)
+
+ # Check if scheme is allowed
+ if parsed.scheme not in ("https", "http"):
+ raise ValueError("URL must start with https:// or http://")
+
+ # Check required suffix if specified
+ if required_suffix and not url.endswith(required_suffix):
+ raise ValueError(f"URL should end with {required_suffix}")
+
+ return url
+
+
+def validate_project_name(project: str, default_name: str) -> str:
+ """
+ Validate and normalize project name
+
+ Args:
+ project: Project name to validate
+ default_name: Default name to use if input is None or empty
+
+ Returns:
+ Normalized project name
+ """
+ if not project or project.strip() == "":
+ return default_name
+
+ return project.strip()
diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py
index 3917348a91..445c6a8741 100644
--- a/api/core/ops/weave_trace/weave_trace.py
+++ b/api/core/ops/weave_trace/weave_trace.py
@@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
@@ -144,10 +144,10 @@ class WeaveDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
- app_id=trace_info.metadata.get("app_id"),
+ app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
diff --git a/api/core/plugin/backwards_invocation/encrypt.py b/api/core/plugin/backwards_invocation/encrypt.py
index 81a5d033a0..213f5c726a 100644
--- a/api/core/plugin/backwards_invocation/encrypt.py
+++ b/api/core/plugin/backwards_invocation/encrypt.py
@@ -1,16 +1,20 @@
+from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.plugin.entities.request import RequestInvokeEncrypt
-from core.tools.utils.configuration import ProviderConfigEncrypter
+from core.tools.utils.encryption import create_provider_encrypter
from models.account import Tenant
class PluginEncrypter:
@classmethod
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
- encrypter = ProviderConfigEncrypter(
+ encrypter, cache = create_provider_encrypter(
tenant_id=tenant.id,
config=payload.config,
- provider_type=payload.namespace,
- provider_identity=payload.identity,
+ cache=SingletonProviderCredentialsCache(
+ tenant_id=tenant.id,
+ provider_type=payload.namespace,
+ provider_identity=payload.identity,
+ ),
)
if payload.opt == "encrypt":
@@ -22,7 +26,7 @@ class PluginEncrypter:
"data": encrypter.decrypt(payload.data),
}
elif payload.opt == "clear":
- encrypter.delete_tool_credentials_cache()
+ cache.delete()
return {
"data": {},
}
diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py
index 072644e53b..d07ab3d0c4 100644
--- a/api/core/plugin/backwards_invocation/model.py
+++ b/api/core/plugin/backwards_invocation/model.py
@@ -2,8 +2,15 @@ import tempfile
from binascii import hexlify, unhexlify
from collections.abc import Generator
+from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelManager
-from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
+from core.model_runtime.entities.llm_entities import (
+ LLMResult,
+ LLMResultChunk,
+ LLMResultChunkDelta,
+ LLMResultChunkWithStructuredOutput,
+ LLMResultWithStructuredOutput,
+)
from core.model_runtime.entities.message_entities import (
PromptMessage,
SystemPromptMessage,
@@ -12,6 +19,7 @@ from core.model_runtime.entities.message_entities import (
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from core.plugin.entities.request import (
RequestInvokeLLM,
+ RequestInvokeLLMWithStructuredOutput,
RequestInvokeModeration,
RequestInvokeRerank,
RequestInvokeSpeech2Text,
@@ -81,6 +89,72 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
return handle_non_streaming(response)
+ @classmethod
+ def invoke_llm_with_structured_output(
+ cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLMWithStructuredOutput
+ ):
+ """
+ invoke llm with structured output
+ """
+ model_instance = ModelManager().get_model_instance(
+ tenant_id=tenant.id,
+ provider=payload.provider,
+ model_type=payload.model_type,
+ model=payload.model,
+ )
+
+ model_schema = model_instance.model_type_instance.get_model_schema(payload.model, model_instance.credentials)
+
+ if not model_schema:
+ raise ValueError(f"Model schema not found for {payload.model}")
+
+ response = invoke_llm_with_structured_output(
+ provider=payload.provider,
+ model_schema=model_schema,
+ model_instance=model_instance,
+ prompt_messages=payload.prompt_messages,
+ json_schema=payload.structured_output_schema,
+ tools=payload.tools,
+ stop=payload.stop,
+ stream=True if payload.stream is None else payload.stream,
+ user=user_id,
+ model_parameters=payload.completion_params,
+ )
+
+ if isinstance(response, Generator):
+
+ def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
+ for chunk in response:
+ if chunk.delta.usage:
+ llm_utils.deduct_llm_quota(
+ tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
+ )
+ chunk.prompt_messages = []
+ yield chunk
+
+ return handle()
+ else:
+ if response.usage:
+ llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
+
+ def handle_non_streaming(
+ response: LLMResultWithStructuredOutput,
+ ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
+ yield LLMResultChunkWithStructuredOutput(
+ model=response.model,
+ prompt_messages=[],
+ system_fingerprint=response.system_fingerprint,
+ structured_output=response.structured_output,
+ delta=LLMResultChunkDelta(
+ index=0,
+ message=response.message,
+ usage=response.usage,
+ finish_reason="",
+ ),
+ )
+
+ return handle_non_streaming(response)
+
@classmethod
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
"""
diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py
index 1d62743f13..06773504d9 100644
--- a/api/core/plugin/backwards_invocation/tool.py
+++ b/api/core/plugin/backwards_invocation/tool.py
@@ -1,5 +1,5 @@
from collections.abc import Generator
-from typing import Any
+from typing import Any, Optional
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
@@ -23,6 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
+ credential_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke tool
@@ -30,7 +31,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
# get tool runtime
try:
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
- tool_type, tenant_id, provider, tool_name, tool_parameters
+ tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id
)
response = ToolEngine.generic_invoke(
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1
diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py
index a19a44aa3c..1c13a621d4 100644
--- a/api/core/plugin/entities/marketplace.py
+++ b/api/core/plugin/entities/marketplace.py
@@ -32,6 +32,13 @@ class MarketplacePluginDeclaration(BaseModel):
latest_package_identifier: str = Field(
..., description="Unique identifier for the latest package release of the plugin"
)
+ status: str = Field(..., description="Indicate the status of marketplace plugin, enum from `active` `deleted`")
+ deprecated_reason: str = Field(
+ ..., description="Not empty when status='deleted', indicates the reason why this plugin is deleted(deprecated)"
+ )
+ alternative_plugin_id: str = Field(
+ ..., description="Optional, indicates the alternative plugin for user to switch to"
+ )
@model_validator(mode="before")
@classmethod
diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py
index 895dd0d0fc..47290ee613 100644
--- a/api/core/plugin/entities/parameters.py
+++ b/api/core/plugin/entities/parameters.py
@@ -5,11 +5,15 @@ from pydantic import BaseModel, Field, field_validator
from core.entities.parameter_entities import CommonParameterType
from core.tools.entities.common_entities import I18nObject
+from core.workflow.nodes.base.entities import NumberType
class PluginParameterOption(BaseModel):
value: str = Field(..., description="The value of the option")
label: I18nObject = Field(..., description="The label of the option")
+ icon: Optional[str] = Field(
+ default=None, description="The icon of the option, can be a url or a base64 encoded image"
+ )
@field_validator("value", mode="before")
@classmethod
@@ -35,10 +39,25 @@ class PluginParameterType(enum.StrEnum):
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
+ ANY = CommonParameterType.ANY.value
+ DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value
# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
+ # MCP object and array type parameters
+ ARRAY = CommonParameterType.ARRAY.value
+ OBJECT = CommonParameterType.OBJECT.value
+
+
+class MCPServerParameterType(enum.StrEnum):
+ """
+ MCP server got complex parameter types
+ """
+
+ ARRAY = "array"
+ OBJECT = "object"
+
class PluginParameterAutoGenerate(BaseModel):
class Type(enum.StrEnum):
@@ -134,6 +153,38 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
if value and not isinstance(value, list):
raise ValueError("The tools selector must be a list.")
return value
+ case PluginParameterType.ANY:
+ if value and not isinstance(value, str | dict | list | NumberType):
+ raise ValueError("The var selector must be a string, dictionary, list or number.")
+ return value
+ case PluginParameterType.ARRAY:
+ if not isinstance(value, list):
+ # Try to parse JSON string for arrays
+ if isinstance(value, str):
+ try:
+ import json
+
+ parsed_value = json.loads(value)
+ if isinstance(parsed_value, list):
+ return parsed_value
+ except (json.JSONDecodeError, ValueError):
+ pass
+ return [value]
+ return value
+ case PluginParameterType.OBJECT:
+ if not isinstance(value, dict):
+ # Try to parse JSON string for objects
+ if isinstance(value, str):
+ try:
+ import json
+
+ parsed_value = json.loads(value)
+ if isinstance(parsed_value, dict):
+ return parsed_value
+ except (json.JSONDecodeError, ValueError):
+ pass
+ return {}
+ return value
case _:
return str(value)
except ValueError:
diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py
index bdf7d5ce1f..a07b58d9ea 100644
--- a/api/core/plugin/entities/plugin.py
+++ b/api/core/plugin/entities/plugin.py
@@ -72,12 +72,14 @@ class PluginDeclaration(BaseModel):
class Meta(BaseModel):
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
+ version: Optional[str] = Field(default=None)
version: str = Field(..., pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
author: Optional[str] = Field(..., pattern=r"^[a-zA-Z0-9_-]{1,64}$")
name: str = Field(..., pattern=r"^[a-z0-9_-]{1,128}$")
description: I18nObject
icon: str
+ icon_dark: Optional[str] = Field(default=None)
label: I18nObject
category: PluginCategory
created_at: datetime.datetime
@@ -133,17 +135,6 @@ class PluginEntity(PluginInstallation):
return self
-class GithubPackage(BaseModel):
- repo: str
- version: str
- package: str
-
-
-class GithubVersion(BaseModel):
- repo: str
- version: str
-
-
class GenericProviderID:
organization: str
plugin_name: str
diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py
index e9275c31cc..00253b8a11 100644
--- a/api/core/plugin/entities/plugin_daemon.py
+++ b/api/core/plugin/entities/plugin_daemon.py
@@ -1,4 +1,4 @@
-from collections.abc import Mapping
+from collections.abc import Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from typing import Any, Generic, Optional, TypeVar
@@ -9,6 +9,7 @@ from core.agent.plugin_entities import AgentProviderEntityWithPlugin
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.plugin.entities.base import BasePluginEntity
+from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.entities.plugin import PluginDeclaration, PluginEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin
@@ -52,6 +53,7 @@ class PluginAgentProviderEntity(BaseModel):
plugin_unique_identifier: str
plugin_id: str
declaration: AgentProviderEntityWithPlugin
+ meta: PluginDeclaration.Meta
class PluginBasicBooleanResponse(BaseModel):
@@ -156,9 +158,23 @@ class PluginInstallTaskStartResponse(BaseModel):
task_id: str = Field(description="The ID of the install task.")
-class PluginUploadResponse(BaseModel):
+class PluginVerification(BaseModel):
+ """
+ Verification of the plugin.
+ """
+
+ class AuthorizedCategory(StrEnum):
+ Langgenius = "langgenius"
+ Partner = "partner"
+ Community = "community"
+
+ authorized_category: AuthorizedCategory = Field(description="The authorized category of the plugin.")
+
+
+class PluginDecodeResponse(BaseModel):
unique_identifier: str = Field(description="The unique identifier of the plugin.")
manifest: PluginDeclaration
+ verification: Optional[PluginVerification] = Field(default=None, description="Basic verification information")
class PluginOAuthAuthorizationUrlResponse(BaseModel):
@@ -172,3 +188,7 @@ class PluginOAuthCredentialsResponse(BaseModel):
class PluginListResponse(BaseModel):
list: list[PluginEntity]
total: int
+
+
+class PluginDynamicSelectOptionsResponse(BaseModel):
+ options: Sequence[PluginParameterOption] = Field(description="The options of the dynamic select.")
diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py
index 1692020ec8..3a783dad3e 100644
--- a/api/core/plugin/entities/request.py
+++ b/api/core/plugin/entities/request.py
@@ -27,15 +27,30 @@ from core.workflow.nodes.question_classifier.entities import (
)
+class InvokeCredentials(BaseModel):
+ tool_credentials: dict[str, str] = Field(
+ default_factory=dict,
+ description="Map of tool provider to credential id, used to store the credential id for the tool provider.",
+ )
+
+
+class PluginInvokeContext(BaseModel):
+ credentials: Optional[InvokeCredentials] = Field(
+ default_factory=InvokeCredentials,
+ description="Credentials context for the plugin invocation or backward invocation.",
+ )
+
+
class RequestInvokeTool(BaseModel):
"""
Request to invoke a tool
"""
- tool_type: Literal["builtin", "workflow", "api"]
+ tool_type: Literal["builtin", "workflow", "api", "mcp"]
provider: str
tool: str
tool_parameters: dict
+ credential_id: Optional[str] = None
class BaseRequestInvokeModel(BaseModel):
@@ -82,6 +97,16 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
return v
+class RequestInvokeLLMWithStructuredOutput(RequestInvokeLLM):
+ """
+ Request to invoke LLM with structured output
+ """
+
+ structured_output_schema: dict[str, Any] = Field(
+ default_factory=dict, description="The schema of the structured output in JSON schema format"
+ )
+
+
class RequestInvokeTextEmbedding(BaseRequestInvokeModel):
"""
Request to invoke text embedding
diff --git a/api/core/plugin/impl/agent.py b/api/core/plugin/impl/agent.py
index 66b77c7489..9575c57ac8 100644
--- a/api/core/plugin/impl/agent.py
+++ b/api/core/plugin/impl/agent.py
@@ -6,6 +6,7 @@ from core.plugin.entities.plugin import GenericProviderID
from core.plugin.entities.plugin_daemon import (
PluginAgentProviderEntity,
)
+from core.plugin.entities.request import PluginInvokeContext
from core.plugin.impl.base import BasePluginClient
@@ -83,6 +84,7 @@ class PluginAgentClient(BasePluginClient):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
+ context: Optional[PluginInvokeContext] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent with the given tenant, user, plugin, provider, name and parameters.
@@ -99,6 +101,7 @@ class PluginAgentClient(BasePluginClient):
"conversation_id": conversation_id,
"app_id": app_id,
"message_id": message_id,
+ "context": context.model_dump() if context else {},
"data": {
"agent_strategy_provider": agent_provider_id.provider_name,
"agent_strategy": agent_strategy,
diff --git a/api/core/plugin/impl/dynamic_select.py b/api/core/plugin/impl/dynamic_select.py
new file mode 100644
index 0000000000..004412afd7
--- /dev/null
+++ b/api/core/plugin/impl/dynamic_select.py
@@ -0,0 +1,45 @@
+from collections.abc import Mapping
+from typing import Any
+
+from core.plugin.entities.plugin import GenericProviderID
+from core.plugin.entities.plugin_daemon import PluginDynamicSelectOptionsResponse
+from core.plugin.impl.base import BasePluginClient
+
+
+class DynamicSelectClient(BasePluginClient):
+ def fetch_dynamic_select_options(
+ self,
+ tenant_id: str,
+ user_id: str,
+ plugin_id: str,
+ provider: str,
+ action: str,
+ credentials: Mapping[str, Any],
+ parameter: str,
+ ) -> PluginDynamicSelectOptionsResponse:
+ """
+ Fetch dynamic select options for a plugin parameter.
+ """
+ response = self._request_with_plugin_daemon_response_stream(
+ "POST",
+ f"plugin/{tenant_id}/dispatch/dynamic_select/fetch_parameter_options",
+ PluginDynamicSelectOptionsResponse,
+ data={
+ "user_id": user_id,
+ "data": {
+ "provider": GenericProviderID(provider).provider_name,
+ "credentials": credentials,
+ "provider_action": action,
+ "parameter": parameter,
+ },
+ },
+ headers={
+ "X-Plugin-ID": plugin_id,
+ "Content-Type": "application/json",
+ },
+ )
+
+ for options in response:
+ return options
+
+ raise ValueError(f"Plugin service returned no options for parameter '{parameter}' in provider '{provider}'")
diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py
index 91774984c8..d73e5d9f9e 100644
--- a/api/core/plugin/impl/oauth.py
+++ b/api/core/plugin/impl/oauth.py
@@ -1,3 +1,4 @@
+import binascii
from collections.abc import Mapping
from typing import Any
@@ -14,24 +15,32 @@ class OAuthHandler(BasePluginClient):
user_id: str,
plugin_id: str,
provider: str,
+ redirect_uri: str,
system_credentials: Mapping[str, Any],
) -> PluginOAuthAuthorizationUrlResponse:
- return self._request_with_plugin_daemon_response(
- "POST",
- f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
- PluginOAuthAuthorizationUrlResponse,
- data={
- "user_id": user_id,
- "data": {
- "provider": provider,
- "system_credentials": system_credentials,
+ try:
+ response = self._request_with_plugin_daemon_response_stream(
+ "POST",
+ f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
+ PluginOAuthAuthorizationUrlResponse,
+ data={
+ "user_id": user_id,
+ "data": {
+ "provider": provider,
+ "redirect_uri": redirect_uri,
+ "system_credentials": system_credentials,
+ },
},
- },
- headers={
- "X-Plugin-ID": plugin_id,
- "Content-Type": "application/json",
- },
- )
+ headers={
+ "X-Plugin-ID": plugin_id,
+ "Content-Type": "application/json",
+ },
+ )
+ for resp in response:
+ return resp
+ raise ValueError("No response received from plugin daemon for authorization URL request.")
+ except Exception as e:
+ raise ValueError(f"Error getting authorization URL: {e}")
def get_credentials(
self,
@@ -39,6 +48,7 @@ class OAuthHandler(BasePluginClient):
user_id: str,
plugin_id: str,
provider: str,
+ redirect_uri: str,
system_credentials: Mapping[str, Any],
request: Request,
) -> PluginOAuthCredentialsResponse:
@@ -46,26 +56,33 @@ class OAuthHandler(BasePluginClient):
Get credentials from the given request.
"""
- # encode request to raw http request
- raw_request_bytes = self._convert_request_to_raw_data(request)
-
- return self._request_with_plugin_daemon_response(
- "POST",
- f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
- PluginOAuthCredentialsResponse,
- data={
- "user_id": user_id,
- "data": {
- "provider": provider,
- "system_credentials": system_credentials,
- "raw_request_bytes": raw_request_bytes,
+ try:
+ # encode request to raw http request
+ raw_request_bytes = self._convert_request_to_raw_data(request)
+ response = self._request_with_plugin_daemon_response_stream(
+ "POST",
+ f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
+ PluginOAuthCredentialsResponse,
+ data={
+ "user_id": user_id,
+ "data": {
+ "provider": provider,
+ "redirect_uri": redirect_uri,
+ "system_credentials": system_credentials,
+ # for json serialization
+ "raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
+ },
+ },
+ headers={
+ "X-Plugin-ID": plugin_id,
+ "Content-Type": "application/json",
},
- },
- headers={
- "X-Plugin-ID": plugin_id,
- "Content-Type": "application/json",
- },
- )
+ )
+ for resp in response:
+ return resp
+ raise ValueError("No response received from plugin daemon for authorization URL request.")
+ except Exception as e:
+ raise ValueError(f"Error getting credentials: {e}")
def _convert_request_to_raw_data(self, request: Request) -> bytes:
"""
@@ -79,7 +96,7 @@ class OAuthHandler(BasePluginClient):
"""
# Start with the request line
method = request.method
- path = request.path
+ path = request.full_path
protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1")
raw_data = f"{method} {path} {protocol}\r\n".encode()
diff --git a/api/core/plugin/impl/plugin.py b/api/core/plugin/impl/plugin.py
index 1cd2dc1be7..04ac8c9649 100644
--- a/api/core/plugin/impl/plugin.py
+++ b/api/core/plugin/impl/plugin.py
@@ -10,10 +10,10 @@ from core.plugin.entities.plugin import (
PluginInstallationSource,
)
from core.plugin.entities.plugin_daemon import (
+ PluginDecodeResponse,
PluginInstallTask,
PluginInstallTaskStartResponse,
PluginListResponse,
- PluginUploadResponse,
)
from core.plugin.impl.base import BasePluginClient
@@ -36,7 +36,7 @@ class PluginInstaller(BasePluginClient):
"GET",
f"plugin/{tenant_id}/management/list",
PluginListResponse,
- params={"page": 1, "page_size": 256},
+ params={"page": 1, "page_size": 256, "response_type": "paged"},
)
return result.list
@@ -45,7 +45,7 @@ class PluginInstaller(BasePluginClient):
"GET",
f"plugin/{tenant_id}/management/list",
PluginListResponse,
- params={"page": page, "page_size": page_size},
+ params={"page": page, "page_size": page_size, "response_type": "paged"},
)
def upload_pkg(
@@ -53,7 +53,7 @@ class PluginInstaller(BasePluginClient):
tenant_id: str,
pkg: bytes,
verify_signature: bool = False,
- ) -> PluginUploadResponse:
+ ) -> PluginDecodeResponse:
"""
Upload a plugin package and return the plugin unique identifier.
"""
@@ -68,7 +68,7 @@ class PluginInstaller(BasePluginClient):
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/management/install/upload/package",
- PluginUploadResponse,
+ PluginDecodeResponse,
files=body,
data=data,
)
@@ -176,6 +176,18 @@ class PluginInstaller(BasePluginClient):
params={"plugin_unique_identifier": plugin_unique_identifier},
)
+ def decode_plugin_from_identifier(self, tenant_id: str, plugin_unique_identifier: str) -> PluginDecodeResponse:
+ """
+ Decode a plugin from an identifier.
+ """
+ return self._request_with_plugin_daemon_response(
+ "GET",
+ f"plugin/{tenant_id}/management/decode/from_identifier",
+ PluginDecodeResponse,
+ data={"plugin_unique_identifier": plugin_unique_identifier},
+ headers={"Content-Type": "application/json"},
+ )
+
def fetch_plugin_installation_by_ids(
self, tenant_id: str, plugin_ids: Sequence[str]
) -> Sequence[PluginInstallation]:
diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py
index 19b26c8fe3..04225f95ee 100644
--- a/api/core/plugin/impl/tool.py
+++ b/api/core/plugin/impl/tool.py
@@ -6,7 +6,7 @@ from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient
-from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
+from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
class PluginToolManager(BasePluginClient):
@@ -78,6 +78,7 @@ class PluginToolManager(BasePluginClient):
tool_provider: str,
tool_name: str,
credentials: dict[str, Any],
+ credential_type: CredentialType,
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
@@ -102,6 +103,7 @@ class PluginToolManager(BasePluginClient):
"provider": tool_provider_id.provider_name,
"tool": tool_name,
"credentials": credentials,
+ "credential_type": credential_type,
"tool_parameters": tool_parameters,
},
},
diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py
index 25964ae063..0f0fe65f27 100644
--- a/api/core/prompt/advanced_prompt_transform.py
+++ b/api/core/prompt/advanced_prompt_transform.py
@@ -158,7 +158,7 @@ class AdvancedPromptTransform(PromptTransform):
if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
if self.with_variable_tmpl:
- vp = VariablePool()
+ vp = VariablePool.empty()
for k, v in inputs.items():
if k.startswith("#"):
vp.add(k[1:-1].split("."), v)
diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py
index 47808928f7..e19c6419ca 100644
--- a/api/core/prompt/simple_prompt_transform.py
+++ b/api/core/prompt/simple_prompt_transform.py
@@ -29,19 +29,6 @@ class ModelMode(enum.StrEnum):
COMPLETION = "completion"
CHAT = "chat"
- @classmethod
- def value_of(cls, value: str) -> "ModelMode":
- """
- Get value of given mode.
-
- :param value: mode value
- :return: mode
- """
- for mode in cls:
- if mode.value == value:
- return mode
- raise ValueError(f"invalid mode value {value}")
-
prompt_file_contents: dict[str, Any] = {}
@@ -65,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
) -> tuple[list[PromptMessage], Optional[list[str]]]:
inputs = {key: str(value) for key, value in inputs.items()}
- model_mode = ModelMode.value_of(model_config.mode)
+ model_mode = ModelMode(model_config.mode)
if model_mode == ModelMode.CHAT:
prompt_messages, stops = self._get_chat_model_prompt_messages(
app_mode=app_mode,
diff --git a/api/core/prompt/utils/extract_thread_messages.py b/api/core/prompt/utils/extract_thread_messages.py
index f7aef76c87..4b883622a7 100644
--- a/api/core/prompt/utils/extract_thread_messages.py
+++ b/api/core/prompt/utils/extract_thread_messages.py
@@ -1,10 +1,11 @@
-from typing import Any
+from collections.abc import Sequence
from constants import UUID_NIL
+from models import Message
-def extract_thread_messages(messages: list[Any]):
- thread_messages = []
+def extract_thread_messages(messages: Sequence[Message]):
+ thread_messages: list[Message] = []
next_message = None
for message in messages:
diff --git a/api/core/prompt/utils/get_thread_messages_length.py b/api/core/prompt/utils/get_thread_messages_length.py
index f49466db6d..de64c27a73 100644
--- a/api/core/prompt/utils/get_thread_messages_length.py
+++ b/api/core/prompt/utils/get_thread_messages_length.py
@@ -1,3 +1,5 @@
+from sqlalchemy import select
+
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from models.model import Message
@@ -8,19 +10,9 @@ def get_thread_messages_length(conversation_id: str) -> int:
Get the number of thread messages based on the parent message id.
"""
# Fetch all messages related to the conversation
- query = (
- db.session.query(
- Message.id,
- Message.parent_message_id,
- Message.answer,
- )
- .filter(
- Message.conversation_id == conversation_id,
- )
- .order_by(Message.created_at.desc())
- )
+ stmt = select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at.desc())
- messages = query.all()
+ messages = db.session.scalars(stmt).all()
# Extract thread messages
thread_messages = extract_thread_messages(messages)
diff --git a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py
deleted file mode 100644
index 167a919e69..0000000000
--- a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py
+++ /dev/null
@@ -1,12 +0,0 @@
-"""Abstract interface for document clean implementations."""
-
-from core.rag.cleaner.cleaner_base import BaseCleaner
-
-
-class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
- def clean(self, content) -> str:
- """clean document content."""
- from unstructured.cleaners.core import clean_extra_whitespace
-
- # Returns "ITEM 1A: RISK FACTORS"
- return clean_extra_whitespace(content)
diff --git a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py
deleted file mode 100644
index 9c682d29db..0000000000
--- a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py
+++ /dev/null
@@ -1,15 +0,0 @@
-"""Abstract interface for document clean implementations."""
-
-from core.rag.cleaner.cleaner_base import BaseCleaner
-
-
-class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner):
- def clean(self, content) -> str:
- """clean document content."""
- import re
-
- from unstructured.cleaners.core import group_broken_paragraphs
-
- para_split_re = re.compile(r"(\s*\n\s*){3}")
-
- return group_broken_paragraphs(content, paragraph_split=para_split_re)
diff --git a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py
deleted file mode 100644
index 0cdbb171e1..0000000000
--- a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py
+++ /dev/null
@@ -1,12 +0,0 @@
-"""Abstract interface for document clean implementations."""
-
-from core.rag.cleaner.cleaner_base import BaseCleaner
-
-
-class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
- def clean(self, content) -> str:
- """clean document content."""
- from unstructured.cleaners.core import clean_non_ascii_chars
-
- # Returns "This text contains non-ascii characters!"
- return clean_non_ascii_chars(content)
diff --git a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py
deleted file mode 100644
index 9f42044a2d..0000000000
--- a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py
+++ /dev/null
@@ -1,12 +0,0 @@
-"""Abstract interface for document clean implementations."""
-
-from core.rag.cleaner.cleaner_base import BaseCleaner
-
-
-class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
- def clean(self, content) -> str:
- """Replaces unicode quote characters, such as the \x91 character in a string."""
-
- from unstructured.cleaners.core import replace_unicode_quotes
-
- return replace_unicode_quotes(content)
diff --git a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py
deleted file mode 100644
index 32ae7217e8..0000000000
--- a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py
+++ /dev/null
@@ -1,11 +0,0 @@
-"""Abstract interface for document clean implementations."""
-
-from core.rag.cleaner.cleaner_base import BaseCleaner
-
-
-class UnstructuredTranslateTextCleaner(BaseCleaner):
- def clean(self, content) -> str:
- """clean document content."""
- from unstructured.cleaners.translate import translate_text
-
- return translate_text(content)
diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py
index 2c5178241c..5a6903d3d5 100644
--- a/api/core/rag/datasource/retrieval_service.py
+++ b/api/core/rag/datasource/retrieval_service.py
@@ -3,7 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from flask import Flask, current_app
-from sqlalchemy.orm import load_only
+from sqlalchemy.orm import Session, load_only
from configs import dify_config
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
@@ -144,7 +144,8 @@ class RetrievalService:
@classmethod
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
- return db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ with Session(db.engine) as session:
+ return session.query(Dataset).filter(Dataset.id == dataset_id).first()
@classmethod
def keyword_search(
diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
index 095752ea8e..6f3e15d166 100644
--- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
+++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
@@ -233,6 +233,12 @@ class AnalyticdbVectorOpenAPI:
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
+ document_ids_filter = kwargs.get("document_ids_filter")
+ where_clause = ""
+ if document_ids_filter:
+ document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+ where_clause += f"metadata_->>'document_id' IN ({document_ids})"
+
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
@@ -245,7 +251,7 @@ class AnalyticdbVectorOpenAPI:
vector=query_vector,
content=None,
top_k=kwargs.get("top_k", 4),
- filter=None,
+ filter=where_clause,
)
response = self._client.query_collection_data(request)
documents = []
@@ -265,6 +271,11 @@ class AnalyticdbVectorOpenAPI:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
+ document_ids_filter = kwargs.get("document_ids_filter")
+ where_clause = ""
+ if document_ids_filter:
+ document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+ where_clause += f"metadata_->>'document_id' IN ({document_ids})"
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
@@ -277,7 +288,7 @@ class AnalyticdbVectorOpenAPI:
vector=None,
content=query,
top_k=kwargs.get("top_k", 4),
- filter=None,
+ filter=where_clause,
)
response = self._client.query_collection_data(request)
documents = []
diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
index 44cc5d3e98..ad39717183 100644
--- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
+++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
@@ -147,10 +147,17 @@ class ElasticSearchVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
- query_str = {"match": {Field.CONTENT_KEY.value: query}}
+ query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY.value: query}}
document_ids_filter = kwargs.get("document_ids_filter")
+
if document_ids_filter:
- query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} # type: ignore
+ query_str = {
+ "bool": {
+ "must": {"match": {Field.CONTENT_KEY.value: query}},
+ "filter": {"terms": {"metadata.document_id": document_ids_filter}},
+ }
+ }
+
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
docs = []
for hit in results["hits"]["hits"]:
diff --git a/api/core/rag/datasource/vdb/matrixone/__init__.py b/api/core/rag/datasource/vdb/matrixone/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py
new file mode 100644
index 0000000000..4894957382
--- /dev/null
+++ b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py
@@ -0,0 +1,233 @@
+import json
+import logging
+import uuid
+from functools import wraps
+from typing import Any, Optional
+
+from mo_vector.client import MoVectorClient # type: ignore
+from pydantic import BaseModel, model_validator
+
+from configs import dify_config
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+from models.dataset import Dataset
+
+logger = logging.getLogger(__name__)
+
+
+class MatrixoneConfig(BaseModel):
+ host: str = "localhost"
+ port: int = 6001
+ user: str = "dump"
+ password: str = "111"
+ database: str = "dify"
+ metric: str = "l2"
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_config(cls, values: dict) -> dict:
+ if not values["host"]:
+ raise ValueError("config host is required")
+ if not values["port"]:
+ raise ValueError("config port is required")
+ if not values["user"]:
+ raise ValueError("config user is required")
+ if not values["password"]:
+ raise ValueError("config password is required")
+ if not values["database"]:
+ raise ValueError("config database is required")
+ return values
+
+
+def ensure_client(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ if self.client is None:
+ self.client = self._get_client(None, False)
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+
+class MatrixoneVector(BaseVector):
+ """
+ Matrixone vector storage implementation.
+ """
+
+ def __init__(self, collection_name: str, config: MatrixoneConfig):
+ super().__init__(collection_name)
+ self.config = config
+ self.collection_name = collection_name.lower()
+ self.client = None
+
+ @property
+ def collection_name(self):
+ return self._collection_name
+
+ @collection_name.setter
+ def collection_name(self, value):
+ self._collection_name = value
+
+ def get_type(self) -> str:
+ return VectorType.MATRIXONE
+
+ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+ if self.client is None:
+ self.client = self._get_client(len(embeddings[0]), True)
+ return self.add_texts(texts, embeddings)
+
+ def _get_client(self, dimension: Optional[int] = None, create_table: bool = False) -> MoVectorClient:
+ """
+ Create a new client for the collection.
+
+ The collection will be created if it doesn't exist.
+ """
+ lock_name = f"vector_indexing_lock_{self._collection_name}"
+ with redis_client.lock(lock_name, timeout=20):
+ client = MoVectorClient(
+ connection_string=f"mysql+pymysql://{self.config.user}:{self.config.password}@{self.config.host}:{self.config.port}/{self.config.database}",
+ table_name=self.collection_name,
+ vector_dimension=dimension,
+ create_table=create_table,
+ )
+ collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
+ if redis_client.get(collection_exist_cache_key):
+ return client
+ try:
+ client.create_full_text_index()
+ except Exception as e:
+ logger.exception("Failed to create full text index")
+ redis_client.set(collection_exist_cache_key, 1, ex=3600)
+ return client
+
+ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+ if self.client is None:
+ self.client = self._get_client(len(embeddings[0]), True)
+ assert self.client is not None
+ ids = []
+ for _, doc in enumerate(documents):
+ if doc.metadata is not None:
+ doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
+ ids.append(doc_id)
+ self.client.insert(
+ texts=[doc.page_content for doc in documents],
+ embeddings=embeddings,
+ metadatas=[doc.metadata for doc in documents],
+ ids=ids,
+ )
+ return ids
+
+ @ensure_client
+ def text_exists(self, id: str) -> bool:
+ assert self.client is not None
+ result = self.client.get(ids=[id])
+ return len(result) > 0
+
+ @ensure_client
+ def delete_by_ids(self, ids: list[str]) -> None:
+ assert self.client is not None
+ if not ids:
+ return
+ self.client.delete(ids=ids)
+
+ @ensure_client
+ def get_ids_by_metadata_field(self, key: str, value: str):
+ assert self.client is not None
+ results = self.client.query_by_metadata(filter={key: value})
+ return [result.id for result in results]
+
+ @ensure_client
+ def delete_by_metadata_field(self, key: str, value: str) -> None:
+ assert self.client is not None
+ self.client.delete(filter={key: value})
+
+ @ensure_client
+ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+ assert self.client is not None
+ top_k = kwargs.get("top_k", 5)
+ document_ids_filter = kwargs.get("document_ids_filter")
+ filter = None
+ if document_ids_filter:
+ filter = {"document_id": {"$in": document_ids_filter}}
+
+ results = self.client.query(
+ query_vector=query_vector,
+ k=top_k,
+ filter=filter,
+ )
+
+ docs = []
+ # TODO: add the score threshold to the query
+ for result in results:
+ metadata = result.metadata
+ docs.append(
+ Document(
+ page_content=result.document,
+ metadata=metadata,
+ )
+ )
+ return docs
+
+ @ensure_client
+ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+ assert self.client is not None
+ top_k = kwargs.get("top_k", 5)
+ document_ids_filter = kwargs.get("document_ids_filter")
+ filter = None
+ if document_ids_filter:
+ filter = {"document_id": {"$in": document_ids_filter}}
+ score_threshold = float(kwargs.get("score_threshold", 0.0))
+
+ results = self.client.full_text_query(
+ keywords=[query],
+ k=top_k,
+ filter=filter,
+ )
+
+ docs = []
+ for result in results:
+ metadata = result.metadata
+ if isinstance(metadata, str):
+ import json
+
+ metadata = json.loads(metadata)
+ score = 1 - result.distance
+ if score >= score_threshold:
+ metadata["score"] = score
+ docs.append(
+ Document(
+ page_content=result.document,
+ metadata=metadata,
+ )
+ )
+ return docs
+
+ @ensure_client
+ def delete(self) -> None:
+ assert self.client is not None
+ self.client.delete()
+
+
+class MatrixoneVectorFactory(AbstractVectorFactory):
+ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:
+ if dataset.index_struct_dict:
+ class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
+ collection_name = class_prefix
+ else:
+ dataset_id = dataset.id
+ collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+ dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MATRIXONE, collection_name))
+
+ config = MatrixoneConfig(
+ host=dify_config.MATRIXONE_HOST or "localhost",
+ port=dify_config.MATRIXONE_PORT or 6001,
+ user=dify_config.MATRIXONE_USER or "dump",
+ password=dify_config.MATRIXONE_PASSWORD or "111",
+ database=dify_config.MATRIXONE_DATABASE or "dify",
+ metric=dify_config.MATRIXONE_METRIC or "l2",
+ )
+ return MatrixoneVector(collection_name=collection_name, config=config)
diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
index 8ce194c683..05fa73011a 100644
--- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
+++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
@@ -47,6 +47,7 @@ class QdrantConfig(BaseModel):
grpc_port: int = 6334
prefer_grpc: bool = False
replication_factor: int = 1
+ write_consistency_factor: int = 1
def to_qdrant_params(self):
if self.endpoint and self.endpoint.startswith("path:"):
@@ -127,6 +128,7 @@ class QdrantVector(BaseVector):
hnsw_config=hnsw_config,
timeout=int(self._client_config.timeout),
replication_factor=self._client_config.replication_factor,
+ write_consistency_factor=self._client_config.write_consistency_factor,
)
# create group_id payload index
diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py
index a124faa503..552068c99e 100644
--- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py
+++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py
@@ -4,6 +4,7 @@ from typing import Any, Optional
import tablestore # type: ignore
from pydantic import BaseModel, model_validator
+from tablestore import BatchGetRowRequest, TableInBatchGetRowItem
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -50,6 +51,29 @@ class TableStoreVector(BaseVector):
self._index_name = f"{collection_name}_idx"
self._tags_field = f"{Field.METADATA_KEY.value}_tags"
+ def create_collection(self, embeddings: list[list[float]], **kwargs):
+ dimension = len(embeddings[0])
+ self._create_collection(dimension)
+
+ def get_by_ids(self, ids: list[str]) -> list[Document]:
+ docs = []
+ request = BatchGetRowRequest()
+ columns_to_get = [Field.METADATA_KEY.value, Field.CONTENT_KEY.value]
+ rows_to_get = [[("id", _id)] for _id in ids]
+ request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1))
+
+ result = self._tablestore_client.batch_get_row(request)
+ table_result = result.get_result_by_table(self._table_name)
+ for item in table_result:
+ if item.is_ok and item.row:
+ kv = {k: v for k, v, t in item.row.attribute_columns}
+ docs.append(
+ Document(
+ page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value])
+ )
+ )
+ return docs
+
def get_type(self) -> str:
return VectorType.TABLESTORE
diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py
index d2bf3eb92a..84746d23ea 100644
--- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py
+++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py
@@ -122,7 +122,6 @@ class TencentVector(BaseVector):
metric_type,
params,
)
- index_text = vdb_index.FilterIndex(self.field_text, enum.FieldType.String, enum.IndexType.FILTER)
index_metadate = vdb_index.FilterIndex(self.field_metadata, enum.FieldType.Json, enum.IndexType.FILTER)
index_sparse_vector = vdb_index.SparseIndex(
name="sparse_vector",
@@ -130,7 +129,7 @@ class TencentVector(BaseVector):
index_type=enum.IndexType.SPARSE_INVERTED,
metric_type=enum.MetricType.IP,
)
- indexes = [index_id, index_vector, index_text, index_metadate]
+ indexes = [index_id, index_vector, index_metadate]
if self._enable_hybrid_search:
indexes.append(index_sparse_vector)
try:
@@ -149,7 +148,7 @@ class TencentVector(BaseVector):
index_metadate = vdb_index.FilterIndex(
self.field_metadata, enum.FieldType.String, enum.IndexType.FILTER
)
- indexes = [index_id, index_vector, index_text, index_metadate]
+ indexes = [index_id, index_vector, index_metadate]
if self._enable_hybrid_search:
indexes.append(index_sparse_vector)
self._client.create_collection(
@@ -207,9 +206,19 @@ class TencentVector(BaseVector):
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
- self._client.delete(
- database_name=self._client_config.database, collection_name=self.collection_name, document_ids=ids
- )
+
+ total_count = len(ids)
+ batch_size = self._client_config.max_upsert_batch_size
+ batch = math.ceil(total_count / batch_size)
+
+ for j in range(batch):
+ start_idx = j * batch_size
+ end_idx = min(total_count, (j + 1) * batch_size)
+ batch_ids = ids[start_idx:end_idx]
+
+ self._client.delete(
+ database_name=self._client_config.database, collection_name=self.collection_name, document_ids=batch_ids
+ )
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._client.delete(
diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py
deleted file mode 100644
index 1e62b3c589..0000000000
--- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from typing import Optional
-
-from pydantic import BaseModel
-
-
-class ClusterEntity(BaseModel):
- """
- Model Config Entity.
- """
-
- name: str
- cluster_id: str
- displayName: str
- region: str
- spendingLimit: Optional[int] = 1000
- version: str
- createdBy: str
diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py
index 66e002312a..00080b0fae 100644
--- a/api/core/rag/datasource/vdb/vector_factory.py
+++ b/api/core/rag/datasource/vdb/vector_factory.py
@@ -1,3 +1,5 @@
+import logging
+import time
from abc import ABC, abstractmethod
from typing import Any, Optional
@@ -13,6 +15,8 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Whitelist
+logger = logging.getLogger(__name__)
+
class AbstractVectorFactory(ABC):
@abstractmethod
@@ -164,13 +168,29 @@ class Vector:
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
return HuaweiCloudVectorFactory
+ case VectorType.MATRIXONE:
+ from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
+
+ return MatrixoneVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")
def create(self, texts: Optional[list] = None, **kwargs):
if texts:
- embeddings = self._embeddings.embed_documents([document.page_content for document in texts])
- self._vector_processor.create(texts=texts, embeddings=embeddings, **kwargs)
+ start = time.time()
+ logger.info(f"start embedding {len(texts)} texts {start}")
+ batch_size = 1000
+ total_batches = len(texts) + batch_size - 1
+ for i in range(0, len(texts), batch_size):
+ batch = texts[i : i + batch_size]
+ batch_start = time.time()
+ logger.info(f"Processing batch {i // batch_size + 1}/{total_batches} ({len(batch)} texts)")
+ batch_embeddings = self._embeddings.embed_documents([document.page_content for document in batch])
+ logger.info(
+ f"Embedding batch {i // batch_size + 1}/{total_batches} took {time.time() - batch_start:.3f}s"
+ )
+ self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
+ logger.info(f"Embedding {len(texts)} texts took {time.time() - start:.3f}s")
def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get("duplicate_check", False):
diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py
index 7a81565e37..0d70947b72 100644
--- a/api/core/rag/datasource/vdb/vector_type.py
+++ b/api/core/rag/datasource/vdb/vector_type.py
@@ -29,3 +29,4 @@ class VectorType(StrEnum):
OPENGAUSS = "opengauss"
TABLESTORE = "tablestore"
HUAWEI_CLOUD = "huawei_cloud"
+ MATRIXONE = "matrixone"
diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
index c6cf0d2b27..7a8efb4068 100644
--- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
+++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py
@@ -45,7 +45,8 @@ class WeaviateVector(BaseVector):
# by changing the connection timeout to pypi.org from 1 second to 0.001 seconds.
# TODO: This can be removed once weaviate-client is updated to 3.26.7 or higher,
# which does not contain the deprecation check.
- weaviate.connect.connection.PYPI_TIMEOUT = 0.001
+ if hasattr(weaviate.connect.connection, "PYPI_TIMEOUT"):
+ weaviate.connect.connection.PYPI_TIMEOUT = 0.001
try:
client = weaviate.Client(
diff --git a/api/core/rag/extractor/blob/blob.py b/api/core/rag/extractor/blob/blob.py
index e46ab8b7fd..01003a13b6 100644
--- a/api/core/rag/extractor/blob/blob.py
+++ b/api/core/rag/extractor/blob/blob.py
@@ -9,8 +9,7 @@ from __future__ import annotations
import contextlib
import mimetypes
-from abc import ABC, abstractmethod
-from collections.abc import Generator, Iterable, Mapping
+from collections.abc import Generator, Mapping
from io import BufferedReader, BytesIO
from pathlib import Path, PurePath
from typing import Any, Optional, Union
@@ -143,21 +142,3 @@ class Blob(BaseModel):
if self.source:
str_repr += f" {self.source}"
return str_repr
-
-
-class BlobLoader(ABC):
- """Abstract interface for blob loaders implementation.
-
- Implementer should be able to load raw content from a datasource system according
- to some criteria and return the raw content lazily as a stream of blobs.
- """
-
- @abstractmethod
- def yield_blobs(
- self,
- ) -> Iterable[Blob]:
- """A lazy loader for raw data represented by Blob object.
-
- Returns:
- A generator over blobs
- """
diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py
index 836a1398bf..83a4ac651f 100644
--- a/api/core/rag/extractor/firecrawl/firecrawl_app.py
+++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py
@@ -22,6 +22,7 @@ class FirecrawlApp:
"formats": ["markdown"],
"onlyMainContent": True,
"timeout": 30000,
+ "integration": "dify",
}
if params:
json_data.update(params)
@@ -39,7 +40,7 @@ class FirecrawlApp:
def crawl_url(self, url, params=None) -> str:
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post
headers = self._prepare_headers()
- json_data = {"url": url}
+ json_data = {"url": url, "integration": "dify"}
if params:
json_data.update(params)
response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers)
@@ -49,7 +50,6 @@ class FirecrawlApp:
return cast(str, job_id)
else:
self._handle_error(response, "start crawl job")
- # FIXME: unreachable code for mypy
return "" # unreachable
def check_crawl_status(self, job_id) -> dict[str, Any]:
@@ -82,7 +82,6 @@ class FirecrawlApp:
)
else:
self._handle_error(response, "check crawl status")
- # FIXME: unreachable code for mypy
return {} # unreachable
def _format_crawl_status_response(
@@ -126,4 +125,31 @@ class FirecrawlApp:
def _handle_error(self, response, action) -> None:
error_message = response.json().get("error", "Unknown error occurred")
- raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}")
+ raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
+
+ def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
+ # Documentation: https://docs.firecrawl.dev/api-reference/endpoint/search
+ headers = self._prepare_headers()
+ json_data = {
+ "query": query,
+ "limit": 5,
+ "lang": "en",
+ "country": "us",
+ "timeout": 60000,
+ "ignoreInvalidURLs": False,
+ "scrapeOptions": {},
+ "integration": "dify",
+ }
+ if params:
+ json_data.update(params)
+ response = self._post_request(f"{self.base_url}/v1/search", json_data, headers)
+ if response.status_code == 200:
+ response_data = response.json()
+ if not response_data.get("success"):
+ raise Exception(f"Search failed. Error: {response_data.get('warning', 'Unknown error')}")
+ return cast(dict[str, Any], response_data)
+ elif response.status_code in {402, 409, 500, 429, 408}:
+ self._handle_error(response, "perform search")
+ return {} # Avoid additional exception after handling error
+ else:
+ raise Exception(f"Failed to perform search. Status code: {response.status_code}")
diff --git a/api/core/rag/extractor/helpers.py b/api/core/rag/extractor/helpers.py
index 69ca9d5d63..3d2fb55d9a 100644
--- a/api/core/rag/extractor/helpers.py
+++ b/api/core/rag/extractor/helpers.py
@@ -1,7 +1,6 @@
"""Document loader helpers."""
import concurrent.futures
-from pathlib import Path
from typing import NamedTuple, Optional, cast
@@ -16,7 +15,7 @@ class FileEncoding(NamedTuple):
"""The language of the file."""
-def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding]:
+def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1024 * 1024) -> list[FileEncoding]:
"""Try to detect the file encoding.
Returns a list of `FileEncoding` tuples with the detected encodings ordered
@@ -25,11 +24,16 @@ def detect_file_encodings(file_path: str, timeout: int = 5) -> list[FileEncoding
Args:
file_path: The path to the file to detect the encoding for.
timeout: The timeout in seconds for the encoding detection.
+ sample_size: The number of bytes to read for encoding detection. Default is 1MB.
+ For large files, reading only a sample is sufficient and prevents timeout.
"""
import chardet
def read_and_detect(file_path: str) -> list[dict]:
- rawdata = Path(file_path).read_bytes()
+ with open(file_path, "rb") as f:
+ # Read only a sample of the file for encoding detection
+ # This prevents timeout on large files while still providing accurate encoding detection
+ rawdata = f.read(sample_size)
return cast(list[dict], chardet.detect_all(rawdata))
with concurrent.futures.ThreadPoolExecutor() as executor:
diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py
index 849852ac23..c97765b1dc 100644
--- a/api/core/rag/extractor/markdown_extractor.py
+++ b/api/core/rag/extractor/markdown_extractor.py
@@ -68,22 +68,17 @@ class MarkdownExtractor(BaseExtractor):
continue
header_match = re.match(r"^#+\s", line)
if header_match:
- if current_header is not None:
- markdown_tups.append((current_header, current_text))
-
+ markdown_tups.append((current_header, current_text))
current_header = line
current_text = ""
else:
current_text += line + "\n"
markdown_tups.append((current_header, current_text))
- if current_header is not None:
- # pass linting, assert keys are defined
- markdown_tups = [
- (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) for key, value in markdown_tups
- ]
- else:
- markdown_tups = [(key, re.sub("\n", "", value)) for key, value in markdown_tups]
+ markdown_tups = [
+ (re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value))
+ for key, value in markdown_tups
+ ]
return markdown_tups
diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py
index 4e14800d0a..eca955ddd1 100644
--- a/api/core/rag/extractor/notion_extractor.py
+++ b/api/core/rag/extractor/notion_extractor.py
@@ -79,55 +79,71 @@ class NotionExtractor(BaseExtractor):
def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]:
"""Get all the pages from a Notion database."""
assert self._notion_access_token is not None, "Notion access token is required"
- res = requests.post(
- DATABASE_URL_TMPL.format(database_id=database_id),
- headers={
- "Authorization": "Bearer " + self._notion_access_token,
- "Content-Type": "application/json",
- "Notion-Version": "2022-06-28",
- },
- json=query_dict,
- )
-
- data = res.json()
database_content = []
- if "results" not in data or data["results"] is None:
- return []
- for result in data["results"]:
- properties = result["properties"]
- data = {}
- value: Any
- for property_name, property_value in properties.items():
- type = property_value["type"]
- if type == "multi_select":
- value = []
- multi_select_list = property_value[type]
- for multi_select in multi_select_list:
- value.append(multi_select["name"])
- elif type in {"rich_text", "title"}:
- if len(property_value[type]) > 0:
- value = property_value[type][0]["plain_text"]
+ next_cursor = None
+ has_more = True
+
+ while has_more:
+ current_query = query_dict.copy()
+ if next_cursor:
+ current_query["start_cursor"] = next_cursor
+
+ res = requests.post(
+ DATABASE_URL_TMPL.format(database_id=database_id),
+ headers={
+ "Authorization": "Bearer " + self._notion_access_token,
+ "Content-Type": "application/json",
+ "Notion-Version": "2022-06-28",
+ },
+ json=current_query,
+ )
+
+ response_data = res.json()
+
+ if "results" not in response_data or response_data["results"] is None:
+ break
+
+ for result in response_data["results"]:
+ properties = result["properties"]
+ data = {}
+ value: Any
+ for property_name, property_value in properties.items():
+ type = property_value["type"]
+ if type == "multi_select":
+ value = []
+ multi_select_list = property_value[type]
+ for multi_select in multi_select_list:
+ value.append(multi_select["name"])
+ elif type in {"rich_text", "title"}:
+ if len(property_value[type]) > 0:
+ value = property_value[type][0]["plain_text"]
+ else:
+ value = ""
+ elif type in {"select", "status"}:
+ if property_value[type]:
+ value = property_value[type]["name"]
+ else:
+ value = ""
else:
- value = ""
- elif type in {"select", "status"}:
- if property_value[type]:
- value = property_value[type]["name"]
+ value = property_value[type]
+ data[property_name] = value
+ row_dict = {k: v for k, v in data.items() if v}
+ row_content = ""
+ for key, value in row_dict.items():
+ if isinstance(value, dict):
+ value_dict = {k: v for k, v in value.items() if v}
+ value_content = "".join(f"{k}:{v} " for k, v in value_dict.items())
+ row_content = row_content + f"{key}:{value_content}\n"
else:
- value = ""
- else:
- value = property_value[type]
- data[property_name] = value
- row_dict = {k: v for k, v in data.items() if v}
- row_content = ""
- for key, value in row_dict.items():
- if isinstance(value, dict):
- value_dict = {k: v for k, v in value.items() if v}
- value_content = "".join(f"{k}:{v} " for k, v in value_dict.items())
- row_content = row_content + f"{key}:{value_content}\n"
- else:
- row_content = row_content + f"{key}:{value}\n"
- database_content.append(row_content)
+ row_content = row_content + f"{key}:{value}\n"
+ database_content.append(row_content)
+
+ has_more = response_data.get("has_more", False)
+ next_cursor = response_data.get("next_cursor")
+
+ if not database_content:
+ return []
return [Document(page_content="\n".join(database_content))]
diff --git a/api/core/rag/extractor/text_extractor.py b/api/core/rag/extractor/text_extractor.py
index b2b51d71d7..a00d328cb1 100644
--- a/api/core/rag/extractor/text_extractor.py
+++ b/api/core/rag/extractor/text_extractor.py
@@ -36,8 +36,12 @@ class TextExtractor(BaseExtractor):
break
except UnicodeDecodeError:
continue
+ else:
+ raise RuntimeError(
+ f"Decode failed: {self._file_path}, all detected encodings failed. Original error: {e}"
+ )
else:
- raise RuntimeError(f"Error loading {self._file_path}") from e
+ raise RuntimeError(f"Decode failed: {self._file_path}, specified encoding failed. Original error: {e}")
except Exception as e:
raise RuntimeError(f"Error loading {self._file_path}") from e
diff --git a/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py
deleted file mode 100644
index dd8a979e70..0000000000
--- a/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import logging
-
-from core.rag.extractor.extractor_base import BaseExtractor
-from core.rag.models.document import Document
-
-logger = logging.getLogger(__name__)
-
-
-class UnstructuredPDFExtractor(BaseExtractor):
- """Load pdf files.
-
-
- Args:
- file_path: Path to the file to load.
-
- api_url: Unstructured API URL
-
- api_key: Unstructured API Key
- """
-
- def __init__(self, file_path: str, api_url: str, api_key: str):
- """Initialize with file path."""
- self._file_path = file_path
- self._api_url = api_url
- self._api_key = api_key
-
- def extract(self) -> list[Document]:
- if self._api_url:
- from unstructured.partition.api import partition_via_api
-
- elements = partition_via_api(
- filename=self._file_path, api_url=self._api_url, api_key=self._api_key, strategy="auto"
- )
- else:
- from unstructured.partition.pdf import partition_pdf
-
- elements = partition_pdf(filename=self._file_path, strategy="auto")
-
- from unstructured.chunking.title import chunk_by_title
-
- chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
- documents = []
- for chunk in chunks:
- text = chunk.text.strip()
- documents.append(Document(page_content=text))
-
- return documents
diff --git a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py
deleted file mode 100644
index 22dfdd2075..0000000000
--- a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import logging
-
-from core.rag.extractor.extractor_base import BaseExtractor
-from core.rag.models.document import Document
-
-logger = logging.getLogger(__name__)
-
-
-class UnstructuredTextExtractor(BaseExtractor):
- """Load msg files.
-
-
- Args:
- file_path: Path to the file to load.
- """
-
- def __init__(self, file_path: str, api_url: str):
- """Initialize with file path."""
- self._file_path = file_path
- self._api_url = api_url
-
- def extract(self) -> list[Document]:
- from unstructured.partition.text import partition_text
-
- elements = partition_text(filename=self._file_path)
- from unstructured.chunking.title import chunk_by_title
-
- chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
- documents = []
- for chunk in chunks:
- text = chunk.text.strip()
- documents.append(Document(page_content=text))
-
- return documents
diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py
index bff0acc48f..14363de7d4 100644
--- a/api/core/rag/extractor/word_extractor.py
+++ b/api/core/rag/extractor/word_extractor.py
@@ -238,9 +238,11 @@ class WordExtractor(BaseExtractor):
paragraph_content = []
for run in paragraph.runs:
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
+ # Process drawing type images
drawing_elements = run.element.findall(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing"
)
+ has_drawing = False
for drawing in drawing_elements:
blip_elements = drawing.findall(
".//{http://schemas.openxmlformats.org/drawingml/2006/main}blip"
@@ -252,6 +254,34 @@ class WordExtractor(BaseExtractor):
if embed_id:
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
+ has_drawing = True
+ paragraph_content.append(image_map[image_part])
+ # Process pict type images
+ shape_elements = run.element.findall(
+ ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
+ )
+ for shape in shape_elements:
+ # Find image data in VML
+ shape_image = shape.find(
+ ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}binData"
+ )
+ if shape_image is not None and shape_image.text:
+ image_id = shape_image.get(
+ "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
+ )
+ if image_id and image_id in doc.part.rels:
+ image_part = doc.part.rels[image_id].target_part
+ if image_part in image_map and not has_drawing:
+ paragraph_content.append(image_map[image_part])
+ # Find imagedata element in VML
+ image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
+ if image_data is not None:
+ image_id = image_data.get("id") or image_data.get(
+ "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
+ )
+ if image_id and image_id in doc.part.rels:
+ image_part = doc.part.rels[image_id].target_part
+ if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
if run.text.strip():
paragraph_content.append(run.text.strip())
diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py
index dca84b9041..9b90bd2bb3 100644
--- a/api/core/rag/index_processor/processor/paragraph_index_processor.py
+++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py
@@ -76,6 +76,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
vector.create(documents)
+ with_keywords = False
if with_keywords:
keywords_list = kwargs.get("keywords_list")
keyword = Keyword(dataset)
@@ -91,6 +92,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
vector.delete_by_ids(node_ids)
else:
vector.delete()
+ with_keywords = False
if with_keywords:
keyword = Keyword(dataset)
if node_ids:
diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py
index e778b2cec4..75f3153697 100644
--- a/api/core/rag/index_processor/processor/qa_index_processor.py
+++ b/api/core/rag/index_processor/processor/qa_index_processor.py
@@ -104,7 +104,7 @@ class QAIndexProcessor(BaseIndexProcessor):
def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
# check file type
- if not file.filename or not file.filename.endswith(".csv"):
+ if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed")
try:
diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py
index 6978860529..3d0f0f97bc 100644
--- a/api/core/rag/retrieval/dataset_retrieval.py
+++ b/api/core/rag/retrieval/dataset_retrieval.py
@@ -9,6 +9,7 @@ from typing import Any, Optional, Union, cast
from flask import Flask, current_app
from sqlalchemy import Float, and_, or_, text
from sqlalchemy import cast as sqlalchemy_cast
+from sqlalchemy.orm import Session
from core.app.app_config.entities import (
DatasetEntity,
@@ -496,6 +497,8 @@ class DatasetRetrieval:
all_documents = self.calculate_keyword_score(query, all_documents, top_k)
elif index_type == "high_quality":
all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
+ else:
+ all_documents = all_documents[:top_k] if top_k else all_documents
self._on_query(query, dataset_ids, app_id, user_from, user_id)
@@ -596,7 +599,8 @@ class DatasetRetrieval:
metadata_condition: Optional[MetadataCondition] = None,
):
with flask_app.app_context():
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ with Session(db.engine) as session:
+ dataset = session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
return []
@@ -1008,6 +1012,9 @@ class DatasetRetrieval:
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
):
+ if value is None:
+ return
+
key = f"{metadata_name}_{sequence}"
key_value = f"{metadata_name}_{sequence}_value"
match condition:
@@ -1130,7 +1137,7 @@ class DatasetRetrieval:
def _get_prompt_template(
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
):
- model_mode = ModelMode.value_of(mode)
+ model_mode = ModelMode(mode)
input_text = query
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py
index 0fb1bcb2e0..bcaf299892 100644
--- a/api/core/rag/splitter/fixed_text_splitter.py
+++ b/api/core/rag/splitter/fixed_text_splitter.py
@@ -102,6 +102,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
splits = text.split()
else:
splits = text.split(separator)
+ splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
else:
splits = list(text)
splits = [s for s in splits if (s not in {"", "\n"})]
diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py
index b711e8434a..529d8ccd27 100644
--- a/api/core/rag/splitter/text_splitter.py
+++ b/api/core/rag/splitter/text_splitter.py
@@ -10,7 +10,6 @@ from typing import (
Any,
Literal,
Optional,
- TypedDict,
TypeVar,
Union,
)
@@ -168,167 +167,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
raise NotImplementedError
-class CharacterTextSplitter(TextSplitter):
- """Splitting text that looks at characters."""
-
- def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
- """Create a new TextSplitter."""
- super().__init__(**kwargs)
- self._separator = separator
-
- def split_text(self, text: str) -> list[str]:
- """Split incoming text and return chunks."""
- # First we naively split the large input into a bunch of smaller ones.
- splits = _split_text_with_regex(text, self._separator, self._keep_separator)
- _separator = "" if self._keep_separator else self._separator
- _good_splits_lengths = [] # cache the lengths of the splits
- if splits:
- _good_splits_lengths.extend(self._length_function(splits))
- return self._merge_splits(splits, _separator, _good_splits_lengths)
-
-
-class LineType(TypedDict):
- """Line type as typed dict."""
-
- metadata: dict[str, str]
- content: str
-
-
-class HeaderType(TypedDict):
- """Header type as typed dict."""
-
- level: int
- name: str
- data: str
-
-
-class MarkdownHeaderTextSplitter:
- """Splitting markdown files based on specified headers."""
-
- def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False):
- """Create a new MarkdownHeaderTextSplitter.
-
- Args:
- headers_to_split_on: Headers we want to track
- return_each_line: Return each line w/ associated headers
- """
- # Output line-by-line or aggregated into chunks w/ common headers
- self.return_each_line = return_each_line
- # Given the headers we want to split on,
- # (e.g., "#, ##, etc") order by length
- self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True)
-
- def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]:
- """Combine lines with common metadata into chunks
- Args:
- lines: Line of text / associated header metadata
- """
- aggregated_chunks: list[LineType] = []
-
- for line in lines:
- if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]:
- # If the last line in the aggregated list
- # has the same metadata as the current line,
- # append the current content to the last lines's content
- aggregated_chunks[-1]["content"] += " \n" + line["content"]
- else:
- # Otherwise, append the current line to the aggregated list
- aggregated_chunks.append(line)
-
- return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks]
-
- def split_text(self, text: str) -> list[Document]:
- """Split markdown file
- Args:
- text: Markdown file"""
-
- # Split the input text by newline character ("\n").
- lines = text.split("\n")
- # Final output
- lines_with_metadata: list[LineType] = []
- # Content and metadata of the chunk currently being processed
- current_content: list[str] = []
- current_metadata: dict[str, str] = {}
- # Keep track of the nested header structure
- # header_stack: List[Dict[str, Union[int, str]]] = []
- header_stack: list[HeaderType] = []
- initial_metadata: dict[str, str] = {}
-
- for line in lines:
- stripped_line = line.strip()
- # Check each line against each of the header types (e.g., #, ##)
- for sep, name in self.headers_to_split_on:
- # Check if line starts with a header that we intend to split on
- if stripped_line.startswith(sep) and (
- # Header with no text OR header is followed by space
- # Both are valid conditions that sep is being used a header
- len(stripped_line) == len(sep) or stripped_line[len(sep)] == " "
- ):
- # Ensure we are tracking the header as metadata
- if name is not None:
- # Get the current header level
- current_header_level = sep.count("#")
-
- # Pop out headers of lower or same level from the stack
- while header_stack and header_stack[-1]["level"] >= current_header_level:
- # We have encountered a new header
- # at the same or higher level
- popped_header = header_stack.pop()
- # Clear the metadata for the
- # popped header in initial_metadata
- if popped_header["name"] in initial_metadata:
- initial_metadata.pop(popped_header["name"])
-
- # Push the current header to the stack
- header: HeaderType = {
- "level": current_header_level,
- "name": name,
- "data": stripped_line[len(sep) :].strip(),
- }
- header_stack.append(header)
- # Update initial_metadata with the current header
- initial_metadata[name] = header["data"]
-
- # Add the previous line to the lines_with_metadata
- # only if current_content is not empty
- if current_content:
- lines_with_metadata.append(
- {
- "content": "\n".join(current_content),
- "metadata": current_metadata.copy(),
- }
- )
- current_content.clear()
-
- break
- else:
- if stripped_line:
- current_content.append(stripped_line)
- elif current_content:
- lines_with_metadata.append(
- {
- "content": "\n".join(current_content),
- "metadata": current_metadata.copy(),
- }
- )
- current_content.clear()
-
- current_metadata = initial_metadata.copy()
-
- if current_content:
- lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata})
-
- # lines_with_metadata has each line with associated header metadata
- # aggregate these into chunks based on common metadata
- if not self.return_each_line:
- return self.aggregate_lines_to_chunks(lines_with_metadata)
- else:
- return [
- Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata
- ]
-
-
-# should be in newer Python versions (3.10+)
# @dataclass(frozen=True, kw_only=True, slots=True)
@dataclass(frozen=True)
class Tokenizer:
diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py
index 6452317120..052ba1c2cb 100644
--- a/api/core/repositories/__init__.py
+++ b/api/core/repositories/__init__.py
@@ -5,8 +5,11 @@ This package contains concrete implementations of the repository interfaces
defined in the core.workflow.repository package.
"""
+from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
+ "DifyCoreRepositoryFactory",
+ "RepositoryImportError",
"SQLAlchemyWorkflowNodeExecutionRepository",
]
diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py
new file mode 100644
index 0000000000..4118aa61c7
--- /dev/null
+++ b/api/core/repositories/factory.py
@@ -0,0 +1,224 @@
+"""
+Repository factory for dynamically creating repository instances based on configuration.
+
+This module provides a Django-like settings system for repository implementations,
+allowing users to configure different repository backends through string paths.
+"""
+
+import importlib
+import inspect
+import logging
+from typing import Protocol, Union
+
+from sqlalchemy.engine import Engine
+from sqlalchemy.orm import sessionmaker
+
+from configs import dify_config
+from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
+from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from models import Account, EndUser
+from models.enums import WorkflowRunTriggeredFrom
+from models.workflow import WorkflowNodeExecutionTriggeredFrom
+
+logger = logging.getLogger(__name__)
+
+
+class RepositoryImportError(Exception):
+ """Raised when a repository implementation cannot be imported or instantiated."""
+
+ pass
+
+
+class DifyCoreRepositoryFactory:
+ """
+ Factory for creating repository instances based on configuration.
+
+ This factory supports Django-like settings where repository implementations
+ are specified as module paths (e.g., 'module.submodule.ClassName').
+ """
+
+ @staticmethod
+ def _import_class(class_path: str) -> type:
+ """
+ Import a class from a module path string.
+
+ Args:
+ class_path: Full module path to the class (e.g., 'module.submodule.ClassName')
+
+ Returns:
+ The imported class
+
+ Raises:
+ RepositoryImportError: If the class cannot be imported
+ """
+ try:
+ module_path, class_name = class_path.rsplit(".", 1)
+ module = importlib.import_module(module_path)
+ repo_class = getattr(module, class_name)
+ assert isinstance(repo_class, type)
+ return repo_class
+ except (ValueError, ImportError, AttributeError) as e:
+ raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e
+
+ @staticmethod
+ def _validate_repository_interface(repository_class: type, expected_interface: type[Protocol]) -> None: # type: ignore
+ """
+ Validate that a class implements the expected repository interface.
+
+ Args:
+ repository_class: The class to validate
+ expected_interface: The expected interface/protocol
+
+ Raises:
+ RepositoryImportError: If the class doesn't implement the interface
+ """
+ # Check if the class has all required methods from the protocol
+ required_methods = [
+ method
+ for method in dir(expected_interface)
+ if not method.startswith("_") and callable(getattr(expected_interface, method, None))
+ ]
+
+ missing_methods = []
+ for method_name in required_methods:
+ if not hasattr(repository_class, method_name):
+ missing_methods.append(method_name)
+
+ if missing_methods:
+ raise RepositoryImportError(
+ f"Repository class '{repository_class.__name__}' does not implement required methods "
+ f"{missing_methods} from interface '{expected_interface.__name__}'"
+ )
+
+ @staticmethod
+ def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None:
+ """
+ Validate that a repository class constructor accepts required parameters.
+
+ Args:
+ repository_class: The class to validate
+ required_params: List of required parameter names
+
+ Raises:
+ RepositoryImportError: If the constructor doesn't accept required parameters
+ """
+
+ try:
+ # MyPy may flag the line below with the following error:
+ #
+ # > Accessing "__init__" on an instance is unsound, since
+ # > instance.__init__ could be from an incompatible subclass.
+ #
+ # Despite this, we need to ensure that the constructor of `repository_class`
+ # has a compatible signature.
+ signature = inspect.signature(repository_class.__init__) # type: ignore[misc]
+ param_names = list(signature.parameters.keys())
+
+ # Remove 'self' parameter
+ if "self" in param_names:
+ param_names.remove("self")
+
+ missing_params = [param for param in required_params if param not in param_names]
+ if missing_params:
+ raise RepositoryImportError(
+ f"Repository class '{repository_class.__name__}' constructor does not accept required parameters: "
+ f"{missing_params}. Expected parameters: {required_params}"
+ )
+ except Exception as e:
+ raise RepositoryImportError(
+ f"Failed to validate constructor signature for '{repository_class.__name__}': {e}"
+ ) from e
+
+ @classmethod
+ def create_workflow_execution_repository(
+ cls,
+ session_factory: Union[sessionmaker, Engine],
+ user: Union[Account, EndUser],
+ app_id: str,
+ triggered_from: WorkflowRunTriggeredFrom,
+ ) -> WorkflowExecutionRepository:
+ """
+ Create a WorkflowExecutionRepository instance based on configuration.
+
+ Args:
+ session_factory: SQLAlchemy sessionmaker or engine
+ user: Account or EndUser object
+ app_id: Application ID
+ triggered_from: Source of the execution trigger
+
+ Returns:
+ Configured WorkflowExecutionRepository instance
+
+ Raises:
+ RepositoryImportError: If the configured repository cannot be created
+ """
+ class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY
+ logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}")
+
+ try:
+ repository_class = cls._import_class(class_path)
+ cls._validate_repository_interface(repository_class, WorkflowExecutionRepository)
+ cls._validate_constructor_signature(
+ repository_class, ["session_factory", "user", "app_id", "triggered_from"]
+ )
+
+ return repository_class( # type: ignore[no-any-return]
+ session_factory=session_factory,
+ user=user,
+ app_id=app_id,
+ triggered_from=triggered_from,
+ )
+ except RepositoryImportError:
+ # Re-raise our custom errors as-is
+ raise
+ except Exception as e:
+ logger.exception("Failed to create WorkflowExecutionRepository")
+ raise RepositoryImportError(f"Failed to create WorkflowExecutionRepository from '{class_path}': {e}") from e
+
+ @classmethod
+ def create_workflow_node_execution_repository(
+ cls,
+ session_factory: Union[sessionmaker, Engine],
+ user: Union[Account, EndUser],
+ app_id: str,
+ triggered_from: WorkflowNodeExecutionTriggeredFrom,
+ ) -> WorkflowNodeExecutionRepository:
+ """
+ Create a WorkflowNodeExecutionRepository instance based on configuration.
+
+ Args:
+ session_factory: SQLAlchemy sessionmaker or engine
+ user: Account or EndUser object
+ app_id: Application ID
+ triggered_from: Source of the execution trigger
+
+ Returns:
+ Configured WorkflowNodeExecutionRepository instance
+
+ Raises:
+ RepositoryImportError: If the configured repository cannot be created
+ """
+ class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY
+ logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}")
+
+ try:
+ repository_class = cls._import_class(class_path)
+ cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository)
+ cls._validate_constructor_signature(
+ repository_class, ["session_factory", "user", "app_id", "triggered_from"]
+ )
+
+ return repository_class( # type: ignore[no-any-return]
+ session_factory=session_factory,
+ user=user,
+ app_id=app_id,
+ triggered_from=triggered_from,
+ )
+ except RepositoryImportError:
+ # Re-raise our custom errors as-is
+ raise
+ except Exception as e:
+ logger.exception("Failed to create WorkflowNodeExecutionRepository")
+ raise RepositoryImportError(
+ f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}"
+ ) from e
diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py
index e5ead9dc56..c579ff4028 100644
--- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py
+++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py
@@ -6,7 +6,6 @@ import json
import logging
from typing import Optional, Union
-from sqlalchemy import func, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
@@ -16,6 +15,8 @@ from core.workflow.entities.workflow_execution import (
WorkflowType,
)
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
+from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
+from libs.helper import extract_tenant_id
from models import (
Account,
CreatorUserRole,
@@ -66,7 +67,7 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
)
# Extract tenant_id from user
- tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
+ tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id
@@ -146,26 +147,17 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
db_model.workflow_id = domain_model.workflow_id
db_model.triggered_from = self._triggered_from
- # Check if this is a new record
- with self._session_factory() as session:
- existing = session.scalar(select(WorkflowRun).where(WorkflowRun.id == domain_model.id_))
- if not existing:
- # For new records, get the next sequence number
- stmt = select(func.max(WorkflowRun.sequence_number)).where(
- WorkflowRun.app_id == self._app_id,
- WorkflowRun.tenant_id == self._tenant_id,
- )
- max_sequence = session.scalar(stmt)
- db_model.sequence_number = (max_sequence or 0) + 1
- else:
- # For updates, keep the existing sequence number
- db_model.sequence_number = existing.sequence_number
+ # No sequence number generation needed anymore
db_model.type = domain_model.workflow_type
db_model.version = domain_model.workflow_version
db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
- db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None
+ db_model.outputs = (
+ json.dumps(WorkflowRuntimeTypeConverter().to_json_encodable(domain_model.outputs))
+ if domain_model.outputs
+ else None
+ )
db_model.status = domain_model.status
db_model.error = domain_model.error_message if domain_model.error_message else None
db_model.total_tokens = domain_model.total_tokens
@@ -213,44 +205,3 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
# Update the in-memory cache for faster subsequent lookups
logger.debug(f"Updating cache for execution_id: {db_model.id}")
self._execution_cache[db_model.id] = db_model
-
- def get(self, execution_id: str) -> Optional[WorkflowExecution]:
- """
- Retrieve a WorkflowExecution by its ID.
-
- First checks the in-memory cache, and if not found, queries the database.
- If found in the database, adds it to the cache for future lookups.
-
- Args:
- execution_id: The workflow execution ID
-
- Returns:
- The WorkflowExecution instance if found, None otherwise
- """
- # First check the cache
- if execution_id in self._execution_cache:
- logger.debug(f"Cache hit for execution_id: {execution_id}")
- # Convert cached DB model to domain model
- cached_db_model = self._execution_cache[execution_id]
- return self._to_domain_model(cached_db_model)
-
- # If not in cache, query the database
- logger.debug(f"Cache miss for execution_id: {execution_id}, querying database")
- with self._session_factory() as session:
- stmt = select(WorkflowRun).where(
- WorkflowRun.id == execution_id,
- WorkflowRun.tenant_id == self._tenant_id,
- )
-
- if self._app_id:
- stmt = stmt.where(WorkflowRun.app_id == self._app_id)
-
- db_model = session.scalar(stmt)
- if db_model:
- # Add DB model to cache
- self._execution_cache[execution_id] = db_model
-
- # Convert to domain model and return
- return self._to_domain_model(db_model)
-
- return None
diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
index 2f27442616..d4a31390f8 100644
--- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
+++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
@@ -7,7 +7,7 @@ import logging
from collections.abc import Sequence
from typing import Optional, Union
-from sqlalchemy import UnaryExpression, asc, delete, desc, select
+from sqlalchemy import UnaryExpression, asc, desc, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
@@ -19,6 +19,8 @@ from core.workflow.entities.workflow_node_execution import (
)
from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
+from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
+from libs.helper import extract_tenant_id
from models import (
Account,
CreatorUserRole,
@@ -69,7 +71,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
)
# Extract tenant_id from user
- tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
+ tenant_id = extract_tenant_id(user)
if not tenant_id:
raise ValueError("User must have a tenant_id or current_tenant_id")
self._tenant_id = tenant_id
@@ -146,6 +148,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor")
+ json_converter = WorkflowRuntimeTypeConverter()
db_model = WorkflowNodeExecutionModel()
db_model.id = domain_model.id
db_model.tenant_id = self._tenant_id
@@ -160,9 +163,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
db_model.node_id = domain_model.node_id
db_model.node_type = domain_model.node_type
db_model.title = domain_model.title
- db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
- db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None
- db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None
+ db_model.inputs = (
+ json.dumps(json_converter.to_json_encodable(domain_model.inputs)) if domain_model.inputs else None
+ )
+ db_model.process_data = (
+ json.dumps(json_converter.to_json_encodable(domain_model.process_data))
+ if domain_model.process_data
+ else None
+ )
+ db_model.outputs = (
+ json.dumps(json_converter.to_json_encodable(domain_model.outputs)) if domain_model.outputs else None
+ )
db_model.status = domain_model.status
db_model.error = domain_model.error
db_model.elapsed_time = domain_model.elapsed_time
@@ -207,47 +218,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
self._node_execution_cache[db_model.node_execution_id] = db_model
- def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
- """
- Retrieve a NodeExecution by its node_execution_id.
-
- First checks the in-memory cache, and if not found, queries the database.
- If found in the database, adds it to the cache for future lookups.
-
- Args:
- node_execution_id: The node execution ID
-
- Returns:
- The NodeExecution instance if found, None otherwise
- """
- # First check the cache
- if node_execution_id in self._node_execution_cache:
- logger.debug(f"Cache hit for node_execution_id: {node_execution_id}")
- # Convert cached DB model to domain model
- cached_db_model = self._node_execution_cache[node_execution_id]
- return self._to_domain_model(cached_db_model)
-
- # If not in cache, query the database
- logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database")
- with self._session_factory() as session:
- stmt = select(WorkflowNodeExecutionModel).where(
- WorkflowNodeExecutionModel.node_execution_id == node_execution_id,
- WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
- )
-
- if self._app_id:
- stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
-
- db_model = session.scalar(stmt)
- if db_model:
- # Add DB model to cache
- self._node_execution_cache[node_execution_id] = db_model
-
- # Convert to domain model and return
- return self._to_domain_model(db_model)
-
- return None
-
def get_db_models_by_workflow_run(
self,
workflow_run_id: str,
@@ -333,68 +303,3 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
domain_models.append(domain_model)
return domain_models
-
- def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
- """
- Retrieve all running NodeExecution instances for a specific workflow run.
-
- This method queries the database directly and updates the cache with any
- retrieved executions that have a node_execution_id.
-
- Args:
- workflow_run_id: The workflow run ID
-
- Returns:
- A list of running NodeExecution instances
- """
- with self._session_factory() as session:
- stmt = select(WorkflowNodeExecutionModel).where(
- WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
- WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
- WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING,
- WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
- )
-
- if self._app_id:
- stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
-
- db_models = session.scalars(stmt).all()
- domain_models = []
-
- for model in db_models:
- # Update cache if node_execution_id is present
- if model.node_execution_id:
- self._node_execution_cache[model.node_execution_id] = model
-
- # Convert to domain model
- domain_model = self._to_domain_model(model)
- domain_models.append(domain_model)
-
- return domain_models
-
- def clear(self) -> None:
- """
- Clear all WorkflowNodeExecution records for the current tenant_id and app_id.
-
- This method deletes all WorkflowNodeExecution records that match the tenant_id
- and app_id (if provided) associated with this repository instance.
- It also clears the in-memory cache.
- """
- with self._session_factory() as session:
- stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id)
-
- if self._app_id:
- stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
-
- result = session.execute(stmt)
- session.commit()
-
- deleted_count = result.rowcount
- logger.info(
- f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
- + (f" and app {self._app_id}" if self._app_id else "")
- )
-
- # Clear the in-memory cache
- self._node_execution_cache.clear()
- logger.info("Cleared in-memory node execution cache")
diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py
index c9e157cb77..ddec7b1329 100644
--- a/api/core/tools/__base/tool_runtime.py
+++ b/api/core/tools/__base/tool_runtime.py
@@ -4,7 +4,7 @@ from openai import BaseModel
from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom
-from core.tools.entities.tool_entities import ToolInvokeFrom
+from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
class ToolRuntime(BaseModel):
@@ -17,6 +17,7 @@ class ToolRuntime(BaseModel):
invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict)
+ credential_type: CredentialType = Field(default=CredentialType.API_KEY)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py
index cf75bd3d7e..a70ded9efd 100644
--- a/api/core/tools/builtin_tool/provider.py
+++ b/api/core/tools/builtin_tool/provider.py
@@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool
-from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
+from core.tools.entities.tool_entities import (
+ CredentialType,
+ OAuthSchema,
+ ToolEntity,
+ ToolProviderEntity,
+ ToolProviderType,
+)
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
ToolProviderNotFoundError,
@@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController):
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
credentials_schema.append(credential_dict)
+ oauth_schema = None
+ if provider_yaml.get("oauth_schema", None) is not None:
+ oauth_schema = OAuthSchema(
+ client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
+ credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
+ )
+
super().__init__(
entity=ToolProviderEntity(
identity=provider_yaml["identity"],
credentials_schema=credentials_schema,
+ oauth_schema=oauth_schema,
),
)
@@ -97,10 +111,39 @@ class BuiltinToolProviderController(ToolProviderController):
:return: the credentials schema
"""
- if not self.entity.credentials_schema:
- return []
+ return self.get_credentials_schema_by_type(CredentialType.API_KEY.value)
+
+ def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
+ """
+ returns the credentials schema of the provider
- return self.entity.credentials_schema.copy()
+ :param credential_type: the type of the credential
+ :return: the credentials schema of the provider
+ """
+ if credential_type == CredentialType.OAUTH2.value:
+ return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
+ if credential_type == CredentialType.API_KEY.value:
+ return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
+ raise ValueError(f"Invalid credential type: {credential_type}")
+
+ def get_oauth_client_schema(self) -> list[ProviderConfig]:
+ """
+ returns the oauth client schema of the provider
+
+ :return: the oauth client schema
+ """
+ return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
+
+ def get_supported_credential_types(self) -> list[str]:
+ """
+ returns the credential support type of the provider
+ """
+ types = []
+ if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0:
+ types.append(CredentialType.API_KEY.value)
+ if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0:
+ types.append(CredentialType.OAUTH2.value)
+ return types
def get_tools(self) -> list[BuiltinTool]:
"""
@@ -123,7 +166,11 @@ class BuiltinToolProviderController(ToolProviderController):
:return: whether the provider needs credentials
"""
- return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
+ return (
+ self.entity.credentials_schema is not None
+ and len(self.entity.credentials_schema) != 0
+ or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0)
+ )
@property
def provider_type(self) -> ToolProviderType:
diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py
index 9b104b00f5..f191968812 100644
--- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py
+++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py
@@ -31,6 +31,14 @@ class TTSTool(BuiltinTool):
model_type=ModelType.TTS,
model=model,
)
+ if not voice:
+ voices = model_instance.get_tts_voices()
+ if voices:
+ voice = voices[0].get("value")
+ if not voice:
+ raise ValueError("Sorry, no voice available.")
+ else:
+ raise ValueError("Sorry, no voice available.")
tts = model_instance.invoke_tts(
content_text=tool_parameters.get("text"), # type: ignore
user=user_id,
diff --git a/api/core/tools/builtin_tool/providers/code/tools/simple_code.py b/api/core/tools/builtin_tool/providers/code/tools/simple_code.py
index ab0e155b98..b4e650e0ed 100644
--- a/api/core/tools/builtin_tool/providers/code/tools/simple_code.py
+++ b/api/core/tools/builtin_tool/providers/code/tools/simple_code.py
@@ -4,6 +4,7 @@ from typing import Any, Optional
from core.helper.code_executor.code_executor import CodeExecutor, CodeLanguage
from core.tools.builtin_tool.tool import BuiltinTool
from core.tools.entities.tool_entities import ToolInvokeMessage
+from core.tools.errors import ToolInvokeError
class SimpleCode(BuiltinTool):
@@ -25,6 +26,8 @@ class SimpleCode(BuiltinTool):
if language not in {CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT}:
raise ValueError(f"Only python3 and javascript are supported, not {language}")
- result = CodeExecutor.execute_code(language, "", code)
-
- yield self.create_text_message(result)
+ try:
+ result = CodeExecutor.execute_code(language, "", code)
+ yield self.create_text_message(result)
+ except Exception as e:
+ raise ToolInvokeError(str(e))
diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py
index 3137d32013..fbe1d79137 100644
--- a/api/core/tools/custom_tool/provider.py
+++ b/api/core/tools/custom_tool/provider.py
@@ -39,19 +39,22 @@ class ApiToolProviderController(ToolProviderController):
type=ProviderConfig.Type.SELECT,
options=[
ProviderConfig.Option(value="none", label=I18nObject(en_US="None", zh_Hans="无")),
- ProviderConfig.Option(value="api_key", label=I18nObject(en_US="api_key", zh_Hans="api_key")),
+ ProviderConfig.Option(value="api_key_header", label=I18nObject(en_US="Header", zh_Hans="请求头")),
+ ProviderConfig.Option(
+ value="api_key_query", label=I18nObject(en_US="Query Param", zh_Hans="查询参数")
+ ),
],
default="none",
help=I18nObject(en_US="The auth type of the api provider", zh_Hans="api provider 的认证类型"),
)
]
- if auth_type == ApiProviderAuthType.API_KEY:
+ if auth_type == ApiProviderAuthType.API_KEY_HEADER:
credentials_schema = [
*credentials_schema,
ProviderConfig(
name="api_key_header",
required=False,
- default="api_key",
+ default="Authorization",
type=ProviderConfig.Type.TEXT_INPUT,
help=I18nObject(en_US="The header name of the api key", zh_Hans="携带 api key 的 header 名称"),
),
@@ -74,6 +77,25 @@ class ApiToolProviderController(ToolProviderController):
],
),
]
+ elif auth_type == ApiProviderAuthType.API_KEY_QUERY:
+ credentials_schema = [
+ *credentials_schema,
+ ProviderConfig(
+ name="api_key_query_param",
+ required=False,
+ default="key",
+ type=ProviderConfig.Type.TEXT_INPUT,
+ help=I18nObject(
+ en_US="The query parameter name of the api key", zh_Hans="携带 api key 的查询参数名称"
+ ),
+ ),
+ ProviderConfig(
+ name="api_key_value",
+ required=True,
+ type=ProviderConfig.Type.SECRET_INPUT,
+ help=I18nObject(en_US="The api key", zh_Hans="api key 的值"),
+ ),
+ ]
elif auth_type == ApiProviderAuthType.NONE:
pass
diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py
index 2f5cc6d4c0..10653b9948 100644
--- a/api/core/tools/custom_tool/tool.py
+++ b/api/core/tools/custom_tool/tool.py
@@ -78,8 +78,8 @@ class ApiTool(Tool):
if "auth_type" not in credentials:
raise ToolProviderCredentialValidationError("Missing auth_type")
- if credentials["auth_type"] == "api_key":
- api_key_header = "api_key"
+ if credentials["auth_type"] in ("api_key_header", "api_key"): # backward compatibility:
+ api_key_header = "Authorization"
if "api_key_header" in credentials:
api_key_header = credentials["api_key_header"]
@@ -100,6 +100,11 @@ class ApiTool(Tool):
headers[api_key_header] = credentials["api_key_value"]
+ elif credentials["auth_type"] == "api_key_query":
+ # For query parameter authentication, we don't add anything to headers
+ # The query parameter will be added in do_http_request method
+ pass
+
needed_parameters = [parameter for parameter in (self.api_bundle.parameters or []) if parameter.required]
for parameter in needed_parameters:
if parameter.required and parameter.name not in parameters:
@@ -154,6 +159,15 @@ class ApiTool(Tool):
cookies = {}
files = []
+ # Add API key to query parameters if auth_type is api_key_query
+ if self.runtime and self.runtime.credentials:
+ credentials = self.runtime.credentials
+ if credentials.get("auth_type") == "api_key_query":
+ api_key_query_param = credentials.get("api_key_query_param", "key")
+ api_key_value = credentials.get("api_key_value")
+ if api_key_value:
+ params[api_key_query_param] = api_key_value
+
# check parameters
for parameter in self.api_bundle.openapi.get("parameters", []):
value = self.get_parameter_value(parameter, parameters)
@@ -213,7 +227,8 @@ class ApiTool(Tool):
elif "default" in property:
body[name] = property["default"]
else:
- body[name] = None
+ # omit optional parameters that weren't provided, instead of setting them to None
+ pass
break
# replace path parameters
diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py
index b96c994cff..27ce96b90e 100644
--- a/api/core/tools/entities/api_entities.py
+++ b/api/core/tools/entities/api_entities.py
@@ -1,11 +1,12 @@
-from typing import Literal, Optional
+from datetime import datetime
+from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
-from core.tools.entities.tool_entities import ToolProviderType
+from core.tools.entities.tool_entities import CredentialType, ToolProviderType
class ToolApiEntity(BaseModel):
@@ -18,7 +19,7 @@ class ToolApiEntity(BaseModel):
output_schema: Optional[dict] = None
-ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
+ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow", "mcp"]]
class ToolProviderApiEntity(BaseModel):
@@ -27,6 +28,7 @@ class ToolProviderApiEntity(BaseModel):
name: str # identifier
description: I18nObject
icon: str | dict
+ icon_dark: Optional[str | dict] = Field(default=None, description="The dark icon of the tool")
label: I18nObject # label
type: ToolProviderType
masked_credentials: Optional[dict] = None
@@ -37,6 +39,10 @@ class ToolProviderApiEntity(BaseModel):
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
tools: list[ToolApiEntity] = Field(default_factory=list)
labels: list[str] = Field(default_factory=list)
+ # MCP
+ server_url: Optional[str] = Field(default="", description="The server url of the tool")
+ updated_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
+ server_identifier: Optional[str] = Field(default="", description="The server identifier of the MCP tool")
@field_validator("tools", mode="before")
@classmethod
@@ -52,8 +58,13 @@ class ToolProviderApiEntity(BaseModel):
for parameter in tool.get("parameters"):
if parameter.get("type") == ToolParameter.ToolParameterType.SYSTEM_FILES.value:
parameter["type"] = "files"
+ if parameter.get("input_schema") is None:
+ parameter.pop("input_schema", None)
# -------------
-
+ optional_fields = self.optional_field("server_url", self.server_url)
+ if self.type == ToolProviderType.MCP.value:
+ optional_fields.update(self.optional_field("updated_at", self.updated_at))
+ optional_fields.update(self.optional_field("server_identifier", self.server_identifier))
return {
"id": self.id,
"author": self.author,
@@ -62,6 +73,7 @@ class ToolProviderApiEntity(BaseModel):
"plugin_unique_identifier": self.plugin_unique_identifier,
"description": self.description.to_dict(),
"icon": self.icon,
+ "icon_dark": self.icon_dark,
"label": self.label.to_dict(),
"type": self.type.value,
"team_credentials": self.masked_credentials,
@@ -69,4 +81,28 @@ class ToolProviderApiEntity(BaseModel):
"allow_delete": self.allow_delete,
"tools": tools,
"labels": self.labels,
+ **optional_fields,
}
+
+ def optional_field(self, key: str, value: Any) -> dict:
+ """Return dict with key-value if value is truthy, empty dict otherwise."""
+ return {key: value} if value else {}
+
+
+class ToolProviderCredentialApiEntity(BaseModel):
+ id: str = Field(description="The unique id of the credential")
+ name: str = Field(description="The name of the credential")
+ provider: str = Field(description="The provider of the credential")
+ credential_type: CredentialType = Field(description="The type of the credential")
+ is_default: bool = Field(
+ default=False, description="Whether the credential is the default credential for the provider in the workspace"
+ )
+ credentials: dict = Field(description="The credentials of the provider")
+
+
+class ToolProviderCredentialInfoApiEntity(BaseModel):
+ supported_credential_types: list[str] = Field(description="The supported credential types of the provider")
+ is_oauth_custom_client_enabled: bool = Field(
+ default=False, description="Whether the OAuth custom client is enabled for the provider"
+ )
+ credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")
diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py
index 03047c0545..5377cbbb69 100644
--- a/api/core/tools/entities/tool_entities.py
+++ b/api/core/tools/entities/tool_entities.py
@@ -8,6 +8,7 @@ from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_seriali
from core.entities.provider_entities import ProviderConfig
from core.plugin.entities.parameters import (
+ MCPServerParameterType,
PluginParameter,
PluginParameterOption,
PluginParameterType,
@@ -15,6 +16,7 @@ from core.plugin.entities.parameters import (
cast_parameter_value,
init_frontend_parameter,
)
+from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
@@ -49,6 +51,7 @@ class ToolProviderType(enum.StrEnum):
API = "api"
APP = "app"
DATASET_RETRIEVAL = "dataset-retrieval"
+ MCP = "mcp"
@classmethod
def value_of(cls, value: str) -> "ToolProviderType":
@@ -94,7 +97,8 @@ class ApiProviderAuthType(Enum):
"""
NONE = "none"
- API_KEY = "api_key"
+ API_KEY_HEADER = "api_key_header"
+ API_KEY_QUERY = "api_key_query"
@classmethod
def value_of(cls, value: str) -> "ApiProviderAuthType":
@@ -176,6 +180,10 @@ class ToolInvokeMessage(BaseModel):
data: Mapping[str, Any] = Field(..., description="Detailed log data")
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
+ class RetrieverResourceMessage(BaseModel):
+ retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
+ context: str = Field(..., description="context")
+
class MessageType(Enum):
TEXT = "text"
IMAGE = "image"
@@ -188,13 +196,22 @@ class ToolInvokeMessage(BaseModel):
FILE = "file"
LOG = "log"
BLOB_CHUNK = "blob_chunk"
+ RETRIEVER_RESOURCES = "retriever_resources"
type: MessageType = MessageType.TEXT
"""
plain text, image url or link url
"""
message: (
- JsonMessage | TextMessage | BlobChunkMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage
+ JsonMessage
+ | TextMessage
+ | BlobChunkMessage
+ | BlobMessage
+ | LogMessage
+ | FileMessage
+ | None
+ | VariableMessage
+ | RetrieverResourceMessage
)
meta: dict[str, Any] | None = None
@@ -240,6 +257,12 @@ class ToolParameter(PluginParameter):
FILES = PluginParameterType.FILES.value
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
+ ANY = PluginParameterType.ANY.value
+ DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value
+
+ # MCP object and array type parameters
+ ARRAY = MCPServerParameterType.ARRAY.value
+ OBJECT = MCPServerParameterType.OBJECT.value
# deprecated, should not use.
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
@@ -259,6 +282,8 @@ class ToolParameter(PluginParameter):
human_description: Optional[I18nObject] = Field(default=None, description="The description presented to the user")
form: ToolParameterForm = Field(..., description="The form of the parameter, schema/form/llm")
llm_description: Optional[str] = None
+ # MCP object and array type parameters use this field to store the schema
+ input_schema: Optional[dict] = None
@classmethod
def get_simple_instance(
@@ -308,6 +333,7 @@ class ToolProviderIdentity(BaseModel):
name: str = Field(..., description="The name of the tool")
description: I18nObject = Field(..., description="The description of the tool")
icon: str = Field(..., description="The icon of the tool")
+ icon_dark: Optional[str] = Field(default=None, description="The dark icon of the tool")
label: I18nObject = Field(..., description="The label of the tool")
tags: Optional[list[ToolLabelEnum]] = Field(
default=[],
@@ -344,10 +370,18 @@ class ToolEntity(BaseModel):
return v or []
+class OAuthSchema(BaseModel):
+ client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
+ credentials_schema: list[ProviderConfig] = Field(
+ default_factory=list, description="The schema of the OAuth credentials"
+ )
+
+
class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
plugin_id: Optional[str] = None
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
+ oauth_schema: Optional[OAuthSchema] = None
class ToolProviderEntityWithPlugin(ToolProviderEntity):
@@ -427,6 +461,7 @@ class ToolSelector(BaseModel):
options: Optional[list[PluginParameterOption]] = None
provider_id: str = Field(..., description="The id of the provider")
+ credential_id: Optional[str] = Field(default=None, description="The id of the credential")
tool_name: str = Field(..., description="The name of the tool")
tool_description: str = Field(..., description="The description of the tool")
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
@@ -434,3 +469,36 @@ class ToolSelector(BaseModel):
def to_plugin_parameter(self) -> dict[str, Any]:
return self.model_dump()
+
+
+class CredentialType(enum.StrEnum):
+ API_KEY = "api-key"
+ OAUTH2 = "oauth2"
+
+ def get_name(self):
+ if self == CredentialType.API_KEY:
+ return "API KEY"
+ elif self == CredentialType.OAUTH2:
+ return "AUTH"
+ else:
+ return self.value.replace("-", " ").upper()
+
+ def is_editable(self):
+ return self == CredentialType.API_KEY
+
+ def is_validate_allowed(self):
+ return self == CredentialType.API_KEY
+
+ @classmethod
+ def values(cls):
+ return [item.value for item in cls]
+
+ @classmethod
+ def of(cls, credential_type: str) -> "CredentialType":
+ type_name = credential_type.lower()
+ if type_name == "api-key":
+ return cls.API_KEY
+ elif type_name == "oauth2":
+ return cls.OAUTH2
+ else:
+ raise ValueError(f"Invalid credential type: {credential_type}")
diff --git a/api/core/tools/mcp_tool/provider.py b/api/core/tools/mcp_tool/provider.py
new file mode 100644
index 0000000000..93f003effe
--- /dev/null
+++ b/api/core/tools/mcp_tool/provider.py
@@ -0,0 +1,130 @@
+import json
+from typing import Any
+
+from core.mcp.types import Tool as RemoteMCPTool
+from core.tools.__base.tool_provider import ToolProviderController
+from core.tools.__base.tool_runtime import ToolRuntime
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import (
+ ToolDescription,
+ ToolEntity,
+ ToolIdentity,
+ ToolProviderEntityWithPlugin,
+ ToolProviderIdentity,
+ ToolProviderType,
+)
+from core.tools.mcp_tool.tool import MCPTool
+from models.tools import MCPToolProvider
+from services.tools.tools_transform_service import ToolTransformService
+
+
+class MCPToolProviderController(ToolProviderController):
+ provider_id: str
+ entity: ToolProviderEntityWithPlugin
+
+ def __init__(self, entity: ToolProviderEntityWithPlugin, provider_id: str, tenant_id: str, server_url: str) -> None:
+ super().__init__(entity)
+ self.entity = entity
+ self.tenant_id = tenant_id
+ self.provider_id = provider_id
+ self.server_url = server_url
+
+ @property
+ def provider_type(self) -> ToolProviderType:
+ """
+ returns the type of the provider
+
+ :return: type of the provider
+ """
+ return ToolProviderType.MCP
+
+ @classmethod
+ def _from_db(cls, db_provider: MCPToolProvider) -> "MCPToolProviderController":
+ """
+ from db provider
+ """
+ tools = []
+ tools_data = json.loads(db_provider.tools)
+ remote_mcp_tools = [RemoteMCPTool(**tool) for tool in tools_data]
+ user = db_provider.load_user()
+ tools = [
+ ToolEntity(
+ identity=ToolIdentity(
+ author=user.name if user else "Anonymous",
+ name=remote_mcp_tool.name,
+ label=I18nObject(en_US=remote_mcp_tool.name, zh_Hans=remote_mcp_tool.name),
+ provider=db_provider.server_identifier,
+ icon=db_provider.icon,
+ ),
+ parameters=ToolTransformService.convert_mcp_schema_to_parameter(remote_mcp_tool.inputSchema),
+ description=ToolDescription(
+ human=I18nObject(
+ en_US=remote_mcp_tool.description or "", zh_Hans=remote_mcp_tool.description or ""
+ ),
+ llm=remote_mcp_tool.description or "",
+ ),
+ output_schema=None,
+ has_runtime_parameters=len(remote_mcp_tool.inputSchema) > 0,
+ )
+ for remote_mcp_tool in remote_mcp_tools
+ ]
+
+ return cls(
+ entity=ToolProviderEntityWithPlugin(
+ identity=ToolProviderIdentity(
+ author=user.name if user else "Anonymous",
+ name=db_provider.name,
+ label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
+ description=I18nObject(en_US="", zh_Hans=""),
+ icon=db_provider.icon,
+ ),
+ plugin_id=None,
+ credentials_schema=[],
+ tools=tools,
+ ),
+ provider_id=db_provider.server_identifier or "",
+ tenant_id=db_provider.tenant_id or "",
+ server_url=db_provider.decrypted_server_url,
+ )
+
+ def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
+ """
+ validate the credentials of the provider
+ """
+ pass
+
+ def get_tool(self, tool_name: str) -> MCPTool: # type: ignore
+ """
+ return tool with given name
+ """
+ tool_entity = next(
+ (tool_entity for tool_entity in self.entity.tools if tool_entity.identity.name == tool_name), None
+ )
+
+ if not tool_entity:
+ raise ValueError(f"Tool with name {tool_name} not found")
+
+ return MCPTool(
+ entity=tool_entity,
+ runtime=ToolRuntime(tenant_id=self.tenant_id),
+ tenant_id=self.tenant_id,
+ icon=self.entity.identity.icon,
+ server_url=self.server_url,
+ provider_id=self.provider_id,
+ )
+
+ def get_tools(self) -> list[MCPTool]: # type: ignore
+ """
+ get all tools
+ """
+ return [
+ MCPTool(
+ entity=tool_entity,
+ runtime=ToolRuntime(tenant_id=self.tenant_id),
+ tenant_id=self.tenant_id,
+ icon=self.entity.identity.icon,
+ server_url=self.server_url,
+ provider_id=self.provider_id,
+ )
+ for tool_entity in self.entity.tools
+ ]
diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py
new file mode 100644
index 0000000000..d1bacbc735
--- /dev/null
+++ b/api/core/tools/mcp_tool/tool.py
@@ -0,0 +1,92 @@
+import base64
+import json
+from collections.abc import Generator
+from typing import Any, Optional
+
+from core.mcp.error import MCPAuthError, MCPConnectionError
+from core.mcp.mcp_client import MCPClient
+from core.mcp.types import ImageContent, TextContent
+from core.tools.__base.tool import Tool
+from core.tools.__base.tool_runtime import ToolRuntime
+from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
+
+
+class MCPTool(Tool):
+ tenant_id: str
+ icon: str
+ runtime_parameters: Optional[list[ToolParameter]]
+ server_url: str
+ provider_id: str
+
+ def __init__(
+ self, entity: ToolEntity, runtime: ToolRuntime, tenant_id: str, icon: str, server_url: str, provider_id: str
+ ) -> None:
+ super().__init__(entity, runtime)
+ self.tenant_id = tenant_id
+ self.icon = icon
+ self.runtime_parameters = None
+ self.server_url = server_url
+ self.provider_id = provider_id
+
+ def tool_provider_type(self) -> ToolProviderType:
+ return ToolProviderType.MCP
+
+ def _invoke(
+ self,
+ user_id: str,
+ tool_parameters: dict[str, Any],
+ conversation_id: Optional[str] = None,
+ app_id: Optional[str] = None,
+ message_id: Optional[str] = None,
+ ) -> Generator[ToolInvokeMessage, None, None]:
+ from core.tools.errors import ToolInvokeError
+
+ try:
+ with MCPClient(self.server_url, self.provider_id, self.tenant_id, authed=True) as mcp_client:
+ tool_parameters = self._handle_none_parameter(tool_parameters)
+ result = mcp_client.invoke_tool(tool_name=self.entity.identity.name, tool_args=tool_parameters)
+ except MCPAuthError as e:
+ raise ToolInvokeError("Please auth the tool first") from e
+ except MCPConnectionError as e:
+ raise ToolInvokeError(f"Failed to connect to MCP server: {e}") from e
+ except Exception as e:
+ raise ToolInvokeError(f"Failed to invoke tool: {e}") from e
+
+ for content in result.content:
+ if isinstance(content, TextContent):
+ try:
+ content_json = json.loads(content.text)
+ if isinstance(content_json, dict):
+ yield self.create_json_message(content_json)
+ elif isinstance(content_json, list):
+ for item in content_json:
+ yield self.create_json_message(item)
+ else:
+ yield self.create_text_message(content.text)
+ except json.JSONDecodeError:
+ yield self.create_text_message(content.text)
+
+ elif isinstance(content, ImageContent):
+ yield self.create_blob_message(
+ blob=base64.b64decode(content.data), meta={"mime_type": content.mimeType}
+ )
+
+ def fork_tool_runtime(self, runtime: ToolRuntime) -> "MCPTool":
+ return MCPTool(
+ entity=self.entity,
+ runtime=runtime,
+ tenant_id=self.tenant_id,
+ icon=self.icon,
+ server_url=self.server_url,
+ provider_id=self.provider_id,
+ )
+
+ def _handle_none_parameter(self, parameter: dict[str, Any]) -> dict[str, Any]:
+ """
+ in mcp tool invoke, if the parameter is empty, it will be set to None
+ """
+ return {
+ key: value
+ for key, value in parameter.items()
+ if value is not None and not (isinstance(value, str) and value.strip() == "")
+ }
diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py
index d21e3d7d1c..aef2677c36 100644
--- a/api/core/tools/plugin_tool/tool.py
+++ b/api/core/tools/plugin_tool/tool.py
@@ -44,6 +44,7 @@ class PluginTool(Tool):
tool_provider=self.entity.identity.provider,
tool_name=self.entity.identity.name,
credentials=self.runtime.credentials,
+ credential_type=self.runtime.credential_type,
tool_parameters=tool_parameters,
conversation_id=conversation_id,
app_id=app_id,
diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py
index e80005d7bf..5cdf473542 100644
--- a/api/core/tools/signature.py
+++ b/api/core/tools/signature.py
@@ -9,9 +9,10 @@ from configs import dify_config
def sign_tool_file(tool_file_id: str, extension: str) -> str:
"""
- sign file to get a temporary url
+ sign file to get a temporary url for plugin access
"""
- base_url = dify_config.FILES_URL
+ # Use internal URL for plugin/tool file access in Docker environments
+ base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
timestamp = str(int(time.time()))
diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py
index b849f51064..ece02f9d59 100644
--- a/api/core/tools/tool_file_manager.py
+++ b/api/core/tools/tool_file_manager.py
@@ -35,9 +35,10 @@ class ToolFileManager:
@staticmethod
def sign_file(tool_file_id: str, extension: str) -> str:
"""
- sign file to get a temporary url
+ sign file to get a temporary url for plugin access
"""
- base_url = dify_config.FILES_URL
+ # Use internal URL for plugin/tool file access in Docker environments
+ base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
timestamp = str(int(time.time()))
diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py
index 0bfe6329b1..7822bc389c 100644
--- a/api/core/tools/tool_manager.py
+++ b/api/core/tools/tool_manager.py
@@ -4,23 +4,28 @@ import mimetypes
from collections.abc import Generator
from os import listdir, path
from threading import Lock
-from typing import TYPE_CHECKING, Any, Union, cast
+from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from yarl import URL
import contexts
+from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
+from core.tools.mcp_tool.provider import MCPToolProviderController
+from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
+from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
+from core.workflow.entities.variable_pool import VariablePool
+from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
-
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -37,19 +42,20 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
+ CredentialType,
ToolInvokeFrom,
ToolParameter,
ToolProviderType,
)
-from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
+from core.tools.errors import ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
- ProviderConfigEncrypter,
ToolParameterConfigurationManager,
)
+from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
-from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
+from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
@@ -64,8 +70,11 @@ class ToolManager:
@classmethod
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
"""
+
get the hardcoded provider
+
"""
+
if len(cls._hardcoded_providers) == 0:
# init the builtin providers
cls.load_hardcoded_providers_cache()
@@ -109,7 +118,12 @@ class ToolManager:
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(Lock())
+ plugin_tool_providers = contexts.plugin_tool_providers.get()
+ if provider in plugin_tool_providers:
+ return plugin_tool_providers[provider]
+
with contexts.plugin_tool_providers_lock.get():
+ # double check
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
@@ -127,25 +141,7 @@ class ToolManager:
)
plugin_tool_providers[provider] = controller
-
- return controller
-
- @classmethod
- def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
- """
- get the builtin tool
-
- :param provider: the name of the provider
- :param tool_name: the name of the tool
- :param tenant_id: the id of the tenant
- :return: the provider, the tool
- """
- provider_controller = cls.get_builtin_provider(provider, tenant_id)
- tool = provider_controller.get_tool(tool_name)
- if tool is None:
- raise ToolNotFoundError(f"tool {tool_name} not found")
-
- return tool
+ return controller
@classmethod
def get_tool_runtime(
@@ -156,7 +152,8 @@ class ToolManager:
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
- ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool]:
+ credential_id: Optional[str] = None,
+ ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
"""
get the tool runtime
@@ -166,6 +163,7 @@ class ToolManager:
:param tenant_id: the tenant id
:param invoke_from: invoke from
:param tool_invoke_from: the tool invoke from
+ :param credential_id: the credential id
:return: the tool
"""
@@ -189,49 +187,70 @@ class ToolManager:
)
),
)
-
+ builtin_provider = None
if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = ToolProviderID(provider_id)
- # get credentials
- builtin_provider: BuiltinToolProvider | None = (
- db.session.query(BuiltinToolProvider)
- .filter(
- BuiltinToolProvider.tenant_id == tenant_id,
- (BuiltinToolProvider.provider == str(provider_id_entity))
- | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
- )
- .first()
- )
-
+ # get specific credentials
+ if is_valid_uuid(credential_id):
+ try:
+ builtin_provider = (
+ db.session.query(BuiltinToolProvider)
+ .filter(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ BuiltinToolProvider.id == credential_id,
+ )
+ .first()
+ )
+ except Exception as e:
+ builtin_provider = None
+ logger.info(f"Error getting builtin provider {credential_id}:{e}", exc_info=True)
+ # if the provider has been deleted, raise an error
+ if builtin_provider is None:
+ raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
+
+ # fallback to the default provider
if builtin_provider is None:
- raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
+ # use the default provider
+ builtin_provider = (
+ db.session.query(BuiltinToolProvider)
+ .filter(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ (BuiltinToolProvider.provider == str(provider_id_entity))
+ | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
+ )
+ .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
+ .first()
+ )
+ if builtin_provider is None:
+ raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
builtin_provider = (
db.session.query(BuiltinToolProvider)
.filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
+ .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
- # decrypt the credentials
- credentials = builtin_provider.credentials
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ config=[
+ x.to_basic_provider_config()
+ for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
+ ],
+ cache=ToolProviderCredentialsCache(
+ tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
+ ),
)
-
- decrypted_credentials = tool_configuration.decrypt(credentials)
-
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
- credentials=decrypted_credentials,
+ credentials=encrypter.decrypt(builtin_provider.credentials),
+ credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
@@ -241,22 +260,16 @@ class ToolManager:
elif provider_type == ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
-
- # decrypt the credentials
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
- provider_type=api_provider.provider_type.value,
- provider_identity=api_provider.entity.identity.name,
+ controller=api_provider,
)
- decrypted_credentials = tool_configuration.decrypt(credentials)
-
return cast(
ApiTool,
api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
- credentials=decrypted_credentials,
+ credentials=encrypter.decrypt(credentials),
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
@@ -292,6 +305,8 @@ class ToolManager:
raise NotImplementedError("app provider not implemented")
elif provider_type == ToolProviderType.PLUGIN:
return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name)
+ elif provider_type == ToolProviderType.MCP:
+ return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name)
else:
raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found")
@@ -302,6 +317,7 @@ class ToolManager:
app_id: str,
agent_tool: AgentToolEntity,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
+ variable_pool: Optional[VariablePool] = None,
) -> Tool:
"""
get the agent tool runtime
@@ -313,27 +329,13 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.AGENT,
+ credential_id=agent_tool.credential_id,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
- for parameter in parameters:
- # check file types
- if (
- parameter.type
- in {
- ToolParameter.ToolParameterType.SYSTEM_FILES,
- ToolParameter.ToolParameterType.FILE,
- ToolParameter.ToolParameterType.FILES,
- }
- and parameter.required
- ):
- raise ValueError(f"file type parameter {parameter.name} not supported in agent")
-
- if parameter.form == ToolParameter.ToolParameterForm.FORM:
- # save tool parameter to tool entity memory
- value = parameter.init_frontend_parameter(agent_tool.tool_parameters.get(parameter.name))
- runtime_parameters[parameter.name] = value
-
+ runtime_parameters = cls._convert_tool_parameters_type(
+ parameters, variable_pool, agent_tool.tool_parameters, typ="agent"
+ )
# decrypt runtime parameters
encryption_manager = ToolParameterConfigurationManager(
tenant_id=tenant_id,
@@ -357,10 +359,12 @@ class ToolManager:
node_id: str,
workflow_tool: "ToolEntity",
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
+ variable_pool: Optional[VariablePool] = None,
) -> Tool:
"""
get the workflow tool runtime
"""
+
tool_runtime = cls.get_tool_runtime(
provider_type=workflow_tool.provider_type,
provider_id=workflow_tool.provider_id,
@@ -368,16 +372,13 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
+ credential_id=workflow_tool.credential_id,
)
- runtime_parameters = {}
- parameters = tool_runtime.get_merged_runtime_parameters()
-
- for parameter in parameters:
- # save tool parameter to tool entity memory
- if parameter.form == ToolParameter.ToolParameterForm.FORM:
- value = parameter.init_frontend_parameter(workflow_tool.tool_configurations.get(parameter.name))
- runtime_parameters[parameter.name] = value
+ parameters = tool_runtime.get_merged_runtime_parameters()
+ runtime_parameters = cls._convert_tool_parameters_type(
+ parameters, variable_pool, workflow_tool.tool_configurations, typ="workflow"
+ )
# decrypt runtime parameters
encryption_manager = ToolParameterConfigurationManager(
tenant_id=tenant_id,
@@ -401,6 +402,7 @@ class ToolManager:
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
+ credential_id: Optional[str] = None,
) -> Tool:
"""
get tool runtime from plugin
@@ -412,6 +414,7 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.PLUGIN,
+ credential_id=credential_id,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
@@ -561,6 +564,22 @@ class ToolManager:
return cls._builtin_tools_labels[tool_name]
+ @classmethod
+ def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
+ """
+ list all the builtin providers
+ """
+ # according to multi credentials, select the one with is_default=True first, then created_at oldest
+ # for compatibility with old version
+ sql = """
+ SELECT DISTINCT ON (tenant_id, provider) id
+ FROM tool_builtin_providers
+ WHERE tenant_id = :tenant_id
+ ORDER BY tenant_id, provider, is_default DESC, created_at DESC
+ """
+ ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
+ return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all()
+
@classmethod
def list_providers_from_api(
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
@@ -569,27 +588,19 @@ class ToolManager:
filters = []
if not typ:
- filters.extend(["builtin", "api", "workflow"])
+ filters.extend(["builtin", "api", "workflow", "mcp"])
else:
filters.append(typ)
with db.session.no_autoflush:
if "builtin" in filters:
- # get builtin providers
builtin_providers = cls.list_builtin_providers(tenant_id)
- # get db builtin providers
- db_builtin_providers: list[BuiltinToolProvider] = (
- db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
- )
-
- # rewrite db_builtin_providers
- for db_provider in db_builtin_providers:
- tool_provider_id = str(ToolProviderID(db_provider.provider))
- db_provider.provider = tool_provider_id
-
- def find_db_builtin_provider(provider):
- return next((x for x in db_builtin_providers if x.provider == provider), None)
+ # key: provider name, value: provider
+ db_builtin_providers = {
+ str(ToolProviderID(provider.provider)): provider
+ for provider in cls.list_default_builtin_providers(tenant_id)
+ }
# append builtin providers
for provider in builtin_providers:
@@ -601,10 +612,9 @@ class ToolManager:
name_func=lambda x: x.identity.name,
):
continue
-
user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider,
- db_provider=find_db_builtin_provider(provider.entity.identity.name),
+ db_provider=db_builtin_providers.get(provider.entity.identity.name),
decrypt_credentials=False,
)
@@ -614,7 +624,6 @@ class ToolManager:
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
# get db api providers
-
if "api" in filters:
db_api_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
@@ -663,6 +672,10 @@ class ToolManager:
labels=labels.get(provider_controller.provider_id, []),
)
result_providers[f"workflow_provider.{user_provider.name}"] = user_provider
+ if "mcp" in filters:
+ mcp_providers = MCPToolManageService.retrieve_mcp_tools(tenant_id, for_list=True)
+ for mcp_provider in mcp_providers:
+ result_providers[f"mcp_provider.{mcp_provider.name}"] = mcp_provider
return BuiltinToolProviderSort.sort(list(result_providers.values()))
@@ -690,14 +703,47 @@ class ToolManager:
if provider is None:
raise ToolProviderNotFoundError(f"api provider {provider_id} not found")
+ auth_type = ApiProviderAuthType.NONE
+ provider_auth_type = provider.credentials.get("auth_type")
+ if provider_auth_type in ("api_key_header", "api_key"): # backward compatibility
+ auth_type = ApiProviderAuthType.API_KEY_HEADER
+ elif provider_auth_type == "api_key_query":
+ auth_type = ApiProviderAuthType.API_KEY_QUERY
+
controller = ApiToolProviderController.from_db(
provider,
- ApiProviderAuthType.API_KEY if provider.credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
+ auth_type,
)
controller.load_bundled_tools(provider.tools)
return controller, provider.credentials
+ @classmethod
+ def get_mcp_provider_controller(cls, tenant_id: str, provider_id: str) -> MCPToolProviderController:
+ """
+ get the api provider
+
+ :param tenant_id: the id of the tenant
+ :param provider_id: the id of the provider
+
+ :return: the provider controller, the credentials
+ """
+ provider: MCPToolProvider | None = (
+ db.session.query(MCPToolProvider)
+ .filter(
+ MCPToolProvider.server_identifier == provider_id,
+ MCPToolProvider.tenant_id == tenant_id,
+ )
+ .first()
+ )
+
+ if provider is None:
+ raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
+
+ controller = MCPToolProviderController._from_db(provider)
+
+ return controller
+
@classmethod
def user_get_api_provider(cls, provider: str, tenant_id: str) -> dict:
"""
@@ -725,20 +771,24 @@ class ToolManager:
credentials = {}
# package tool provider controller
+ auth_type = ApiProviderAuthType.NONE
+ credentials_auth_type = credentials.get("auth_type")
+ if credentials_auth_type in ("api_key_header", "api_key"): # backward compatibility
+ auth_type = ApiProviderAuthType.API_KEY_HEADER
+ elif credentials_auth_type == "api_key_query":
+ auth_type = ApiProviderAuthType.API_KEY_QUERY
+
controller = ApiToolProviderController.from_db(
provider_obj,
- ApiProviderAuthType.API_KEY if credentials["auth_type"] == "api_key" else ApiProviderAuthType.NONE,
+ auth_type,
)
# init tool configuration
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
- provider_type=controller.provider_type.value,
- provider_identity=controller.entity.identity.name,
+ controller=controller,
)
- decrypted_credentials = tool_configuration.decrypt(credentials)
- masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
+ masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
try:
icon = json.loads(provider_obj.icon)
@@ -826,6 +876,22 @@ class ToolManager:
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
+ @classmethod
+ def generate_mcp_tool_icon_url(cls, tenant_id: str, provider_id: str) -> dict[str, str] | str:
+ try:
+ mcp_provider: MCPToolProvider | None = (
+ db.session.query(MCPToolProvider)
+ .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
+ .first()
+ )
+
+ if mcp_provider is None:
+ raise ToolProviderNotFoundError(f"mcp provider {provider_id} not found")
+
+ return mcp_provider.provider_icon
+ except Exception:
+ return {"background": "#252525", "content": "\ud83d\ude01"}
+
@classmethod
def get_tool_icon(
cls,
@@ -863,8 +929,61 @@ class ToolManager:
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
raise ValueError(f"plugin provider {provider_id} not found")
+ elif provider_type == ToolProviderType.MCP:
+ return cls.generate_mcp_tool_icon_url(tenant_id, provider_id)
else:
raise ValueError(f"provider type {provider_type} not found")
+ @classmethod
+ def _convert_tool_parameters_type(
+ cls,
+ parameters: list[ToolParameter],
+ variable_pool: Optional[VariablePool],
+ tool_configurations: dict[str, Any],
+ typ: Literal["agent", "workflow", "tool"] = "workflow",
+ ) -> dict[str, Any]:
+ """
+ Convert tool parameters type
+ """
+ from core.workflow.nodes.tool.entities import ToolNodeData
+ from core.workflow.nodes.tool.exc import ToolParameterError
+
+ runtime_parameters = {}
+ for parameter in parameters:
+ if (
+ parameter.type
+ in {
+ ToolParameter.ToolParameterType.SYSTEM_FILES,
+ ToolParameter.ToolParameterType.FILE,
+ ToolParameter.ToolParameterType.FILES,
+ }
+ and parameter.required
+ and typ == "agent"
+ ):
+ raise ValueError(f"file type parameter {parameter.name} not supported in agent")
+ # save tool parameter to tool entity memory
+ if parameter.form == ToolParameter.ToolParameterForm.FORM:
+ if variable_pool:
+ config = tool_configurations.get(parameter.name, {})
+ if not (config and isinstance(config, dict) and config.get("value") is not None):
+ continue
+ tool_input = ToolNodeData.ToolInput(**tool_configurations.get(parameter.name, {}))
+ if tool_input.type == "variable":
+ variable = variable_pool.get(tool_input.value)
+ if variable is None:
+ raise ToolParameterError(f"Variable {tool_input.value} does not exist")
+ parameter_value = variable.value
+ elif tool_input.type in {"mixed", "constant"}:
+ segment_group = variable_pool.convert_template(str(tool_input.value))
+ parameter_value = segment_group.text
+ else:
+ raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
+ runtime_parameters[parameter.name] = parameter_value
+
+ else:
+ value = parameter.init_frontend_parameter(tool_configurations.get(parameter.name))
+ runtime_parameters[parameter.name] = value
+ return runtime_parameters
+
ToolManager.load_hardcoded_providers_cache()
diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py
index 6a5fba65bd..aceba6e69f 100644
--- a/api/core/tools/utils/configuration.py
+++ b/api/core/tools/utils/configuration.py
@@ -1,12 +1,8 @@
from copy import deepcopy
from typing import Any
-from pydantic import BaseModel
-
-from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
-from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
@@ -14,108 +10,6 @@ from core.tools.entities.tool_entities import (
)
-class ProviderConfigEncrypter(BaseModel):
- tenant_id: str
- config: list[BasicProviderConfig]
- provider_type: str
- provider_identity: str
-
- def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
- """
- deep copy data
- """
- return deepcopy(data)
-
- def encrypt(self, data: dict[str, str]) -> dict[str, str]:
- """
- encrypt tool credentials with tenant id
-
- return a deep copy of credentials with encrypted values
- """
- data = self._deep_copy(data)
-
- # get fields need to be decrypted
- fields = dict[str, BasicProviderConfig]()
- for credential in self.config:
- fields[credential.name] = credential
-
- for field_name, field in fields.items():
- if field.type == BasicProviderConfig.Type.SECRET_INPUT:
- if field_name in data:
- encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
- data[field_name] = encrypted
-
- return data
-
- def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
- """
- mask tool credentials
-
- return a deep copy of credentials with masked values
- """
- data = self._deep_copy(data)
-
- # get fields need to be decrypted
- fields = dict[str, BasicProviderConfig]()
- for credential in self.config:
- fields[credential.name] = credential
-
- for field_name, field in fields.items():
- if field.type == BasicProviderConfig.Type.SECRET_INPUT:
- if field_name in data:
- if len(data[field_name]) > 6:
- data[field_name] = (
- data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
- )
- else:
- data[field_name] = "*" * len(data[field_name])
-
- return data
-
- def decrypt(self, data: dict[str, str]) -> dict[str, str]:
- """
- decrypt tool credentials with tenant id
-
- return a deep copy of credentials with decrypted values
- """
- cache = ToolProviderCredentialsCache(
- tenant_id=self.tenant_id,
- identity_id=f"{self.provider_type}.{self.provider_identity}",
- cache_type=ToolProviderCredentialsCacheType.PROVIDER,
- )
- cached_credentials = cache.get()
- if cached_credentials:
- return cached_credentials
- data = self._deep_copy(data)
- # get fields need to be decrypted
- fields = dict[str, BasicProviderConfig]()
- for credential in self.config:
- fields[credential.name] = credential
-
- for field_name, field in fields.items():
- if field.type == BasicProviderConfig.Type.SECRET_INPUT:
- if field_name in data:
- try:
- # if the value is None or empty string, skip decrypt
- if not data[field_name]:
- continue
-
- data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
- except Exception:
- pass
-
- cache.set(data)
- return data
-
- def delete_tool_credentials_cache(self):
- cache = ToolProviderCredentialsCache(
- tenant_id=self.tenant_id,
- identity_id=f"{self.provider_type}.{self.provider_identity}",
- cache_type=ToolProviderCredentialsCacheType.PROVIDER,
- )
- cache.delete()
-
-
class ToolParameterConfigurationManager:
"""
Tool parameter configuration manager
diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py
new file mode 100644
index 0000000000..5fdfd3b9d1
--- /dev/null
+++ b/api/core/tools/utils/encryption.py
@@ -0,0 +1,142 @@
+from copy import deepcopy
+from typing import Any, Optional, Protocol
+
+from core.entities.provider_entities import BasicProviderConfig
+from core.helper import encrypter
+from core.helper.provider_cache import SingletonProviderCredentialsCache
+from core.tools.__base.tool_provider import ToolProviderController
+
+
+class ProviderConfigCache(Protocol):
+ """
+ Interface for provider configuration cache operations
+ """
+
+ def get(self) -> Optional[dict]:
+ """Get cached provider configuration"""
+ ...
+
+ def set(self, config: dict[str, Any]) -> None:
+ """Cache provider configuration"""
+ ...
+
+ def delete(self) -> None:
+ """Delete cached provider configuration"""
+ ...
+
+
+class ProviderConfigEncrypter:
+ tenant_id: str
+ config: list[BasicProviderConfig]
+ provider_config_cache: ProviderConfigCache
+
+ def __init__(
+ self,
+ tenant_id: str,
+ config: list[BasicProviderConfig],
+ provider_config_cache: ProviderConfigCache,
+ ):
+ self.tenant_id = tenant_id
+ self.config = config
+ self.provider_config_cache = provider_config_cache
+
+ def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
+ """
+ deep copy data
+ """
+ return deepcopy(data)
+
+ def encrypt(self, data: dict[str, str]) -> dict[str, str]:
+ """
+ encrypt tool credentials with tenant id
+
+ return a deep copy of credentials with encrypted values
+ """
+ data = self._deep_copy(data)
+
+ # get fields need to be decrypted
+ fields = dict[str, BasicProviderConfig]()
+ for credential in self.config:
+ fields[credential.name] = credential
+
+ for field_name, field in fields.items():
+ if field.type == BasicProviderConfig.Type.SECRET_INPUT:
+ if field_name in data:
+ encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
+ data[field_name] = encrypted
+
+ return data
+
+ def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
+ """
+ mask tool credentials
+
+ return a deep copy of credentials with masked values
+ """
+ data = self._deep_copy(data)
+
+ # get fields need to be decrypted
+ fields = dict[str, BasicProviderConfig]()
+ for credential in self.config:
+ fields[credential.name] = credential
+
+ for field_name, field in fields.items():
+ if field.type == BasicProviderConfig.Type.SECRET_INPUT:
+ if field_name in data:
+ if len(data[field_name]) > 6:
+ data[field_name] = (
+ data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
+ )
+ else:
+ data[field_name] = "*" * len(data[field_name])
+
+ return data
+
+ def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
+ """
+ decrypt tool credentials with tenant id
+
+ return a deep copy of credentials with decrypted values
+ """
+ cached_credentials = self.provider_config_cache.get()
+ if cached_credentials:
+ return cached_credentials
+
+ data = self._deep_copy(data)
+ # get fields need to be decrypted
+ fields = dict[str, BasicProviderConfig]()
+ for credential in self.config:
+ fields[credential.name] = credential
+
+ for field_name, field in fields.items():
+ if field.type == BasicProviderConfig.Type.SECRET_INPUT:
+ if field_name in data:
+ try:
+ # if the value is None or empty string, skip decrypt
+ if not data[field_name]:
+ continue
+
+ data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
+ except Exception:
+ pass
+
+ self.provider_config_cache.set(data)
+ return data
+
+
+def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
+ return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
+
+
+def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
+ cache = SingletonProviderCredentialsCache(
+ tenant_id=tenant_id,
+ provider_type=controller.provider_type.value,
+ provider_identity=controller.entity.identity.name,
+ )
+ encrypt = ProviderConfigEncrypter(
+ tenant_id=tenant_id,
+ config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
+ provider_config_cache=cache,
+ )
+ return encrypt, cache
diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py
index 3f844e8234..a3c84615ca 100644
--- a/api/core/tools/utils/parser.py
+++ b/api/core/tools/utils/parser.py
@@ -1,5 +1,4 @@
import re
-import uuid
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
@@ -154,7 +153,7 @@ class ApiBasedToolSchemaParser:
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
if not path:
- path = str(uuid.uuid4())
+ path = ""
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
diff --git a/api/core/tools/utils/system_oauth_encryption.py b/api/core/tools/utils/system_oauth_encryption.py
new file mode 100644
index 0000000000..f3c946b95f
--- /dev/null
+++ b/api/core/tools/utils/system_oauth_encryption.py
@@ -0,0 +1,187 @@
+import base64
+import hashlib
+import logging
+from collections.abc import Mapping
+from typing import Any, Optional
+
+from Crypto.Cipher import AES
+from Crypto.Random import get_random_bytes
+from Crypto.Util.Padding import pad, unpad
+from pydantic import TypeAdapter
+
+from configs import dify_config
+
+logger = logging.getLogger(__name__)
+
+
+class OAuthEncryptionError(Exception):
+ """OAuth encryption/decryption specific error"""
+
+ pass
+
+
+class SystemOAuthEncrypter:
+ """
+ A simple OAuth parameters encrypter using AES-CBC encryption.
+
+ This class provides methods to encrypt and decrypt OAuth parameters
+ using AES-CBC mode with a key derived from the application's SECRET_KEY.
+ """
+
+ def __init__(self, secret_key: Optional[str] = None):
+ """
+ Initialize the OAuth encrypter.
+
+ Args:
+ secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
+
+ Raises:
+ ValueError: If SECRET_KEY is not configured or empty
+ """
+ secret_key = secret_key or dify_config.SECRET_KEY or ""
+
+ # Generate a fixed 256-bit key using SHA-256
+ self.key = hashlib.sha256(secret_key.encode()).digest()
+
+ def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
+ """
+ Encrypt OAuth parameters.
+
+ Args:
+ oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
+
+ Returns:
+ Base64-encoded encrypted string
+
+ Raises:
+ OAuthEncryptionError: If encryption fails
+ ValueError: If oauth_params is invalid
+ """
+
+ try:
+ # Generate random IV (16 bytes)
+ iv = get_random_bytes(16)
+
+ # Create AES cipher (CBC mode)
+ cipher = AES.new(self.key, AES.MODE_CBC, iv)
+
+ # Encrypt data
+ padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
+ encrypted_data = cipher.encrypt(padded_data)
+
+ # Combine IV and encrypted data
+ combined = iv + encrypted_data
+
+ # Return base64 encoded string
+ return base64.b64encode(combined).decode()
+
+ except Exception as e:
+ raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
+
+ def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
+ """
+ Decrypt OAuth parameters.
+
+ Args:
+ encrypted_data: Base64-encoded encrypted string
+
+ Returns:
+ Decrypted OAuth parameters dictionary
+
+ Raises:
+ OAuthEncryptionError: If decryption fails
+ ValueError: If encrypted_data is invalid
+ """
+ if not isinstance(encrypted_data, str):
+ raise ValueError("encrypted_data must be a string")
+
+ if not encrypted_data:
+ raise ValueError("encrypted_data cannot be empty")
+
+ try:
+ # Base64 decode
+ combined = base64.b64decode(encrypted_data)
+
+ # Check minimum length (IV + at least one AES block)
+ if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
+ raise ValueError("Invalid encrypted data format")
+
+ # Separate IV and encrypted data
+ iv = combined[:16]
+ encrypted_data_bytes = combined[16:]
+
+ # Create AES cipher
+ cipher = AES.new(self.key, AES.MODE_CBC, iv)
+
+ # Decrypt data
+ decrypted_data = cipher.decrypt(encrypted_data_bytes)
+ unpadded_data = unpad(decrypted_data, AES.block_size)
+
+ # Parse JSON
+ oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
+
+ if not isinstance(oauth_params, dict):
+ raise ValueError("Decrypted data is not a valid dictionary")
+
+ return oauth_params
+
+ except Exception as e:
+ raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
+
+
+# Factory function for creating encrypter instances
+def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter:
+ """
+ Create an OAuth encrypter instance.
+
+ Args:
+ secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
+
+ Returns:
+ SystemOAuthEncrypter instance
+ """
+ return SystemOAuthEncrypter(secret_key=secret_key)
+
+
+# Global encrypter instance (for backward compatibility)
+_oauth_encrypter: Optional[SystemOAuthEncrypter] = None
+
+
+def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
+ """
+ Get the global OAuth encrypter instance.
+
+ Returns:
+ SystemOAuthEncrypter instance
+ """
+ global _oauth_encrypter
+ if _oauth_encrypter is None:
+ _oauth_encrypter = SystemOAuthEncrypter()
+ return _oauth_encrypter
+
+
+# Convenience functions for backward compatibility
+def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
+ """
+ Encrypt OAuth parameters using the global encrypter.
+
+ Args:
+ oauth_params: OAuth parameters dictionary
+
+ Returns:
+ Base64-encoded encrypted string
+ """
+ return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
+
+
+def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
+ """
+ Decrypt OAuth parameters using the global encrypter.
+
+ Args:
+ encrypted_data: Base64-encoded encrypted string
+
+ Returns:
+ Decrypted OAuth parameters dictionary
+ """
+ return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
diff --git a/api/core/tools/utils/uuid_utils.py b/api/core/tools/utils/uuid_utils.py
index 3046c08c89..bdcc33259d 100644
--- a/api/core/tools/utils/uuid_utils.py
+++ b/api/core/tools/utils/uuid_utils.py
@@ -1,7 +1,9 @@
import uuid
-def is_valid_uuid(uuid_str: str) -> bool:
+def is_valid_uuid(uuid_str: str | None) -> bool:
+ if uuid_str is None or len(uuid_str) == 0:
+ return False
try:
uuid.UUID(uuid_str)
return True
diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py
index 57c93d1d45..10bf8ca640 100644
--- a/api/core/tools/workflow_as_tool/tool.py
+++ b/api/core/tools/workflow_as_tool/tool.py
@@ -8,7 +8,12 @@ from flask_login import current_user
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
-from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolParameter, ToolProviderType
+from core.tools.entities.tool_entities import (
+ ToolEntity,
+ ToolInvokeMessage,
+ ToolParameter,
+ ToolProviderType,
+)
from core.tools.errors import ToolInvokeError
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py
index 64ba16c367..13274f4e0e 100644
--- a/api/core/variables/segments.py
+++ b/api/core/variables/segments.py
@@ -1,9 +1,9 @@
import json
import sys
from collections.abc import Mapping, Sequence
-from typing import Any
+from typing import Annotated, Any, TypeAlias
-from pydantic import BaseModel, ConfigDict, field_validator
+from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.file import File
@@ -11,6 +11,11 @@ from .types import SegmentType
class Segment(BaseModel):
+ """Segment is runtime type used during the execution of workflow.
+
+ Note: this class is abstract, you should use subclasses of this class instead.
+ """
+
model_config = ConfigDict(frozen=True)
value_type: SegmentType
@@ -73,12 +78,26 @@ class StringSegment(Segment):
class FloatSegment(Segment):
- value_type: SegmentType = SegmentType.NUMBER
+ value_type: SegmentType = SegmentType.FLOAT
value: float
+ # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
+ # The following tests cannot pass.
+ #
+ # def test_float_segment_and_nan():
+ # nan = float("nan")
+ # assert nan != nan
+ #
+ # f1 = FloatSegment(value=float("nan"))
+ # f2 = FloatSegment(value=float("nan"))
+ # assert f1 != f2
+ #
+ # f3 = FloatSegment(value=nan)
+ # f4 = FloatSegment(value=nan)
+ # assert f3 != f4
class IntegerSegment(Segment):
- value_type: SegmentType = SegmentType.NUMBER
+ value_type: SegmentType = SegmentType.INTEGER
value: int
@@ -167,3 +186,46 @@ class ArrayFileSegment(ArraySegment):
@property
def text(self) -> str:
return ""
+
+
+def get_segment_discriminator(v: Any) -> SegmentType | None:
+ if isinstance(v, Segment):
+ return v.value_type
+ elif isinstance(v, dict):
+ value_type = v.get("value_type")
+ if value_type is None:
+ return None
+ try:
+ seg_type = SegmentType(value_type)
+ except ValueError:
+ return None
+ return seg_type
+ else:
+ # return None if the discriminator value isn't found
+ return None
+
+
+# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic.
+# Use `Segment` for type hinting when serialization is not required.
+#
+# Note:
+# - All variants in `SegmentUnion` must inherit from the `Segment` class.
+# - The union must include all non-abstract subclasses of `Segment`, except:
+# - `SegmentGroup`, which is not added to the variable pool.
+# - `Variable` and its subclasses, which are handled by `VariableUnion`.
+SegmentUnion: TypeAlias = Annotated[
+ (
+ Annotated[NoneSegment, Tag(SegmentType.NONE)]
+ | Annotated[StringSegment, Tag(SegmentType.STRING)]
+ | Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
+ | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
+ | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
+ | Annotated[FileSegment, Tag(SegmentType.FILE)]
+ | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
+ | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
+ | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
+ | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
+ | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
+ ),
+ Discriminator(get_segment_discriminator),
+]
diff --git a/api/core/variables/types.py b/api/core/variables/types.py
index 4387e9693e..e79b2410bf 100644
--- a/api/core/variables/types.py
+++ b/api/core/variables/types.py
@@ -1,8 +1,27 @@
+from collections.abc import Mapping
from enum import StrEnum
+from typing import Any, Optional
+
+from core.file.models import File
+
+
+class ArrayValidation(StrEnum):
+ """Strategy for validating array elements"""
+
+ # Skip element validation (only check array container)
+ NONE = "none"
+
+ # Validate the first element (if array is non-empty)
+ FIRST = "first"
+
+ # Validate all elements in the array.
+ ALL = "all"
class SegmentType(StrEnum):
NUMBER = "number"
+ INTEGER = "integer"
+ FLOAT = "float"
STRING = "string"
OBJECT = "object"
SECRET = "secret"
@@ -18,3 +37,140 @@ class SegmentType(StrEnum):
NONE = "none"
GROUP = "group"
+
+ def is_array_type(self) -> bool:
+ return self in _ARRAY_TYPES
+
+ @classmethod
+ def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
+ """
+ Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
+
+ Returns `None` if no appropriate `SegmentType` can be determined for the given `value`.
+ For example, this may occur if the input is a generic Python object of type `object`.
+ """
+
+ if isinstance(value, list):
+ elem_types: set[SegmentType] = set()
+ for i in value:
+ segment_type = cls.infer_segment_type(i)
+ if segment_type is None:
+ return None
+
+ elem_types.add(segment_type)
+
+ if len(elem_types) != 1:
+ if elem_types.issubset(_NUMERICAL_TYPES):
+ return SegmentType.ARRAY_NUMBER
+ return SegmentType.ARRAY_ANY
+ elif all(i.is_array_type() for i in elem_types):
+ return SegmentType.ARRAY_ANY
+ match elem_types.pop():
+ case SegmentType.STRING:
+ return SegmentType.ARRAY_STRING
+ case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
+ return SegmentType.ARRAY_NUMBER
+ case SegmentType.OBJECT:
+ return SegmentType.ARRAY_OBJECT
+ case SegmentType.FILE:
+ return SegmentType.ARRAY_FILE
+ case SegmentType.NONE:
+ return SegmentType.ARRAY_ANY
+ case _:
+ # This should be unreachable.
+ raise ValueError(f"not supported value {value}")
+ if value is None:
+ return SegmentType.NONE
+ elif isinstance(value, int) and not isinstance(value, bool):
+ return SegmentType.INTEGER
+ elif isinstance(value, float):
+ return SegmentType.FLOAT
+ elif isinstance(value, str):
+ return SegmentType.STRING
+ elif isinstance(value, dict):
+ return SegmentType.OBJECT
+ elif isinstance(value, File):
+ return SegmentType.FILE
+ else:
+ return None
+
+ def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool:
+ if not isinstance(value, list):
+ return False
+ # Skip element validation if array is empty
+ if len(value) == 0:
+ return True
+ if self == SegmentType.ARRAY_ANY:
+ return True
+ element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self]
+
+ if array_validation == ArrayValidation.NONE:
+ return True
+ elif array_validation == ArrayValidation.FIRST:
+ return element_type.is_valid(value[0])
+ else:
+ return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value)
+
+ def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
+ """
+ Check if a value matches the segment type.
+ Users of `SegmentType` should call this method, instead of using
+ `isinstance` manually.
+
+ Args:
+ value: The value to validate
+ array_validation: Validation strategy for array types (ignored for non-array types)
+
+ Returns:
+ True if the value matches the type under the given validation strategy
+ """
+ if self.is_array_type():
+ return self._validate_array(value, array_validation)
+ elif self == SegmentType.NUMBER:
+ return isinstance(value, (int, float))
+ elif self == SegmentType.STRING:
+ return isinstance(value, str)
+ elif self == SegmentType.OBJECT:
+ return isinstance(value, dict)
+ elif self == SegmentType.SECRET:
+ return isinstance(value, str)
+ elif self == SegmentType.FILE:
+ return isinstance(value, File)
+ elif self == SegmentType.NONE:
+ return value is None
+ else:
+ raise AssertionError("this statement should be unreachable.")
+
+ def exposed_type(self) -> "SegmentType":
+ """Returns the type exposed to the frontend.
+
+ The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
+ """
+ if self in (SegmentType.INTEGER, SegmentType.FLOAT):
+ return SegmentType.NUMBER
+ return self
+
+
+_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
+ # ARRAY_ANY does not have correpond element type.
+ SegmentType.ARRAY_STRING: SegmentType.STRING,
+ SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
+ SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
+ SegmentType.ARRAY_FILE: SegmentType.FILE,
+}
+
+_ARRAY_TYPES = frozenset(
+ list(_ARRAY_ELEMENT_TYPES_MAPPING.keys())
+ + [
+ SegmentType.ARRAY_ANY,
+ ]
+)
+
+
+_NUMERICAL_TYPES = frozenset(
+ [
+ SegmentType.NUMBER,
+ SegmentType.INTEGER,
+ SegmentType.FLOAT,
+ ]
+)
diff --git a/api/core/variables/utils.py b/api/core/variables/utils.py
index e5d222af7d..692db3502e 100644
--- a/api/core/variables/utils.py
+++ b/api/core/variables/utils.py
@@ -1,8 +1,26 @@
+import json
from collections.abc import Iterable, Sequence
+from .segment_group import SegmentGroup
+from .segments import ArrayFileSegment, FileSegment, Segment
+
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
selectors = [node_id, name]
if paths:
selectors.extend(paths)
return selectors
+
+
+class SegmentJSONEncoder(json.JSONEncoder):
+ def default(self, o):
+ if isinstance(o, ArrayFileSegment):
+ return [v.model_dump() for v in o.value]
+ elif isinstance(o, FileSegment):
+ return o.value.model_dump()
+ elif isinstance(o, SegmentGroup):
+ return [self.default(seg) for seg in o.value]
+ elif isinstance(o, Segment):
+ return o.value
+ else:
+ super().default(o)
diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py
index b650b1682e..a31ebc848e 100644
--- a/api/core/variables/variables.py
+++ b/api/core/variables/variables.py
@@ -1,8 +1,8 @@
from collections.abc import Sequence
-from typing import cast
+from typing import Annotated, TypeAlias, cast
from uuid import uuid4
-from pydantic import Field
+from pydantic import Discriminator, Field, Tag
from core.helper import encrypter
@@ -20,6 +20,7 @@ from .segments import (
ObjectSegment,
Segment,
StringSegment,
+ get_segment_discriminator,
)
from .types import SegmentType
@@ -27,6 +28,10 @@ from .types import SegmentType
class Variable(Segment):
"""
A variable is a segment that has a name.
+
+ It is mainly used to store segments and their selector in VariablePool.
+
+ Note: this class is abstract, you should use subclasses of this class instead.
"""
id: str = Field(
@@ -93,3 +98,28 @@ class FileVariable(FileSegment, Variable):
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass
+
+
+# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
+# Use `Variable` for type hinting when serialization is not required.
+#
+# Note:
+# - All variants in `VariableUnion` must inherit from the `Variable` class.
+# - The union must include all non-abstract subclasses of `Segment`, except:
+VariableUnion: TypeAlias = Annotated[
+ (
+ Annotated[NoneVariable, Tag(SegmentType.NONE)]
+ | Annotated[StringVariable, Tag(SegmentType.STRING)]
+ | Annotated[FloatVariable, Tag(SegmentType.FLOAT)]
+ | Annotated[IntegerVariable, Tag(SegmentType.INTEGER)]
+ | Annotated[ObjectVariable, Tag(SegmentType.OBJECT)]
+ | Annotated[FileVariable, Tag(SegmentType.FILE)]
+ | Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)]
+ | Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)]
+ | Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)]
+ | Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
+ | Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
+ | Annotated[SecretVariable, Tag(SegmentType.SECRET)]
+ ),
+ Discriminator(get_segment_discriminator),
+]
diff --git a/api/core/workflow/callbacks/workflow_logging_callback.py b/api/core/workflow/callbacks/workflow_logging_callback.py
index e6813a3997..12b5203ca3 100644
--- a/api/core/workflow/callbacks/workflow_logging_callback.py
+++ b/api/core/workflow/callbacks/workflow_logging_callback.py
@@ -232,14 +232,14 @@ class WorkflowLoggingCallback(WorkflowCallback):
Publish loop started
"""
self.print_text("\n[LoopRunStartedEvent]", color="blue")
- self.print_text(f"Loop Node ID: {event.loop_id}", color="blue")
+ self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
def on_workflow_loop_next(self, event: LoopRunNextEvent) -> None:
"""
Publish loop next
"""
self.print_text("\n[LoopRunNextEvent]", color="blue")
- self.print_text(f"Loop Node ID: {event.loop_id}", color="blue")
+ self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
self.print_text(f"Loop Index: {event.index}", color="blue")
def on_workflow_loop_completed(self, event: LoopRunSucceededEvent | LoopRunFailedEvent) -> None:
@@ -250,7 +250,7 @@ class WorkflowLoggingCallback(WorkflowCallback):
"\n[LoopRunSucceededEvent]" if isinstance(event, LoopRunSucceededEvent) else "\n[LoopRunFailedEvent]",
color="blue",
)
- self.print_text(f"Node ID: {event.loop_id}", color="blue")
+ self.print_text(f"Loop Node ID: {event.loop_node_id}", color="blue")
def print_text(self, text: str, color: Optional[str] = None, end: str = "\n") -> None:
"""Print text with highlighting and no end characters."""
diff --git a/api/core/workflow/conversation_variable_updater.py b/api/core/workflow/conversation_variable_updater.py
new file mode 100644
index 0000000000..84e99bb582
--- /dev/null
+++ b/api/core/workflow/conversation_variable_updater.py
@@ -0,0 +1,39 @@
+import abc
+from typing import Protocol
+
+from core.variables import Variable
+
+
+class ConversationVariableUpdater(Protocol):
+ """
+ ConversationVariableUpdater defines an abstraction for updating conversation variable values.
+
+ It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating
+ conversation variables.
+
+ Implementations may choose to batch updates. If batching is used, the `flush` method
+ should be implemented to persist buffered changes, and `update`
+ should handle buffering accordingly.
+
+ Note: Since implementations may buffer updates, instances of ConversationVariableUpdater
+ are not thread-safe. Each VariableAssignerNode should create its own instance during execution.
+ """
+
+ @abc.abstractmethod
+ def update(self, conversation_id: str, variable: "Variable") -> None:
+ """
+ Updates the value of the specified conversation variable in the underlying storage.
+
+ :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
+ :param variable: The `Variable` instance containing the updated value.
+ """
+ pass
+
+ @abc.abstractmethod
+ def flush(self):
+ """
+ Flushes all pending updates to the underlying storage system.
+
+ If the implementation does not buffer updates, this method can be a no-op.
+ """
+ pass
diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py
index af26864c01..fbb8df6b01 100644
--- a/api/core/workflow/entities/variable_pool.py
+++ b/api/core/workflow/entities/variable_pool.py
@@ -1,18 +1,19 @@
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
-from typing import Any, Union
+from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
+from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment
+from core.variables.variables import VariableUnion
+from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
+from core.workflow.system_variable import SystemVariable
from factories import variable_factory
-from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
-from ..enums import SystemVariableKey
-
VariableValue = Union[str, int, float, dict, list, File]
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
@@ -23,50 +24,31 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
- variable_dictionary: dict[str, dict[int, Segment]] = Field(
+ variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
- # TODO: This user inputs is not used for pool.
+
+ # The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere.
user_inputs: Mapping[str, Any] = Field(
description="User inputs",
+ default_factory=dict,
)
- system_variables: Mapping[SystemVariableKey, Any] = Field(
+ system_variables: SystemVariable = Field(
description="System variables",
)
- environment_variables: Sequence[Variable] = Field(
+ environment_variables: Sequence[VariableUnion] = Field(
description="Environment variables.",
default_factory=list,
)
- conversation_variables: Sequence[Variable] = Field(
+ conversation_variables: Sequence[VariableUnion] = Field(
description="Conversation variables.",
default_factory=list,
)
- def __init__(
- self,
- *,
- system_variables: Mapping[SystemVariableKey, Any] | None = None,
- user_inputs: Mapping[str, Any] | None = None,
- environment_variables: Sequence[Variable] | None = None,
- conversation_variables: Sequence[Variable] | None = None,
- **kwargs,
- ):
- environment_variables = environment_variables or []
- conversation_variables = conversation_variables or []
- user_inputs = user_inputs or {}
- system_variables = system_variables or {}
-
- super().__init__(
- system_variables=system_variables,
- user_inputs=user_inputs,
- environment_variables=environment_variables,
- conversation_variables=conversation_variables,
- **kwargs,
- )
-
- for key, value in self.system_variables.items():
- self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
+ def model_post_init(self, context: Any, /) -> None:
+ # Create a mapping from field names to SystemVariableKey enum values
+ self._add_system_variables(self.system_variables)
# Add environment variables to the variable pool
for var in self.environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
@@ -91,19 +73,33 @@ class VariablePool(BaseModel):
Returns:
None
"""
- if len(selector) < 2:
+ if len(selector) < MIN_SELECTORS_LENGTH:
raise ValueError("Invalid selector")
if isinstance(value, Variable):
variable = value
- if isinstance(value, Segment):
+ elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
else:
segment = variable_factory.build_segment(value)
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
- hash_key = hash(tuple(selector[1:]))
- self.variable_dictionary[selector[0]][hash_key] = variable
+ key, hash_key = self._selector_to_keys(selector)
+ # Based on the definition of `VariableUnion`,
+ # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
+ self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
+
+ @classmethod
+ def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]:
+ return selector[0], hash(tuple(selector[1:]))
+
+ def _has(self, selector: Sequence[str]) -> bool:
+ key, hash_key = self._selector_to_keys(selector)
+ if key not in self.variable_dictionary:
+ return False
+ if hash_key not in self.variable_dictionary[key]:
+ return False
+ return True
def get(self, selector: Sequence[str], /) -> Segment | None:
"""
@@ -118,11 +114,11 @@ class VariablePool(BaseModel):
Raises:
ValueError: If the selector is invalid.
"""
- if len(selector) < 2:
+ if len(selector) < MIN_SELECTORS_LENGTH:
return None
- hash_key = hash(tuple(selector[1:]))
- value = self.variable_dictionary[selector[0]].get(hash_key)
+ key, hash_key = self._selector_to_keys(selector)
+ value: Segment | None = self.variable_dictionary[key].get(hash_key)
if value is None:
selector, attr = selector[:-1], selector[-1]
@@ -155,8 +151,8 @@ class VariablePool(BaseModel):
if len(selector) == 1:
self.variable_dictionary[selector[0]] = {}
return
- hash_key = hash(tuple(selector[1:]))
- self.variable_dictionary[selector[0]].pop(hash_key, None)
+ key, hash_key = self._selector_to_keys(selector)
+ self.variable_dictionary[key].pop(hash_key, None)
def convert_template(self, template: str, /):
parts = VARIABLE_PATTERN.split(template)
@@ -173,3 +169,20 @@ class VariablePool(BaseModel):
if isinstance(segment, FileSegment):
return segment
return None
+
+ def _add_system_variables(self, system_variable: SystemVariable):
+ sys_var_mapping = system_variable.to_dict()
+ for key, value in sys_var_mapping.items():
+ if value is None:
+ continue
+ selector = (SYSTEM_VARIABLE_NODE_ID, key)
+ # If the system variable already exists, do not add it again.
+ # This ensures that we can keep the id of the system variables intact.
+ if self._has(selector):
+ continue
+ self.add(selector, value) # type: ignore
+
+ @classmethod
+ def empty(cls) -> "VariablePool":
+ """Create an empty variable pool."""
+ return cls(system_variables=SystemVariable.empty())
diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py
deleted file mode 100644
index 8896416f12..0000000000
--- a/api/core/workflow/entities/workflow_entities.py
+++ /dev/null
@@ -1,79 +0,0 @@
-from typing import Optional
-
-from pydantic import BaseModel
-
-from core.app.entities.app_invoke_entities import InvokeFrom
-from core.workflow.nodes.base import BaseIterationState, BaseLoopState, BaseNode
-from models.enums import UserFrom
-from models.workflow import Workflow, WorkflowType
-
-from .node_entities import NodeRunResult
-from .variable_pool import VariablePool
-
-
-class WorkflowNodeAndResult:
- node: BaseNode
- result: Optional[NodeRunResult] = None
-
- def __init__(self, node: BaseNode, result: Optional[NodeRunResult] = None):
- self.node = node
- self.result = result
-
-
-class WorkflowRunState:
- tenant_id: str
- app_id: str
- workflow_id: str
- workflow_type: WorkflowType
- user_id: str
- user_from: UserFrom
- invoke_from: InvokeFrom
-
- workflow_call_depth: int
-
- start_at: float
- variable_pool: VariablePool
-
- total_tokens: int = 0
-
- workflow_nodes_and_results: list[WorkflowNodeAndResult]
-
- class NodeRun(BaseModel):
- node_id: str
- iteration_node_id: str
- loop_node_id: str
-
- workflow_node_runs: list[NodeRun]
- workflow_node_steps: int
-
- current_iteration_state: Optional[BaseIterationState]
- current_loop_state: Optional[BaseLoopState]
-
- def __init__(
- self,
- workflow: Workflow,
- start_at: float,
- variable_pool: VariablePool,
- user_id: str,
- user_from: UserFrom,
- invoke_from: InvokeFrom,
- workflow_call_depth: int,
- ):
- self.workflow_id = workflow.id
- self.tenant_id = workflow.tenant_id
- self.app_id = workflow.app_id
- self.workflow_type = WorkflowType.value_of(workflow.type)
- self.user_id = user_id
- self.user_from = user_from
- self.invoke_from = invoke_from
- self.workflow_call_depth = workflow_call_depth
-
- self.start_at = start_at
- self.variable_pool = variable_pool
-
- self.total_tokens = 0
-
- self.workflow_node_steps = 1
- self.workflow_node_runs = []
- self.current_iteration_state = None
- self.current_loop_state = None
diff --git a/api/core/workflow/entities/workflow_node_execution.py b/api/core/workflow/entities/workflow_node_execution.py
index 773f5b777b..09a408f4d7 100644
--- a/api/core/workflow/entities/workflow_node_execution.py
+++ b/api/core/workflow/entities/workflow_node_execution.py
@@ -66,11 +66,21 @@ class WorkflowNodeExecution(BaseModel):
but they are not stored in the model.
"""
- # Core identification fields
- id: str # Unique identifier for this execution record
- node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing
+ # --------- Core identification fields ---------
+
+ # Unique identifier for this execution record, used when persisting to storage.
+ # Value is a UUID string (e.g., '09b3e04c-f9ae-404c-ad82-290b8d7bd382').
+ id: str
+
+ # Optional secondary ID for cross-referencing purposes.
+ #
+ # NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`.
+ # While `node_execution_id` may sometimes be a UUID string, this is not guaranteed.
+ # In most scenarios, `id` should be used as the primary identifier.
+ node_execution_id: Optional[str] = None
workflow_id: str # ID of the workflow this node belongs to
workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging)
+ # --------- Core identification fields ends ---------
# Execution positioning and flow
index: int # Sequence number for ordering in trace visualization
diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py
index bd4ccc1072..594bb2b32e 100644
--- a/api/core/workflow/errors.py
+++ b/api/core/workflow/errors.py
@@ -2,7 +2,7 @@ from core.workflow.nodes.base import BaseNode
class WorkflowNodeRunFailedError(Exception):
- def __init__(self, node_instance: BaseNode, error: str):
- self.node_instance = node_instance
- self.error = error
- super().__init__(f"Node {node_instance.node_data.title} run failed: {error}")
+ def __init__(self, node: BaseNode, err_msg: str):
+ self._node = node
+ self._error = err_msg
+ super().__init__(f"Node {node.title} run failed: {err_msg}")
diff --git a/api/core/workflow/graph_engine/__init__.py b/api/core/workflow/graph_engine/__init__.py
index 2fee3d7fad..12e1de464b 100644
--- a/api/core/workflow/graph_engine/__init__.py
+++ b/api/core/workflow/graph_engine/__init__.py
@@ -1,3 +1,4 @@
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
+from .graph_engine import GraphEngine
-__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
+__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py
index 9a4939502e..e57e9e4d64 100644
--- a/api/core/workflow/graph_engine/entities/event.py
+++ b/api/core/workflow/graph_engine/entities/event.py
@@ -66,6 +66,8 @@ class BaseNodeEvent(GraphEngineEvent):
"""iteration id if node is in iteration"""
in_loop_id: Optional[str] = None
"""loop id if node is in loop"""
+ # The version of the node, or "1" if not specified.
+ node_version: str = "1"
class NodeRunStartedEvent(BaseNodeEvent):
diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py
index 8e5b1e7142..362777a199 100644
--- a/api/core/workflow/graph_engine/entities/graph.py
+++ b/api/core/workflow/graph_engine/entities/graph.py
@@ -334,7 +334,7 @@ class Graph(BaseModel):
parallel = GraphParallel(
start_from_node_id=start_node_id,
- parent_parallel_id=parent_parallel.id if parent_parallel else None,
+ parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None,
)
parallel_mapping[parallel.id] = parallel
diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py
index afc09bfac5..a62ffe46c9 100644
--- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py
+++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py
@@ -17,8 +17,12 @@ class GraphRuntimeState(BaseModel):
"""total tokens"""
llm_usage: LLMUsage = LLMUsage.empty_usage()
"""llm usage info"""
+
+ # The `outputs` field stores the final output values generated by executing workflows or chatflows.
+ #
+ # Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
+ # after a serialization and deserialization round trip.
outputs: dict[str, Any] = {}
- """outputs"""
node_run_steps: int = 0
"""node run steps"""
diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py
index 875cee17e6..b315129763 100644
--- a/api/core/workflow/graph_engine/graph_engine.py
+++ b/api/core/workflow/graph_engine/graph_engine.py
@@ -12,7 +12,7 @@ from typing import Any, Optional, cast
from flask import Flask, current_app
from configs import dify_config
-from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
@@ -48,11 +48,10 @@ from core.workflow.nodes.agent.entities import AgentNodeData
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
-from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
+from core.workflow.utils import variable_utils
from libs.flask_utils import preserve_flask_contexts
from models.enums import UserFrom
from models.workflow import WorkflowType
@@ -102,7 +101,7 @@ class GraphEngine:
call_depth: int,
graph: Graph,
graph_config: Mapping[str, Any],
- variable_pool: VariablePool,
+ graph_runtime_state: GraphRuntimeState,
max_execution_steps: int,
max_execution_time: int,
thread_pool_id: Optional[str] = None,
@@ -139,7 +138,7 @@ class GraphEngine:
call_depth=call_depth,
)
- self.graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
+ self.graph_runtime_state = graph_runtime_state
self.max_execution_steps = max_execution_steps
self.max_execution_time = max_execution_time
@@ -259,12 +258,16 @@ class GraphEngine:
# convert to specific node
node_type = NodeType(node_config.get("data", {}).get("type"))
node_version = node_config.get("data", {}).get("version", "1")
+
+ # Import here to avoid circular import
+ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
+
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
# init workflow run state
- node_instance = node_cls( # type: ignore
+ node = node_cls(
id=route_node_state.id,
config=node_config,
graph_init_params=self.init_params,
@@ -273,11 +276,11 @@ class GraphEngine:
previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id,
)
- node_instance = cast(BaseNode[BaseNodeData], node_instance)
+ node.init_node_data(node_config.get("data", {}))
try:
# run node
generator = self._run_node(
- node_instance=node_instance,
+ node=node,
route_node_state=route_node_state,
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
@@ -305,15 +308,16 @@ class GraphEngine:
route_node_state.failed_reason = str(e)
yield NodeRunFailedEvent(
error=str(e),
- id=node_instance.id,
+ id=node.id,
node_id=next_node_id,
node_type=node_type,
- node_data=node_instance.node_data,
+ node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
+ node_version=node.version(),
)
raise e
@@ -335,7 +339,7 @@ class GraphEngine:
edge = edge_mappings[0]
if (
previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
- and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
+ and node.error_strategy == ErrorStrategy.FAIL_BRANCH
and edge.run_condition is None
):
break
@@ -411,8 +415,8 @@ class GraphEngine:
next_node_id = final_node_id
elif (
- node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
- and node_instance.should_continue_on_error
+ node.continue_on_error
+ and node.error_strategy == ErrorStrategy.FAIL_BRANCH
and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
):
break
@@ -595,7 +599,7 @@ class GraphEngine:
def _run_node(
self,
- node_instance: BaseNode[BaseNodeData],
+ node: BaseNode,
route_node_state: RouteNodeState,
parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
@@ -609,28 +613,29 @@ class GraphEngine:
# trigger node run start event
agent_strategy = (
AgentNodeStrategyInit(
- name=cast(AgentNodeData, node_instance.node_data).agent_strategy_name,
- icon=cast(AgentNode, node_instance).agent_strategy_icon,
+ name=cast(AgentNodeData, node.get_base_node_data()).agent_strategy_name,
+ icon=cast(AgentNode, node).agent_strategy_icon,
)
- if node_instance.node_type == NodeType.AGENT
+ if node.type_ == NodeType.AGENT
else None
)
yield NodeRunStartedEvent(
- id=node_instance.id,
- node_id=node_instance.node_id,
- node_type=node_instance.node_type,
- node_data=node_instance.node_data,
+ id=node.id,
+ node_id=node.node_id,
+ node_type=node.type_,
+ node_data=node.get_base_node_data(),
route_node_state=route_node_state,
- predecessor_node_id=node_instance.previous_node_id,
+ predecessor_node_id=node.previous_node_id,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
agent_strategy=agent_strategy,
+ node_version=node.version(),
)
- max_retries = node_instance.node_data.retry_config.max_retries
- retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
+ max_retries = node.retry_config.max_retries
+ retry_interval = node.retry_config.retry_interval_seconds
retries = 0
should_continue_retry = True
while should_continue_retry and retries <= max_retries:
@@ -639,44 +644,37 @@ class GraphEngine:
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
# yield control to other threads
time.sleep(0.001)
- generator = node_instance.run()
- for item in generator:
- if isinstance(item, GraphEngineEvent):
- if isinstance(item, BaseIterationEvent):
- # add parallel info to iteration event
- item.parallel_id = parallel_id
- item.parallel_start_node_id = parallel_start_node_id
- item.parent_parallel_id = parent_parallel_id
- item.parent_parallel_start_node_id = parent_parallel_start_node_id
- elif isinstance(item, BaseLoopEvent):
- # add parallel info to loop event
- item.parallel_id = parallel_id
- item.parallel_start_node_id = parallel_start_node_id
- item.parent_parallel_id = parent_parallel_id
- item.parent_parallel_start_node_id = parent_parallel_start_node_id
-
- yield item
+ event_stream = node.run()
+ for event in event_stream:
+ if isinstance(event, GraphEngineEvent):
+ # add parallel info to iteration event
+ if isinstance(event, BaseIterationEvent | BaseLoopEvent):
+ event.parallel_id = parallel_id
+ event.parallel_start_node_id = parallel_start_node_id
+ event.parent_parallel_id = parent_parallel_id
+ event.parent_parallel_start_node_id = parent_parallel_start_node_id
+ yield event
else:
- if isinstance(item, RunCompletedEvent):
- run_result = item.run_result
+ if isinstance(event, RunCompletedEvent):
+ run_result = event.run_result
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if (
retries == max_retries
- and node_instance.node_type == NodeType.HTTP_REQUEST
+ and node.type_ == NodeType.HTTP_REQUEST
and run_result.outputs
- and not node_instance.should_continue_on_error
+ and not node.continue_on_error
):
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
- if node_instance.should_retry and retries < max_retries:
+ if node.retry and retries < max_retries:
retries += 1
route_node_state.node_run_result = run_result
yield NodeRunRetryEvent(
id=str(uuid.uuid4()),
- node_id=node_instance.node_id,
- node_type=node_instance.node_type,
- node_data=node_instance.node_data,
+ node_id=node.node_id,
+ node_type=node.type_,
+ node_data=node.get_base_node_data(),
route_node_state=route_node_state,
- predecessor_node_id=node_instance.previous_node_id,
+ predecessor_node_id=node.previous_node_id,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
@@ -684,17 +682,18 @@ class GraphEngine:
error=run_result.error or "Unknown error",
retry_index=retries,
start_at=retry_start_at,
+ node_version=node.version(),
)
time.sleep(retry_interval)
break
route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
- if node_instance.should_continue_on_error:
+ if node.continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
- node_instance,
- item.run_result,
+ node,
+ event.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
)
@@ -704,42 +703,44 @@ class GraphEngine:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
- node_id=node_instance.node_id,
+ node_id=node.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
- id=node_instance.id,
- node_id=node_instance.node_id,
- node_type=node_instance.node_type,
- node_data=node_instance.node_data,
+ id=node.id,
+ node_id=node.node_id,
+ node_type=node.type_,
+ node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
+ node_version=node.version(),
)
should_continue_retry = False
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
- id=node_instance.id,
- node_id=node_instance.node_id,
- node_type=node_instance.node_type,
- node_data=node_instance.node_data,
+ id=node.id,
+ node_id=node.node_id,
+ node_type=node.type_,
+ node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
+ node_version=node.version(),
)
should_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if (
- node_instance.should_continue_on_error
- and self.graph.edge_mapping.get(node_instance.node_id)
- and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH
+ node.continue_on_error
+ and self.graph.edge_mapping.get(node.node_id)
+ and node.error_strategy is ErrorStrategy.FAIL_BRANCH
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(
@@ -759,7 +760,7 @@ class GraphEngine:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
- node_id=node_instance.node_id,
+ node_id=node.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
@@ -784,46 +785,49 @@ class GraphEngine:
run_result.metadata = metadata_dict
yield NodeRunSucceededEvent(
- id=node_instance.id,
- node_id=node_instance.node_id,
- node_type=node_instance.node_type,
- node_data=node_instance.node_data,
+ id=node.id,
+ node_id=node.node_id,
+ node_type=node.type_,
+ node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
+ node_version=node.version(),
)
should_continue_retry = False
break
- elif isinstance(item, RunStreamChunkEvent):
+ elif isinstance(event, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
- id=node_instance.id,
- node_id=node_instance.node_id,
- node_type=node_instance.node_type,
- node_data=node_instance.node_data,
- chunk_content=item.chunk_content,
- from_variable_selector=item.from_variable_selector,
+ id=node.id,
+ node_id=node.node_id,
+ node_type=node.type_,
+ node_data=node.get_base_node_data(),
+ chunk_content=event.chunk_content,
+ from_variable_selector=event.from_variable_selector,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
+ node_version=node.version(),
)
- elif isinstance(item, RunRetrieverResourceEvent):
+ elif isinstance(event, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
- id=node_instance.id,
- node_id=node_instance.node_id,
- node_type=node_instance.node_type,
- node_data=node_instance.node_data,
- retriever_resources=item.retriever_resources,
- context=item.context,
+ id=node.id,
+ node_id=node.node_id,
+ node_type=node.type_,
+ node_data=node.get_base_node_data(),
+ retriever_resources=event.retriever_resources,
+ context=event.context,
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
+ node_version=node.version(),
)
except GenerateTaskStoppedError:
# trigger node run failed event
@@ -831,19 +835,20 @@ class GraphEngine:
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
- id=node_instance.id,
- node_id=node_instance.node_id,
- node_type=node_instance.node_type,
- node_data=node_instance.node_data,
+ id=node.id,
+ node_id=node.node_id,
+ node_type=node.type_,
+ node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
+ node_version=node.version(),
)
return
except Exception as e:
- logger.exception(f"Node {node_instance.node_data.title} run failed")
+ logger.exception(f"Node {node.title} run failed")
raise e
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
@@ -854,16 +859,12 @@ class GraphEngine:
:param variable_value: variable value
:return:
"""
- self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value)
-
- # if variable_value is a dict, then recursively append variables
- if isinstance(variable_value, dict):
- for key, value in variable_value.items():
- # construct new key list
- new_key_list = variable_key_list + [key]
- self._append_variables_recursively(
- node_id=node_id, variable_key_list=new_key_list, variable_value=value
- )
+ variable_utils.append_variables_recursively(
+ self.graph_runtime_state.variable_pool,
+ node_id,
+ variable_key_list,
+ variable_value,
+ )
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
"""
@@ -887,22 +888,14 @@ class GraphEngine:
def _handle_continue_on_error(
self,
- node_instance: BaseNode[BaseNodeData],
+ node: BaseNode,
error_result: NodeRunResult,
variable_pool: VariablePool,
handle_exceptions: list[str] = [],
) -> NodeRunResult:
- """
- handle continue on error when self._should_continue_on_error is True
-
-
- :param error_result (NodeRunResult): error run result
- :param variable_pool (VariablePool): variable pool
- :return: excption run result
- """
# add error message and error type to variable pool
- variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
- variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
+ variable_pool.add([node.node_id, "error_message"], error_result.error)
+ variable_pool.add([node.node_id, "error_type"], error_result.error_type)
# add error message to handle_exceptions
handle_exceptions.append(error_result.error or "")
node_error_args: dict[str, Any] = {
@@ -910,21 +903,21 @@ class GraphEngine:
"error": error_result.error,
"inputs": error_result.inputs,
"metadata": {
- WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
+ WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy,
},
}
- if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
+ if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
return NodeRunResult(
**node_error_args,
outputs={
- **node_instance.node_data.default_value_dict,
+ **node.default_value_dict,
"error_message": error_result.error,
"error_type": error_result.error_type,
},
)
- elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH:
- if self.graph.edge_mapping.get(node_instance.node_id):
+ elif node.error_strategy is ErrorStrategy.FAIL_BRANCH:
+ if self.graph.edge_mapping.get(node.node_id):
node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED
return NodeRunResult(
**node_error_args,
diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py
index 22c564c1fc..8cf33ac81e 100644
--- a/api/core/workflow/nodes/agent/agent_node.py
+++ b/api/core/workflow/nodes/agent/agent_node.py
@@ -2,54 +2,99 @@ import json
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
+from packaging.version import Version
+from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
+from core.agent.strategy.plugin import PluginAgentStrategy
+from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
+from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
+from core.plugin.entities.request import InvokeCredentials
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller
from core.provider_manager import ProviderManager
-from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
+from core.tools.entities.tool_entities import (
+ ToolIdentity,
+ ToolInvokeMessage,
+ ToolParameter,
+ ToolProviderType,
+)
from core.tools.tool_manager import ToolManager
-from core.variables.segments import StringSegment
+from core.tools.utils.message_transformer import ToolFileMessageTransformer
+from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
+from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
+from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
-from core.workflow.nodes.base.entities import BaseNodeData
-from core.workflow.nodes.enums import NodeType
-from core.workflow.nodes.event.event import RunCompletedEvent
-from core.workflow.nodes.tool.tool_node import ToolNode
+from core.workflow.nodes.base import BaseNode
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
+from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
+from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
+from models import ToolFile
from models.model import Conversation
+from services.tools.builtin_tools_manage_service import BuiltinToolManageService
+from .exc import (
+ AgentInputTypeError,
+ AgentInvocationError,
+ AgentMessageTransformError,
+ AgentVariableNotFoundError,
+ AgentVariableTypeError,
+ ToolFileNotFoundError,
+)
-class AgentNode(ToolNode):
+
+class AgentNode(BaseNode):
"""
Agent Node
"""
- _node_data_cls = AgentNodeData # type: ignore
_node_type = NodeType.AGENT
+ _node_data: AgentNodeData
- def _run(self) -> Generator:
- """
- Run the agent node
- """
- node_data = cast(AgentNodeData, self.node_data)
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = AgentNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
+ def _run(self) -> Generator:
try:
strategy = get_plugin_agent_strategy(
tenant_id=self.tenant_id,
- agent_strategy_provider_name=node_data.agent_strategy_provider_name,
- agent_strategy_name=node_data.agent_strategy_name,
+ agent_strategy_provider_name=self._node_data.agent_strategy_provider_name,
+ agent_strategy_name=self._node_data.agent_strategy_name,
)
except Exception as e:
yield RunCompletedEvent(
@@ -67,14 +112,17 @@ class AgentNode(ToolNode):
parameters = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
- node_data=node_data,
+ node_data=self._node_data,
+ strategy=strategy,
)
parameters_for_log = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
- node_data=node_data,
+ node_data=self._node_data,
for_log=True,
+ strategy=strategy,
)
+ credentials = self._generate_credentials(parameters=parameters)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
@@ -85,34 +133,42 @@ class AgentNode(ToolNode):
user_id=self.user_id,
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
+ credentials=credentials,
)
except Exception as e:
+ error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
- error=f"Failed to invoke agent: {str(e)}",
+ error=str(error),
)
)
return
try:
- # convert tool messages
-
yield from self._transform_message(
- message_stream,
- {
+ messages=message_stream,
+ tool_info={
"icon": self.agent_strategy_icon,
- "agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
+ "agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
},
- parameters_for_log,
+ parameters_for_log=parameters_for_log,
+ user_id=self.user_id,
+ tenant_id=self.tenant_id,
+ node_type=self.type_,
+ node_id=self.node_id,
+ node_execution_id=self.id,
)
except PluginDaemonClientSideError as e:
+ transform_error = AgentMessageTransformError(
+ f"Failed to transform agent message: {str(e)}", original_error=e
+ )
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
- error=f"Failed to transform agent message: {str(e)}",
+ error=str(transform_error),
)
)
@@ -123,6 +179,7 @@ class AgentNode(ToolNode):
variable_pool: VariablePool,
node_data: AgentNodeData,
for_log: bool = False,
+ strategy: PluginAgentStrategy,
) -> dict[str, Any]:
"""
Generate parameters based on the given tool parameters, variable pool, and node data.
@@ -148,13 +205,16 @@ class AgentNode(ToolNode):
if agent_input.type == "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
- raise ValueError(f"Variable {agent_input.value} does not exist")
+ raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
elif agent_input.type in {"mixed", "constant"}:
# variable_pool.convert_template expects a string template,
# but if passing a dict, convert to JSON string first before rendering
try:
- parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
+ if not isinstance(agent_input.value, str):
+ parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
+ else:
+ parameter_value = str(agent_input.value)
except TypeError:
parameter_value = str(agent_input.value)
segment_group = variable_pool.convert_template(parameter_value)
@@ -162,16 +222,17 @@ class AgentNode(ToolNode):
# variable_pool.convert_template returns a string,
# so we need to convert it back to a dictionary
try:
- parameter_value = json.loads(parameter_value)
+ if not isinstance(agent_input.value, str):
+ parameter_value = json.loads(parameter_value)
except json.JSONDecodeError:
parameter_value = parameter_value
else:
- raise ValueError(f"Unknown agent input type '{agent_input.type}'")
+ raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
-
+ value = self._filter_mcp_type_tool(strategy, value)
for tool in value:
if "schemas" in tool:
tool.pop("schemas")
@@ -205,12 +266,20 @@ class AgentNode(ToolNode):
tool_name=tool.get("tool_name", ""),
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
+ credential_id=tool.get("credential_id", None),
)
extra = tool.get("extra", {})
+ # This is an issue that caused problems before.
+ # Logically, we shouldn't use the node_data.version field for judgment
+ # But for backward compatibility with historical data
+ # this version field judgment is still preserved here.
+ runtime_variable_pool: VariablePool | None = None
+ if node_data.version != "1" or node_data.tool_node_version != "1":
+ runtime_variable_pool = variable_pool
tool_runtime = ToolManager.get_agent_tool_runtime(
- self.tenant_id, self.app_id, entity, self.invoke_from
+ self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
)
if tool_runtime.entity.description:
tool_runtime.entity.description.llm = (
@@ -235,6 +304,7 @@ class AgentNode(ToolNode):
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
+ "credential_id": tool.get("credential_id", None),
"provider_type": provider_type.value,
}
)
@@ -264,25 +334,41 @@ class AgentNode(ToolNode):
return result
+ def _generate_credentials(
+ self,
+ parameters: dict[str, Any],
+ ) -> InvokeCredentials:
+ """
+ Generate credentials based on the given agent parameters.
+ """
+
+ credentials = InvokeCredentials()
+
+ # generate credentials for tools selector
+ credentials.tool_credentials = {}
+ for tool in parameters.get("tools", []):
+ if tool.get("credential_id"):
+ try:
+ identity = ToolIdentity.model_validate(tool.get("identity", {}))
+ credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
+ except ValidationError:
+ continue
+ return credentials
+
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: BaseNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
- node_data = cast(AgentNodeData, node_data)
+ # Create typed NodeData from dict
+ typed_node_data = AgentNodeData.model_validate(node_data)
+
result: dict[str, Any] = {}
- for parameter_name in node_data.agent_parameters:
- input = node_data.agent_parameters[parameter_name]
+ for parameter_name in typed_node_data.agent_parameters:
+ input = typed_node_data.agent_parameters[parameter_name]
if input.type in ["mixed", "constant"]:
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
@@ -307,7 +393,7 @@ class AgentNode(ToolNode):
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}"
- == cast(AgentNodeData, self.node_data).agent_strategy_provider_name
+ == cast(AgentNodeData, self._node_data).agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:
@@ -362,3 +448,249 @@ class AgentNode(ToolNode):
except ValueError:
model_schema.features.remove(feature)
return model_schema
+
+ def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
+ """
+ Filter MCP type tool
+ :param strategy: plugin agent strategy
+ :param tool: tool
+ :return: filtered tool dict
+ """
+ meta_version = strategy.meta_version
+ if meta_version and Version(meta_version) > Version("0.0.1"):
+ return tools
+ else:
+ return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value]
+
+ def _transform_message(
+ self,
+ messages: Generator[ToolInvokeMessage, None, None],
+ tool_info: Mapping[str, Any],
+ parameters_for_log: dict[str, Any],
+ user_id: str,
+ tenant_id: str,
+ node_type: NodeType,
+ node_id: str,
+ node_execution_id: str,
+ ) -> Generator:
+ """
+ Convert ToolInvokeMessages into tuple[plain_text, files]
+ """
+ # transform message and handle file storage
+ message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
+ messages=messages,
+ user_id=user_id,
+ tenant_id=tenant_id,
+ conversation_id=None,
+ )
+
+ text = ""
+ files: list[File] = []
+ json_list: list[dict] = []
+
+ agent_logs: list[AgentLogEvent] = []
+ agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
+ llm_usage: LLMUsage | None = None
+ variables: dict[str, Any] = {}
+
+ for message in message_stream:
+ if message.type in {
+ ToolInvokeMessage.MessageType.IMAGE_LINK,
+ ToolInvokeMessage.MessageType.BINARY_LINK,
+ ToolInvokeMessage.MessageType.IMAGE,
+ }:
+ assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+
+ url = message.message.text
+ if message.meta:
+ transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
+ else:
+ transfer_method = FileTransferMethod.TOOL_FILE
+
+ tool_file_id = str(url).split("/")[-1].split(".")[0]
+
+ with Session(db.engine) as session:
+ stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
+ tool_file = session.scalar(stmt)
+ if tool_file is None:
+ raise ToolFileNotFoundError(tool_file_id)
+
+ mapping = {
+ "tool_file_id": tool_file_id,
+ "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
+ "transfer_method": transfer_method,
+ "url": url,
+ }
+ file = file_factory.build_from_mapping(
+ mapping=mapping,
+ tenant_id=tenant_id,
+ )
+ files.append(file)
+ elif message.type == ToolInvokeMessage.MessageType.BLOB:
+ # get tool file id
+ assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+ assert message.meta
+
+ tool_file_id = message.message.text.split("/")[-1].split(".")[0]
+ with Session(db.engine) as session:
+ stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
+ tool_file = session.scalar(stmt)
+ if tool_file is None:
+ raise ToolFileNotFoundError(tool_file_id)
+
+ mapping = {
+ "tool_file_id": tool_file_id,
+ "transfer_method": FileTransferMethod.TOOL_FILE,
+ }
+
+ files.append(
+ file_factory.build_from_mapping(
+ mapping=mapping,
+ tenant_id=tenant_id,
+ )
+ )
+ elif message.type == ToolInvokeMessage.MessageType.TEXT:
+ assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+ text += message.message.text
+ yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
+ elif message.type == ToolInvokeMessage.MessageType.JSON:
+ assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
+ if node_type == NodeType.AGENT:
+ msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
+ llm_usage = LLMUsage.from_metadata(msg_metadata)
+ agent_execution_metadata = {
+ WorkflowNodeExecutionMetadataKey(key): value
+ for key, value in msg_metadata.items()
+ if key in WorkflowNodeExecutionMetadataKey.__members__.values()
+ }
+ if message.message.json_object is not None:
+ json_list.append(message.message.json_object)
+ elif message.type == ToolInvokeMessage.MessageType.LINK:
+ assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+ stream_text = f"Link: {message.message.text}\n"
+ text += stream_text
+ yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
+ elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
+ assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
+ variable_name = message.message.variable_name
+ variable_value = message.message.variable_value
+ if message.message.stream:
+ if not isinstance(variable_value, str):
+ raise AgentVariableTypeError(
+ "When 'stream' is True, 'variable_value' must be a string.",
+ variable_name=variable_name,
+ expected_type="str",
+ actual_type=type(variable_value).__name__,
+ )
+ if variable_name not in variables:
+ variables[variable_name] = ""
+ variables[variable_name] += variable_value
+
+ yield RunStreamChunkEvent(
+ chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
+ )
+ else:
+ variables[variable_name] = variable_value
+ elif message.type == ToolInvokeMessage.MessageType.FILE:
+ assert message.meta is not None
+ assert isinstance(message.meta, File)
+ files.append(message.meta["file"])
+ elif message.type == ToolInvokeMessage.MessageType.LOG:
+ assert isinstance(message.message, ToolInvokeMessage.LogMessage)
+ if message.message.metadata:
+ icon = tool_info.get("icon", "")
+ dict_metadata = dict(message.message.metadata)
+ if dict_metadata.get("provider"):
+ manager = PluginInstaller()
+ plugins = manager.list_plugins(tenant_id)
+ try:
+ current_plugin = next(
+ plugin
+ for plugin in plugins
+ if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
+ )
+ icon = current_plugin.declaration.icon
+ except StopIteration:
+ pass
+ icon_dark = None
+ try:
+ builtin_tool = next(
+ provider
+ for provider in BuiltinToolManageService.list_builtin_tools(
+ user_id,
+ tenant_id,
+ )
+ if provider.name == dict_metadata["provider"]
+ )
+ icon = builtin_tool.icon
+ icon_dark = builtin_tool.icon_dark
+ except StopIteration:
+ pass
+
+ dict_metadata["icon"] = icon
+ dict_metadata["icon_dark"] = icon_dark
+ message.message.metadata = dict_metadata
+ agent_log = AgentLogEvent(
+ id=message.message.id,
+ node_execution_id=node_execution_id,
+ parent_id=message.message.parent_id,
+ error=message.message.error,
+ status=message.message.status.value,
+ data=message.message.data,
+ label=message.message.label,
+ metadata=message.message.metadata,
+ node_id=node_id,
+ )
+
+ # check if the agent log is already in the list
+ for log in agent_logs:
+ if log.id == agent_log.id:
+ # update the log
+ log.data = agent_log.data
+ log.status = agent_log.status
+ log.error = agent_log.error
+ log.label = agent_log.label
+ log.metadata = agent_log.metadata
+ break
+ else:
+ agent_logs.append(agent_log)
+
+ yield agent_log
+
+ # Add agent_logs to outputs['json'] to ensure frontend can access thinking process
+ json_output: list[dict[str, Any]] = []
+
+ # Step 1: append each agent log as its own dict.
+ if agent_logs:
+ for log in agent_logs:
+ json_output.append(
+ {
+ "id": log.id,
+ "parent_id": log.parent_id,
+ "error": log.error,
+ "status": log.status,
+ "data": log.data,
+ "label": log.label,
+ "metadata": log.metadata,
+ "node_id": log.node_id,
+ }
+ )
+ # Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
+ if json_list:
+ json_output.extend(json_list)
+ else:
+ json_output.append({"data": []})
+
+ yield RunCompletedEvent(
+ run_result=NodeRunResult(
+ status=WorkflowNodeExecutionStatus.SUCCEEDED,
+ outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
+ metadata={
+ **agent_execution_metadata,
+ WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
+ WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
+ },
+ inputs=parameters_for_log,
+ llm_usage=llm_usage,
+ )
+ )
diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py
index 075a41fb2f..11b11068e7 100644
--- a/api/core/workflow/nodes/agent/entities.py
+++ b/api/core/workflow/nodes/agent/entities.py
@@ -13,6 +13,10 @@ class AgentNodeData(BaseNodeData):
agent_strategy_name: str
agent_strategy_label: str # redundancy
memory: MemoryConfig | None = None
+ # The version of the tool parameter.
+ # If this value is None, it indicates this is a previous version
+ # and requires using the legacy parameter parsing rules.
+ tool_node_version: str | None = None
class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any]
diff --git a/api/core/workflow/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exc.py
new file mode 100644
index 0000000000..d5955bdd7d
--- /dev/null
+++ b/api/core/workflow/nodes/agent/exc.py
@@ -0,0 +1,124 @@
+from typing import Optional
+
+
+class AgentNodeError(Exception):
+ """Base exception for all agent node errors."""
+
+ def __init__(self, message: str):
+ self.message = message
+ super().__init__(self.message)
+
+
+class AgentStrategyError(AgentNodeError):
+ """Exception raised when there's an error with the agent strategy."""
+
+ def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
+ self.strategy_name = strategy_name
+ self.provider_name = provider_name
+ super().__init__(message)
+
+
+class AgentStrategyNotFoundError(AgentStrategyError):
+ """Exception raised when the specified agent strategy is not found."""
+
+ def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
+ super().__init__(
+ f"Agent strategy '{strategy_name}' not found"
+ + (f" for provider '{provider_name}'" if provider_name else ""),
+ strategy_name,
+ provider_name,
+ )
+
+
+class AgentInvocationError(AgentNodeError):
+ """Exception raised when there's an error invoking the agent."""
+
+ def __init__(self, message: str, original_error: Optional[Exception] = None):
+ self.original_error = original_error
+ super().__init__(message)
+
+
+class AgentParameterError(AgentNodeError):
+ """Exception raised when there's an error with agent parameters."""
+
+ def __init__(self, message: str, parameter_name: Optional[str] = None):
+ self.parameter_name = parameter_name
+ super().__init__(message)
+
+
+class AgentVariableError(AgentNodeError):
+ """Exception raised when there's an error with variables in the agent node."""
+
+ def __init__(self, message: str, variable_name: Optional[str] = None):
+ self.variable_name = variable_name
+ super().__init__(message)
+
+
+class AgentVariableNotFoundError(AgentVariableError):
+ """Exception raised when a variable is not found in the variable pool."""
+
+ def __init__(self, variable_name: str):
+ super().__init__(f"Variable '{variable_name}' does not exist", variable_name)
+
+
+class AgentInputTypeError(AgentNodeError):
+ """Exception raised when an unknown agent input type is encountered."""
+
+ def __init__(self, input_type: str):
+ super().__init__(f"Unknown agent input type '{input_type}'")
+
+
+class ToolFileError(AgentNodeError):
+ """Exception raised when there's an error with a tool file."""
+
+ def __init__(self, message: str, file_id: Optional[str] = None):
+ self.file_id = file_id
+ super().__init__(message)
+
+
+class ToolFileNotFoundError(ToolFileError):
+ """Exception raised when a tool file is not found."""
+
+ def __init__(self, file_id: str):
+ super().__init__(f"Tool file '{file_id}' does not exist", file_id)
+
+
+class AgentMessageTransformError(AgentNodeError):
+ """Exception raised when there's an error transforming agent messages."""
+
+ def __init__(self, message: str, original_error: Optional[Exception] = None):
+ self.original_error = original_error
+ super().__init__(message)
+
+
+class AgentModelError(AgentNodeError):
+ """Exception raised when there's an error with the model used by the agent."""
+
+ def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
+ self.model_name = model_name
+ self.provider = provider
+ super().__init__(message)
+
+
+class AgentMemoryError(AgentNodeError):
+ """Exception raised when there's an error with the agent's memory."""
+
+ def __init__(self, message: str, conversation_id: Optional[str] = None):
+ self.conversation_id = conversation_id
+ super().__init__(message)
+
+
+class AgentVariableTypeError(AgentNodeError):
+ """Exception raised when a variable has an unexpected type."""
+
+ def __init__(
+ self,
+ message: str,
+ variable_name: Optional[str] = None,
+ expected_type: Optional[str] = None,
+ actual_type: Optional[str] = None,
+ ):
+ self.variable_name = variable_name
+ self.expected_type = expected_type
+ self.actual_type = actual_type
+ super().__init__(message)
diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py
index aa030870e2..84bbabca73 100644
--- a/api/core/workflow/nodes/answer/answer_node.py
+++ b/api/core/workflow/nodes/answer/answer_node.py
@@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
-from typing import Any, cast
+from typing import Any, Optional, cast
from core.variables import ArrayFileSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
@@ -12,13 +12,40 @@ from core.workflow.nodes.answer.entities import (
VarGenerateRouteChunk,
)
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
-class AnswerNode(BaseNode[AnswerNodeData]):
- _node_data_cls = AnswerNodeData
- _node_type: NodeType = NodeType.ANSWER
+class AnswerNode(BaseNode):
+ _node_type = NodeType.ANSWER
+
+ _node_data: AnswerNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = AnswerNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @classmethod
+ def version(cls) -> str:
+ return "1"
def _run(self) -> NodeRunResult:
"""
@@ -26,7 +53,7 @@ class AnswerNode(BaseNode[AnswerNodeData]):
:return:
"""
# generate routes
- generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data)
+ generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data)
answer = ""
files = []
@@ -45,7 +72,10 @@ class AnswerNode(BaseNode[AnswerNodeData]):
part = cast(TextGenerateRouteChunk, part)
answer += part.text
- return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files})
+ return NodeRunResult(
+ status=WorkflowNodeExecutionStatus.SUCCEEDED,
+ outputs={"answer": answer, "files": ArrayFileSegment(value=files)},
+ )
@classmethod
def _extract_variable_selector_to_variable_mapping(
@@ -53,16 +83,12 @@ class AnswerNode(BaseNode[AnswerNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: AnswerNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
- variable_template_parser = VariableTemplateParser(template=node_data.answer)
+ # Create typed NodeData from dict
+ typed_node_data = AnswerNodeData.model_validate(node_data)
+
+ variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}
diff --git a/api/core/workflow/nodes/answer/answer_stream_processor.py b/api/core/workflow/nodes/answer/answer_stream_processor.py
index ba6ba16e36..97666fad05 100644
--- a/api/core/workflow/nodes/answer/answer_stream_processor.py
+++ b/api/core/workflow/nodes/answer/answer_stream_processor.py
@@ -2,7 +2,6 @@ import logging
from collections.abc import Generator
from typing import cast
-from core.file import FILE_MODEL_IDENTITY, File
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import (
GraphEngineEvent,
@@ -109,6 +108,7 @@ class AnswerStreamProcessor(StreamProcessor):
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
from_variable_selector=[answer_node_id, "answer"],
+ node_version=event.node_version,
)
else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
@@ -134,6 +134,7 @@ class AnswerStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
+ node_version=event.node_version,
)
self.route_position[answer_node_id] += 1
@@ -199,44 +200,3 @@ class AnswerStreamProcessor(StreamProcessor):
stream_out_answer_node_ids.append(answer_node_id)
return stream_out_answer_node_ids
-
- @classmethod
- def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]:
- """
- Fetch files from variable value
- :param value: variable value
- :return:
- """
- if not value:
- return []
-
- files = []
- if isinstance(value, list):
- for item in value:
- file_var = cls._get_file_var_from_value(item)
- if file_var:
- files.append(file_var)
- elif isinstance(value, dict):
- file_var = cls._get_file_var_from_value(value)
- if file_var:
- files.append(file_var)
-
- return files
-
- @classmethod
- def _get_file_var_from_value(cls, value: dict | list):
- """
- Get file var from value
- :param value: variable value
- :return:
- """
- if not value:
- return None
-
- if isinstance(value, dict):
- if "dify_model_identity" in value and value["dify_model_identity"] == FILE_MODEL_IDENTITY:
- return value
- elif isinstance(value, File):
- return value.to_dict()
-
- return None
diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py
index 6671ff0746..09d5464d7a 100644
--- a/api/core/workflow/nodes/answer/base_stream_processor.py
+++ b/api/core/workflow/nodes/answer/base_stream_processor.py
@@ -57,7 +57,6 @@ class StreamProcessor(ABC):
# The branch_identify parameter is added to ensure that
# only nodes in the correct logical branch are included.
- reachable_node_ids.append(edge.target_node_id)
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
reachable_node_ids.extend(ids)
else:
@@ -74,6 +73,8 @@ class StreamProcessor(ABC):
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
+ if node_id not in self.rest_node_ids:
+ self.rest_node_ids.append(node_id)
node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id:
diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py
index d853eb71be..dcfed5eed2 100644
--- a/api/core/workflow/nodes/base/entities.py
+++ b/api/core/workflow/nodes/base/entities.py
@@ -122,13 +122,13 @@ class RetryConfig(BaseModel):
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
+ version: str = "1"
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
- version: str = "1"
retry_config: RetryConfig = RetryConfig()
@property
- def default_value_dict(self):
+ def default_value_dict(self) -> dict[str, Any]:
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}
diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py
index 7da0c19740..fb5ec55453 100644
--- a/api/core/workflow/nodes/base/node.py
+++ b/api/core/workflow/nodes/base/node.py
@@ -1,29 +1,23 @@
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
-from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast
+from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
-from .entities import BaseNodeData
-
if TYPE_CHECKING:
+ from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.event import InNodeEvent
- from core.workflow.graph_engine.entities.graph import Graph
- from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
- from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
logger = logging.getLogger(__name__)
-GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
-
-class BaseNode(Generic[GenericNodeData]):
- _node_data_cls: type[GenericNodeData]
- _node_type: NodeType
+class BaseNode:
+ _node_type: ClassVar[NodeType]
def __init__(
self,
@@ -56,8 +50,8 @@ class BaseNode(Generic[GenericNodeData]):
self.node_id = node_id
- node_data = self._node_data_cls.model_validate(config.get("data", {}))
- self.node_data = node_data
+ @abstractmethod
+ def init_node_data(self, data: Mapping[str, Any]) -> None: ...
@abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
@@ -90,8 +84,38 @@ class BaseNode(Generic[GenericNodeData]):
graph_config: Mapping[str, Any],
config: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
+ """Extracts references variable selectors from node configuration.
+
+ The `config` parameter represents the configuration for a specific node type and corresponds
+ to the `data` field in the node definition object.
+
+ The returned mapping has the following structure:
+
+ {'1747829548239.#1747829667553.result#': ['1747829667553', 'result']}
+
+ For loop and iteration nodes, the mapping may look like this:
+
+ {
+ "1748332301644.input_selector": ["1748332363630", "result"],
+ "1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"],
+ }
+
+ where `1748332301644` is the ID of the loop / iteration node,
+ and `1748332325079` is the ID of the node inside the loop or iteration node.
+
+ Here, the key consists of two parts: the current node ID (provided as the `node_id`
+ parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector,
+ enclosed in `#` symbols. These two parts are separated by a dot (`.`).
+
+ The value is a list of string representing the variable selector, where the first element is the node ID
+ of the referenced variable, and the second element is the variable name within that node.
+
+ The meaning of the above response is:
+
+ The node with ID `1747829548239` references the variable `result` from the node with
+ ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a
+ reference to the `result` output variable of node `1747829667553`.
+
:param graph_config: graph config
:param config: node config
:return:
@@ -100,10 +124,11 @@ class BaseNode(Generic[GenericNodeData]):
if not node_id:
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
- node_data = cls._node_data_cls(**config.get("data", {}))
- return cls._extract_variable_selector_to_variable_mapping(
- graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
+ # Pass raw dict data instead of creating NodeData instance
+ data = cls._extract_variable_selector_to_variable_mapping(
+ graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
)
+ return data
@classmethod
def _extract_variable_selector_to_variable_mapping(
@@ -111,48 +136,91 @@ class BaseNode(Generic[GenericNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: GenericNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
return {}
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
- """
- Get default config of node.
- :param filters: filter by node config parameters.
- :return:
- """
return {}
@property
- def node_type(self) -> NodeType:
- """
- Get node type
- :return:
- """
+ def type_(self) -> NodeType:
return self._node_type
+ @classmethod
+ @abstractmethod
+ def version(cls) -> str:
+ """`node_version` returns the version of current node type."""
+ # NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
+ #
+ # If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
+ # in `api/core/workflow/nodes/__init__.py`.
+ raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
+
@property
- def should_continue_on_error(self) -> bool:
- """judge if should continue on error
+ def continue_on_error(self) -> bool:
+ return False
- Returns:
- bool: if should continue on error
- """
- return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
+ @property
+ def retry(self) -> bool:
+ return False
+
+ # Abstract methods that subclasses must implement to provide access
+ # to BaseNodeData properties in a type-safe way
+
+ @abstractmethod
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ """Get the error strategy for this node."""
+ ...
+
+ @abstractmethod
+ def _get_retry_config(self) -> RetryConfig:
+ """Get the retry configuration for this node."""
+ ...
+
+ @abstractmethod
+ def _get_title(self) -> str:
+ """Get the node title."""
+ ...
+
+ @abstractmethod
+ def _get_description(self) -> Optional[str]:
+ """Get the node description."""
+ ...
+ @abstractmethod
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ """Get the default values dictionary for this node."""
+ ...
+
+ @abstractmethod
+ def get_base_node_data(self) -> BaseNodeData:
+ """Get the BaseNodeData object for this node."""
+ ...
+
+ # Public interface properties that delegate to abstract methods
@property
- def should_retry(self) -> bool:
- """judge if should retry
+ def error_strategy(self) -> Optional[ErrorStrategy]:
+ """Get the error strategy for this node."""
+ return self._get_error_strategy()
- Returns:
- bool: if should retry
- """
- return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE
+ @property
+ def retry_config(self) -> RetryConfig:
+ """Get the retry configuration for this node."""
+ return self._get_retry_config()
+
+ @property
+ def title(self) -> str:
+ """Get the node title."""
+ return self._get_title()
+
+ @property
+ def description(self) -> Optional[str]:
+ """Get the node description."""
+ return self._get_description()
+
+ @property
+ def default_value_dict(self) -> dict[str, Any]:
+ """Get the default values dictionary for this node."""
+ return self._get_default_value_dict()
diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py
index 61c08a7d71..fdf3932827 100644
--- a/api/core/workflow/nodes/code/code_node.py
+++ b/api/core/workflow/nodes/code/code_node.py
@@ -1,4 +1,5 @@
from collections.abc import Mapping, Sequence
+from decimal import Decimal
from typing import Any, Optional
from configs import dify_config
@@ -10,8 +11,9 @@ from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.code.entities import CodeNodeData
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .exc import (
CodeNodeError,
@@ -20,10 +22,32 @@ from .exc import (
)
-class CodeNode(BaseNode[CodeNodeData]):
- _node_data_cls = CodeNodeData
+class CodeNode(BaseNode):
_node_type = NodeType.CODE
+ _node_data: CodeNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = CodeNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
@@ -40,14 +64,18 @@ class CodeNode(BaseNode[CodeNodeData]):
return code_provider.get_default_config()
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
def _run(self) -> NodeRunResult:
# Get code language
- code_language = self.node_data.code_language
- code = self.node_data.code
+ code_language = self._node_data.code_language
+ code = self._node_data.code
# Get variables
variables = {}
- for variable_selector in self.node_data.variables:
+ for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if isinstance(variable, ArrayFileSegment):
@@ -63,7 +91,7 @@ class CodeNode(BaseNode[CodeNodeData]):
)
# Transform result
- result = self._transform_result(result=result, output_schema=self.node_data.outputs)
+ result = self._transform_result(result=result, output_schema=self._node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
@@ -110,8 +138,10 @@ class CodeNode(BaseNode[CodeNodeData]):
)
if isinstance(value, float):
+ decimal_value = Decimal(str(value)).normalize()
+ precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
# raise error if precision is too high
- if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION:
+ if precision > dify_config.CODE_MAX_PRECISION:
raise OutputValidationError(
f"Output variable `{variable}` has too high precision,"
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
@@ -126,6 +156,9 @@ class CodeNode(BaseNode[CodeNodeData]):
prefix: str = "",
depth: int = 1,
):
+ # TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
+ # Note that `_transform_result` may produce lists containing `None` values,
+ # which don't conform to the type requirements of `Array*Segment` classes.
if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
@@ -324,16 +357,20 @@ class CodeNode(BaseNode[CodeNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: CodeNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
+ # Create typed NodeData from dict
+ typed_node_data = CodeNodeData.model_validate(node_data)
+
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
- for variable_selector in node_data.variables
+ for variable_selector in typed_node_data.variables
}
+
+ @property
+ def continue_on_error(self) -> bool:
+ return self._node_data.error_strategy is not None
+
+ @property
+ def retry(self) -> bool:
+ return self._node_data.retry_config.retry_enabled
diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py
index 429fed2d04..ab5964ebd4 100644
--- a/api/core/workflow/nodes/document_extractor/node.py
+++ b/api/core/workflow/nodes/document_extractor/node.py
@@ -5,7 +5,7 @@ import logging
import os
import tempfile
from collections.abc import Mapping, Sequence
-from typing import Any, cast
+from typing import Any, Optional, cast
import chardet
import docx
@@ -24,11 +24,12 @@ from configs import dify_config
from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment
-from core.variables.segments import FileSegment
+from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .entities import DocumentExtractorNodeData
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
@@ -36,17 +37,43 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError,
logger = logging.getLogger(__name__)
-class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
+class DocumentExtractorNode(BaseNode):
"""
Extracts text content from various file types.
Supports plain text, PDF, and DOC/DOCX files.
"""
- _node_data_cls = DocumentExtractorNodeData
_node_type = NodeType.DOCUMENT_EXTRACTOR
+ _node_data: DocumentExtractorNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = DocumentExtractorNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
def _run(self):
- variable_selector = self.node_data.variable_selector
+ variable_selector = self._node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
if variable is None:
@@ -67,7 +94,7 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
- outputs={"text": extracted_text_list},
+ outputs={"text": ArrayStringSegment(value=extracted_text_list)},
)
elif isinstance(value, File):
extracted_text = _extract_text_from_file(value)
@@ -93,16 +120,12 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: DocumentExtractorNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
- return {node_id + ".files": node_data.variable_selector}
+ # Create typed NodeData from dict
+ typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
+
+ return {node_id + ".files": typed_node_data.variable_selector}
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
@@ -447,7 +470,7 @@ def _extract_text_from_excel(file_content: bytes) -> str:
df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore
# Combine multi-line text in column names into a single line
- df.columns = pd.Index([" ".join(col.splitlines()) for col in df.columns])
+ df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns])
# Manually construct the Markdown table
markdown_table += _construct_markdown_table(df) + "\n\n"
diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py
index 0e9756b243..f86f2e8129 100644
--- a/api/core/workflow/nodes/end/end_node.py
+++ b/api/core/workflow/nodes/end/end_node.py
@@ -1,20 +1,50 @@
+from collections.abc import Mapping
+from typing import Any, Optional
+
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.end.entities import EndNodeData
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
-class EndNode(BaseNode[EndNodeData]):
- _node_data_cls = EndNodeData
+class EndNode(BaseNode):
_node_type = NodeType.END
+ _node_data: EndNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = EndNodeData(**data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
def _run(self) -> NodeRunResult:
"""
Run node
:return:
"""
- output_variables = self.node_data.outputs
+ output_variables = self._node_data.outputs
outputs = {}
for variable_selector in output_variables:
diff --git a/api/core/workflow/nodes/end/end_stream_processor.py b/api/core/workflow/nodes/end/end_stream_processor.py
index 3ae5af7137..a6fb2ffc18 100644
--- a/api/core/workflow/nodes/end/end_stream_processor.py
+++ b/api/core/workflow/nodes/end/end_stream_processor.py
@@ -139,6 +139,7 @@ class EndStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
+ node_version=event.node_version,
)
self.route_position[end_node_id] += 1
diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py
index 73b43eeaf7..7cf9ab9107 100644
--- a/api/core/workflow/nodes/enums.py
+++ b/api/core/workflow/nodes/enums.py
@@ -35,7 +35,3 @@ class ErrorStrategy(StrEnum):
class FailBranchSourceHandle(StrEnum):
FAILED = "fail-branch"
SUCCESS = "success-branch"
-
-
-CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
-RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE
diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py
index 2c83b00d4a..8ac1ae8526 100644
--- a/api/core/workflow/nodes/http_request/executor.py
+++ b/api/core/workflow/nodes/http_request/executor.py
@@ -8,6 +8,7 @@ from typing import Any, Literal
from urllib.parse import urlencode, urlparse
import httpx
+from json_repair import repair_json
from configs import dify_config
from core.file import file_manager
@@ -178,7 +179,8 @@ class Executor:
raise RequestBodyError("json body type should have exactly one item")
json_string = self.variable_pool.convert_template(data[0].value).text
try:
- json_object = json.loads(json_string, strict=False)
+ repaired = repair_json(json_string)
+ json_object = json.loads(repaired, strict=False)
except json.JSONDecodeError as e:
raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e
self.json = json_object
@@ -333,7 +335,7 @@ class Executor:
try:
response = getattr(ssrf_proxy, self.method.lower())(**request_args)
except (ssrf_proxy.MaxRetriesExceededError, httpx.RequestError) as e:
- raise HttpRequestNodeError(str(e))
+ raise HttpRequestNodeError(str(e)) from e
# FIXME: fix type ignore, this maybe httpx type issue
return response # type: ignore
diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py
index 6b1ac57c06..6799d5c63c 100644
--- a/api/core/workflow/nodes/http_request/node.py
+++ b/api/core/workflow/nodes/http_request/node.py
@@ -6,11 +6,13 @@ from typing import Any, Optional
from configs import dify_config
from core.file import File, FileTransferMethod
from core.tools.tool_file_manager import ToolFileManager
+from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.utils import variable_template_parser
from factories import file_factory
@@ -31,10 +33,32 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__)
-class HttpRequestNode(BaseNode[HttpRequestNodeData]):
- _node_data_cls = HttpRequestNodeData
+class HttpRequestNode(BaseNode):
_node_type = NodeType.HTTP_REQUEST
+ _node_data: HttpRequestNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = HttpRequestNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
return {
@@ -60,12 +84,16 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
},
}
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
def _run(self) -> NodeRunResult:
process_data = {}
try:
http_executor = Executor(
- node_data=self.node_data,
- timeout=self._get_request_timeout(self.node_data),
+ node_data=self._node_data,
+ timeout=self._get_request_timeout(self._node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
)
@@ -73,7 +101,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response)
- if not response.response.is_success and (self.should_continue_on_error or self.should_retry):
+ if not response.response.is_success and (self.continue_on_error or self.retry):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
@@ -92,7 +120,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={
"status_code": response.status_code,
- "body": response.text if not files else "",
+ "body": response.text if not files.value else "",
"headers": response.headers,
"files": files,
},
@@ -126,15 +154,18 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: HttpRequestNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
+ # Create typed NodeData from dict
+ typed_node_data = HttpRequestNodeData.model_validate(node_data)
+
selectors: list[VariableSelector] = []
- selectors += variable_template_parser.extract_selectors_from_template(node_data.url)
- selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
- selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
- if node_data.body:
- body_type = node_data.body.type
- data = node_data.body.data
+ selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
+ selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
+ selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
+ if typed_node_data.body:
+ body_type = typed_node_data.body.type
+ data = typed_node_data.body.data
match body_type:
case "binary":
if len(data) != 1:
@@ -166,7 +197,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
return mapping
- def extract_files(self, url: str, response: Response) -> list[File]:
+ def extract_files(self, url: str, response: Response) -> ArrayFileSegment:
"""
Extract files from response by checking both Content-Type header and URL
"""
@@ -178,7 +209,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
content_disposition_type = None
if not is_file:
- return files
+ return ArrayFileSegment(value=[])
if parsed_content_disposition:
content_disposition_filename = parsed_content_disposition.get_filename()
@@ -211,4 +242,12 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
)
files.append(file)
- return files
+ return ArrayFileSegment(value=files)
+
+ @property
+ def continue_on_error(self) -> bool:
+ return self._node_data.error_strategy is not None
+
+ @property
+ def retry(self) -> bool:
+ return self._node_data.retry_config.retry_enabled
diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py
index 976922f75d..86e703dc68 100644
--- a/api/core/workflow/nodes/if_else/if_else_node.py
+++ b/api/core/workflow/nodes/if_else/if_else_node.py
@@ -1,4 +1,5 @@
-from typing import Literal
+from collections.abc import Mapping, Sequence
+from typing import Any, Literal, Optional
from typing_extensions import deprecated
@@ -6,16 +7,43 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor
-class IfElseNode(BaseNode[IfElseNodeData]):
- _node_data_cls = IfElseNodeData
+class IfElseNode(BaseNode):
_node_type = NodeType.IF_ELSE
+ _node_data: IfElseNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = IfElseNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
def _run(self) -> NodeRunResult:
"""
Run node
@@ -31,8 +59,8 @@ class IfElseNode(BaseNode[IfElseNodeData]):
condition_processor = ConditionProcessor()
try:
# Check if the new cases structure is used
- if self.node_data.cases:
- for case in self.node_data.cases:
+ if self._node_data.cases:
+ for case in self._node_data.cases:
input_conditions, group_result, final_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions,
@@ -58,8 +86,8 @@ class IfElseNode(BaseNode[IfElseNodeData]):
input_conditions, group_result, final_result = _should_not_use_old_function(
condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool,
- conditions=self.node_data.conditions or [],
- operator=self.node_data.logical_operator or "and",
+ conditions=self._node_data.conditions or [],
+ operator=self._node_data.logical_operator or "and",
)
selected_case_id = "true" if final_result else "false"
@@ -87,6 +115,25 @@ class IfElseNode(BaseNode[IfElseNodeData]):
return data
+ @classmethod
+ def _extract_variable_selector_to_variable_mapping(
+ cls,
+ *,
+ graph_config: Mapping[str, Any],
+ node_id: str,
+ node_data: Mapping[str, Any],
+ ) -> Mapping[str, Sequence[str]]:
+ # Create typed NodeData from dict
+ typed_node_data = IfElseNodeData.model_validate(node_data)
+
+ var_mapping: dict[str, list[str]] = {}
+ for case in typed_node_data.cases or []:
+ for condition in case.conditions:
+ key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector))
+ var_mapping[key] = condition.variable_selector
+
+ return var_mapping
+
@deprecated("This function is deprecated. You should use the new cases structure.")
def _should_not_use_old_function(
diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py
index 42b6795fb0..5842c8d64b 100644
--- a/api/core/workflow/nodes/iteration/iteration_node.py
+++ b/api/core/workflow/nodes/iteration/iteration_node.py
@@ -1,5 +1,6 @@
import contextvars
import logging
+import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from concurrent.futures import Future, wait
@@ -11,6 +12,7 @@ from flask import Flask, current_app
from configs import dify_config
from core.variables import ArrayVariable, IntegerVariable, NoneVariable
+from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import (
NodeRunResult,
)
@@ -34,9 +36,11 @@ from core.workflow.graph_engine.entities.event import (
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
+from factories.variable_factory import build_segment
from libs.flask_utils import preserve_flask_contexts
from .exc import (
@@ -53,14 +57,36 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class IterationNode(BaseNode[IterationNodeData]):
+class IterationNode(BaseNode):
"""
Iteration Node.
"""
- _node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
+ _node_data: IterationNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = IterationNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
return {
@@ -72,23 +98,34 @@ class IterationNode(BaseNode[IterationNodeData]):
},
}
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""
Run the node.
"""
- variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
+ variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
if not variable:
- raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
+ raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneVariable) or len(variable.value) == 0:
+ # Try our best to preserve the type informat.
+ if isinstance(variable, ArraySegment):
+ output = variable.model_copy(update={"value": []})
+ else:
+ output = ArrayAnySegment(value=[])
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
- outputs={"output": []},
+ # TODO(QuantumGhost): is it possible to compute the type of `output`
+ # from graph definition?
+ outputs={"output": output},
)
)
return
@@ -102,10 +139,10 @@ class IterationNode(BaseNode[IterationNodeData]):
graph_config = self.graph_config
- if not self.node_data.start_node_id:
+ if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
- root_node_id = self.node_data.start_node_id
+ root_node_id = self._node_data.start_node_id
# init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
@@ -120,8 +157,11 @@ class IterationNode(BaseNode[IterationNodeData]):
variable_pool.add([self.node_id, "item"], iterator_list_value[0])
# init graph engine
+ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
+ graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
+
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
@@ -133,7 +173,7 @@ class IterationNode(BaseNode[IterationNodeData]):
call_depth=self.workflow_call_depth,
graph=iteration_graph,
graph_config=graph_config,
- variable_pool=variable_pool,
+ graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id,
@@ -144,8 +184,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunStartedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
metadata={"iterator_length": len(iterator_list_value)},
@@ -155,8 +195,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
index=0,
pre_iteration_output=None,
duration=None,
@@ -164,11 +204,11 @@ class IterationNode(BaseNode[IterationNodeData]):
iter_run_map: dict[str, float] = {}
outputs: list[Any] = [None] * len(iterator_list_value)
try:
- if self.node_data.is_parallel:
+ if self._node_data.is_parallel:
futures: list[Future] = []
q: Queue = Queue()
thread_pool = GraphEngineThreadPool(
- max_workers=self.node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
+ max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
)
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
@@ -225,18 +265,19 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_graph=iteration_graph,
iter_run_map=iter_run_map,
)
- if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
+ if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None]
# Flatten the list of lists
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
outputs = [item for sublist in outputs for item in sublist]
+ output_segment = build_segment(outputs)
yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
@@ -247,7 +288,7 @@ class IterationNode(BaseNode[IterationNodeData]):
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
- outputs={"output": outputs},
+ outputs={"output": output_segment},
metadata={
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
@@ -260,8 +301,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
@@ -287,21 +328,17 @@ class IterationNode(BaseNode[IterationNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: IterationNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
+ # Create typed NodeData from dict
+ typed_node_data = IterationNodeData.model_validate(node_data)
+
variable_mapping: dict[str, Sequence[str]] = {
- f"{node_id}.input_selector": node_data.iterator_selector,
+ f"{node_id}.input_selector": typed_node_data.iterator_selector,
}
# init graph
- iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
+ iteration_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
@@ -357,7 +394,7 @@ class IterationNode(BaseNode[IterationNodeData]):
"""
if not isinstance(event, BaseNodeEvent):
return event
- if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
+ if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
event.parallel_mode_run_id = parallel_mode_run_id
iter_metadata = {
@@ -420,12 +457,12 @@ class IterationNode(BaseNode[IterationNodeData]):
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
- if self.node_data.is_parallel:
+ if self._node_data.is_parallel:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
@@ -438,8 +475,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
@@ -460,7 +497,7 @@ class IterationNode(BaseNode[IterationNodeData]):
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
if isinstance(event, NodeRunFailedEvent):
- if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
+ if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
@@ -473,15 +510,15 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
- elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
+ elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
@@ -494,30 +531,64 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
- elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
- yield IterationRunFailedEvent(
- iteration_id=self.id,
- iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
- start_at=start_at,
- inputs=inputs,
- outputs={"output": None},
- steps=len(iterator_list_value),
- metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
- error=event.error,
+ elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
+ yield NodeInIterationFailedEvent(
+ **metadata_event.model_dump(),
)
+ outputs[current_index] = None
+
+ # clean nodes resources
+ for node_id in iteration_graph.node_ids:
+ variable_pool.remove([node_id])
+
+ # iteration run failed
+ if self._node_data.is_parallel:
+ yield IterationRunFailedEvent(
+ iteration_id=self.id,
+ iteration_node_id=self.node_id,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
+ parallel_mode_run_id=parallel_mode_run_id,
+ start_at=start_at,
+ inputs=inputs,
+ outputs={"output": outputs},
+ steps=len(iterator_list_value),
+ metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
+ error=event.error,
+ )
+ else:
+ yield IterationRunFailedEvent(
+ iteration_id=self.id,
+ iteration_node_id=self.node_id,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
+ start_at=start_at,
+ inputs=inputs,
+ outputs={"output": outputs},
+ steps=len(iterator_list_value),
+ metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
+ error=event.error,
+ )
+
+ # stop the iterator
+ yield RunCompletedEvent(
+ run_result=NodeRunResult(
+ status=WorkflowNodeExecutionStatus.FAILED,
+ error=event.error,
+ )
+ )
+ return
yield metadata_event
- current_output_segment = variable_pool.get(self.node_data.output_selector)
+ current_output_segment = variable_pool.get(self._node_data.output_selector)
if current_output_segment is None:
raise IterationNodeError("iteration output selector not found")
current_iteration_output = current_output_segment.value
@@ -536,8 +607,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=current_iteration_output or None,
@@ -549,8 +620,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": None},
diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py
index bee481ebdb..b82c29291a 100644
--- a/api/core/workflow/nodes/iteration/iteration_start_node.py
+++ b/api/core/workflow/nodes/iteration/iteration_start_node.py
@@ -1,18 +1,48 @@
+from collections.abc import Mapping
+from typing import Any, Optional
+
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.iteration.entities import IterationStartNodeData
-class IterationStartNode(BaseNode[IterationStartNodeData]):
+class IterationStartNode(BaseNode):
"""
Iteration Start Node.
"""
- _node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START
+ _node_data: IterationStartNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = IterationStartNodeData(**data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
def _run(self) -> NodeRunResult:
"""
Run the node.
diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py
index 19bdee4fe2..f1767bdf9e 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/entities.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py
@@ -1,10 +1,10 @@
from collections.abc import Sequence
-from typing import Any, Literal, Optional
+from typing import Literal, Optional
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
-from core.workflow.nodes.llm.entities import VisionConfig
+from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
class RerankingModelConfig(BaseModel):
@@ -56,17 +56,6 @@ class MultipleRetrievalConfig(BaseModel):
weights: Optional[WeightedScoreConfig] = None
-class ModelConfig(BaseModel):
- """
- Model Config.
- """
-
- provider: str
- name: str
- mode: str
- completion_params: dict[str, Any] = {}
-
-
class SingleRetrievalConfig(BaseModel):
"""
Single Retrieval Config.
diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
index 5cf5848d54..5f092dc2f1 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
@@ -4,7 +4,7 @@ import re
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
-from typing import Any, Optional, cast
+from typing import TYPE_CHECKING, Any, Optional, cast
from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast
@@ -15,19 +15,31 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
-from core.model_runtime.entities.message_entities import PromptMessageRole
-from core.model_runtime.entities.model_entities import ModelFeature, ModelType
+from core.model_runtime.entities.message_entities import (
+ PromptMessageRole,
+)
+from core.model_runtime.entities.model_entities import (
+ ModelFeature,
+ ModelType,
+)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
-from core.variables import StringSegment
+from core.variables import (
+ StringSegment,
+)
+from core.variables.segments import ArrayObjectSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.nodes.enums import NodeType
-from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
+from core.workflow.nodes.base import BaseNode
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
+from core.workflow.nodes.event import (
+ ModelInvokeCompletedEvent,
+)
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
@@ -37,7 +49,8 @@ from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
-from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
+from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig
+from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -45,7 +58,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from services.feature_service import FeatureService
-from .entities import KnowledgeRetrievalNodeData, ModelConfig
+from .entities import KnowledgeRetrievalNodeData
from .exc import (
InvalidModelTypeError,
KnowledgeRetrievalNodeError,
@@ -55,6 +68,10 @@ from .exc import (
ModelQuotaExceededError,
)
+if TYPE_CHECKING:
+ from core.file.models import File
+ from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
+
logger = logging.getLogger(__name__)
default_retrieval_model = {
@@ -66,14 +83,76 @@ default_retrieval_model = {
}
-class KnowledgeRetrievalNode(LLMNode):
- _node_data_cls = KnowledgeRetrievalNodeData # type: ignore
+class KnowledgeRetrievalNode(BaseNode):
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
+ _node_data: KnowledgeRetrievalNodeData
+
+ # Instance attributes specific to LLMNode.
+ # Output variable for file
+ _file_outputs: list["File"]
+
+ _llm_file_saver: LLMFileSaver
+
+ def __init__(
+ self,
+ id: str,
+ config: Mapping[str, Any],
+ graph_init_params: "GraphInitParams",
+ graph: "Graph",
+ graph_runtime_state: "GraphRuntimeState",
+ previous_node_id: Optional[str] = None,
+ thread_pool_id: Optional[str] = None,
+ *,
+ llm_file_saver: LLMFileSaver | None = None,
+ ) -> None:
+ super().__init__(
+ id=id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=graph,
+ graph_runtime_state=graph_runtime_state,
+ previous_node_id=previous_node_id,
+ thread_pool_id=thread_pool_id,
+ )
+ # LLM file outputs, used for MultiModal outputs.
+ self._file_outputs: list[File] = []
+
+ if llm_file_saver is None:
+ llm_file_saver = FileSaverImpl(
+ user_id=graph_init_params.user_id,
+ tenant_id=graph_init_params.tenant_id,
+ )
+ self._llm_file_saver = llm_file_saver
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @classmethod
+ def version(cls):
+ return "1"
+
def _run(self) -> NodeRunResult: # type: ignore
- node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
# extract variables
- variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector)
+ variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@@ -114,10 +193,13 @@ class KnowledgeRetrievalNode(LLMNode):
# retrieve knowledge
try:
- results = self._fetch_dataset_retriever(node_data=node_data, query=query)
- outputs = {"result": results}
+ results = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
+ outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
- status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs
+ status=WorkflowNodeExecutionStatus.SUCCEEDED,
+ inputs=variables,
+ process_data=None,
+ outputs=outputs, # type: ignore
)
except KnowledgeRetrievalNodeError as e:
@@ -136,6 +218,8 @@ class KnowledgeRetrievalNode(LLMNode):
error=str(e),
error_type=type(e).__name__,
)
+ finally:
+ db.session.close()
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
available_datasets = []
@@ -163,6 +247,9 @@ class KnowledgeRetrievalNode(LLMNode):
.all()
)
+ # avoid blocking at retrieval
+ db.session.close()
+
for dataset in results:
# pass if dataset is not available
if not dataset:
@@ -422,20 +509,17 @@ class KnowledgeRetrievalNode(LLMNode):
# get all metadata field
metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
- # get metadata model config
- metadata_model_config = node_data.metadata_model_config
- if metadata_model_config is None:
+ if node_data.metadata_model_config is None:
raise ValueError("metadata_model_config is required")
- # get metadata model instance
- # fetch model config
- model_instance, model_config = self.get_model_config(metadata_model_config)
+ # get metadata model instance and fetch model config
+ model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
# fetch prompt messages
prompt_template = self._get_prompt_template(
node_data=node_data,
metadata_fields=all_metadata_fields,
query=query or "",
)
- prompt_messages, stop = self._fetch_prompt_messages(
+ prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query=query,
memory=None,
@@ -445,16 +529,23 @@ class KnowledgeRetrievalNode(LLMNode):
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
+ tenant_id=self.tenant_id,
)
result_text = ""
try:
# handle invoke result
- generator = self._invoke_llm(
- node_data_model=node_data.metadata_model_config, # type: ignore
+ generator = LLMNode.invoke_llm(
+ node_data_model=node_data.metadata_model_config,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
+ user_id=self.user_id,
+ structured_output_enabled=self._node_data.structured_output_enabled,
+ structured_output=None,
+ file_saver=self._llm_file_saver,
+ file_outputs=self._file_outputs,
+ node_id=self.node_id,
)
for event in generator:
@@ -482,6 +573,9 @@ class KnowledgeRetrievalNode(LLMNode):
def _process_metadata_filter_func(
self, sequence: int, condition: str, metadata_name: str, value: Optional[Any], filters: list
):
+ if value is None:
+ return
+
key = f"{metadata_name}_{sequence}"
key_value = f"{metadata_name}_{sequence}_value"
match condition:
@@ -541,17 +635,13 @@ class KnowledgeRetrievalNode(LLMNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: KnowledgeRetrievalNodeData, # type: ignore
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
+ # Create typed NodeData from dict
+ typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
+
variable_mapping = {}
- variable_mapping[node_id + ".query"] = node_data.query_variable_selector
+ variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
return variable_mapping
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
@@ -613,7 +703,7 @@ class KnowledgeRetrievalNode(LLMNode):
)
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
- model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore
+ model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore
input_text = query
prompt_messages: list[LLMNodeChatModelMessage] = []
diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py
index e698d3f5d8..ae9401b056 100644
--- a/api/core/workflow/nodes/list_operator/node.py
+++ b/api/core/workflow/nodes/list_operator/node.py
@@ -1,36 +1,68 @@
-from collections.abc import Callable, Sequence
-from typing import Any, Literal, Union
+from collections.abc import Callable, Mapping, Sequence
+from typing import Any, Literal, Optional, Union
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
+from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .entities import ListOperatorNodeData
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
-class ListOperatorNode(BaseNode[ListOperatorNodeData]):
- _node_data_cls = ListOperatorNodeData
+class ListOperatorNode(BaseNode):
_node_type = NodeType.LIST_OPERATOR
+ _node_data: ListOperatorNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = ListOperatorNodeData(**data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
def _run(self):
inputs: dict[str, list] = {}
process_data: dict[str, list] = {}
outputs: dict[str, Any] = {}
- variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
+ variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
if variable is None:
- error_message = f"Variable not found for selector: {self.node_data.variable}"
+ error_message = f"Variable not found for selector: {self._node_data.variable}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
if not variable.value:
inputs = {"variable": []}
process_data = {"variable": []}
- outputs = {"result": [], "first_record": None, "last_record": None}
+ if isinstance(variable, ArraySegment):
+ result = variable.model_copy(update={"value": []})
+ else:
+ result = ArrayAnySegment(value=[])
+ outputs = {"result": result, "first_record": None, "last_record": None}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
@@ -39,7 +71,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
)
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = (
- f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
+ f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment"
)
return NodeRunResult(
@@ -55,23 +87,23 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
try:
# Filter
- if self.node_data.filter_by.enabled:
+ if self._node_data.filter_by.enabled:
variable = self._apply_filter(variable)
# Extract
- if self.node_data.extract_by.enabled:
+ if self._node_data.extract_by.enabled:
variable = self._extract_slice(variable)
# Order
- if self.node_data.order_by.enabled:
+ if self._node_data.order_by.enabled:
variable = self._apply_order(variable)
# Slice
- if self.node_data.limit.enabled:
+ if self._node_data.limit.enabled:
variable = self._apply_slice(variable)
outputs = {
- "result": variable.value,
+ "result": variable,
"first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None,
}
@@ -95,7 +127,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
filter_func: Callable[[Any], bool]
result: list[Any] = []
- for condition in self.node_data.filter_by.conditions:
+ for condition in self._node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
@@ -128,14 +160,14 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment):
- result = _order_string(order=self.node_data.order_by.value, array=variable.value)
+ result = _order_string(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
- result = _order_number(order=self.node_data.order_by.value, array=variable.value)
+ result = _order_number(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
- order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
+ order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
return variable
@@ -143,13 +175,13 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
def _apply_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
- result = variable.value[: self.node_data.limit.size]
+ result = variable.value[: self._node_data.limit.size]
return variable.model_copy(update={"value": result})
def _extract_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
- value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text)
+ value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
value -= 1
diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py
index 36d0688807..4bb62d35a2 100644
--- a/api/core/workflow/nodes/llm/entities.py
+++ b/api/core/workflow/nodes/llm/entities.py
@@ -1,4 +1,4 @@
-from collections.abc import Sequence
+from collections.abc import Mapping, Sequence
from typing import Any, Optional
from pydantic import BaseModel, Field, field_validator
@@ -65,7 +65,7 @@ class LLMNodeData(BaseNodeData):
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
- structured_output: dict | None = None
+ structured_output: Mapping[str, Any] | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
diff --git a/api/core/workflow/nodes/llm/file_saver.py b/api/core/workflow/nodes/llm/file_saver.py
index c85baade03..a4b45ce652 100644
--- a/api/core/workflow/nodes/llm/file_saver.py
+++ b/api/core/workflow/nodes/llm/file_saver.py
@@ -119,9 +119,6 @@ class FileSaverImpl(LLMFileSaver):
size=len(data),
related_id=tool_file.id,
url=url,
- # TODO(QuantumGhost): how should I set the following key?
- # What's the difference between `remote_url` and `url`?
- # What's the purpose of `storage_key` and `dify_model_identity`?
storage_key=tool_file.file_key,
)
diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py
index d27124d62c..91e7312805 100644
--- a/api/core/workflow/nodes/llm/node.py
+++ b/api/core/workflow/nodes/llm/node.py
@@ -5,11 +5,11 @@ import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Optional, cast
-import json_repair
-
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.file import FileType, file_manager
from core.helper.code_executor import CodeExecutor, CodeLanguage
+from core.llm_generator.output_parser.errors import OutputParserError
+from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
@@ -18,7 +18,13 @@ from core.model_runtime.entities import (
PromptMessageContentType,
TextPromptMessageContent,
)
-from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
+from core.model_runtime.entities.llm_entities import (
+ LLMResult,
+ LLMResultChunk,
+ LLMResultChunkWithStructuredOutput,
+ LLMStructuredOutput,
+ LLMUsage,
+)
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContentUnionTypes,
@@ -31,7 +37,6 @@ from core.model_runtime.entities.model_entities import (
ModelFeature,
ModelPropertyKey,
ModelType,
- ParameterRule,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -54,7 +59,8 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import (
ModelInvokeCompletedEvent,
NodeEvent,
@@ -62,11 +68,6 @@ from core.workflow.nodes.event import (
RunRetrieverResourceEvent,
RunStreamChunkEvent,
)
-from core.workflow.utils.structured_output.entities import (
- ResponseFormat,
- SpecialModelType,
-)
-from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from . import llm_utils
@@ -90,17 +91,16 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
- from core.workflow.graph_engine.entities.graph import Graph
- from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
- from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+ from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
logger = logging.getLogger(__name__)
-class LLMNode(BaseNode[LLMNodeData]):
- _node_data_cls = LLMNodeData
+class LLMNode(BaseNode):
_node_type = NodeType.LLM
+ _node_data: LLMNodeData
+
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
@@ -138,13 +138,32 @@ class LLMNode(BaseNode[LLMNodeData]):
)
self._llm_file_saver = llm_file_saver
- def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
- def process_structured_output(text: str) -> Optional[dict[str, Any]]:
- """Process structured output if enabled"""
- if not self.node_data.structured_output_enabled or not self.node_data.structured_output:
- return None
- return self._parse_structured_output(text)
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = LLMNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
+ def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
node_inputs: Optional[dict[str, Any]] = None
process_data = None
result_text = ""
@@ -154,13 +173,13 @@ class LLMNode(BaseNode[LLMNodeData]):
try:
# init messages template
- self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
+ self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template)
# fetch variables and fetch values from variable pool
- inputs = self._fetch_inputs(node_data=self.node_data)
+ inputs = self._fetch_inputs(node_data=self._node_data)
# fetch jinja2 inputs
- jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
+ jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data)
# merge inputs
inputs.update(jinja_inputs)
@@ -171,9 +190,9 @@ class LLMNode(BaseNode[LLMNodeData]):
files = (
llm_utils.fetch_files(
variable_pool=variable_pool,
- selector=self.node_data.vision.configs.variable_selector,
+ selector=self._node_data.vision.configs.variable_selector,
)
- if self.node_data.vision.enabled
+ if self._node_data.vision.enabled
else []
)
@@ -181,7 +200,7 @@ class LLMNode(BaseNode[LLMNodeData]):
node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value
- generator = self._fetch_context(node_data=self.node_data)
+ generator = self._fetch_context(node_data=self._node_data)
context = None
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
@@ -191,55 +210,58 @@ class LLMNode(BaseNode[LLMNodeData]):
node_inputs["#context#"] = context
# fetch model config
- model_instance, model_config = self._fetch_model_config(self.node_data.model)
+ model_instance, model_config = LLMNode._fetch_model_config(
+ node_data_model=self._node_data.model,
+ tenant_id=self.tenant_id,
+ )
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
- node_data_memory=self.node_data.memory,
+ node_data_memory=self._node_data.memory,
model_instance=model_instance,
)
query = None
- if self.node_data.memory:
- query = self.node_data.memory.query_prompt_template
+ if self._node_data.memory:
+ query = self._node_data.memory.query_prompt_template
if not query and (
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
):
query = query_variable.text
- prompt_messages, stop = self._fetch_prompt_messages(
+ prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
- prompt_template=self.node_data.prompt_template,
- memory_config=self.node_data.memory,
- vision_enabled=self.node_data.vision.enabled,
- vision_detail=self.node_data.vision.configs.detail,
+ prompt_template=self._node_data.prompt_template,
+ memory_config=self._node_data.memory,
+ vision_enabled=self._node_data.vision.enabled,
+ vision_detail=self._node_data.vision.configs.detail,
variable_pool=variable_pool,
- jinja2_variables=self.node_data.prompt_config.jinja2_variables,
+ jinja2_variables=self._node_data.prompt_config.jinja2_variables,
+ tenant_id=self.tenant_id,
)
- process_data = {
- "model_mode": model_config.mode,
- "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
- model_mode=model_config.mode, prompt_messages=prompt_messages
- ),
- "model_provider": model_config.provider,
- "model_name": model_config.model,
- }
-
# handle invoke result
- generator = self._invoke_llm(
- node_data_model=self.node_data.model,
+ generator = LLMNode.invoke_llm(
+ node_data_model=self._node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
+ user_id=self.user_id,
+ structured_output_enabled=self._node_data.structured_output_enabled,
+ structured_output=self._node_data.structured_output,
+ file_saver=self._llm_file_saver,
+ file_outputs=self._file_outputs,
+ node_id=self.node_id,
)
+ structured_output: LLMStructuredOutput | None = None
+
for event in generator:
if isinstance(event, RunStreamChunkEvent):
yield event
@@ -250,12 +272,25 @@ class LLMNode(BaseNode[LLMNodeData]):
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
+ elif isinstance(event, LLMStructuredOutput):
+ structured_output = event
+
+ process_data = {
+ "model_mode": model_config.mode,
+ "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
+ model_mode=model_config.mode, prompt_messages=prompt_messages
+ ),
+ "usage": jsonable_encoder(usage),
+ "finish_reason": finish_reason,
+ "model_provider": model_config.provider,
+ "model_name": model_config.model,
+ }
+
outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason}
- structured_output = process_structured_output(result_text)
if structured_output:
- outputs["structured_output"] = structured_output
+ outputs["structured_output"] = structured_output.structured_output
if self._file_outputs is not None:
- outputs["files"] = self._file_outputs
+ outputs["files"] = ArrayFileSegment(value=self._file_outputs)
yield RunCompletedEvent(
run_result=NodeRunResult(
@@ -292,29 +327,72 @@ class LLMNode(BaseNode[LLMNodeData]):
)
)
- def _invoke_llm(
- self,
+ @staticmethod
+ def invoke_llm(
+ *,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
- ) -> Generator[NodeEvent, None, None]:
- invoke_result = model_instance.invoke_llm(
- prompt_messages=list(prompt_messages),
- model_parameters=node_data_model.completion_params,
- stop=list(stop or []),
- stream=True,
- user=self.user_id,
+ user_id: str,
+ structured_output_enabled: bool,
+ structured_output: Optional[Mapping[str, Any]] = None,
+ file_saver: LLMFileSaver,
+ file_outputs: list["File"],
+ node_id: str,
+ ) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
+ model_schema = model_instance.model_type_instance.get_model_schema(
+ node_data_model.name, model_instance.credentials
)
+ if not model_schema:
+ raise ValueError(f"Model schema not found for {node_data_model.name}")
- return self._handle_invoke_result(invoke_result=invoke_result)
+ if structured_output_enabled:
+ output_schema = LLMNode.fetch_structured_output_schema(
+ structured_output=structured_output or {},
+ )
+ invoke_result = invoke_llm_with_structured_output(
+ provider=model_instance.provider,
+ model_schema=model_schema,
+ model_instance=model_instance,
+ prompt_messages=prompt_messages,
+ json_schema=output_schema,
+ model_parameters=node_data_model.completion_params,
+ stop=list(stop or []),
+ stream=True,
+ user=user_id,
+ )
+ else:
+ invoke_result = model_instance.invoke_llm(
+ prompt_messages=list(prompt_messages),
+ model_parameters=node_data_model.completion_params,
+ stop=list(stop or []),
+ stream=True,
+ user=user_id,
+ )
- def _handle_invoke_result(
- self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None]
- ) -> Generator[NodeEvent, None, None]:
+ return LLMNode.handle_invoke_result(
+ invoke_result=invoke_result,
+ file_saver=file_saver,
+ file_outputs=file_outputs,
+ node_id=node_id,
+ )
+
+ @staticmethod
+ def handle_invoke_result(
+ *,
+ invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
+ file_saver: LLMFileSaver,
+ file_outputs: list["File"],
+ node_id: str,
+ ) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
- event = self._handle_blocking_result(invoke_result=invoke_result)
+ event = LLMNode.handle_blocking_result(
+ invoke_result=invoke_result,
+ saver=file_saver,
+ file_outputs=file_outputs,
+ )
yield event
return
@@ -325,27 +403,39 @@ class LLMNode(BaseNode[LLMNodeData]):
usage = LLMUsage.empty_usage()
finish_reason = None
full_text_buffer = io.StringIO()
- for result in invoke_result:
- contents = result.delta.message.content
- for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
- full_text_buffer.write(text_part)
- yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[self.node_id, "text"])
-
- # Update the whole metadata
- if not model and result.model:
- model = result.model
- if len(prompt_messages) == 0:
- # TODO(QuantumGhost): it seems that this update has no visable effect.
- # What's the purpose of the line below?
- prompt_messages = list(result.prompt_messages)
- if usage.prompt_tokens == 0 and result.delta.usage:
- usage = result.delta.usage
- if finish_reason is None and result.delta.finish_reason:
- finish_reason = result.delta.finish_reason
+ # Consume the invoke result and handle generator exception
+ try:
+ for result in invoke_result:
+ if isinstance(result, LLMResultChunkWithStructuredOutput):
+ yield result
+ if isinstance(result, LLMResultChunk):
+ contents = result.delta.message.content
+ for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
+ contents=contents,
+ file_saver=file_saver,
+ file_outputs=file_outputs,
+ ):
+ full_text_buffer.write(text_part)
+ yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[node_id, "text"])
+
+ # Update the whole metadata
+ if not model and result.model:
+ model = result.model
+ if len(prompt_messages) == 0:
+ # TODO(QuantumGhost): it seems that this update has no visable effect.
+ # What's the purpose of the line below?
+ prompt_messages = list(result.prompt_messages)
+ if usage.prompt_tokens == 0 and result.delta.usage:
+ usage = result.delta.usage
+ if finish_reason is None and result.delta.finish_reason:
+ finish_reason = result.delta.finish_reason
+ except OutputParserError as e:
+ raise LLMNodeError(f"Failed to parse structured output: {e}")
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
- def _image_file_to_markdown(self, file: "File", /):
+ @staticmethod
+ def _image_file_to_markdown(file: "File", /):
text_chunk = f"})"
return text_chunk
@@ -506,11 +596,14 @@ class LLMNode(BaseNode[LLMNodeData]):
return None
+ @staticmethod
def _fetch_model_config(
- self, node_data_model: ModelConfig
+ *,
+ node_data_model: ModelConfig,
+ tenant_id: str,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model, model_config_with_cred = llm_utils.fetch_model_config(
- tenant_id=self.tenant_id, node_data_model=node_data_model
+ tenant_id=tenant_id, node_data_model=node_data_model
)
completion_params = model_config_with_cred.parameters
@@ -518,19 +611,13 @@ class LLMNode(BaseNode[LLMNodeData]):
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
- if self.node_data.structured_output_enabled:
- if model_schema.support_structure_output:
- completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules)
- else:
- # Set appropriate response format based on model capabilities
- self._set_response_format(completion_params, model_schema.parameter_rules)
model_config_with_cred.parameters = completion_params
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
node_data_model.completion_params = completion_params
return model, model_config_with_cred
- def _fetch_prompt_messages(
- self,
+ @staticmethod
+ def fetch_prompt_messages(
*,
sys_query: str | None = None,
sys_files: Sequence["File"],
@@ -543,13 +630,14 @@ class LLMNode(BaseNode[LLMNodeData]):
vision_detail: ImagePromptMessageContent.DETAIL,
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
+ tenant_id: str,
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages: list[PromptMessage] = []
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
- self._handle_list_messages(
+ LLMNode.handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
@@ -575,7 +663,7 @@ class LLMNode(BaseNode[LLMNodeData]):
edition_type="basic",
)
prompt_messages.extend(
- self._handle_list_messages(
+ LLMNode.handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
@@ -704,7 +792,7 @@ class LLMNode(BaseNode[LLMNodeData]):
)
model = ModelManager().get_model_instance(
- tenant_id=self.tenant_id,
+ tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,
model=model_config.model,
@@ -715,42 +803,20 @@ class LLMNode(BaseNode[LLMNodeData]):
)
if not model_schema:
raise ModelNotExistError(f"Model {model_config.model} not exist.")
- if self.node_data.structured_output_enabled:
- if not model_schema.support_structure_output:
- filtered_prompt_messages = self._handle_prompt_based_schema(
- prompt_messages=filtered_prompt_messages,
- )
return filtered_prompt_messages, model_config.stop
- def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
- structured_output: dict[str, Any] = {}
- try:
- parsed = json.loads(result_text)
- if not isinstance(parsed, dict):
- raise LLMNodeError(f"Failed to parse structured output: {result_text}")
- structured_output = parsed
- except json.JSONDecodeError as e:
- # if the result_text is not a valid json, try to repair it
- parsed = json_repair.loads(result_text)
- if not isinstance(parsed, dict):
- # handle reasoning model like deepseek-r1 got '\n\n\n' prefix
- if isinstance(parsed, list):
- parsed = next((item for item in parsed if isinstance(item, dict)), {})
- else:
- raise LLMNodeError(f"Failed to parse structured output: {result_text}")
- structured_output = parsed
- return structured_output
-
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: LLMNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- prompt_template = node_data.prompt_template
+ # Create typed NodeData from dict
+ typed_node_data = LLMNodeData.model_validate(node_data)
+ prompt_template = typed_node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list) and all(
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
@@ -770,7 +836,7 @@ class LLMNode(BaseNode[LLMNodeData]):
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
- memory = node_data.memory
+ memory = typed_node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
@@ -778,16 +844,16 @@ class LLMNode(BaseNode[LLMNodeData]):
for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
- if node_data.context.enabled:
- variable_mapping["#context#"] = node_data.context.variable_selector
+ if typed_node_data.context.enabled:
+ variable_mapping["#context#"] = typed_node_data.context.variable_selector
- if node_data.vision.enabled:
- variable_mapping["#files#"] = node_data.vision.configs.variable_selector
+ if typed_node_data.vision.enabled:
+ variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
- if node_data.memory:
+ if typed_node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
- if node_data.prompt_config:
+ if typed_node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, list):
@@ -800,7 +866,7 @@ class LLMNode(BaseNode[LLMNodeData]):
enable_jinja = True
if enable_jinja:
- for variable_selector in node_data.prompt_config.jinja2_variables or []:
+ for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
@@ -832,8 +898,8 @@ class LLMNode(BaseNode[LLMNodeData]):
},
}
- def _handle_list_messages(
- self,
+ @staticmethod
+ def handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
@@ -846,7 +912,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
- jinjia2_variables=jinja2_variables,
+ jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_message_content_with_role(
@@ -894,9 +960,19 @@ class LLMNode(BaseNode[LLMNodeData]):
return prompt_messages
- def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
+ @staticmethod
+ def handle_blocking_result(
+ *,
+ invoke_result: LLMResult,
+ saver: LLMFileSaver,
+ file_outputs: list["File"],
+ ) -> ModelInvokeCompletedEvent:
buffer = io.StringIO()
- for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content):
+ for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
+ contents=invoke_result.message.content,
+ file_saver=saver,
+ file_outputs=file_outputs,
+ ):
buffer.write(text_part)
return ModelInvokeCompletedEvent(
@@ -905,7 +981,12 @@ class LLMNode(BaseNode[LLMNodeData]):
finish_reason=None,
)
- def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File":
+ @staticmethod
+ def save_multimodal_image_output(
+ *,
+ content: ImagePromptMessageContent,
+ file_saver: LLMFileSaver,
+ ) -> "File":
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
There are two kinds of multimodal outputs:
@@ -915,124 +996,21 @@ class LLMNode(BaseNode[LLMNodeData]):
Currently, only image files are supported.
"""
- # Inject the saver somehow...
- _saver = self._llm_file_saver
-
- # If this
if content.url != "":
- saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
+ saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE)
else:
- saved_file = _saver.save_binary_string(
+ saved_file = file_saver.save_binary_string(
data=base64.b64decode(content.base64_data),
mime_type=content.mime_type,
file_type=FileType.IMAGE,
)
- self._file_outputs.append(saved_file)
return saved_file
- def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
- """
- Handle structured output for models with native JSON schema support.
-
- :param model_parameters: Model parameters to update
- :param rules: Model parameter rules
- :return: Updated model parameters with JSON schema configuration
- """
- # Process schema according to model requirements
- schema = self._fetch_structured_output_schema()
- schema_json = self._prepare_schema_for_model(schema)
-
- # Set JSON schema in parameters
- model_parameters["json_schema"] = json.dumps(schema_json, ensure_ascii=False)
-
- # Set appropriate response format if required by the model
- for rule in rules:
- if rule.name == "response_format" and ResponseFormat.JSON_SCHEMA.value in rule.options:
- model_parameters["response_format"] = ResponseFormat.JSON_SCHEMA.value
-
- return model_parameters
-
- def _handle_prompt_based_schema(self, prompt_messages: Sequence[PromptMessage]) -> list[PromptMessage]:
- """
- Handle structured output for models without native JSON schema support.
- This function modifies the prompt messages to include schema-based output requirements.
-
- Args:
- prompt_messages: Original sequence of prompt messages
-
- Returns:
- list[PromptMessage]: Updated prompt messages with structured output requirements
- """
- # Convert schema to string format
- schema_str = json.dumps(self._fetch_structured_output_schema(), ensure_ascii=False)
-
- # Find existing system prompt with schema placeholder
- system_prompt = next(
- (prompt for prompt in prompt_messages if isinstance(prompt, SystemPromptMessage)),
- None,
- )
- structured_output_prompt = STRUCTURED_OUTPUT_PROMPT.replace("{{schema}}", schema_str)
- # Prepare system prompt content
- system_prompt_content = (
- structured_output_prompt + "\n\n" + system_prompt.content
- if system_prompt and isinstance(system_prompt.content, str)
- else structured_output_prompt
- )
- system_prompt = SystemPromptMessage(content=system_prompt_content)
-
- # Extract content from the last user message
-
- filtered_prompts = [prompt for prompt in prompt_messages if not isinstance(prompt, SystemPromptMessage)]
- updated_prompt = [system_prompt] + filtered_prompts
-
- return updated_prompt
-
- def _set_response_format(self, model_parameters: dict, rules: list) -> None:
- """
- Set the appropriate response format parameter based on model rules.
-
- :param model_parameters: Model parameters to update
- :param rules: Model parameter rules
- """
- for rule in rules:
- if rule.name == "response_format":
- if ResponseFormat.JSON.value in rule.options:
- model_parameters["response_format"] = ResponseFormat.JSON.value
- elif ResponseFormat.JSON_OBJECT.value in rule.options:
- model_parameters["response_format"] = ResponseFormat.JSON_OBJECT.value
-
- def _prepare_schema_for_model(self, schema: dict) -> dict:
- """
- Prepare JSON schema based on model requirements.
-
- Different models have different requirements for JSON schema formatting.
- This function handles these differences.
-
- :param schema: The original JSON schema
- :return: Processed schema compatible with the current model
- """
-
- # Deep copy to avoid modifying the original schema
- processed_schema = schema.copy()
-
- # Convert boolean types to string types (common requirement)
- convert_boolean_to_string(processed_schema)
-
- # Apply model-specific transformations
- if SpecialModelType.GEMINI in self.node_data.model.name:
- remove_additional_properties(processed_schema)
- return processed_schema
- elif SpecialModelType.OLLAMA in self.node_data.model.provider:
- return processed_schema
- else:
- # Default format with name field
- return {"schema": processed_schema, "name": "llm_response"}
-
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
"""
Fetch model schema
"""
- model_name = self.node_data.model.name
+ model_name = self._node_data.model.name
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
@@ -1043,16 +1021,20 @@ class LLMNode(BaseNode[LLMNodeData]):
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_schema
- def _fetch_structured_output_schema(self) -> dict[str, Any]:
+ @staticmethod
+ def fetch_structured_output_schema(
+ *,
+ structured_output: Mapping[str, Any],
+ ) -> dict[str, Any]:
"""
Fetch the structured output schema from the node data.
Returns:
dict[str, Any]: The structured output schema
"""
- if not self.node_data.structured_output:
+ if not structured_output:
raise LLMNodeError("Please provide a valid structured output schema")
- structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False)
+ structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema")
@@ -1064,9 +1046,12 @@ class LLMNode(BaseNode[LLMNodeData]):
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
+ @staticmethod
def _save_multimodal_output_and_convert_result_to_markdown(
- self,
+ *,
contents: str | list[PromptMessageContentUnionTypes] | None,
+ file_saver: LLMFileSaver,
+ file_outputs: list["File"],
) -> Generator[str, None, None]:
"""Convert intermediate prompt messages into strings and yield them to the caller.
@@ -1089,9 +1074,12 @@ class LLMNode(BaseNode[LLMNodeData]):
if isinstance(item, TextPromptMessageContent):
yield item.data
elif isinstance(item, ImagePromptMessageContent):
- file = self._save_multimodal_image_output(item)
- self._file_outputs.append(file)
- yield self._image_file_to_markdown(file)
+ file = LLMNode.save_multimodal_image_output(
+ content=item,
+ file_saver=file_saver,
+ )
+ file_outputs.append(file)
+ yield LLMNode._image_file_to_markdown(file)
else:
logger.warning("unknown item type encountered, type=%s", type(item))
yield str(item)
@@ -1099,6 +1087,14 @@ class LLMNode(BaseNode[LLMNodeData]):
logger.warning("unknown contents type encountered, type=%s", type(contents))
yield str(contents)
+ @property
+ def continue_on_error(self) -> bool:
+ return self._node_data.error_strategy is not None
+
+ @property
+ def retry(self) -> bool:
+ return self._node_data.retry_config.retry_enabled
+
def _combine_message_content_with_role(
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
@@ -1116,20 +1112,20 @@ def _combine_message_content_with_role(
def _render_jinja2_message(
*,
template: str,
- jinjia2_variables: Sequence[VariableSelector],
+ jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
):
if not template:
return ""
- jinjia2_inputs = {}
- for jinja2_variable in jinjia2_variables:
+ jinja2_inputs = {}
+ for jinja2_variable in jinja2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
- jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
+ jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
code_execute_resp = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
- inputs=jinjia2_inputs,
+ inputs=jinja2_inputs,
)
result_text = code_execute_resp["result"]
return result_text
@@ -1225,7 +1221,7 @@ def _handle_completion_template(
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
- jinjia2_variables=jinja2_variables,
+ jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
else:
@@ -1239,49 +1235,3 @@ def _handle_completion_template(
)
prompt_messages.append(prompt_message)
return prompt_messages
-
-
-def remove_additional_properties(schema: dict) -> None:
- """
- Remove additionalProperties fields from JSON schema.
- Used for models like Gemini that don't support this property.
-
- :param schema: JSON schema to modify in-place
- """
- if not isinstance(schema, dict):
- return
-
- # Remove additionalProperties at current level
- schema.pop("additionalProperties", None)
-
- # Process nested structures recursively
- for value in schema.values():
- if isinstance(value, dict):
- remove_additional_properties(value)
- elif isinstance(value, list):
- for item in value:
- if isinstance(item, dict):
- remove_additional_properties(item)
-
-
-def convert_boolean_to_string(schema: dict) -> None:
- """
- Convert boolean type specifications to string in JSON schema.
-
- :param schema: JSON schema to modify in-place
- """
- if not isinstance(schema, dict):
- return
-
- # Check for boolean type at current level
- if schema.get("type") == "boolean":
- schema["type"] = "string"
-
- # Process nested dictionaries and lists recursively
- for value in schema.values():
- if isinstance(value, dict):
- convert_boolean_to_string(value)
- elif isinstance(value, list):
- for item in value:
- if isinstance(item, dict):
- convert_boolean_to_string(item)
diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py
index 3f4a5edab9..d04e0bfae1 100644
--- a/api/core/workflow/nodes/loop/entities.py
+++ b/api/core/workflow/nodes/loop/entities.py
@@ -1,11 +1,29 @@
from collections.abc import Mapping
-from typing import Any, Literal, Optional
+from typing import Annotated, Any, Literal, Optional
-from pydantic import BaseModel, Field
+from pydantic import AfterValidator, BaseModel, Field
+from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
from core.workflow.utils.condition.entities import Condition
+_VALID_VAR_TYPE = frozenset(
+ [
+ SegmentType.STRING,
+ SegmentType.NUMBER,
+ SegmentType.OBJECT,
+ SegmentType.ARRAY_STRING,
+ SegmentType.ARRAY_NUMBER,
+ SegmentType.ARRAY_OBJECT,
+ ]
+)
+
+
+def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
+ if seg_type not in _VALID_VAR_TYPE:
+ raise ValueError(...)
+ return seg_type
+
class LoopVariableData(BaseModel):
"""
@@ -13,7 +31,7 @@ class LoopVariableData(BaseModel):
"""
label: str
- var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
+ var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"]
value: Optional[Any | list[str]] = None
diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py
index 327b9e234b..53cadc5251 100644
--- a/api/core/workflow/nodes/loop/loop_end_node.py
+++ b/api/core/workflow/nodes/loop/loop_end_node.py
@@ -1,18 +1,48 @@
+from collections.abc import Mapping
+from typing import Any, Optional
+
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.loop.entities import LoopEndNodeData
-class LoopEndNode(BaseNode[LoopEndNodeData]):
+class LoopEndNode(BaseNode):
"""
Loop End Node.
"""
- _node_data_cls = LoopEndNodeData
_node_type = NodeType.LOOP_END
+ _node_data: LoopEndNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = LoopEndNodeData(**data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
def _run(self) -> NodeRunResult:
"""
Run the node.
diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py
index fafa205386..655de9362f 100644
--- a/api/core/workflow/nodes/loop/loop_node.py
+++ b/api/core/workflow/nodes/loop/loop_node.py
@@ -1,19 +1,15 @@
import json
import logging
+import time
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
-from typing import TYPE_CHECKING, Any, Literal, cast
+from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from configs import dify_config
from core.variables import (
- ArrayNumberSegment,
- ArrayObjectSegment,
- ArrayStringSegment,
IntegerSegment,
- ObjectSegment,
Segment,
SegmentType,
- StringSegment,
)
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@@ -34,10 +30,12 @@ from core.workflow.graph_engine.entities.event import (
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.utils.condition.processor import ConditionProcessor
+from factories.variable_factory import TypeMismatchError, build_segment_with_type
if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
@@ -46,28 +44,54 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class LoopNode(BaseNode[LoopNodeData]):
+class LoopNode(BaseNode):
"""
Loop Node.
"""
- _node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
+ _node_data: LoopNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = LoopNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""Run the node."""
# Get inputs
- loop_count = self.node_data.loop_count
- break_conditions = self.node_data.break_conditions
- logical_operator = self.node_data.logical_operator
+ loop_count = self._node_data.loop_count
+ break_conditions = self._node_data.break_conditions
+ logical_operator = self._node_data.logical_operator
inputs = {"loop_count": loop_count}
- if not self.node_data.start_node_id:
+ if not self._node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self.node_id} not found")
# Initialize graph
- loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self.node_data.start_node_id)
+ loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
if not loop_graph:
raise ValueError("loop graph not found")
@@ -77,8 +101,8 @@ class LoopNode(BaseNode[LoopNodeData]):
# Initialize loop variables
loop_variable_selectors = {}
- if self.node_data.loop_variables:
- for loop_variable in self.node_data.loop_variables:
+ if self._node_data.loop_variables:
+ for loop_variable in self._node_data.loop_variables:
value_processor = {
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var=loop_variable: variable_pool.get(var.value),
@@ -97,8 +121,11 @@ class LoopNode(BaseNode[LoopNodeData]):
loop_variable_selectors[loop_variable.label] = variable_selector
inputs[loop_variable.label] = processed_segment.value
+ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine
+ graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
+
graph_engine = GraphEngine(
tenant_id=self.tenant_id,
app_id=self.app_id,
@@ -110,7 +137,7 @@ class LoopNode(BaseNode[LoopNodeData]):
call_depth=self.workflow_call_depth,
graph=loop_graph,
graph_config=self.graph_config,
- variable_pool=variable_pool,
+ graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=self.thread_pool_id,
@@ -123,8 +150,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunStartedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
- loop_node_type=self.node_type,
- loop_node_data=self.node_data,
+ loop_node_type=self.type_,
+ loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
metadata={"loop_length": loop_count},
@@ -180,11 +207,11 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunSucceededEvent(
loop_id=self.id,
loop_node_id=self.node_id,
- loop_node_type=self.node_type,
- loop_node_data=self.node_data,
+ loop_node_type=self.type_,
+ loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
- outputs=self.node_data.outputs,
+ outputs=self._node_data.outputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
@@ -202,7 +229,7 @@ class LoopNode(BaseNode[LoopNodeData]):
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
- outputs=self.node_data.outputs,
+ outputs=self._node_data.outputs,
inputs=inputs,
)
)
@@ -213,8 +240,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
- loop_node_type=self.node_type,
- loop_node_data=self.node_data,
+ loop_node_type=self.type_,
+ loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=loop_count,
@@ -316,8 +343,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
- loop_node_type=self.node_type,
- loop_node_data=self.node_data,
+ loop_node_type=self.type_,
+ loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
@@ -347,8 +374,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
- loop_node_type=self.node_type,
- loop_node_data=self.node_data,
+ loop_node_type=self.type_,
+ loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
@@ -384,7 +411,7 @@ class LoopNode(BaseNode[LoopNodeData]):
_outputs[loop_variable_key] = None
_outputs["loop_round"] = current_index + 1
- self.node_data.outputs = _outputs
+ self._node_data.outputs = _outputs
if check_break_result:
return {"check_break_result": True}
@@ -396,10 +423,10 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunNextEvent(
loop_id=self.id,
loop_node_id=self.node_id,
- loop_node_type=self.node_type,
- loop_node_data=self.node_data,
+ loop_node_type=self.type_,
+ loop_node_data=self._node_data,
index=next_index,
- pre_loop_output=self.node_data.outputs,
+ pre_loop_output=self._node_data.outputs,
)
return {"check_break_result": False}
@@ -434,19 +461,15 @@ class LoopNode(BaseNode[LoopNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: LoopNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
+ # Create typed NodeData from dict
+ typed_node_data = LoopNodeData.model_validate(node_data)
+
variable_mapping = {}
# init graph
- loop_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
+ loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
if not loop_graph:
raise ValueError("loop graph not found")
@@ -482,6 +505,13 @@ class LoopNode(BaseNode[LoopNodeData]):
variable_mapping.update(sub_node_variable_mapping)
+ for loop_variable in typed_node_data.loop_variables or []:
+ if loop_variable.value_type == "variable":
+ assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
+ # add loop variable to variable mapping
+ selector = loop_variable.value
+ variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
+
# remove variable out from loop
variable_mapping = {
key: value for key, value in variable_mapping.items() if value[0] not in loop_graph.node_ids
@@ -490,23 +520,21 @@ class LoopNode(BaseNode[LoopNodeData]):
return variable_mapping
@staticmethod
- def _get_segment_for_constant(var_type: str, value: Any) -> Segment:
+ def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
- segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = {
- "string": (StringSegment, SegmentType.STRING),
- "number": (IntegerSegment, SegmentType.NUMBER),
- "object": (ObjectSegment, SegmentType.OBJECT),
- "array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING),
- "array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER),
- "array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT),
- }
if var_type in ["array[string]", "array[number]", "array[object]"]:
- if value:
+ if value and isinstance(value, str):
value = json.loads(value)
else:
value = []
- segment_info = segment_mapping.get(var_type)
- if not segment_info:
- raise ValueError(f"Invalid variable type: {var_type}")
- segment_class, value_type = segment_info
- return segment_class(value=value, value_type=value_type)
+ try:
+ return build_segment_with_type(var_type, value)
+ except TypeMismatchError as type_exc:
+ # Attempt to parse the value as a JSON-encoded string, if applicable.
+ if not isinstance(value, str):
+ raise
+ try:
+ value = json.loads(value)
+ except ValueError:
+ raise type_exc
+ return build_segment_with_type(var_type, value)
diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py
index 5a15f36044..29b45ea0c3 100644
--- a/api/core/workflow/nodes/loop/loop_start_node.py
+++ b/api/core/workflow/nodes/loop/loop_start_node.py
@@ -1,18 +1,48 @@
+from collections.abc import Mapping
+from typing import Any, Optional
+
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.loop.entities import LoopStartNodeData
-class LoopStartNode(BaseNode[LoopStartNodeData]):
+class LoopStartNode(BaseNode):
"""
Loop Start Node.
"""
- _node_data_cls = LoopStartNodeData
_node_type = NodeType.LOOP_START
+ _node_data: LoopStartNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = LoopStartNodeData(**data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
def _run(self) -> NodeRunResult:
"""
Run the node.
diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py
index 1f1be59542..294b47670b 100644
--- a/api/core/workflow/nodes/node_mapping.py
+++ b/api/core/workflow/nodes/node_mapping.py
@@ -25,6 +25,11 @@ from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as Var
LATEST_VERSION = "latest"
+# NOTE(QuantumGhost): This should be in sync with subclasses of BaseNode.
+# Specifically, if you have introduced new node types, you should add them here.
+#
+# TODO(QuantumGhost): This could be automated with either metaclass or `__init_subclass__`
+# hook. Try to avoid duplication of node information.
NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
NodeType.START: {
LATEST_VERSION: StartNode,
@@ -68,6 +73,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
},
NodeType.TOOL: {
LATEST_VERSION: ToolNode,
+ # This is an issue that caused problems before.
+ # Logically, we shouldn't use two different versions to point to the same class here,
+ # but in order to maintain compatibility with historical data, this approach has been retained.
+ "2": ToolNode,
"1": ToolNode,
},
NodeType.VARIABLE_AGGREGATOR: {
@@ -117,6 +126,10 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
},
NodeType.AGENT: {
LATEST_VERSION: AgentNode,
+ # This is an issue that caused problems before.
+ # Logically, we shouldn't use two different versions to point to the same class here,
+ # but in order to maintain compatibility with historical data, this approach has been retained.
+ "2": AgentNode,
"1": AgentNode,
},
}
diff --git a/api/core/workflow/nodes/parameter_extractor/entities.py b/api/core/workflow/nodes/parameter_extractor/entities.py
index 369eb13b04..916778d167 100644
--- a/api/core/workflow/nodes/parameter_extractor/entities.py
+++ b/api/core/workflow/nodes/parameter_extractor/entities.py
@@ -7,6 +7,10 @@ from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig
+class _ParameterConfigError(Exception):
+ pass
+
+
class ParameterConfig(BaseModel):
"""
Parameter Config.
@@ -27,6 +31,19 @@ class ParameterConfig(BaseModel):
raise ValueError("Invalid parameter name, __reason and __is_success are reserved")
return str(value)
+ def is_array_type(self) -> bool:
+ return self.type in ("array[string]", "array[number]", "array[object]")
+
+ def element_type(self) -> Literal["string", "number", "object"]:
+ if self.type == "array[number]":
+ return "number"
+ elif self.type == "array[string]":
+ return "string"
+ elif self.type == "array[object]":
+ return "object"
+ else:
+ raise _ParameterConfigError(f"{self.type} is not array type.")
+
class ParameterExtractorNodeData(BaseNodeData):
"""
diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
index 2552784762..a23d284626 100644
--- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
+++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
@@ -25,13 +25,16 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
+from core.variables.types import SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.utils import variable_template_parser
+from factories.variable_factory import build_segment_with_type
from .entities import ParameterExtractorNodeData
from .exc import (
@@ -89,10 +92,31 @@ class ParameterExtractorNode(BaseNode):
Parameter Extractor Node.
"""
- # FIXME: figure out why here is different from super class
- _node_data_cls = ParameterExtractorNodeData # type: ignore
_node_type = NodeType.PARAMETER_EXTRACTOR
+ _node_data: ParameterExtractorNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = ParameterExtractorNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
_model_instance: Optional[ModelInstance] = None
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
@@ -109,11 +133,15 @@ class ParameterExtractorNode(BaseNode):
}
}
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
def _run(self):
"""
Run the node.
"""
- node_data = cast(ParameterExtractorNodeData, self.node_data)
+ node_data = cast(ParameterExtractorNodeData, self._node_data)
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""
@@ -247,7 +275,12 @@ class ParameterExtractorNode(BaseNode):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
- outputs={"__is_success": 1 if not error else 0, "__reason": error, **result},
+ outputs={
+ "__is_success": 1 if not error else 0,
+ "__reason": error,
+ "__usage": jsonable_encoder(usage),
+ **result,
+ },
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
@@ -387,7 +420,7 @@ class ParameterExtractorNode(BaseNode):
"""
Generate prompt engineering prompt.
"""
- model_mode = ModelMode.value_of(data.model.mode)
+ model_mode = ModelMode(data.model.mode)
if model_mode == ModelMode.COMPLETION:
return self._generate_prompt_engineering_completion_prompt(
@@ -584,28 +617,30 @@ class ParameterExtractorNode(BaseNode):
elif parameter.type in {"string", "select"}:
if isinstance(result[parameter.name], str):
transformed_result[parameter.name] = result[parameter.name]
- elif parameter.type.startswith("array"):
+ elif parameter.is_array_type():
if isinstance(result[parameter.name], list):
- nested_type = parameter.type[6:-1]
- transformed_result[parameter.name] = []
+ nested_type = parameter.element_type()
+ assert nested_type is not None
+ segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[])
+ transformed_result[parameter.name] = segment_value
for item in result[parameter.name]:
if nested_type == "number":
if isinstance(item, int | float):
- transformed_result[parameter.name].append(item)
+ segment_value.value.append(item)
elif isinstance(item, str):
try:
if "." in item:
- transformed_result[parameter.name].append(float(item))
+ segment_value.value.append(float(item))
else:
- transformed_result[parameter.name].append(int(item))
+ segment_value.value.append(int(item))
except ValueError:
pass
elif nested_type == "string":
if isinstance(item, str):
- transformed_result[parameter.name].append(item)
+ segment_value.value.append(item)
elif nested_type == "object":
if isinstance(item, dict):
- transformed_result[parameter.name].append(item)
+ segment_value.value.append(item)
if parameter.name not in transformed_result:
if parameter.type == "number":
@@ -615,7 +650,9 @@ class ParameterExtractorNode(BaseNode):
elif parameter.type in {"string", "select"}:
transformed_result[parameter.name] = ""
elif parameter.type.startswith("array"):
- transformed_result[parameter.name] = []
+ transformed_result[parameter.name] = build_segment_with_type(
+ segment_type=SegmentType(parameter.type), value=[]
+ )
return transformed_result
@@ -679,7 +716,7 @@ class ParameterExtractorNode(BaseNode):
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
- model_mode = ModelMode.value_of(node_data.model.mode)
+ model_mode = ModelMode(node_data.model.mode)
input_text = query
memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text
@@ -706,7 +743,7 @@ class ParameterExtractorNode(BaseNode):
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
):
- model_mode = ModelMode.value_of(node_data.model.mode)
+ model_mode = ModelMode(node_data.model.mode)
input_text = query
memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text
@@ -812,19 +849,15 @@ class ParameterExtractorNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: ParameterExtractorNodeData, # type: ignore
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
- variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
+ # Create typed NodeData from dict
+ typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
+
+ variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
- if node_data.instruction:
- selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
+ if typed_node_data.instruction:
+ selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
for selector in selectors:
variable_mapping[selector.variable] = selector.value_selector
diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py
index 1f50700c7e..15012fa48d 100644
--- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py
+++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py
@@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
-from typing import Any, Optional, cast
+from typing import TYPE_CHECKING, Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -11,8 +11,11 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunResult
+from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.base.node import BaseNode
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import ModelInvokeCompletedEvent
from core.workflow.nodes.llm import (
LLMNode,
@@ -20,6 +23,7 @@ from core.workflow.nodes.llm import (
LLMNodeCompletionModelPromptTemplate,
llm_utils,
)
+from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown
@@ -35,13 +39,77 @@ from .template_prompts import (
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
+if TYPE_CHECKING:
+ from core.file.models import File
+ from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
-class QuestionClassifierNode(LLMNode):
- _node_data_cls = QuestionClassifierNodeData # type: ignore
+
+class QuestionClassifierNode(BaseNode):
_node_type = NodeType.QUESTION_CLASSIFIER
+ _node_data: QuestionClassifierNodeData
+
+ _file_outputs: list["File"]
+ _llm_file_saver: LLMFileSaver
+
+ def __init__(
+ self,
+ id: str,
+ config: Mapping[str, Any],
+ graph_init_params: "GraphInitParams",
+ graph: "Graph",
+ graph_runtime_state: "GraphRuntimeState",
+ previous_node_id: Optional[str] = None,
+ thread_pool_id: Optional[str] = None,
+ *,
+ llm_file_saver: LLMFileSaver | None = None,
+ ) -> None:
+ super().__init__(
+ id=id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=graph,
+ graph_runtime_state=graph_runtime_state,
+ previous_node_id=previous_node_id,
+ thread_pool_id=thread_pool_id,
+ )
+ # LLM file outputs, used for MultiModal outputs.
+ self._file_outputs: list[File] = []
+
+ if llm_file_saver is None:
+ llm_file_saver = FileSaverImpl(
+ user_id=graph_init_params.user_id,
+ tenant_id=graph_init_params.tenant_id,
+ )
+ self._llm_file_saver = llm_file_saver
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = QuestionClassifierNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @classmethod
+ def version(cls):
+ return "1"
+
def _run(self):
- node_data = cast(QuestionClassifierNodeData, self.node_data)
+ node_data = cast(QuestionClassifierNodeData, self._node_data)
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
@@ -49,7 +117,10 @@ class QuestionClassifierNode(LLMNode):
query = variable.value if variable else None
variables = {"query": query}
# fetch model config
- model_instance, model_config = self._fetch_model_config(node_data.model)
+ model_instance, model_config = LLMNode._fetch_model_config(
+ node_data_model=node_data.model,
+ tenant_id=self.tenant_id,
+ )
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
@@ -87,7 +158,7 @@ class QuestionClassifierNode(LLMNode):
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
# two consecutive user prompts will be generated, causing model's error.
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
- prompt_messages, stop = self._fetch_prompt_messages(
+ prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query="",
memory=memory,
@@ -97,6 +168,7 @@ class QuestionClassifierNode(LLMNode):
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
+ tenant_id=self.tenant_id,
)
result_text = ""
@@ -105,11 +177,17 @@ class QuestionClassifierNode(LLMNode):
try:
# handle invoke result
- generator = self._invoke_llm(
+ generator = LLMNode.invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
+ user_id=self.user_id,
+ structured_output_enabled=False,
+ structured_output=None,
+ file_saver=self._llm_file_saver,
+ file_outputs=self._file_outputs,
+ node_id=self.node_id,
)
for event in generator:
@@ -141,7 +219,11 @@ class QuestionClassifierNode(LLMNode):
"model_provider": model_config.provider,
"model_name": model_config.model,
}
- outputs = {"class_name": category_name, "class_id": category_id}
+ outputs = {
+ "class_name": category_name,
+ "class_id": category_id,
+ "usage": jsonable_encoder(usage),
+ }
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -175,23 +257,18 @@ class QuestionClassifierNode(LLMNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: Any,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
- node_data = cast(QuestionClassifierNodeData, node_data)
- variable_mapping = {"query": node_data.query_variable_selector}
- variable_selectors = []
- if node_data.instruction:
- variable_template_parser = VariableTemplateParser(template=node_data.instruction)
+ # Create typed NodeData from dict
+ typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
+
+ variable_mapping = {"query": typed_node_data.query_variable_selector}
+ variable_selectors: list[VariableSelector] = []
+ if typed_node_data.instruction:
+ variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
- variable_mapping[variable_selector.variable] = variable_selector.value_selector
+ variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
@@ -257,7 +334,7 @@ class QuestionClassifierNode(LLMNode):
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
):
- model_mode = ModelMode.value_of(node_data.model.mode)
+ model_mode = ModelMode(node_data.model.mode)
classes = node_data.classes
categories = []
for class_ in classes:
diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py
index 8839aec9d6..9e401e76bb 100644
--- a/api/core/workflow/nodes/start/start_node.py
+++ b/api/core/workflow/nodes/start/start_node.py
@@ -1,22 +1,53 @@
+from collections.abc import Mapping
+from typing import Any, Optional
+
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.start.entities import StartNodeData
-class StartNode(BaseNode[StartNodeData]):
- _node_data_cls = StartNodeData
+class StartNode(BaseNode):
_node_type = NodeType.START
+ _node_data: StartNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = StartNodeData(**data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
- system_inputs = self.graph_runtime_state.variable_pool.system_variables
+ system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.
for var in system_inputs:
node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var]
+ outputs = dict(node_inputs)
- return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=node_inputs)
+ return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs)
diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py
index 476cf7eee4..1962c82db1 100644
--- a/api/core/workflow/nodes/template_transform/template_transform_node.py
+++ b/api/core/workflow/nodes/template_transform/template_transform_node.py
@@ -6,16 +6,39 @@ from core.helper.code_executor.code_executor import CodeExecutionError, CodeExec
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
-class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
- _node_data_cls = TemplateTransformNodeData
+class TemplateTransformNode(BaseNode):
_node_type = NodeType.TEMPLATE_TRANSFORM
+ _node_data: TemplateTransformNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = TemplateTransformNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
@@ -28,17 +51,21 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
"config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"},
}
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
def _run(self) -> NodeRunResult:
# Get variables
variables = {}
- for variable_selector in self.node_data.variables:
+ for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
variables[variable_name] = value.to_object() if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
- language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
+ language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables
)
except CodeExecutionError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
@@ -56,16 +83,12 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
@classmethod
def _extract_variable_selector_to_variable_mapping(
- cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
+ cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
+ # Create typed NodeData from dict
+ typed_node_data = TemplateTransformNodeData.model_validate(node_data)
+
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
- for variable_selector in node_data.variables
+ for variable_selector in typed_node_data.variables
}
diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py
index 21023d4ab7..f0a44d919b 100644
--- a/api/core/workflow/nodes/tool/entities.py
+++ b/api/core/workflow/nodes/tool/entities.py
@@ -14,6 +14,7 @@ class ToolEntity(BaseModel):
tool_name: str
tool_label: str # redundancy
tool_configurations: dict[str, Any]
+ credential_id: str | None = None
plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before")
@@ -41,6 +42,10 @@ class ToolNodeData(BaseNodeData, ToolEntity):
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get("value")
+
+ if value is None:
+ return typ
+
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
@@ -54,3 +59,26 @@ class ToolNodeData(BaseNodeData, ToolEntity):
return typ
tool_parameters: dict[str, ToolInput]
+ # The version of the tool parameter.
+ # If this value is None, it indicates this is a previous version
+ # and requires using the legacy parameter parsing rules.
+ tool_node_version: str | None = None
+
+ @field_validator("tool_parameters", mode="before")
+ @classmethod
+ def filter_none_tool_inputs(cls, value):
+ if not isinstance(value, dict):
+ return value
+
+ return {
+ key: tool_input
+ for key, tool_input in value.items()
+ if tool_input is not None and cls._has_valid_value(tool_input)
+ }
+
+ @staticmethod
+ def _has_valid_value(tool_input):
+ """Check if the value is valid"""
+ if isinstance(tool_input, dict):
+ return tool_input.get("value") is not None
+ return getattr(tool_input, "value", None) is not None
diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py
index aaecc7b989..140fe71f60 100644
--- a/api/core/workflow/nodes/tool/tool_node.py
+++ b/api/core/workflow/nodes/tool/tool_node.py
@@ -1,5 +1,5 @@
from collections.abc import Generator, Mapping, Sequence
-from typing import Any, cast
+from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -12,15 +12,15 @@ from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
from core.tools.utils.message_transformer import ToolFileMessageTransformer
-from core.variables.segments import ArrayAnySegment
+from core.variables.segments import ArrayAnySegment, ArrayFileSegment
from core.variables.variables import ArrayAnyVariable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
-from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
@@ -36,20 +36,28 @@ from .exc import (
)
-class ToolNode(BaseNode[ToolNodeData]):
+class ToolNode(BaseNode):
"""
Tool Node
"""
- _node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
+ _node_data: ToolNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = ToolNodeData.model_validate(data)
+
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
def _run(self) -> Generator:
"""
Run the tool node
"""
- node_data = cast(ToolNodeData, self.node_data)
+ node_data = cast(ToolNodeData, self._node_data)
# fetch tool icon
tool_info = {
@@ -62,8 +70,15 @@ class ToolNode(BaseNode[ToolNodeData]):
try:
from core.tools.tool_manager import ToolManager
+ # This is an issue that caused problems before.
+ # Logically, we shouldn't use the node_data.version field for judgment
+ # But for backward compatibility with historical data
+ # this version field judgment is still preserved here.
+ variable_pool: VariablePool | None = None
+ if node_data.version != "1" or node_data.tool_node_version != "1":
+ variable_pool = self.graph_runtime_state.variable_pool
tool_runtime = ToolManager.get_workflow_tool_runtime(
- self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from
+ self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool
)
except ToolNodeError as e:
yield RunCompletedEvent(
@@ -82,15 +97,14 @@ class ToolNode(BaseNode[ToolNodeData]):
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
- node_data=self.node_data,
+ node_data=self._node_data,
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
- node_data=self.node_data,
+ node_data=self._node_data,
for_log=True,
)
-
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
@@ -119,7 +133,14 @@ class ToolNode(BaseNode[ToolNodeData]):
try:
# convert tool messages
- yield from self._transform_message(message_stream, tool_info, parameters_for_log)
+ yield from self._transform_message(
+ messages=message_stream,
+ tool_info=tool_info,
+ parameters_for_log=parameters_for_log,
+ user_id=self.user_id,
+ tenant_id=self.tenant_id,
+ node_id=self.node_id,
+ )
except (PluginDaemonClientSideError, ToolInvokeError) as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
@@ -163,7 +184,9 @@ class ToolNode(BaseNode[ToolNodeData]):
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
if variable is None:
- raise ToolParameterError(f"Variable {tool_input.value} does not exist")
+ if parameter.required:
+ raise ToolParameterError(f"Variable {tool_input.value} does not exist")
+ continue
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
segment_group = variable_pool.convert_template(str(tool_input.value))
@@ -184,6 +207,9 @@ class ToolNode(BaseNode[ToolNodeData]):
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
+ user_id: str,
+ tenant_id: str,
+ node_id: str,
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
@@ -191,8 +217,8 @@ class ToolNode(BaseNode[ToolNodeData]):
# transform message and handle file storage
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
- user_id=self.user_id,
- tenant_id=self.tenant_id,
+ user_id=user_id,
+ tenant_id=tenant_id,
conversation_id=None,
)
@@ -200,9 +226,6 @@ class ToolNode(BaseNode[ToolNodeData]):
files: list[File] = []
json: list[dict] = []
- agent_logs: list[AgentLogEvent] = []
- agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
-
variables: dict[str, Any] = {}
for message in message_stream:
@@ -235,7 +258,7 @@ class ToolNode(BaseNode[ToolNodeData]):
}
file = file_factory.build_from_mapping(
mapping=mapping,
- tenant_id=self.tenant_id,
+ tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
@@ -258,48 +281,42 @@ class ToolNode(BaseNode[ToolNodeData]):
files.append(
file_factory.build_from_mapping(
mapping=mapping,
- tenant_id=self.tenant_id,
+ tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
- yield RunStreamChunkEvent(
- chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
- )
+ yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
- if self.node_type == NodeType.AGENT:
- msg_metadata = message.message.json_object.pop("execution_metadata", {})
- agent_execution_metadata = {
- key: value
- for key, value in msg_metadata.items()
- if key in WorkflowNodeExecutionMetadataKey.__members__.values()
- }
- json.append(message.message.json_object)
+ # JSON message handling for tool node
+ if message.message.json_object is not None:
+ json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
- yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
+ yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
- raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
+ raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
- chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
+ chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
+ assert isinstance(message.meta, File)
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
@@ -308,7 +325,7 @@ class ToolNode(BaseNode[ToolNodeData]):
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
- plugins = manager.list_plugins(self.tenant_id)
+ plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
@@ -318,56 +335,40 @@ class ToolNode(BaseNode[ToolNodeData]):
icon = current_plugin.declaration.icon
except StopIteration:
pass
+ icon_dark = None
try:
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
- self.user_id,
- self.tenant_id,
+ user_id,
+ tenant_id,
)
if provider.name == dict_metadata["provider"]
)
icon = builtin_tool.icon
+ icon_dark = builtin_tool.icon_dark
except StopIteration:
pass
dict_metadata["icon"] = icon
+ dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
- agent_log = AgentLogEvent(
- id=message.message.id,
- node_execution_id=self.id,
- parent_id=message.message.parent_id,
- error=message.message.error,
- status=message.message.status.value,
- data=message.message.data,
- label=message.message.label,
- metadata=message.message.metadata,
- node_id=self.node_id,
- )
- # check if the agent log is already in the list
- for log in agent_logs:
- if log.id == agent_log.id:
- # update the log
- log.data = agent_log.data
- log.status = agent_log.status
- log.error = agent_log.error
- log.label = agent_log.label
- log.metadata = agent_log.metadata
- break
- else:
- agent_logs.append(agent_log)
+ # Add agent_logs to outputs['json'] to ensure frontend can access thinking process
+ json_output: list[dict[str, Any]] = []
- yield agent_log
+ # Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
+ if json:
+ json_output.extend(json)
+ else:
+ json_output.append({"data": []})
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
- outputs={"text": text, "files": files, "json": json, **variables},
+ outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={
- **agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
- WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
)
@@ -379,7 +380,7 @@ class ToolNode(BaseNode[ToolNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: ToolNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -388,9 +389,12 @@ class ToolNode(BaseNode[ToolNodeData]):
:param node_data: node data
:return:
"""
+ # Create typed NodeData from dict
+ typed_node_data = ToolNodeData.model_validate(node_data)
+
result = {}
- for parameter_name in node_data.tool_parameters:
- input = node_data.tool_parameters[parameter_name]
+ for parameter_name in typed_node_data.tool_parameters:
+ input = typed_node_data.tool_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
@@ -404,3 +408,29 @@ class ToolNode(BaseNode[ToolNodeData]):
result = {node_id + "." + key: value for key, value in result.items()}
return result
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @property
+ def continue_on_error(self) -> bool:
+ return self._node_data.error_strategy is not None
+
+ @property
+ def retry(self) -> bool:
+ return self._node_data.retry_config.retry_enabled
diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py
index db3e25b015..98127bbeb6 100644
--- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py
+++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py
@@ -1,34 +1,65 @@
+from collections.abc import Mapping
+from typing import Any, Optional
+
+from core.variables.segments import Segment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
-class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
- _node_data_cls = VariableAssignerNodeData
+class VariableAggregatorNode(BaseNode):
_node_type = NodeType.VARIABLE_AGGREGATOR
+ _node_data: VariableAssignerNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = VariableAssignerNodeData(**data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
def _run(self) -> NodeRunResult:
# Get variables
- outputs = {}
+ outputs: dict[str, Segment | Mapping[str, Segment]] = {}
inputs = {}
- if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
- for selector in self.node_data.variables:
+ if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled:
+ for selector in self._node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
- outputs = {"output": variable.to_object()}
+ outputs = {"output": variable}
inputs = {".".join(selector[1:]): variable.to_object()}
break
else:
- for group in self.node_data.advanced_settings.groups:
+ for group in self._node_data.advanced_settings.groups:
for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
- outputs[group.group_name] = {"output": variable.to_object()}
+ outputs[group.group_name] = {"output": variable}
inputs[".".join(selector[1:])] = variable.to_object()
break
diff --git a/api/core/workflow/nodes/variable_assigner/common/helpers.py b/api/core/workflow/nodes/variable_assigner/common/helpers.py
index 8031b57fa8..0d2822233e 100644
--- a/api/core/workflow/nodes/variable_assigner/common/helpers.py
+++ b/api/core/workflow/nodes/variable_assigner/common/helpers.py
@@ -1,19 +1,55 @@
-from sqlalchemy import select
-from sqlalchemy.orm import Session
+from collections.abc import Mapping, MutableMapping, Sequence
+from typing import Any, TypeVar
-from core.variables import Variable
-from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
-from extensions.ext_database import db
-from models import ConversationVariable
+from pydantic import BaseModel
+from core.variables import Segment
+from core.variables.consts import MIN_SELECTORS_LENGTH
+from core.variables.types import SegmentType
-def update_conversation_variable(conversation_id: str, variable: Variable):
- stmt = select(ConversationVariable).where(
- ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
+# Use double underscore (`__`) prefix for internal variables
+# to minimize risk of collision with user-defined variable names.
+_UPDATED_VARIABLES_KEY = "__updated_variables"
+
+
+class UpdatedVariable(BaseModel):
+ name: str
+ selector: Sequence[str]
+ value_type: SegmentType
+ new_value: Any
+
+
+_T = TypeVar("_T", bound=MutableMapping[str, Any])
+
+
+def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable:
+ if len(selector) < MIN_SELECTORS_LENGTH:
+ raise Exception("selector too short")
+ node_id, var_name = selector[:2]
+ return UpdatedVariable(
+ name=var_name,
+ selector=list(selector[:2]),
+ value_type=seg.value_type,
+ new_value=seg.value,
)
- with Session(db.engine) as session:
- row = session.scalar(stmt)
- if not row:
- raise VariableOperatorNodeError("conversation variable not found in the database")
- row.data = variable.model_dump_json()
- session.commit()
+
+
+def set_updated_variables(m: _T, updates: Sequence[UpdatedVariable]) -> _T:
+ m[_UPDATED_VARIABLES_KEY] = updates
+ return m
+
+
+def get_updated_variables(m: Mapping[str, Any]) -> Sequence[UpdatedVariable] | None:
+ updated_values = m.get(_UPDATED_VARIABLES_KEY, None)
+ if updated_values is None:
+ return None
+ result = []
+ for items in updated_values:
+ if isinstance(items, UpdatedVariable):
+ result.append(items)
+ elif isinstance(items, dict):
+ items = UpdatedVariable.model_validate(items)
+ result.append(items)
+ else:
+ raise TypeError(f"Invalid updated variable: {items}, type={type(items)}")
+ return result
diff --git a/api/core/workflow/nodes/variable_assigner/common/impl.py b/api/core/workflow/nodes/variable_assigner/common/impl.py
new file mode 100644
index 0000000000..8f7a44bb62
--- /dev/null
+++ b/api/core/workflow/nodes/variable_assigner/common/impl.py
@@ -0,0 +1,38 @@
+from sqlalchemy import Engine, select
+from sqlalchemy.orm import Session
+
+from core.variables.variables import Variable
+from models.engine import db
+from models.workflow import ConversationVariable
+
+from .exc import VariableOperatorNodeError
+
+
+class ConversationVariableUpdaterImpl:
+ _engine: Engine | None
+
+ def __init__(self, engine: Engine | None = None) -> None:
+ self._engine = engine
+
+ def _get_engine(self) -> Engine:
+ if self._engine:
+ return self._engine
+ return db.engine
+
+ def update(self, conversation_id: str, variable: Variable):
+ stmt = select(ConversationVariable).where(
+ ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
+ )
+ with Session(self._get_engine()) as session:
+ row = session.scalar(stmt)
+ if not row:
+ raise VariableOperatorNodeError("conversation variable not found in the database")
+ row.data = variable.model_dump_json()
+ session.commit()
+
+ def flush(self):
+ pass
+
+
+def conversation_variable_updater_factory() -> ConversationVariableUpdaterImpl:
+ return ConversationVariableUpdaterImpl()
diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py
index 835e1d77b5..51383fa588 100644
--- a/api/core/workflow/nodes/variable_assigner/v1/node.py
+++ b/api/core/workflow/nodes/variable_assigner/v1/node.py
@@ -1,34 +1,120 @@
+from collections.abc import Callable, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, Optional, TypeAlias
+
from core.variables import SegmentType, Variable
+from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
+from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory
+from ..common.impl import conversation_variable_updater_factory
from .node_data import VariableAssignerData, WriteMode
+if TYPE_CHECKING:
+ from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
+
+
+_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
-class VariableAssignerNode(BaseNode[VariableAssignerData]):
- _node_data_cls = VariableAssignerData
+
+class VariableAssignerNode(BaseNode):
_node_type = NodeType.VARIABLE_ASSIGNER
+ _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
+
+ _node_data: VariableAssignerData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = VariableAssignerData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ def __init__(
+ self,
+ id: str,
+ config: Mapping[str, Any],
+ graph_init_params: "GraphInitParams",
+ graph: "Graph",
+ graph_runtime_state: "GraphRuntimeState",
+ previous_node_id: Optional[str] = None,
+ thread_pool_id: Optional[str] = None,
+ conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY = conversation_variable_updater_factory,
+ ) -> None:
+ super().__init__(
+ id=id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=graph,
+ graph_runtime_state=graph_runtime_state,
+ previous_node_id=previous_node_id,
+ thread_pool_id=thread_pool_id,
+ )
+ self._conv_var_updater_factory = conv_var_updater_factory
+
+ @classmethod
+ def version(cls) -> str:
+ return "1"
+
+ @classmethod
+ def _extract_variable_selector_to_variable_mapping(
+ cls,
+ *,
+ graph_config: Mapping[str, Any],
+ node_id: str,
+ node_data: Mapping[str, Any],
+ ) -> Mapping[str, Sequence[str]]:
+ # Create typed NodeData from dict
+ typed_node_data = VariableAssignerData.model_validate(node_data)
+
+ mapping = {}
+ assigned_variable_node_id = typed_node_data.assigned_variable_selector[0]
+ if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
+ selector_key = ".".join(typed_node_data.assigned_variable_selector)
+ key = f"{node_id}.#{selector_key}#"
+ mapping[key] = typed_node_data.assigned_variable_selector
+
+ selector_key = ".".join(typed_node_data.input_variable_selector)
+ key = f"{node_id}.#{selector_key}#"
+ mapping[key] = typed_node_data.input_variable_selector
+ return mapping
def _run(self) -> NodeRunResult:
+ assigned_variable_selector = self._node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
- original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector)
+ original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableOperatorNodeError("assigned variable not found")
- match self.node_data.write_mode:
+ match self._node_data.write_mode:
case WriteMode.OVER_WRITE:
- income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
+ income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
if not income_value:
raise VariableOperatorNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
- income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
+ income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
if not income_value:
raise VariableOperatorNodeError("input value not found")
updated_value = original_variable.value + [income_value.value]
@@ -41,27 +127,36 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
- raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
+ raise VariableOperatorNodeError(f"unsupported write mode: {self._node_data.write_mode}")
# Over write the variable.
- self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable)
+ self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
# TODO: Move database operation to the pipeline.
# Update conversation variable.
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"])
if not conversation_id:
raise VariableOperatorNodeError("conversation_id not found")
- common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable)
+ conv_var_updater = self._conv_var_updater_factory()
+ conv_var_updater.update(conversation_id=conversation_id.text, variable=updated_variable)
+ conv_var_updater.flush()
+ updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)]
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={
"value": income_value.to_object(),
},
+ # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`,
+ # we still set `output_variables` as a list to ensure the schema of output is
+ # compatible with `v2.VariableAssignerNode`.
+ process_data=common_helpers.set_updated_variables({}, updated_variables),
+ outputs={},
)
def get_zero_value(t: SegmentType):
+ # TODO(QuantumGhost): this should be a method of `SegmentType`.
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return variable_factory.build_segment([])
@@ -69,6 +164,10 @@ def get_zero_value(t: SegmentType):
return variable_factory.build_segment({})
case SegmentType.STRING:
return variable_factory.build_segment("")
+ case SegmentType.INTEGER:
+ return variable_factory.build_segment(0)
+ case SegmentType.FLOAT:
+ return variable_factory.build_segment(0.0)
case SegmentType.NUMBER:
return variable_factory.build_segment(0)
case _:
diff --git a/api/core/workflow/nodes/variable_assigner/v2/constants.py b/api/core/workflow/nodes/variable_assigner/v2/constants.py
index 3797bfa77a..7f760e5baa 100644
--- a/api/core/workflow/nodes/variable_assigner/v2/constants.py
+++ b/api/core/workflow/nodes/variable_assigner/v2/constants.py
@@ -1,5 +1,6 @@
from core.variables import SegmentType
+# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy.
EMPTY_VALUE_MAPPING = {
SegmentType.STRING: "",
SegmentType.NUMBER: 0,
diff --git a/api/core/workflow/nodes/variable_assigner/v2/entities.py b/api/core/workflow/nodes/variable_assigner/v2/entities.py
index 01df33b6d4..d93affcd15 100644
--- a/api/core/workflow/nodes/variable_assigner/v2/entities.py
+++ b/api/core/workflow/nodes/variable_assigner/v2/entities.py
@@ -12,6 +12,12 @@ class VariableOperationItem(BaseModel):
variable_selector: Sequence[str]
input_type: InputType
operation: Operation
+ # NOTE(QuantumGhost): The `value` field serves multiple purposes depending on context:
+ #
+ # 1. For CONSTANT input_type: Contains the literal value to be used in the operation.
+ # 2. For VARIABLE input_type: Initially contains the selector of the source variable.
+ # 3. During the variable updating procedure: The `value` field is reassigned to hold
+ # the resolved actual value that will be applied to the target variable.
value: Any | None = None
diff --git a/api/core/workflow/nodes/variable_assigner/v2/exc.py b/api/core/workflow/nodes/variable_assigner/v2/exc.py
index b67af6d73c..fd6c304a9a 100644
--- a/api/core/workflow/nodes/variable_assigner/v2/exc.py
+++ b/api/core/workflow/nodes/variable_assigner/v2/exc.py
@@ -29,3 +29,8 @@ class InvalidInputValueError(VariableOperatorNodeError):
class ConversationIDNotFoundError(VariableOperatorNodeError):
def __init__(self):
super().__init__("conversation_id not found")
+
+
+class InvalidDataError(VariableOperatorNodeError):
+ def __init__(self, message: str) -> None:
+ super().__init__(message)
diff --git a/api/core/workflow/nodes/variable_assigner/v2/helpers.py b/api/core/workflow/nodes/variable_assigner/v2/helpers.py
index 8fb2a27388..7a20975b15 100644
--- a/api/core/workflow/nodes/variable_assigner/v2/helpers.py
+++ b/api/core/workflow/nodes/variable_assigner/v2/helpers.py
@@ -10,10 +10,16 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
case Operation.OVER_WRITE | Operation.CLEAR:
return True
case Operation.SET:
- return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER}
+ return variable_type in {
+ SegmentType.OBJECT,
+ SegmentType.STRING,
+ SegmentType.NUMBER,
+ SegmentType.INTEGER,
+ SegmentType.FLOAT,
+ }
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
# Only number variable can be added, subtracted, multiplied or divided
- return variable_type == SegmentType.NUMBER
+ return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
case Operation.APPEND | Operation.EXTEND:
# Only array variable can be appended or extended
return variable_type in {
@@ -46,7 +52,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat
match variable_type:
case SegmentType.STRING | SegmentType.OBJECT:
return operation in {Operation.OVER_WRITE, Operation.SET}
- case SegmentType.NUMBER:
+ case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return operation in {
Operation.OVER_WRITE,
Operation.SET,
@@ -66,7 +72,7 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
case SegmentType.STRING:
return isinstance(value, str)
- case SegmentType.NUMBER:
+ case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
if not isinstance(value, int | float):
return False
if operation == Operation.DIVIDE and value == 0:
diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py
index 8759a55b34..c0215cae71 100644
--- a/api/core/workflow/nodes/variable_assigner/v2/node.py
+++ b/api/core/workflow/nodes/variable_assigner/v2/node.py
@@ -1,42 +1,116 @@
import json
-from collections.abc import Sequence
-from typing import Any, cast
+from collections.abc import Mapping, MutableMapping, Sequence
+from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
+from core.variables.consts import MIN_SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
+from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
+from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
from . import helpers
from .constants import EMPTY_VALUE_MAPPING
-from .entities import VariableAssignerNodeData
+from .entities import VariableAssignerNodeData, VariableOperationItem
from .enums import InputType, Operation
from .exc import (
ConversationIDNotFoundError,
InputTypeNotSupportedError,
+ InvalidDataError,
InvalidInputValueError,
OperationNotSupportedError,
VariableNotFoundError,
)
-class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
- _node_data_cls = VariableAssignerNodeData
+def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
+ selector_node_id = item.variable_selector[0]
+ if selector_node_id != CONVERSATION_VARIABLE_NODE_ID:
+ return
+ selector_str = ".".join(item.variable_selector)
+ key = f"{node_id}.#{selector_str}#"
+ mapping[key] = item.variable_selector
+
+
+def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
+ # Keep this in sync with the logic in _run methods...
+ if item.input_type != InputType.VARIABLE:
+ return
+ selector = item.value
+ if not isinstance(selector, list):
+ raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}")
+ if len(selector) < MIN_SELECTORS_LENGTH:
+ raise InvalidDataError(f"selector too short, {node_id=}, {item=}")
+ selector_str = ".".join(selector)
+ key = f"{node_id}.#{selector_str}#"
+ mapping[key] = selector
+
+
+class VariableAssignerNode(BaseNode):
_node_type = NodeType.VARIABLE_ASSIGNER
+ _node_data: VariableAssignerNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = VariableAssignerNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
+ return conversation_variable_updater_factory()
+
+ @classmethod
+ def version(cls) -> str:
+ return "2"
+
+ @classmethod
+ def _extract_variable_selector_to_variable_mapping(
+ cls,
+ *,
+ graph_config: Mapping[str, Any],
+ node_id: str,
+ node_data: Mapping[str, Any],
+ ) -> Mapping[str, Sequence[str]]:
+ # Create typed NodeData from dict
+ typed_node_data = VariableAssignerNodeData.model_validate(node_data)
+
+ var_mapping: dict[str, Sequence[str]] = {}
+ for item in typed_node_data.items:
+ _target_mapping_from_item(var_mapping, node_id, item)
+ _source_mapping_from_item(var_mapping, node_id, item)
+ return var_mapping
+
def _run(self) -> NodeRunResult:
- inputs = self.node_data.model_dump()
+ inputs = self._node_data.model_dump()
process_data: dict[str, Any] = {}
# NOTE: This node has no outputs
updated_variable_selectors: list[Sequence[str]] = []
try:
- for item in self.node_data.items:
+ for item in self._node_data.items:
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
# ==================== Validation Part
@@ -114,6 +188,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
# remove the duplicated items first.
updated_variable_selectors = list(set(map(tuple, updated_variable_selectors)))
+ conv_var_updater = self._conv_var_updater_factory()
# Update variables
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
@@ -128,15 +203,23 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
raise ConversationIDNotFoundError
else:
conversation_id = conversation_id.value
- common_helpers.update_conversation_variable(
+ conv_var_updater.update(
conversation_id=cast(str, conversation_id),
variable=variable,
)
+ conv_var_updater.flush()
+ updated_variables = [
+ common_helpers.variable_to_processed_data(selector, seg)
+ for selector in updated_variable_selectors
+ if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None
+ ]
+ process_data = common_helpers.set_updated_variables(process_data, updated_variables)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs,
process_data=process_data,
+ outputs={},
)
def _handle_item(
diff --git a/api/core/workflow/repositories/draft_variable_repository.py b/api/core/workflow/repositories/draft_variable_repository.py
new file mode 100644
index 0000000000..cadc23f845
--- /dev/null
+++ b/api/core/workflow/repositories/draft_variable_repository.py
@@ -0,0 +1,32 @@
+import abc
+from collections.abc import Mapping
+from typing import Any, Protocol
+
+from sqlalchemy.orm import Session
+
+from core.workflow.nodes.enums import NodeType
+
+
+class DraftVariableSaver(Protocol):
+ @abc.abstractmethod
+ def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None):
+ pass
+
+
+class DraftVariableSaverFactory(Protocol):
+ @abc.abstractmethod
+ def __call__(
+ self,
+ session: Session,
+ app_id: str,
+ node_id: str,
+ node_type: NodeType,
+ node_execution_id: str,
+ enclosing_node_id: str | None = None,
+ ) -> "DraftVariableSaver":
+ pass
+
+
+class NoopDraftVariableSaver(DraftVariableSaver):
+ def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None):
+ pass
diff --git a/api/core/workflow/repositories/workflow_execution_repository.py b/api/core/workflow/repositories/workflow_execution_repository.py
index 5917310c8b..bcbd253392 100644
--- a/api/core/workflow/repositories/workflow_execution_repository.py
+++ b/api/core/workflow/repositories/workflow_execution_repository.py
@@ -1,4 +1,4 @@
-from typing import Optional, Protocol
+from typing import Protocol
from core.workflow.entities.workflow_execution import WorkflowExecution
@@ -28,15 +28,3 @@ class WorkflowExecutionRepository(Protocol):
execution: The WorkflowExecution instance to save or update
"""
...
-
- def get(self, execution_id: str) -> Optional[WorkflowExecution]:
- """
- Retrieve a WorkflowExecution by its ID.
-
- Args:
- execution_id: The workflow execution ID
-
- Returns:
- The WorkflowExecution instance if found, None otherwise
- """
- ...
diff --git a/api/core/workflow/repositories/workflow_node_execution_repository.py b/api/core/workflow/repositories/workflow_node_execution_repository.py
index 1908a6b190..8bf81f5442 100644
--- a/api/core/workflow/repositories/workflow_node_execution_repository.py
+++ b/api/core/workflow/repositories/workflow_node_execution_repository.py
@@ -39,18 +39,6 @@ class WorkflowNodeExecutionRepository(Protocol):
"""
...
- def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
- """
- Retrieve a NodeExecution by its node_execution_id.
-
- Args:
- node_execution_id: The node execution ID
-
- Returns:
- The NodeExecution instance if found, None otherwise
- """
- ...
-
def get_by_workflow_run(
self,
workflow_run_id: str,
@@ -69,24 +57,3 @@ class WorkflowNodeExecutionRepository(Protocol):
A list of NodeExecution instances
"""
...
-
- def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
- """
- Retrieve all running NodeExecution instances for a specific workflow run.
-
- Args:
- workflow_run_id: The workflow run ID
-
- Returns:
- A list of running NodeExecution instances
- """
- ...
-
- def clear(self) -> None:
- """
- Clear all NodeExecution records based on implementation-specific criteria.
-
- This method is intended to be used for bulk deletion operations, such as removing
- all records associated with a specific app_id and tenant_id in multi-tenant implementations.
- """
- ...
diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py
new file mode 100644
index 0000000000..df90c16596
--- /dev/null
+++ b/api/core/workflow/system_variable.py
@@ -0,0 +1,89 @@
+from collections.abc import Sequence
+from typing import Any
+
+from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
+
+from core.file.models import File
+from core.workflow.enums import SystemVariableKey
+
+
+class SystemVariable(BaseModel):
+ """A model for managing system variables.
+
+ Fields with a value of `None` are treated as absent and will not be included
+ in the variable pool.
+ """
+
+ model_config = ConfigDict(
+ extra="forbid",
+ serialize_by_alias=True,
+ validate_by_alias=True,
+ )
+
+ user_id: str | None = None
+
+ # Ideally, `app_id` and `workflow_id` should be required and not `None`.
+ # However, there are scenarios in the codebase where these fields are not set.
+ # To maintain compatibility, they are marked as optional here.
+ app_id: str | None = None
+ workflow_id: str | None = None
+
+ files: Sequence[File] = Field(default_factory=list)
+
+ # NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`.
+ # To maintain compatibility with existing workflows, it must be serialized
+ # as `workflow_run_id` in dictionaries or JSON objects, and also referenced
+ # as `workflow_run_id` in the variable pool.
+ workflow_execution_id: str | None = Field(
+ validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"),
+ serialization_alias="workflow_run_id",
+ default=None,
+ )
+ # Chatflow related fields.
+ query: str | None = None
+ conversation_id: str | None = None
+ dialogue_count: int | None = None
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_json_fields(cls, data):
+ if isinstance(data, dict):
+ # For JSON validation, only allow workflow_run_id
+ if "workflow_execution_id" in data and "workflow_run_id" not in data:
+ # This is likely from direct instantiation, allow it
+ return data
+ elif "workflow_execution_id" in data and "workflow_run_id" in data:
+ # Both present, remove workflow_execution_id
+ data = data.copy()
+ data.pop("workflow_execution_id")
+ return data
+ return data
+
+ @classmethod
+ def empty(cls) -> "SystemVariable":
+ return cls()
+
+ def to_dict(self) -> dict[SystemVariableKey, Any]:
+ # NOTE: This method is provided for compatibility with legacy code.
+ # New code should use the `SystemVariable` object directly instead of converting
+ # it to a dictionary, as this conversion results in the loss of type information
+ # for each key, making static analysis more difficult.
+
+ d: dict[SystemVariableKey, Any] = {
+ SystemVariableKey.FILES: self.files,
+ }
+ if self.user_id is not None:
+ d[SystemVariableKey.USER_ID] = self.user_id
+ if self.app_id is not None:
+ d[SystemVariableKey.APP_ID] = self.app_id
+ if self.workflow_id is not None:
+ d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id
+ if self.workflow_execution_id is not None:
+ d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id
+ if self.query is not None:
+ d[SystemVariableKey.QUERY] = self.query
+ if self.conversation_id is not None:
+ d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id
+ if self.dialogue_count is not None:
+ d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count
+ return d
diff --git a/api/core/workflow/utils/structured_output/entities.py b/api/core/workflow/utils/structured_output/entities.py
deleted file mode 100644
index 6491042bfe..0000000000
--- a/api/core/workflow/utils/structured_output/entities.py
+++ /dev/null
@@ -1,16 +0,0 @@
-from enum import StrEnum
-
-
-class ResponseFormat(StrEnum):
- """Constants for model response formats"""
-
- JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
- JSON = "JSON" # model's json mode. some model like claude support this mode.
- JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
-
-
-class SpecialModelType(StrEnum):
- """Constants for identifying model types"""
-
- GEMINI = "gemini"
- OLLAMA = "ollama"
diff --git a/api/core/workflow/utils/structured_output/prompt.py b/api/core/workflow/utils/structured_output/prompt.py
deleted file mode 100644
index 06d9b2056e..0000000000
--- a/api/core/workflow/utils/structured_output/prompt.py
+++ /dev/null
@@ -1,17 +0,0 @@
-STRUCTURED_OUTPUT_PROMPT = """You’re a helpful AI assistant. You could answer questions and output in JSON format.
-constraints:
- - You must output in JSON format.
- - Do not output boolean value, use string type instead.
- - Do not output integer or float value, use number type instead.
-eg:
- Here is the JSON schema:
- {"additionalProperties": false, "properties": {"age": {"type": "number"}, "name": {"type": "string"}}, "required": ["name", "age"], "type": "object"}
-
- Here is the user's question:
- My name is John Doe and I am 30 years old.
-
- output:
- {"name": "John Doe", "age": 30}
-Here is the JSON schema:
-{{schema}}
-""" # noqa: E501
diff --git a/api/core/workflow/utils/variable_utils.py b/api/core/workflow/utils/variable_utils.py
new file mode 100644
index 0000000000..868868315b
--- /dev/null
+++ b/api/core/workflow/utils/variable_utils.py
@@ -0,0 +1,29 @@
+from core.variables.segments import ObjectSegment, Segment
+from core.workflow.entities.variable_pool import VariablePool, VariableValue
+
+
+def append_variables_recursively(
+ pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment
+):
+ """
+ Append variables recursively
+ :param pool: variable pool to append variables to
+ :param node_id: node id
+ :param variable_key_list: variable key list
+ :param variable_value: variable value
+ :return:
+ """
+ pool.add([node_id] + variable_key_list, variable_value)
+
+ # if variable_value is a dict, then recursively append variables
+ if isinstance(variable_value, ObjectSegment):
+ variable_dict = variable_value.value
+ elif isinstance(variable_value, dict):
+ variable_dict = variable_value
+ else:
+ return
+
+ for key, value in variable_dict.items():
+ # construct new key list
+ new_key_list = variable_key_list + [key]
+ append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value)
diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py
new file mode 100644
index 0000000000..1e13871d0a
--- /dev/null
+++ b/api/core/workflow/variable_loader.py
@@ -0,0 +1,84 @@
+import abc
+from collections.abc import Mapping, Sequence
+from typing import Any, Protocol
+
+from core.variables import Variable
+from core.variables.consts import MIN_SELECTORS_LENGTH
+from core.workflow.entities.variable_pool import VariablePool
+from core.workflow.utils import variable_utils
+
+
+class VariableLoader(Protocol):
+ """Interface for loading variables based on selectors.
+
+ A `VariableLoader` is responsible for retrieving additional variables required during the execution
+ of a single node, which are not provided as user inputs.
+
+ NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same
+ application and share the same `app_id`. However, this interface does not enforce that constraint,
+ and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of
+ concern and allow for flexible implementations.
+
+ Implementations of `VariableLoader` should almost always have an `app_id` parameter in
+ their constructor.
+
+ TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into
+ `WorkflowService.single_step_run`, we may get rid of this interface.
+ """
+
+ @abc.abstractmethod
+ def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
+ """Load variables based on the provided selectors. If the selectors are empty,
+ this method should return an empty list.
+
+ The order of the returned variables is not guaranteed. If the caller wants to ensure
+ a specific order, they should sort the returned list themselves.
+
+ :param: selectors: a list of string list, each inner list should have at least two elements:
+ - the first element is the node ID,
+ - the second element is the variable name.
+ :return: a list of Variable objects that match the provided selectors.
+ """
+ pass
+
+
+class _DummyVariableLoader(VariableLoader):
+ """A dummy implementation of VariableLoader that does not load any variables.
+ Serves as a placeholder when no variable loading is needed.
+ """
+
+ def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
+ return []
+
+
+DUMMY_VARIABLE_LOADER = _DummyVariableLoader()
+
+
+def load_into_variable_pool(
+ variable_loader: VariableLoader,
+ variable_pool: VariablePool,
+ variable_mapping: Mapping[str, Sequence[str]],
+ user_inputs: Mapping[str, Any],
+):
+ # Loading missing variable from draft var here, and set it into
+ # variable_pool.
+ variables_to_load: list[list[str]] = []
+ for key, selector in variable_mapping.items():
+ # NOTE(QuantumGhost): this logic needs to be in sync with
+ # `WorkflowEntry.mapping_user_inputs_to_variable_pool`.
+ node_variable_list = key.split(".")
+ if len(node_variable_list) < 1:
+ raise ValueError(f"Invalid variable key: {key}. It should have at least one element.")
+ if key in user_inputs:
+ continue
+ node_variable_key = ".".join(node_variable_list[1:])
+ if node_variable_key in user_inputs:
+ continue
+ if variable_pool.get(selector) is None:
+ variables_to_load.append(list(selector))
+ loaded = variable_loader.load_variables(variables_to_load)
+ for var in loaded:
+ assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}"
+ variable_utils.append_variables_recursively(
+ variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var
+ )
diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py
index b88f9edd03..f844aada95 100644
--- a/api/core/workflow/workflow_cycle_manager.py
+++ b/api/core/workflow/workflow_cycle_manager.py
@@ -1,6 +1,6 @@
from collections.abc import Mapping
from dataclasses import dataclass
-from datetime import UTC, datetime
+from datetime import datetime
from typing import Any, Optional, Union
from uuid import uuid4
@@ -26,7 +26,9 @@ from core.workflow.entities.workflow_node_execution import (
from core.workflow.enums import SystemVariableKey
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
+from libs.datetime_utils import naive_utc_now
@dataclass
@@ -42,7 +44,7 @@ class WorkflowCycleManager:
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
- workflow_system_variables: dict[SystemVariableKey, Any],
+ workflow_system_variables: SystemVariable,
workflow_info: CycleManagerWorkflowInfo,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
@@ -53,19 +55,15 @@ class WorkflowCycleManager:
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
- def handle_workflow_run_start(self) -> WorkflowExecution:
- inputs = {**self._application_generate_entity.inputs}
- for key, value in (self._workflow_system_variables or {}).items():
- if key.value == "conversation":
- continue
- inputs[f"sys.{key.value}"] = value
+ # Initialize caches for workflow execution cycle
+ # These caches avoid redundant repository calls during a single workflow execution
+ self._workflow_execution_cache: dict[str, WorkflowExecution] = {}
+ self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
- # handle special values
- inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
+ def handle_workflow_run_start(self) -> WorkflowExecution:
+ inputs = self._prepare_workflow_inputs()
+ execution_id = self._get_or_generate_execution_id()
- # init workflow run
- # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
- execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_EXECUTION_ID) or uuid4())
execution = WorkflowExecution.new(
id_=execution_id,
workflow_id=self._workflow_info.workflow_id,
@@ -73,12 +71,10 @@ class WorkflowCycleManager:
workflow_version=self._workflow_info.version,
graph=self._workflow_info.graph_data,
inputs=inputs,
- started_at=datetime.now(UTC).replace(tzinfo=None),
+ started_at=naive_utc_now(),
)
- self._workflow_execution_repository.save(execution)
-
- return execution
+ return self._save_and_cache_workflow_execution(execution)
def handle_workflow_run_success(
self,
@@ -92,23 +88,15 @@ class WorkflowCycleManager:
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
- outputs = WorkflowEntry.handle_special_values(outputs)
-
- workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED
- workflow_execution.outputs = outputs or {}
- workflow_execution.total_tokens = total_tokens
- workflow_execution.total_steps = total_steps
- workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
+ self._update_workflow_execution_completion(
+ workflow_execution,
+ status=WorkflowExecutionStatus.SUCCEEDED,
+ outputs=outputs,
+ total_tokens=total_tokens,
+ total_steps=total_steps,
+ )
- if trace_manager:
- trace_manager.add_trace_task(
- TraceTask(
- TraceTaskName.WORKFLOW_TRACE,
- workflow_execution=workflow_execution,
- conversation_id=conversation_id,
- user_id=trace_manager.user_id,
- )
- )
+ self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id)
self._workflow_execution_repository.save(workflow_execution)
return workflow_execution
@@ -125,24 +113,17 @@ class WorkflowCycleManager:
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowExecution:
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
- outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
- execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
- execution.outputs = outputs or {}
- execution.total_tokens = total_tokens
- execution.total_steps = total_steps
- execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
- execution.exceptions_count = exceptions_count
+ self._update_workflow_execution_completion(
+ execution,
+ status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
+ outputs=outputs,
+ total_tokens=total_tokens,
+ total_steps=total_steps,
+ exceptions_count=exceptions_count,
+ )
- if trace_manager:
- trace_manager.add_trace_task(
- TraceTask(
- TraceTaskName.WORKFLOW_TRACE,
- workflow_execution=execution,
- conversation_id=conversation_id,
- user_id=trace_manager.user_id,
- )
- )
+ self._add_trace_task_if_needed(trace_manager, execution, conversation_id)
self._workflow_execution_repository.save(execution)
return execution
@@ -160,41 +141,20 @@ class WorkflowCycleManager:
exceptions_count: int = 0,
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
-
- workflow_execution.status = WorkflowExecutionStatus(status.value)
- workflow_execution.error_message = error_message
- workflow_execution.total_tokens = total_tokens
- workflow_execution.total_steps = total_steps
- workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
- workflow_execution.exceptions_count = exceptions_count
-
- # Use the instance repository to find running executions for a workflow run
- running_node_executions = self._workflow_node_execution_repository.get_running_executions(
- workflow_run_id=workflow_execution.id_
+ now = naive_utc_now()
+
+ self._update_workflow_execution_completion(
+ workflow_execution,
+ status=status,
+ total_tokens=total_tokens,
+ total_steps=total_steps,
+ error_message=error_message,
+ exceptions_count=exceptions_count,
+ finished_at=now,
)
- # Update the domain models
- now = datetime.now(UTC).replace(tzinfo=None)
- for node_execution in running_node_executions:
- if node_execution.node_execution_id:
- # Update the domain model
- node_execution.status = WorkflowNodeExecutionStatus.FAILED
- node_execution.error = error_message
- node_execution.finished_at = now
- node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
-
- # Update the repository with the domain model
- self._workflow_node_execution_repository.save(node_execution)
-
- if trace_manager:
- trace_manager.add_trace_task(
- TraceTask(
- TraceTaskName.WORKFLOW_TRACE,
- workflow_execution=workflow_execution,
- conversation_id=conversation_id,
- user_id=trace_manager.user_id,
- )
- )
+ self._fail_running_node_executions(workflow_execution.id_, error_message, now)
+ self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id)
self._workflow_execution_repository.save(workflow_execution)
return workflow_execution
@@ -207,65 +167,24 @@ class WorkflowCycleManager:
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
- # Create a domain model
- created_at = datetime.now(UTC).replace(tzinfo=None)
- metadata = {
- WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
- WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
- WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
- }
-
- domain_execution = WorkflowNodeExecution(
- id=str(uuid4()),
- workflow_id=workflow_execution.workflow_id,
- workflow_execution_id=workflow_execution.id_,
- predecessor_node_id=event.predecessor_node_id,
- index=event.node_run_index,
- node_execution_id=event.node_execution_id,
- node_id=event.node_id,
- node_type=event.node_type,
- title=event.node_data.title,
+ domain_execution = self._create_node_execution_from_event(
+ workflow_execution=workflow_execution,
+ event=event,
status=WorkflowNodeExecutionStatus.RUNNING,
- metadata=metadata,
- created_at=created_at,
)
- # Use the instance repository to save the domain model
- self._workflow_node_execution_repository.save(domain_execution)
-
- return domain_execution
+ return self._save_and_cache_node_execution(domain_execution)
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
- # Get the domain model from repository
- domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
- if not domain_execution:
- raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
-
- # Process data
- inputs = WorkflowEntry.handle_special_values(event.inputs)
- process_data = WorkflowEntry.handle_special_values(event.process_data)
- outputs = WorkflowEntry.handle_special_values(event.outputs)
-
- # Convert metadata keys to strings
- execution_metadata_dict = {}
- if event.execution_metadata:
- for key, value in event.execution_metadata.items():
- execution_metadata_dict[key] = value
-
- finished_at = datetime.now(UTC).replace(tzinfo=None)
- elapsed_time = (finished_at - event.start_at).total_seconds()
+ domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
- # Update domain model
- domain_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
- domain_execution.update_from_mapping(
- inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
+ self._update_node_execution_completion(
+ domain_execution,
+ event=event,
+ status=WorkflowNodeExecutionStatus.SUCCEEDED,
)
- domain_execution.finished_at = finished_at
- domain_execution.elapsed_time = elapsed_time
- # Update the repository with the domain model
self._workflow_node_execution_repository.save(domain_execution)
-
return domain_execution
def handle_workflow_node_execution_failed(
@@ -281,96 +200,251 @@ class WorkflowCycleManager:
:param event: queue node failed event
:return:
"""
- # Get the domain model from repository
- domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
- if not domain_execution:
- raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
-
- # Process data
- inputs = WorkflowEntry.handle_special_values(event.inputs)
- process_data = WorkflowEntry.handle_special_values(event.process_data)
- outputs = WorkflowEntry.handle_special_values(event.outputs)
-
- # Convert metadata keys to strings
- execution_metadata_dict = {}
- if event.execution_metadata:
- for key, value in event.execution_metadata.items():
- execution_metadata_dict[key] = value
-
- finished_at = datetime.now(UTC).replace(tzinfo=None)
- elapsed_time = (finished_at - event.start_at).total_seconds()
+ domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
- # Update domain model
- domain_execution.status = (
- WorkflowNodeExecutionStatus.FAILED
- if not isinstance(event, QueueNodeExceptionEvent)
- else WorkflowNodeExecutionStatus.EXCEPTION
+ status = (
+ WorkflowNodeExecutionStatus.EXCEPTION
+ if isinstance(event, QueueNodeExceptionEvent)
+ else WorkflowNodeExecutionStatus.FAILED
)
- domain_execution.error = event.error
- domain_execution.update_from_mapping(
- inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
+
+ self._update_node_execution_completion(
+ domain_execution,
+ event=event,
+ status=status,
+ error=event.error,
+ handle_special_values=True,
)
- domain_execution.finished_at = finished_at
- domain_execution.elapsed_time = elapsed_time
- # Update the repository with the domain model
self._workflow_node_execution_repository.save(domain_execution)
-
return domain_execution
def handle_workflow_node_execution_retried(
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
- created_at = event.start_at
- finished_at = datetime.now(UTC).replace(tzinfo=None)
- elapsed_time = (finished_at - created_at).total_seconds()
+
+ domain_execution = self._create_node_execution_from_event(
+ workflow_execution=workflow_execution,
+ event=event,
+ status=WorkflowNodeExecutionStatus.RETRY,
+ error=event.error,
+ created_at=event.start_at,
+ )
+
+ # Handle inputs and outputs
inputs = WorkflowEntry.handle_special_values(event.inputs)
- outputs = WorkflowEntry.handle_special_values(event.outputs)
+ outputs = event.outputs
+ metadata = self._merge_event_metadata(event)
- # Convert metadata keys to strings
- origin_metadata = {
- WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
+ domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata)
+
+ return self._save_and_cache_node_execution(domain_execution)
+
+ def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
+ # Check cache first
+ if id in self._workflow_execution_cache:
+ return self._workflow_execution_cache[id]
+
+ raise WorkflowRunNotFoundError(id)
+
+ def _prepare_workflow_inputs(self) -> dict[str, Any]:
+ """Prepare workflow inputs by merging application inputs with system variables."""
+ inputs = {**self._application_generate_entity.inputs}
+
+ if self._workflow_system_variables:
+ for field_name, value in self._workflow_system_variables.to_dict().items():
+ if field_name != SystemVariableKey.CONVERSATION_ID:
+ inputs[f"sys.{field_name}"] = value
+
+ return dict(WorkflowEntry.handle_special_values(inputs) or {})
+
+ def _get_or_generate_execution_id(self) -> str:
+ """Get execution ID from system variables or generate a new one."""
+ if self._workflow_system_variables and self._workflow_system_variables.workflow_execution_id:
+ return str(self._workflow_system_variables.workflow_execution_id)
+ return str(uuid4())
+
+ def _save_and_cache_workflow_execution(self, execution: WorkflowExecution) -> WorkflowExecution:
+ """Save workflow execution to repository and cache it."""
+ self._workflow_execution_repository.save(execution)
+ self._workflow_execution_cache[execution.id_] = execution
+ return execution
+
+ def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution:
+ """Save node execution to repository and cache it if it has an ID."""
+ self._workflow_node_execution_repository.save(execution)
+ if execution.node_execution_id:
+ self._node_execution_cache[execution.node_execution_id] = execution
+ return execution
+
+ def _get_node_execution_from_cache(self, node_execution_id: str) -> WorkflowNodeExecution:
+ """Get node execution from cache or raise error if not found."""
+ domain_execution = self._node_execution_cache.get(node_execution_id)
+ if not domain_execution:
+ raise ValueError(f"Domain node execution not found: {node_execution_id}")
+ return domain_execution
+
+ def _update_workflow_execution_completion(
+ self,
+ execution: WorkflowExecution,
+ *,
+ status: WorkflowExecutionStatus,
+ total_tokens: int,
+ total_steps: int,
+ outputs: Mapping[str, Any] | None = None,
+ error_message: Optional[str] = None,
+ exceptions_count: int = 0,
+ finished_at: Optional[datetime] = None,
+ ) -> None:
+ """Update workflow execution with completion data."""
+ execution.status = status
+ execution.outputs = outputs or {}
+ execution.total_tokens = total_tokens
+ execution.total_steps = total_steps
+ execution.finished_at = finished_at or naive_utc_now()
+ execution.exceptions_count = exceptions_count
+ if error_message:
+ execution.error_message = error_message
+
+ def _add_trace_task_if_needed(
+ self,
+ trace_manager: Optional[TraceQueueManager],
+ workflow_execution: WorkflowExecution,
+ conversation_id: Optional[str],
+ ) -> None:
+ """Add trace task if trace manager is provided."""
+ if trace_manager:
+ trace_manager.add_trace_task(
+ TraceTask(
+ TraceTaskName.WORKFLOW_TRACE,
+ workflow_execution=workflow_execution,
+ conversation_id=conversation_id,
+ user_id=trace_manager.user_id,
+ )
+ )
+
+ def _fail_running_node_executions(
+ self,
+ workflow_execution_id: str,
+ error_message: str,
+ now: datetime,
+ ) -> None:
+ """Fail all running node executions for a workflow."""
+ running_node_executions = [
+ node_exec
+ for node_exec in self._node_execution_cache.values()
+ if node_exec.workflow_execution_id == workflow_execution_id
+ and node_exec.status == WorkflowNodeExecutionStatus.RUNNING
+ ]
+
+ for node_execution in running_node_executions:
+ if node_execution.node_execution_id:
+ node_execution.status = WorkflowNodeExecutionStatus.FAILED
+ node_execution.error = error_message
+ node_execution.finished_at = now
+ node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
+ self._workflow_node_execution_repository.save(node_execution)
+
+ def _create_node_execution_from_event(
+ self,
+ *,
+ workflow_execution: WorkflowExecution,
+ event: Union[QueueNodeStartedEvent, QueueNodeRetryEvent],
+ status: WorkflowNodeExecutionStatus,
+ error: Optional[str] = None,
+ created_at: Optional[datetime] = None,
+ ) -> WorkflowNodeExecution:
+ """Create a node execution from an event."""
+ now = naive_utc_now()
+ created_at = created_at or now
+
+ metadata = {
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
+ WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
- # Convert execution metadata keys to strings
- execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
- if event.execution_metadata:
- for key, value in event.execution_metadata.items():
- execution_metadata_dict[key] = value
-
- merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
-
- # Create a domain model
domain_execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id=workflow_execution.workflow_id,
workflow_execution_id=workflow_execution.id_,
predecessor_node_id=event.predecessor_node_id,
+ index=event.node_run_index,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=event.node_data.title,
- status=WorkflowNodeExecutionStatus.RETRY,
+ status=status,
+ metadata=metadata,
created_at=created_at,
- finished_at=finished_at,
- elapsed_time=elapsed_time,
- error=event.error,
- index=event.node_run_index,
+ error=error,
)
- # Update with mappings
- domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata)
-
- # Use the instance repository to save the domain model
- self._workflow_node_execution_repository.save(domain_execution)
+ if status == WorkflowNodeExecutionStatus.RETRY:
+ domain_execution.finished_at = now
+ domain_execution.elapsed_time = (now - created_at).total_seconds()
return domain_execution
- def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
- execution = self._workflow_execution_repository.get(id)
- if not execution:
- raise WorkflowRunNotFoundError(id)
- return execution
+ def _update_node_execution_completion(
+ self,
+ domain_execution: WorkflowNodeExecution,
+ *,
+ event: Union[
+ QueueNodeSucceededEvent,
+ QueueNodeFailedEvent,
+ QueueNodeInIterationFailedEvent,
+ QueueNodeInLoopFailedEvent,
+ QueueNodeExceptionEvent,
+ ],
+ status: WorkflowNodeExecutionStatus,
+ error: Optional[str] = None,
+ handle_special_values: bool = False,
+ ) -> None:
+ """Update node execution with completion data."""
+ finished_at = naive_utc_now()
+ elapsed_time = (finished_at - event.start_at).total_seconds()
+
+ # Process data
+ if handle_special_values:
+ inputs = WorkflowEntry.handle_special_values(event.inputs)
+ process_data = WorkflowEntry.handle_special_values(event.process_data)
+ else:
+ inputs = event.inputs
+ process_data = event.process_data
+
+ outputs = event.outputs
+
+ # Convert metadata
+ execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
+ if event.execution_metadata:
+ execution_metadata_dict.update(event.execution_metadata)
+
+ # Update domain model
+ domain_execution.status = status
+ domain_execution.update_from_mapping(
+ inputs=inputs,
+ process_data=process_data,
+ outputs=outputs,
+ metadata=execution_metadata_dict,
+ )
+ domain_execution.finished_at = finished_at
+ domain_execution.elapsed_time = elapsed_time
+
+ if error:
+ domain_execution.error = error
+
+ def _merge_event_metadata(self, event: QueueNodeRetryEvent) -> dict[WorkflowNodeExecutionMetadataKey, str | None]:
+ """Merge event metadata with origin metadata."""
+ origin_metadata = {
+ WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
+ WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
+ WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
+ }
+
+ execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
+ if event.execution_metadata:
+ execution_metadata_dict.update(event.execution_metadata)
+
+ return {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py
index 7648947fca..d2375da39c 100644
--- a/api/core/workflow/workflow_entry.py
+++ b/api/core/workflow/workflow_entry.py
@@ -5,7 +5,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from configs import dify_config
-from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.workflow.callbacks import WorkflowCallback
@@ -21,6 +21,8 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event import NodeEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
+from core.workflow.system_variable import SystemVariable
+from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from factories import file_factory
from models.enums import UserFrom
from models.workflow import (
@@ -68,6 +70,7 @@ class WorkflowEntry:
raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth))
# init workflow run state
+ graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
self.graph_engine = GraphEngine(
tenant_id=tenant_id,
app_id=app_id,
@@ -79,7 +82,7 @@ class WorkflowEntry:
call_depth=call_depth,
graph=graph,
graph_config=graph_config,
- variable_pool=variable_pool,
+ graph_runtime_state=graph_runtime_state,
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
thread_pool_id=thread_pool_id,
@@ -119,7 +122,9 @@ class WorkflowEntry:
workflow: Workflow,
node_id: str,
user_id: str,
- user_inputs: dict,
+ user_inputs: Mapping[str, Any],
+ variable_pool: VariablePool,
+ variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]:
"""
Single step run workflow node
@@ -129,34 +134,19 @@ class WorkflowEntry:
:param user_inputs: user inputs
:return:
"""
- # fetch node info from workflow graph
- workflow_graph = workflow.graph_dict
- if not workflow_graph:
- raise ValueError("workflow graph not found")
-
- nodes = workflow_graph.get("nodes")
- if not nodes:
- raise ValueError("nodes not found in workflow graph")
-
- # fetch node config from node id
- try:
- node_config = next(filter(lambda node: node["id"] == node_id, nodes))
- except StopIteration:
- raise ValueError("node id not found in workflow graph")
+ node_config = workflow.get_node_config_by_id(node_id)
+ node_config_data = node_config.get("data", {})
# Get node class
- node_type = NodeType(node_config.get("data", {}).get("type"))
- node_version = node_config.get("data", {}).get("version", "1")
+ node_type = NodeType(node_config_data.get("type"))
+ node_version = node_config_data.get("version", "1")
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
- # init variable pool
- variable_pool = VariablePool(environment_variables=workflow.environment_variables)
-
# init graph
graph = Graph.init(graph_config=workflow.graph_dict)
# init workflow run state
- node_instance = node_cls(
+ node = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=GraphInitParams(
@@ -182,18 +172,29 @@ class WorkflowEntry:
except NotImplementedError:
variable_mapping = {}
+ # Loading missing variable from draft var here, and set it into
+ # variable_pool.
+ load_into_variable_pool(
+ variable_loader=variable_loader,
+ variable_pool=variable_pool,
+ variable_mapping=variable_mapping,
+ user_inputs=user_inputs,
+ )
+
cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=workflow.tenant_id,
)
+
try:
# run node
- generator = node_instance.run()
+ generator = node.run()
except Exception as e:
- raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
- return node_instance, generator
+ logger.exception(f"error while running node, {workflow.id=}, {node.id=}, {node.type_=}, {node.version()=}")
+ raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
+ return node, generator
@classmethod
def run_free_node(
@@ -248,14 +249,14 @@ class WorkflowEntry:
# init variable pool
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=[],
)
node_cls = cast(type[BaseNode], node_cls)
# init workflow run state
- node_instance: BaseNode = node_cls(
+ node: BaseNode = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=GraphInitParams(
@@ -290,14 +291,19 @@ class WorkflowEntry:
)
# run node
- generator = node_instance.run()
+ generator = node.run()
- return node_instance, generator
+ return node, generator
except Exception as e:
- raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
+ logger.exception(f"error while running node, {node.id=}, {node.type_=}, {node.version()=}")
+ raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
@staticmethod
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
+ # NOTE(QuantumGhost): Avoid using this function in new code.
+ # Keep values structured as long as possible and only convert to dict
+ # immediately before serialization (e.g., JSON serialization) to maintain
+ # data integrity and type information.
result = WorkflowEntry._handle_special_values(value)
return result if isinstance(result, Mapping) or result is None else dict(result)
@@ -324,10 +330,17 @@ class WorkflowEntry:
cls,
*,
variable_mapping: Mapping[str, Sequence[str]],
- user_inputs: dict,
+ user_inputs: Mapping[str, Any],
variable_pool: VariablePool,
tenant_id: str,
) -> None:
+ # NOTE(QuantumGhost): This logic should remain synchronized with
+ # the implementation of `load_into_variable_pool`, specifically the logic about
+ # variable existence checking.
+
+ # WARNING(QuantumGhost): The semantics of this method are not clearly defined,
+ # and multiple parts of the codebase depend on its current behavior.
+ # Modify with caution.
for node_variable, variable_selector in variable_mapping.items():
# fetch node id and variable key from node_variable
node_variable_list = node_variable.split(".")
diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py
new file mode 100644
index 0000000000..2c634d25ec
--- /dev/null
+++ b/api/core/workflow/workflow_type_encoder.py
@@ -0,0 +1,36 @@
+from collections.abc import Mapping
+from typing import Any
+
+from pydantic import BaseModel
+
+from core.file.models import File
+from core.variables import Segment
+
+
+class WorkflowRuntimeTypeConverter:
+ def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
+ result = self._to_json_encodable_recursive(value)
+ return result if isinstance(result, Mapping) or result is None else dict(result)
+
+ def _to_json_encodable_recursive(self, value: Any) -> Any:
+ if value is None:
+ return value
+ if isinstance(value, (bool, int, str, float)):
+ return value
+ if isinstance(value, Segment):
+ return self._to_json_encodable_recursive(value.value)
+ if isinstance(value, File):
+ return value.to_dict()
+ if isinstance(value, BaseModel):
+ return value.model_dump(mode="json")
+ if isinstance(value, dict):
+ res = {}
+ for k, v in value.items():
+ res[k] = self._to_json_encodable_recursive(v)
+ return res
+ if isinstance(value, list):
+ res_list = []
+ for item in value:
+ res_list.append(self._to_json_encodable_recursive(item))
+ return res_list
+ return value
diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py
index 1d6ad35333..ebc55d5ef8 100644
--- a/api/events/event_handlers/__init__.py
+++ b/api/events/event_handlers/__init__.py
@@ -3,8 +3,10 @@ from .clean_when_document_deleted import handle
from .create_document_index import handle
from .create_installed_app_when_app_created import handle
from .create_site_record_when_app_created import handle
-from .deduct_quota_when_message_created import handle
from .delete_tool_parameters_cache_when_sync_draft_workflow import handle
from .update_app_dataset_join_when_app_model_config_updated import handle
from .update_app_dataset_join_when_app_published_workflow_updated import handle
-from .update_provider_last_used_at_when_message_created import handle
+
+# Consolidated handler replaces both deduct_quota_when_message_created and
+# update_provider_last_used_at_when_message_created
+from .update_provider_when_message_created import handle
diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py
index 8a677f6b6f..cb48bd92a0 100644
--- a/api/events/event_handlers/create_document_index.py
+++ b/api/events/event_handlers/create_document_index.py
@@ -1,4 +1,3 @@
-import datetime
import logging
import time
@@ -8,6 +7,7 @@ from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from events.event_handlers.document_index_event import document_index_created
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models.dataset import Document
@@ -33,7 +33,7 @@ def handle(sender, **kwargs):
raise NotFound("Document not found")
document.indexing_status = "parsing"
- document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
diff --git a/api/events/event_handlers/deduct_quota_when_message_created.py b/api/events/event_handlers/deduct_quota_when_message_created.py
deleted file mode 100644
index b8e7019446..0000000000
--- a/api/events/event_handlers/deduct_quota_when_message_created.py
+++ /dev/null
@@ -1,65 +0,0 @@
-from datetime import UTC, datetime
-
-from configs import dify_config
-from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
-from core.entities.provider_entities import QuotaUnit
-from core.plugin.entities.plugin import ModelProviderID
-from events.message_event import message_was_created
-from extensions.ext_database import db
-from models.provider import Provider, ProviderType
-
-
-@message_was_created.connect
-def handle(sender, **kwargs):
- message = sender
- application_generate_entity = kwargs.get("application_generate_entity")
-
- if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
- return
-
- model_config = application_generate_entity.model_conf
- provider_model_bundle = model_config.provider_model_bundle
- provider_configuration = provider_model_bundle.configuration
-
- if provider_configuration.using_provider_type != ProviderType.SYSTEM:
- return
-
- system_configuration = provider_configuration.system_configuration
-
- if not system_configuration.current_quota_type:
- return
-
- quota_unit = None
- for quota_configuration in system_configuration.quota_configurations:
- if quota_configuration.quota_type == system_configuration.current_quota_type:
- quota_unit = quota_configuration.quota_unit
-
- if quota_configuration.quota_limit == -1:
- return
-
- break
-
- used_quota = None
- if quota_unit:
- if quota_unit == QuotaUnit.TOKENS:
- used_quota = message.message_tokens + message.answer_tokens
- elif quota_unit == QuotaUnit.CREDITS:
- used_quota = dify_config.get_model_credits(model_config.model)
- else:
- used_quota = 1
-
- if used_quota is not None and system_configuration.current_quota_type is not None:
- db.session.query(Provider).filter(
- Provider.tenant_id == application_generate_entity.app_config.tenant_id,
- # TODO: Use provider name with prefix after the data migration.
- Provider.provider_name == ModelProviderID(model_config.provider).provider_name,
- Provider.provider_type == ProviderType.SYSTEM.value,
- Provider.quota_type == system_configuration.current_quota_type.value,
- Provider.quota_limit > Provider.quota_used,
- ).update(
- {
- "quota_used": Provider.quota_used + used_quota,
- "last_used": datetime.now(tz=UTC).replace(tzinfo=None),
- }
- )
- db.session.commit()
diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py
index 249bd14429..6c9fc0bf1d 100644
--- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py
+++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py
@@ -20,6 +20,7 @@ def handle(sender, **kwargs):
provider_id=tool_entity.provider_id,
tool_name=tool_entity.tool_name,
tenant_id=app.tenant_id,
+ credential_id=tool_entity.credential_id,
)
manager = ToolParameterConfigurationManager(
tenant_id=app.tenant_id,
diff --git a/api/events/event_handlers/update_provider_last_used_at_when_message_created.py b/api/events/event_handlers/update_provider_last_used_at_when_message_created.py
deleted file mode 100644
index 59412cf87c..0000000000
--- a/api/events/event_handlers/update_provider_last_used_at_when_message_created.py
+++ /dev/null
@@ -1,20 +0,0 @@
-from datetime import UTC, datetime
-
-from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
-from events.message_event import message_was_created
-from extensions.ext_database import db
-from models.provider import Provider
-
-
-@message_was_created.connect
-def handle(sender, **kwargs):
- application_generate_entity = kwargs.get("application_generate_entity")
-
- if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
- return
-
- db.session.query(Provider).filter(
- Provider.tenant_id == application_generate_entity.app_config.tenant_id,
- Provider.provider_name == application_generate_entity.model_conf.provider,
- ).update({"last_used": datetime.now(UTC).replace(tzinfo=None)})
- db.session.commit()
diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py
new file mode 100644
index 0000000000..d3943f2eda
--- /dev/null
+++ b/api/events/event_handlers/update_provider_when_message_created.py
@@ -0,0 +1,234 @@
+import logging
+import time as time_module
+from datetime import datetime
+from typing import Any, Optional
+
+from pydantic import BaseModel
+from sqlalchemy import update
+from sqlalchemy.orm import Session
+
+from configs import dify_config
+from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
+from core.entities.provider_entities import QuotaUnit, SystemConfiguration
+from core.plugin.entities.plugin import ModelProviderID
+from events.message_event import message_was_created
+from extensions.ext_database import db
+from libs import datetime_utils
+from models.model import Message
+from models.provider import Provider, ProviderType
+
+logger = logging.getLogger(__name__)
+
+
+class _ProviderUpdateFilters(BaseModel):
+ """Filters for identifying Provider records to update."""
+
+ tenant_id: str
+ provider_name: str
+ provider_type: Optional[str] = None
+ quota_type: Optional[str] = None
+
+
+class _ProviderUpdateAdditionalFilters(BaseModel):
+ """Additional filters for Provider updates."""
+
+ quota_limit_check: bool = False
+
+
+class _ProviderUpdateValues(BaseModel):
+ """Values to update in Provider records."""
+
+ last_used: Optional[datetime] = None
+ quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression
+
+
+class _ProviderUpdateOperation(BaseModel):
+ """A single Provider update operation."""
+
+ filters: _ProviderUpdateFilters
+ values: _ProviderUpdateValues
+ additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
+ description: str = "unknown"
+
+
+@message_was_created.connect
+def handle(sender: Message, **kwargs):
+ """
+ Consolidated handler for Provider updates when a message is created.
+
+ This handler replaces both:
+ - update_provider_last_used_at_when_message_created
+ - deduct_quota_when_message_created
+
+ By performing all Provider updates in a single transaction, we ensure
+ consistency and efficiency when updating Provider records.
+ """
+ message = sender
+ application_generate_entity = kwargs.get("application_generate_entity")
+
+ if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
+ return
+
+ tenant_id = application_generate_entity.app_config.tenant_id
+ provider_name = application_generate_entity.model_conf.provider
+ current_time = datetime_utils.naive_utc_now()
+
+ # Prepare updates for both scenarios
+ updates_to_perform: list[_ProviderUpdateOperation] = []
+
+ # 1. Always update last_used for the provider
+ basic_update = _ProviderUpdateOperation(
+ filters=_ProviderUpdateFilters(
+ tenant_id=tenant_id,
+ provider_name=provider_name,
+ ),
+ values=_ProviderUpdateValues(last_used=current_time),
+ description="basic_last_used_update",
+ )
+ updates_to_perform.append(basic_update)
+
+ # 2. Check if we need to deduct quota (system provider only)
+ model_config = application_generate_entity.model_conf
+ provider_model_bundle = model_config.provider_model_bundle
+ provider_configuration = provider_model_bundle.configuration
+
+ if (
+ provider_configuration.using_provider_type == ProviderType.SYSTEM
+ and provider_configuration.system_configuration
+ and provider_configuration.system_configuration.current_quota_type is not None
+ ):
+ system_configuration = provider_configuration.system_configuration
+
+ # Calculate quota usage
+ used_quota = _calculate_quota_usage(
+ message=message,
+ system_configuration=system_configuration,
+ model_name=model_config.model,
+ )
+
+ if used_quota is not None:
+ quota_update = _ProviderUpdateOperation(
+ filters=_ProviderUpdateFilters(
+ tenant_id=tenant_id,
+ provider_name=ModelProviderID(model_config.provider).provider_name,
+ provider_type=ProviderType.SYSTEM.value,
+ quota_type=provider_configuration.system_configuration.current_quota_type.value,
+ ),
+ values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
+ additional_filters=_ProviderUpdateAdditionalFilters(
+ quota_limit_check=True # Provider.quota_limit > Provider.quota_used
+ ),
+ description="quota_deduction_update",
+ )
+ updates_to_perform.append(quota_update)
+
+ # Execute all updates
+ start_time = time_module.perf_counter()
+ try:
+ _execute_provider_updates(updates_to_perform)
+
+ # Log successful completion with timing
+ duration = time_module.perf_counter() - start_time
+
+ logger.info(
+ f"Provider updates completed successfully. "
+ f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, "
+ f"Tenant: {tenant_id}, Provider: {provider_name}"
+ )
+
+ except Exception as e:
+ # Log failure with timing and context
+ duration = time_module.perf_counter() - start_time
+
+ logger.exception(
+ f"Provider updates failed after {duration:.3f}s. "
+ f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, "
+ f"Provider: {provider_name}"
+ )
+ raise
+
+
+def _calculate_quota_usage(
+ *, message: Message, system_configuration: SystemConfiguration, model_name: str
+) -> Optional[int]:
+ """Calculate quota usage based on message tokens and quota type."""
+ quota_unit = None
+ for quota_configuration in system_configuration.quota_configurations:
+ if quota_configuration.quota_type == system_configuration.current_quota_type:
+ quota_unit = quota_configuration.quota_unit
+ if quota_configuration.quota_limit == -1:
+ return None
+ break
+ if quota_unit is None:
+ return None
+
+ try:
+ if quota_unit == QuotaUnit.TOKENS:
+ tokens = message.message_tokens + message.answer_tokens
+ return tokens
+ if quota_unit == QuotaUnit.CREDITS:
+ tokens = dify_config.get_model_credits(model_name)
+ return tokens
+ elif quota_unit == QuotaUnit.TIMES:
+ return 1
+ return None
+ except Exception as e:
+ logger.exception("Failed to calculate quota usage")
+ return None
+
+
+def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
+ """Execute all Provider updates in a single transaction."""
+ if not updates_to_perform:
+ return
+
+ # Use SQLAlchemy's context manager for transaction management
+ # This automatically handles commit/rollback
+ with Session(db.engine) as session:
+ # Use a single transaction for all updates
+ for update_operation in updates_to_perform:
+ filters = update_operation.filters
+ values = update_operation.values
+ additional_filters = update_operation.additional_filters
+ description = update_operation.description
+
+ # Build the where conditions
+ where_conditions = [
+ Provider.tenant_id == filters.tenant_id,
+ Provider.provider_name == filters.provider_name,
+ ]
+
+ # Add additional filters if specified
+ if filters.provider_type is not None:
+ where_conditions.append(Provider.provider_type == filters.provider_type)
+ if filters.quota_type is not None:
+ where_conditions.append(Provider.quota_type == filters.quota_type)
+ if additional_filters.quota_limit_check:
+ where_conditions.append(Provider.quota_limit > Provider.quota_used)
+
+ # Prepare values dict for SQLAlchemy update
+ update_values = {}
+ if values.last_used is not None:
+ update_values["last_used"] = values.last_used
+ if values.quota_used is not None:
+ update_values["quota_used"] = values.quota_used
+
+ # Build and execute the update statement
+ stmt = update(Provider).where(*where_conditions).values(**update_values)
+ result = session.execute(stmt)
+ rows_affected = result.rowcount
+
+ logger.debug(
+ f"Provider update ({description}): {rows_affected} rows affected. "
+ f"Filters: {filters.model_dump()}, Values: {update_values}"
+ )
+
+ # If no rows were affected for quota updates, log a warning
+ if rows_affected == 0 and description == "quota_deduction_update":
+ logger.warning(
+ f"No Provider rows updated for quota deduction. "
+ f"This may indicate quota limit exceeded or provider not found. "
+ f"Filters: {filters.model_dump()}"
+ )
+
+ logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates")
diff --git a/api/extensions/ext_app_metrics.py b/api/extensions/ext_app_metrics.py
index b7d412d68d..56a69a1862 100644
--- a/api/extensions/ext_app_metrics.py
+++ b/api/extensions/ext_app_metrics.py
@@ -12,14 +12,14 @@ def init_app(app: DifyApp):
@app.after_request
def after_request(response):
"""Add Version headers to the response."""
- response.headers.add("X-Version", dify_config.CURRENT_VERSION)
+ response.headers.add("X-Version", dify_config.project.version)
response.headers.add("X-Env", dify_config.DEPLOY_ENV)
return response
@app.route("/health")
def health():
return Response(
- json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}),
+ json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.project.version}),
status=200,
content_type="application/json",
)
diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py
index 316be12f5c..a4d013ffc0 100644
--- a/api/extensions/ext_blueprints.py
+++ b/api/extensions/ext_blueprints.py
@@ -10,6 +10,7 @@ def init_app(app: DifyApp):
from controllers.console import bp as console_app_bp
from controllers.files import bp as files_bp
from controllers.inner_api import bp as inner_api_bp
+ from controllers.mcp import bp as mcp_bp
from controllers.service_api import bp as service_api_bp
from controllers.web import bp as web_bp
@@ -46,3 +47,4 @@ def init_app(app: DifyApp):
app.register_blueprint(files_bp)
app.register_blueprint(inner_api_bp)
+ app.register_blueprint(mcp_bp)
diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py
index a837552007..6279b1ad36 100644
--- a/api/extensions/ext_celery.py
+++ b/api/extensions/ext_celery.py
@@ -21,6 +21,7 @@ def init_app(app: DifyApp) -> Celery:
"master_name": dify_config.CELERY_SENTINEL_MASTER_NAME,
"sentinel_kwargs": {
"socket_timeout": dify_config.CELERY_SENTINEL_SOCKET_TIMEOUT,
+ "password": dify_config.CELERY_SENTINEL_PASSWORD,
},
}
diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py
index ddc2158a02..600e336c19 100644
--- a/api/extensions/ext_commands.py
+++ b/api/extensions/ext_commands.py
@@ -18,6 +18,7 @@ def init_app(app: DifyApp):
reset_email,
reset_encrypt_key_pair,
reset_password,
+ setup_system_tool_oauth_client,
upgrade_db,
vdb_migrate,
)
@@ -40,6 +41,7 @@ def init_app(app: DifyApp):
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
remove_orphaned_files_on_storage,
+ setup_system_tool_oauth_client,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)
diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py
index 3b4d787d01..11d1856ac4 100644
--- a/api/extensions/ext_login.py
+++ b/api/extensions/ext_login.py
@@ -10,7 +10,7 @@ from dify_app import DifyApp
from extensions.ext_database import db
from libs.passport import PassportService
from models.account import Account, Tenant, TenantAccountJoin
-from models.model import EndUser
+from models.model import AppMCPServer, EndUser
from services.account_service import AccountService
login_manager = flask_login.LoginManager()
@@ -74,6 +74,21 @@ def load_user_from_request(request_from_flask_login):
if not end_user:
raise NotFound("End user not found.")
return end_user
+ elif request.blueprint == "mcp":
+ server_code = request.view_args.get("server_code") if request.view_args else None
+ if not server_code:
+ raise Unauthorized("Invalid Authorization token.")
+ app_mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
+ if not app_mcp_server:
+ raise NotFound("App MCP server not found.")
+ end_user = (
+ db.session.query(EndUser)
+ .filter(EndUser.external_user_id == app_mcp_server.id, EndUser.type == "mcp")
+ .first()
+ )
+ if not end_user:
+ raise NotFound("End user not found.")
+ return end_user
@user_logged_in.connect
diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py
index 84bc12eca0..df5d8a9c11 100644
--- a/api/extensions/ext_mail.py
+++ b/api/extensions/ext_mail.py
@@ -54,6 +54,15 @@ class Mail:
use_tls=dify_config.SMTP_USE_TLS,
opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS,
)
+ case "sendgrid":
+ from libs.sendgrid import SendGridClient
+
+ if not dify_config.SENDGRID_API_KEY:
+ raise ValueError("SENDGRID_API_KEY is required for SendGrid mail type")
+
+ self._client = SendGridClient(
+ sendgrid_api_key=dify_config.SENDGRID_API_KEY, _from=dify_config.MAIL_DEFAULT_SEND_FROM or ""
+ )
case _:
raise ValueError("Unsupported mail type {}".format(mail_type))
diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py
index 6dcfa7bec6..b027a165f9 100644
--- a/api/extensions/ext_otel.py
+++ b/api/extensions/ext_otel.py
@@ -12,6 +12,7 @@ from flask_login import user_loaded_from_request, user_logged_in # type: ignore
from configs import dify_config
from dify_app import DifyApp
+from libs.helper import extract_tenant_id
from models import Account, EndUser
@@ -24,11 +25,8 @@ def on_user_loaded(_sender, user: Union["Account", "EndUser"]):
if user:
try:
current_span = get_current_span()
- if isinstance(user, Account) and user.current_tenant_id:
- tenant_id = user.current_tenant_id
- elif isinstance(user, EndUser):
- tenant_id = user.tenant_id
- else:
+ tenant_id = extract_tenant_id(user)
+ if not tenant_id:
return
if current_span:
current_span.set_attribute("service.tenant.id", tenant_id)
@@ -49,7 +47,7 @@ def init_app(app: DifyApp):
logging.getLogger().addHandler(exception_handler)
def init_flask_instrumentor(app: DifyApp):
- meter = get_meter("http_metrics", version=dify_config.CURRENT_VERSION)
+ meter = get_meter("http_metrics", version=dify_config.project.version)
_http_response_counter = meter.create_counter(
"http.server.response.count",
description="Total number of HTTP responses by status code, method and target",
@@ -163,7 +161,7 @@ def init_app(app: DifyApp):
resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: dify_config.APPLICATION_NAME,
- ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}",
+ ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
ResourceAttributes.PROCESS_PID: os.getpid(),
ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
ResourceAttributes.HOST_NAME: socket.gethostname(),
@@ -195,13 +193,22 @@ def init_app(app: DifyApp):
insecure=True,
)
else:
+ headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None
+
+ trace_endpoint = dify_config.OTLP_TRACE_ENDPOINT
+ if not trace_endpoint:
+ trace_endpoint = dify_config.OTLP_BASE_ENDPOINT + "/v1/traces"
exporter = HTTPSpanExporter(
- endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/traces",
- headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"},
+ endpoint=trace_endpoint,
+ headers=headers,
)
+
+ metric_endpoint = dify_config.OTLP_METRIC_ENDPOINT
+ if not metric_endpoint:
+ metric_endpoint = dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics"
metric_exporter = HTTPMetricExporter(
- endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics",
- headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"},
+ endpoint=metric_endpoint,
+ headers=headers,
)
else:
exporter = ConsoleSpanExporter()
diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py
index c283b1b7ca..be2f6115f7 100644
--- a/api/extensions/ext_redis.py
+++ b/api/extensions/ext_redis.py
@@ -1,6 +1,10 @@
+import functools
+import logging
+from collections.abc import Callable
from typing import Any, Union
import redis
+from redis import RedisError
from redis.cache import CacheConfig
from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
@@ -9,6 +13,8 @@ from redis.sentinel import Sentinel
from configs import dify_config
from dify_app import DifyApp
+logger = logging.getLogger(__name__)
+
class RedisClientWrapper:
"""
@@ -115,3 +121,25 @@ def init_app(app: DifyApp):
redis_client.initialize(redis.Redis(connection_pool=pool))
app.extensions["redis"] = redis_client
+
+
+def redis_fallback(default_return: Any = None):
+ """
+ decorator to handle Redis operation exceptions and return a default value when Redis is unavailable.
+
+ Args:
+ default_return: The value to return when a Redis operation fails. Defaults to None.
+ """
+
+ def decorator(func: Callable):
+ @functools.wraps(func)
+ def wrapper(*args, **kwargs):
+ try:
+ return func(*args, **kwargs)
+ except RedisError as e:
+ logger.warning(f"Redis operation failed in {func.__name__}: {str(e)}", exc_info=True)
+ return default_return
+
+ return wrapper
+
+ return decorator
diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py
index 3a74aace6a..82aed0d98d 100644
--- a/api/extensions/ext_sentry.py
+++ b/api/extensions/ext_sentry.py
@@ -35,6 +35,6 @@ def init_app(app: DifyApp):
traces_sample_rate=dify_config.SENTRY_TRACES_SAMPLE_RATE,
profiles_sample_rate=dify_config.SENTRY_PROFILES_SAMPLE_RATE,
environment=dify_config.DEPLOY_ENV,
- release=f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}",
+ release=f"dify-{dify_config.project.version}-{dify_config.COMMIT_SHA}",
before_send=before_send,
)
diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py
index 7448fd4a6b..81eec94da4 100644
--- a/api/extensions/storage/azure_blob_storage.py
+++ b/api/extensions/storage/azure_blob_storage.py
@@ -1,5 +1,5 @@
from collections.abc import Generator
-from datetime import UTC, datetime, timedelta
+from datetime import timedelta
from typing import Optional
from azure.identity import ChainedTokenCredential, DefaultAzureCredential
@@ -8,6 +8,7 @@ from azure.storage.blob import AccountSasPermissions, BlobServiceClient, Resourc
from configs import dify_config
from extensions.ext_redis import redis_client
from extensions.storage.base_storage import BaseStorage
+from libs.datetime_utils import naive_utc_now
class AzureBlobStorage(BaseStorage):
@@ -78,7 +79,7 @@ class AzureBlobStorage(BaseStorage):
account_key=self.account_key or "",
resource_types=ResourceTypes(service=True, container=True, object=True),
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
- expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1),
+ expiry=naive_utc_now() + timedelta(hours=1),
)
redis_client.set(cache_key, sas_token, ex=3000)
return BlobServiceClient(account_url=self.account_url or "", credential=sas_token)
diff --git a/api/factories/agent_factory.py b/api/factories/agent_factory.py
index 4b12afb528..2570bc22f1 100644
--- a/api/factories/agent_factory.py
+++ b/api/factories/agent_factory.py
@@ -10,6 +10,6 @@ def get_plugin_agent_strategy(
agent_provider = manager.fetch_agent_strategy_provider(tenant_id, agent_strategy_provider_name)
for agent_strategy in agent_provider.declaration.strategies:
if agent_strategy.identity.name == agent_strategy_name:
- return PluginAgentStrategy(tenant_id, agent_strategy)
+ return PluginAgentStrategy(tenant_id, agent_strategy, agent_provider.meta.version)
raise ValueError(f"Agent strategy {agent_strategy_name} not found")
diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py
index 52f119936f..c974dbb700 100644
--- a/api/factories/file_factory.py
+++ b/api/factories/file_factory.py
@@ -5,6 +5,7 @@ from typing import Any, cast
import httpx
from sqlalchemy import select
+from sqlalchemy.orm import Session
from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS
from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers
@@ -91,6 +92,8 @@ def build_from_mappings(
tenant_id: str,
strict_type_validation: bool = False,
) -> Sequence[File]:
+ # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query.
+ # Implement batch processing to reduce database load when handling multiple files.
files = [
build_from_mapping(
mapping=mapping,
@@ -145,9 +148,7 @@ def _build_from_local_file(
if strict_type_validation and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
- file_type = (
- FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
- )
+ file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
return File(
id=mapping.get("id"),
@@ -196,9 +197,7 @@ def _build_from_remote_url(
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
- FileType(specified_type)
- if specified_type and specified_type != FileType.CUSTOM.value
- else detected_file_type
+ FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
)
return File(
@@ -283,9 +282,7 @@ def _build_from_tool_file(
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
- file_type = (
- FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
- )
+ file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
return File(
id=mapping.get("id"),
@@ -377,3 +374,75 @@ def _get_file_type_by_mimetype(mime_type: str) -> FileType | None:
def get_file_type_by_mime_type(mime_type: str) -> FileType:
return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM
+
+
+class StorageKeyLoader:
+ """FileKeyLoader load the storage key from database for a list of files.
+ This loader is batched, the database query count is constant regardless of the input size.
+ """
+
+ def __init__(self, session: Session, tenant_id: str) -> None:
+ self._session = session
+ self._tenant_id = tenant_id
+
+ def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]:
+ stmt = select(UploadFile).where(
+ UploadFile.id.in_(upload_file_ids),
+ UploadFile.tenant_id == self._tenant_id,
+ )
+
+ return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)}
+
+ def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]:
+ stmt = select(ToolFile).where(
+ ToolFile.id.in_(tool_file_ids),
+ ToolFile.tenant_id == self._tenant_id,
+ )
+ return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)}
+
+ def load_storage_keys(self, files: Sequence[File]):
+ """Loads storage keys for a sequence of files by retrieving the corresponding
+ `UploadFile` or `ToolFile` records from the database based on their transfer method.
+
+ This method doesn't modify the input sequence structure but updates the `_storage_key`
+ property of each file object by extracting the relevant key from its database record.
+
+ Performance note: This is a batched operation where database query count remains constant
+ regardless of input size. However, for optimal performance, input sequences should contain
+ fewer than 1000 files. For larger collections, split into smaller batches and process each
+ batch separately.
+ """
+
+ upload_file_ids: list[uuid.UUID] = []
+ tool_file_ids: list[uuid.UUID] = []
+ for file in files:
+ related_model_id = file.related_id
+ if file.related_id is None:
+ raise ValueError("file id should not be None.")
+ if file.tenant_id != self._tenant_id:
+ err_msg = (
+ f"invalid file, expected tenant_id={self._tenant_id}, "
+ f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}"
+ )
+ raise ValueError(err_msg)
+ model_id = uuid.UUID(related_model_id)
+
+ if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
+ upload_file_ids.append(model_id)
+ elif file.transfer_method == FileTransferMethod.TOOL_FILE:
+ tool_file_ids.append(model_id)
+
+ tool_files = self._load_tool_files(tool_file_ids)
+ upload_files = self._load_upload_files(upload_file_ids)
+ for file in files:
+ model_id = uuid.UUID(file.related_id)
+ if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL):
+ upload_file_row = upload_files.get(model_id)
+ if upload_file_row is None:
+ raise ValueError(f"Upload file not found for id: {model_id}")
+ file._storage_key = upload_file_row.key
+ elif file.transfer_method == FileTransferMethod.TOOL_FILE:
+ tool_file_row = tool_files.get(model_id)
+ if tool_file_row is None:
+ raise ValueError(f"Tool file not found for id: {model_id}")
+ file._storage_key = tool_file_row.file_key
diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py
index a41ef4ae4e..39ebd009d5 100644
--- a/api/factories/variable_factory.py
+++ b/api/factories/variable_factory.py
@@ -43,6 +43,10 @@ class UnsupportedSegmentTypeError(Exception):
pass
+class TypeMismatchError(Exception):
+ pass
+
+
# Define the constant
SEGMENT_TO_VARIABLE_MAP = {
StringSegment: StringVariable,
@@ -87,9 +91,13 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
result = StringVariable.model_validate(mapping)
case SegmentType.SECRET:
result = SecretVariable.model_validate(mapping)
- case SegmentType.NUMBER if isinstance(value, int):
+ case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int):
+ mapping = dict(mapping)
+ mapping["value_type"] = SegmentType.INTEGER
result = IntegerVariable.model_validate(mapping)
- case SegmentType.NUMBER if isinstance(value, float):
+ case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float):
+ mapping = dict(mapping)
+ mapping["value_type"] = SegmentType.FLOAT
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f"invalid number value {value}")
@@ -110,7 +118,13 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
return cast(Variable, result)
+def infer_segment_type_from_value(value: Any, /) -> SegmentType:
+ return build_segment(value).value_type
+
+
def build_segment(value: Any, /) -> Segment:
+ # NOTE: If you have runtime type information available, consider using the `build_segment_with_type`
+ # below
if value is None:
return NoneSegment()
if isinstance(value, str):
@@ -126,12 +140,17 @@ def build_segment(value: Any, /) -> Segment:
if isinstance(value, list):
items = [build_segment(item) for item in value]
types = {item.value_type for item in items}
- if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items):
+ if all(isinstance(item, ArraySegment) for item in items):
+ return ArrayAnySegment(value=value)
+ elif len(types) != 1:
+ if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}):
+ return ArrayNumberSegment(value=value)
return ArrayAnySegment(value=value)
+
match types.pop():
case SegmentType.STRING:
return ArrayStringSegment(value=value)
- case SegmentType.NUMBER:
+ case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return ArrayNumberSegment(value=value)
case SegmentType.OBJECT:
return ArrayObjectSegment(value=value)
@@ -140,10 +159,100 @@ def build_segment(value: Any, /) -> Segment:
case SegmentType.NONE:
return ArrayAnySegment(value=value)
case _:
+ # This should be unreachable.
raise ValueError(f"not supported value {value}")
raise ValueError(f"not supported value {value}")
+_segment_factory: Mapping[SegmentType, type[Segment]] = {
+ SegmentType.NONE: NoneSegment,
+ SegmentType.STRING: StringSegment,
+ SegmentType.INTEGER: IntegerSegment,
+ SegmentType.FLOAT: FloatSegment,
+ SegmentType.FILE: FileSegment,
+ SegmentType.OBJECT: ObjectSegment,
+ # Array types
+ SegmentType.ARRAY_ANY: ArrayAnySegment,
+ SegmentType.ARRAY_STRING: ArrayStringSegment,
+ SegmentType.ARRAY_NUMBER: ArrayNumberSegment,
+ SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
+ SegmentType.ARRAY_FILE: ArrayFileSegment,
+}
+
+
+def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
+ """
+ Build a segment with explicit type checking.
+
+ This function creates a segment from a value while enforcing type compatibility
+ with the specified segment_type. It provides stricter type validation compared
+ to the standard build_segment function.
+
+ Args:
+ segment_type: The expected SegmentType for the resulting segment
+ value: The value to be converted into a segment
+
+ Returns:
+ Segment: A segment instance of the appropriate type
+
+ Raises:
+ TypeMismatchError: If the value type doesn't match the expected segment_type
+
+ Special Cases:
+ - For empty list [] values, if segment_type is array[*], returns the corresponding array type
+ - Type validation is performed before segment creation
+
+ Examples:
+ >>> build_segment_with_type(SegmentType.STRING, "hello")
+ StringSegment(value="hello")
+
+ >>> build_segment_with_type(SegmentType.ARRAY_STRING, [])
+ ArrayStringSegment(value=[])
+
+ >>> build_segment_with_type(SegmentType.STRING, 123)
+ # Raises TypeMismatchError
+ """
+ # Handle None values
+ if value is None:
+ if segment_type == SegmentType.NONE:
+ return NoneSegment()
+ else:
+ raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None")
+
+ # Handle empty list special case for array types
+ if isinstance(value, list) and len(value) == 0:
+ if segment_type == SegmentType.ARRAY_ANY:
+ return ArrayAnySegment(value=value)
+ elif segment_type == SegmentType.ARRAY_STRING:
+ return ArrayStringSegment(value=value)
+ elif segment_type == SegmentType.ARRAY_NUMBER:
+ return ArrayNumberSegment(value=value)
+ elif segment_type == SegmentType.ARRAY_OBJECT:
+ return ArrayObjectSegment(value=value)
+ elif segment_type == SegmentType.ARRAY_FILE:
+ return ArrayFileSegment(value=value)
+ else:
+ raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list")
+
+ inferred_type = SegmentType.infer_segment_type(value)
+ # Type compatibility checking
+ if inferred_type is None:
+ raise TypeMismatchError(
+ f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}"
+ )
+ if inferred_type == segment_type:
+ segment_class = _segment_factory[segment_type]
+ return segment_class(value_type=segment_type, value=value)
+ elif segment_type == SegmentType.NUMBER and inferred_type in (
+ SegmentType.INTEGER,
+ SegmentType.FLOAT,
+ ):
+ segment_class = _segment_factory[inferred_type]
+ return segment_class(value_type=inferred_type, value=value)
+ else:
+ raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
+
+
def segment_to_variable(
*,
segment: Segment,
@@ -169,6 +278,6 @@ def segment_to_variable(
name=name,
description=description,
value=segment.value,
- selector=selector,
+ selector=list(selector),
),
)
diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py
new file mode 100644
index 0000000000..8288bd54a3
--- /dev/null
+++ b/api/fields/_value_type_serializer.py
@@ -0,0 +1,15 @@
+from typing import TypedDict
+
+from core.variables.segments import Segment
+from core.variables.types import SegmentType
+
+
+class _VarTypedDict(TypedDict, total=False):
+ value_type: SegmentType
+
+
+def serialize_value_type(v: _VarTypedDict | Segment) -> str:
+ if isinstance(v, Segment):
+ return v.value_type.exposed_type().value
+ else:
+ return v["value_type"].exposed_type().value
diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py
index 500ca47c7e..b6d85e0e24 100644
--- a/api/fields/app_fields.py
+++ b/api/fields/app_fields.py
@@ -1,8 +1,21 @@
+import json
+
from flask_restful import fields
from fields.workflow_fields import workflow_partial_fields
from libs.helper import AppIconUrlField, TimestampField
+
+class JsonStringField(fields.Raw):
+ def format(self, value):
+ if isinstance(value, str):
+ try:
+ return json.loads(value)
+ except (json.JSONDecodeError, TypeError):
+ return value
+ return value
+
+
app_detail_kernel_fields = {
"id": fields.String,
"name": fields.String,
@@ -175,6 +188,7 @@ app_detail_fields_with_site = {
"site": fields.Nested(site_fields),
"api_base_url": fields.String,
"use_icon_as_answer_icon": fields.Boolean,
+ "max_active_requests": fields.Integer,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
@@ -218,3 +232,14 @@ app_import_fields = {
app_import_check_dependencies_fields = {
"leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)),
}
+
+app_server_fields = {
+ "id": fields.String,
+ "name": fields.String,
+ "server_code": fields.String,
+ "description": fields.String,
+ "status": fields.String,
+ "parameters": JsonStringField,
+ "created_at": TimestampField,
+ "updated_at": TimestampField,
+}
diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py
index 71785e7d67..c5a0c9a49d 100644
--- a/api/fields/conversation_variable_fields.py
+++ b/api/fields/conversation_variable_fields.py
@@ -2,10 +2,12 @@ from flask_restful import fields
from libs.helper import TimestampField
+from ._value_type_serializer import serialize_value_type
+
conversation_variable_fields = {
"id": fields.String,
"name": fields.String,
- "value_type": fields.String(attribute="value_type.value"),
+ "value_type": fields.String(attribute=serialize_value_type),
"value": fields.String,
"description": fields.String,
"created_at": TimestampField,
diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py
index 9f1bef3b36..930e59cc1c 100644
--- a/api/fields/workflow_fields.py
+++ b/api/fields/workflow_fields.py
@@ -5,6 +5,8 @@ from core.variables import SecretVariable, SegmentType, Variable
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
+from ._value_type_serializer import serialize_value_type
+
ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET)
@@ -17,16 +19,23 @@ class EnvironmentVariableField(fields.Raw):
"name": value.name,
"value": encrypter.obfuscated_token(value.value),
"value_type": value.value_type.value,
+ "description": value.description,
}
if isinstance(value, Variable):
return {
"id": value.id,
"name": value.name,
"value": value.value,
- "value_type": value.value_type.value,
+ "value_type": value.value_type.exposed_type().value,
+ "description": value.description,
}
if isinstance(value, dict):
- value_type = value.get("value_type")
+ value_type_str = value.get("value_type")
+ if not isinstance(value_type_str, str):
+ raise TypeError(
+ f"unexpected type for value_type field, value={value_type_str}, type={type(value_type_str)}"
+ )
+ value_type = SegmentType(value_type_str).exposed_type()
if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES:
raise ValueError(f"Unsupported environment variable value type: {value_type}")
return value
@@ -35,7 +44,7 @@ class EnvironmentVariableField(fields.Raw):
conversation_variable_fields = {
"id": fields.String,
"name": fields.String,
- "value_type": fields.String(attribute="value_type.value"),
+ "value_type": fields.String(attribute=serialize_value_type),
"value": fields.Raw,
"description": fields.String,
}
diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py
index 74fdf8bd97..a106728e9c 100644
--- a/api/fields/workflow_run_fields.py
+++ b/api/fields/workflow_run_fields.py
@@ -19,7 +19,6 @@ workflow_run_for_log_fields = {
workflow_run_for_list_fields = {
"id": fields.String,
- "sequence_number": fields.Integer,
"version": fields.String,
"status": fields.String,
"elapsed_time": fields.Float,
@@ -36,7 +35,6 @@ advanced_chat_workflow_run_for_list_fields = {
"id": fields.String,
"conversation_id": fields.String,
"message_id": fields.String,
- "sequence_number": fields.Integer,
"version": fields.String,
"status": fields.String,
"elapsed_time": fields.Float,
@@ -63,7 +61,6 @@ workflow_run_pagination_fields = {
workflow_run_detail_fields = {
"id": fields.String,
- "sequence_number": fields.Integer,
"version": fields.String,
"graph": fields.Raw(attribute="graph_dict"),
"inputs": fields.Raw(attribute="inputs_dict"),
diff --git a/api/libs/datetime_utils.py b/api/libs/datetime_utils.py
new file mode 100644
index 0000000000..e576a34629
--- /dev/null
+++ b/api/libs/datetime_utils.py
@@ -0,0 +1,22 @@
+import abc
+import datetime
+from typing import Protocol
+
+
+class _NowFunction(Protocol):
+ @abc.abstractmethod
+ def __call__(self, tz: datetime.timezone | None) -> datetime.datetime:
+ pass
+
+
+# _now_func is a callable with the _NowFunction signature.
+# Its sole purpose is to abstract time retrieval, enabling
+# developers to mock this behavior in tests and time-dependent scenarios.
+_now_func: _NowFunction = datetime.datetime.now
+
+
+def naive_utc_now() -> datetime.datetime:
+ """Return a naive datetime object (without timezone information)
+ representing current UTC time.
+ """
+ return _now_func(datetime.UTC).replace(tzinfo=None)
diff --git a/api/libs/file_utils.py b/api/libs/file_utils.py
new file mode 100644
index 0000000000..982b2cc1ac
--- /dev/null
+++ b/api/libs/file_utils.py
@@ -0,0 +1,30 @@
+from pathlib import Path
+
+
+def search_file_upwards(
+ base_dir_path: Path,
+ target_file_name: str,
+ max_search_parent_depth: int,
+) -> Path:
+ """
+ Find a target file in the current directory or its parent directories up to a specified depth.
+ :param base_dir_path: Starting directory path to search from.
+ :param target_file_name: Name of the file to search for.
+ :param max_search_parent_depth: Maximum number of parent directories to search upwards.
+ :return: Path of the file if found, otherwise None.
+ """
+ current_path = base_dir_path.resolve()
+ for _ in range(max_search_parent_depth):
+ candidate_path = current_path / target_file_name
+ if candidate_path.is_file():
+ return candidate_path
+ parent_path = current_path.parent
+ if parent_path == current_path: # reached the root directory
+ break
+ else:
+ current_path = parent_path
+
+ raise ValueError(
+ f"File '{target_file_name}' not found in the directory '{base_dir_path.resolve()}' or its parent directories"
+ f" in depth of {max_search_parent_depth}."
+ )
diff --git a/api/libs/helper.py b/api/libs/helper.py
index 3f2a630956..00772d530a 100644
--- a/api/libs/helper.py
+++ b/api/libs/helper.py
@@ -25,6 +25,31 @@ from extensions.ext_redis import redis_client
if TYPE_CHECKING:
from models.account import Account
+ from models.model import EndUser
+
+
+def extract_tenant_id(user: Union["Account", "EndUser"]) -> str | None:
+ """
+ Extract tenant_id from Account or EndUser object.
+
+ Args:
+ user: Account or EndUser object
+
+ Returns:
+ tenant_id string if available, None otherwise
+
+ Raises:
+ ValueError: If user is neither Account nor EndUser
+ """
+ from models.account import Account
+ from models.model import EndUser
+
+ if isinstance(user, Account):
+ return user.current_tenant_id
+ elif isinstance(user, EndUser):
+ return user.tenant_id
+ else:
+ raise ValueError(f"Invalid user type: {type(user)}. Expected Account or EndUser.")
def run(script):
@@ -123,25 +148,6 @@ class StrLen:
return value
-class FloatRange:
- """Restrict input to an float in a range (inclusive)"""
-
- def __init__(self, low, high, argument="argument"):
- self.low = low
- self.high = high
- self.argument = argument
-
- def __call__(self, value):
- value = _get_float(value)
- if value < self.low or value > self.high:
- error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format(
- arg=self.argument, val=value, lo=self.low, hi=self.high
- )
- raise ValueError(error)
-
- return value
-
-
class DatetimeString:
def __init__(self, format, argument="argument"):
self.format = format
diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py
index 218109522d..78f827584c 100644
--- a/api/libs/oauth_data_source.py
+++ b/api/libs/oauth_data_source.py
@@ -1,4 +1,3 @@
-import datetime
import urllib.parse
from typing import Any
@@ -6,6 +5,7 @@ import requests
from flask_login import current_user
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding
@@ -75,7 +75,7 @@ class NotionOAuth(OAuthDataSource):
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
- data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
@@ -115,7 +115,7 @@ class NotionOAuth(OAuthDataSource):
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
- data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
@@ -154,7 +154,7 @@ class NotionOAuth(OAuthDataSource):
}
data_source_binding.source_info = new_source_info
data_source_binding.disabled = False
- data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
raise ValueError("Data source binding not found")
diff --git a/api/libs/passport.py b/api/libs/passport.py
index 8df4f529bc..fe8fc33b5f 100644
--- a/api/libs/passport.py
+++ b/api/libs/passport.py
@@ -14,9 +14,11 @@ class PassportService:
def verify(self, token):
try:
return jwt.decode(token, self.sk, algorithms=["HS256"])
+ except jwt.exceptions.ExpiredSignatureError:
+ raise Unauthorized("Token has expired.")
except jwt.exceptions.InvalidSignatureError:
raise Unauthorized("Invalid token signature.")
except jwt.exceptions.DecodeError:
raise Unauthorized("Invalid token.")
- except jwt.exceptions.ExpiredSignatureError:
- raise Unauthorized("Token has expired.")
+ except jwt.exceptions.PyJWTError: # Catch-all for other JWT errors
+ raise Unauthorized("Invalid token.")
diff --git a/api/libs/rsa.py b/api/libs/rsa.py
index 637bcc4a1d..da279eb32b 100644
--- a/api/libs/rsa.py
+++ b/api/libs/rsa.py
@@ -1,4 +1,5 @@
import hashlib
+from typing import Union
from Crypto.Cipher import AES
from Crypto.PublicKey import RSA
@@ -9,7 +10,7 @@ from extensions.ext_storage import storage
from libs import gmpy2_pkcs10aep_cipher
-def generate_key_pair(tenant_id):
+def generate_key_pair(tenant_id: str) -> str:
private_key = RSA.generate(2048)
public_key = private_key.publickey()
@@ -26,7 +27,7 @@ def generate_key_pair(tenant_id):
prefix_hybrid = b"HYBRID:"
-def encrypt(text, public_key):
+def encrypt(text: str, public_key: Union[str, bytes]) -> bytes:
if isinstance(public_key, str):
public_key = public_key.encode()
@@ -38,14 +39,14 @@ def encrypt(text, public_key):
rsa_key = RSA.import_key(public_key)
cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key)
- enc_aes_key = cipher_rsa.encrypt(aes_key)
+ enc_aes_key: bytes = cipher_rsa.encrypt(aes_key)
encrypted_data = enc_aes_key + cipher_aes.nonce + tag + ciphertext
return prefix_hybrid + encrypted_data
-def get_decrypt_decoding(tenant_id):
+def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]:
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
@@ -64,7 +65,7 @@ def get_decrypt_decoding(tenant_id):
return rsa_key, cipher_rsa
-def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa):
+def decrypt_token_with_decoding(encrypted_text: bytes, rsa_key: RSA.RsaKey, cipher_rsa) -> str:
if encrypted_text.startswith(prefix_hybrid):
encrypted_text = encrypted_text[len(prefix_hybrid) :]
@@ -83,10 +84,10 @@ def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa):
return decrypted_text.decode()
-def decrypt(encrypted_text, tenant_id):
+def decrypt(encrypted_text: bytes, tenant_id: str) -> str:
rsa_key, cipher_rsa = get_decrypt_decoding(tenant_id)
- return decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa)
+ return decrypt_token_with_decoding(encrypted_text=encrypted_text, rsa_key=rsa_key, cipher_rsa=cipher_rsa)
class PrivkeyNotFoundError(Exception):
diff --git a/api/libs/sendgrid.py b/api/libs/sendgrid.py
new file mode 100644
index 0000000000..5409e3eeeb
--- /dev/null
+++ b/api/libs/sendgrid.py
@@ -0,0 +1,45 @@
+import logging
+
+import sendgrid # type: ignore
+from python_http_client.exceptions import ForbiddenError, UnauthorizedError
+from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore
+
+
+class SendGridClient:
+ def __init__(self, sendgrid_api_key: str, _from: str):
+ self.sendgrid_api_key = sendgrid_api_key
+ self._from = _from
+
+ def send(self, mail: dict):
+ logging.debug("Sending email with SendGrid")
+
+ try:
+ _to = mail["to"]
+
+ if not _to:
+ raise ValueError("SendGridClient: Cannot send email: recipient address is missing.")
+
+ sg = sendgrid.SendGridAPIClient(api_key=self.sendgrid_api_key)
+ from_email = Email(self._from)
+ to_email = To(_to)
+ subject = mail["subject"]
+ content = Content("text/html", mail["html"])
+ mail = Mail(from_email, to_email, subject, content)
+ mail_json = mail.get() # type: ignore
+ response = sg.client.mail.send.post(request_body=mail_json)
+ logging.debug(response.status_code)
+ logging.debug(response.body)
+ logging.debug(response.headers)
+
+ except TimeoutError as e:
+ logging.exception("SendGridClient Timeout occurred while sending email")
+ raise
+ except (UnauthorizedError, ForbiddenError) as e:
+ logging.exception(
+ "SendGridClient Authentication failed. "
+ "Verify that your credentials and the 'from' email address are correct"
+ )
+ raise
+ except Exception as e:
+ logging.exception(f"SendGridClient Unexpected error occurred while sending email to {_to}")
+ raise
diff --git a/api/libs/smtp.py b/api/libs/smtp.py
index 35561f071c..b94386660e 100644
--- a/api/libs/smtp.py
+++ b/api/libs/smtp.py
@@ -22,7 +22,11 @@ class SMTPClient:
if self.use_tls:
if self.opportunistic_tls:
smtp = smtplib.SMTP(self.server, self.port, timeout=10)
+ # Send EHLO command with the HELO domain name as the server address
+ smtp.ehlo(self.server)
smtp.starttls()
+ # Resend EHLO command to identify the TLS session
+ smtp.ehlo(self.server)
else:
smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10)
else:
diff --git a/api/libs/uuid_utils.py b/api/libs/uuid_utils.py
new file mode 100644
index 0000000000..a8190011ed
--- /dev/null
+++ b/api/libs/uuid_utils.py
@@ -0,0 +1,164 @@
+import secrets
+import struct
+import time
+import uuid
+
+# Reference for UUIDv7 specification:
+# RFC 9562, Section 5.7 - https://www.rfc-editor.org/rfc/rfc9562.html#section-5.7
+
+# Define the format for packing the timestamp as an unsigned 64-bit integer (big-endian).
+#
+# For details on the `struct.pack` format, refer to:
+# https://docs.python.org/3/library/struct.html#byte-order-size-and-alignment
+_PACK_TIMESTAMP = ">Q"
+
+# Define the format for packing the 12-bit random data A (as specified in RFC 9562 Section 5.7)
+# into an unsigned 16-bit integer (big-endian).
+_PACK_RAND_A = ">H"
+
+
+def _create_uuidv7_bytes(timestamp_ms: int, random_bytes: bytes) -> bytes:
+ """Create UUIDv7 byte structure with given timestamp and random bytes.
+
+ This is a private helper function that handles the common logic for creating
+ UUIDv7 byte structure according to RFC 9562 specification.
+
+ UUIDv7 Structure:
+ - 48 bits: timestamp (milliseconds since Unix epoch)
+ - 12 bits: random data A (with version bits)
+ - 62 bits: random data B (with variant bits)
+
+ The function performs the following operations:
+ 1. Creates a 128-bit (16-byte) UUID structure
+ 2. Packs the timestamp into the first 48 bits (6 bytes)
+ 3. Sets the version bits to 7 (0111) in the correct position
+ 4. Sets the variant bits to 10 (binary) in the correct position
+ 5. Fills the remaining bits with the provided random bytes
+
+ Args:
+ timestamp_ms: The timestamp in milliseconds since Unix epoch (48 bits).
+ random_bytes: Random bytes to use for the random portions (must be 10 bytes).
+ First 2 bytes are used for random data A (12 bits after version).
+ Last 8 bytes are used for random data B (62 bits after variant).
+
+ Returns:
+ A 16-byte bytes object representing the complete UUIDv7 structure.
+
+ Note:
+ This function assumes the random_bytes parameter is exactly 10 bytes.
+ The caller is responsible for providing appropriate random data.
+ """
+ # Create the 128-bit UUID structure
+ uuid_bytes = bytearray(16)
+
+ # Pack timestamp (48 bits) into first 6 bytes
+ uuid_bytes[0:6] = struct.pack(_PACK_TIMESTAMP, timestamp_ms)[2:8] # Take last 6 bytes of 8-byte big-endian
+
+ # Next 16 bits: random data A (12 bits) + version (4 bits)
+ # Take first 2 random bytes and set version to 7
+ rand_a = struct.unpack(_PACK_RAND_A, random_bytes[0:2])[0]
+ # Clear the highest 4 bits to make room for the version field
+ # by performing a bitwise AND with 0x0FFF (binary: 0b0000_1111_1111_1111).
+ rand_a = rand_a & 0x0FFF
+ # Set the version field to 7 (binary: 0111) by performing a bitwise OR with 0x7000 (binary: 0b0111_0000_0000_0000).
+ rand_a = rand_a | 0x7000
+ uuid_bytes[6:8] = struct.pack(_PACK_RAND_A, rand_a)
+
+ # Last 64 bits: random data B (62 bits) + variant (2 bits)
+ # Use remaining 8 random bytes and set variant to 10 (binary)
+ uuid_bytes[8:16] = random_bytes[2:10]
+
+ # Set variant bits (first 2 bits of byte 8 should be '10')
+ uuid_bytes[8] = (uuid_bytes[8] & 0x3F) | 0x80 # Set variant to 10xxxxxx
+
+ return bytes(uuid_bytes)
+
+
+def uuidv7(timestamp_ms: int | None = None) -> uuid.UUID:
+ """Generate a UUID version 7 according to RFC 9562 specification.
+
+ UUIDv7 features a time-ordered value field derived from the widely
+ implemented and well known Unix Epoch timestamp source, the number of
+ milliseconds since midnight 1 Jan 1970 UTC, leap seconds excluded.
+
+ Structure:
+ - 48 bits: timestamp (milliseconds since Unix epoch)
+ - 12 bits: random data A (with version bits)
+ - 62 bits: random data B (with variant bits)
+
+ Args:
+ timestamp_ms: The timestamp used when generating UUID, use the current time if unspecified.
+ Should be an integer representing milliseconds since Unix epoch.
+
+ Returns:
+ A UUID object representing a UUIDv7.
+
+ Example:
+ >>> import time
+ >>> # Generate UUIDv7 with current time
+ >>> uuid_current = uuidv7()
+ >>> # Generate UUIDv7 with specific timestamp
+ >>> uuid_specific = uuidv7(int(time.time() * 1000))
+ """
+ if timestamp_ms is None:
+ timestamp_ms = int(time.time() * 1000)
+
+ # Generate 10 random bytes for the random portions
+ random_bytes = secrets.token_bytes(10)
+
+ # Create UUIDv7 bytes using the helper function
+ uuid_bytes = _create_uuidv7_bytes(timestamp_ms, random_bytes)
+
+ return uuid.UUID(bytes=uuid_bytes)
+
+
+def uuidv7_timestamp(id_: uuid.UUID) -> int:
+ """Extract the timestamp from a UUIDv7.
+
+ UUIDv7 contains a 48-bit timestamp field representing milliseconds since
+ the Unix epoch (1970-01-01 00:00:00 UTC). This function extracts and
+ returns that timestamp as an integer representing milliseconds since the epoch.
+
+ Args:
+ id_: A UUID object that should be a UUIDv7 (version 7).
+
+ Returns:
+ The timestamp as an integer representing milliseconds since Unix epoch.
+
+ Raises:
+ ValueError: If the provided UUID is not version 7.
+
+ Example:
+ >>> uuid_v7 = uuidv7()
+ >>> timestamp = uuidv7_timestamp(uuid_v7)
+ >>> print(f"UUID was created at: {timestamp} ms")
+ """
+ # Verify this is a UUIDv7
+ if id_.version != 7:
+ raise ValueError(f"Expected UUIDv7 (version 7), got version {id_.version}")
+
+ # Extract the UUID bytes
+ uuid_bytes = id_.bytes
+
+ # Extract the first 48 bits (6 bytes) as the timestamp in milliseconds
+ # Pad with 2 zero bytes at the beginning to make it 8 bytes for unpacking as Q (unsigned long long)
+ timestamp_bytes = b"\x00\x00" + uuid_bytes[0:6]
+ ts_in_ms = struct.unpack(_PACK_TIMESTAMP, timestamp_bytes)[0]
+
+ # Return timestamp directly in milliseconds as integer
+ assert isinstance(ts_in_ms, int)
+ return ts_in_ms
+
+
+def uuidv7_boundary(timestamp_ms: int) -> uuid.UUID:
+ """Generate a non-random uuidv7 with the given timestamp (first 48 bits) and
+ all random bits to 0. As the smallest possible uuidv7 for that timestamp,
+ it may be used as a boundary for partitions.
+ """
+ # Use zero bytes for all random portions
+ zero_random_bytes = b"\x00" * 10
+
+ # Create UUIDv7 bytes using the helper function
+ uuid_bytes = _create_uuidv7_bytes(timestamp_ms, zero_random_bytes)
+
+ return uuid.UUID(bytes=uuid_bytes)
diff --git a/api/migrations/versions/2025_06_19_1633-0ab65e1cc7fa_remove_sequence_number_from_workflow_.py b/api/migrations/versions/2025_06_19_1633-0ab65e1cc7fa_remove_sequence_number_from_workflow_.py
new file mode 100644
index 0000000000..29fef77798
--- /dev/null
+++ b/api/migrations/versions/2025_06_19_1633-0ab65e1cc7fa_remove_sequence_number_from_workflow_.py
@@ -0,0 +1,66 @@
+"""remove sequence_number from workflow_runs
+
+Revision ID: 0ab65e1cc7fa
+Revises: 4474872b0ee6
+Create Date: 2025-06-19 16:33:13.377215
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '0ab65e1cc7fa'
+down_revision = '4474872b0ee6'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
+ batch_op.drop_index(batch_op.f('workflow_run_tenant_app_sequence_idx'))
+ batch_op.drop_column('sequence_number')
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+
+ # WARNING: This downgrade CANNOT recover the original sequence_number values!
+ # The original sequence numbers are permanently lost after the upgrade.
+ # This downgrade will regenerate sequence numbers based on created_at order,
+ # which may result in different values than the original sequence numbers.
+ #
+ # If you need to preserve original sequence numbers, use the alternative
+ # migration approach that creates a backup table before removal.
+
+ # Step 1: Add sequence_number column as nullable first
+ with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('sequence_number', sa.INTEGER(), autoincrement=False, nullable=True))
+
+ # Step 2: Populate sequence_number values based on created_at order within each app
+ # NOTE: This recreates sequence numbering logic but values will be different
+ # from the original sequence numbers that were removed in the upgrade
+ connection = op.get_bind()
+ connection.execute(sa.text("""
+ UPDATE workflow_runs
+ SET sequence_number = subquery.row_num
+ FROM (
+ SELECT id, ROW_NUMBER() OVER (
+ PARTITION BY tenant_id, app_id
+ ORDER BY created_at, id
+ ) as row_num
+ FROM workflow_runs
+ ) subquery
+ WHERE workflow_runs.id = subquery.id
+ """))
+
+ # Step 3: Make the column NOT NULL and add the index
+ with op.batch_alter_table('workflow_runs', schema=None) as batch_op:
+ batch_op.alter_column('sequence_number', nullable=False)
+ batch_op.create_index(batch_op.f('workflow_run_tenant_app_sequence_idx'), ['tenant_id', 'app_id', 'sequence_number'], unique=False)
+
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2025_06_25_0936-58eb7bdb93fe_add_mcp_server_tool_and_app_server.py b/api/migrations/versions/2025_06_25_0936-58eb7bdb93fe_add_mcp_server_tool_and_app_server.py
new file mode 100644
index 0000000000..0548bf05ef
--- /dev/null
+++ b/api/migrations/versions/2025_06_25_0936-58eb7bdb93fe_add_mcp_server_tool_and_app_server.py
@@ -0,0 +1,64 @@
+"""add mcp server tool and app server
+
+Revision ID: 58eb7bdb93fe
+Revises: 0ab65e1cc7fa
+Create Date: 2025-06-25 09:36:07.510570
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '58eb7bdb93fe'
+down_revision = '0ab65e1cc7fa'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('app_mcp_servers',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+ sa.Column('app_id', models.types.StringUUID(), nullable=False),
+ sa.Column('name', sa.String(length=255), nullable=False),
+ sa.Column('description', sa.String(length=255), nullable=False),
+ sa.Column('server_code', sa.String(length=255), nullable=False),
+ sa.Column('status', sa.String(length=255), server_default=sa.text("'normal'::character varying"), nullable=False),
+ sa.Column('parameters', sa.Text(), nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='app_mcp_server_pkey'),
+ sa.UniqueConstraint('tenant_id', 'app_id', name='unique_app_mcp_server_tenant_app_id'),
+ sa.UniqueConstraint('server_code', name='unique_app_mcp_server_server_code')
+ )
+ op.create_table('tool_mcp_providers',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('name', sa.String(length=40), nullable=False),
+ sa.Column('server_identifier', sa.String(length=24), nullable=False),
+ sa.Column('server_url', sa.Text(), nullable=False),
+ sa.Column('server_url_hash', sa.String(length=64), nullable=False),
+ sa.Column('icon', sa.String(length=255), nullable=True),
+ sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+ sa.Column('user_id', models.types.StringUUID(), nullable=False),
+ sa.Column('encrypted_credentials', sa.Text(), nullable=True),
+ sa.Column('authed', sa.Boolean(), nullable=False),
+ sa.Column('tools', sa.Text(), nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='tool_mcp_provider_pkey'),
+ sa.UniqueConstraint('tenant_id', 'name', name='unique_mcp_provider_name'),
+ sa.UniqueConstraint('tenant_id', 'server_identifier', name='unique_mcp_provider_server_identifier'),
+ sa.UniqueConstraint('tenant_id', 'server_url_hash', name='unique_mcp_provider_server_url')
+ )
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.drop_table('tool_mcp_providers')
+ op.drop_table('app_mcp_servers')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py b/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py
new file mode 100644
index 0000000000..2bbbb3d28e
--- /dev/null
+++ b/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py
@@ -0,0 +1,86 @@
+"""add uuidv7 function in SQL
+
+Revision ID: 1c9ba48be8e4
+Revises: 58eb7bdb93fe
+Create Date: 2025-07-02 23:32:38.484499
+
+"""
+
+"""
+The functions in this files comes from https://github.com/dverite/postgres-uuidv7-sql/, with minor modifications.
+
+LICENSE:
+
+# Copyright and License
+
+Copyright (c) 2024, Daniel Vérité
+
+Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement is hereby granted, provided that the above copyright notice and this paragraph and the following two paragraphs appear in all copies.
+
+In no event shall Daniel Vérité be liable to any party for direct, indirect, special, incidental, or consequential damages, including lost profits, arising out of the use of this software and its documentation, even if Daniel Vérité has been advised of the possibility of such damage.
+
+Daniel Vérité specifically disclaims any warranties, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose. The software provided hereunder is on an "AS IS" basis, and Daniel Vérité has no obligations to provide maintenance, support, updates, enhancements, or modifications.
+"""
+
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '1c9ba48be8e4'
+down_revision = '58eb7bdb93fe'
+branch_labels: None = None
+depends_on: None = None
+
+
+def upgrade():
+ # This implementation differs slightly from the original uuidv7 function in
+ # https://github.com/dverite/postgres-uuidv7-sql/.
+ # The ability to specify source timestamp has been removed because its type signature is incompatible with
+ # PostgreSQL 18's `uuidv7` function. This capability is rarely needed in practice, as IDs can be
+ # generated and controlled within the application layer.
+ op.execute(sa.text(r"""
+/* Main function to generate a uuidv7 value with millisecond precision */
+CREATE FUNCTION uuidv7() RETURNS uuid
+AS
+$$
+ -- Replace the first 48 bits of a uuidv4 with the current
+ -- number of milliseconds since 1970-01-01 UTC
+ -- and set the "ver" field to 7 by setting additional bits
+SELECT encode(
+ set_bit(
+ set_bit(
+ overlay(uuid_send(gen_random_uuid()) placing
+ substring(int8send((extract(epoch from clock_timestamp()) * 1000)::bigint) from
+ 3)
+ from 1 for 6),
+ 52, 1),
+ 53, 1), 'hex')::uuid;
+$$ LANGUAGE SQL VOLATILE PARALLEL SAFE;
+
+COMMENT ON FUNCTION uuidv7 IS
+ 'Generate a uuid-v7 value with a 48-bit timestamp (millisecond precision) and 74 bits of randomness';
+"""))
+
+ op.execute(sa.text(r"""
+CREATE FUNCTION uuidv7_boundary(timestamptz) RETURNS uuid
+AS
+$$
+ /* uuid fields: version=0b0111, variant=0b10 */
+SELECT encode(
+ overlay('\x00000000000070008000000000000000'::bytea
+ placing substring(int8send(floor(extract(epoch from $1) * 1000)::bigint) from 3)
+ from 1 for 6),
+ 'hex')::uuid;
+$$ LANGUAGE SQL STABLE STRICT PARALLEL SAFE;
+
+COMMENT ON FUNCTION uuidv7_boundary(timestamptz) IS
+ 'Generate a non-random uuidv7 with the given timestamp (first 48 bits) and all random bits to 0. As the smallest possible uuidv7 for that timestamp, it may be used as a boundary for partitions.';
+"""
+))
+
+
+def downgrade():
+ op.execute(sa.text("DROP FUNCTION uuidv7"))
+ op.execute(sa.text("DROP FUNCTION uuidv7_boundary"))
diff --git a/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py b/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py
new file mode 100644
index 0000000000..df4fbf0a0e
--- /dev/null
+++ b/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py
@@ -0,0 +1,62 @@
+"""tool oauth
+
+Revision ID: 71f5020c6470
+Revises: 4474872b0ee6
+Create Date: 2025-06-24 17:05:43.118647
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '71f5020c6470'
+down_revision = '1c9ba48be8e4'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('tool_oauth_system_clients',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('plugin_id', sa.String(length=512), nullable=False),
+ sa.Column('provider', sa.String(length=255), nullable=False),
+ sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'),
+ sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx')
+ )
+ op.create_table('tool_oauth_tenant_clients',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+ sa.Column('plugin_id', sa.String(length=512), nullable=False),
+ sa.Column('provider', sa.String(length=255), nullable=False),
+ sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
+ sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'),
+ sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client')
+ )
+
+ with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False))
+ batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
+ batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False))
+ batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
+ batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name'])
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
+ batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
+ batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider'])
+ batch_op.drop_column('credential_type')
+ batch_op.drop_column('is_default')
+ batch_op.drop_column('name')
+
+ op.drop_table('tool_oauth_tenant_clients')
+ op.drop_table('tool_oauth_system_clients')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py b/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py
new file mode 100644
index 0000000000..3bdbafda7c
--- /dev/null
+++ b/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py
@@ -0,0 +1,51 @@
+"""update models
+
+Revision ID: 1a83934ad6d1
+Revises: 71f5020c6470
+Create Date: 2025-07-21 09:35:48.774794
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '1a83934ad6d1'
+down_revision = '71f5020c6470'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
+ batch_op.alter_column('server_identifier',
+ existing_type=sa.VARCHAR(length=24),
+ type_=sa.String(length=64),
+ existing_nullable=False)
+
+ with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
+ batch_op.alter_column('tool_name',
+ existing_type=sa.VARCHAR(length=40),
+ type_=sa.String(length=128),
+ existing_nullable=False)
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
+ batch_op.alter_column('tool_name',
+ existing_type=sa.String(length=128),
+ type_=sa.VARCHAR(length=40),
+ existing_nullable=False)
+
+ with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
+ batch_op.alter_column('server_identifier',
+ existing_type=sa.String(length=64),
+ type_=sa.VARCHAR(length=24),
+ existing_nullable=False)
+
+ # ### end Alembic commands ###
diff --git a/api/models/__init__.py b/api/models/__init__.py
index 83b50eb099..1b4bdd32e4 100644
--- a/api/models/__init__.py
+++ b/api/models/__init__.py
@@ -34,6 +34,7 @@ from .model import (
App,
AppAnnotationHitHistory,
AppAnnotationSetting,
+ AppMCPServer,
AppMode,
AppModelConfig,
Conversation,
@@ -103,6 +104,7 @@ __all__ = [
"AppAnnotationHitHistory",
"AppAnnotationSetting",
"AppDatasetJoin",
+ "AppMCPServer", # Added
"AppMode",
"AppModelConfig",
"BuiltinToolProvider",
diff --git a/api/models/_workflow_exc.py b/api/models/_workflow_exc.py
new file mode 100644
index 0000000000..f6271bda47
--- /dev/null
+++ b/api/models/_workflow_exc.py
@@ -0,0 +1,20 @@
+"""All these exceptions are not meant to be caught by callers."""
+
+
+class WorkflowDataError(Exception):
+ """Base class for all workflow data related exceptions.
+
+ This should be used to indicate issues with workflow data integrity, such as
+ no `graph` configuration, missing `nodes` field in `graph` configuration, or
+ similar issues.
+ """
+
+ pass
+
+
+class NodeNotFoundError(WorkflowDataError):
+ """Raised when a node with the specified ID is not found in the workflow."""
+
+ def __init__(self, node_id: str):
+ super().__init__(f"Node with ID '{node_id}' not found in the workflow.")
+ self.node_id = node_id
diff --git a/api/models/account.py b/api/models/account.py
index 7ffeefa980..1af571bc01 100644
--- a/api/models/account.py
+++ b/api/models/account.py
@@ -196,7 +196,7 @@ class Tenant(Base):
__tablename__ = "tenants"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
name = db.Column(db.String(255), nullable=False)
encrypt_public_key = db.Column(db.Text)
plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))
diff --git a/api/models/dataset.py b/api/models/dataset.py
index ad43d6f371..57e54b72a7 100644
--- a/api/models/dataset.py
+++ b/api/models/dataset.py
@@ -140,7 +140,7 @@ class Dataset(Base):
def word_count(self):
return (
db.session.query(Document)
- .with_entities(func.coalesce(func.sum(Document.word_count)))
+ .with_entities(func.coalesce(func.sum(Document.word_count), 0))
.filter(Document.dataset_id == self.id)
.scalar()
)
@@ -255,7 +255,7 @@ class Dataset(Base):
@staticmethod
def gen_collection_name_by_id(dataset_id: str) -> str:
normalized_dataset_id = dataset_id.replace("-", "_")
- return f"Vector_index_{normalized_dataset_id}_Node"
+ return f"{dify_config.VECTOR_INDEX_NAME_PREFIX}_{normalized_dataset_id}_Node"
class DatasetProcessRule(Base):
@@ -448,7 +448,7 @@ class Document(Base):
def hit_count(self):
return (
db.session.query(DocumentSegment)
- .with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
+ .with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
.filter(DocumentSegment.document_id == self.id)
.scalar()
)
diff --git a/api/models/enums.py b/api/models/enums.py
index 4434c3fec8..cc9f28a7bb 100644
--- a/api/models/enums.py
+++ b/api/models/enums.py
@@ -21,3 +21,12 @@ class DraftVariableType(StrEnum):
NODE = "node"
SYS = "sys"
CONVERSATION = "conversation"
+
+
+class MessageStatus(StrEnum):
+ """
+ Message Status Enum
+ """
+
+ NORMAL = "normal"
+ ERROR = "error"
diff --git a/api/models/model.py b/api/models/model.py
index 229e77134e..2377aeed8a 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -10,7 +10,6 @@ from core.plugin.entities.plugin import GenericProviderID
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.signature import sign_tool_file
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
-from services.plugin.plugin_service import PluginService
if TYPE_CHECKING:
from models.workflow import Workflow
@@ -51,7 +50,6 @@ class AppMode(StrEnum):
CHAT = "chat"
ADVANCED_CHAT = "advanced-chat"
AGENT_CHAT = "agent-chat"
- CHANNEL = "channel"
@classmethod
def value_of(cls, value: str) -> "AppMode":
@@ -169,6 +167,7 @@ class App(Base):
@property
def deleted_tools(self) -> list:
from core.tools.tool_manager import ToolManager
+ from services.plugin.plugin_service import PluginService
# get agent mode tools
app_model_config = self.app_model_config
@@ -632,7 +631,14 @@ class Conversation(Base):
system_instruction = db.Column(db.Text)
system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
status = db.Column(db.String(255), nullable=False)
+
+ # The `invoke_from` records how the conversation is created.
+ #
+ # Its value corresponds to the members of `InvokeFrom`.
+ # (api/core/app/entities/app_invoke_entities.py)
invoke_from = db.Column(db.String(255), nullable=True)
+
+ # ref: ConversationSource.
from_source = db.Column(db.String(255), nullable=False)
from_end_user_id = db.Column(StringUUID)
from_account_id = db.Column(StringUUID)
@@ -661,7 +667,7 @@ class Conversation(Base):
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
value["tool_file_id"] = value["related_id"]
- elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE:
+ elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value["upload_file_id"] = value["related_id"]
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
elif isinstance(value, list) and all(
@@ -671,7 +677,7 @@ class Conversation(Base):
for item in value:
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
item["tool_file_id"] = item["related_id"]
- elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE:
+ elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
item["upload_file_id"] = item["related_id"]
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
@@ -703,7 +709,6 @@ class Conversation(Base):
if "model" in override_model_configs:
app_model_config = AppModelConfig()
app_model_config = app_model_config.from_model_config_dict(override_model_configs)
- assert app_model_config is not None, "app model config not found"
model_config = app_model_config.to_dict()
else:
model_config["configs"] = override_model_configs
@@ -817,7 +822,12 @@ class Conversation(Base):
@property
def first_message(self):
- return db.session.query(Message).filter(Message.conversation_id == self.id).first()
+ return (
+ db.session.query(Message)
+ .filter(Message.conversation_id == self.id)
+ .order_by(Message.created_at.asc())
+ .first()
+ )
@property
def app(self):
@@ -894,11 +904,11 @@ class Message(Base):
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
query: Mapped[str] = db.Column(db.Text, nullable=False)
message = db.Column(db.JSON, nullable=False)
- message_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
+ message_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
answer: Mapped[str] = db.Column(db.Text, nullable=False)
- answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
+ answer_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
parent_message_id = db.Column(StringUUID, nullable=True)
@@ -915,7 +925,7 @@ class Message(Base):
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- workflow_run_id = db.Column(StringUUID)
+ workflow_run_id: Mapped[str] = db.Column(StringUUID)
@property
def inputs(self):
@@ -927,7 +937,7 @@ class Message(Base):
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
if value["transfer_method"] == FileTransferMethod.TOOL_FILE:
value["tool_file_id"] = value["related_id"]
- elif value["transfer_method"] == FileTransferMethod.LOCAL_FILE:
+ elif value["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
value["upload_file_id"] = value["related_id"]
inputs[key] = file_factory.build_from_mapping(mapping=value, tenant_id=value["tenant_id"])
elif isinstance(value, list) and all(
@@ -937,7 +947,7 @@ class Message(Base):
for item in value:
if item["transfer_method"] == FileTransferMethod.TOOL_FILE:
item["tool_file_id"] = item["related_id"]
- elif item["transfer_method"] == FileTransferMethod.LOCAL_FILE:
+ elif item["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]:
item["upload_file_id"] = item["related_id"]
inputs[key].append(file_factory.build_from_mapping(mapping=item, tenant_id=item["tenant_id"]))
return inputs
@@ -1437,6 +1447,39 @@ class EndUser(Base, UserMixin):
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+class AppMCPServer(Base):
+ __tablename__ = "app_mcp_servers"
+ __table_args__ = (
+ db.PrimaryKeyConstraint("id", name="app_mcp_server_pkey"),
+ db.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"),
+ db.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"),
+ )
+ id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = db.Column(StringUUID, nullable=False)
+ app_id = db.Column(StringUUID, nullable=False)
+ name = db.Column(db.String(255), nullable=False)
+ description = db.Column(db.String(255), nullable=False)
+ server_code = db.Column(db.String(255), nullable=False)
+ status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
+ parameters = db.Column(db.Text, nullable=False)
+
+ created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+
+ @staticmethod
+ def generate_server_code(n):
+ while True:
+ result = generate_string(n)
+ while db.session.query(AppMCPServer).filter(AppMCPServer.server_code == result).count() > 0:
+ result = generate_string(n)
+
+ return result
+
+ @property
+ def parameters_dict(self) -> dict[str, Any]:
+ return cast(dict[str, Any], json.loads(self.parameters))
+
+
class Site(Base):
__tablename__ = "sites"
__table_args__ = (
diff --git a/api/models/task.py b/api/models/task.py
index d853c1dd9a..1a4b606ff5 100644
--- a/api/models/task.py
+++ b/api/models/task.py
@@ -1,7 +1,6 @@
-from datetime import UTC, datetime
-
from celery import states # type: ignore
+from libs.datetime_utils import naive_utc_now
from models.base import Base
from .engine import db
@@ -18,8 +17,8 @@ class CeleryTask(Base):
result = db.Column(db.PickleType, nullable=True)
date_done = db.Column(
db.DateTime,
- default=lambda: datetime.now(UTC).replace(tzinfo=None),
- onupdate=lambda: datetime.now(UTC).replace(tzinfo=None),
+ default=lambda: naive_utc_now(),
+ onupdate=lambda: naive_utc_now(),
nullable=True,
)
traceback = db.Column(db.Text, nullable=True)
@@ -39,4 +38,4 @@ class CeleryTaskSet(Base):
id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True)
taskset_id = db.Column(db.String(155), unique=True)
result = db.Column(db.PickleType, nullable=True)
- date_done = db.Column(db.DateTime, default=lambda: datetime.now(UTC).replace(tzinfo=None), nullable=True)
+ date_done = db.Column(db.DateTime, default=lambda: naive_utc_now(), nullable=True)
diff --git a/api/models/tools.py b/api/models/tools.py
index 03fbc3acb1..f5fae8b796 100644
--- a/api/models/tools.py
+++ b/api/models/tools.py
@@ -1,12 +1,16 @@
import json
from datetime import datetime
from typing import Any, cast
+from urllib.parse import urlparse
import sqlalchemy as sa
from deprecated import deprecated
from sqlalchemy import ForeignKey, func
from sqlalchemy.orm import Mapped, mapped_column
+from core.file import helpers as file_helpers
+from core.helper import encrypter
+from core.mcp.types import Tool
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration
@@ -17,6 +21,43 @@ from .model import Account, App, Tenant
from .types import StringUUID
+# system level tool oauth client params (client_id, client_secret, etc.)
+class ToolOAuthSystemClient(Base):
+ __tablename__ = "tool_oauth_system_clients"
+ __table_args__ = (
+ db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
+ db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
+ )
+
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
+ provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ # oauth params of the tool provider
+ encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
+
+
+# tenant level tool oauth client params (client_id, client_secret, etc.)
+class ToolOAuthTenantClient(Base):
+ __tablename__ = "tool_oauth_tenant_clients"
+ __table_args__ = (
+ db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
+ db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
+ )
+
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ # tenant id
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
+ provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
+ # oauth params of the tool provider
+ encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
+
+ @property
+ def oauth_params(self) -> dict:
+ return cast(dict, json.loads(self.encrypted_oauth_params or "{}"))
+
+
class BuiltinToolProvider(Base):
"""
This table stores the tool provider information for built-in tools for each tenant.
@@ -25,12 +66,14 @@ class BuiltinToolProvider(Base):
__tablename__ = "tool_builtin_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
- # one tenant can only have one tool provider with the same name
- db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"),
+ db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
)
# id of the tool provider
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ name: Mapped[str] = mapped_column(
+ db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying")
+ )
# id of the tenant
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
# who created this tool provider
@@ -45,6 +88,11 @@ class BuiltinToolProvider(Base):
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
+ is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
+ # credential type, e.g., "api-key", "oauth2"
+ credential_type: Mapped[str] = mapped_column(
+ db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
+ )
@property
def credentials(self) -> dict:
@@ -64,7 +112,7 @@ class ApiToolProvider(Base):
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# name of the api provider
- name = db.Column(db.String(255), nullable=False)
+ name = db.Column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying"))
# icon
icon = db.Column(db.String(255), nullable=False)
# original schema
@@ -189,6 +237,109 @@ class WorkflowToolProvider(Base):
return db.session.query(App).filter(App.id == self.app_id).first()
+class MCPToolProvider(Base):
+ """
+ The table stores the mcp providers.
+ """
+
+ __tablename__ = "tool_mcp_providers"
+ __table_args__ = (
+ db.PrimaryKeyConstraint("id", name="tool_mcp_provider_pkey"),
+ db.UniqueConstraint("tenant_id", "server_url_hash", name="unique_mcp_provider_server_url"),
+ db.UniqueConstraint("tenant_id", "name", name="unique_mcp_provider_name"),
+ db.UniqueConstraint("tenant_id", "server_identifier", name="unique_mcp_provider_server_identifier"),
+ )
+
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ # name of the mcp provider
+ name: Mapped[str] = mapped_column(db.String(40), nullable=False)
+ # server identifier of the mcp provider
+ server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False)
+ # encrypted url of the mcp provider
+ server_url: Mapped[str] = mapped_column(db.Text, nullable=False)
+ # hash of server_url for uniqueness check
+ server_url_hash: Mapped[str] = mapped_column(db.String(64), nullable=False)
+ # icon of the mcp provider
+ icon: Mapped[str] = mapped_column(db.String(255), nullable=True)
+ # tenant id
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ # who created this tool
+ user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ # encrypted credentials
+ encrypted_credentials: Mapped[str] = mapped_column(db.Text, nullable=True)
+ # authed
+ authed: Mapped[bool] = mapped_column(db.Boolean, nullable=False, default=False)
+ # tools
+ tools: Mapped[str] = mapped_column(db.Text, nullable=False, default="[]")
+ created_at: Mapped[datetime] = mapped_column(
+ db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
+ )
+ updated_at: Mapped[datetime] = mapped_column(
+ db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
+ )
+
+ def load_user(self) -> Account | None:
+ return db.session.query(Account).filter(Account.id == self.user_id).first()
+
+ @property
+ def tenant(self) -> Tenant | None:
+ return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
+
+ @property
+ def credentials(self) -> dict:
+ try:
+ return cast(dict, json.loads(self.encrypted_credentials)) or {}
+ except Exception:
+ return {}
+
+ @property
+ def mcp_tools(self) -> list[Tool]:
+ return [Tool(**tool) for tool in json.loads(self.tools)]
+
+ @property
+ def provider_icon(self) -> dict[str, str] | str:
+ try:
+ return cast(dict[str, str], json.loads(self.icon))
+ except json.JSONDecodeError:
+ return file_helpers.get_signed_file_url(self.icon)
+
+ @property
+ def decrypted_server_url(self) -> str:
+ return cast(str, encrypter.decrypt_token(self.tenant_id, self.server_url))
+
+ @property
+ def masked_server_url(self) -> str:
+ def mask_url(url: str, mask_char: str = "*") -> str:
+ """
+ mask the url to a simple string
+ """
+ parsed = urlparse(url)
+ base_url = f"{parsed.scheme}://{parsed.netloc}"
+
+ if parsed.path and parsed.path != "/":
+ return f"{base_url}/{mask_char * 6}"
+ else:
+ return base_url
+
+ return mask_url(self.decrypted_server_url)
+
+ @property
+ def decrypted_credentials(self) -> dict:
+ from core.helper.provider_cache import NoOpProviderCredentialCache
+ from core.tools.mcp_tool.provider import MCPToolProviderController
+ from core.tools.utils.encryption import create_provider_encrypter
+
+ provider_controller = MCPToolProviderController._from_db(self)
+
+ encrypter, _ = create_provider_encrypter(
+ tenant_id=self.tenant_id,
+ config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
+ cache=NoOpProviderCredentialCache(),
+ )
+
+ return encrypter.decrypt(self.credentials) # type: ignore
+
+
class ToolModelInvoke(Base):
"""
store the invoke logs from tool invoke
@@ -207,7 +358,7 @@ class ToolModelInvoke(Base):
# type
tool_type = db.Column(db.String(40), nullable=False)
# tool name
- tool_name = db.Column(db.String(40), nullable=False)
+ tool_name = db.Column(db.String(128), nullable=False)
# invoke parameters
model_parameters = db.Column(db.Text, nullable=False)
# prompt messages
diff --git a/api/models/workflow.py b/api/models/workflow.py
index 2fff045543..124fb3bb4c 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -1,16 +1,25 @@
import json
import logging
from collections.abc import Mapping, Sequence
-from datetime import UTC, datetime
+from datetime import datetime
from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4
from flask_login import current_user
+from sqlalchemy import orm
+from core.file.constants import maybe_file_object
+from core.file.models import File
from core.variables import utils as variable_utils
+from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
-from factories.variable_factory import build_segment
+from core.workflow.nodes.enums import NodeType
+from factories.variable_factory import TypeMismatchError, build_segment_with_type
+from libs.datetime_utils import naive_utc_now
+from libs.helper import extract_tenant_id
+
+from ._workflow_exc import NodeNotFoundError, WorkflowDataError
if TYPE_CHECKING:
from models.model import AppMode
@@ -72,6 +81,10 @@ class WorkflowType(Enum):
return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT
+class _InvalidGraphDefinitionError(Exception):
+ pass
+
+
class Workflow(Base):
"""
Workflow, for `Workflow App` and `Chat App workflow mode`.
@@ -126,7 +139,7 @@ class Workflow(Base):
updated_at: Mapped[datetime] = mapped_column(
db.DateTime,
nullable=False,
- default=datetime.now(UTC).replace(tzinfo=None),
+ default=naive_utc_now(),
server_onupdate=func.current_timestamp(),
)
_environment_variables: Mapped[str] = mapped_column(
@@ -136,6 +149,8 @@ class Workflow(Base):
"conversation_variables", db.Text, nullable=False, server_default="{}"
)
+ VERSION_DRAFT = "draft"
+
@classmethod
def new(
cls,
@@ -165,7 +180,7 @@ class Workflow(Base):
workflow.conversation_variables = conversation_variables or []
workflow.marked_name = marked_name
workflow.marked_comment = marked_comment
- workflow.created_at = datetime.now(UTC).replace(tzinfo=None)
+ workflow.created_at = naive_utc_now()
workflow.updated_at = workflow.created_at
return workflow
@@ -179,8 +194,72 @@ class Workflow(Base):
@property
def graph_dict(self) -> Mapping[str, Any]:
+ # TODO(QuantumGhost): Consider caching `graph_dict` to avoid repeated JSON decoding.
+ #
+ # Using `functools.cached_property` could help, but some code in the codebase may
+ # modify the returned dict, which can cause issues elsewhere.
+ #
+ # For example, changing this property to a cached property led to errors like the
+ # following when single stepping an `Iteration` node:
+ #
+ # Root node id 1748401971780start not found in the graph
+ #
+ # There is currently no standard way to make a dict deeply immutable in Python,
+ # and tracking modifications to the returned dict is difficult. For now, we leave
+ # the code as-is to avoid these issues.
+ #
+ # Currently, the following functions / methods would mutate the returned dict:
+ #
+ # - `_get_graph_and_variable_pool_of_single_iteration`.
+ # - `_get_graph_and_variable_pool_of_single_loop`.
return json.loads(self.graph) if self.graph else {}
+ def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]:
+ """Extract a node configuration from the workflow graph by node ID.
+ A node configuration is a dictionary containing the node's properties, including
+ the node's id, title, and its data as a dict.
+ """
+ workflow_graph = self.graph_dict
+
+ if not workflow_graph:
+ raise WorkflowDataError(f"workflow graph not found, workflow_id={self.id}")
+
+ nodes = workflow_graph.get("nodes")
+ if not nodes:
+ raise WorkflowDataError("nodes not found in workflow graph")
+
+ try:
+ node_config = next(filter(lambda node: node["id"] == node_id, nodes))
+ except StopIteration:
+ raise NodeNotFoundError(node_id)
+ assert isinstance(node_config, dict)
+ return node_config
+
+ @staticmethod
+ def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType:
+ """Extract type of a node from the node configuration returned by `get_node_config_by_id`."""
+ node_config_data = node_config.get("data", {})
+ # Get node class
+ node_type = NodeType(node_config_data.get("type"))
+ return node_type
+
+ @staticmethod
+ def get_enclosing_node_type_and_id(node_config: Mapping[str, Any]) -> tuple[NodeType, str] | None:
+ in_loop = node_config.get("isInLoop", False)
+ in_iteration = node_config.get("isInIteration", False)
+ if in_loop:
+ loop_id = node_config.get("loop_id")
+ if loop_id is None:
+ raise _InvalidGraphDefinitionError("invalid graph")
+ return NodeType.LOOP, loop_id
+ elif in_iteration:
+ iteration_id = node_config.get("iteration_id")
+ if iteration_id is None:
+ raise _InvalidGraphDefinitionError("invalid graph")
+ return NodeType.ITERATION, iteration_id
+ else:
+ return None
+
@property
def features(self) -> str:
"""
@@ -270,18 +349,13 @@ class Workflow(Base):
)
@property
- def environment_variables(self) -> Sequence[Variable]:
+ def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# TODO: find some way to init `self._environment_variables` when instance created.
if self._environment_variables is None:
self._environment_variables = "{}"
# Get tenant_id from current_user (Account or EndUser)
- if isinstance(current_user, Account):
- # Account user
- tenant_id = current_user.current_tenant_id
- else:
- # EndUser
- tenant_id = current_user.tenant_id
+ tenant_id = extract_tenant_id(current_user)
if not tenant_id:
return []
@@ -295,11 +369,15 @@ class Workflow(Base):
def decrypt_func(var):
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
- else:
+ elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
return var
+ else:
+ raise AssertionError("this statement should be unreachable.")
- results = list(map(decrypt_func, results))
- return results
+ decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list(
+ map(decrypt_func, results)
+ )
+ return decrypted_results
@environment_variables.setter
def environment_variables(self, value: Sequence[Variable]):
@@ -308,12 +386,7 @@ class Workflow(Base):
return
# Get tenant_id from current_user (Account or EndUser)
- if isinstance(current_user, Account):
- # Account user
- tenant_id = current_user.current_tenant_id
- else:
- # EndUser
- tenant_id = current_user.tenant_id
+ tenant_id = extract_tenant_id(current_user)
if not tenant_id:
self._environment_variables = "{}"
@@ -376,6 +449,10 @@ class Workflow(Base):
ensure_ascii=False,
)
+ @staticmethod
+ def version_from_datetime(d: datetime) -> str:
+ return str(d)
+
class WorkflowRun(Base):
"""
@@ -386,7 +463,7 @@ class WorkflowRun(Base):
- id (uuid) Run ID
- tenant_id (uuid) Workspace ID
- app_id (uuid) App ID
- - sequence_number (int) Auto-increment sequence number, incremented within the App, starting from 1
+
- workflow_id (uuid) Workflow ID
- type (string) Workflow type
- triggered_from (string) Trigger source
@@ -419,13 +496,12 @@ class WorkflowRun(Base):
__table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_run_pkey"),
db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"),
- db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id: Mapped[str] = mapped_column(StringUUID)
- sequence_number: Mapped[int] = mapped_column()
+
workflow_id: Mapped[str] = mapped_column(StringUUID)
type: Mapped[str] = mapped_column(db.String(255))
triggered_from: Mapped[str] = mapped_column(db.String(255))
@@ -485,7 +561,6 @@ class WorkflowRun(Base):
"id": self.id,
"tenant_id": self.tenant_id,
"app_id": self.app_id,
- "sequence_number": self.sequence_number,
"workflow_id": self.workflow_id,
"type": self.type,
"triggered_from": self.triggered_from,
@@ -511,7 +586,6 @@ class WorkflowRun(Base):
id=data.get("id"),
tenant_id=data.get("tenant_id"),
app_id=data.get("app_id"),
- sequence_number=data.get("sequence_number"),
workflow_id=data.get("workflow_id"),
type=data.get("type"),
triggered_from=data.get("triggered_from"),
@@ -834,12 +908,22 @@ _EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"])
def _naive_utc_datetime():
- return datetime.now(UTC).replace(tzinfo=None)
+ return naive_utc_now()
class WorkflowDraftVariable(Base):
+ """`WorkflowDraftVariable` record variables and outputs generated during
+ debugging worfklow or chatflow.
+
+ IMPORTANT: This model maintains multiple invariant rules that must be preserved.
+ Do not instantiate this class directly with the constructor.
+
+ Instead, use the factory methods (`new_conversation_variable`, `new_sys_variable`,
+ `new_node_variable`) defined below to ensure all invariants are properly maintained.
+ """
+
@staticmethod
- def unique_columns() -> list[str]:
+ def unique_app_id_node_id_name() -> list[str]:
return [
"app_id",
"node_id",
@@ -847,7 +931,9 @@ class WorkflowDraftVariable(Base):
]
__tablename__ = "workflow_draft_variables"
- __table_args__ = (UniqueConstraint(*unique_columns()),)
+ __table_args__ = (UniqueConstraint(*unique_app_id_node_id_name()),)
+ # Required for instance variable annotation.
+ __allow_unmapped__ = True
# id is the unique identifier of a draft variable.
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
@@ -928,6 +1014,36 @@ class WorkflowDraftVariable(Base):
default=None,
)
+ # Cache for deserialized value
+ #
+ # NOTE(QuantumGhost): This field serves two purposes:
+ #
+ # 1. Caches deserialized values to reduce repeated parsing costs
+ # 2. Allows modification of the deserialized value after retrieval,
+ # particularly important for `File`` variables which require database
+ # lookups to obtain storage_key and other metadata
+ #
+ # Use double underscore prefix for better encapsulation,
+ # making this attribute harder to access from outside the class.
+ __value: Segment | None
+
+ def __init__(self, *args, **kwargs):
+ """
+ The constructor of `WorkflowDraftVariable` is not intended for
+ direct use outside this file. Its solo purpose is setup private state
+ used by the model instance.
+
+ Please use the factory methods
+ (`new_conversation_variable`, `new_sys_variable`, `new_node_variable`)
+ defined below to create instances of this class.
+ """
+ super().__init__(*args, **kwargs)
+ self.__value = None
+
+ @orm.reconstructor
+ def _init_on_load(self):
+ self.__value = None
+
def get_selector(self) -> list[str]:
selector = json.loads(self.selector)
if not isinstance(selector, list):
@@ -942,15 +1058,92 @@ class WorkflowDraftVariable(Base):
def _set_selector(self, value: list[str]):
self.selector = json.dumps(value)
- def get_value(self) -> Segment | None:
- return build_segment(json.loads(self.value))
+ def _loads_value(self) -> Segment:
+ value = json.loads(self.value)
+ return self.build_segment_with_type(self.value_type, value)
+
+ @staticmethod
+ def rebuild_file_types(value: Any) -> Any:
+ # NOTE(QuantumGhost): Temporary workaround for structured data handling.
+ # By this point, `output` has been converted to dict by
+ # `WorkflowEntry.handle_special_values`, so we need to
+ # reconstruct File objects from their serialized form
+ # to maintain proper variable saving behavior.
+ #
+ # Ideally, we should work with structured data objects directly
+ # rather than their serialized forms.
+ # However, multiple components in the codebase depend on
+ # `WorkflowEntry.handle_special_values`, making a comprehensive migration challenging.
+ if isinstance(value, dict):
+ if not maybe_file_object(value):
+ return value
+ return File.model_validate(value)
+ elif isinstance(value, list) and value:
+ first = value[0]
+ if not maybe_file_object(first):
+ return value
+ return [File.model_validate(i) for i in value]
+ else:
+ return value
+
+ @classmethod
+ def build_segment_with_type(cls, segment_type: SegmentType, value: Any) -> Segment:
+ # Extends `variable_factory.build_segment_with_type` functionality by
+ # reconstructing `FileSegment`` or `ArrayFileSegment`` objects from
+ # their serialized dictionary or list representations, respectively.
+ if segment_type == SegmentType.FILE:
+ if isinstance(value, File):
+ return build_segment_with_type(segment_type, value)
+ elif isinstance(value, dict):
+ file = cls.rebuild_file_types(value)
+ return build_segment_with_type(segment_type, file)
+ else:
+ raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}")
+ if segment_type == SegmentType.ARRAY_FILE:
+ if not isinstance(value, list):
+ raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}")
+ file_list = cls.rebuild_file_types(value)
+ return build_segment_with_type(segment_type=segment_type, value=file_list)
+
+ return build_segment_with_type(segment_type=segment_type, value=value)
+
+ def get_value(self) -> Segment:
+ """Decode the serialized value into its corresponding `Segment` object.
+
+ This method caches the result, so repeated calls will return the same
+ object instance without re-parsing the serialized data.
+
+ If you need to modify the returned `Segment`, use `value.model_copy()`
+ to create a copy first to avoid affecting the cached instance.
+
+ For more information about the caching mechanism, see the documentation
+ of the `__value` field.
+
+ Returns:
+ Segment: The deserialized value as a Segment object.
+ """
+
+ if self.__value is not None:
+ return self.__value
+ value = self._loads_value()
+ self.__value = value
+ return value
def set_name(self, name: str):
self.name = name
self._set_selector([self.node_id, name])
def set_value(self, value: Segment):
- self.value = json.dumps(value.value)
+ """Updates the `value` and corresponding `value_type` fields in the database model.
+
+ This method also stores the provided Segment object in the deserialized cache
+ without creating a copy, allowing for efficient value access.
+
+ Args:
+ value: The Segment object to store as the variable's value.
+ """
+ self.__value = value
+ self.value = json.dumps(value, cls=variable_utils.SegmentJSONEncoder)
self.value_type = value.value_type
def get_node_id(self) -> str | None:
@@ -976,6 +1169,7 @@ class WorkflowDraftVariable(Base):
node_id: str,
name: str,
value: Segment,
+ node_execution_id: str | None,
description: str = "",
) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
@@ -987,6 +1181,7 @@ class WorkflowDraftVariable(Base):
variable.name = name
variable.set_value(value)
variable._set_selector(list(variable_utils.to_selector(node_id, name)))
+ variable.node_execution_id = node_execution_id
return variable
@classmethod
@@ -996,13 +1191,17 @@ class WorkflowDraftVariable(Base):
app_id: str,
name: str,
value: Segment,
+ description: str = "",
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=CONVERSATION_VARIABLE_NODE_ID,
name=name,
value=value,
+ description=description,
+ node_execution_id=None,
)
+ variable.editable = True
return variable
@classmethod
@@ -1012,9 +1211,16 @@ class WorkflowDraftVariable(Base):
app_id: str,
name: str,
value: Segment,
+ node_execution_id: str,
editable: bool = False,
) -> "WorkflowDraftVariable":
- variable = cls._new(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, value=value)
+ variable = cls._new(
+ app_id=app_id,
+ node_id=SYSTEM_VARIABLE_NODE_ID,
+ name=name,
+ node_execution_id=node_execution_id,
+ value=value,
+ )
variable.editable = editable
return variable
@@ -1026,11 +1232,19 @@ class WorkflowDraftVariable(Base):
node_id: str,
name: str,
value: Segment,
+ node_execution_id: str,
visible: bool = True,
+ editable: bool = True,
) -> "WorkflowDraftVariable":
- variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value)
+ variable = cls._new(
+ app_id=app_id,
+ node_id=node_id,
+ name=name,
+ node_execution_id=node_execution_id,
+ value=value,
+ )
variable.visible = visible
- variable.editable = True
+ variable.editable = editable
return variable
@property
diff --git a/api/mypy.ini b/api/mypy.ini
index 12fe529b08..6836b2602b 100644
--- a/api/mypy.ini
+++ b/api/mypy.ini
@@ -18,4 +18,3 @@ ignore_missing_imports=True
[mypy-flask_restful.inputs]
ignore_missing_imports=True
-
diff --git a/api/pyproject.toml b/api/pyproject.toml
index 9631586ed4..7f1efa671f 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -1,9 +1,10 @@
[project]
name = "dify-api"
-dynamic = ["version"]
+version = "1.6.0"
requires-python = ">=3.11,<3.13"
dependencies = [
+ "arize-phoenix-otel~=0.9.2",
"authlib==1.3.1",
"azure-identity==1.16.1",
"beautifulsoup4==4.12.2",
@@ -81,6 +82,9 @@ dependencies = [
"weave~=0.51.0",
"yarl~=1.18.3",
"webvtt-py~=0.5.1",
+ "sseclient-py>=1.8.0",
+ "httpx-sse>=0.4.0",
+ "sendgrid~=6.12.3",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.
@@ -104,7 +108,7 @@ dev = [
"faker~=32.1.0",
"lxml-stubs~=0.5.1",
"mypy~=1.16.0",
- "ruff~=0.11.5",
+ "ruff~=0.12.3",
"pytest~=8.3.2",
"pytest-benchmark~=4.0.0",
"pytest-cov~=4.1.0",
@@ -148,11 +152,13 @@ dev = [
"types-ujson>=5.10.0",
"boto3-stubs>=1.38.20",
"types-jmespath>=1.0.2.20240106",
+ "hypothesis>=6.131.15",
"types_pyOpenSSL>=24.1.0",
"types_cffi>=1.17.0",
"types_setuptools>=80.9.0",
"pandas-stubs~=2.2.3",
"scipy-stubs>=1.15.3.0",
+ "types-python-http-client>=3.3.7.20240910",
]
############################################################
@@ -195,11 +201,12 @@ vdb = [
"pymochow==1.3.1",
"pyobvector~=0.1.6",
"qdrant-client==1.9.0",
- "tablestore==6.1.0",
+ "tablestore==6.2.0",
"tcvectordb~=1.6.4",
"tidb-vector==0.0.9",
"upstash-vector==0.6.0",
"volcengine-compat~=1.0.0",
"weaviate-client~=3.24.0",
"xinference-client~=1.2.2",
+ "mo-vector~=0.1.13",
]
diff --git a/api/repositories/__init__.py b/api/repositories/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py
new file mode 100644
index 0000000000..00a2d1f87d
--- /dev/null
+++ b/api/repositories/api_workflow_node_execution_repository.py
@@ -0,0 +1,197 @@
+"""
+Service-layer repository protocol for WorkflowNodeExecutionModel operations.
+
+This module provides a protocol interface for service-layer operations on WorkflowNodeExecutionModel
+that abstracts database queries currently done directly in service classes. This repository is
+specifically designed for service-layer needs and is separate from the core domain repository.
+
+The service repository handles operations that require access to database-specific fields like
+tenant_id, app_id, triggered_from, etc., which are not part of the core domain model.
+"""
+
+from collections.abc import Sequence
+from datetime import datetime
+from typing import Optional, Protocol
+
+from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from models.workflow import WorkflowNodeExecutionModel
+
+
+class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol):
+ """
+ Protocol for service-layer operations on WorkflowNodeExecutionModel.
+
+ This repository provides database access patterns specifically needed by service classes,
+ handling queries that involve database-specific fields and multi-tenancy concerns.
+
+ Key responsibilities:
+ - Manages database operations for workflow node executions
+ - Handles multi-tenant data isolation
+ - Provides batch processing capabilities
+ - Supports execution lifecycle management
+
+ Implementation notes:
+ - Returns database models directly (WorkflowNodeExecutionModel)
+ - Handles tenant/app filtering automatically
+ - Provides service-specific query patterns
+ - Focuses on database operations without domain logic
+ - Supports cleanup and maintenance operations
+ """
+
+ def get_node_last_execution(
+ self,
+ tenant_id: str,
+ app_id: str,
+ workflow_id: str,
+ node_id: str,
+ ) -> Optional[WorkflowNodeExecutionModel]:
+ """
+ Get the most recent execution for a specific node.
+
+ This method finds the latest execution of a specific node within a workflow,
+ ordered by creation time. Used primarily for debugging and inspection purposes.
+
+ Args:
+ tenant_id: The tenant identifier
+ app_id: The application identifier
+ workflow_id: The workflow identifier
+ node_id: The node identifier
+
+ Returns:
+ The most recent WorkflowNodeExecutionModel for the node, or None if not found
+ """
+ ...
+
+ def get_executions_by_workflow_run(
+ self,
+ tenant_id: str,
+ app_id: str,
+ workflow_run_id: str,
+ ) -> Sequence[WorkflowNodeExecutionModel]:
+ """
+ Get all node executions for a specific workflow run.
+
+ This method retrieves all node executions that belong to a specific workflow run,
+ ordered by index in descending order for proper trace visualization.
+
+ Args:
+ tenant_id: The tenant identifier
+ app_id: The application identifier
+ workflow_run_id: The workflow run identifier
+
+ Returns:
+ A sequence of WorkflowNodeExecutionModel instances ordered by index (desc)
+ """
+ ...
+
+ def get_execution_by_id(
+ self,
+ execution_id: str,
+ tenant_id: Optional[str] = None,
+ ) -> Optional[WorkflowNodeExecutionModel]:
+ """
+ Get a workflow node execution by its ID.
+
+ This method retrieves a specific execution by its unique identifier.
+ Tenant filtering is optional for cases where the execution ID is globally unique.
+
+ When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants.
+ If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should
+ set `tenant_id` to prevent horizontal privilege escalation.
+
+ Args:
+ execution_id: The execution identifier
+ tenant_id: Optional tenant identifier for additional filtering
+
+ Returns:
+ The WorkflowNodeExecutionModel if found, or None if not found
+ """
+ ...
+
+ def delete_expired_executions(
+ self,
+ tenant_id: str,
+ before_date: datetime,
+ batch_size: int = 1000,
+ ) -> int:
+ """
+ Delete workflow node executions that are older than the specified date.
+
+ This method is used for cleanup operations to remove expired executions
+ in batches to avoid overwhelming the database.
+
+ Args:
+ tenant_id: The tenant identifier
+ before_date: Delete executions created before this date
+ batch_size: Maximum number of executions to delete in one batch
+
+ Returns:
+ The number of executions deleted
+ """
+ ...
+
+ def delete_executions_by_app(
+ self,
+ tenant_id: str,
+ app_id: str,
+ batch_size: int = 1000,
+ ) -> int:
+ """
+ Delete all workflow node executions for a specific app.
+
+ This method is used when removing an app and all its related data.
+ Executions are deleted in batches to avoid overwhelming the database.
+
+ Args:
+ tenant_id: The tenant identifier
+ app_id: The application identifier
+ batch_size: Maximum number of executions to delete in one batch
+
+ Returns:
+ The total number of executions deleted
+ """
+ ...
+
+ def get_expired_executions_batch(
+ self,
+ tenant_id: str,
+ before_date: datetime,
+ batch_size: int = 1000,
+ ) -> Sequence[WorkflowNodeExecutionModel]:
+ """
+ Get a batch of expired workflow node executions for backup purposes.
+
+ This method retrieves expired executions without deleting them,
+ allowing the caller to backup the data before deletion.
+
+ Args:
+ tenant_id: The tenant identifier
+ before_date: Get executions created before this date
+ batch_size: Maximum number of executions to retrieve
+
+ Returns:
+ A sequence of WorkflowNodeExecutionModel instances
+ """
+ ...
+
+ def delete_executions_by_ids(
+ self,
+ execution_ids: Sequence[str],
+ ) -> int:
+ """
+ Delete workflow node executions by their IDs.
+
+ This method deletes specific executions by their IDs,
+ typically used after backing up the data.
+
+ This method does not perform tenant isolation checks. The caller is responsible for ensuring proper
+ data isolation between tenants. When execution IDs come from untrusted sources (e.g., API requests),
+ additional tenant validation should be implemented to prevent unauthorized access.
+
+ Args:
+ execution_ids: List of execution IDs to delete
+
+ Returns:
+ The number of executions deleted
+ """
+ ...
diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py
new file mode 100644
index 0000000000..59e7baeb79
--- /dev/null
+++ b/api/repositories/api_workflow_run_repository.py
@@ -0,0 +1,181 @@
+"""
+API WorkflowRun Repository Protocol
+
+This module defines the protocol for service-layer WorkflowRun operations.
+The repository provides an abstraction layer for WorkflowRun database operations
+used by service classes, separating service-layer concerns from core domain logic.
+
+Key Features:
+- Paginated workflow run queries with filtering
+- Bulk deletion operations with OSS backup support
+- Multi-tenant data isolation
+- Expired record cleanup with data retention
+- Service-layer specific query patterns
+
+Usage:
+ This protocol should be used by service classes that need to perform
+ WorkflowRun database operations. It provides a clean interface that
+ hides implementation details and supports dependency injection.
+
+Example:
+ ```python
+ from repositories.dify_api_repository_factory import DifyAPIRepositoryFactory
+
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+
+ # Get paginated workflow runs
+ runs = repo.get_paginated_workflow_runs(
+ tenant_id="tenant-123",
+ app_id="app-456",
+ triggered_from="debugging",
+ limit=20
+ )
+ ```
+"""
+
+from collections.abc import Sequence
+from datetime import datetime
+from typing import Optional, Protocol
+
+from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
+from libs.infinite_scroll_pagination import InfiniteScrollPagination
+from models.workflow import WorkflowRun
+
+
+class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
+ """
+ Protocol for service-layer WorkflowRun repository operations.
+
+ This protocol defines the interface for WorkflowRun database operations
+ that are specific to service-layer needs, including pagination, filtering,
+ and bulk operations with data backup support.
+ """
+
+ def get_paginated_workflow_runs(
+ self,
+ tenant_id: str,
+ app_id: str,
+ triggered_from: str,
+ limit: int = 20,
+ last_id: Optional[str] = None,
+ ) -> InfiniteScrollPagination:
+ """
+ Get paginated workflow runs with filtering.
+
+ Retrieves workflow runs for a specific app and trigger source with
+ cursor-based pagination support. Used primarily for debugging and
+ workflow run listing in the UI.
+
+ Args:
+ tenant_id: Tenant identifier for multi-tenant isolation
+ app_id: Application identifier
+ triggered_from: Filter by trigger source (e.g., "debugging", "app-run")
+ limit: Maximum number of records to return (default: 20)
+ last_id: Cursor for pagination - ID of the last record from previous page
+
+ Returns:
+ InfiniteScrollPagination object containing:
+ - data: List of WorkflowRun objects
+ - limit: Applied limit
+ - has_more: Boolean indicating if more records exist
+
+ Raises:
+ ValueError: If last_id is provided but the corresponding record doesn't exist
+ """
+ ...
+
+ def get_workflow_run_by_id(
+ self,
+ tenant_id: str,
+ app_id: str,
+ run_id: str,
+ ) -> Optional[WorkflowRun]:
+ """
+ Get a specific workflow run by ID.
+
+ Retrieves a single workflow run with tenant and app isolation.
+ Used for workflow run detail views and execution tracking.
+
+ Args:
+ tenant_id: Tenant identifier for multi-tenant isolation
+ app_id: Application identifier
+ run_id: Workflow run identifier
+
+ Returns:
+ WorkflowRun object if found, None otherwise
+ """
+ ...
+
+ def get_expired_runs_batch(
+ self,
+ tenant_id: str,
+ before_date: datetime,
+ batch_size: int = 1000,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Get a batch of expired workflow runs for cleanup.
+
+ Retrieves workflow runs created before the specified date for
+ cleanup operations. Used by scheduled tasks to remove old data
+ while maintaining data retention policies.
+
+ Args:
+ tenant_id: Tenant identifier for multi-tenant isolation
+ before_date: Only return runs created before this date
+ batch_size: Maximum number of records to return
+
+ Returns:
+ Sequence of WorkflowRun objects to be processed for cleanup
+ """
+ ...
+
+ def delete_runs_by_ids(
+ self,
+ run_ids: Sequence[str],
+ ) -> int:
+ """
+ Delete workflow runs by their IDs.
+
+ Performs bulk deletion of workflow runs by ID. This method should
+ be used after backing up the data to OSS storage for retention.
+
+ Args:
+ run_ids: Sequence of workflow run IDs to delete
+
+ Returns:
+ Number of records actually deleted
+
+ Note:
+ This method performs hard deletion. Ensure data is backed up
+ to OSS storage before calling this method for compliance with
+ data retention policies.
+ """
+ ...
+
+ def delete_runs_by_app(
+ self,
+ tenant_id: str,
+ app_id: str,
+ batch_size: int = 1000,
+ ) -> int:
+ """
+ Delete all workflow runs for a specific app.
+
+ Performs bulk deletion of all workflow runs associated with an app.
+ Used during app cleanup operations. Processes records in batches
+ to avoid memory issues and long-running transactions.
+
+ Args:
+ tenant_id: Tenant identifier for multi-tenant isolation
+ app_id: Application identifier
+ batch_size: Number of records to process in each batch
+
+ Returns:
+ Total number of records deleted across all batches
+
+ Note:
+ This method performs hard deletion without backup. Use with caution
+ and ensure proper data retention policies are followed.
+ """
+ ...
diff --git a/api/repositories/factory.py b/api/repositories/factory.py
new file mode 100644
index 0000000000..0a0adbf2c2
--- /dev/null
+++ b/api/repositories/factory.py
@@ -0,0 +1,103 @@
+"""
+DifyAPI Repository Factory for creating repository instances.
+
+This factory is specifically designed for DifyAPI repositories that handle
+service-layer operations with dependency injection patterns.
+"""
+
+import logging
+
+from sqlalchemy.orm import sessionmaker
+
+from configs import dify_config
+from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError
+from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+
+logger = logging.getLogger(__name__)
+
+
+class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
+ """
+ Factory for creating DifyAPI repository instances based on configuration.
+
+ This factory handles the creation of repositories that are specifically designed
+ for service-layer operations and use dependency injection with sessionmaker
+ for better testability and separation of concerns.
+ """
+
+ @classmethod
+ def create_api_workflow_node_execution_repository(
+ cls, session_maker: sessionmaker
+ ) -> DifyAPIWorkflowNodeExecutionRepository:
+ """
+ Create a DifyAPIWorkflowNodeExecutionRepository instance based on configuration.
+
+ This repository is designed for service-layer operations and uses dependency injection
+ with a sessionmaker for better testability and separation of concerns. It provides
+ database access patterns specifically needed by service classes, handling queries
+ that involve database-specific fields and multi-tenancy concerns.
+
+ Args:
+ session_maker: SQLAlchemy sessionmaker to inject for database session management.
+
+ Returns:
+ Configured DifyAPIWorkflowNodeExecutionRepository instance
+
+ Raises:
+ RepositoryImportError: If the configured repository cannot be imported or instantiated
+ """
+ class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY
+ logger.debug(f"Creating DifyAPIWorkflowNodeExecutionRepository from: {class_path}")
+
+ try:
+ repository_class = cls._import_class(class_path)
+ cls._validate_repository_interface(repository_class, DifyAPIWorkflowNodeExecutionRepository)
+ # Service repository requires session_maker parameter
+ cls._validate_constructor_signature(repository_class, ["session_maker"])
+
+ return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
+ except RepositoryImportError:
+ # Re-raise our custom errors as-is
+ raise
+ except Exception as e:
+ logger.exception("Failed to create DifyAPIWorkflowNodeExecutionRepository")
+ raise RepositoryImportError(
+ f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
+ ) from e
+
+ @classmethod
+ def create_api_workflow_run_repository(cls, session_maker: sessionmaker) -> APIWorkflowRunRepository:
+ """
+ Create an APIWorkflowRunRepository instance based on configuration.
+
+ This repository is designed for service-layer WorkflowRun operations and uses dependency
+ injection with a sessionmaker for better testability and separation of concerns. It provides
+ database access patterns specifically needed by service classes for workflow run management,
+ including pagination, filtering, and bulk operations.
+
+ Args:
+ session_maker: SQLAlchemy sessionmaker to inject for database session management.
+
+ Returns:
+ Configured APIWorkflowRunRepository instance
+
+ Raises:
+ RepositoryImportError: If the configured repository cannot be imported or instantiated
+ """
+ class_path = dify_config.API_WORKFLOW_RUN_REPOSITORY
+ logger.debug(f"Creating APIWorkflowRunRepository from: {class_path}")
+
+ try:
+ repository_class = cls._import_class(class_path)
+ cls._validate_repository_interface(repository_class, APIWorkflowRunRepository)
+ # Service repository requires session_maker parameter
+ cls._validate_constructor_signature(repository_class, ["session_maker"])
+
+ return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
+ except RepositoryImportError:
+ # Re-raise our custom errors as-is
+ raise
+ except Exception as e:
+ logger.exception("Failed to create APIWorkflowRunRepository")
+ raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e
diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py
new file mode 100644
index 0000000000..e6a23ddf9f
--- /dev/null
+++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py
@@ -0,0 +1,290 @@
+"""
+SQLAlchemy implementation of WorkflowNodeExecutionServiceRepository.
+
+This module provides a concrete implementation of the service repository protocol
+using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
+"""
+
+from collections.abc import Sequence
+from datetime import datetime
+from typing import Optional
+
+from sqlalchemy import delete, desc, select
+from sqlalchemy.orm import Session, sessionmaker
+
+from models.workflow import WorkflowNodeExecutionModel
+from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
+
+
+class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
+ """
+ SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
+
+ This repository provides service-layer database operations for WorkflowNodeExecutionModel
+ using SQLAlchemy 2.0 style queries. It implements the DifyAPIWorkflowNodeExecutionRepository
+ protocol with the following features:
+
+ - Multi-tenancy data isolation through tenant_id filtering
+ - Direct database model operations without domain conversion
+ - Batch processing for efficient large-scale operations
+ - Optimized query patterns for common access patterns
+ - Dependency injection for better testability and maintainability
+ - Session management and transaction handling with proper cleanup
+ - Maintenance operations for data lifecycle management
+ - Thread-safe database operations using session-per-request pattern
+ """
+
+ def __init__(self, session_maker: sessionmaker[Session]):
+ """
+ Initialize the repository with a sessionmaker.
+
+ Args:
+ session_maker: SQLAlchemy sessionmaker for creating database sessions
+ """
+ self._session_maker = session_maker
+
+ def get_node_last_execution(
+ self,
+ tenant_id: str,
+ app_id: str,
+ workflow_id: str,
+ node_id: str,
+ ) -> Optional[WorkflowNodeExecutionModel]:
+ """
+ Get the most recent execution for a specific node.
+
+ This method replicates the query pattern from WorkflowService.get_node_last_run()
+ using SQLAlchemy 2.0 style syntax.
+
+ Args:
+ tenant_id: The tenant identifier
+ app_id: The application identifier
+ workflow_id: The workflow identifier
+ node_id: The node identifier
+
+ Returns:
+ The most recent WorkflowNodeExecutionModel for the node, or None if not found
+ """
+ stmt = (
+ select(WorkflowNodeExecutionModel)
+ .where(
+ WorkflowNodeExecutionModel.tenant_id == tenant_id,
+ WorkflowNodeExecutionModel.app_id == app_id,
+ WorkflowNodeExecutionModel.workflow_id == workflow_id,
+ WorkflowNodeExecutionModel.node_id == node_id,
+ )
+ .order_by(desc(WorkflowNodeExecutionModel.created_at))
+ .limit(1)
+ )
+
+ with self._session_maker() as session:
+ return session.scalar(stmt)
+
+ def get_executions_by_workflow_run(
+ self,
+ tenant_id: str,
+ app_id: str,
+ workflow_run_id: str,
+ ) -> Sequence[WorkflowNodeExecutionModel]:
+ """
+ Get all node executions for a specific workflow run.
+
+ This method replicates the query pattern from WorkflowRunService.get_workflow_run_node_executions()
+ using SQLAlchemy 2.0 style syntax.
+
+ Args:
+ tenant_id: The tenant identifier
+ app_id: The application identifier
+ workflow_run_id: The workflow run identifier
+
+ Returns:
+ A sequence of WorkflowNodeExecutionModel instances ordered by index (desc)
+ """
+ stmt = (
+ select(WorkflowNodeExecutionModel)
+ .where(
+ WorkflowNodeExecutionModel.tenant_id == tenant_id,
+ WorkflowNodeExecutionModel.app_id == app_id,
+ WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
+ )
+ .order_by(desc(WorkflowNodeExecutionModel.index))
+ )
+
+ with self._session_maker() as session:
+ return session.execute(stmt).scalars().all()
+
+ def get_execution_by_id(
+ self,
+ execution_id: str,
+ tenant_id: Optional[str] = None,
+ ) -> Optional[WorkflowNodeExecutionModel]:
+ """
+ Get a workflow node execution by its ID.
+
+ This method replicates the query pattern from WorkflowDraftVariableService
+ and WorkflowService.single_step_run_workflow_node() using SQLAlchemy 2.0 style syntax.
+
+ When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants.
+ If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should
+ set `tenant_id` to prevent horizontal privilege escalation.
+
+ Args:
+ execution_id: The execution identifier
+ tenant_id: Optional tenant identifier for additional filtering
+
+ Returns:
+ The WorkflowNodeExecutionModel if found, or None if not found
+ """
+ stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution_id)
+
+ # Add tenant filtering if provided
+ if tenant_id is not None:
+ stmt = stmt.where(WorkflowNodeExecutionModel.tenant_id == tenant_id)
+
+ with self._session_maker() as session:
+ return session.scalar(stmt)
+
+ def delete_expired_executions(
+ self,
+ tenant_id: str,
+ before_date: datetime,
+ batch_size: int = 1000,
+ ) -> int:
+ """
+ Delete workflow node executions that are older than the specified date.
+
+ Args:
+ tenant_id: The tenant identifier
+ before_date: Delete executions created before this date
+ batch_size: Maximum number of executions to delete in one batch
+
+ Returns:
+ The number of executions deleted
+ """
+ total_deleted = 0
+
+ while True:
+ with self._session_maker() as session:
+ # Find executions to delete in batches
+ stmt = (
+ select(WorkflowNodeExecutionModel.id)
+ .where(
+ WorkflowNodeExecutionModel.tenant_id == tenant_id,
+ WorkflowNodeExecutionModel.created_at < before_date,
+ )
+ .limit(batch_size)
+ )
+
+ execution_ids = session.execute(stmt).scalars().all()
+ if not execution_ids:
+ break
+
+ # Delete the batch
+ delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
+ result = session.execute(delete_stmt)
+ session.commit()
+ total_deleted += result.rowcount
+
+ # If we deleted fewer than the batch size, we're done
+ if len(execution_ids) < batch_size:
+ break
+
+ return total_deleted
+
+ def delete_executions_by_app(
+ self,
+ tenant_id: str,
+ app_id: str,
+ batch_size: int = 1000,
+ ) -> int:
+ """
+ Delete all workflow node executions for a specific app.
+
+ Args:
+ tenant_id: The tenant identifier
+ app_id: The application identifier
+ batch_size: Maximum number of executions to delete in one batch
+
+ Returns:
+ The total number of executions deleted
+ """
+ total_deleted = 0
+
+ while True:
+ with self._session_maker() as session:
+ # Find executions to delete in batches
+ stmt = (
+ select(WorkflowNodeExecutionModel.id)
+ .where(
+ WorkflowNodeExecutionModel.tenant_id == tenant_id,
+ WorkflowNodeExecutionModel.app_id == app_id,
+ )
+ .limit(batch_size)
+ )
+
+ execution_ids = session.execute(stmt).scalars().all()
+ if not execution_ids:
+ break
+
+ # Delete the batch
+ delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
+ result = session.execute(delete_stmt)
+ session.commit()
+ total_deleted += result.rowcount
+
+ # If we deleted fewer than the batch size, we're done
+ if len(execution_ids) < batch_size:
+ break
+
+ return total_deleted
+
+ def get_expired_executions_batch(
+ self,
+ tenant_id: str,
+ before_date: datetime,
+ batch_size: int = 1000,
+ ) -> Sequence[WorkflowNodeExecutionModel]:
+ """
+ Get a batch of expired workflow node executions for backup purposes.
+
+ Args:
+ tenant_id: The tenant identifier
+ before_date: Get executions created before this date
+ batch_size: Maximum number of executions to retrieve
+
+ Returns:
+ A sequence of WorkflowNodeExecutionModel instances
+ """
+ stmt = (
+ select(WorkflowNodeExecutionModel)
+ .where(
+ WorkflowNodeExecutionModel.tenant_id == tenant_id,
+ WorkflowNodeExecutionModel.created_at < before_date,
+ )
+ .limit(batch_size)
+ )
+
+ with self._session_maker() as session:
+ return session.execute(stmt).scalars().all()
+
+ def delete_executions_by_ids(
+ self,
+ execution_ids: Sequence[str],
+ ) -> int:
+ """
+ Delete workflow node executions by their IDs.
+
+ Args:
+ execution_ids: List of execution IDs to delete
+
+ Returns:
+ The number of executions deleted
+ """
+ if not execution_ids:
+ return 0
+
+ with self._session_maker() as session:
+ stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
+ result = session.execute(stmt)
+ session.commit()
+ return result.rowcount
diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py
new file mode 100644
index 0000000000..ebd1d74b20
--- /dev/null
+++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py
@@ -0,0 +1,203 @@
+"""
+SQLAlchemy API WorkflowRun Repository Implementation
+
+This module provides the SQLAlchemy-based implementation of the APIWorkflowRunRepository
+protocol. It handles service-layer WorkflowRun database operations using SQLAlchemy 2.0
+style queries with proper session management and multi-tenant data isolation.
+
+Key Features:
+- SQLAlchemy 2.0 style queries for modern database operations
+- Cursor-based pagination for efficient large dataset handling
+- Bulk operations with batch processing for performance
+- Multi-tenant data isolation and security
+- Proper session management with dependency injection
+
+Implementation Notes:
+- Uses sessionmaker for consistent session management
+- Implements cursor-based pagination using created_at timestamps
+- Provides efficient bulk deletion with batch processing
+- Maintains data consistency with proper transaction handling
+"""
+
+import logging
+from collections.abc import Sequence
+from datetime import datetime
+from typing import Optional, cast
+
+from sqlalchemy import delete, select
+from sqlalchemy.orm import Session, sessionmaker
+
+from libs.infinite_scroll_pagination import InfiniteScrollPagination
+from models.workflow import WorkflowRun
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+
+logger = logging.getLogger(__name__)
+
+
+class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
+ """
+ SQLAlchemy implementation of APIWorkflowRunRepository.
+
+ Provides service-layer WorkflowRun database operations using SQLAlchemy 2.0
+ style queries. Supports dependency injection through sessionmaker and
+ maintains proper multi-tenant data isolation.
+
+ Args:
+ session_maker: SQLAlchemy sessionmaker instance for database connections
+ """
+
+ def __init__(self, session_maker: sessionmaker[Session]) -> None:
+ """
+ Initialize the repository with a sessionmaker.
+
+ Args:
+ session_maker: SQLAlchemy sessionmaker for database connections
+ """
+ self._session_maker = session_maker
+
+ def get_paginated_workflow_runs(
+ self,
+ tenant_id: str,
+ app_id: str,
+ triggered_from: str,
+ limit: int = 20,
+ last_id: Optional[str] = None,
+ ) -> InfiniteScrollPagination:
+ """
+ Get paginated workflow runs with filtering.
+
+ Implements cursor-based pagination using created_at timestamps for
+ efficient handling of large datasets. Filters by tenant, app, and
+ trigger source for proper data isolation.
+ """
+ with self._session_maker() as session:
+ # Build base query with filters
+ base_stmt = select(WorkflowRun).where(
+ WorkflowRun.tenant_id == tenant_id,
+ WorkflowRun.app_id == app_id,
+ WorkflowRun.triggered_from == triggered_from,
+ )
+
+ if last_id:
+ # Get the last workflow run for cursor-based pagination
+ last_run_stmt = base_stmt.where(WorkflowRun.id == last_id)
+ last_workflow_run = session.scalar(last_run_stmt)
+
+ if not last_workflow_run:
+ raise ValueError("Last workflow run not exists")
+
+ # Get records created before the last run's timestamp
+ base_stmt = base_stmt.where(
+ WorkflowRun.created_at < last_workflow_run.created_at,
+ WorkflowRun.id != last_workflow_run.id,
+ )
+
+ # First page - get most recent records
+ workflow_runs = session.scalars(base_stmt.order_by(WorkflowRun.created_at.desc()).limit(limit + 1)).all()
+
+ # Check if there are more records for pagination
+ has_more = len(workflow_runs) > limit
+ if has_more:
+ workflow_runs = workflow_runs[:-1]
+
+ return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
+
+ def get_workflow_run_by_id(
+ self,
+ tenant_id: str,
+ app_id: str,
+ run_id: str,
+ ) -> Optional[WorkflowRun]:
+ """
+ Get a specific workflow run by ID with tenant and app isolation.
+ """
+ with self._session_maker() as session:
+ stmt = select(WorkflowRun).where(
+ WorkflowRun.tenant_id == tenant_id,
+ WorkflowRun.app_id == app_id,
+ WorkflowRun.id == run_id,
+ )
+ return cast(Optional[WorkflowRun], session.scalar(stmt))
+
+ def get_expired_runs_batch(
+ self,
+ tenant_id: str,
+ before_date: datetime,
+ batch_size: int = 1000,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Get a batch of expired workflow runs for cleanup operations.
+ """
+ with self._session_maker() as session:
+ stmt = (
+ select(WorkflowRun)
+ .where(
+ WorkflowRun.tenant_id == tenant_id,
+ WorkflowRun.created_at < before_date,
+ )
+ .limit(batch_size)
+ )
+ return cast(Sequence[WorkflowRun], session.scalars(stmt).all())
+
+ def delete_runs_by_ids(
+ self,
+ run_ids: Sequence[str],
+ ) -> int:
+ """
+ Delete workflow runs by their IDs using bulk deletion.
+ """
+ if not run_ids:
+ return 0
+
+ with self._session_maker() as session:
+ stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
+ result = session.execute(stmt)
+ session.commit()
+
+ deleted_count = cast(int, result.rowcount)
+ logger.info(f"Deleted {deleted_count} workflow runs by IDs")
+ return deleted_count
+
+ def delete_runs_by_app(
+ self,
+ tenant_id: str,
+ app_id: str,
+ batch_size: int = 1000,
+ ) -> int:
+ """
+ Delete all workflow runs for a specific app in batches.
+ """
+ total_deleted = 0
+
+ while True:
+ with self._session_maker() as session:
+ # Get a batch of run IDs to delete
+ stmt = (
+ select(WorkflowRun.id)
+ .where(
+ WorkflowRun.tenant_id == tenant_id,
+ WorkflowRun.app_id == app_id,
+ )
+ .limit(batch_size)
+ )
+ run_ids = session.scalars(stmt).all()
+
+ if not run_ids:
+ break
+
+ # Delete the batch
+ delete_stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
+ result = session.execute(delete_stmt)
+ session.commit()
+
+ batch_deleted = result.rowcount
+ total_deleted += batch_deleted
+
+ logger.info(f"Deleted batch of {batch_deleted} workflow runs for app {app_id}")
+
+ # If we deleted fewer records than the batch size, we're done
+ if batch_deleted < batch_size:
+ break
+
+ logger.info(f"Total deleted {total_deleted} workflow runs for app {app_id}")
+ return total_deleted
diff --git a/api/services/account_service.py b/api/services/account_service.py
index 14d238467d..352efb2f0c 100644
--- a/api/services/account_service.py
+++ b/api/services/account_service.py
@@ -16,7 +16,8 @@ from configs import dify_config
from constants.languages import language_timezone_mapping, languages
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
-from extensions.ext_redis import redis_client
+from extensions.ext_redis import redis_client, redis_fallback
+from libs.datetime_utils import naive_utc_now
from libs.helper import RateLimiter, TokenManager
from libs.passport import PassportService
from libs.password import compare_password, hash_password, valid_password
@@ -52,8 +53,14 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces
from services.feature_service import FeatureService
from tasks.delete_account_task import delete_account_task
from tasks.mail_account_deletion_task import send_account_deletion_verification_code
+from tasks.mail_change_mail_task import send_change_mail_task
from tasks.mail_email_code_login import send_email_code_login_mail_task
from tasks.mail_invite_member_task import send_invite_member_mail_task
+from tasks.mail_owner_transfer_task import (
+ send_new_owner_transfer_notify_email_task,
+ send_old_owner_transfer_notify_email_task,
+ send_owner_transfer_confirm_task,
+)
from tasks.mail_reset_password_task import send_reset_password_mail_task
@@ -75,8 +82,13 @@ class AccountService:
email_code_account_deletion_rate_limiter = RateLimiter(
prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1
)
+ change_email_rate_limiter = RateLimiter(prefix="change_email_rate_limit", max_attempts=1, time_window=60 * 1)
+ owner_transfer_rate_limiter = RateLimiter(prefix="owner_transfer_rate_limit", max_attempts=1, time_window=60 * 1)
+
LOGIN_MAX_ERROR_LIMITS = 5
FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5
+ CHANGE_EMAIL_MAX_ERROR_LIMITS = 5
+ OWNER_TRANSFER_MAX_ERROR_LIMITS = 5
@staticmethod
def _get_refresh_token_key(refresh_token: str) -> str:
@@ -124,8 +136,8 @@ class AccountService:
available_ta.current = True
db.session.commit()
- if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10):
- account.last_active_at = datetime.now(UTC).replace(tzinfo=None)
+ if naive_utc_now() - account.last_active_at > timedelta(minutes=10):
+ account.last_active_at = naive_utc_now()
db.session.commit()
return cast(Account, account)
@@ -169,7 +181,7 @@ class AccountService:
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
- account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
+ account.initialized_at = naive_utc_now()
db.session.commit()
@@ -307,7 +319,7 @@ class AccountService:
# If it exists, update the record
account_integrate.open_id = open_id
account_integrate.encrypted_token = "" # todo
- account_integrate.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ account_integrate.updated_at = naive_utc_now()
else:
# If it does not exist, create a new record
account_integrate = AccountIntegrate(
@@ -342,7 +354,7 @@ class AccountService:
@staticmethod
def update_login_info(account: Account, *, ip_address: str) -> None:
"""Update last login time and ip"""
- account.last_login_at = datetime.now(UTC).replace(tzinfo=None)
+ account.last_login_at = naive_utc_now()
account.last_login_ip = ip_address
db.session.add(account)
db.session.commit()
@@ -419,6 +431,101 @@ class AccountService:
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
return token
+ @classmethod
+ def send_change_email_email(
+ cls,
+ account: Optional[Account] = None,
+ email: Optional[str] = None,
+ old_email: Optional[str] = None,
+ language: Optional[str] = "en-US",
+ phase: Optional[str] = None,
+ ):
+ account_email = account.email if account else email
+ if account_email is None:
+ raise ValueError("Email must be provided.")
+
+ if cls.change_email_rate_limiter.is_rate_limited(account_email):
+ from controllers.console.auth.error import EmailChangeRateLimitExceededError
+
+ raise EmailChangeRateLimitExceededError()
+
+ code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
+
+ send_change_mail_task.delay(
+ language=language,
+ to=account_email,
+ code=code,
+ phase=phase,
+ )
+ cls.change_email_rate_limiter.increment_rate_limit(account_email)
+ return token
+
+ @classmethod
+ def send_owner_transfer_email(
+ cls,
+ account: Optional[Account] = None,
+ email: Optional[str] = None,
+ language: Optional[str] = "en-US",
+ workspace_name: Optional[str] = "",
+ ):
+ account_email = account.email if account else email
+ if account_email is None:
+ raise ValueError("Email must be provided.")
+
+ if cls.owner_transfer_rate_limiter.is_rate_limited(account_email):
+ from controllers.console.auth.error import OwnerTransferRateLimitExceededError
+
+ raise OwnerTransferRateLimitExceededError()
+
+ code, token = cls.generate_owner_transfer_token(account_email, account)
+
+ send_owner_transfer_confirm_task.delay(
+ language=language,
+ to=account_email,
+ code=code,
+ workspace=workspace_name,
+ )
+ cls.owner_transfer_rate_limiter.increment_rate_limit(account_email)
+ return token
+
+ @classmethod
+ def send_old_owner_transfer_notify_email(
+ cls,
+ account: Optional[Account] = None,
+ email: Optional[str] = None,
+ language: Optional[str] = "en-US",
+ workspace_name: Optional[str] = "",
+ new_owner_email: Optional[str] = "",
+ ):
+ account_email = account.email if account else email
+ if account_email is None:
+ raise ValueError("Email must be provided.")
+
+ send_old_owner_transfer_notify_email_task.delay(
+ language=language,
+ to=account_email,
+ workspace=workspace_name,
+ new_owner_email=new_owner_email,
+ )
+
+ @classmethod
+ def send_new_owner_transfer_notify_email(
+ cls,
+ account: Optional[Account] = None,
+ email: Optional[str] = None,
+ language: Optional[str] = "en-US",
+ workspace_name: Optional[str] = "",
+ ):
+ account_email = account.email if account else email
+ if account_email is None:
+ raise ValueError("Email must be provided.")
+
+ send_new_owner_transfer_notify_email_task.delay(
+ language=language,
+ to=account_email,
+ workspace=workspace_name,
+ )
+
@classmethod
def generate_reset_password_token(
cls,
@@ -435,14 +542,64 @@ class AccountService:
)
return code, token
+ @classmethod
+ def generate_change_email_token(
+ cls,
+ email: str,
+ account: Optional[Account] = None,
+ code: Optional[str] = None,
+ old_email: Optional[str] = None,
+ additional_data: dict[str, Any] = {},
+ ):
+ if not code:
+ code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
+ additional_data["code"] = code
+ additional_data["old_email"] = old_email
+ token = TokenManager.generate_token(
+ account=account, email=email, token_type="change_email", additional_data=additional_data
+ )
+ return code, token
+
+ @classmethod
+ def generate_owner_transfer_token(
+ cls,
+ email: str,
+ account: Optional[Account] = None,
+ code: Optional[str] = None,
+ additional_data: dict[str, Any] = {},
+ ):
+ if not code:
+ code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
+ additional_data["code"] = code
+ token = TokenManager.generate_token(
+ account=account, email=email, token_type="owner_transfer", additional_data=additional_data
+ )
+ return code, token
+
@classmethod
def revoke_reset_password_token(cls, token: str):
TokenManager.revoke_token(token, "reset_password")
+ @classmethod
+ def revoke_change_email_token(cls, token: str):
+ TokenManager.revoke_token(token, "change_email")
+
+ @classmethod
+ def revoke_owner_transfer_token(cls, token: str):
+ TokenManager.revoke_token(token, "owner_transfer")
+
@classmethod
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "reset_password")
+ @classmethod
+ def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]:
+ return TokenManager.get_token_data(token, "change_email")
+
+ @classmethod
+ def get_owner_transfer_data(cls, token: str) -> Optional[dict[str, Any]]:
+ return TokenManager.get_token_data(token, "owner_transfer")
+
@classmethod
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
@@ -495,6 +652,7 @@ class AccountService:
return account
@staticmethod
+ @redis_fallback(default_return=None)
def add_login_error_rate_limit(email: str) -> None:
key = f"login_error_rate_limit:{email}"
count = redis_client.get(key)
@@ -504,6 +662,7 @@ class AccountService:
redis_client.setex(key, dify_config.LOGIN_LOCKOUT_DURATION, count)
@staticmethod
+ @redis_fallback(default_return=False)
def is_login_error_rate_limit(email: str) -> bool:
key = f"login_error_rate_limit:{email}"
count = redis_client.get(key)
@@ -516,11 +675,13 @@ class AccountService:
return False
@staticmethod
+ @redis_fallback(default_return=None)
def reset_login_error_rate_limit(email: str):
key = f"login_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
+ @redis_fallback(default_return=None)
def add_forgot_password_error_rate_limit(email: str) -> None:
key = f"forgot_password_error_rate_limit:{email}"
count = redis_client.get(key)
@@ -530,6 +691,7 @@ class AccountService:
redis_client.setex(key, dify_config.FORGOT_PASSWORD_LOCKOUT_DURATION, count)
@staticmethod
+ @redis_fallback(default_return=False)
def is_forgot_password_error_rate_limit(email: str) -> bool:
key = f"forgot_password_error_rate_limit:{email}"
count = redis_client.get(key)
@@ -542,11 +704,69 @@ class AccountService:
return False
@staticmethod
+ @redis_fallback(default_return=None)
def reset_forgot_password_error_rate_limit(email: str):
key = f"forgot_password_error_rate_limit:{email}"
redis_client.delete(key)
@staticmethod
+ @redis_fallback(default_return=None)
+ def add_change_email_error_rate_limit(email: str) -> None:
+ key = f"change_email_error_rate_limit:{email}"
+ count = redis_client.get(key)
+ if count is None:
+ count = 0
+ count = int(count) + 1
+ redis_client.setex(key, dify_config.CHANGE_EMAIL_LOCKOUT_DURATION, count)
+
+ @staticmethod
+ @redis_fallback(default_return=False)
+ def is_change_email_error_rate_limit(email: str) -> bool:
+ key = f"change_email_error_rate_limit:{email}"
+ count = redis_client.get(key)
+ if count is None:
+ return False
+ count = int(count)
+ if count > AccountService.CHANGE_EMAIL_MAX_ERROR_LIMITS:
+ return True
+ return False
+
+ @staticmethod
+ @redis_fallback(default_return=None)
+ def reset_change_email_error_rate_limit(email: str):
+ key = f"change_email_error_rate_limit:{email}"
+ redis_client.delete(key)
+
+ @staticmethod
+ @redis_fallback(default_return=None)
+ def add_owner_transfer_error_rate_limit(email: str) -> None:
+ key = f"owner_transfer_error_rate_limit:{email}"
+ count = redis_client.get(key)
+ if count is None:
+ count = 0
+ count = int(count) + 1
+ redis_client.setex(key, dify_config.OWNER_TRANSFER_LOCKOUT_DURATION, count)
+
+ @staticmethod
+ @redis_fallback(default_return=False)
+ def is_owner_transfer_error_rate_limit(email: str) -> bool:
+ key = f"owner_transfer_error_rate_limit:{email}"
+ count = redis_client.get(key)
+ if count is None:
+ return False
+ count = int(count)
+ if count > AccountService.OWNER_TRANSFER_MAX_ERROR_LIMITS:
+ return True
+ return False
+
+ @staticmethod
+ @redis_fallback(default_return=None)
+ def reset_owner_transfer_error_rate_limit(email: str):
+ key = f"owner_transfer_error_rate_limit:{email}"
+ redis_client.delete(key)
+
+ @staticmethod
+ @redis_fallback(default_return=False)
def is_email_send_ip_limit(ip_address: str):
minute_key = f"email_send_ip_limit_minute:{ip_address}"
freeze_key = f"email_send_ip_limit_freeze:{ip_address}"
@@ -586,6 +806,10 @@ class AccountService:
return False
+ @staticmethod
+ def check_email_unique(email: str) -> bool:
+ return db.session.query(Account).filter_by(email=email).first() is None
+
class TenantService:
@staticmethod
@@ -843,21 +1067,21 @@ class TenantService:
target_member_join.role = new_role
db.session.commit()
- @staticmethod
- def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
- """Dissolve tenant"""
- if not TenantService.check_member_permission(tenant, operator, operator, "remove"):
- raise NoPermissionError("No permission to dissolve tenant.")
- db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
- db.session.delete(tenant)
- db.session.commit()
-
@staticmethod
def get_custom_config(tenant_id: str) -> dict:
tenant = db.get_or_404(Tenant, tenant_id)
return cast(dict, tenant.custom_config_dict)
+ @staticmethod
+ def is_owner(account: Account, tenant: Tenant) -> bool:
+ return TenantService.get_user_role(account, tenant) == TenantAccountRole.OWNER
+
+ @staticmethod
+ def is_member(account: Account, tenant: Tenant) -> bool:
+ """Check if the account is a member of the tenant"""
+ return TenantService.get_user_role(account, tenant) is not None
+
class RegisterService:
@classmethod
@@ -885,11 +1109,11 @@ class RegisterService:
)
account.last_login_ip = ip_address
- account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
+ account.initialized_at = naive_utc_now()
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
- dify_setup = DifySetup(version=dify_config.CURRENT_VERSION)
+ dify_setup = DifySetup(version=dify_config.project.version)
db.session.add(dify_setup)
db.session.commit()
except Exception as e:
@@ -926,7 +1150,7 @@ class RegisterService:
is_setup=is_setup,
)
account.status = AccountStatus.ACTIVE.value if not status else status.value
- account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
+ account.initialized_at = naive_utc_now()
if open_id is not None and provider is not None:
AccountService.link_account_integrate(provider, open_id, account)
diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py
index d2875180d8..08e13c588e 100644
--- a/api/services/app_dsl_service.py
+++ b/api/services/app_dsl_service.py
@@ -32,6 +32,7 @@ from models import Account, App, AppMode
from models.model import AppModelConfig
from models.workflow import Workflow
from services.plugin.dependencies_analysis import DependenciesAnalysisService
+from services.workflow_draft_variable_service import WorkflowDraftVariableService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
@@ -292,6 +293,8 @@ class AppDslService:
dependencies=check_dependencies_pending_data,
)
+ draft_var_srv = WorkflowDraftVariableService(session=self._session)
+ draft_var_srv.delete_workflow_variables(app_id=app.id)
return Import(
id=import_id,
status=status,
@@ -421,7 +424,7 @@ class AppDslService:
# Set icon type
icon_type_value = icon_type or app_data.get("icon_type")
- if icon_type_value in ["emoji", "link"]:
+ if icon_type_value in ["emoji", "link", "image"]:
icon_type = icon_type_value
else:
icon_type = "emoji"
@@ -572,13 +575,26 @@ class AppDslService:
raise ValueError("Missing draft workflow configuration, please check.")
workflow_dict = workflow.to_dict(include_secret=include_secret)
+ # TODO: refactor: we need a better way to filter workspace related data from nodes
for node in workflow_dict.get("graph", {}).get("nodes", []):
- if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
- dataset_ids = node["data"].get("dataset_ids", [])
- node["data"]["dataset_ids"] = [
+ node_data = node.get("data", {})
+ if not node_data:
+ continue
+ data_type = node_data.get("type", "")
+ if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value:
+ dataset_ids = node_data.get("dataset_ids", [])
+ node_data["dataset_ids"] = [
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
for dataset_id in dataset_ids
]
+ # filter credential id from tool node
+ if not include_secret and data_type == NodeType.TOOL.value:
+ node_data.pop("credential_id", None)
+ # filter credential id from agent node
+ if not include_secret and data_type == NodeType.AGENT.value:
+ for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
+ tool.pop("credential_id", None)
+
export_data["workflow"] = workflow_dict
dependencies = cls._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [
@@ -599,7 +615,15 @@ class AppDslService:
if not app_model_config:
raise ValueError("Missing app configuration, please check.")
- export_data["model_config"] = app_model_config.to_dict()
+ model_config = app_model_config.to_dict()
+
+ # TODO: refactor: we need a better way to filter workspace related data from model config
+ # filter credential id from model config
+ for tool in model_config.get("agent_mode", {}).get("tools", []):
+ tool.pop("credential_id", None)
+
+ export_data["model_config"] = model_config
+
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
diff --git a/api/services/app_service.py b/api/services/app_service.py
index d08462d001..3494b2796b 100644
--- a/api/services/app_service.py
+++ b/api/services/app_service.py
@@ -1,6 +1,5 @@
import json
import logging
-from datetime import UTC, datetime
from typing import Optional, cast
from flask_login import current_user
@@ -17,6 +16,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models.account import Account
from models.model import App, AppMode, AppModelConfig, Site
from models.tools import ApiToolProvider
@@ -47,8 +47,6 @@ class AppService:
filters.append(App.mode == AppMode.ADVANCED_CHAT.value)
elif args["mode"] == "agent-chat":
filters.append(App.mode == AppMode.AGENT_CHAT.value)
- elif args["mode"] == "channel":
- filters.append(App.mode == AppMode.CHANNEL.value)
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
@@ -235,8 +233,9 @@ class AppService:
app.icon = args.get("icon")
app.icon_background = args.get("icon_background")
app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False)
+ app.max_active_requests = args.get("max_active_requests")
app.updated_by = current_user.id
- app.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ app.updated_at = naive_utc_now()
db.session.commit()
return app
@@ -250,7 +249,7 @@ class AppService:
"""
app.name = name
app.updated_by = current_user.id
- app.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ app.updated_at = naive_utc_now()
db.session.commit()
return app
@@ -266,7 +265,7 @@ class AppService:
app.icon = icon
app.icon_background = icon_background
app.updated_by = current_user.id
- app.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ app.updated_at = naive_utc_now()
db.session.commit()
return app
@@ -283,7 +282,7 @@ class AppService:
app.enable_site = enable_site
app.updated_by = current_user.id
- app.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ app.updated_at = naive_utc_now()
db.session.commit()
return app
@@ -300,7 +299,7 @@ class AppService:
app.enable_api = enable_api
app.updated_by = current_user.id
- app.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ app.updated_at = naive_utc_now()
db.session.commit()
return app
diff --git a/api/services/audio_service.py b/api/services/audio_service.py
index a259f5a4c4..e8923eb51b 100644
--- a/api/services/audio_service.py
+++ b/api/services/audio_service.py
@@ -1,13 +1,17 @@
import io
import logging
import uuid
+from collections.abc import Generator
from typing import Optional
+from flask import Response, stream_with_context
from werkzeug.datastructures import FileStorage
from constants import AUDIO_EXTENSIONS
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
+from extensions.ext_database import db
+from models.enums import MessageStatus
from models.model import App, AppMode, AppModelConfig, Message
from services.errors.audio import (
AudioTooLargeServiceError,
@@ -16,6 +20,7 @@ from services.errors.audio import (
ProviderNotSupportTextToSpeechServiceError,
UnsupportedAudioTypeServiceError,
)
+from services.workflow_service import WorkflowService
FILE_SIZE = 30
FILE_SIZE_LIMIT = FILE_SIZE * 1024 * 1024
@@ -74,35 +79,36 @@ class AudioService:
voice: Optional[str] = None,
end_user: Optional[str] = None,
message_id: Optional[str] = None,
+ is_draft: bool = False,
):
- from collections.abc import Generator
-
- from flask import Response, stream_with_context
-
from app import app
- from extensions.ext_database import db
- def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None):
+ def invoke_tts(text_content: str, app_model: App, voice: Optional[str] = None, is_draft: bool = False):
with app.app_context():
- if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
- workflow = app_model.workflow
- if workflow is None:
- raise ValueError("TTS is not enabled")
-
- features_dict = workflow.features_dict
- if "text_to_speech" not in features_dict or not features_dict["text_to_speech"].get("enabled"):
- raise ValueError("TTS is not enabled")
-
- voice = features_dict["text_to_speech"].get("voice") if voice is None else voice
- else:
- if app_model.app_model_config is None:
- raise ValueError("AppModelConfig not found")
- text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
-
- if not text_to_speech_dict.get("enabled"):
- raise ValueError("TTS is not enabled")
-
- voice = text_to_speech_dict.get("voice") if voice is None else voice
+ if voice is None:
+ if app_model.mode in {AppMode.ADVANCED_CHAT.value, AppMode.WORKFLOW.value}:
+ if is_draft:
+ workflow = WorkflowService().get_draft_workflow(app_model=app_model)
+ else:
+ workflow = app_model.workflow
+ if (
+ workflow is None
+ or "text_to_speech" not in workflow.features_dict
+ or not workflow.features_dict["text_to_speech"].get("enabled")
+ ):
+ raise ValueError("TTS is not enabled")
+
+ voice = workflow.features_dict["text_to_speech"].get("voice")
+ else:
+ if not is_draft:
+ if app_model.app_model_config is None:
+ raise ValueError("AppModelConfig not found")
+ text_to_speech_dict = app_model.app_model_config.text_to_speech_dict
+
+ if not text_to_speech_dict.get("enabled"):
+ raise ValueError("TTS is not enabled")
+
+ voice = text_to_speech_dict.get("voice")
model_manager = ModelManager()
model_instance = model_manager.get_default_model_instance(
@@ -132,18 +138,18 @@ class AudioService:
message = db.session.query(Message).filter(Message.id == message_id).first()
if message is None:
return None
- if message.answer == "" and message.status == "normal":
+ if message.answer == "" and message.status == MessageStatus.NORMAL:
return None
else:
- response = invoke_tts(message.answer, app_model=app_model, voice=voice)
+ response = invoke_tts(text_content=message.answer, app_model=app_model, voice=voice, is_draft=is_draft)
if isinstance(response, Generator):
return Response(stream_with_context(response), content_type="audio/mpeg")
return response
else:
if text is None:
raise ValueError("Text is required")
- response = invoke_tts(text, app_model, voice)
+ response = invoke_tts(text_content=text, app_model=app_model, voice=voice, is_draft=is_draft)
if isinstance(response, Generator):
return Response(stream_with_context(response), content_type="audio/mpeg")
return response
diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py
index 1fd560d581..ddd16b2e0c 100644
--- a/api/services/clear_free_plan_tenant_expired_logs.py
+++ b/api/services/clear_free_plan_tenant_expired_logs.py
@@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
import click
from flask import Flask, current_app
-from sqlalchemy.orm import Session
+from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -14,7 +14,7 @@ from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Tenant
from models.model import App, Conversation, Message
-from models.workflow import WorkflowNodeExecutionModel, WorkflowRun
+from repositories.factory import DifyAPIRepositoryFactory
from services.billing_service import BillingService
logger = logging.getLogger(__name__)
@@ -105,84 +105,99 @@ class ClearFreePlanTenantExpiredLogs:
)
)
- while True:
- with Session(db.engine).no_autoflush as session:
- workflow_node_executions = (
- session.query(WorkflowNodeExecutionModel)
- .filter(
- WorkflowNodeExecutionModel.tenant_id == tenant_id,
- WorkflowNodeExecutionModel.created_at
- < datetime.datetime.now() - datetime.timedelta(days=days),
- )
- .limit(batch)
- .all()
- )
+ # Process expired workflow node executions with backup
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
+ before_date = datetime.datetime.now() - datetime.timedelta(days=days)
+ total_deleted = 0
- if len(workflow_node_executions) == 0:
- break
+ while True:
+ # Get a batch of expired executions for backup
+ workflow_node_executions = node_execution_repo.get_expired_executions_batch(
+ tenant_id=tenant_id,
+ before_date=before_date,
+ batch_size=batch,
+ )
- # save workflow node executions
- storage.save(
- f"free_plan_tenant_expired_logs/"
- f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}"
- f"-{time.time()}.json",
- json.dumps(
- jsonable_encoder(workflow_node_executions),
- ).encode("utf-8"),
- )
+ if len(workflow_node_executions) == 0:
+ break
+
+ # Save workflow node executions to storage
+ storage.save(
+ f"free_plan_tenant_expired_logs/"
+ f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}"
+ f"-{time.time()}.json",
+ json.dumps(
+ jsonable_encoder(workflow_node_executions),
+ ).encode("utf-8"),
+ )
- workflow_node_execution_ids = [
- workflow_node_execution.id for workflow_node_execution in workflow_node_executions
- ]
+ # Extract IDs for deletion
+ workflow_node_execution_ids = [
+ workflow_node_execution.id for workflow_node_execution in workflow_node_executions
+ ]
- # delete workflow node executions
- session.query(WorkflowNodeExecutionModel).filter(
- WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids),
- ).delete(synchronize_session=False)
- session.commit()
+ # Delete the backed up executions
+ deleted_count = node_execution_repo.delete_executions_by_ids(workflow_node_execution_ids)
+ total_deleted += deleted_count
- click.echo(
- click.style(
- f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}"
- f" workflow node executions for tenant {tenant_id}"
- )
+ click.echo(
+ click.style(
+ f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}"
+ f" workflow node executions for tenant {tenant_id}"
)
+ )
+
+ # If we got fewer than the batch size, we're done
+ if len(workflow_node_executions) < batch:
+ break
+
+ # Process expired workflow runs with backup
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+ before_date = datetime.datetime.now() - datetime.timedelta(days=days)
+ total_deleted = 0
while True:
- with Session(db.engine).no_autoflush as session:
- workflow_runs = (
- session.query(WorkflowRun)
- .filter(
- WorkflowRun.tenant_id == tenant_id,
- WorkflowRun.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
- )
- .limit(batch)
- .all()
- )
+ # Get a batch of expired workflow runs for backup
+ workflow_runs = workflow_run_repo.get_expired_runs_batch(
+ tenant_id=tenant_id,
+ before_date=before_date,
+ batch_size=batch,
+ )
- if len(workflow_runs) == 0:
- break
+ if len(workflow_runs) == 0:
+ break
+
+ # Save workflow runs to storage
+ storage.save(
+ f"free_plan_tenant_expired_logs/"
+ f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
+ f"-{time.time()}.json",
+ json.dumps(
+ jsonable_encoder(
+ [workflow_run.to_dict() for workflow_run in workflow_runs],
+ ),
+ ).encode("utf-8"),
+ )
- # save workflow runs
+ # Extract IDs for deletion
+ workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs]
- storage.save(
- f"free_plan_tenant_expired_logs/"
- f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
- f"-{time.time()}.json",
- json.dumps(
- jsonable_encoder(
- [workflow_run.to_dict() for workflow_run in workflow_runs],
- ),
- ).encode("utf-8"),
- )
+ # Delete the backed up workflow runs
+ deleted_count = workflow_run_repo.delete_runs_by_ids(workflow_run_ids)
+ total_deleted += deleted_count
- workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs]
+ click.echo(
+ click.style(
+ f"[{datetime.datetime.now()}] Processed {len(workflow_run_ids)}"
+ f" workflow runs for tenant {tenant_id}"
+ )
+ )
- # delete workflow runs
- session.query(WorkflowRun).filter(
- WorkflowRun.id.in_(workflow_run_ids),
- ).delete(synchronize_session=False)
- session.commit()
+ # If we got fewer than the batch size, we're done
+ if len(workflow_runs) < batch:
+ break
@classmethod
def process(cls, days: int, batch: int, tenant_ids: list[str]):
diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py
index afdaa49465..40097d5ed5 100644
--- a/api/services/conversation_service.py
+++ b/api/services/conversation_service.py
@@ -1,5 +1,4 @@
from collections.abc import Callable, Sequence
-from datetime import UTC, datetime
from typing import Optional, Union
from sqlalchemy import asc, desc, func, or_, select
@@ -8,6 +7,7 @@ from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import ConversationVariable
from models.account import Account
@@ -113,7 +113,7 @@ class ConversationService:
return cls.auto_generate_name(app_model, conversation)
else:
conversation.name = name
- conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ conversation.updated_at = naive_utc_now()
db.session.commit()
return conversation
@@ -169,7 +169,7 @@ class ConversationService:
conversation = cls.get_conversation(app_model, conversation_id, user)
conversation.is_deleted = True
- conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ conversation.updated_at = naive_utc_now()
db.session.commit()
@classmethod
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index e98b47921f..09cdd66e04 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -26,6 +26,7 @@ from events.document_event import document_was_deleted
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
+from libs.datetime_utils import naive_utc_now
from models.account import Account, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
@@ -59,6 +60,7 @@ from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureModel, FeatureService
from services.tag_service import TagService
from services.vector_service import VectorService
+from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.batch_clean_document_task import batch_clean_document_task
from tasks.clean_notion_document_task import clean_notion_document_task
from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task
@@ -70,6 +72,7 @@ from tasks.document_indexing_update_task import document_indexing_update_task
from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task
from tasks.enable_segments_to_index_task import enable_segments_to_index_task
from tasks.recover_document_indexing_task import recover_document_indexing_task
+from tasks.remove_document_from_index_task import remove_document_from_index_task
from tasks.retry_document_indexing_task import retry_document_indexing_task
from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task
@@ -276,176 +279,351 @@ class DatasetService:
except ProviderTokenNotInitError as ex:
raise ValueError(ex.description)
+ @staticmethod
+ def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str):
+ try:
+ model_manager = ModelManager()
+ model_manager.get_model_instance(
+ tenant_id=tenant_id,
+ provider=reranking_model_provider,
+ model_type=ModelType.RERANK,
+ model=reranking_model,
+ )
+ except LLMBadRequestError:
+ raise ValueError(
+ "No Rerank Model available. Please configure a valid provider in the Settings -> Model Provider."
+ )
+ except ProviderTokenNotInitError as ex:
+ raise ValueError(ex.description)
+
@staticmethod
def update_dataset(dataset_id, data, user):
+ """
+ Update dataset configuration and settings.
+
+ Args:
+ dataset_id: The unique identifier of the dataset to update
+ data: Dictionary containing the update data
+ user: The user performing the update operation
+
+ Returns:
+ Dataset: The updated dataset object
+
+ Raises:
+ ValueError: If dataset not found or validation fails
+ NoPermissionError: If user lacks permission to update the dataset
+ """
+ # Retrieve and validate dataset existence
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise ValueError("Dataset not found")
+ # Verify user has permission to update this dataset
DatasetService.check_dataset_permission(dataset, user)
+
+ # Handle external dataset updates
if dataset.provider == "external":
- external_retrieval_model = data.get("external_retrieval_model", None)
- if external_retrieval_model:
- dataset.retrieval_model = external_retrieval_model
- dataset.name = data.get("name", dataset.name)
- dataset.description = data.get("description", "")
- permission = data.get("permission")
- if permission:
- dataset.permission = permission
- external_knowledge_id = data.get("external_knowledge_id", None)
- db.session.add(dataset)
- if not external_knowledge_id:
- raise ValueError("External knowledge id is required.")
- external_knowledge_api_id = data.get("external_knowledge_api_id", None)
- if not external_knowledge_api_id:
- raise ValueError("External knowledge api id is required.")
-
- with Session(db.engine) as session:
- external_knowledge_binding = (
- session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
- )
+ return DatasetService._update_external_dataset(dataset, data, user)
+ else:
+ return DatasetService._update_internal_dataset(dataset, data, user)
- if not external_knowledge_binding:
- raise ValueError("External knowledge binding not found.")
+ @staticmethod
+ def _update_external_dataset(dataset, data, user):
+ """
+ Update external dataset configuration.
+
+ Args:
+ dataset: The dataset object to update
+ data: Update data dictionary
+ user: User performing the update
+
+ Returns:
+ Dataset: Updated dataset object
+ """
+ # Update retrieval model if provided
+ external_retrieval_model = data.get("external_retrieval_model", None)
+ if external_retrieval_model:
+ dataset.retrieval_model = external_retrieval_model
+
+ # Update basic dataset properties
+ dataset.name = data.get("name", dataset.name)
+ dataset.description = data.get("description", dataset.description)
+
+ # Update permission if provided
+ permission = data.get("permission")
+ if permission:
+ dataset.permission = permission
+
+ # Validate and update external knowledge configuration
+ external_knowledge_id = data.get("external_knowledge_id", None)
+ external_knowledge_api_id = data.get("external_knowledge_api_id", None)
+
+ if not external_knowledge_id:
+ raise ValueError("External knowledge id is required.")
+ if not external_knowledge_api_id:
+ raise ValueError("External knowledge api id is required.")
+ # Update metadata fields
+ dataset.updated_by = user.id if user else None
+ dataset.updated_at = datetime.datetime.utcnow()
+ db.session.add(dataset)
- if (
- external_knowledge_binding.external_knowledge_id != external_knowledge_id
- or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
- ):
- external_knowledge_binding.external_knowledge_id = external_knowledge_id
- external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
- db.session.add(external_knowledge_binding)
- db.session.commit()
- else:
- data.pop("partial_member_list", None)
- data.pop("external_knowledge_api_id", None)
- data.pop("external_knowledge_id", None)
- data.pop("external_retrieval_model", None)
- filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"}
- action = None
- if dataset.indexing_technique != data["indexing_technique"]:
- # if update indexing_technique
- if data["indexing_technique"] == "economy":
- action = "remove"
- filtered_data["embedding_model"] = None
- filtered_data["embedding_model_provider"] = None
- filtered_data["collection_binding_id"] = None
- elif data["indexing_technique"] == "high_quality":
- action = "add"
- # get embedding model setting
- try:
- model_manager = ModelManager()
- embedding_model = model_manager.get_model_instance(
- tenant_id=current_user.current_tenant_id,
- provider=data["embedding_model_provider"],
- model_type=ModelType.TEXT_EMBEDDING,
- model=data["embedding_model"],
- )
- filtered_data["embedding_model"] = embedding_model.model
- filtered_data["embedding_model_provider"] = embedding_model.provider
- dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
- embedding_model.provider, embedding_model.model
- )
- filtered_data["collection_binding_id"] = dataset_collection_binding.id
- except LLMBadRequestError:
- raise ValueError(
- "No Embedding Model available. Please configure a valid provider "
- "in the Settings -> Model Provider."
- )
- except ProviderTokenNotInitError as ex:
- raise ValueError(ex.description)
- else:
- # add default plugin id to both setting sets, to make sure the plugin model provider is consistent
- # Skip embedding model checks if not provided in the update request
- if (
- "embedding_model_provider" not in data
- or "embedding_model" not in data
- or not data.get("embedding_model_provider")
- or not data.get("embedding_model")
- ):
- # If the dataset already has embedding model settings, use those
- if dataset.embedding_model_provider and dataset.embedding_model:
- # Keep existing values
- filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
- filtered_data["embedding_model"] = dataset.embedding_model
- # If collection_binding_id exists, keep it too
- if dataset.collection_binding_id:
- filtered_data["collection_binding_id"] = dataset.collection_binding_id
- # Otherwise, don't try to update embedding model settings at all
- # Remove these fields from filtered_data if they exist but are None/empty
- if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]:
- del filtered_data["embedding_model_provider"]
- if "embedding_model" in filtered_data and not filtered_data["embedding_model"]:
- del filtered_data["embedding_model"]
- else:
- skip_embedding_update = False
- try:
- # Handle existing model provider
- plugin_model_provider = dataset.embedding_model_provider
- plugin_model_provider_str = None
- if plugin_model_provider:
- plugin_model_provider_str = str(ModelProviderID(plugin_model_provider))
-
- # Handle new model provider from request
- new_plugin_model_provider = data["embedding_model_provider"]
- new_plugin_model_provider_str = None
- if new_plugin_model_provider:
- new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider))
-
- # Only update embedding model if both values are provided and different from current
- if (
- plugin_model_provider_str != new_plugin_model_provider_str
- or data["embedding_model"] != dataset.embedding_model
- ):
- action = "update"
- model_manager = ModelManager()
- try:
- embedding_model = model_manager.get_model_instance(
- tenant_id=current_user.current_tenant_id,
- provider=data["embedding_model_provider"],
- model_type=ModelType.TEXT_EMBEDDING,
- model=data["embedding_model"],
- )
- except ProviderTokenNotInitError:
- # If we can't get the embedding model, skip updating it
- # and keep the existing settings if available
- if dataset.embedding_model_provider and dataset.embedding_model:
- filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
- filtered_data["embedding_model"] = dataset.embedding_model
- if dataset.collection_binding_id:
- filtered_data["collection_binding_id"] = dataset.collection_binding_id
- # Skip the rest of the embedding model update
- skip_embedding_update = True
- if not skip_embedding_update:
- filtered_data["embedding_model"] = embedding_model.model
- filtered_data["embedding_model_provider"] = embedding_model.provider
- dataset_collection_binding = (
- DatasetCollectionBindingService.get_dataset_collection_binding(
- embedding_model.provider, embedding_model.model
- )
- )
- filtered_data["collection_binding_id"] = dataset_collection_binding.id
- except LLMBadRequestError:
- raise ValueError(
- "No Embedding Model available. Please configure a valid provider "
- "in the Settings -> Model Provider."
- )
- except ProviderTokenNotInitError as ex:
- raise ValueError(ex.description)
+ # Update external knowledge binding
+ DatasetService._update_external_knowledge_binding(dataset.id, external_knowledge_id, external_knowledge_api_id)
+
+ # Commit changes to database
+ db.session.commit()
+
+ return dataset
- filtered_data["updated_by"] = user.id
- filtered_data["updated_at"] = datetime.datetime.now()
+ @staticmethod
+ def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id):
+ """
+ Update external knowledge binding configuration.
+
+ Args:
+ dataset_id: Dataset identifier
+ external_knowledge_id: External knowledge identifier
+ external_knowledge_api_id: External knowledge API identifier
+ """
+ with Session(db.engine) as session:
+ external_knowledge_binding = (
+ session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first()
+ )
- # update Retrieval model
- filtered_data["retrieval_model"] = data["retrieval_model"]
+ if not external_knowledge_binding:
+ raise ValueError("External knowledge binding not found.")
- db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data)
+ # Update binding if values have changed
+ if (
+ external_knowledge_binding.external_knowledge_id != external_knowledge_id
+ or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id
+ ):
+ external_knowledge_binding.external_knowledge_id = external_knowledge_id
+ external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id
+ db.session.add(external_knowledge_binding)
+
+ @staticmethod
+ def _update_internal_dataset(dataset, data, user):
+ """
+ Update internal dataset configuration.
+
+ Args:
+ dataset: The dataset object to update
+ data: Update data dictionary
+ user: User performing the update
+
+ Returns:
+ Dataset: Updated dataset object
+ """
+ # Remove external-specific fields from update data
+ data.pop("partial_member_list", None)
+ data.pop("external_knowledge_api_id", None)
+ data.pop("external_knowledge_id", None)
+ data.pop("external_retrieval_model", None)
+
+ # Filter out None values except for description field
+ filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"}
+
+ # Handle indexing technique changes and embedding model updates
+ action = DatasetService._handle_indexing_technique_change(dataset, data, filtered_data)
+
+ # Add metadata fields
+ filtered_data["updated_by"] = user.id
+ filtered_data["updated_at"] = naive_utc_now()
+ # update Retrieval model
+ filtered_data["retrieval_model"] = data["retrieval_model"]
+
+ # Update dataset in database
+ db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data)
+ db.session.commit()
+
+ # Trigger vector index task if indexing technique changed
+ if action:
+ deal_dataset_vector_index_task.delay(dataset.id, action)
- db.session.commit()
- if action:
- deal_dataset_vector_index_task.delay(dataset_id, action)
return dataset
+ @staticmethod
+ def _handle_indexing_technique_change(dataset, data, filtered_data):
+ """
+ Handle changes in indexing technique and configure embedding models accordingly.
+
+ Args:
+ dataset: Current dataset object
+ data: Update data dictionary
+ filtered_data: Filtered update data
+
+ Returns:
+ str: Action to perform ('add', 'remove', 'update', or None)
+ """
+ if dataset.indexing_technique != data["indexing_technique"]:
+ if data["indexing_technique"] == "economy":
+ # Remove embedding model configuration for economy mode
+ filtered_data["embedding_model"] = None
+ filtered_data["embedding_model_provider"] = None
+ filtered_data["collection_binding_id"] = None
+ return "remove"
+ elif data["indexing_technique"] == "high_quality":
+ # Configure embedding model for high quality mode
+ DatasetService._configure_embedding_model_for_high_quality(data, filtered_data)
+ return "add"
+ else:
+ # Handle embedding model updates when indexing technique remains the same
+ return DatasetService._handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data)
+ return None
+
+ @staticmethod
+ def _configure_embedding_model_for_high_quality(data, filtered_data):
+ """
+ Configure embedding model settings for high quality indexing.
+
+ Args:
+ data: Update data dictionary
+ filtered_data: Filtered update data to modify
+ """
+ try:
+ model_manager = ModelManager()
+ embedding_model = model_manager.get_model_instance(
+ tenant_id=current_user.current_tenant_id,
+ provider=data["embedding_model_provider"],
+ model_type=ModelType.TEXT_EMBEDDING,
+ model=data["embedding_model"],
+ )
+ filtered_data["embedding_model"] = embedding_model.model
+ filtered_data["embedding_model_provider"] = embedding_model.provider
+ dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
+ embedding_model.provider, embedding_model.model
+ )
+ filtered_data["collection_binding_id"] = dataset_collection_binding.id
+ except LLMBadRequestError:
+ raise ValueError(
+ "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
+ )
+ except ProviderTokenNotInitError as ex:
+ raise ValueError(ex.description)
+
+ @staticmethod
+ def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data):
+ """
+ Handle embedding model updates when indexing technique remains the same.
+
+ Args:
+ dataset: Current dataset object
+ data: Update data dictionary
+ filtered_data: Filtered update data to modify
+
+ Returns:
+ str: Action to perform ('update' or None)
+ """
+ # Skip embedding model checks if not provided in the update request
+ if (
+ "embedding_model_provider" not in data
+ or "embedding_model" not in data
+ or not data.get("embedding_model_provider")
+ or not data.get("embedding_model")
+ ):
+ DatasetService._preserve_existing_embedding_settings(dataset, filtered_data)
+ return None
+ else:
+ return DatasetService._update_embedding_model_settings(dataset, data, filtered_data)
+
+ @staticmethod
+ def _preserve_existing_embedding_settings(dataset, filtered_data):
+ """
+ Preserve existing embedding model settings when not provided in update.
+
+ Args:
+ dataset: Current dataset object
+ filtered_data: Filtered update data to modify
+ """
+ # If the dataset already has embedding model settings, use those
+ if dataset.embedding_model_provider and dataset.embedding_model:
+ filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
+ filtered_data["embedding_model"] = dataset.embedding_model
+ # If collection_binding_id exists, keep it too
+ if dataset.collection_binding_id:
+ filtered_data["collection_binding_id"] = dataset.collection_binding_id
+ # Otherwise, don't try to update embedding model settings at all
+ # Remove these fields from filtered_data if they exist but are None/empty
+ if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]:
+ del filtered_data["embedding_model_provider"]
+ if "embedding_model" in filtered_data and not filtered_data["embedding_model"]:
+ del filtered_data["embedding_model"]
+
+ @staticmethod
+ def _update_embedding_model_settings(dataset, data, filtered_data):
+ """
+ Update embedding model settings with new values.
+
+ Args:
+ dataset: Current dataset object
+ data: Update data dictionary
+ filtered_data: Filtered update data to modify
+
+ Returns:
+ str: Action to perform ('update' or None)
+ """
+ try:
+ # Compare current and new model provider settings
+ current_provider_str = (
+ str(ModelProviderID(dataset.embedding_model_provider)) if dataset.embedding_model_provider else None
+ )
+ new_provider_str = (
+ str(ModelProviderID(data["embedding_model_provider"])) if data["embedding_model_provider"] else None
+ )
+
+ # Only update if values are different
+ if current_provider_str != new_provider_str or data["embedding_model"] != dataset.embedding_model:
+ DatasetService._apply_new_embedding_settings(dataset, data, filtered_data)
+ return "update"
+ except LLMBadRequestError:
+ raise ValueError(
+ "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider."
+ )
+ except ProviderTokenNotInitError as ex:
+ raise ValueError(ex.description)
+ return None
+
+ @staticmethod
+ def _apply_new_embedding_settings(dataset, data, filtered_data):
+ """
+ Apply new embedding model settings to the dataset.
+
+ Args:
+ dataset: Current dataset object
+ data: Update data dictionary
+ filtered_data: Filtered update data to modify
+ """
+ model_manager = ModelManager()
+ try:
+ embedding_model = model_manager.get_model_instance(
+ tenant_id=current_user.current_tenant_id,
+ provider=data["embedding_model_provider"],
+ model_type=ModelType.TEXT_EMBEDDING,
+ model=data["embedding_model"],
+ )
+ except ProviderTokenNotInitError:
+ # If we can't get the embedding model, preserve existing settings
+ logging.warning(
+ f"Failed to initialize embedding model {data['embedding_model_provider']}/{data['embedding_model']}, "
+ f"preserving existing settings"
+ )
+ if dataset.embedding_model_provider and dataset.embedding_model:
+ filtered_data["embedding_model_provider"] = dataset.embedding_model_provider
+ filtered_data["embedding_model"] = dataset.embedding_model
+ if dataset.collection_binding_id:
+ filtered_data["collection_binding_id"] = dataset.collection_binding_id
+ # Skip the rest of the embedding model update
+ return
+
+ # Apply new embedding model settings
+ filtered_data["embedding_model"] = embedding_model.model
+ filtered_data["embedding_model_provider"] = embedding_model.provider
+ dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
+ embedding_model.provider, embedding_model.model
+ )
+ filtered_data["collection_binding_id"] = dataset_collection_binding.id
+
@staticmethod
def delete_dataset(dataset_id, user):
dataset = DatasetService.get_dataset(dataset_id)
@@ -817,7 +995,7 @@ class DocumentService:
# update document to be paused
document.is_paused = True
document.paused_by = current_user.id
- document.paused_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.paused_at = naive_utc_now()
db.session.add(document)
db.session.commit()
@@ -976,12 +1154,17 @@ class DocumentService:
process_rule = knowledge_config.process_rule
if process_rule:
if process_rule.mode in ("custom", "hierarchical"):
- dataset_process_rule = DatasetProcessRule(
- dataset_id=dataset.id,
- mode=process_rule.mode,
- rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
- created_by=account.id,
- )
+ if process_rule.rules:
+ dataset_process_rule = DatasetProcessRule(
+ dataset_id=dataset.id,
+ mode=process_rule.mode,
+ rules=process_rule.rules.model_dump_json() if process_rule.rules else None,
+ created_by=account.id,
+ )
+ else:
+ dataset_process_rule = dataset.latest_process_rule
+ if not dataset_process_rule:
+ raise ValueError("No process rule found.")
elif process_rule.mode == "automatic":
dataset_process_rule = DatasetProcessRule(
dataset_id=dataset.id,
@@ -1402,16 +1585,16 @@ class DocumentService:
knowledge_config.embedding_model, # type: ignore
)
dataset_collection_binding_id = dataset_collection_binding.id
- if knowledge_config.retrieval_model:
- retrieval_model = knowledge_config.retrieval_model
- else:
- retrieval_model = RetrievalModel(
- search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
- reranking_enable=False,
- reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
- top_k=2,
- score_threshold_enabled=False,
- )
+ if knowledge_config.retrieval_model:
+ retrieval_model = knowledge_config.retrieval_model
+ else:
+ retrieval_model = RetrievalModel(
+ search_method=RetrievalMethod.SEMANTIC_SEARCH.value,
+ reranking_enable=False,
+ reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""),
+ top_k=2,
+ score_threshold_enabled=False,
+ )
# save dataset
dataset = Dataset(
tenant_id=tenant_id,
@@ -1603,6 +1786,191 @@ class DocumentService:
if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int):
raise ValueError("Process rule segmentation max_tokens is invalid")
+ @staticmethod
+ def batch_update_document_status(dataset: Dataset, document_ids: list[str], action: str, user):
+ """
+ Batch update document status.
+
+ Args:
+ dataset (Dataset): The dataset object
+ document_ids (list[str]): List of document IDs to update
+ action (str): Action to perform (enable, disable, archive, un_archive)
+ user: Current user performing the action
+
+ Raises:
+ DocumentIndexingError: If document is being indexed or not in correct state
+ ValueError: If action is invalid
+ """
+ if not document_ids:
+ return
+
+ # Early validation of action parameter
+ valid_actions = ["enable", "disable", "archive", "un_archive"]
+ if action not in valid_actions:
+ raise ValueError(f"Invalid action: {action}. Must be one of {valid_actions}")
+
+ documents_to_update = []
+
+ # First pass: validate all documents and prepare updates
+ for document_id in document_ids:
+ document = DocumentService.get_document(dataset.id, document_id)
+ if not document:
+ continue
+
+ # Check if document is being indexed
+ indexing_cache_key = f"document_{document.id}_indexing"
+ cache_result = redis_client.get(indexing_cache_key)
+ if cache_result is not None:
+ raise DocumentIndexingError(f"Document:{document.name} is being indexed, please try again later")
+
+ # Prepare update based on action
+ update_info = DocumentService._prepare_document_status_update(document, action, user)
+ if update_info:
+ documents_to_update.append(update_info)
+
+ # Second pass: apply all updates in a single transaction
+ if documents_to_update:
+ try:
+ for update_info in documents_to_update:
+ document = update_info["document"]
+ updates = update_info["updates"]
+
+ # Apply updates to the document
+ for field, value in updates.items():
+ setattr(document, field, value)
+
+ db.session.add(document)
+
+ # Batch commit all changes
+ db.session.commit()
+ except Exception as e:
+ # Rollback on any error
+ db.session.rollback()
+ raise e
+ # Execute async tasks and set Redis cache after successful commit
+ # propagation_error is used to capture any errors for submitting async task execution
+ propagation_error = None
+ for update_info in documents_to_update:
+ try:
+ # Execute async tasks after successful commit
+ if update_info["async_task"]:
+ task_info = update_info["async_task"]
+ task_func = task_info["function"]
+ task_args = task_info["args"]
+ task_func.delay(*task_args)
+ except Exception as e:
+ # Log the error but do not rollback the transaction
+ logging.exception(f"Error executing async task for document {update_info['document'].id}")
+ # don't raise the error immediately, but capture it for later
+ propagation_error = e
+ try:
+ # Set Redis cache if needed after successful commit
+ if update_info["set_cache"]:
+ document = update_info["document"]
+ indexing_cache_key = f"document_{document.id}_indexing"
+ redis_client.setex(indexing_cache_key, 600, 1)
+ except Exception as e:
+ # Log the error but do not rollback the transaction
+ logging.exception(f"Error setting cache for document {update_info['document'].id}")
+ # Raise any propagation error after all updates
+ if propagation_error:
+ raise propagation_error
+
+ @staticmethod
+ def _prepare_document_status_update(document, action: str, user):
+ """
+ Prepare document status update information.
+
+ Args:
+ document: Document object to update
+ action: Action to perform
+ user: Current user
+
+ Returns:
+ dict: Update information or None if no update needed
+ """
+ now = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+
+ if action == "enable":
+ return DocumentService._prepare_enable_update(document, now)
+ elif action == "disable":
+ return DocumentService._prepare_disable_update(document, user, now)
+ elif action == "archive":
+ return DocumentService._prepare_archive_update(document, user, now)
+ elif action == "un_archive":
+ return DocumentService._prepare_unarchive_update(document, now)
+
+ return None
+
+ @staticmethod
+ def _prepare_enable_update(document, now):
+ """Prepare updates for enabling a document."""
+ if document.enabled:
+ return None
+
+ return {
+ "document": document,
+ "updates": {"enabled": True, "disabled_at": None, "disabled_by": None, "updated_at": now},
+ "async_task": {"function": add_document_to_index_task, "args": [document.id]},
+ "set_cache": True,
+ }
+
+ @staticmethod
+ def _prepare_disable_update(document, user, now):
+ """Prepare updates for disabling a document."""
+ if not document.completed_at or document.indexing_status != "completed":
+ raise DocumentIndexingError(f"Document: {document.name} is not completed.")
+
+ if not document.enabled:
+ return None
+
+ return {
+ "document": document,
+ "updates": {"enabled": False, "disabled_at": now, "disabled_by": user.id, "updated_at": now},
+ "async_task": {"function": remove_document_from_index_task, "args": [document.id]},
+ "set_cache": True,
+ }
+
+ @staticmethod
+ def _prepare_archive_update(document, user, now):
+ """Prepare updates for archiving a document."""
+ if document.archived:
+ return None
+
+ update_info = {
+ "document": document,
+ "updates": {"archived": True, "archived_at": now, "archived_by": user.id, "updated_at": now},
+ "async_task": None,
+ "set_cache": False,
+ }
+
+ # Only set async task and cache if document is currently enabled
+ if document.enabled:
+ update_info["async_task"] = {"function": remove_document_from_index_task, "args": [document.id]}
+ update_info["set_cache"] = True
+
+ return update_info
+
+ @staticmethod
+ def _prepare_unarchive_update(document, now):
+ """Prepare updates for unarchiving a document."""
+ if not document.archived:
+ return None
+
+ update_info = {
+ "document": document,
+ "updates": {"archived": False, "archived_at": None, "archived_by": None, "updated_at": now},
+ "async_task": None,
+ "set_cache": False,
+ }
+
+ # Only re-index if the document is currently enabled
+ if document.enabled:
+ update_info["async_task"] = {"function": add_document_to_index_task, "args": [document.id]}
+ update_info["set_cache"] = True
+
+ return update_info
+
class SegmentService:
@classmethod
@@ -1857,6 +2225,7 @@ class SegmentService:
# calc embedding use tokens
if document.doc_form == "qa_model":
+ segment.answer = args.answer
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0]
else:
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py
index 8c06ee9386..54d45f45ea 100644
--- a/api/services/enterprise/enterprise_service.py
+++ b/api/services/enterprise/enterprise_service.py
@@ -29,7 +29,7 @@ class EnterpriseService:
raise ValueError("No data found.")
try:
# parse the UTC timestamp from the response
- return datetime.fromisoformat(data.replace("Z", "+00:00"))
+ return datetime.fromisoformat(data)
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
@@ -40,7 +40,7 @@ class EnterpriseService:
raise ValueError("No data found.")
try:
# parse the UTC timestamp from the response
- return datetime.fromisoformat(data.replace("Z", "+00:00"))
+ return datetime.fromisoformat(data)
except ValueError as e:
raise ValueError(f"Invalid date format: {data}") from e
diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py
index bb3be61f85..344c67885e 100644
--- a/api/services/entities/knowledge_entities/knowledge_entities.py
+++ b/api/services/entities/knowledge_entities/knowledge_entities.py
@@ -4,13 +4,6 @@ from typing import Literal, Optional
from pydantic import BaseModel
-class SegmentUpdateEntity(BaseModel):
- content: str
- answer: Optional[str] = None
- keywords: Optional[list[str]] = None
- enabled: Optional[bool] = None
-
-
class ParentMode(StrEnum):
FULL_DOC = "full-doc"
PARAGRAPH = "paragraph"
@@ -95,13 +88,13 @@ class WeightKeywordSetting(BaseModel):
class WeightModel(BaseModel):
- weight_type: Optional[str] = None
+ weight_type: Optional[Literal["semantic_first", "keyword_first", "customized"]] = None
vector_setting: Optional[WeightVectorSetting] = None
keyword_setting: Optional[WeightKeywordSetting] = None
class RetrievalModel(BaseModel):
- search_method: Literal["hybrid_search", "semantic_search", "full_text_search"]
+ search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"]
reranking_enable: bool
reranking_model: Optional[RerankingModel] = None
reranking_mode: Optional[str] = None
@@ -153,10 +146,6 @@ class MetadataUpdateArgs(BaseModel):
value: Optional[str | int | float] = None
-class MetadataValueUpdateArgs(BaseModel):
- fields: list[MetadataUpdateArgs]
-
-
class MetadataDetail(BaseModel):
id: str
name: str
diff --git a/api/services/errors/app.py b/api/services/errors/app.py
index 87e9e9247d..5d348c61be 100644
--- a/api/services/errors/app.py
+++ b/api/services/errors/app.py
@@ -4,3 +4,7 @@ class MoreLikeThisDisabledError(Exception):
class WorkflowHashNotEqualError(Exception):
pass
+
+
+class IsDraftWorkflowError(Exception):
+ pass
diff --git a/api/services/errors/plugin.py b/api/services/errors/plugin.py
new file mode 100644
index 0000000000..be5b144b3d
--- /dev/null
+++ b/api/services/errors/plugin.py
@@ -0,0 +1,5 @@
+from services.errors.base import BaseServiceError
+
+
+class PluginInstallationForbiddenError(BaseServiceError):
+ pass
diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py
index eb50d79494..06a4c22117 100644
--- a/api/services/external_knowledge_service.py
+++ b/api/services/external_knowledge_service.py
@@ -1,6 +1,5 @@
import json
from copy import deepcopy
-from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
from urllib.parse import urlparse
@@ -11,6 +10,7 @@ from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
from core.rag.entities.metadata_entities import MetadataCondition
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models.dataset import (
Dataset,
ExternalKnowledgeApis,
@@ -120,7 +120,7 @@ class ExternalDatasetService:
external_knowledge_api.description = args.get("description", "")
external_knowledge_api.settings = json.dumps(args.get("settings"), ensure_ascii=False)
external_knowledge_api.updated_by = user_id
- external_knowledge_api.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ external_knowledge_api.updated_at = naive_utc_now()
db.session.commit()
return external_knowledge_api
diff --git a/api/services/feature_service.py b/api/services/feature_service.py
index be85a03e80..1441e6ce16 100644
--- a/api/services/feature_service.py
+++ b/api/services/feature_service.py
@@ -88,6 +88,26 @@ class WebAppAuthModel(BaseModel):
allow_email_password_login: bool = False
+class PluginInstallationScope(StrEnum):
+ NONE = "none"
+ OFFICIAL_ONLY = "official_only"
+ OFFICIAL_AND_SPECIFIC_PARTNERS = "official_and_specific_partners"
+ ALL = "all"
+
+
+class PluginInstallationPermissionModel(BaseModel):
+ # Plugin installation scope – possible values:
+ # none: prohibit all plugin installations
+ # official_only: allow only Dify official plugins
+ # official_and_specific_partners: allow official and specific partner plugins
+ # all: allow installation of all plugins
+ plugin_installation_scope: PluginInstallationScope = PluginInstallationScope.ALL
+
+ # If True, restrict plugin installation to the marketplace only
+ # Equivalent to ForceEnablePluginVerification
+ restrict_to_marketplace_only: bool = False
+
+
class FeatureModel(BaseModel):
billing: BillingModel = BillingModel()
education: EducationModel = EducationModel()
@@ -103,7 +123,7 @@ class FeatureModel(BaseModel):
dataset_operator_enabled: bool = False
webapp_copyright_enabled: bool = False
workspace_members: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0)
-
+ is_allow_transfer_workspace: bool = True
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@@ -128,6 +148,8 @@ class SystemFeatureModel(BaseModel):
license: LicenseModel = LicenseModel()
branding: BrandingModel = BrandingModel()
webapp_auth: WebAppAuthModel = WebAppAuthModel()
+ plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
+ enable_change_email: bool = True
class FeatureService:
@@ -165,6 +187,7 @@ class FeatureService:
if dify_config.ENTERPRISE_ENABLED:
system_features.branding.enabled = True
system_features.webapp_auth.enabled = True
+ system_features.enable_change_email = False
cls._fulfill_params_from_enterprise(system_features)
if dify_config.MARKETPLACE_ENABLED:
@@ -207,6 +230,8 @@ class FeatureService:
if features.billing.subscription.plan != "sandbox":
features.webapp_copyright_enabled = True
+ else:
+ features.is_allow_transfer_workspace = False
if "members" in billing_info:
features.members.size = billing_info["members"]["size"]
@@ -291,3 +316,12 @@ class FeatureService:
features.license.workspaces.enabled = license_info["workspaces"]["enabled"]
features.license.workspaces.limit = license_info["workspaces"]["limit"]
features.license.workspaces.size = license_info["workspaces"]["used"]
+
+ if "PluginInstallationPermission" in enterprise_info:
+ plugin_installation_info = enterprise_info["PluginInstallationPermission"]
+ features.plugin_installation_permission.plugin_installation_scope = plugin_installation_info[
+ "pluginInstallationScope"
+ ]
+ features.plugin_installation_permission.restrict_to_marketplace_only = plugin_installation_info[
+ "restrictToMarketplaceOnly"
+ ]
diff --git a/api/services/file_service.py b/api/services/file_service.py
index 2d68f30c5a..286535bd18 100644
--- a/api/services/file_service.py
+++ b/api/services/file_service.py
@@ -18,6 +18,7 @@ from core.file import helpers as file_helpers
from core.rag.extractor.extract_processor import ExtractProcessor
from extensions.ext_database import db
from extensions.ext_storage import storage
+from libs.helper import extract_tenant_id
from models.account import Account
from models.enums import CreatorUserRole
from models.model import EndUser, UploadFile
@@ -61,11 +62,7 @@ class FileService:
# generate file key
file_uuid = str(uuid.uuid4())
- if isinstance(user, Account):
- current_tenant_id = user.current_tenant_id
- else:
- # end_user
- current_tenant_id = user.tenant_id
+ current_tenant_id = extract_tenant_id(user)
file_key = "upload_files/" + (current_tenant_id or "") + "/" + file_uuid + "." + extension
diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py
index 26d6d4ce18..cfcb121153 100644
--- a/api/services/metadata_service.py
+++ b/api/services/metadata_service.py
@@ -19,6 +19,10 @@ from services.entities.knowledge_entities.knowledge_entities import (
class MetadataService:
@staticmethod
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
+ # check if metadata name is too long
+ if len(metadata_args.name) > 255:
+ raise ValueError("Metadata name cannot exceed 255 characters.")
+
# check if metadata name already exists
if (
db.session.query(DatasetMetadata)
@@ -42,6 +46,10 @@ class MetadataService:
@staticmethod
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
+ # check if metadata name is too long
+ if len(name) > 255:
+ raise ValueError("Metadata name cannot exceed 255 characters.")
+
lock_key = f"dataset_metadata_lock_{dataset_id}"
# check if metadata name already exists
if (
diff --git a/api/services/moderation_service.py b/api/services/moderation_service.py
deleted file mode 100644
index 082afeed89..0000000000
--- a/api/services/moderation_service.py
+++ /dev/null
@@ -1,23 +0,0 @@
-from typing import Optional
-
-from core.moderation.factory import ModerationFactory, ModerationOutputsResult
-from extensions.ext_database import db
-from models.model import App, AppModelConfig
-
-
-class ModerationService:
- def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
- app_model_config: Optional[AppModelConfig] = None
-
- app_model_config = (
- db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
- )
-
- if not app_model_config:
- raise ValueError("app model config not found")
-
- name = app_model_config.sensitive_word_avoidance_dict["type"]
- config = app_model_config.sensitive_word_avoidance_dict["config"]
-
- moderation = ModerationFactory(name, app_id, app_model.tenant_id, config)
- return moderation.moderation_for_outputs(text)
diff --git a/api/services/ops_service.py b/api/services/ops_service.py
index 792f50703e..dbeb4f1908 100644
--- a/api/services/ops_service.py
+++ b/api/services/ops_service.py
@@ -34,6 +34,24 @@ class OpsService:
)
new_decrypt_tracing_config = OpsTraceManager.obfuscated_decrypt_token(tracing_provider, decrypt_tracing_config)
+ if tracing_provider == "arize" and (
+ "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
+ ):
+ try:
+ project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
+ new_decrypt_tracing_config.update({"project_url": project_url})
+ except Exception:
+ new_decrypt_tracing_config.update({"project_url": "https://app.arize.com/"})
+
+ if tracing_provider == "phoenix" and (
+ "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
+ ):
+ try:
+ project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
+ new_decrypt_tracing_config.update({"project_url": project_url})
+ except Exception:
+ new_decrypt_tracing_config.update({"project_url": "https://app.phoenix.arize.com/projects/"})
+
if tracing_provider == "langfuse" and (
"project_key" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_key")
):
@@ -76,6 +94,16 @@ class OpsService:
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://wandb.ai/"})
+
+ if tracing_provider == "aliyun" and (
+ "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
+ ):
+ try:
+ project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
+ new_decrypt_tracing_config.update({"project_url": project_url})
+ except Exception:
+ new_decrypt_tracing_config.update({"project_url": "https://arms.console.aliyun.com/"})
+
trace_config_data.tracing_config = new_decrypt_tracing_config
return trace_config_data.to_dict()
@@ -107,7 +135,9 @@ class OpsService:
return {"error": "Invalid Credentials"}
# get project url
- if tracing_provider == "langfuse":
+ if tracing_provider in ("arize", "phoenix"):
+ project_url = OpsTraceManager.get_trace_config_project_url(tracing_config, tracing_provider)
+ elif tracing_provider == "langfuse":
project_key = OpsTraceManager.get_trace_config_project_key(tracing_config, tracing_provider)
project_url = "{host}/project/{key}".format(host=tracing_config.get("host"), key=project_key)
elif tracing_provider in ("langsmith", "opik"):
diff --git a/api/services/plugin/data_migration.py b/api/services/plugin/data_migration.py
index 1c5abfecba..5324036414 100644
--- a/api/services/plugin/data_migration.py
+++ b/api/services/plugin/data_migration.py
@@ -3,7 +3,7 @@ import logging
import click
-from core.entities import DEFAULT_PLUGIN_ID
+from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID
from models.engine import db
logger = logging.getLogger(__name__)
@@ -12,17 +12,17 @@ logger = logging.getLogger(__name__)
class PluginDataMigration:
@classmethod
def migrate(cls) -> None:
- cls.migrate_db_records("providers", "provider_name") # large table
- cls.migrate_db_records("provider_models", "provider_name")
- cls.migrate_db_records("provider_orders", "provider_name")
- cls.migrate_db_records("tenant_default_models", "provider_name")
- cls.migrate_db_records("tenant_preferred_model_providers", "provider_name")
- cls.migrate_db_records("provider_model_settings", "provider_name")
- cls.migrate_db_records("load_balancing_model_configs", "provider_name")
+ cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table
+ cls.migrate_db_records("provider_models", "provider_name", ModelProviderID)
+ cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID)
+ cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID)
+ cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID)
+ cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID)
+ cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID)
cls.migrate_datasets()
- cls.migrate_db_records("embeddings", "provider_name") # large table
- cls.migrate_db_records("dataset_collection_bindings", "provider_name")
- cls.migrate_db_records("tool_builtin_providers", "provider")
+ cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table
+ cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID)
+ cls.migrate_db_records("tool_builtin_providers", "provider", ToolProviderID)
@classmethod
def migrate_datasets(cls) -> None:
@@ -66,9 +66,10 @@ limit 1000"""
fg="white",
)
)
- retrieval_model["reranking_model"]["reranking_provider_name"] = (
- f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}"
- )
+ # update google to langgenius/gemini/google etc.
+ retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID(
+ retrieval_model["reranking_model"]["reranking_provider_name"]
+ ).to_string()
retrieval_model_changed = True
click.echo(
@@ -86,9 +87,11 @@ limit 1000"""
update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
params["retrieval_model"] = json.dumps(retrieval_model)
+ params["provider_name"] = ModelProviderID(provider_name).to_string()
+
sql = f"""update {table_name}
set {provider_column_name} =
- concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
+ :provider_name
{update_retrieval_model_sql}
where id = :record_id"""
conn.execute(db.text(sql), params)
@@ -122,7 +125,9 @@ limit 1000"""
)
@classmethod
- def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None:
+ def migrate_db_records(
+ cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID]
+ ) -> None:
click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white"))
processed_count = 0
@@ -166,7 +171,8 @@ limit 1000"""
)
try:
- updated_value = f"{DEFAULT_PLUGIN_ID}/{provider_name}/{provider_name}"
+ # update jina to langgenius/jina_tool/jina etc.
+ updated_value = provider_cls(provider_name).to_string()
batch_updates.append((updated_value, record_id))
except Exception as e:
failed_ids.append(record_id)
diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py
index 461247419b..b84dd0afc5 100644
--- a/api/services/plugin/oauth_service.py
+++ b/api/services/plugin/oauth_service.py
@@ -1,7 +1,53 @@
+import json
+import uuid
+
from core.plugin.impl.base import BasePluginClient
+from extensions.ext_redis import redis_client
+
+
+class OAuthProxyService(BasePluginClient):
+ # Default max age for proxy context parameter in seconds
+ __MAX_AGE__ = 5 * 60 # 5 minutes
+ __KEY_PREFIX__ = "oauth_proxy_context:"
+
+ @staticmethod
+ def create_proxy_context(user_id: str, tenant_id: str, plugin_id: str, provider: str):
+ """
+ Create a proxy context for an OAuth 2.0 authorization request.
+
+ This parameter is a crucial security measure to prevent Cross-Site Request
+ Forgery (CSRF) attacks. It works by generating a unique nonce and storing it
+ in a distributed cache (Redis) along with the user's session context.
+ The returned nonce should be included as the 'proxy_context' parameter in the
+ authorization URL. Upon callback, the `use_proxy_context` method
+ is used to verify the state, ensuring the request's integrity and authenticity,
+ and mitigating replay attacks.
+ """
+ context_id = str(uuid.uuid4())
+ data = {
+ "user_id": user_id,
+ "plugin_id": plugin_id,
+ "tenant_id": tenant_id,
+ "provider": provider,
+ }
+ redis_client.setex(
+ f"{OAuthProxyService.__KEY_PREFIX__}{context_id}",
+ OAuthProxyService.__MAX_AGE__,
+ json.dumps(data),
+ )
+ return context_id
-class OAuthService(BasePluginClient):
- @classmethod
- def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str:
- return "1234567890"
+ @staticmethod
+ def use_proxy_context(context_id: str):
+ """
+ Validate the proxy context parameter.
+ This checks if the context_id is valid and not expired.
+ """
+ if not context_id:
+ raise ValueError("context_id is required")
+ # get data from redis
+ data = redis_client.getdel(f"{OAuthProxyService.__KEY_PREFIX__}{context_id}")
+ if not data:
+ raise ValueError("context_id is invalid")
+ return json.loads(data)
diff --git a/api/services/plugin/plugin_parameter_service.py b/api/services/plugin/plugin_parameter_service.py
new file mode 100644
index 0000000000..a1c5639e00
--- /dev/null
+++ b/api/services/plugin/plugin_parameter_service.py
@@ -0,0 +1,72 @@
+from collections.abc import Mapping, Sequence
+from typing import Any, Literal
+
+from sqlalchemy.orm import Session
+
+from core.plugin.entities.parameters import PluginParameterOption
+from core.plugin.impl.dynamic_select import DynamicSelectClient
+from core.tools.tool_manager import ToolManager
+from core.tools.utils.encryption import create_tool_provider_encrypter
+from extensions.ext_database import db
+from models.tools import BuiltinToolProvider
+
+
+class PluginParameterService:
+ @staticmethod
+ def get_dynamic_select_options(
+ tenant_id: str,
+ user_id: str,
+ plugin_id: str,
+ provider: str,
+ action: str,
+ parameter: str,
+ provider_type: Literal["tool"],
+ ) -> Sequence[PluginParameterOption]:
+ """
+ Get dynamic select options for a plugin parameter.
+
+ Args:
+ tenant_id: The tenant ID.
+ plugin_id: The plugin ID.
+ provider: The provider name.
+ action: The action name.
+ parameter: The parameter name.
+ """
+ credentials: Mapping[str, Any] = {}
+
+ match provider_type:
+ case "tool":
+ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
+ # init tool configuration
+ encrypter, _ = create_tool_provider_encrypter(
+ tenant_id=tenant_id,
+ controller=provider_controller,
+ )
+
+ # check if credentials are required
+ if not provider_controller.need_credentials:
+ credentials = {}
+ else:
+ # fetch credentials from db
+ with Session(db.engine) as session:
+ db_record = (
+ session.query(BuiltinToolProvider)
+ .filter(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ BuiltinToolProvider.provider == provider,
+ )
+ .first()
+ )
+
+ if db_record is None:
+ raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
+
+ credentials = encrypter.decrypt(db_record.credentials)
+ case _:
+ raise ValueError(f"Invalid provider type: {provider_type}")
+
+ return (
+ DynamicSelectClient()
+ .fetch_dynamic_select_options(tenant_id, user_id, plugin_id, provider, action, credentials, parameter)
+ .options
+ )
diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py
index a8b64f27db..9005f0669b 100644
--- a/api/services/plugin/plugin_service.py
+++ b/api/services/plugin/plugin_service.py
@@ -17,11 +17,18 @@ from core.plugin.entities.plugin import (
PluginInstallation,
PluginInstallationSource,
)
-from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginListResponse, PluginUploadResponse
+from core.plugin.entities.plugin_daemon import (
+ PluginDecodeResponse,
+ PluginInstallTask,
+ PluginListResponse,
+ PluginVerification,
+)
from core.plugin.impl.asset import PluginAssetManager
from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_redis import redis_client
+from services.errors.plugin import PluginInstallationForbiddenError
+from services.feature_service import FeatureService, PluginInstallationScope
logger = logging.getLogger(__name__)
@@ -31,6 +38,9 @@ class PluginService:
plugin_id: str
version: str
unique_identifier: str
+ status: str
+ deprecated_reason: str
+ alternative_plugin_id: str
REDIS_KEY_PREFIX = "plugin_service:latest_plugin:"
REDIS_TTL = 60 * 5 # 5 minutes
@@ -64,6 +74,9 @@ class PluginService:
plugin_id=plugin_id,
version=manifest.latest_version,
unique_identifier=manifest.latest_package_identifier,
+ status=manifest.status,
+ deprecated_reason=manifest.deprecated_reason,
+ alternative_plugin_id=manifest.alternative_plugin_id,
)
# Store in Redis
@@ -86,6 +99,42 @@ class PluginService:
logger.exception("failed to fetch latest plugin version")
return result
+ @staticmethod
+ def _check_marketplace_only_permission():
+ """
+ Check if the marketplace only permission is enabled
+ """
+ features = FeatureService.get_system_features()
+ if features.plugin_installation_permission.restrict_to_marketplace_only:
+ raise PluginInstallationForbiddenError("Plugin installation is restricted to marketplace only")
+
+ @staticmethod
+ def _check_plugin_installation_scope(plugin_verification: Optional[PluginVerification]):
+ """
+ Check the plugin installation scope
+ """
+ features = FeatureService.get_system_features()
+
+ match features.plugin_installation_permission.plugin_installation_scope:
+ case PluginInstallationScope.OFFICIAL_ONLY:
+ if (
+ plugin_verification is None
+ or plugin_verification.authorized_category != PluginVerification.AuthorizedCategory.Langgenius
+ ):
+ raise PluginInstallationForbiddenError("Plugin installation is restricted to official only")
+ case PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS:
+ if plugin_verification is None or plugin_verification.authorized_category not in [
+ PluginVerification.AuthorizedCategory.Langgenius,
+ PluginVerification.AuthorizedCategory.Partner,
+ ]:
+ raise PluginInstallationForbiddenError(
+ "Plugin installation is restricted to official and specific partners"
+ )
+ case PluginInstallationScope.NONE:
+ raise PluginInstallationForbiddenError("Installing plugins is not allowed")
+ case PluginInstallationScope.ALL:
+ pass
+
@staticmethod
def get_debugging_key(tenant_id: str) -> str:
"""
@@ -153,6 +202,17 @@ class PluginService:
manager = PluginInstaller()
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
+ @staticmethod
+ def is_plugin_verified(tenant_id: str, plugin_unique_identifier: str) -> bool:
+ """
+ Check if the plugin is verified
+ """
+ manager = PluginInstaller()
+ try:
+ return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier).verified
+ except Exception:
+ return False
+
@staticmethod
def fetch_install_tasks(tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]:
"""
@@ -208,6 +268,8 @@ class PluginService:
# check if plugin pkg is already downloaded
manager = PluginInstaller()
+ features = FeatureService.get_system_features()
+
try:
manager.fetch_plugin_manifest(tenant_id, new_plugin_unique_identifier)
# already downloaded, skip, and record install event
@@ -215,7 +277,14 @@ class PluginService:
except Exception:
# plugin not installed, download and upload pkg
pkg = download_plugin_pkg(new_plugin_unique_identifier)
- manager.upload_pkg(tenant_id, pkg, verify_signature=False)
+ response = manager.upload_pkg(
+ tenant_id,
+ pkg,
+ verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
+ )
+
+ # check if the plugin is available to install
+ PluginService._check_plugin_installation_scope(response.verification)
return manager.upgrade_plugin(
tenant_id,
@@ -239,6 +308,7 @@ class PluginService:
"""
Upgrade plugin with github
"""
+ PluginService._check_marketplace_only_permission()
manager = PluginInstaller()
return manager.upgrade_plugin(
tenant_id,
@@ -253,33 +323,43 @@ class PluginService:
)
@staticmethod
- def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginUploadResponse:
+ def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginDecodeResponse:
"""
Upload plugin package files
returns: plugin_unique_identifier
"""
+ PluginService._check_marketplace_only_permission()
manager = PluginInstaller()
- return manager.upload_pkg(tenant_id, pkg, verify_signature)
+ features = FeatureService.get_system_features()
+ response = manager.upload_pkg(
+ tenant_id,
+ pkg,
+ verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
+ )
+ return response
@staticmethod
def upload_pkg_from_github(
tenant_id: str, repo: str, version: str, package: str, verify_signature: bool = False
- ) -> PluginUploadResponse:
+ ) -> PluginDecodeResponse:
"""
Install plugin from github release package files,
returns plugin_unique_identifier
"""
+ PluginService._check_marketplace_only_permission()
pkg = download_with_size_limit(
f"https://github.com/{repo}/releases/download/{version}/{package}", dify_config.PLUGIN_MAX_PACKAGE_SIZE
)
+ features = FeatureService.get_system_features()
manager = PluginInstaller()
- return manager.upload_pkg(
+ response = manager.upload_pkg(
tenant_id,
pkg,
- verify_signature,
+ verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
)
+ return response
@staticmethod
def upload_bundle(
@@ -289,11 +369,15 @@ class PluginService:
Upload a plugin bundle and return the dependencies.
"""
manager = PluginInstaller()
+ PluginService._check_marketplace_only_permission()
return manager.upload_bundle(tenant_id, bundle, verify_signature)
@staticmethod
def install_from_local_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]):
+ PluginService._check_marketplace_only_permission()
+
manager = PluginInstaller()
+
return manager.install_from_identifiers(
tenant_id,
plugin_unique_identifiers,
@@ -307,6 +391,8 @@ class PluginService:
Install plugin from github release package files,
returns plugin_unique_identifier
"""
+ PluginService._check_marketplace_only_permission()
+
manager = PluginInstaller()
return manager.install_from_identifiers(
tenant_id,
@@ -322,28 +408,33 @@ class PluginService:
)
@staticmethod
- def fetch_marketplace_pkg(
- tenant_id: str, plugin_unique_identifier: str, verify_signature: bool = False
- ) -> PluginDeclaration:
+ def fetch_marketplace_pkg(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
"""
Fetch marketplace package
"""
if not dify_config.MARKETPLACE_ENABLED:
raise ValueError("marketplace is not enabled")
+ features = FeatureService.get_system_features()
+
manager = PluginInstaller()
try:
declaration = manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
except Exception:
pkg = download_plugin_pkg(plugin_unique_identifier)
- declaration = manager.upload_pkg(tenant_id, pkg, verify_signature).manifest
+ response = manager.upload_pkg(
+ tenant_id,
+ pkg,
+ verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
+ )
+ # check if the plugin is available to install
+ PluginService._check_plugin_installation_scope(response.verification)
+ declaration = response.manifest
return declaration
@staticmethod
- def install_from_marketplace_pkg(
- tenant_id: str, plugin_unique_identifiers: Sequence[str], verify_signature: bool = False
- ):
+ def install_from_marketplace_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]):
"""
Install plugin from marketplace package files,
returns installation task id
@@ -353,26 +444,40 @@ class PluginService:
manager = PluginInstaller()
+ # collect actual plugin_unique_identifiers
+ actual_plugin_unique_identifiers = []
+ metas = []
+ features = FeatureService.get_system_features()
+
# check if already downloaded
for plugin_unique_identifier in plugin_unique_identifiers:
try:
manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
+ plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier)
+ # check if the plugin is available to install
+ PluginService._check_plugin_installation_scope(plugin_decode_response.verification)
# already downloaded, skip
+ actual_plugin_unique_identifiers.append(plugin_unique_identifier)
+ metas.append({"plugin_unique_identifier": plugin_unique_identifier})
except Exception:
# plugin not installed, download and upload pkg
pkg = download_plugin_pkg(plugin_unique_identifier)
- manager.upload_pkg(tenant_id, pkg, verify_signature)
+ response = manager.upload_pkg(
+ tenant_id,
+ pkg,
+ verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only,
+ )
+ # check if the plugin is available to install
+ PluginService._check_plugin_installation_scope(response.verification)
+ # use response plugin_unique_identifier
+ actual_plugin_unique_identifiers.append(response.unique_identifier)
+ metas.append({"plugin_unique_identifier": response.unique_identifier})
return manager.install_from_identifiers(
tenant_id,
- plugin_unique_identifiers,
+ actual_plugin_unique_identifiers,
PluginInstallationSource.Marketplace,
- [
- {
- "plugin_unique_identifier": plugin_unique_identifier,
- }
- for plugin_unique_identifier in plugin_unique_identifiers
- ],
+ metas,
)
@staticmethod
diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py
index 6f848d49c4..80badf2335 100644
--- a/api/services/tools/api_tools_manage_service.py
+++ b/api/services/tools/api_tools_manage_service.py
@@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import (
)
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
-from core.tools.utils.configuration import ProviderConfigEncrypter
+from core.tools.utils.encryption import create_tool_provider_encrypter
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
from models.tools import ApiToolProvider
@@ -164,15 +164,11 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# encrypt credentials
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
- config=list(provider_controller.get_credentials_schema()),
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ controller=provider_controller,
)
-
- encrypted_credentials = tool_configuration.encrypt(credentials)
- db_provider.credentials_str = json.dumps(encrypted_credentials)
+ db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
db.session.add(db_provider)
db.session.commit()
@@ -297,28 +293,26 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, cache = create_tool_provider_encrypter(
tenant_id=tenant_id,
- config=list(provider_controller.get_credentials_schema()),
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ controller=provider_controller,
)
- original_credentials = tool_configuration.decrypt(provider.credentials)
- masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
+ original_credentials = encrypter.decrypt(provider.credentials)
+ masked_credentials = encrypter.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
- credentials = tool_configuration.encrypt(credentials)
+ credentials = encrypter.encrypt(credentials)
provider.credentials_str = json.dumps(credentials)
db.session.add(provider)
db.session.commit()
# delete cache
- tool_configuration.delete_tool_credentials_cache()
+ cache.delete()
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
@@ -416,15 +410,13 @@ class ApiToolManageService:
# decrypt credentials
if db_provider.id:
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
- config=list(provider_controller.get_credentials_schema()),
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ controller=provider_controller,
)
- decrypted_credentials = tool_configuration.decrypt(credentials)
+ decrypted_credentials = encrypter.decrypt(credentials)
# check if the credential has changed, save the original credential
- masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
+ masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials)
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = decrypted_credentials[name]
@@ -446,7 +438,7 @@ class ApiToolManageService:
return {"result": result or "empty response"}
@staticmethod
- def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
+ def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
"""
list api tools
"""
@@ -474,7 +466,7 @@ class ApiToolManageService:
for tool in tools or []:
user_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
- tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
+ tenant_id=tenant_id, tool=tool, labels=labels
)
)
diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py
index 58a4b2f179..430575b532 100644
--- a/api/services/tools/builtin_tools_manage_service.py
+++ b/api/services/tools/builtin_tools_manage_service.py
@@ -1,28 +1,84 @@
import json
import logging
+import re
+from collections.abc import Mapping
from pathlib import Path
+from typing import Any, Optional
from sqlalchemy.orm import Session
from configs import dify_config
+from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.position_helper import is_filtered
-from core.model_runtime.utils.encoders import jsonable_encoder
+from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
-from core.plugin.impl.exc import PluginDaemonClientSideError
+from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
-from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
-from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
+from core.tools.entities.api_entities import (
+ ToolApiEntity,
+ ToolProviderApiEntity,
+ ToolProviderCredentialApiEntity,
+ ToolProviderCredentialInfoApiEntity,
+)
+from core.tools.entities.tool_entities import CredentialType
+from core.tools.errors import ToolProviderNotFoundError
+from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
-from core.tools.utils.configuration import ProviderConfigEncrypter
+from core.tools.utils.encryption import create_provider_encrypter
+from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from extensions.ext_database import db
-from models.tools import BuiltinToolProvider
+from extensions.ext_redis import redis_client
+from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
+from services.plugin.plugin_service import PluginService
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class BuiltinToolManageService:
+ __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
+
+ @staticmethod
+ def delete_custom_oauth_client_params(tenant_id: str, provider: str):
+ """
+ delete custom oauth client params
+ """
+ tool_provider = ToolProviderID(provider)
+ with Session(db.engine) as session:
+ session.query(ToolOAuthTenantClient).filter_by(
+ tenant_id=tenant_id,
+ provider=tool_provider.provider_name,
+ plugin_id=tool_provider.plugin_id,
+ ).delete()
+ session.commit()
+ return {"result": "success"}
+
+ @staticmethod
+ def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str):
+ """
+ get builtin tool provider oauth client schema
+ """
+ provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
+ verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
+ tenant_id, provider.plugin_unique_identifier
+ )
+
+ is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled(
+ tenant_id, provider_name
+ )
+ is_system_oauth_params_exists = verified and BuiltinToolManageService.is_oauth_system_client_exists(
+ provider_name
+ )
+ result = {
+ "schema": provider.get_oauth_client_schema(),
+ "is_oauth_custom_client_enabled": is_oauth_custom_client_enabled,
+ "is_system_oauth_params_exists": is_system_oauth_params_exists,
+ "client_params": BuiltinToolManageService.get_custom_oauth_client_params(tenant_id, provider_name),
+ "redirect_uri": f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback",
+ }
+ return result
+
@staticmethod
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
"""
@@ -36,27 +92,11 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tools = provider_controller.get_tools()
- tool_provider_configurations = ProviderConfigEncrypter(
- tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
- )
- # check if user has added the provider
- builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
-
- credentials = {}
- if builtin_provider is not None:
- # get credentials
- credentials = builtin_provider.credentials
- credentials = tool_provider_configurations.decrypt(credentials)
-
result: list[ToolApiEntity] = []
for tool in tools or []:
result.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool,
- credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
@@ -65,25 +105,15 @@ class BuiltinToolManageService:
return result
@staticmethod
- def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
+ def get_builtin_tool_provider_info(tenant_id: str, provider: str):
"""
get builtin tool provider info
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
- tool_provider_configurations = ProviderConfigEncrypter(
- tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
- )
# check if user has added the provider
- builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
-
- credentials = {}
- if builtin_provider is not None:
- # get credentials
- credentials = builtin_provider.credentials
- credentials = tool_provider_configurations.decrypt(credentials)
+ builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
+ if builtin_provider is None:
+ raise ValueError(f"you have not added provider {provider}")
entity = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
@@ -92,127 +122,406 @@ class BuiltinToolManageService:
)
entity.original_credentials = {}
-
return entity
@staticmethod
- def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
+ def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str):
"""
list builtin provider credentials schema
+ :param credential_type: credential type
:param provider_name: the name of the provider
:param tenant_id: the id of the tenant
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
- return jsonable_encoder(provider.get_credentials_schema())
+ return provider.get_credentials_schema_by_type(credential_type)
@staticmethod
def update_builtin_tool_provider(
- session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
+ user_id: str,
+ tenant_id: str,
+ provider: str,
+ credential_id: str,
+ credentials: dict | None = None,
+ name: str | None = None,
):
"""
update builtin tool provider
"""
- # get if the provider exists
- provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
+ with Session(db.engine) as session:
+ # get if the provider exists
+ db_provider = (
+ session.query(BuiltinToolProvider)
+ .filter(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ BuiltinToolProvider.id == credential_id,
+ )
+ .first()
+ )
+ if db_provider is None:
+ raise ValueError(f"you have not added provider {provider}")
+
+ try:
+ if CredentialType.of(db_provider.credential_type).is_editable() and credentials:
+ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
+ if not provider_controller.need_credentials:
+ raise ValueError(f"provider {provider} does not need credentials")
+
+ encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
+ tenant_id, db_provider, provider, provider_controller
+ )
+
+ original_credentials = encrypter.decrypt(db_provider.credentials)
+ new_credentials: dict = {
+ key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
+ for key, value in credentials.items()
+ }
+
+ if CredentialType.of(db_provider.credential_type).is_validate_allowed():
+ provider_controller.validate_credentials(user_id, new_credentials)
+ # encrypt credentials
+ db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
+
+ cache.delete()
+
+ # update name if provided
+ if name and name != db_provider.name:
+ # check if the name is already used
+ if (
+ session.query(BuiltinToolProvider)
+ .filter_by(tenant_id=tenant_id, provider=provider, name=name)
+ .count()
+ > 0
+ ):
+ raise ValueError(f"the credential name '{name}' is already used")
+
+ db_provider.name = name
+
+ session.commit()
+ except Exception as e:
+ session.rollback()
+ raise ValueError(str(e))
+ return {"result": "success"}
+
+ @staticmethod
+ def add_builtin_tool_provider(
+ user_id: str,
+ api_type: CredentialType,
+ tenant_id: str,
+ provider: str,
+ credentials: dict,
+ name: str | None = None,
+ ):
+ """
+ add builtin tool provider
+ """
try:
- # get provider
- provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
- if not provider_controller.need_credentials:
- raise ValueError(f"provider {provider_name} does not need credentials")
- tool_configuration = ProviderConfigEncrypter(
- tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
- )
+ with Session(db.engine) as session:
+ lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
+ with redis_client.lock(lock, timeout=20):
+ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
+ if not provider_controller.need_credentials:
+ raise ValueError(f"provider {provider} does not need credentials")
+
+ provider_count = (
+ session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
+ )
+
+ # check if the provider count is reached the limit
+ if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
+ raise ValueError(f"you have reached the maximum number of providers for {provider}")
+
+ # validate credentials if allowed
+ if CredentialType.of(api_type).is_validate_allowed():
+ provider_controller.validate_credentials(user_id, credentials)
+
+ # generate name if not provided
+ if name is None or name == "":
+ name = BuiltinToolManageService.generate_builtin_tool_provider_name(
+ session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type
+ )
+ else:
+ # check if the name is already used
+ if (
+ session.query(BuiltinToolProvider)
+ .filter_by(tenant_id=tenant_id, provider=provider, name=name)
+ .count()
+ > 0
+ ):
+ raise ValueError(f"the credential name '{name}' is already used")
+
+ # create encrypter
+ encrypter, _ = create_provider_encrypter(
+ tenant_id=tenant_id,
+ config=[
+ x.to_basic_provider_config()
+ for x in provider_controller.get_credentials_schema_by_type(api_type)
+ ],
+ cache=NoOpProviderCredentialCache(),
+ )
+
+ db_provider = BuiltinToolProvider(
+ tenant_id=tenant_id,
+ user_id=user_id,
+ provider=provider,
+ encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
+ credential_type=api_type.value,
+ name=name,
+ )
- # get original credentials if exists
- if provider is not None:
- original_credentials = tool_configuration.decrypt(provider.credentials)
- masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
- # check if the credential has changed, save the original credential
- for name, value in credentials.items():
- if name in masked_credentials and value == masked_credentials[name]:
- credentials[name] = original_credentials[name]
- # validate credentials
- provider_controller.validate_credentials(user_id, credentials)
- # encrypt credentials
- credentials = tool_configuration.encrypt(credentials)
- except (
- PluginDaemonClientSideError,
- ToolProviderNotFoundError,
- ToolNotFoundError,
- ToolProviderCredentialValidationError,
- ) as e:
+ session.add(db_provider)
+ session.commit()
+ except Exception as e:
+ session.rollback()
raise ValueError(str(e))
+ return {"result": "success"}
- if provider is None:
- # create provider
- provider = BuiltinToolProvider(
- tenant_id=tenant_id,
- user_id=user_id,
- provider=provider_name,
- encrypted_credentials=json.dumps(credentials),
+ @staticmethod
+ def create_tool_encrypter(
+ tenant_id: str,
+ db_provider: BuiltinToolProvider,
+ provider: str,
+ provider_controller: BuiltinToolProviderController,
+ ):
+ encrypter, cache = create_provider_encrypter(
+ tenant_id=tenant_id,
+ config=[
+ x.to_basic_provider_config()
+ for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type)
+ ],
+ cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id),
+ )
+ return encrypter, cache
+
+ @staticmethod
+ def generate_builtin_tool_provider_name(
+ session: Session, tenant_id: str, provider: str, credential_type: CredentialType
+ ) -> str:
+ try:
+ db_providers = (
+ session.query(BuiltinToolProvider)
+ .filter_by(
+ tenant_id=tenant_id,
+ provider=provider,
+ credential_type=credential_type.value,
+ )
+ .order_by(BuiltinToolProvider.created_at.desc())
+ .all()
)
- db.session.add(provider)
- else:
- provider.encrypted_credentials = json.dumps(credentials)
+ # Get the default name pattern
+ default_pattern = f"{credential_type.get_name()}"
- # delete cache
- tool_configuration.delete_tool_credentials_cache()
+ # Find all names that match the default pattern: "{default_pattern} {number}"
+ pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
+ numbers = []
- db.session.commit()
- return {"result": "success"}
+ for db_provider in db_providers:
+ if db_provider.name:
+ match = re.match(pattern, db_provider.name.strip())
+ if match:
+ numbers.append(int(match.group(1)))
+
+ # If no default pattern names found, start with 1
+ if not numbers:
+ return f"{default_pattern} 1"
+
+ # Find the next number
+ max_number = max(numbers)
+ return f"{default_pattern} {max_number + 1}"
+ except Exception as e:
+ logger.warning(f"Error generating next provider name for {provider}: {str(e)}")
+ # fallback
+ return f"{credential_type.get_name()} 1"
@staticmethod
- def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
+ def get_builtin_tool_provider_credentials(
+ tenant_id: str, provider_name: str
+ ) -> list[ToolProviderCredentialApiEntity]:
"""
get builtin tool provider credentials
"""
- provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
+ with db.session.no_autoflush:
+ providers = (
+ db.session.query(BuiltinToolProvider)
+ .filter_by(tenant_id=tenant_id, provider=provider_name)
+ .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
+ .all()
+ )
- if provider_obj is None:
- return {}
+ if len(providers) == 0:
+ return []
+
+ default_provider = providers[0]
+ default_provider.is_default = True
+ provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
+
+ credentials: list[ToolProviderCredentialApiEntity] = []
+ encrypters = {}
+ for provider in providers:
+ credential_type = provider.credential_type
+ if credential_type not in encrypters:
+ encrypters[credential_type] = BuiltinToolManageService.create_tool_encrypter(
+ tenant_id, provider, provider.provider, provider_controller
+ )[0]
+ encrypter = encrypters[credential_type]
+ decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
+ credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
+ provider=provider,
+ credentials=decrypt_credential,
+ )
+ credentials.append(credential_entity)
+ return credentials
- provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
- tool_configuration = ProviderConfigEncrypter(
- tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ @staticmethod
+ def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity:
+ """
+ get builtin tool provider credential info
+ """
+ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
+ supported_credential_types = provider_controller.get_supported_credential_types()
+ credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider)
+ credential_info = ToolProviderCredentialInfoApiEntity(
+ supported_credential_types=supported_credential_types,
+ is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
+ credentials=credentials,
)
- credentials = tool_configuration.decrypt(provider_obj.credentials)
- credentials = tool_configuration.mask_tool_credentials(credentials)
- return credentials
+
+ return credential_info
@staticmethod
- def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
+ def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str):
"""
delete tool provider
"""
- provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
+ with Session(db.engine) as session:
+ db_provider = (
+ session.query(BuiltinToolProvider)
+ .filter(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ BuiltinToolProvider.id == credential_id,
+ )
+ .first()
+ )
- if provider_obj is None:
- raise ValueError(f"you have not added provider {provider_name}")
+ if db_provider is None:
+ raise ValueError(f"you have not added provider {provider}")
- db.session.delete(provider_obj)
- db.session.commit()
+ session.delete(db_provider)
+ session.commit()
- # delete cache
- provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
- tool_configuration = ProviderConfigEncrypter(
+ # delete cache
+ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
+ _, cache = BuiltinToolManageService.create_tool_encrypter(
+ tenant_id, db_provider, provider, provider_controller
+ )
+ cache.delete()
+
+ return {"result": "success"}
+
+ @staticmethod
+ def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
+ """
+ set default provider
+ """
+ with Session(db.engine) as session:
+ # get provider
+ target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
+ if target_provider is None:
+ raise ValueError("provider not found")
+
+ # clear default provider
+ session.query(BuiltinToolProvider).filter_by(
+ tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
+ ).update({"is_default": False})
+
+ # set new default provider
+ target_provider.is_default = True
+ session.commit()
+ return {"result": "success"}
+
+ @staticmethod
+ def is_oauth_system_client_exists(provider_name: str) -> bool:
+ """
+ check if oauth system client exists
+ """
+ tool_provider = ToolProviderID(provider_name)
+ with Session(db.engine).no_autoflush as session:
+ system_client: ToolOAuthSystemClient | None = (
+ session.query(ToolOAuthSystemClient)
+ .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
+ .first()
+ )
+ return system_client is not None
+
+ @staticmethod
+ def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
+ """
+ check if oauth custom client is enabled
+ """
+ tool_provider = ToolProviderID(provider)
+ with Session(db.engine).no_autoflush as session:
+ user_client: ToolOAuthTenantClient | None = (
+ session.query(ToolOAuthTenantClient)
+ .filter_by(
+ tenant_id=tenant_id,
+ provider=tool_provider.provider_name,
+ plugin_id=tool_provider.plugin_id,
+ enabled=True,
+ )
+ .first()
+ )
+ return user_client is not None and user_client.enabled
+
+ @staticmethod
+ def get_oauth_client(tenant_id: str, provider: str) -> Mapping[str, Any] | None:
+ """
+ get builtin tool provider
+ """
+ tool_provider = ToolProviderID(provider)
+ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
+ encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
+ cache=NoOpProviderCredentialCache(),
)
- tool_configuration.delete_tool_credentials_cache()
+ with Session(db.engine).no_autoflush as session:
+ user_client: ToolOAuthTenantClient | None = (
+ session.query(ToolOAuthTenantClient)
+ .filter_by(
+ tenant_id=tenant_id,
+ provider=tool_provider.provider_name,
+ plugin_id=tool_provider.plugin_id,
+ enabled=True,
+ )
+ .first()
+ )
+ oauth_params: Mapping[str, Any] | None = None
+ if user_client:
+ oauth_params = encrypter.decrypt(user_client.oauth_params)
+ return oauth_params
+
+ # only verified provider can use custom oauth client
+ is_verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
+ tenant_id, provider.plugin_unique_identifier
+ )
+ if not is_verified:
+ return oauth_params
- return {"result": "success"}
+ system_client: ToolOAuthSystemClient | None = (
+ session.query(ToolOAuthSystemClient)
+ .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
+ .first()
+ )
+ if system_client:
+ try:
+ oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
+ except Exception as e:
+ raise ValueError(f"Error decrypting system oauth params: {e}")
+
+ return oauth_params
@staticmethod
def get_builtin_tool_provider_icon(provider: str):
@@ -234,9 +543,7 @@ class BuiltinToolManageService:
with db.session.no_autoflush:
# get all user added providers
- db_providers: list[BuiltinToolProvider] = (
- db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
- )
+ db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
# rewrite db_providers
for db_provider in db_providers:
@@ -275,7 +582,6 @@ class BuiltinToolManageService:
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
- credentials=user_builtin_provider.original_credentials,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
@@ -287,43 +593,153 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
- def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
- try:
- full_provider_name = provider_name
- provider_id_entity = ToolProviderID(provider_name)
- provider_name = provider_id_entity.provider_name
- if provider_id_entity.organization != "langgenius":
- provider_obj = (
- db.session.query(BuiltinToolProvider)
- .filter(
- BuiltinToolProvider.tenant_id == tenant_id,
- BuiltinToolProvider.provider == full_provider_name,
+ def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
+ """
+ This method is used to fetch the builtin provider from the database
+ 1.if the default provider exists, return the default provider
+ 2.if the default provider does not exist, return the oldest provider
+ """
+ with Session(db.engine) as session:
+ try:
+ full_provider_name = provider_name
+ provider_id_entity = ToolProviderID(provider_name)
+ provider_name = provider_id_entity.provider_name
+
+ if provider_id_entity.organization != "langgenius":
+ provider = (
+ session.query(BuiltinToolProvider)
+ .filter(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ BuiltinToolProvider.provider == full_provider_name,
+ )
+ .order_by(
+ BuiltinToolProvider.is_default.desc(), # default=True first
+ BuiltinToolProvider.created_at.asc(), # oldest first
+ )
+ .first()
)
- .first()
- )
- else:
- provider_obj = (
- db.session.query(BuiltinToolProvider)
- .filter(
- BuiltinToolProvider.tenant_id == tenant_id,
- (BuiltinToolProvider.provider == provider_name)
- | (BuiltinToolProvider.provider == full_provider_name),
+ else:
+ provider = (
+ session.query(BuiltinToolProvider)
+ .filter(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ (BuiltinToolProvider.provider == provider_name)
+ | (BuiltinToolProvider.provider == full_provider_name),
+ )
+ .order_by(
+ BuiltinToolProvider.is_default.desc(), # default=True first
+ BuiltinToolProvider.created_at.asc(), # oldest first
+ )
+ .first()
+ )
+
+ if provider is None:
+ return None
+
+ provider.provider = ToolProviderID(provider.provider).to_string()
+ return provider
+ except Exception:
+ # it's an old provider without organization
+ return (
+ session.query(BuiltinToolProvider)
+ .filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
+ .order_by(
+ BuiltinToolProvider.is_default.desc(), # default=True first
+ BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
- if provider_obj is None:
- return None
+ @staticmethod
+ def save_custom_oauth_client_params(
+ tenant_id: str,
+ provider: str,
+ client_params: Optional[dict] = None,
+ enable_oauth_custom_client: Optional[bool] = None,
+ ):
+ """
+ setup oauth custom client
+ """
+ if client_params is None and enable_oauth_custom_client is None:
+ return {"result": "success"}
+
+ tool_provider = ToolProviderID(provider)
+ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
+ if not provider_controller:
+ raise ToolProviderNotFoundError(f"Provider {provider} not found")
- provider_obj.provider = ToolProviderID(provider_obj.provider).to_string()
- return provider_obj
- except Exception:
- # it's an old provider without organization
- return (
- db.session.query(BuiltinToolProvider)
- .filter(
- BuiltinToolProvider.tenant_id == tenant_id,
- (BuiltinToolProvider.provider == provider_name),
+ if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
+ raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
+
+ with Session(db.engine) as session:
+ custom_client_params = (
+ session.query(ToolOAuthTenantClient)
+ .filter_by(
+ tenant_id=tenant_id,
+ plugin_id=tool_provider.plugin_id,
+ provider=tool_provider.provider_name,
)
.first()
)
+
+ # if the record does not exist, create a basic record
+ if custom_client_params is None:
+ custom_client_params = ToolOAuthTenantClient(
+ tenant_id=tenant_id,
+ plugin_id=tool_provider.plugin_id,
+ provider=tool_provider.provider_name,
+ )
+ session.add(custom_client_params)
+
+ if client_params is not None:
+ encrypter, _ = create_provider_encrypter(
+ tenant_id=tenant_id,
+ config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
+ cache=NoOpProviderCredentialCache(),
+ )
+ original_params = encrypter.decrypt(custom_client_params.oauth_params)
+ new_params: dict = {
+ key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
+ for key, value in client_params.items()
+ }
+ custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
+
+ if enable_oauth_custom_client is not None:
+ custom_client_params.enabled = enable_oauth_custom_client
+
+ session.commit()
+ return {"result": "success"}
+
+ @staticmethod
+ def get_custom_oauth_client_params(tenant_id: str, provider: str):
+ """
+ get custom oauth client params
+ """
+ with Session(db.engine) as session:
+ tool_provider = ToolProviderID(provider)
+ custom_oauth_client_params: ToolOAuthTenantClient | None = (
+ session.query(ToolOAuthTenantClient)
+ .filter_by(
+ tenant_id=tenant_id,
+ plugin_id=tool_provider.plugin_id,
+ provider=tool_provider.provider_name,
+ )
+ .first()
+ )
+ if custom_oauth_client_params is None:
+ return {}
+
+ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
+ if not provider_controller:
+ raise ToolProviderNotFoundError(f"Provider {provider} not found")
+
+ if not isinstance(provider_controller, BuiltinToolProviderController):
+ raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
+
+ encrypter, _ = create_provider_encrypter(
+ tenant_id=tenant_id,
+ config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
+ cache=NoOpProviderCredentialCache(),
+ )
+
+ return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))
diff --git a/api/services/tools/mcp_tools_manage_service.py b/api/services/tools/mcp_tools_manage_service.py
new file mode 100644
index 0000000000..e0e256912e
--- /dev/null
+++ b/api/services/tools/mcp_tools_manage_service.py
@@ -0,0 +1,229 @@
+import hashlib
+import json
+from datetime import datetime
+from typing import Any
+
+from sqlalchemy import or_
+from sqlalchemy.exc import IntegrityError
+
+from core.helper import encrypter
+from core.helper.provider_cache import NoOpProviderCredentialCache
+from core.mcp.error import MCPAuthError, MCPError
+from core.mcp.mcp_client import MCPClient
+from core.tools.entities.api_entities import ToolProviderApiEntity
+from core.tools.entities.common_entities import I18nObject
+from core.tools.entities.tool_entities import ToolProviderType
+from core.tools.mcp_tool.provider import MCPToolProviderController
+from core.tools.utils.encryption import ProviderConfigEncrypter
+from extensions.ext_database import db
+from models.tools import MCPToolProvider
+from services.tools.tools_transform_service import ToolTransformService
+
+UNCHANGED_SERVER_URL_PLACEHOLDER = "[__HIDDEN__]"
+
+
+class MCPToolManageService:
+ """
+ Service class for managing mcp tools.
+ """
+
+ @staticmethod
+ def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
+ res = (
+ db.session.query(MCPToolProvider)
+ .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
+ .first()
+ )
+ if not res:
+ raise ValueError("MCP tool not found")
+ return res
+
+ @staticmethod
+ def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
+ res = (
+ db.session.query(MCPToolProvider)
+ .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
+ .first()
+ )
+ if not res:
+ raise ValueError("MCP tool not found")
+ return res
+
+ @staticmethod
+ def create_mcp_provider(
+ tenant_id: str,
+ name: str,
+ server_url: str,
+ user_id: str,
+ icon: str,
+ icon_type: str,
+ icon_background: str,
+ server_identifier: str,
+ ) -> ToolProviderApiEntity:
+ server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
+ existing_provider = (
+ db.session.query(MCPToolProvider)
+ .filter(
+ MCPToolProvider.tenant_id == tenant_id,
+ or_(
+ MCPToolProvider.name == name,
+ MCPToolProvider.server_url_hash == server_url_hash,
+ MCPToolProvider.server_identifier == server_identifier,
+ ),
+ )
+ .first()
+ )
+ if existing_provider:
+ if existing_provider.name == name:
+ raise ValueError(f"MCP tool {name} already exists")
+ if existing_provider.server_url_hash == server_url_hash:
+ raise ValueError(f"MCP tool {server_url} already exists")
+ if existing_provider.server_identifier == server_identifier:
+ raise ValueError(f"MCP tool {server_identifier} already exists")
+ encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
+ mcp_tool = MCPToolProvider(
+ tenant_id=tenant_id,
+ name=name,
+ server_url=encrypted_server_url,
+ server_url_hash=server_url_hash,
+ user_id=user_id,
+ authed=False,
+ tools="[]",
+ icon=json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon,
+ server_identifier=server_identifier,
+ )
+ db.session.add(mcp_tool)
+ db.session.commit()
+ return ToolTransformService.mcp_provider_to_user_provider(mcp_tool, for_list=True)
+
+ @staticmethod
+ def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
+ mcp_providers = (
+ db.session.query(MCPToolProvider)
+ .filter(MCPToolProvider.tenant_id == tenant_id)
+ .order_by(MCPToolProvider.name)
+ .all()
+ )
+ return [
+ ToolTransformService.mcp_provider_to_user_provider(mcp_provider, for_list=for_list)
+ for mcp_provider in mcp_providers
+ ]
+
+ @classmethod
+ def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
+ mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
+ try:
+ with MCPClient(
+ mcp_provider.decrypted_server_url, provider_id, tenant_id, authed=mcp_provider.authed, for_list=True
+ ) as mcp_client:
+ tools = mcp_client.list_tools()
+ except MCPAuthError:
+ raise ValueError("Please auth the tool first")
+ except MCPError as e:
+ raise ValueError(f"Failed to connect to MCP server: {e}")
+ mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
+ mcp_provider.authed = True
+ mcp_provider.updated_at = datetime.now()
+ db.session.commit()
+ user = mcp_provider.load_user()
+ return ToolProviderApiEntity(
+ id=mcp_provider.id,
+ name=mcp_provider.name,
+ tools=ToolTransformService.mcp_tool_to_user_tool(mcp_provider, tools),
+ type=ToolProviderType.MCP,
+ icon=mcp_provider.icon,
+ author=user.name if user else "Anonymous",
+ server_url=mcp_provider.masked_server_url,
+ updated_at=int(mcp_provider.updated_at.timestamp()),
+ description=I18nObject(en_US="", zh_Hans=""),
+ label=I18nObject(en_US=mcp_provider.name, zh_Hans=mcp_provider.name),
+ plugin_unique_identifier=mcp_provider.server_identifier,
+ )
+
+ @classmethod
+ def delete_mcp_tool(cls, tenant_id: str, provider_id: str):
+ mcp_tool = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
+
+ db.session.delete(mcp_tool)
+ db.session.commit()
+
+ @classmethod
+ def update_mcp_provider(
+ cls,
+ tenant_id: str,
+ provider_id: str,
+ name: str,
+ server_url: str,
+ icon: str,
+ icon_type: str,
+ icon_background: str,
+ server_identifier: str,
+ ):
+ mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
+ mcp_provider.updated_at = datetime.now()
+ mcp_provider.name = name
+ mcp_provider.icon = (
+ json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
+ )
+ mcp_provider.server_identifier = server_identifier
+
+ if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
+ encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
+ mcp_provider.server_url = encrypted_server_url
+ server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
+
+ if server_url_hash != mcp_provider.server_url_hash:
+ cls._re_connect_mcp_provider(mcp_provider, provider_id, tenant_id)
+ mcp_provider.server_url_hash = server_url_hash
+ try:
+ db.session.commit()
+ except IntegrityError as e:
+ db.session.rollback()
+ error_msg = str(e.orig)
+ if "unique_mcp_provider_name" in error_msg:
+ raise ValueError(f"MCP tool {name} already exists")
+ if "unique_mcp_provider_server_url" in error_msg:
+ raise ValueError(f"MCP tool {server_url} already exists")
+ if "unique_mcp_provider_server_identifier" in error_msg:
+ raise ValueError(f"MCP tool {server_identifier} already exists")
+ raise
+
+ @classmethod
+ def update_mcp_provider_credentials(
+ cls, mcp_provider: MCPToolProvider, credentials: dict[str, Any], authed: bool = False
+ ):
+ provider_controller = MCPToolProviderController._from_db(mcp_provider)
+ tool_configuration = ProviderConfigEncrypter(
+ tenant_id=mcp_provider.tenant_id,
+ config=list(provider_controller.get_credentials_schema()),
+ provider_config_cache=NoOpProviderCredentialCache(),
+ )
+ credentials = tool_configuration.encrypt(credentials)
+ mcp_provider.updated_at = datetime.now()
+ mcp_provider.encrypted_credentials = json.dumps({**mcp_provider.credentials, **credentials})
+ mcp_provider.authed = authed
+ if not authed:
+ mcp_provider.tools = "[]"
+ db.session.commit()
+
+ @classmethod
+ def _re_connect_mcp_provider(cls, mcp_provider: MCPToolProvider, provider_id: str, tenant_id: str):
+ """re-connect mcp provider"""
+ try:
+ with MCPClient(
+ mcp_provider.decrypted_server_url,
+ provider_id,
+ tenant_id,
+ authed=False,
+ for_list=True,
+ ) as mcp_client:
+ tools = mcp_client.list_tools()
+ mcp_provider.authed = True
+ mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
+ except MCPAuthError:
+ mcp_provider.authed = False
+ mcp_provider.tools = "[]"
+ except MCPError as e:
+ raise ValueError(f"Failed to re-connect MCP server: {e}") from e
+ # reset credentials
+ mcp_provider.encrypted_credentials = "{}"
diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py
index 367121125b..2d192e6f7f 100644
--- a/api/services/tools/tools_transform_service.py
+++ b/api/services/tools/tools_transform_service.py
@@ -1,27 +1,30 @@
import json
import logging
-from typing import Optional, Union, cast
+from typing import Any, Optional, Union, cast
from yarl import URL
from configs import dify_config
+from core.helper.provider_cache import ToolProviderCredentialsCache
+from core.mcp.types import Tool as MCPTool
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController
-from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
+from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
+ CredentialType,
ToolParameter,
ToolProviderType,
)
from core.tools.plugin_tool.provider import PluginToolProviderController
-from core.tools.utils.configuration import ProviderConfigEncrypter
+from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
-from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider
+from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
logger = logging.getLogger(__name__)
@@ -52,7 +55,8 @@ class ToolTransformService:
return icon
except Exception:
return {"background": "#252525", "content": "\ud83d\ude01"}
-
+ elif provider_type == ToolProviderType.MCP.value:
+ return icon
return ""
@staticmethod
@@ -73,10 +77,18 @@ class ToolTransformService:
provider.icon = ToolTransformService.get_plugin_icon_url(
tenant_id=tenant_id, filename=provider.icon
)
+ if isinstance(provider.icon_dark, str) and provider.icon_dark:
+ provider.icon_dark = ToolTransformService.get_plugin_icon_url(
+ tenant_id=tenant_id, filename=provider.icon_dark
+ )
else:
provider.icon = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon
)
+ if provider.icon_dark:
+ provider.icon_dark = ToolTransformService.get_tool_provider_icon_url(
+ provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon_dark
+ )
@classmethod
def builtin_provider_to_user_provider(
@@ -94,6 +106,7 @@ class ToolTransformService:
name=provider_controller.entity.identity.name,
description=provider_controller.entity.identity.description,
icon=provider_controller.entity.identity.icon,
+ icon_dark=provider_controller.entity.identity.icon_dark,
label=provider_controller.entity.identity.label,
type=ToolProviderType.BUILT_IN,
masked_credentials={},
@@ -108,7 +121,12 @@ class ToolTransformService:
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
# get credentials schema
- schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
+ schema = {
+ x.to_basic_provider_config().name: x
+ for x in provider_controller.get_credentials_schema_by_type(
+ CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY
+ )
+ }
for name, value in schema.items():
if result.masked_credentials:
@@ -125,15 +143,23 @@ class ToolTransformService:
credentials = db_provider.credentials
# init tool configuration
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, _ = create_provider_encrypter(
tenant_id=db_provider.tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ config=[
+ x.to_basic_provider_config()
+ for x in provider_controller.get_credentials_schema_by_type(
+ CredentialType.of(db_provider.credential_type)
+ )
+ ],
+ cache=ToolProviderCredentialsCache(
+ tenant_id=db_provider.tenant_id,
+ provider=db_provider.provider,
+ credential_id=db_provider.id,
+ ),
)
# decrypt the credentials and mask the credentials
- decrypted_credentials = tool_configuration.decrypt(data=credentials)
- masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
+ decrypted_credentials = encrypter.decrypt(data=credentials)
+ masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
result.original_credentials = decrypted_credentials
@@ -148,11 +174,16 @@ class ToolTransformService:
convert provider controller to user provider
"""
# package tool provider controller
+ auth_type = ApiProviderAuthType.NONE
+ credentials_auth_type = db_provider.credentials.get("auth_type")
+ if credentials_auth_type in ("api_key_header", "api_key"): # backward compatibility
+ auth_type = ApiProviderAuthType.API_KEY_HEADER
+ elif credentials_auth_type == "api_key_query":
+ auth_type = ApiProviderAuthType.API_KEY_QUERY
+
controller = ApiToolProviderController.from_db(
db_provider=db_provider,
- auth_type=ApiProviderAuthType.API_KEY
- if db_provider.credentials["auth_type"] == "api_key"
- else ApiProviderAuthType.NONE,
+ auth_type=auth_type,
)
return controller
@@ -177,6 +208,7 @@ class ToolTransformService:
name=provider_controller.entity.identity.name,
description=provider_controller.entity.identity.description,
icon=provider_controller.entity.identity.icon,
+ icon_dark=provider_controller.entity.identity.icon_dark,
label=provider_controller.entity.identity.label,
type=ToolProviderType.WORKFLOW,
masked_credentials={},
@@ -187,6 +219,41 @@ class ToolTransformService:
labels=labels or [],
)
+ @staticmethod
+ def mcp_provider_to_user_provider(db_provider: MCPToolProvider, for_list: bool = False) -> ToolProviderApiEntity:
+ user = db_provider.load_user()
+ return ToolProviderApiEntity(
+ id=db_provider.server_identifier if not for_list else db_provider.id,
+ author=user.name if user else "Anonymous",
+ name=db_provider.name,
+ icon=db_provider.provider_icon,
+ type=ToolProviderType.MCP,
+ is_team_authorization=db_provider.authed,
+ server_url=db_provider.masked_server_url,
+ tools=ToolTransformService.mcp_tool_to_user_tool(
+ db_provider, [MCPTool(**tool) for tool in json.loads(db_provider.tools)]
+ ),
+ updated_at=int(db_provider.updated_at.timestamp()),
+ label=I18nObject(en_US=db_provider.name, zh_Hans=db_provider.name),
+ description=I18nObject(en_US="", zh_Hans=""),
+ server_identifier=db_provider.server_identifier,
+ )
+
+ @staticmethod
+ def mcp_tool_to_user_tool(mcp_provider: MCPToolProvider, tools: list[MCPTool]) -> list[ToolApiEntity]:
+ user = mcp_provider.load_user()
+ return [
+ ToolApiEntity(
+ author=user.name if user else "Anonymous",
+ name=tool.name,
+ label=I18nObject(en_US=tool.name, zh_Hans=tool.name),
+ description=I18nObject(en_US=tool.description, zh_Hans=tool.description),
+ parameters=ToolTransformService.convert_mcp_schema_to_parameter(tool.inputSchema),
+ labels=[],
+ )
+ for tool in tools
+ ]
+
@classmethod
def api_provider_to_user_provider(
cls,
@@ -235,16 +302,14 @@ class ToolTransformService:
if decrypt_credentials:
# init tool configuration
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, _ = create_tool_provider_encrypter(
tenant_id=db_provider.tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ controller=provider_controller,
)
# decrypt the credentials and mask the credentials
- decrypted_credentials = tool_configuration.decrypt(data=credentials)
- masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
+ decrypted_credentials = encrypter.decrypt(data=credentials)
+ masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
@@ -254,7 +319,6 @@ class ToolTransformService:
def convert_tool_entity_to_api_entity(
tool: Union[ApiToolBundle, WorkflowTool, Tool],
tenant_id: str,
- credentials: dict | None = None,
labels: list[str] | None = None,
) -> ToolApiEntity:
"""
@@ -264,27 +328,39 @@ class ToolTransformService:
# fork tool runtime
tool = tool.fork_tool_runtime(
runtime=ToolRuntime(
- credentials=credentials or {},
+ credentials={},
tenant_id=tenant_id,
)
)
# get tool parameters
- parameters = tool.entity.parameters or []
+ base_parameters = tool.entity.parameters or []
# get tool runtime parameters
runtime_parameters = tool.get_runtime_parameters()
- # override parameters
- current_parameters = parameters.copy()
- for runtime_parameter in runtime_parameters:
- found = False
- for index, parameter in enumerate(current_parameters):
- if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
- current_parameters[index] = runtime_parameter
- found = True
- break
- if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
- current_parameters.append(runtime_parameter)
+ # merge parameters using a functional approach to avoid type issues
+ merged_parameters: list[ToolParameter] = []
+
+ # create a mapping of runtime parameters for quick lookup
+ runtime_param_map = {(rp.name, rp.form): rp for rp in runtime_parameters}
+
+ # process base parameters, replacing with runtime versions if they exist
+ for base_param in base_parameters:
+ key = (base_param.name, base_param.form)
+ if key in runtime_param_map:
+ merged_parameters.append(runtime_param_map[key])
+ else:
+ merged_parameters.append(base_param)
+
+ # add any runtime parameters that weren't in base parameters
+ for runtime_parameter in runtime_parameters:
+ if runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
+ # check if this parameter is already in merged_parameters
+ already_exists = any(
+ p.name == runtime_parameter.name and p.form == runtime_parameter.form for p in merged_parameters
+ )
+ if not already_exists:
+ merged_parameters.append(runtime_parameter)
return ToolApiEntity(
author=tool.entity.identity.author,
@@ -292,10 +368,10 @@ class ToolTransformService:
label=tool.entity.identity.label,
description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
output_schema=tool.entity.output_schema,
- parameters=current_parameters,
+ parameters=merged_parameters,
labels=labels or [],
)
- if isinstance(tool, ApiToolBundle):
+ elif isinstance(tool, ApiToolBundle):
return ToolApiEntity(
author=tool.author,
name=tool.operation_id or "",
@@ -304,3 +380,69 @@ class ToolTransformService:
parameters=tool.parameters,
labels=labels or [],
)
+ else:
+ # Handle WorkflowTool case
+ raise ValueError(f"Unsupported tool type: {type(tool)}")
+
+ @staticmethod
+ def convert_builtin_provider_to_credential_entity(
+ provider: BuiltinToolProvider, credentials: dict
+ ) -> ToolProviderCredentialApiEntity:
+ return ToolProviderCredentialApiEntity(
+ id=provider.id,
+ name=provider.name,
+ provider=provider.provider,
+ credential_type=CredentialType.of(provider.credential_type),
+ is_default=provider.is_default,
+ credentials=credentials,
+ )
+
+ @staticmethod
+ def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
+ """
+ Convert MCP JSON schema to tool parameters
+
+ :param schema: JSON schema dictionary
+ :return: list of ToolParameter instances
+ """
+
+ def create_parameter(
+ name: str, description: str, param_type: str, required: bool, input_schema: dict | None = None
+ ) -> ToolParameter:
+ """Create a ToolParameter instance with given attributes"""
+ input_schema_dict: dict[str, Any] = {"input_schema": input_schema} if input_schema else {}
+ return ToolParameter(
+ name=name,
+ llm_description=description,
+ label=I18nObject(en_US=name),
+ form=ToolParameter.ToolParameterForm.LLM,
+ required=required,
+ type=ToolParameter.ToolParameterType(param_type),
+ human_description=I18nObject(en_US=description),
+ **input_schema_dict,
+ )
+
+ def process_properties(props: dict, required: list, prefix: str = "") -> list[ToolParameter]:
+ """Process properties recursively"""
+ TYPE_MAPPING = {"integer": "number", "float": "number"}
+ COMPLEX_TYPES = ["array", "object"]
+
+ parameters = []
+ for name, prop in props.items():
+ current_description = prop.get("description", "")
+ prop_type = prop.get("type", "string")
+
+ if isinstance(prop_type, list):
+ prop_type = prop_type[0]
+ if prop_type in TYPE_MAPPING:
+ prop_type = TYPE_MAPPING[prop_type]
+ input_schema = prop if prop_type in COMPLEX_TYPES else None
+ parameters.append(
+ create_parameter(name, current_description, prop_type, name in required, input_schema)
+ )
+
+ return parameters
+
+ if schema.get("type") == "object" and "properties" in schema:
+ return process_properties(schema["properties"], schema.get("required", []))
+ return []
diff --git a/api/services/vector_service.py b/api/services/vector_service.py
index 19e37f4ee3..9165139193 100644
--- a/api/services/vector_service.py
+++ b/api/services/vector_service.py
@@ -97,16 +97,16 @@ class VectorService:
vector = Vector(dataset=dataset)
vector.delete_by_ids([segment.index_node_id])
vector.add_texts([document], duplicate_check=True)
-
- # update keyword index
- keyword = Keyword(dataset)
- keyword.delete_by_ids([segment.index_node_id])
-
- # save keyword index
- if keywords and len(keywords) > 0:
- keyword.add_texts([document], keywords_list=[keywords])
else:
- keyword.add_texts([document])
+ # update keyword index
+ keyword = Keyword(dataset)
+ keyword.delete_by_ids([segment.index_node_id])
+
+ # save keyword index
+ if keywords and len(keywords) > 0:
+ keyword.add_texts([document], keywords_list=[keywords])
+ else:
+ keyword.add_texts([document])
@classmethod
def generate_child_chunks(
diff --git a/api/services/website_service.py b/api/services/website_service.py
index 6720932a3a..991b669737 100644
--- a/api/services/website_service.py
+++ b/api/services/website_service.py
@@ -1,6 +1,7 @@
import datetime
import json
-from typing import Any
+from dataclasses import dataclass
+from typing import Any, Optional
import requests
from flask_login import current_user
@@ -13,241 +14,392 @@ from extensions.ext_storage import storage
from services.auth.api_key_auth_service import ApiKeyAuthService
+@dataclass
+class CrawlOptions:
+ """Options for crawling operations."""
+
+ limit: int = 1
+ crawl_sub_pages: bool = False
+ only_main_content: bool = False
+ includes: Optional[str] = None
+ excludes: Optional[str] = None
+ max_depth: Optional[int] = None
+ use_sitemap: bool = True
+
+ def get_include_paths(self) -> list[str]:
+ """Get list of include paths from comma-separated string."""
+ return self.includes.split(",") if self.includes else []
+
+ def get_exclude_paths(self) -> list[str]:
+ """Get list of exclude paths from comma-separated string."""
+ return self.excludes.split(",") if self.excludes else []
+
+
+@dataclass
+class CrawlRequest:
+ """Request container for crawling operations."""
+
+ url: str
+ provider: str
+ options: CrawlOptions
+
+
+@dataclass
+class ScrapeRequest:
+ """Request container for scraping operations."""
+
+ provider: str
+ url: str
+ tenant_id: str
+ only_main_content: bool
+
+
+@dataclass
+class WebsiteCrawlApiRequest:
+ """Request container for website crawl API arguments."""
+
+ provider: str
+ url: str
+ options: dict[str, Any]
+
+ def to_crawl_request(self) -> CrawlRequest:
+ """Convert API request to internal CrawlRequest."""
+ options = CrawlOptions(
+ limit=self.options.get("limit", 1),
+ crawl_sub_pages=self.options.get("crawl_sub_pages", False),
+ only_main_content=self.options.get("only_main_content", False),
+ includes=self.options.get("includes"),
+ excludes=self.options.get("excludes"),
+ max_depth=self.options.get("max_depth"),
+ use_sitemap=self.options.get("use_sitemap", True),
+ )
+ return CrawlRequest(url=self.url, provider=self.provider, options=options)
+
+ @classmethod
+ def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest":
+ """Create from Flask-RESTful parsed arguments."""
+ provider = args.get("provider")
+ url = args.get("url")
+ options = args.get("options", {})
+
+ if not provider:
+ raise ValueError("Provider is required")
+ if not url:
+ raise ValueError("URL is required")
+ if not options:
+ raise ValueError("Options are required")
+
+ return cls(provider=provider, url=url, options=options)
+
+
+@dataclass
+class WebsiteCrawlStatusApiRequest:
+ """Request container for website crawl status API arguments."""
+
+ provider: str
+ job_id: str
+
+ @classmethod
+ def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest":
+ """Create from Flask-RESTful parsed arguments."""
+ provider = args.get("provider")
+
+ if not provider:
+ raise ValueError("Provider is required")
+ if not job_id:
+ raise ValueError("Job ID is required")
+
+ return cls(provider=provider, job_id=job_id)
+
+
class WebsiteService:
+ """Service class for website crawling operations using different providers."""
+
@classmethod
- def document_create_args_validate(cls, args: dict):
- if "url" not in args or not args["url"]:
- raise ValueError("url is required")
- if "options" not in args or not args["options"]:
- raise ValueError("options is required")
- if "limit" not in args["options"] or not args["options"]["limit"]:
- raise ValueError("limit is required")
+ def _get_credentials_and_config(cls, tenant_id: str, provider: str) -> tuple[dict, dict]:
+ """Get and validate credentials for a provider."""
+ credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
+ if not credentials or "config" not in credentials:
+ raise ValueError("No valid credentials found for the provider")
+ return credentials, credentials["config"]
@classmethod
- def crawl_url(cls, args: dict) -> dict:
- provider = args.get("provider", "")
- url = args.get("url")
- options = args.get("options", "")
- credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
- if provider == "firecrawl":
- # decrypt api_key
- api_key = encrypter.decrypt_token(
- tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
- )
- firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
- crawl_sub_pages = options.get("crawl_sub_pages", False)
- only_main_content = options.get("only_main_content", False)
- if not crawl_sub_pages:
- params = {
- "includePaths": [],
- "excludePaths": [],
- "limit": 1,
- "scrapeOptions": {"onlyMainContent": only_main_content},
- }
- else:
- includes = options.get("includes").split(",") if options.get("includes") else []
- excludes = options.get("excludes").split(",") if options.get("excludes") else []
- params = {
- "includePaths": includes,
- "excludePaths": excludes,
- "limit": options.get("limit", 1),
- "scrapeOptions": {"onlyMainContent": only_main_content},
- }
- if options.get("max_depth"):
- params["maxDepth"] = options.get("max_depth")
- job_id = firecrawl_app.crawl_url(url, params)
- website_crawl_time_cache_key = f"website_crawl_{job_id}"
- time = str(datetime.datetime.now().timestamp())
- redis_client.setex(website_crawl_time_cache_key, 3600, time)
- return {"status": "active", "job_id": job_id}
- elif provider == "watercrawl":
- # decrypt api_key
- api_key = encrypter.decrypt_token(
- tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
- )
- return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).crawl_url(url, options)
+ def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str:
+ """Decrypt and return the API key from config."""
+ api_key = config.get("api_key")
+ if not api_key:
+ raise ValueError("API key not found in configuration")
+ return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key)
- elif provider == "jinareader":
- api_key = encrypter.decrypt_token(
- tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
- )
- crawl_sub_pages = options.get("crawl_sub_pages", False)
- if not crawl_sub_pages:
- response = requests.get(
- f"https://r.jina.ai/{url}",
- headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
- )
- if response.json().get("code") != 200:
- raise ValueError("Failed to crawl")
- return {"status": "active", "data": response.json().get("data")}
- else:
- response = requests.post(
- "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
- json={
- "url": url,
- "maxPages": options.get("limit", 1),
- "useSitemap": options.get("use_sitemap", True),
- },
- headers={
- "Content-Type": "application/json",
- "Authorization": f"Bearer {api_key}",
- },
- )
- if response.json().get("code") != 200:
- raise ValueError("Failed to crawl")
- return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")}
+ @classmethod
+ def document_create_args_validate(cls, args: dict) -> None:
+ """Validate arguments for document creation."""
+ try:
+ WebsiteCrawlApiRequest.from_args(args)
+ except ValueError as e:
+ raise ValueError(f"Invalid arguments: {e}")
+
+ @classmethod
+ def crawl_url(cls, api_request: WebsiteCrawlApiRequest) -> dict[str, Any]:
+ """Crawl a URL using the specified provider with typed request."""
+ request = api_request.to_crawl_request()
+
+ _, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider)
+ api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
+
+ if request.provider == "firecrawl":
+ return cls._crawl_with_firecrawl(request=request, api_key=api_key, config=config)
+ elif request.provider == "watercrawl":
+ return cls._crawl_with_watercrawl(request=request, api_key=api_key, config=config)
+ elif request.provider == "jinareader":
+ return cls._crawl_with_jinareader(request=request, api_key=api_key)
else:
raise ValueError("Invalid provider")
@classmethod
- def get_crawl_status(cls, job_id: str, provider: str) -> dict:
- credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
- if provider == "firecrawl":
- # decrypt api_key
- api_key = encrypter.decrypt_token(
- tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
- )
- firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
- result = firecrawl_app.check_crawl_status(job_id)
- crawl_status_data = {
- "status": result.get("status", "active"),
- "job_id": job_id,
- "total": result.get("total", 0),
- "current": result.get("current", 0),
- "data": result.get("data", []),
+ def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
+ firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
+
+ if not request.options.crawl_sub_pages:
+ params = {
+ "includePaths": [],
+ "excludePaths": [],
+ "limit": 1,
+ "scrapeOptions": {"onlyMainContent": request.options.only_main_content},
}
- if crawl_status_data["status"] == "completed":
- website_crawl_time_cache_key = f"website_crawl_{job_id}"
- start_time = redis_client.get(website_crawl_time_cache_key)
- if start_time:
- end_time = datetime.datetime.now().timestamp()
- time_consuming = abs(end_time - float(start_time))
- crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
- redis_client.delete(website_crawl_time_cache_key)
- elif provider == "watercrawl":
- # decrypt api_key
- api_key = encrypter.decrypt_token(
- tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
+ else:
+ params = {
+ "includePaths": request.options.get_include_paths(),
+ "excludePaths": request.options.get_exclude_paths(),
+ "limit": request.options.limit,
+ "scrapeOptions": {"onlyMainContent": request.options.only_main_content},
+ }
+ if request.options.max_depth:
+ params["maxDepth"] = request.options.max_depth
+
+ job_id = firecrawl_app.crawl_url(request.url, params)
+ website_crawl_time_cache_key = f"website_crawl_{job_id}"
+ time = str(datetime.datetime.now().timestamp())
+ redis_client.setex(website_crawl_time_cache_key, 3600, time)
+ return {"status": "active", "job_id": job_id}
+
+ @classmethod
+ def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
+ # Convert CrawlOptions back to dict format for WaterCrawlProvider
+ options = {
+ "limit": request.options.limit,
+ "crawl_sub_pages": request.options.crawl_sub_pages,
+ "only_main_content": request.options.only_main_content,
+ "includes": request.options.includes,
+ "excludes": request.options.excludes,
+ "max_depth": request.options.max_depth,
+ "use_sitemap": request.options.use_sitemap,
+ }
+ return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url(
+ url=request.url, options=options
+ )
+
+ @classmethod
+ def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]:
+ if not request.options.crawl_sub_pages:
+ response = requests.get(
+ f"https://r.jina.ai/{request.url}",
+ headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
- crawl_status_data = WaterCrawlProvider(
- api_key, credentials.get("config").get("base_url", None)
- ).get_crawl_status(job_id)
- elif provider == "jinareader":
- api_key = encrypter.decrypt_token(
- tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
+ if response.json().get("code") != 200:
+ raise ValueError("Failed to crawl")
+ return {"status": "active", "data": response.json().get("data")}
+ else:
+ response = requests.post(
+ "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
+ json={
+ "url": request.url,
+ "maxPages": request.options.limit,
+ "useSitemap": request.options.use_sitemap,
+ },
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {api_key}",
+ },
)
+ if response.json().get("code") != 200:
+ raise ValueError("Failed to crawl")
+ return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")}
+
+ @classmethod
+ def get_crawl_status(cls, job_id: str, provider: str) -> dict[str, Any]:
+ """Get crawl status using string parameters."""
+ api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id)
+ return cls.get_crawl_status_typed(api_request)
+
+ @classmethod
+ def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> dict[str, Any]:
+ """Get crawl status using typed request."""
+ _, config = cls._get_credentials_and_config(current_user.current_tenant_id, api_request.provider)
+ api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
+
+ if api_request.provider == "firecrawl":
+ return cls._get_firecrawl_status(api_request.job_id, api_key, config)
+ elif api_request.provider == "watercrawl":
+ return cls._get_watercrawl_status(api_request.job_id, api_key, config)
+ elif api_request.provider == "jinareader":
+ return cls._get_jinareader_status(api_request.job_id, api_key)
+ else:
+ raise ValueError("Invalid provider")
+
+ @classmethod
+ def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
+ firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
+ result = firecrawl_app.check_crawl_status(job_id)
+ crawl_status_data = {
+ "status": result.get("status", "active"),
+ "job_id": job_id,
+ "total": result.get("total", 0),
+ "current": result.get("current", 0),
+ "data": result.get("data", []),
+ }
+ if crawl_status_data["status"] == "completed":
+ website_crawl_time_cache_key = f"website_crawl_{job_id}"
+ start_time = redis_client.get(website_crawl_time_cache_key)
+ if start_time:
+ end_time = datetime.datetime.now().timestamp()
+ time_consuming = abs(end_time - float(start_time))
+ crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
+ redis_client.delete(website_crawl_time_cache_key)
+ return crawl_status_data
+
+ @classmethod
+ def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
+ return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id)
+
+ @classmethod
+ def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
+ response = requests.post(
+ "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
+ headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
+ json={"taskId": job_id},
+ )
+ data = response.json().get("data", {})
+ crawl_status_data = {
+ "status": data.get("status", "active"),
+ "job_id": job_id,
+ "total": len(data.get("urls", [])),
+ "current": len(data.get("processed", [])) + len(data.get("failed", [])),
+ "data": [],
+ "time_consuming": data.get("duration", 0) / 1000,
+ }
+
+ if crawl_status_data["status"] == "completed":
response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
- json={"taskId": job_id},
+ json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
)
data = response.json().get("data", {})
- crawl_status_data = {
- "status": data.get("status", "active"),
- "job_id": job_id,
- "total": len(data.get("urls", [])),
- "current": len(data.get("processed", [])) + len(data.get("failed", [])),
- "data": [],
- "time_consuming": data.get("duration", 0) / 1000,
- }
-
- if crawl_status_data["status"] == "completed":
- response = requests.post(
- "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
- headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
- json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
- )
- data = response.json().get("data", {})
- formatted_data = [
- {
- "title": item.get("data", {}).get("title"),
- "source_url": item.get("data", {}).get("url"),
- "description": item.get("data", {}).get("description"),
- "markdown": item.get("data", {}).get("content"),
- }
- for item in data.get("processed", {}).values()
- ]
- crawl_status_data["data"] = formatted_data
- else:
- raise ValueError("Invalid provider")
+ formatted_data = [
+ {
+ "title": item.get("data", {}).get("title"),
+ "source_url": item.get("data", {}).get("url"),
+ "description": item.get("data", {}).get("description"),
+ "markdown": item.get("data", {}).get("content"),
+ }
+ for item in data.get("processed", {}).values()
+ ]
+ crawl_status_data["data"] = formatted_data
return crawl_status_data
@classmethod
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[str, Any] | None:
- credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
- # decrypt api_key
- api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
+ _, config = cls._get_credentials_and_config(tenant_id, provider)
+ api_key = cls._get_decrypted_api_key(tenant_id, config)
if provider == "firecrawl":
- crawl_data: list[dict[str, Any]] | None = None
- file_key = "website_files/" + job_id + ".txt"
- if storage.exists(file_key):
- stored_data = storage.load_once(file_key)
- if stored_data:
- crawl_data = json.loads(stored_data.decode("utf-8"))
- else:
- firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
- result = firecrawl_app.check_crawl_status(job_id)
- if result.get("status") != "completed":
- raise ValueError("Crawl job is not completed")
- crawl_data = result.get("data")
-
- if crawl_data:
- for item in crawl_data:
- if item.get("source_url") == url:
- return dict(item)
- return None
+ return cls._get_firecrawl_url_data(job_id, url, api_key, config)
elif provider == "watercrawl":
- api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
- return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).get_crawl_url_data(
- job_id, url
- )
+ return cls._get_watercrawl_url_data(job_id, url, api_key, config)
elif provider == "jinareader":
- if not job_id:
- response = requests.get(
- f"https://r.jina.ai/{url}",
- headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
- )
- if response.json().get("code") != 200:
- raise ValueError("Failed to crawl")
- return dict(response.json().get("data", {}))
- else:
- # Get crawl status first
- status_response = requests.post(
- "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
- headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
- json={"taskId": job_id},
- )
- status_data = status_response.json().get("data", {})
- if status_data.get("status") != "completed":
- raise ValueError("Crawl job is not completed")
-
- # Get processed data
- data_response = requests.post(
- "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
- headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
- json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
- )
- processed_data = data_response.json().get("data", {})
- for item in processed_data.get("processed", {}).values():
- if item.get("data", {}).get("url") == url:
- return dict(item.get("data", {}))
- return None
+ return cls._get_jinareader_url_data(job_id, url, api_key)
else:
raise ValueError("Invalid provider")
@classmethod
- def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict:
- credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
- if provider == "firecrawl":
- # decrypt api_key
- api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
- firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
- params = {"onlyMainContent": only_main_content}
- result = firecrawl_app.scrape_url(url, params)
- return result
- elif provider == "watercrawl":
- api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
- return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).scrape_url(url)
+ def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
+ crawl_data: list[dict[str, Any]] | None = None
+ file_key = "website_files/" + job_id + ".txt"
+ if storage.exists(file_key):
+ stored_data = storage.load_once(file_key)
+ if stored_data:
+ crawl_data = json.loads(stored_data.decode("utf-8"))
+ else:
+ firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
+ result = firecrawl_app.check_crawl_status(job_id)
+ if result.get("status") != "completed":
+ raise ValueError("Crawl job is not completed")
+ crawl_data = result.get("data")
+
+ if crawl_data:
+ for item in crawl_data:
+ if item.get("source_url") == url:
+ return dict(item)
+ return None
+
+ @classmethod
+ def _get_watercrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
+ return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url)
+
+ @classmethod
+ def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
+ if not job_id:
+ response = requests.get(
+ f"https://r.jina.ai/{url}",
+ headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
+ )
+ if response.json().get("code") != 200:
+ raise ValueError("Failed to crawl")
+ return dict(response.json().get("data", {}))
+ else:
+ # Get crawl status first
+ status_response = requests.post(
+ "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
+ headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
+ json={"taskId": job_id},
+ )
+ status_data = status_response.json().get("data", {})
+ if status_data.get("status") != "completed":
+ raise ValueError("Crawl job is not completed")
+
+ # Get processed data
+ data_response = requests.post(
+ "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
+ headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
+ json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
+ )
+ processed_data = data_response.json().get("data", {})
+ for item in processed_data.get("processed", {}).values():
+ if item.get("data", {}).get("url") == url:
+ return dict(item.get("data", {}))
+ return None
+
+ @classmethod
+ def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict[str, Any]:
+ request = ScrapeRequest(provider=provider, url=url, tenant_id=tenant_id, only_main_content=only_main_content)
+
+ _, config = cls._get_credentials_and_config(tenant_id=request.tenant_id, provider=request.provider)
+ api_key = cls._get_decrypted_api_key(tenant_id=request.tenant_id, config=config)
+
+ if request.provider == "firecrawl":
+ return cls._scrape_with_firecrawl(request=request, api_key=api_key, config=config)
+ elif request.provider == "watercrawl":
+ return cls._scrape_with_watercrawl(request=request, api_key=api_key, config=config)
else:
raise ValueError("Invalid provider")
+
+ @classmethod
+ def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
+ firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
+ params = {"onlyMainContent": request.only_main_content}
+ return firecrawl_app.scrape_url(url=request.url, params=params)
+
+ @classmethod
+ def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
+ return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url)
diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py
index 6b30a70372..6eabf03018 100644
--- a/api/services/workflow_app_service.py
+++ b/api/services/workflow_app_service.py
@@ -5,7 +5,7 @@ from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
-from models import App, EndUser, WorkflowAppLog, WorkflowRun
+from models import Account, App, EndUser, WorkflowAppLog, WorkflowRun
from models.enums import CreatorUserRole
@@ -21,6 +21,8 @@ class WorkflowAppService:
created_at_after: datetime | None = None,
page: int = 1,
limit: int = 20,
+ created_by_end_user_session_id: str | None = None,
+ created_by_account: str | None = None,
) -> dict:
"""
Get paginate workflow app logs using SQLAlchemy 2.0 style
@@ -32,6 +34,8 @@ class WorkflowAppService:
:param created_at_after: filter logs created after this timestamp
:param page: page number
:param limit: items per page
+ :param created_by_end_user_session_id: filter by end user session id
+ :param created_by_account: filter by account email
:return: Pagination object
"""
# Build base statement using SQLAlchemy 2.0 style
@@ -71,6 +75,26 @@ class WorkflowAppService:
if created_at_after:
stmt = stmt.where(WorkflowAppLog.created_at >= created_at_after)
+ # Filter by end user session id or account email
+ if created_by_end_user_session_id:
+ stmt = stmt.join(
+ EndUser,
+ and_(
+ WorkflowAppLog.created_by == EndUser.id,
+ WorkflowAppLog.created_by_role == CreatorUserRole.END_USER,
+ EndUser.session_id == created_by_end_user_session_id,
+ ),
+ )
+ if created_by_account:
+ stmt = stmt.join(
+ Account,
+ and_(
+ WorkflowAppLog.created_by == Account.id,
+ WorkflowAppLog.created_by_role == CreatorUserRole.ACCOUNT,
+ Account.email == created_by_account,
+ ),
+ )
+
stmt = stmt.order_by(WorkflowAppLog.created_at.desc())
# Get total count using the same filters
diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py
new file mode 100644
index 0000000000..f306e1f062
--- /dev/null
+++ b/api/services/workflow_draft_variable_service.py
@@ -0,0 +1,746 @@
+import dataclasses
+import datetime
+import logging
+from collections.abc import Mapping, Sequence
+from enum import StrEnum
+from typing import Any, ClassVar
+
+from sqlalchemy import Engine, orm
+from sqlalchemy.dialects.postgresql import insert
+from sqlalchemy.orm import Session, sessionmaker
+from sqlalchemy.sql.expression import and_, or_
+
+from core.app.entities.app_invoke_entities import InvokeFrom
+from core.file.models import File
+from core.variables import Segment, StringSegment, Variable
+from core.variables.consts import MIN_SELECTORS_LENGTH
+from core.variables.segments import ArrayFileSegment, FileSegment
+from core.variables.types import SegmentType
+from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
+from core.workflow.enums import SystemVariableKey
+from core.workflow.nodes import NodeType
+from core.workflow.nodes.variable_assigner.common.helpers import get_updated_variables
+from core.workflow.variable_loader import VariableLoader
+from factories.file_factory import StorageKeyLoader
+from factories.variable_factory import build_segment, segment_to_variable
+from models import App, Conversation
+from models.enums import DraftVariableType
+from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable
+from repositories.factory import DifyAPIRepositoryFactory
+
+_logger = logging.getLogger(__name__)
+
+
+@dataclasses.dataclass(frozen=True)
+class WorkflowDraftVariableList:
+ variables: list[WorkflowDraftVariable]
+ total: int | None = None
+
+
+class WorkflowDraftVariableError(Exception):
+ pass
+
+
+class VariableResetError(WorkflowDraftVariableError):
+ pass
+
+
+class UpdateNotSupportedError(WorkflowDraftVariableError):
+ pass
+
+
+class DraftVarLoader(VariableLoader):
+ # This implements the VariableLoader interface for loading draft variables.
+ #
+ # ref: core.workflow.variable_loader.VariableLoader
+
+ # Database engine used for loading variables.
+ _engine: Engine
+ # Application ID for which variables are being loaded.
+ _app_id: str
+ _tenant_id: str
+ _fallback_variables: Sequence[Variable]
+
+ def __init__(
+ self,
+ engine: Engine,
+ app_id: str,
+ tenant_id: str,
+ fallback_variables: Sequence[Variable] | None = None,
+ ) -> None:
+ self._engine = engine
+ self._app_id = app_id
+ self._tenant_id = tenant_id
+ self._fallback_variables = fallback_variables or []
+
+ def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]:
+ return (selector[0], selector[1])
+
+ def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
+ if not selectors:
+ return []
+
+ # Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance.
+ variable_by_selector: dict[tuple[str, str], Variable] = {}
+
+ with Session(bind=self._engine, expire_on_commit=False) as session:
+ srv = WorkflowDraftVariableService(session)
+ draft_vars = srv.get_draft_variables_by_selectors(self._app_id, selectors)
+
+ for draft_var in draft_vars:
+ segment = draft_var.get_value()
+ variable = segment_to_variable(
+ segment=segment,
+ selector=draft_var.get_selector(),
+ id=draft_var.id,
+ name=draft_var.name,
+ description=draft_var.description,
+ )
+ selector_tuple = self._selector_to_tuple(variable.selector)
+ variable_by_selector[selector_tuple] = variable
+
+ # Important:
+ files: list[File] = []
+ for draft_var in draft_vars:
+ value = draft_var.get_value()
+ if isinstance(value, FileSegment):
+ files.append(value.value)
+ elif isinstance(value, ArrayFileSegment):
+ files.extend(value.value)
+ with Session(bind=self._engine) as session:
+ storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id)
+ storage_key_loader.load_storage_keys(files)
+
+ return list(variable_by_selector.values())
+
+
+class WorkflowDraftVariableService:
+ _session: Session
+
+ def __init__(self, session: Session) -> None:
+ """
+ Initialize the WorkflowDraftVariableService with a SQLAlchemy session.
+
+ Args:
+ session (Session): The SQLAlchemy session used to execute database queries.
+ The provided session must be bound to an `Engine` object, not a specific `Connection`.
+
+ Raises:
+ AssertionError: If the provided session is not bound to an `Engine` object.
+ """
+ self._session = session
+ engine = session.get_bind()
+ # Ensure the session is bound to a engine.
+ assert isinstance(engine, Engine)
+ session_maker = sessionmaker(bind=engine, expire_on_commit=False)
+ self._api_node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker
+ )
+
+ def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
+ return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first()
+
+ def get_draft_variables_by_selectors(
+ self,
+ app_id: str,
+ selectors: Sequence[list[str]],
+ ) -> list[WorkflowDraftVariable]:
+ ors = []
+ for selector in selectors:
+ assert len(selector) >= MIN_SELECTORS_LENGTH, f"Invalid selector to get: {selector}"
+ node_id, name = selector[:2]
+ ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name))
+
+ # NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as
+ # each expression includes conditions on both `node_id` and `name` (which are covered by the unique index),
+ # PostgreSQL can efficiently retrieve the results using a bitmap index scan.
+ #
+ # Alternatively, a `SELECT` statement could be constructed for each selector and
+ # combined using `UNION` to fetch all rows.
+ # Benchmarking indicates that both approaches yield comparable performance.
+ variables = (
+ self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.app_id == app_id, or_(*ors)).all()
+ )
+ return variables
+
+ def list_variables_without_values(self, app_id: str, page: int, limit: int) -> WorkflowDraftVariableList:
+ criteria = WorkflowDraftVariable.app_id == app_id
+ total = None
+ query = self._session.query(WorkflowDraftVariable).filter(criteria)
+ if page == 1:
+ total = query.count()
+ variables = (
+ # Do not load the `value` field.
+ query.options(orm.defer(WorkflowDraftVariable.value))
+ .order_by(WorkflowDraftVariable.created_at.desc())
+ .limit(limit)
+ .offset((page - 1) * limit)
+ .all()
+ )
+
+ return WorkflowDraftVariableList(variables=variables, total=total)
+
+ def _list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList:
+ criteria = (
+ WorkflowDraftVariable.app_id == app_id,
+ WorkflowDraftVariable.node_id == node_id,
+ )
+ query = self._session.query(WorkflowDraftVariable).filter(*criteria)
+ variables = query.order_by(WorkflowDraftVariable.created_at.desc()).all()
+ return WorkflowDraftVariableList(variables=variables)
+
+ def list_node_variables(self, app_id: str, node_id: str) -> WorkflowDraftVariableList:
+ return self._list_node_variables(app_id, node_id)
+
+ def list_conversation_variables(self, app_id: str) -> WorkflowDraftVariableList:
+ return self._list_node_variables(app_id, CONVERSATION_VARIABLE_NODE_ID)
+
+ def list_system_variables(self, app_id: str) -> WorkflowDraftVariableList:
+ return self._list_node_variables(app_id, SYSTEM_VARIABLE_NODE_ID)
+
+ def get_conversation_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None:
+ return self._get_variable(app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name)
+
+ def get_system_variable(self, app_id: str, name: str) -> WorkflowDraftVariable | None:
+ return self._get_variable(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name)
+
+ def get_node_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None:
+ return self._get_variable(app_id, node_id, name)
+
+ def _get_variable(self, app_id: str, node_id: str, name: str) -> WorkflowDraftVariable | None:
+ variable = (
+ self._session.query(WorkflowDraftVariable)
+ .where(
+ WorkflowDraftVariable.app_id == app_id,
+ WorkflowDraftVariable.node_id == node_id,
+ WorkflowDraftVariable.name == name,
+ )
+ .first()
+ )
+ return variable
+
+ def update_variable(
+ self,
+ variable: WorkflowDraftVariable,
+ name: str | None = None,
+ value: Segment | None = None,
+ ) -> WorkflowDraftVariable:
+ if not variable.editable:
+ raise UpdateNotSupportedError(f"variable not support updating, id={variable.id}")
+ if name is not None:
+ variable.set_name(name)
+ if value is not None:
+ variable.set_value(value)
+ variable.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ self._session.flush()
+ return variable
+
+ def _reset_conv_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None:
+ conv_var_by_name = {i.name: i for i in workflow.conversation_variables}
+ conv_var = conv_var_by_name.get(variable.name)
+
+ if conv_var is None:
+ self._session.delete(instance=variable)
+ self._session.flush()
+ _logger.warning(
+ "Conversation variable not found for draft variable, id=%s, name=%s", variable.id, variable.name
+ )
+ return None
+
+ variable.set_value(conv_var)
+ variable.last_edited_at = None
+ self._session.add(variable)
+ self._session.flush()
+ return variable
+
+ def _reset_node_var_or_sys_var(
+ self, workflow: Workflow, variable: WorkflowDraftVariable
+ ) -> WorkflowDraftVariable | None:
+ # If a variable does not allow updating, it makes no sence to resetting it.
+ if not variable.editable:
+ return variable
+ # No execution record for this variable, delete the variable instead.
+ if variable.node_execution_id is None:
+ self._session.delete(instance=variable)
+ self._session.flush()
+ _logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name)
+ return None
+
+ node_exec = self._api_node_execution_repo.get_execution_by_id(variable.node_execution_id)
+ if node_exec is None:
+ _logger.warning(
+ "Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s",
+ variable.id,
+ variable.name,
+ variable.node_execution_id,
+ )
+ self._session.delete(instance=variable)
+ self._session.flush()
+ return None
+
+ outputs_dict = node_exec.outputs_dict or {}
+ # a sentinel value used to check the absent of the output variable key.
+ absent = object()
+
+ if variable.get_variable_type() == DraftVariableType.NODE:
+ # Get node type for proper value extraction
+ node_config = workflow.get_node_config_by_id(variable.node_id)
+ node_type = workflow.get_node_type_from_node_config(node_config)
+
+ # Note: Based on the implementation in `_build_from_variable_assigner_mapping`,
+ # VariableAssignerNode (both v1 and v2) can only create conversation draft variables.
+ # For consistency, we should simply return when processing VARIABLE_ASSIGNER nodes.
+ #
+ # This implementation must remain synchronized with the `_build_from_variable_assigner_mapping`
+ # and `save` methods.
+ if node_type == NodeType.VARIABLE_ASSIGNER:
+ return variable
+ output_value = outputs_dict.get(variable.name, absent)
+ else:
+ output_value = outputs_dict.get(f"sys.{variable.name}", absent)
+
+ # We cannot use `is None` to check the existence of an output variable here as
+ # the value of the output may be `None`.
+ if output_value is absent:
+ # If variable not found in execution data, delete the variable
+ self._session.delete(instance=variable)
+ self._session.flush()
+ return None
+ value_seg = WorkflowDraftVariable.build_segment_with_type(variable.value_type, output_value)
+ # Extract variable value using unified logic
+ variable.set_value(value_seg)
+ variable.last_edited_at = None # Reset to indicate this is a reset operation
+ self._session.flush()
+ return variable
+
+ def reset_variable(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None:
+ variable_type = variable.get_variable_type()
+ if variable_type == DraftVariableType.SYS and not is_system_variable_editable(variable.name):
+ raise VariableResetError(f"cannot reset system variable, variable_id={variable.id}")
+ if variable_type == DraftVariableType.CONVERSATION:
+ return self._reset_conv_var(workflow, variable)
+ else:
+ return self._reset_node_var_or_sys_var(workflow, variable)
+
+ def delete_variable(self, variable: WorkflowDraftVariable):
+ self._session.delete(variable)
+
+ def delete_workflow_variables(self, app_id: str):
+ (
+ self._session.query(WorkflowDraftVariable)
+ .filter(WorkflowDraftVariable.app_id == app_id)
+ .delete(synchronize_session=False)
+ )
+
+ def delete_node_variables(self, app_id: str, node_id: str):
+ return self._delete_node_variables(app_id, node_id)
+
+ def _delete_node_variables(self, app_id: str, node_id: str):
+ self._session.query(WorkflowDraftVariable).where(
+ WorkflowDraftVariable.app_id == app_id,
+ WorkflowDraftVariable.node_id == node_id,
+ ).delete()
+
+ def _get_conversation_id_from_draft_variable(self, app_id: str) -> str | None:
+ draft_var = self._get_variable(
+ app_id=app_id,
+ node_id=SYSTEM_VARIABLE_NODE_ID,
+ name=str(SystemVariableKey.CONVERSATION_ID),
+ )
+ if draft_var is None:
+ return None
+ segment = draft_var.get_value()
+ if not isinstance(segment, StringSegment):
+ _logger.warning(
+ "sys.conversation_id variable is not a string: app_id=%s, id=%s",
+ app_id,
+ draft_var.id,
+ )
+ return None
+ return segment.value
+
+ def get_or_create_conversation(
+ self,
+ account_id: str,
+ app: App,
+ workflow: Workflow,
+ ) -> str:
+ """
+ get_or_create_conversation creates and returns the ID of a conversation for debugging.
+
+ If a conversation already exists, as determined by the following criteria, its ID is returned:
+ - The system variable `sys.conversation_id` exists in the draft variable table, and
+ - A corresponding conversation record is found in the database.
+
+ If no such conversation exists, a new conversation is created and its ID is returned.
+ """
+ conv_id = self._get_conversation_id_from_draft_variable(workflow.app_id)
+
+ if conv_id is not None:
+ conversation = (
+ self._session.query(Conversation)
+ .filter(
+ Conversation.id == conv_id,
+ Conversation.app_id == workflow.app_id,
+ )
+ .first()
+ )
+ # Only return the conversation ID if it exists and is valid (has a correspond conversation record in DB).
+ if conversation is not None:
+ return conv_id
+ conversation = Conversation(
+ app_id=workflow.app_id,
+ app_model_config_id=app.app_model_config_id,
+ model_provider=None,
+ model_id="",
+ override_model_configs=None,
+ mode=app.mode,
+ name="Draft Debugging Conversation",
+ inputs={},
+ introduction="",
+ system_instruction="",
+ system_instruction_tokens=0,
+ status="normal",
+ invoke_from=InvokeFrom.DEBUGGER.value,
+ from_source="console",
+ from_end_user_id=None,
+ from_account_id=account_id,
+ )
+
+ self._session.add(conversation)
+ self._session.flush()
+ return conversation.id
+
+ def prefill_conversation_variable_default_values(self, workflow: Workflow):
+ """"""
+ draft_conv_vars: list[WorkflowDraftVariable] = []
+ for conv_var in workflow.conversation_variables:
+ draft_var = WorkflowDraftVariable.new_conversation_variable(
+ app_id=workflow.app_id,
+ name=conv_var.name,
+ value=conv_var,
+ description=conv_var.description,
+ )
+ draft_conv_vars.append(draft_var)
+ _batch_upsert_draft_varaible(
+ self._session,
+ draft_conv_vars,
+ policy=_UpsertPolicy.IGNORE,
+ )
+
+
+class _UpsertPolicy(StrEnum):
+ IGNORE = "ignore"
+ OVERWRITE = "overwrite"
+
+
+def _batch_upsert_draft_varaible(
+ session: Session,
+ draft_vars: Sequence[WorkflowDraftVariable],
+ policy: _UpsertPolicy = _UpsertPolicy.OVERWRITE,
+) -> None:
+ if not draft_vars:
+ return None
+ # Although we could use SQLAlchemy ORM operations here, we choose not to for several reasons:
+ #
+ # 1. The variable saving process involves writing multiple rows to the
+ # `workflow_draft_variables` table. Batch insertion significantly improves performance.
+ # 2. Using the ORM would require either:
+ #
+ # a. Checking for the existence of each variable before insertion,
+ # resulting in 2n SQL statements for n variables and potential concurrency issues.
+ # b. Attempting insertion first, then updating if a unique index violation occurs,
+ # which still results in n to 2n SQL statements.
+ #
+ # Both approaches are inefficient and suboptimal.
+ # 3. We do not need to retrieve the results of the SQL execution or populate ORM
+ # model instances with the returned values.
+ # 4. Batch insertion with `ON CONFLICT DO UPDATE` allows us to insert or update all
+ # variables in a single SQL statement, avoiding the issues above.
+ #
+ # For these reasons, we use the SQLAlchemy query builder and rely on dialect-specific
+ # insert operations instead of the ORM layer.
+ stmt = insert(WorkflowDraftVariable).values([_model_to_insertion_dict(v) for v in draft_vars])
+ if policy == _UpsertPolicy.OVERWRITE:
+ stmt = stmt.on_conflict_do_update(
+ index_elements=WorkflowDraftVariable.unique_app_id_node_id_name(),
+ set_={
+ # Refresh creation timestamp to ensure updated variables
+ # appear first in chronologically sorted result sets.
+ "created_at": stmt.excluded.created_at,
+ "updated_at": stmt.excluded.updated_at,
+ "last_edited_at": stmt.excluded.last_edited_at,
+ "description": stmt.excluded.description,
+ "value_type": stmt.excluded.value_type,
+ "value": stmt.excluded.value,
+ "visible": stmt.excluded.visible,
+ "editable": stmt.excluded.editable,
+ "node_execution_id": stmt.excluded.node_execution_id,
+ },
+ )
+ elif _UpsertPolicy.IGNORE:
+ stmt = stmt.on_conflict_do_nothing(index_elements=WorkflowDraftVariable.unique_app_id_node_id_name())
+ else:
+ raise Exception("Invalid value for update policy.")
+ session.execute(stmt)
+
+
+def _model_to_insertion_dict(model: WorkflowDraftVariable) -> dict[str, Any]:
+ d: dict[str, Any] = {
+ "app_id": model.app_id,
+ "last_edited_at": None,
+ "node_id": model.node_id,
+ "name": model.name,
+ "selector": model.selector,
+ "value_type": model.value_type,
+ "value": model.value,
+ "node_execution_id": model.node_execution_id,
+ }
+ if model.visible is not None:
+ d["visible"] = model.visible
+ if model.editable is not None:
+ d["editable"] = model.editable
+ if model.created_at is not None:
+ d["created_at"] = model.created_at
+ if model.updated_at is not None:
+ d["updated_at"] = model.updated_at
+ if model.description is not None:
+ d["description"] = model.description
+ return d
+
+
+def _build_segment_for_serialized_values(v: Any) -> Segment:
+ """
+ Reconstructs Segment objects from serialized values, with special handling
+ for FileSegment and ArrayFileSegment types.
+
+ This function should only be used when:
+ 1. No explicit type information is available
+ 2. The input value is in serialized form (dict or list)
+
+ It detects potential file objects in the serialized data and properly rebuilds the
+ appropriate segment type.
+ """
+ return build_segment(WorkflowDraftVariable.rebuild_file_types(v))
+
+
+class DraftVariableSaver:
+ # _DUMMY_OUTPUT_IDENTITY is a placeholder output for workflow nodes.
+ # Its sole possible value is `None`.
+ #
+ # This is used to signal the execution of a workflow node when it has no other outputs.
+ _DUMMY_OUTPUT_IDENTITY: ClassVar[str] = "__dummy__"
+ _DUMMY_OUTPUT_VALUE: ClassVar[None] = None
+
+ # _EXCLUDE_VARIABLE_NAMES_MAPPING maps node types and versions to variable names that
+ # should be excluded when saving draft variables. This prevents certain internal or
+ # technical variables from being exposed in the draft environment, particularly those
+ # that aren't meant to be directly edited or viewed by users.
+ _EXCLUDE_VARIABLE_NAMES_MAPPING: dict[NodeType, frozenset[str]] = {
+ NodeType.LLM: frozenset(["finish_reason"]),
+ NodeType.LOOP: frozenset(["loop_round"]),
+ }
+
+ # Database session used for persisting draft variables.
+ _session: Session
+
+ # The application ID associated with the draft variables.
+ # This should match the `Workflow.app_id` of the workflow to which the current node belongs.
+ _app_id: str
+
+ # The ID of the node for which DraftVariableSaver is saving output variables.
+ _node_id: str
+
+ # The type of the current node (see NodeType).
+ _node_type: NodeType
+
+ #
+ _node_execution_id: str
+
+ # _enclosing_node_id identifies the container node that the current node belongs to.
+ # For example, if the current node is an LLM node inside an Iteration node
+ # or Loop node, then `_enclosing_node_id` refers to the ID of
+ # the containing Iteration or Loop node.
+ #
+ # If the current node is not nested within another node, `_enclosing_node_id` is
+ # `None`.
+ _enclosing_node_id: str | None
+
+ def __init__(
+ self,
+ session: Session,
+ app_id: str,
+ node_id: str,
+ node_type: NodeType,
+ node_execution_id: str,
+ enclosing_node_id: str | None = None,
+ ):
+ # Important: `node_execution_id` parameter refers to the primary key (`id`) of the
+ # WorkflowNodeExecutionModel/WorkflowNodeExecution, not their `node_execution_id`
+ # field. These are distinct database fields with different purposes.
+ self._session = session
+ self._app_id = app_id
+ self._node_id = node_id
+ self._node_type = node_type
+ self._node_execution_id = node_execution_id
+ self._enclosing_node_id = enclosing_node_id
+
+ def _create_dummy_output_variable(self):
+ return WorkflowDraftVariable.new_node_variable(
+ app_id=self._app_id,
+ node_id=self._node_id,
+ name=self._DUMMY_OUTPUT_IDENTITY,
+ node_execution_id=self._node_execution_id,
+ value=build_segment(self._DUMMY_OUTPUT_VALUE),
+ visible=False,
+ editable=False,
+ )
+
+ def _should_save_output_variables_for_draft(self) -> bool:
+ if self._enclosing_node_id is not None and self._node_type != NodeType.VARIABLE_ASSIGNER:
+ # Currently we do not save output variables for nodes inside loop or iteration.
+ return False
+ return True
+
+ def _build_from_variable_assigner_mapping(self, process_data: Mapping[str, Any]) -> list[WorkflowDraftVariable]:
+ draft_vars: list[WorkflowDraftVariable] = []
+ updated_variables = get_updated_variables(process_data) or []
+
+ for item in updated_variables:
+ selector = item.selector
+ if len(selector) < MIN_SELECTORS_LENGTH:
+ raise Exception("selector too short")
+ # NOTE(QuantumGhost): only the following two kinds of variable could be updated by
+ # VariableAssigner: ConversationVariable and iteration variable.
+ # We only save conversation variable here.
+ if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
+ continue
+ segment = WorkflowDraftVariable.build_segment_with_type(segment_type=item.value_type, value=item.new_value)
+ draft_vars.append(
+ WorkflowDraftVariable.new_conversation_variable(
+ app_id=self._app_id,
+ name=item.name,
+ value=segment,
+ )
+ )
+ # Add a dummy output variable to indicate that this node is executed.
+ draft_vars.append(self._create_dummy_output_variable())
+ return draft_vars
+
+ def _build_variables_from_start_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]:
+ draft_vars = []
+ has_non_sys_variables = False
+ for name, value in output.items():
+ value_seg = _build_segment_for_serialized_values(value)
+ node_id, name = self._normalize_variable_for_start_node(name)
+ # If node_id is not `sys`, it means that the variable is a user-defined input field
+ # in `Start` node.
+ if node_id != SYSTEM_VARIABLE_NODE_ID:
+ draft_vars.append(
+ WorkflowDraftVariable.new_node_variable(
+ app_id=self._app_id,
+ node_id=self._node_id,
+ name=name,
+ node_execution_id=self._node_execution_id,
+ value=value_seg,
+ visible=True,
+ editable=True,
+ )
+ )
+ has_non_sys_variables = True
+ else:
+ if name == SystemVariableKey.FILES:
+ # Here we know the type of variable must be `array[file]`, we
+ # just build files from the value.
+ files = [File.model_validate(v) for v in value]
+ if files:
+ value_seg = WorkflowDraftVariable.build_segment_with_type(SegmentType.ARRAY_FILE, files)
+ else:
+ value_seg = ArrayFileSegment(value=[])
+
+ draft_vars.append(
+ WorkflowDraftVariable.new_sys_variable(
+ app_id=self._app_id,
+ name=name,
+ node_execution_id=self._node_execution_id,
+ value=value_seg,
+ editable=self._should_variable_be_editable(node_id, name),
+ )
+ )
+ if not has_non_sys_variables:
+ draft_vars.append(self._create_dummy_output_variable())
+ return draft_vars
+
+ def _normalize_variable_for_start_node(self, name: str) -> tuple[str, str]:
+ if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."):
+ return self._node_id, name
+ _, name_ = name.split(".", maxsplit=1)
+ return SYSTEM_VARIABLE_NODE_ID, name_
+
+ def _build_variables_from_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]:
+ draft_vars = []
+ for name, value in output.items():
+ if not self._should_variable_be_saved(name):
+ _logger.debug(
+ "Skip saving variable as it has been excluded by its node_type, name=%s, node_type=%s",
+ name,
+ self._node_type,
+ )
+ continue
+ if isinstance(value, Segment):
+ value_seg = value
+ else:
+ value_seg = _build_segment_for_serialized_values(value)
+ draft_vars.append(
+ WorkflowDraftVariable.new_node_variable(
+ app_id=self._app_id,
+ node_id=self._node_id,
+ name=name,
+ node_execution_id=self._node_execution_id,
+ value=value_seg,
+ visible=self._should_variable_be_visible(self._node_id, self._node_type, name),
+ )
+ )
+ return draft_vars
+
+ def save(
+ self,
+ process_data: Mapping[str, Any] | None = None,
+ outputs: Mapping[str, Any] | None = None,
+ ):
+ draft_vars: list[WorkflowDraftVariable] = []
+ if outputs is None:
+ outputs = {}
+ if process_data is None:
+ process_data = {}
+ if not self._should_save_output_variables_for_draft():
+ return
+ if self._node_type == NodeType.VARIABLE_ASSIGNER:
+ draft_vars = self._build_from_variable_assigner_mapping(process_data=process_data)
+ elif self._node_type == NodeType.START:
+ draft_vars = self._build_variables_from_start_mapping(outputs)
+ else:
+ draft_vars = self._build_variables_from_mapping(outputs)
+ _batch_upsert_draft_varaible(self._session, draft_vars)
+
+ @staticmethod
+ def _should_variable_be_editable(node_id: str, name: str) -> bool:
+ if node_id in (CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID):
+ return False
+ if node_id == SYSTEM_VARIABLE_NODE_ID and not is_system_variable_editable(name):
+ return False
+ return True
+
+ @staticmethod
+ def _should_variable_be_visible(node_id: str, node_type: NodeType, name: str) -> bool:
+ if node_type in NodeType.IF_ELSE:
+ return False
+ if node_id == SYSTEM_VARIABLE_NODE_ID and not is_system_variable_editable(name):
+ return False
+ return True
+
+ def _should_variable_be_saved(self, name: str) -> bool:
+ exclude_var_names = self._EXCLUDE_VARIABLE_NAMES_MAPPING.get(self._node_type)
+ if exclude_var_names is None:
+ return True
+ return name not in exclude_var_names
diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py
index 483c0d3086..e43999a8c9 100644
--- a/api/services/workflow_run_service.py
+++ b/api/services/workflow_run_service.py
@@ -2,9 +2,9 @@ import threading
from collections.abc import Sequence
from typing import Optional
+from sqlalchemy.orm import sessionmaker
+
import contexts
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
-from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import (
@@ -15,10 +15,18 @@ from models import (
WorkflowRun,
WorkflowRunTriggeredFrom,
)
-from models.workflow import WorkflowNodeExecutionTriggeredFrom
+from repositories.factory import DifyAPIRepositoryFactory
class WorkflowRunService:
+ def __init__(self):
+ """Initialize WorkflowRunService with repository dependencies."""
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker
+ )
+ self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+
def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination:
"""
Get advanced chat app workflow run list
@@ -62,45 +70,16 @@ class WorkflowRunService:
:param args: request args
"""
limit = int(args.get("limit", 20))
+ last_id = args.get("last_id")
- base_query = db.session.query(WorkflowRun).filter(
- WorkflowRun.tenant_id == app_model.tenant_id,
- WorkflowRun.app_id == app_model.id,
- WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value,
+ return self._workflow_run_repo.get_paginated_workflow_runs(
+ tenant_id=app_model.tenant_id,
+ app_id=app_model.id,
+ triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value,
+ limit=limit,
+ last_id=last_id,
)
- if args.get("last_id"):
- last_workflow_run = base_query.filter(
- WorkflowRun.id == args.get("last_id"),
- ).first()
-
- if not last_workflow_run:
- raise ValueError("Last workflow run not exists")
-
- workflow_runs = (
- base_query.filter(
- WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
- )
- .order_by(WorkflowRun.created_at.desc())
- .limit(limit)
- .all()
- )
- else:
- workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
-
- has_more = False
- if len(workflow_runs) == limit:
- current_page_first_workflow_run = workflow_runs[-1]
- rest_count = base_query.filter(
- WorkflowRun.created_at < current_page_first_workflow_run.created_at,
- WorkflowRun.id != current_page_first_workflow_run.id,
- ).count()
-
- if rest_count > 0:
- has_more = True
-
- return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
-
def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]:
"""
Get workflow run detail
@@ -108,18 +87,12 @@ class WorkflowRunService:
:param app_model: app model
:param run_id: workflow run id
"""
- workflow_run = (
- db.session.query(WorkflowRun)
- .filter(
- WorkflowRun.tenant_id == app_model.tenant_id,
- WorkflowRun.app_id == app_model.id,
- WorkflowRun.id == run_id,
- )
- .first()
+ return self._workflow_run_repo.get_workflow_run_by_id(
+ tenant_id=app_model.tenant_id,
+ app_id=app_model.id,
+ run_id=run_id,
)
- return workflow_run
-
def get_workflow_run_node_executions(
self,
app_model: App,
@@ -137,17 +110,13 @@ class WorkflowRunService:
if not workflow_run:
return []
- repository = SQLAlchemyWorkflowNodeExecutionRepository(
- session_factory=db.engine,
- user=user,
- app_id=app_model.id,
- triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
- )
+ # Get tenant_id from user
+ tenant_id = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
+ if tenant_id is None:
+ raise ValueError("User tenant_id cannot be None")
- # Use the repository to get the database models directly
- order_config = OrderConfig(order_by=["index"], order_direction="desc")
- workflow_node_executions = repository.get_db_models_by_workflow_run(
- workflow_run_id=run_id, order_config=order_config
+ return self._node_execution_service_repo.get_executions_by_workflow_run(
+ tenant_id=tenant_id,
+ app_id=app_model.id,
+ workflow_run_id=run_id,
)
-
- return workflow_node_executions
diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py
index bc213ccce6..403e559743 100644
--- a/api/services/workflow_service.py
+++ b/api/services/workflow_service.py
@@ -1,18 +1,22 @@
import json
import time
-from collections.abc import Callable, Generator, Sequence
-from datetime import UTC, datetime
-from typing import Any, Optional
+import uuid
+from collections.abc import Callable, Generator, Mapping, Sequence
+from typing import Any, Optional, cast
from uuid import uuid4
from sqlalchemy import select
-from sqlalchemy.orm import Session
+from sqlalchemy.orm import Session, sessionmaker
+from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.file import File
+from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
+from core.variables.variables import VariableUnion
from core.workflow.entities.node_entities import NodeRunResult
+from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import InNodeEvent
@@ -22,9 +26,13 @@ from core.workflow.nodes.enums import ErrorStrategy
from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.event.types import NodeEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
+from core.workflow.nodes.start.entities import StartNodeData
+from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
+from factories.file_factory import build_from_mapping, build_from_mappings
+from libs.datetime_utils import naive_utc_now
from models.account import Account
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
@@ -34,10 +42,16 @@ from models.workflow import (
WorkflowNodeExecutionTriggeredFrom,
WorkflowType,
)
-from services.errors.app import WorkflowHashNotEqualError
+from repositories.factory import DifyAPIRepositoryFactory
+from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter
from .errors.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError
+from .workflow_draft_variable_service import (
+ DraftVariableSaver,
+ DraftVarLoader,
+ WorkflowDraftVariableService,
+)
class WorkflowService:
@@ -45,6 +59,44 @@ class WorkflowService:
Workflow Service
"""
+ def __init__(self, session_maker: sessionmaker | None = None):
+ """Initialize WorkflowService with repository dependencies."""
+ if session_maker is None:
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker
+ )
+
+ def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None:
+ """
+ Get the most recent execution for a specific node.
+
+ Args:
+ app_model: The application model
+ workflow: The workflow model
+ node_id: The node identifier
+
+ Returns:
+ The most recent WorkflowNodeExecutionModel for the node, or None if not found
+ """
+ return self._node_execution_service_repo.get_node_last_execution(
+ tenant_id=app_model.tenant_id,
+ app_id=app_model.id,
+ workflow_id=workflow.id,
+ node_id=node_id,
+ )
+
+ def is_workflow_exist(self, app_model: App) -> bool:
+ return (
+ db.session.query(Workflow)
+ .filter(
+ Workflow.tenant_id == app_model.tenant_id,
+ Workflow.app_id == app_model.id,
+ Workflow.version == Workflow.VERSION_DRAFT,
+ )
+ .count()
+ ) > 0
+
def get_draft_workflow(self, app_model: App) -> Optional[Workflow]:
"""
Get draft workflow
@@ -61,6 +113,23 @@ class WorkflowService:
# return draft workflow
return workflow
+ def get_published_workflow_by_id(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
+ # fetch published workflow by workflow_id
+ workflow = (
+ db.session.query(Workflow)
+ .filter(
+ Workflow.tenant_id == app_model.tenant_id,
+ Workflow.app_id == app_model.id,
+ Workflow.id == workflow_id,
+ )
+ .first()
+ )
+ if not workflow:
+ return None
+ if workflow.version == Workflow.VERSION_DRAFT:
+ raise IsDraftWorkflowError(f"Workflow is draft version, id={workflow_id}")
+ return workflow
+
def get_published_workflow(self, app_model: App) -> Optional[Workflow]:
"""
Get published workflow
@@ -163,7 +232,7 @@ class WorkflowService:
workflow.graph = json.dumps(graph)
workflow.features = json.dumps(features)
workflow.updated_by = account.id
- workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ workflow.updated_at = naive_utc_now()
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
@@ -199,7 +268,7 @@ class WorkflowService:
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type=draft_workflow.type,
- version=str(datetime.now(UTC).replace(tzinfo=None)),
+ version=Workflow.version_from_datetime(naive_utc_now()),
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
@@ -253,26 +322,85 @@ class WorkflowService:
return default_config
def run_draft_workflow_node(
- self, app_model: App, node_id: str, user_inputs: dict, account: Account
+ self,
+ app_model: App,
+ draft_workflow: Workflow,
+ node_id: str,
+ user_inputs: Mapping[str, Any],
+ account: Account,
+ query: str = "",
+ files: Sequence[File] | None = None,
) -> WorkflowNodeExecutionModel:
"""
Run draft workflow node
"""
- # fetch draft workflow by app_model
- draft_workflow = self.get_draft_workflow(app_model=app_model)
- if not draft_workflow:
- raise ValueError("Workflow not initialized")
+ files = files or []
+
+ with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
+ draft_var_srv = WorkflowDraftVariableService(session)
+ draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
+
+ node_config = draft_workflow.get_node_config_by_id(node_id)
+ node_type = Workflow.get_node_type_from_node_config(node_config)
+ node_data = node_config.get("data", {})
+ if node_type == NodeType.START:
+ with Session(bind=db.engine) as session, session.begin():
+ draft_var_srv = WorkflowDraftVariableService(session)
+ conversation_id = draft_var_srv.get_or_create_conversation(
+ account_id=account.id,
+ app=app_model,
+ workflow=draft_workflow,
+ )
+ start_data = StartNodeData.model_validate(node_data)
+ user_inputs = _rebuild_file_for_user_inputs_in_start_node(
+ tenant_id=draft_workflow.tenant_id, start_node_data=start_data, user_inputs=user_inputs
+ )
+ # init variable pool
+ variable_pool = _setup_variable_pool(
+ query=query,
+ files=files or [],
+ user_id=account.id,
+ user_inputs=user_inputs,
+ workflow=draft_workflow,
+ # NOTE(QuantumGhost): We rely on `DraftVarLoader` to load conversation variables.
+ conversation_variables=[],
+ node_type=node_type,
+ conversation_id=conversation_id,
+ )
+
+ else:
+ variable_pool = VariablePool(
+ system_variables=SystemVariable.empty(),
+ user_inputs=user_inputs,
+ environment_variables=draft_workflow.environment_variables,
+ conversation_variables=[],
+ )
+
+ variable_loader = DraftVarLoader(
+ engine=db.engine,
+ app_id=app_model.id,
+ tenant_id=app_model.tenant_id,
+ )
+
+ eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)
+ if eclosing_node_type_and_id:
+ _, enclosing_node_id = eclosing_node_type_and_id
+ else:
+ enclosing_node_id = None
+
+ run = WorkflowEntry.single_step_run(
+ workflow=draft_workflow,
+ node_id=node_id,
+ user_inputs=user_inputs,
+ user_id=account.id,
+ variable_pool=variable_pool,
+ variable_loader=variable_loader,
+ )
# run draft workflow node
start_at = time.perf_counter()
-
node_execution = self._handle_node_run_result(
- invoke_node_fn=lambda: WorkflowEntry.single_step_run(
- workflow=draft_workflow,
- node_id=node_id,
- user_inputs=user_inputs,
- user_id=account.id,
- ),
+ invoke_node_fn=lambda: run,
start_at=start_at,
node_id=node_id,
)
@@ -281,7 +409,7 @@ class WorkflowService:
node_execution.workflow_id = draft_workflow.id
# Create repository and save the node execution
- repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=db.engine,
user=account,
app_id=app_model.id,
@@ -289,8 +417,21 @@ class WorkflowService:
)
repository.save(node_execution)
- # Convert node_execution to WorkflowNodeExecution after save
- workflow_node_execution = repository.to_db_model(node_execution)
+ workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id)
+ if workflow_node_execution is None:
+ raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
+
+ with Session(bind=db.engine) as session, session.begin():
+ draft_var_saver = DraftVariableSaver(
+ session=session,
+ app_id=app_model.id,
+ node_id=workflow_node_execution.node_id,
+ node_type=NodeType(workflow_node_execution.node_type),
+ enclosing_node_id=enclosing_node_id,
+ node_execution_id=node_execution.id,
+ )
+ draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs)
+ session.commit()
return workflow_node_execution
@@ -303,7 +444,7 @@ class WorkflowService:
# run draft workflow node
start_at = time.perf_counter()
- workflow_node_execution = self._handle_node_run_result(
+ node_execution = self._handle_node_run_result(
invoke_node_fn=lambda: WorkflowEntry.run_free_node(
node_id=node_id,
node_data=node_data,
@@ -315,7 +456,7 @@ class WorkflowService:
node_id=node_id,
)
- return workflow_node_execution
+ return node_execution
def _handle_node_run_result(
self,
@@ -324,32 +465,32 @@ class WorkflowService:
node_id: str,
) -> WorkflowNodeExecution:
try:
- node_instance, generator = invoke_node_fn()
+ node, node_events = invoke_node_fn()
node_run_result: NodeRunResult | None = None
- for event in generator:
+ for event in node_events:
if isinstance(event, RunCompletedEvent):
node_run_result = event.run_result
# sign output files
- node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
+ # node_run_result.outputs = WorkflowEntry.handle_special_values(node_run_result.outputs)
break
if not node_run_result:
raise ValueError("Node run failed with no run result")
# single step debug mode error handling return
- if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
+ if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.continue_on_error:
node_error_args: dict[str, Any] = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
"inputs": node_run_result.inputs,
- "metadata": {"error_strategy": node_instance.node_data.error_strategy},
+ "metadata": {"error_strategy": node.error_strategy},
}
- if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
+ if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
- **node_instance.node_data.default_value_dict,
+ **node.default_value_dict,
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
@@ -368,10 +509,10 @@ class WorkflowService:
)
error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e:
- node_instance = e.node_instance
+ node = e._node
run_succeeded = False
node_run_result = None
- error = e.error
+ error = e._error
# Create a NodeExecution domain model
node_execution = WorkflowNodeExecution(
@@ -379,11 +520,11 @@ class WorkflowService:
workflow_id="", # This is a single-step execution, so no workflow ID
index=1,
node_id=node_id,
- node_type=node_instance.node_type,
- title=node_instance.node_data.title,
+ node_type=node.type_,
+ title=node.title,
elapsed_time=time.perf_counter() - start_at,
- created_at=datetime.now(UTC).replace(tzinfo=None),
- finished_at=datetime.now(UTC).replace(tzinfo=None),
+ created_at=naive_utc_now(),
+ finished_at=naive_utc_now(),
)
if run_succeeded and node_run_result:
@@ -394,7 +535,7 @@ class WorkflowService:
if node_run_result.process_data
else None
)
- outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None
+ outputs = node_run_result.outputs
node_execution.inputs = inputs
node_execution.process_data = process_data
@@ -480,7 +621,7 @@ class WorkflowService:
setattr(workflow, field, value)
workflow.updated_by = account_id
- workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ workflow.updated_at = naive_utc_now()
return workflow
@@ -531,3 +672,77 @@ class WorkflowService:
session.delete(workflow)
return True
+
+
+def _setup_variable_pool(
+ query: str,
+ files: Sequence[File],
+ user_id: str,
+ user_inputs: Mapping[str, Any],
+ workflow: Workflow,
+ node_type: NodeType,
+ conversation_id: str,
+ conversation_variables: list[Variable],
+):
+ # Only inject system variables for START node type.
+ if node_type == NodeType.START:
+ system_variable = SystemVariable(
+ user_id=user_id,
+ app_id=workflow.app_id,
+ workflow_id=workflow.id,
+ files=files or [],
+ workflow_execution_id=str(uuid.uuid4()),
+ )
+
+ # Only add chatflow-specific variables for non-workflow types
+ if workflow.type != WorkflowType.WORKFLOW.value:
+ system_variable.query = query
+ system_variable.conversation_id = conversation_id
+ system_variable.dialogue_count = 0
+ else:
+ system_variable = SystemVariable.empty()
+
+ # init variable pool
+ variable_pool = VariablePool(
+ system_variables=system_variable,
+ user_inputs=user_inputs,
+ environment_variables=workflow.environment_variables,
+ # Based on the definition of `VariableUnion`,
+ # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
+ conversation_variables=cast(list[VariableUnion], conversation_variables), #
+ )
+
+ return variable_pool
+
+
+def _rebuild_file_for_user_inputs_in_start_node(
+ tenant_id: str, start_node_data: StartNodeData, user_inputs: Mapping[str, Any]
+) -> Mapping[str, Any]:
+ inputs_copy = dict(user_inputs)
+
+ for variable in start_node_data.variables:
+ if variable.type not in (VariableEntityType.FILE, VariableEntityType.FILE_LIST):
+ continue
+ if variable.variable not in user_inputs:
+ continue
+ value = user_inputs[variable.variable]
+ file = _rebuild_single_file(tenant_id=tenant_id, value=value, variable_entity_type=variable.type)
+ inputs_copy[variable.variable] = file
+ return inputs_copy
+
+
+def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: VariableEntityType) -> File | Sequence[File]:
+ if variable_entity_type == VariableEntityType.FILE:
+ if not isinstance(value, dict):
+ raise ValueError(f"expected dict for file object, got {type(value)}")
+ return build_from_mapping(mapping=value, tenant_id=tenant_id)
+ elif variable_entity_type == VariableEntityType.FILE_LIST:
+ if not isinstance(value, list):
+ raise ValueError(f"expected list for file list object, got {type(value)}")
+ if len(value) == 0:
+ return []
+ if not isinstance(value[0], dict):
+ raise ValueError(f"expected dict for first element in the file list, got {type(value)}")
+ return build_from_mappings(mappings=value, tenant_id=tenant_id)
+ else:
+ raise Exception("unreachable")
diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py
index 125e0c1b1e..bb35645c50 100644
--- a/api/services/workspace_service.py
+++ b/api/services/workspace_service.py
@@ -31,7 +31,7 @@ class WorkspaceService:
assert tenant_account_join is not None, "TenantAccountJoin not found"
tenant_info["role"] = tenant_account_join.role
- can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo
+ can_replace_logo = FeatureService.get_features(tenant.id).can_replace_logo
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]):
base_url = dify_config.FILES_URL
diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py
index 5824121e8f..c72a3319c1 100644
--- a/api/tasks/clean_document_task.py
+++ b/api/tasks/clean_document_task.py
@@ -72,6 +72,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
+ db.session.commit()
end_at = time.perf_counter()
logging.info(
diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py
index 55cac6a9af..a85aab0bb7 100644
--- a/api/tasks/document_indexing_task.py
+++ b/api/tasks/document_indexing_task.py
@@ -1,4 +1,3 @@
-import datetime
import logging
import time
@@ -8,6 +7,7 @@ from celery import shared_task # type: ignore
from configs import dify_config
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
from services.feature_service import FeatureService
@@ -53,7 +53,7 @@ def document_indexing_task(dataset_id: str, document_ids: list):
if document:
document.indexing_status = "error"
document.error = str(e)
- document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
db.session.close()
@@ -68,7 +68,7 @@ def document_indexing_task(dataset_id: str, document_ids: list):
if document:
document.indexing_status = "parsing"
- document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
diff --git a/api/tasks/mail_change_mail_task.py b/api/tasks/mail_change_mail_task.py
new file mode 100644
index 0000000000..da44040b7d
--- /dev/null
+++ b/api/tasks/mail_change_mail_task.py
@@ -0,0 +1,78 @@
+import logging
+import time
+
+import click
+from celery import shared_task # type: ignore
+from flask import render_template
+
+from extensions.ext_mail import mail
+from services.feature_service import FeatureService
+
+
+@shared_task(queue="mail")
+def send_change_mail_task(language: str, to: str, code: str, phase: str):
+ """
+ Async Send change email mail
+ :param language: Language in which the email should be sent (e.g., 'en', 'zh')
+ :param to: Recipient email address
+ :param code: Change email code
+ :param phase: Change email phase (new_email, old_email)
+ """
+ if not mail.is_inited():
+ return
+
+ logging.info(click.style("Start change email mail to {}".format(to), fg="green"))
+ start_at = time.perf_counter()
+
+ email_config = {
+ "zh-Hans": {
+ "old_email": {
+ "subject": "检测您现在的邮箱",
+ "template_with_brand": "change_mail_confirm_old_template_zh-CN.html",
+ "template_without_brand": "without-brand/change_mail_confirm_old_template_zh-CN.html",
+ },
+ "new_email": {
+ "subject": "确认您的邮箱地址变更",
+ "template_with_brand": "change_mail_confirm_new_template_zh-CN.html",
+ "template_without_brand": "without-brand/change_mail_confirm_new_template_zh-CN.html",
+ },
+ },
+ "en": {
+ "old_email": {
+ "subject": "Check your current email",
+ "template_with_brand": "change_mail_confirm_old_template_en-US.html",
+ "template_without_brand": "without-brand/change_mail_confirm_old_template_en-US.html",
+ },
+ "new_email": {
+ "subject": "Confirm your new email address",
+ "template_with_brand": "change_mail_confirm_new_template_en-US.html",
+ "template_without_brand": "without-brand/change_mail_confirm_new_template_en-US.html",
+ },
+ },
+ }
+
+ # send change email mail using different languages
+ try:
+ system_features = FeatureService.get_system_features()
+ lang_key = "zh-Hans" if language == "zh-Hans" else "en"
+
+ if phase not in ["old_email", "new_email"]:
+ raise ValueError("Invalid phase")
+
+ config = email_config[lang_key][phase]
+ subject = config["subject"]
+
+ if system_features.branding.enabled:
+ template = config["template_without_brand"]
+ else:
+ template = config["template_with_brand"]
+
+ html_content = render_template(template, to=to, code=code)
+ mail.send(to=to, subject=subject, html=html_content)
+
+ end_at = time.perf_counter()
+ logging.info(
+ click.style("Send change email mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green")
+ )
+ except Exception:
+ logging.exception("Send change email mail to {} failed".format(to))
diff --git a/api/tasks/mail_owner_transfer_task.py b/api/tasks/mail_owner_transfer_task.py
new file mode 100644
index 0000000000..8d05c6dc0f
--- /dev/null
+++ b/api/tasks/mail_owner_transfer_task.py
@@ -0,0 +1,152 @@
+import logging
+import time
+
+import click
+from celery import shared_task # type: ignore
+from flask import render_template
+
+from extensions.ext_mail import mail
+from services.feature_service import FeatureService
+
+
+@shared_task(queue="mail")
+def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspace: str):
+ """
+ Async Send owner transfer confirm mail
+ :param language: Language in which the email should be sent (e.g., 'en', 'zh')
+ :param to: Recipient email address
+ :param workspace: Workspace name
+ """
+ if not mail.is_inited():
+ return
+
+ logging.info(click.style("Start change email mail to {}".format(to), fg="green"))
+ start_at = time.perf_counter()
+ # send change email mail using different languages
+ try:
+ if language == "zh-Hans":
+ template = "transfer_workspace_owner_confirm_template_zh-CN.html"
+ system_features = FeatureService.get_system_features()
+ if system_features.branding.enabled:
+ template = "without-brand/transfer_workspace_owner_confirm_template_zh-CN.html"
+ html_content = render_template(template, to=to, code=code, WorkspaceName=workspace)
+ mail.send(to=to, subject="验证您转移工作空间所有权的请求", html=html_content)
+ else:
+ html_content = render_template(template, to=to, code=code, WorkspaceName=workspace)
+ mail.send(to=to, subject="验证您转移工作空间所有权的请求", html=html_content)
+ else:
+ template = "transfer_workspace_owner_confirm_template_en-US.html"
+ system_features = FeatureService.get_system_features()
+ if system_features.branding.enabled:
+ template = "without-brand/transfer_workspace_owner_confirm_template_en-US.html"
+ html_content = render_template(template, to=to, code=code, WorkspaceName=workspace)
+ mail.send(to=to, subject="Verify Your Request to Transfer Workspace Ownership", html=html_content)
+ else:
+ html_content = render_template(template, to=to, code=code, WorkspaceName=workspace)
+ mail.send(to=to, subject="Verify Your Request to Transfer Workspace Ownership", html=html_content)
+
+ end_at = time.perf_counter()
+ logging.info(
+ click.style(
+ "Send owner transfer confirm mail to {} succeeded: latency: {}".format(to, end_at - start_at),
+ fg="green",
+ )
+ )
+ except Exception:
+ logging.exception("owner transfer confirm email mail to {} failed".format(to))
+
+
+@shared_task(queue="mail")
+def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: str, new_owner_email: str):
+ """
+ Async Send owner transfer confirm mail
+ :param language: Language in which the email should be sent (e.g., 'en', 'zh')
+ :param to: Recipient email address
+ :param workspace: Workspace name
+ :param new_owner_email: New owner email
+ """
+ if not mail.is_inited():
+ return
+
+ logging.info(click.style("Start change email mail to {}".format(to), fg="green"))
+ start_at = time.perf_counter()
+ # send change email mail using different languages
+ try:
+ if language == "zh-Hans":
+ template = "transfer_workspace_old_owner_notify_template_zh-CN.html"
+ system_features = FeatureService.get_system_features()
+ if system_features.branding.enabled:
+ template = "without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html"
+ html_content = render_template(template, to=to, WorkspaceName=workspace, NewOwnerEmail=new_owner_email)
+ mail.send(to=to, subject="工作区所有权已转移", html=html_content)
+ else:
+ html_content = render_template(template, to=to, WorkspaceName=workspace, NewOwnerEmail=new_owner_email)
+ mail.send(to=to, subject="工作区所有权已转移", html=html_content)
+ else:
+ template = "transfer_workspace_old_owner_notify_template_en-US.html"
+ system_features = FeatureService.get_system_features()
+ if system_features.branding.enabled:
+ template = "without-brand/transfer_workspace_old_owner_notify_template_en-US.html"
+ html_content = render_template(template, to=to, WorkspaceName=workspace, NewOwnerEmail=new_owner_email)
+ mail.send(to=to, subject="Workspace ownership has been transferred", html=html_content)
+ else:
+ html_content = render_template(template, to=to, WorkspaceName=workspace, NewOwnerEmail=new_owner_email)
+ mail.send(to=to, subject="Workspace ownership has been transferred", html=html_content)
+
+ end_at = time.perf_counter()
+ logging.info(
+ click.style(
+ "Send owner transfer confirm mail to {} succeeded: latency: {}".format(to, end_at - start_at),
+ fg="green",
+ )
+ )
+ except Exception:
+ logging.exception("owner transfer confirm email mail to {} failed".format(to))
+
+
+@shared_task(queue="mail")
+def send_new_owner_transfer_notify_email_task(language: str, to: str, workspace: str):
+ """
+ Async Send owner transfer confirm mail
+ :param language: Language in which the email should be sent (e.g., 'en', 'zh')
+ :param to: Recipient email address
+ :param code: Change email code
+ :param workspace: Workspace name
+ """
+ if not mail.is_inited():
+ return
+
+ logging.info(click.style("Start change email mail to {}".format(to), fg="green"))
+ start_at = time.perf_counter()
+ # send change email mail using different languages
+ try:
+ if language == "zh-Hans":
+ template = "transfer_workspace_new_owner_notify_template_zh-CN.html"
+ system_features = FeatureService.get_system_features()
+ if system_features.branding.enabled:
+ template = "without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html"
+ html_content = render_template(template, to=to, WorkspaceName=workspace)
+ mail.send(to=to, subject=f"您现在是 {workspace} 的所有者", html=html_content)
+ else:
+ html_content = render_template(template, to=to, WorkspaceName=workspace)
+ mail.send(to=to, subject=f"您现在是 {workspace} 的所有者", html=html_content)
+ else:
+ template = "transfer_workspace_new_owner_notify_template_en-US.html"
+ system_features = FeatureService.get_system_features()
+ if system_features.branding.enabled:
+ template = "without-brand/transfer_workspace_new_owner_notify_template_en-US.html"
+ html_content = render_template(template, to=to, WorkspaceName=workspace)
+ mail.send(to=to, subject=f"You are now the owner of {workspace}", html=html_content)
+ else:
+ html_content = render_template(template, to=to, WorkspaceName=workspace)
+ mail.send(to=to, subject=f"You are now the owner of {workspace}", html=html_content)
+
+ end_at = time.perf_counter()
+ logging.info(
+ click.style(
+ "Send owner transfer confirm mail to {} succeeded: latency: {}".format(to, end_at - start_at),
+ fg="green",
+ )
+ )
+ except Exception:
+ logging.exception("owner transfer confirm email mail to {} failed".format(to))
diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py
index d366efd6f2..179adcbd6e 100644
--- a/api/tasks/remove_app_and_related_data_task.py
+++ b/api/tasks/remove_app_and_related_data_task.py
@@ -6,6 +6,7 @@ import click
from celery import shared_task # type: ignore
from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from models import (
@@ -13,6 +14,7 @@ from models import (
AppAnnotationHitHistory,
AppAnnotationSetting,
AppDatasetJoin,
+ AppMCPServer,
AppModelConfig,
Conversation,
EndUser,
@@ -30,7 +32,8 @@ from models import (
)
from models.tools import WorkflowToolProvider
from models.web import PinnedConversation, SavedMessage
-from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun
+from models.workflow import ConversationVariable, Workflow, WorkflowAppLog
+from repositories.factory import DifyAPIRepositoryFactory
@shared_task(queue="app_deletion", bind=True, max_retries=3)
@@ -41,6 +44,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
# Delete related data
_delete_app_model_configs(tenant_id, app_id)
_delete_app_site(tenant_id, app_id)
+ _delete_app_mcp_servers(tenant_id, app_id)
_delete_app_api_tokens(tenant_id, app_id)
_delete_installed_apps(tenant_id, app_id)
_delete_recommended_apps(tenant_id, app_id)
@@ -89,6 +93,18 @@ def _delete_app_site(tenant_id: str, app_id: str):
_delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site")
+def _delete_app_mcp_servers(tenant_id: str, app_id: str):
+ def del_mcp_server(mcp_server_id: str):
+ db.session.query(AppMCPServer).filter(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
+
+ _delete_records(
+ """select id from app_mcp_servers where app_id=:app_id limit 1000""",
+ {"app_id": app_id},
+ del_mcp_server,
+ "app mcp server",
+ )
+
+
def _delete_app_api_tokens(tenant_id: str, app_id: str):
def del_api_token(api_token_id: str):
db.session.query(ApiToken).filter(ApiToken.id == api_token_id).delete(synchronize_session=False)
@@ -175,30 +191,32 @@ def _delete_app_workflows(tenant_id: str, app_id: str):
def _delete_app_workflow_runs(tenant_id: str, app_id: str):
- def del_workflow_run(workflow_run_id: str):
- db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).delete(synchronize_session=False)
-
- _delete_records(
- """select id from workflow_runs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
- {"tenant_id": tenant_id, "app_id": app_id},
- del_workflow_run,
- "workflow run",
+ """Delete all workflow runs for an app using the service repository."""
+ session_maker = sessionmaker(bind=db.engine)
+ workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+
+ deleted_count = workflow_run_repo.delete_runs_by_app(
+ tenant_id=tenant_id,
+ app_id=app_id,
+ batch_size=1000,
)
+ logging.info(f"Deleted {deleted_count} workflow runs for app {app_id}")
-def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
- def del_workflow_node_execution(workflow_node_execution_id: str):
- db.session.query(WorkflowNodeExecutionModel).filter(
- WorkflowNodeExecutionModel.id == workflow_node_execution_id
- ).delete(synchronize_session=False)
- _delete_records(
- """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
- {"tenant_id": tenant_id, "app_id": app_id},
- del_workflow_node_execution,
- "workflow node execution",
+def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
+ """Delete all workflow node executions for an app using the service repository."""
+ session_maker = sessionmaker(bind=db.engine)
+ node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
+
+ deleted_count = node_execution_repo.delete_executions_by_app(
+ tenant_id=tenant_id,
+ app_id=app_id,
+ batch_size=1000,
)
+ logging.info(f"Deleted {deleted_count} workflow node executions for app {app_id}")
+
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def del_workflow_app_log(workflow_app_log_id: str):
diff --git a/api/templates/change_mail_confirm_new_template_en-US.html b/api/templates/change_mail_confirm_new_template_en-US.html
new file mode 100644
index 0000000000..88721e787c
--- /dev/null
+++ b/api/templates/change_mail_confirm_new_template_en-US.html
@@ -0,0 +1,125 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Confirm Your New Email Address
+
+
You’re updating the email address linked to your Dify account.
+
To confirm this action, please use the verification code below.
+
This code will only be valid for the next 5 minutes:
+
+
+ {{code}}
+
+
If you didn’t make this request, please ignore this email or contact support immediately.
Some Documents in Your Knowledge Base Have Been Disabled
-
Some Documents in Your Knowledge Base Have Been Disabled
-
Dear {{userName}},
-
+
Dear {{userName}},
+
We're sorry for the inconvenience. To ensure optimal performance, documents
that haven’t been updated or accessed in the past 30 days have been disabled in
your knowledge bases:
{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.
-
Click the button below to log in to Dify and join the workspace.
{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.
+
Click the button below to log in to Dify and join the workspace.
{{ inviter_name }} is pleased to invite you to join our workspace on {{application_title}}, a platform specifically designed for LLM application development. On {{application_title}}, you can explore, create, and collaborate to build and operate AI applications.
-
Click the button below to log in to {{application_title}} and join the workspace.
{{ inviter_name }} is pleased to invite you to join our workspace on {{application_title}}, a platform specifically designed for LLM application development. On {{application_title}}, you can explore, create, and collaborate to build and operate AI applications.
+
Click the button below to log in to {{application_title}} and join the workspace.