diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh
index 93ecac48f2..022f71bfb4 100755
--- a/.devcontainer/post_create_command.sh
+++ b/.devcontainer/post_create_command.sh
@@ -1,6 +1,6 @@
#!/bin/bash
-npm add -g pnpm@10.11.1
+npm add -g pnpm@10.13.1
cd web && pnpm install
pipx install uv
@@ -12,3 +12,4 @@ echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f do
echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc
source /home/vscode/.bashrc
+
diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml
index a9580a3ba3..d684fe9144 100644
--- a/.github/ISSUE_TEMPLATE/bug_report.yml
+++ b/.github/ISSUE_TEMPLATE/bug_report.yml
@@ -8,13 +8,15 @@ body:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
+ - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
+ required: true
- label: This is only for bug report, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general).
required: true
- label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
+ - label: I confirm that I am using English to submit this report, otherwise it will be closed.
required: true
- - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)"
+ - label: 【中文用户 & Non English User】请使用英语提交,否则会被关闭 :)
required: true
- label: "Please do not modify this template :) and fill in all the required fields."
required: true
@@ -42,20 +44,22 @@ body:
attributes:
label: Steps to reproduce
description: We highly suggest including screenshots and a bug report log. Please use the right markdown syntax for code blocks.
- placeholder: Having detailed steps helps us reproduce the bug.
+ placeholder: Having detailed steps helps us reproduce the bug. If you have logs, please use fenced code blocks (triple backticks ```) to format them.
validations:
required: true
- type: textarea
attributes:
label: ✔️ Expected Behavior
- placeholder: What were you expecting?
+ description: Describe what you expected to happen.
+ placeholder: What were you expecting? Please do not copy and paste the steps to reproduce here.
validations:
- required: false
+ required: true
- type: textarea
attributes:
label: ❌ Actual Behavior
- placeholder: What happened instead?
+ description: Describe what actually happened.
+ placeholder: What happened instead? Please do not copy and paste the steps to reproduce here.
validations:
required: false
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
index 6877c382c4..c1666d24cf 100644
--- a/.github/ISSUE_TEMPLATE/config.yml
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -1,5 +1,11 @@
blank_issues_enabled: false
contact_links:
+ - name: "\U0001F4A1 Model Providers & Plugins"
+ url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose"
+ about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details.
+ - name: "\U0001F4AC Documentation Issues"
+ url: "https://github.com/langgenius/dify-docs/issues/new"
+ about: Report issues with the documentation, such as typos, outdated information, or missing content. Please provide the specific section and details of the issue.
- name: "\U0001F4E7 Discussions"
url: https://github.com/langgenius/dify/discussions/categories/general
- about: General discussions and request help from the community
+ about: General discussions and seek help from the community
diff --git a/.github/ISSUE_TEMPLATE/document_issue.yml b/.github/ISSUE_TEMPLATE/document_issue.yml
deleted file mode 100644
index 8fdbc0fb9a..0000000000
--- a/.github/ISSUE_TEMPLATE/document_issue.yml
+++ /dev/null
@@ -1,24 +0,0 @@
-name: "📚 Documentation Issue"
-description: Report issues in our documentation
-labels:
- - documentation
-body:
- - type: checkboxes
- attributes:
- label: Self Checks
- description: "To make sure we get to you in time, please check the following :)"
- options:
- - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
- required: true
- - label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
- required: true
- - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)"
- required: true
- - label: "Please do not modify this template :) and fill in all the required fields."
- required: true
- - type: textarea
- attributes:
- label: Provide a description of requested docs changes
- placeholder: Briefly describe which document needs to be corrected and why.
- validations:
- required: true
diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml
index b1952c63a9..bd293e2442 100644
--- a/.github/ISSUE_TEMPLATE/feature_request.yml
+++ b/.github/ISSUE_TEMPLATE/feature_request.yml
@@ -8,11 +8,11 @@ body:
label: Self Checks
description: "To make sure we get to you in time, please check the following :)"
options:
- - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
+ - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542).
required: true
- - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
+ - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
required: true
- - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)"
+ - label: I confirm that I am using English to submit this report, otherwise it will be closed.
required: true
- label: "Please do not modify this template :) and fill in all the required fields."
required: true
diff --git a/.github/ISSUE_TEMPLATE/translation_issue.yml b/.github/ISSUE_TEMPLATE/translation_issue.yml
deleted file mode 100644
index f9c2dfb7d2..0000000000
--- a/.github/ISSUE_TEMPLATE/translation_issue.yml
+++ /dev/null
@@ -1,55 +0,0 @@
-name: "🌐 Localization/Translation issue"
-description: Report incorrect translations. [please use English :)]
-labels:
- - translation
-body:
- - type: checkboxes
- attributes:
- label: Self Checks
- description: "To make sure we get to you in time, please check the following :)"
- options:
- - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones.
- required: true
- - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)).
- required: true
- - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)"
- required: true
- - label: "Please do not modify this template :) and fill in all the required fields."
- required: true
- - type: input
- attributes:
- label: Dify version
- description: Hover over system tray icon or look at Settings
- validations:
- required: true
- - type: input
- attributes:
- label: Utility with translation issue
- placeholder: Some area
- description: Please input here the utility with the translation issue
- validations:
- required: true
- - type: input
- attributes:
- label: 🌐 Language affected
- placeholder: "German"
- validations:
- required: true
- - type: textarea
- attributes:
- label: ❌ Actual phrase(s)
- placeholder: What is there? Please include a screenshot as that is extremely helpful.
- validations:
- required: true
- - type: textarea
- attributes:
- label: ✔️ Expected phrase(s)
- placeholder: What was expected?
- validations:
- required: true
- - type: textarea
- attributes:
- label: ℹ Why is the current translation wrong
- placeholder: Why do you feel this is incorrect?
- validations:
- required: true
diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml
new file mode 100644
index 0000000000..5e290c5d02
--- /dev/null
+++ b/.github/workflows/autofix.yml
@@ -0,0 +1,27 @@
+name: autofix.ci
+on:
+ workflow_call:
+ pull_request:
+ push:
+ branches: [ "main" ]
+permissions:
+ contents: read
+
+jobs:
+ autofix:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v4
+
+ # Use uv to ensure we have the same ruff version in CI and locally.
+ - uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f
+ - run: |
+ cd api
+ uv sync --dev
+ # Fix lint errors
+ uv run ruff check --fix-only .
+ # Format code
+ uv run ruff format .
+
+ - uses: autofix-ci/action@635ffb0c9798bd160680f18fd73371e355b85f27
+
diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml
index cc735ae67c..b933560a5e 100644
--- a/.github/workflows/build-push.yml
+++ b/.github/workflows/build-push.yml
@@ -6,6 +6,7 @@ on:
- "main"
- "deploy/dev"
- "deploy/enterprise"
+ - "build/**"
tags:
- "*"
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
index b06ab9653e..a283f8d5ca 100644
--- a/.github/workflows/style.yml
+++ b/.github/workflows/style.yml
@@ -28,7 +28,7 @@ jobs:
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v45
+ uses: tj-actions/changed-files@v46
with:
files: |
api/**
@@ -75,7 +75,7 @@ jobs:
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v45
+ uses: tj-actions/changed-files@v46
with:
files: web/**
@@ -113,7 +113,7 @@ jobs:
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v45
+ uses: tj-actions/changed-files@v46
with:
files: |
docker/generate_docker_compose
@@ -144,7 +144,7 @@ jobs:
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v45
+ uses: tj-actions/changed-files@v46
with:
files: |
**.sh
@@ -152,13 +152,15 @@ jobs:
**.yml
**Dockerfile
dev/**
+ .editorconfig
- name: Super-linter
- uses: super-linter/super-linter/slim@v7
+ uses: super-linter/super-linter/slim@v8
if: steps.changed-files.outputs.any_changed == 'true'
env:
BASH_SEVERITY: warning
- DEFAULT_BRANCH: main
+ DEFAULT_BRANCH: origin/main
+ EDITORCONFIG_FILE_NAME: editorconfig-checker.json
FILTER_REGEX_INCLUDE: pnpm-lock.yaml
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
IGNORE_GENERATED_FILES: true
@@ -168,16 +170,6 @@ jobs:
# FIXME: temporarily disabled until api-docker.yaml's run script is fixed for shellcheck
# VALIDATE_GITHUB_ACTIONS: true
VALIDATE_DOCKERFILE_HADOLINT: true
+ VALIDATE_EDITORCONFIG: true
VALIDATE_XML: true
VALIDATE_YAML: true
-
- - name: EditorConfig checks
- uses: super-linter/super-linter/slim@v7
- env:
- DEFAULT_BRANCH: main
- GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- IGNORE_GENERATED_FILES: true
- IGNORE_GITIGNORED_FILES: true
- # EditorConfig validation
- VALIDATE_EDITORCONFIG: true
- EDITORCONFIG_FILE_NAME: editorconfig-checker.json
diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml
index 37cfdc5c1e..c3f8fdbaf6 100644
--- a/.github/workflows/web-tests.yml
+++ b/.github/workflows/web-tests.yml
@@ -27,7 +27,7 @@ jobs:
- name: Check changed files
id: changed-files
- uses: tj-actions/changed-files@v45
+ uses: tj-actions/changed-files@v46
with:
files: web/**
diff --git a/README.md b/README.md
index 1dc7e2dd98..2909e0e6cf 100644
--- a/README.md
+++ b/README.md
@@ -54,7 +54,7 @@
-Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features, and more, allowing you to quickly move from prototype to production.
+Dify is an open-source platform for developing LLM applications. Its intuitive interface combines agentic AI workflows, RAG pipelines, agent capabilities, model management, observability features, and more—allowing you to quickly move from prototype to production.
## Quick start
@@ -65,7 +65,7 @@ Dify is an open-source LLM app development platform. Its intuitive interface com
-The easiest way to start the Dify server is through [docker compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
+The easiest way to start the Dify server is through [Docker Compose](docker/docker-compose.yaml). Before running Dify with the following commands, make sure that [Docker](https://docs.docker.com/get-docker/) and [Docker Compose](https://docs.docker.com/compose/install/) are installed on your machine:
```bash
cd dify
@@ -205,6 +205,7 @@ If you'd like to configure a highly-available setup, there are community-contrib
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Using Terraform for Deployment
@@ -261,8 +262,8 @@ At the same time, please consider supporting Dify by sharing it on social media
## Security disclosure
-To protect your privacy, please avoid posting security issues on GitHub. Instead, send your questions to security@dify.ai and we will provide you with a more detailed answer.
+To protect your privacy, please avoid posting security issues on GitHub. Instead, report issues to security@dify.ai, and our team will respond with detailed answer.
## License
-This repository is available under the [Dify Open Source License](LICENSE), which is essentially Apache 2.0 with a few additional restrictions.
+This repository is licensed under the [Dify Open Source License](LICENSE), based on Apache 2.0 with additional conditions.
diff --git a/README_AR.md b/README_AR.md
index d93bca8646..e959ca0f78 100644
--- a/README_AR.md
+++ b/README_AR.md
@@ -188,6 +188,7 @@ docker compose up -d
- [رسم بياني Helm من قبل @magicsong](https://github.com/magicsong/ai-charts)
- [ملف YAML من قبل @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [ملف YAML من قبل @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 جديد! ملفات YAML (تدعم Dify v1.6.0) بواسطة @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### استخدام Terraform للتوزيع
diff --git a/README_BN.md b/README_BN.md
index 3efee3684d..29d7374ea5 100644
--- a/README_BN.md
+++ b/README_BN.md
@@ -204,6 +204,8 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 নতুন! YAML ফাইলসমূহ (Dify v1.6.0 সমর্থিত) তৈরি করেছেন @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
+
#### টেরাফর্ম ব্যবহার করে ডিপ্লয়
diff --git a/README_CN.md b/README_CN.md
index 21e27429ec..486a368c09 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -194,9 +194,9 @@ docker compose up -d
如果您需要自定义配置,请参考 [.env.example](docker/.env.example) 文件中的注释,并更新 `.env` 文件中对应的值。此外,您可能需要根据您的具体部署环境和需求对 `docker-compose.yaml` 文件本身进行调整,例如更改镜像版本、端口映射或卷挂载。完成任何更改后,请重新运行 `docker-compose up -d`。您可以在[此处](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用环境变量的完整列表。
-#### 使用 Helm Chart 部署
+#### 使用 Helm Chart 或 Kubernetes 资源清单(YAML)部署
-使用 [Helm Chart](https://helm.sh/) 版本或者 YAML 文件,可以在 Kubernetes 上部署 Dify。
+使用 [Helm Chart](https://helm.sh/) 版本或者 Kubernetes 资源清单(YAML),可以在 Kubernetes 上部署 Dify。
- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
@@ -204,6 +204,10 @@ docker compose up -d
- [YAML 文件 by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NEW! YAML 文件 (支持 Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
+
+
+
#### 使用 Terraform 部署
使用 [terraform](https://www.terraform.io/) 一键将 Dify 部署到云平台
diff --git a/README_DE.md b/README_DE.md
index 20c313035e..fce52c34c2 100644
--- a/README_DE.md
+++ b/README_DE.md
@@ -203,6 +203,7 @@ Falls Sie eine hochverfügbare Konfiguration einrichten möchten, gibt es von de
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Terraform für die Bereitstellung verwenden
diff --git a/README_ES.md b/README_ES.md
index e4b7df6686..6fd6dfcee8 100644
--- a/README_ES.md
+++ b/README_ES.md
@@ -203,6 +203,7 @@ Si desea configurar una configuración de alta disponibilidad, la comunidad prop
- [Gráfico Helm por @magicsong](https://github.com/magicsong/ai-charts)
- [Ficheros YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [Ficheros YAML por @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 ¡NUEVO! Archivos YAML (compatible con Dify v1.6.0) por @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Uso de Terraform para el despliegue
diff --git a/README_FR.md b/README_FR.md
index 8fd17fb7c3..b2209fb495 100644
--- a/README_FR.md
+++ b/README_FR.md
@@ -201,6 +201,7 @@ Si vous souhaitez configurer une configuration haute disponibilité, la communau
- [Helm Chart par @magicsong](https://github.com/magicsong/ai-charts)
- [Fichier YAML par @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [Fichier YAML par @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NOUVEAU ! Fichiers YAML (compatible avec Dify v1.6.0) par @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Utilisation de Terraform pour le déploiement
diff --git a/README_JA.md b/README_JA.md
index a3ee81e1f2..c658225f90 100644
--- a/README_JA.md
+++ b/README_JA.md
@@ -202,6 +202,7 @@ docker compose up -d
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 新着!YAML ファイル(Dify v1.6.0 対応)by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Terraformを使用したデプロイ
diff --git a/README_KL.md b/README_KL.md
index 3e5ab1a74f..bfafcc7407 100644
--- a/README_KL.md
+++ b/README_KL.md
@@ -201,6 +201,7 @@ If you'd like to configure a highly-available setup, there are community-contrib
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Terraform atorlugu pilersitsineq
diff --git a/README_KR.md b/README_KR.md
index 3c504900e1..282117e776 100644
--- a/README_KR.md
+++ b/README_KR.md
@@ -195,6 +195,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
- [Helm Chart by @magicsong](https://github.com/magicsong/ai-charts)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Terraform을 사용한 배포
diff --git a/README_PT.md b/README_PT.md
index fb5f3662ae..576f6b48f7 100644
--- a/README_PT.md
+++ b/README_PT.md
@@ -200,6 +200,7 @@ Se deseja configurar uma instalação de alta disponibilidade, há [Helm Charts]
- [Helm Chart de @magicsong](https://github.com/magicsong/ai-charts)
- [Arquivo YAML por @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [Arquivo YAML por @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NOVO! Arquivos YAML (Compatível com Dify v1.6.0) por @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Usando o Terraform para Implantação
diff --git a/README_SI.md b/README_SI.md
index 647069a220..7ded001d86 100644
--- a/README_SI.md
+++ b/README_SI.md
@@ -201,6 +201,7 @@ Star Dify on GitHub and be instantly notified of new releases.
- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NEW! YAML files (Supports Dify v1.6.0) by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Uporaba Terraform za uvajanje
diff --git a/README_TR.md b/README_TR.md
index f52335646a..6e94e54fa0 100644
--- a/README_TR.md
+++ b/README_TR.md
@@ -194,6 +194,7 @@ Yüksek kullanılabilirliğe sahip bir kurulum yapılandırmak isterseniz, Dify'
- [@BorisPolonsky tarafından Helm Chart](https://github.com/BorisPolonsky/dify-helm)
- [@Winson-030 tarafından YAML dosyası](https://github.com/Winson-030/dify-kubernetes)
- [@wyy-holding tarafından YAML dosyası](https://github.com/wyy-holding/dify-k8s)
+- [🚀 YENİ! YAML dosyaları (Dify v1.6.0 destekli) @Zhoneym tarafından](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Dağıtım için Terraform Kullanımı
diff --git a/README_TW.md b/README_TW.md
index 71082ff893..6e3e22b5c1 100644
--- a/README_TW.md
+++ b/README_TW.md
@@ -197,12 +197,13 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify
如果您需要自定義配置,請參考我們的 [.env.example](docker/.env.example) 文件中的註釋,並在您的 `.env` 文件中更新相應的值。此外,根據您特定的部署環境和需求,您可能需要調整 `docker-compose.yaml` 文件本身,例如更改映像版本、端口映射或卷掛載。進行任何更改後,請重新運行 `docker-compose up -d`。您可以在[這裡](https://docs.dify.ai/getting-started/install-self-hosted/environments)找到可用環境變數的完整列表。
-如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 YAML 文件允許在 Kubernetes 上部署 Dify。
+如果您想配置高可用性設置,社區貢獻的 [Helm Charts](https://helm.sh/) 和 Kubernetes 資源清單(YAML)允許在 Kubernetes 上部署 Dify。
- [由 @LeoQuote 提供的 Helm Chart](https://github.com/douban/charts/tree/master/charts/dify)
- [由 @BorisPolonsky 提供的 Helm Chart](https://github.com/BorisPolonsky/dify-helm)
- [由 @Winson-030 提供的 YAML 文件](https://github.com/Winson-030/dify-kubernetes)
- [由 @wyy-holding 提供的 YAML 文件](https://github.com/wyy-holding/dify-k8s)
+- [🚀 NEW! YAML 檔案(支援 Dify v1.6.0)by @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
### 使用 Terraform 進行部署
diff --git a/README_VI.md b/README_VI.md
index 58d8434fff..51314e6de5 100644
--- a/README_VI.md
+++ b/README_VI.md
@@ -196,6 +196,7 @@ Nếu bạn muốn cấu hình một cài đặt có độ sẵn sàng cao, có
- [Helm Chart bởi @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
- [Tệp YAML bởi @Winson-030](https://github.com/Winson-030/dify-kubernetes)
- [Tệp YAML bởi @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+- [🚀 MỚI! Tệp YAML (Hỗ trợ Dify v1.6.0) bởi @Zhoneym](https://github.com/Zhoneym/DifyAI-Kubernetes)
#### Sử dụng Terraform để Triển khai
diff --git a/api/.env.example b/api/.env.example
index a7ea6cf937..80b1c12cd8 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -5,17 +5,17 @@
SECRET_KEY=
# Console API base URL
-CONSOLE_API_URL=http://127.0.0.1:5001
-CONSOLE_WEB_URL=http://127.0.0.1:3000
+CONSOLE_API_URL=http://localhost:5001
+CONSOLE_WEB_URL=http://localhost:3000
# Service API base URL
-SERVICE_API_URL=http://127.0.0.1:5001
+SERVICE_API_URL=http://localhost:5001
# Web APP base URL
-APP_WEB_URL=http://127.0.0.1:3000
+APP_WEB_URL=http://localhost:3000
# Files URL
-FILES_URL=http://127.0.0.1:5001
+FILES_URL=http://localhost:5001
# INTERNAL_FILES_URL is used for plugin daemon communication within Docker network.
# Set this to the internal Docker service URL for proper plugin file access.
@@ -54,7 +54,7 @@ REDIS_CLUSTERS_PASSWORD=
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1
-
+CELERY_BACKEND=redis
# PostgreSQL database configuration
DB_USERNAME=postgres
DB_PASSWORD=difyai123456
@@ -138,12 +138,14 @@ SUPABASE_API_KEY=your-access-key
SUPABASE_URL=your-server-url
# CORS configuration
-WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
-CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
+WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
+CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
# Vector database configuration
-# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone
+# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.
VECTOR_STORE=weaviate
+# Prefix used to create collection name in vector database
+VECTOR_INDEX_NAME_PREFIX=Vector_index
# Weaviate configuration
WEAVIATE_ENDPOINT=http://localhost:8080
@@ -449,6 +451,19 @@ MAX_VARIABLE_SIZE=204800
# hybrid: Save new data to object storage, read from both object storage and RDBMS
WORKFLOW_NODE_EXECUTION_STORAGE=rdbms
+# Repository configuration
+# Core workflow execution repository implementation
+CORE_WORKFLOW_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository
+
+# Core workflow node execution repository implementation
+CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY=core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository
+
+# API workflow node execution repository implementation
+API_WORKFLOW_NODE_EXECUTION_REPOSITORY=repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository
+
+# API workflow run repository implementation
+API_WORKFLOW_RUN_REPOSITORY=repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository
+
# App configuration
APP_MAX_EXECUTION_TIME=1200
APP_MAX_ACTIVE_REQUESTS=0
@@ -456,6 +471,16 @@ APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
+# Celery schedule tasks configuration
+ENABLE_CLEAN_EMBEDDING_CACHE_TASK=false
+ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
+ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
+ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
+ENABLE_CLEAN_MESSAGES=false
+ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
+ENABLE_DATASETS_QUEUE_MONITOR=false
+ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
+
# Position configuration
POSITION_TOOL_PINS=
POSITION_TOOL_INCLUDES=
@@ -482,6 +507,8 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
# Reset password token expiry minutes
RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5
+CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5
+OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5
CREATE_TIDB_SERVICE_JOB_ENABLED=false
@@ -492,6 +519,8 @@ LOGIN_LOCKOUT_DURATION=86400
# Enable OpenTelemetry
ENABLE_OTEL=false
+OTLP_TRACE_ENDPOINT=
+OTLP_METRIC_ENDPOINT=
OTLP_BASE_ENDPOINT=http://localhost:4318
OTLP_API_KEY=
OTEL_EXPORTER_OTLP_PROTOCOL=
diff --git a/api/Dockerfile b/api/Dockerfile
index 7e4997507f..8c7a1717b9 100644
--- a/api/Dockerfile
+++ b/api/Dockerfile
@@ -47,6 +47,8 @@ RUN \
curl nodejs libgmp-dev libmpfr-dev libmpc-dev \
# For Security
expat libldap-2.5-0 perl libsqlite3-0 zlib1g \
+ # install fonts to support the use of tools like pypdfium2
+ fonts-noto-cjk \
# install a package to improve the accuracy of guessing mime type and file extension
media-types \
# install libmagic to support the use of python-magic guess MIMETYPE
diff --git a/api/README.md b/api/README.md
index 9308d5dc44..6ab923070e 100644
--- a/api/README.md
+++ b/api/README.md
@@ -74,7 +74,12 @@
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
- uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
+ uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin
+ ```
+
+ Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
+ ```bash
+ uv run celery -A app.celery beat
```
## Testing
diff --git a/api/commands.py b/api/commands.py
index 86769847c1..c2e62ec261 100644
--- a/api/commands.py
+++ b/api/commands.py
@@ -2,19 +2,22 @@ import base64
import json
import logging
import secrets
-from typing import Optional
+from typing import Any, Optional
import click
from flask import current_app
+from pydantic import TypeAdapter
from sqlalchemy import select
from werkzeug.exceptions import NotFound
from configs import dify_config
from constants.languages import languages
+from core.plugin.entities.plugin import ToolProviderID
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.models.document import Document
+from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from events.app_event import app_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -27,6 +30,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, D
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation
from models.provider import Provider, ProviderModel
+from models.tools import ToolOAuthSystemClient
from services.account_service import AccountService, RegisterService, TenantService
from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs
from services.plugin.data_migration import PluginDataMigration
@@ -46,7 +50,7 @@ def reset_password(email, new_password, password_confirm):
click.echo(click.style("Passwords do not match.", fg="red"))
return
- account = db.session.query(Account).filter(Account.email == email).one_or_none()
+ account = db.session.query(Account).where(Account.email == email).one_or_none()
if not account:
click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
@@ -85,7 +89,7 @@ def reset_email(email, new_email, email_confirm):
click.echo(click.style("New emails do not match.", fg="red"))
return
- account = db.session.query(Account).filter(Account.email == email).one_or_none()
+ account = db.session.query(Account).where(Account.email == email).one_or_none()
if not account:
click.echo(click.style("Account not found for email: {}".format(email), fg="red"))
@@ -132,8 +136,8 @@ def reset_encrypt_key_pair():
tenant.encrypt_public_key = generate_key_pair(tenant.id)
- db.session.query(Provider).filter(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
- db.session.query(ProviderModel).filter(ProviderModel.tenant_id == tenant.id).delete()
+ db.session.query(Provider).where(Provider.provider_type == "custom", Provider.tenant_id == tenant.id).delete()
+ db.session.query(ProviderModel).where(ProviderModel.tenant_id == tenant.id).delete()
db.session.commit()
click.echo(
@@ -168,7 +172,7 @@ def migrate_annotation_vector_database():
per_page = 50
apps = (
db.session.query(App)
- .filter(App.status == "normal")
+ .where(App.status == "normal")
.order_by(App.created_at.desc())
.limit(per_page)
.offset((page - 1) * per_page)
@@ -188,7 +192,7 @@ def migrate_annotation_vector_database():
try:
click.echo("Creating app annotation index: {}".format(app.id))
app_annotation_setting = (
- db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app.id).first()
+ db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app.id).first()
)
if not app_annotation_setting:
@@ -198,13 +202,13 @@ def migrate_annotation_vector_database():
# get dataset_collection_binding info
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
- .filter(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
+ .where(DatasetCollectionBinding.id == app_annotation_setting.collection_binding_id)
.first()
)
if not dataset_collection_binding:
click.echo("App annotation collection binding not found: {}".format(app.id))
continue
- annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app.id).all()
+ annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app.id).all()
dataset = Dataset(
id=app.id,
tenant_id=app.tenant_id,
@@ -301,7 +305,7 @@ def migrate_knowledge_vector_database():
while True:
try:
stmt = (
- select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
+ select(Dataset).where(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
)
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
@@ -328,7 +332,7 @@ def migrate_knowledge_vector_database():
if dataset.collection_binding_id:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
- .filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
+ .where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
if dataset_collection_binding:
@@ -363,7 +367,7 @@ def migrate_knowledge_vector_database():
dataset_documents = (
db.session.query(DatasetDocument)
- .filter(
+ .where(
DatasetDocument.dataset_id == dataset.id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@@ -377,7 +381,7 @@ def migrate_knowledge_vector_database():
for dataset_document in dataset_documents:
segments = (
db.session.query(DocumentSegment)
- .filter(
+ .where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
@@ -464,7 +468,7 @@ def convert_to_agent_apps():
app_id = str(i.id)
if app_id not in proceeded_app_ids:
proceeded_app_ids.append(app_id)
- app = db.session.query(App).filter(App.id == app_id).first()
+ app = db.session.query(App).where(App.id == app_id).first()
if app is not None:
apps.append(app)
@@ -479,7 +483,7 @@ def convert_to_agent_apps():
db.session.commit()
# update conversation mode to agent
- db.session.query(Conversation).filter(Conversation.app_id == app.id).update(
+ db.session.query(Conversation).where(Conversation.app_id == app.id).update(
{Conversation.mode: AppMode.AGENT_CHAT.value}
)
@@ -556,7 +560,7 @@ def old_metadata_migration():
try:
stmt = (
select(DatasetDocument)
- .filter(DatasetDocument.doc_metadata.is_not(None))
+ .where(DatasetDocument.doc_metadata.is_not(None))
.order_by(DatasetDocument.created_at.desc())
)
documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
@@ -574,7 +578,7 @@ def old_metadata_migration():
else:
dataset_metadata = (
db.session.query(DatasetMetadata)
- .filter(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
+ .where(DatasetMetadata.dataset_id == document.dataset_id, DatasetMetadata.name == key)
.first()
)
if not dataset_metadata:
@@ -598,7 +602,7 @@ def old_metadata_migration():
else:
dataset_metadata_binding = (
db.session.query(DatasetMetadataBinding) # type: ignore
- .filter(
+ .where(
DatasetMetadataBinding.dataset_id == document.dataset_id,
DatasetMetadataBinding.document_id == document.id,
DatasetMetadataBinding.metadata_id == dataset_metadata.id,
@@ -713,7 +717,7 @@ where sites.id is null limit 1000"""
continue
try:
- app = db.session.query(App).filter(App.id == app_id).first()
+ app = db.session.query(App).where(App.id == app_id).first()
if not app:
print(f"App {app_id} not found")
continue
@@ -1155,3 +1159,49 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green"))
else:
click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
+
+
+@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
+@click.option("--provider", prompt=True, help="Provider name")
+@click.option("--client-params", prompt=True, help="Client Params")
+def setup_system_tool_oauth_client(provider, client_params):
+ """
+ Setup system tool oauth client
+ """
+ provider_id = ToolProviderID(provider)
+ provider_name = provider_id.provider_name
+ plugin_id = provider_id.plugin_id
+
+ try:
+ # json validate
+ click.echo(click.style(f"Validating client params: {client_params}", fg="yellow"))
+ client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params)
+ click.echo(click.style("Client params validated successfully.", fg="green"))
+
+ click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
+ click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
+ oauth_client_params = encrypt_system_oauth_params(client_params_dict)
+ click.echo(click.style("Client params encrypted successfully.", fg="green"))
+ except Exception as e:
+ click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
+ return
+
+ deleted_count = (
+ db.session.query(ToolOAuthSystemClient)
+ .filter_by(
+ provider=provider_name,
+ plugin_id=plugin_id,
+ )
+ .delete()
+ )
+ if deleted_count > 0:
+ click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow"))
+
+ oauth_client = ToolOAuthSystemClient(
+ provider=provider_name,
+ plugin_id=plugin_id,
+ encrypted_oauth_params=oauth_client_params,
+ )
+ db.session.add(oauth_client)
+ db.session.commit()
+ click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green"))
diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py
index 963fcbedf9..9f1646ea7d 100644
--- a/api/configs/feature/__init__.py
+++ b/api/configs/feature/__init__.py
@@ -31,6 +31,15 @@ class SecurityConfig(BaseSettings):
description="Duration in minutes for which a password reset token remains valid",
default=5,
)
+ CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
+ description="Duration in minutes for which a change email token remains valid",
+ default=5,
+ )
+
+ OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES: PositiveInt = Field(
+ description="Duration in minutes for which a owner transfer token remains valid",
+ default=5,
+ )
LOGIN_DISABLED: bool = Field(
description="Whether to disable login checks",
@@ -537,6 +546,33 @@ class WorkflowNodeExecutionConfig(BaseSettings):
)
+class RepositoryConfig(BaseSettings):
+ """
+ Configuration for repository implementations
+ """
+
+ CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field(
+ description="Repository implementation for WorkflowExecution. Specify as a module path",
+ default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository",
+ )
+
+ CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
+ description="Repository implementation for WorkflowNodeExecution. Specify as a module path",
+ default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository",
+ )
+
+ API_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
+ description="Service-layer repository implementation for WorkflowNodeExecutionModel operations. "
+ "Specify as a module path",
+ default="repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository",
+ )
+
+ API_WORKFLOW_RUN_REPOSITORY: str = Field(
+ description="Service-layer repository implementation for WorkflowRun operations. Specify as a module path",
+ default="repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository",
+ )
+
+
class AuthConfig(BaseSettings):
"""
Configuration for authentication and OAuth
@@ -587,6 +623,16 @@ class AuthConfig(BaseSettings):
default=86400,
)
+ CHANGE_EMAIL_LOCKOUT_DURATION: PositiveInt = Field(
+ description="Time (in seconds) a user must wait before retrying change email after exceeding the rate limit.",
+ default=86400,
+ )
+
+ OWNER_TRANSFER_LOCKOUT_DURATION: PositiveInt = Field(
+ description="Time (in seconds) a user must wait before retrying owner transfer after exceeding the rate limit.",
+ default=86400,
+ )
+
class ModerationConfig(BaseSettings):
"""
@@ -786,6 +832,41 @@ class CeleryBeatConfig(BaseSettings):
)
+class CeleryScheduleTasksConfig(BaseSettings):
+ ENABLE_CLEAN_EMBEDDING_CACHE_TASK: bool = Field(
+ description="Enable clean embedding cache task",
+ default=False,
+ )
+ ENABLE_CLEAN_UNUSED_DATASETS_TASK: bool = Field(
+ description="Enable clean unused datasets task",
+ default=False,
+ )
+ ENABLE_CREATE_TIDB_SERVERLESS_TASK: bool = Field(
+ description="Enable create tidb service job task",
+ default=False,
+ )
+ ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK: bool = Field(
+ description="Enable update tidb service job status task",
+ default=False,
+ )
+ ENABLE_CLEAN_MESSAGES: bool = Field(
+ description="Enable clean messages task",
+ default=False,
+ )
+ ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
+ description="Enable mail clean document notify task",
+ default=False,
+ )
+ ENABLE_DATASETS_QUEUE_MONITOR: bool = Field(
+ description="Enable queue monitor task",
+ default=False,
+ )
+ ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field(
+ description="Enable check upgradable plugin task",
+ default=True,
+ )
+
+
class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field(
description="Comma-separated list of pinned model providers",
@@ -903,6 +984,7 @@ class FeatureConfig(
MultiModalTransferConfig,
PositionConfig,
RagEtlConfig,
+ RepositoryConfig,
SecurityConfig,
ToolConfig,
UpdateConfig,
@@ -914,5 +996,6 @@ class FeatureConfig(
# hosted services config
HostedServiceConfig,
CeleryBeatConfig,
+ CeleryScheduleTasksConfig,
):
pass
diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py
index 427602676f..587ea55ca7 100644
--- a/api/configs/middleware/__init__.py
+++ b/api/configs/middleware/__init__.py
@@ -85,6 +85,11 @@ class VectorStoreConfig(BaseSettings):
default=False,
)
+ VECTOR_INDEX_NAME_PREFIX: Optional[str] = Field(
+ description="Prefix used to create collection name in vector database",
+ default="Vector_index",
+ )
+
class KeywordStoreConfig(BaseSettings):
KEYWORD_STORE: str = Field(
@@ -162,6 +167,11 @@ class DatabaseConfig(BaseSettings):
default=3600,
)
+ SQLALCHEMY_POOL_USE_LIFO: bool = Field(
+ description="If True, SQLAlchemy will use last-in-first-out way to retrieve connections from pool.",
+ default=False,
+ )
+
SQLALCHEMY_POOL_PRE_PING: bool = Field(
description="If True, enables connection pool pre-ping feature to check connections.",
default=False,
@@ -199,13 +209,14 @@ class DatabaseConfig(BaseSettings):
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
"connect_args": connect_args,
+ "pool_use_lifo": self.SQLALCHEMY_POOL_USE_LIFO,
}
class CeleryConfig(DatabaseConfig):
CELERY_BACKEND: str = Field(
description="Backend for Celery task results. Options: 'database', 'redis'.",
- default="database",
+ default="redis",
)
CELERY_BROKER_URL: Optional[str] = Field(
diff --git a/api/configs/observability/otel/otel_config.py b/api/configs/observability/otel/otel_config.py
index 1b88ddcfe6..7572a696ce 100644
--- a/api/configs/observability/otel/otel_config.py
+++ b/api/configs/observability/otel/otel_config.py
@@ -12,6 +12,16 @@ class OTelConfig(BaseSettings):
default=False,
)
+ OTLP_TRACE_ENDPOINT: str = Field(
+ description="OTLP trace endpoint",
+ default="",
+ )
+
+ OTLP_METRIC_ENDPOINT: str = Field(
+ description="OTLP metric endpoint",
+ default="",
+ )
+
OTLP_BASE_ENDPOINT: str = Field(
description="OTLP base endpoint",
default="http://localhost:4318",
diff --git a/api/constants/__init__.py b/api/constants/__init__.py
index a84de0a451..9e052320ac 100644
--- a/api/constants/__init__.py
+++ b/api/constants/__init__.py
@@ -1,6 +1,7 @@
from configs import dify_config
HIDDEN_VALUE = "[__HIDDEN__]"
+UNKNOWN_VALUE = "[__UNKNOWN__]"
UUID_NIL = "00000000-0000-0000-0000-000000000000"
DEFAULT_FILE_NUMBER_LIMITS = 3
diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py
index f5257fae79..8a55197fb6 100644
--- a/api/controllers/console/admin.py
+++ b/api/controllers/console/admin.py
@@ -56,7 +56,7 @@ class InsertExploreAppListApi(Resource):
parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args()
- app = db.session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
+ app = db.session.execute(select(App).where(App.id == args["app_id"])).scalar_one_or_none()
if not app:
raise NotFound(f"App '{args['app_id']}' is not found")
@@ -74,7 +74,7 @@ class InsertExploreAppListApi(Resource):
with Session(db.engine) as session:
recommended_app = session.execute(
- select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
+ select(RecommendedApp).where(RecommendedApp.app_id == args["app_id"])
).scalar_one_or_none()
if not recommended_app:
@@ -117,21 +117,21 @@ class InsertExploreAppApi(Resource):
def delete(self, app_id):
with Session(db.engine) as session:
recommended_app = session.execute(
- select(RecommendedApp).filter(RecommendedApp.app_id == str(app_id))
+ select(RecommendedApp).where(RecommendedApp.app_id == str(app_id))
).scalar_one_or_none()
if not recommended_app:
return {"result": "success"}, 204
with Session(db.engine) as session:
- app = session.execute(select(App).filter(App.id == recommended_app.app_id)).scalar_one_or_none()
+ app = session.execute(select(App).where(App.id == recommended_app.app_id)).scalar_one_or_none()
if app:
app.is_public = False
with Session(db.engine) as session:
installed_apps = session.execute(
- select(InstalledApp).filter(
+ select(InstalledApp).where(
InstalledApp.app_id == recommended_app.app_id,
InstalledApp.tenant_id != InstalledApp.app_owner_tenant_id,
)
diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py
index 47c93a15c6..d7500c415c 100644
--- a/api/controllers/console/apikey.py
+++ b/api/controllers/console/apikey.py
@@ -61,7 +61,7 @@ class BaseApiKeyListResource(Resource):
_get_resource(resource_id, current_user.current_tenant_id, self.resource_model)
keys = (
db.session.query(ApiToken)
- .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
+ .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.all()
)
return {"items": keys}
@@ -76,7 +76,7 @@ class BaseApiKeyListResource(Resource):
current_key_count = (
db.session.query(ApiToken)
- .filter(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
+ .where(ApiToken.type == self.resource_type, getattr(ApiToken, self.resource_id_field) == resource_id)
.count()
)
@@ -117,7 +117,7 @@ class BaseApiKeyResource(Resource):
key = (
db.session.query(ApiToken)
- .filter(
+ .where(
getattr(ApiToken, self.resource_id_field) == resource_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
@@ -128,7 +128,7 @@ class BaseApiKeyResource(Resource):
if key is None:
flask_restful.abort(404, message="API key not found")
- db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
+ db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()
return {"result": "success"}, 204
diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py
index 860166a61a..9fe32dde6d 100644
--- a/api/controllers/console/app/app.py
+++ b/api/controllers/console/app/app.py
@@ -151,6 +151,7 @@ class AppApi(Resource):
parser.add_argument("icon", type=str, location="json")
parser.add_argument("icon_background", type=str, location="json")
parser.add_argument("use_icon_as_answer_icon", type=bool, location="json")
+ parser.add_argument("max_active_requests", type=int, location="json")
args = parser.parse_args()
app_service = AppService()
diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py
index 70d6216497..b5b6d1f75b 100644
--- a/api/controllers/console/app/conversation.py
+++ b/api/controllers/console/app/conversation.py
@@ -1,4 +1,4 @@
-from datetime import UTC, datetime
+from datetime import datetime
import pytz # pip install pytz
from flask_login import current_user
@@ -19,6 +19,7 @@ from fields.conversation_fields import (
conversation_pagination_fields,
conversation_with_summary_pagination_fields,
)
+from libs.datetime_utils import naive_utc_now
from libs.helper import DatetimeString
from libs.login import login_required
from models import Conversation, EndUser, Message, MessageAnnotation
@@ -48,7 +49,7 @@ class CompletionConversationApi(Resource):
query = db.select(Conversation).where(Conversation.app_id == app_model.id, Conversation.mode == "completion")
if args["keyword"]:
- query = query.join(Message, Message.conversation_id == Conversation.id).filter(
+ query = query.join(Message, Message.conversation_id == Conversation.id).where(
or_(
Message.query.ilike("%{}%".format(args["keyword"])),
Message.answer.ilike("%{}%".format(args["keyword"])),
@@ -120,7 +121,7 @@ class CompletionConversationDetailApi(Resource):
conversation = (
db.session.query(Conversation)
- .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
+ .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
@@ -180,7 +181,7 @@ class ChatConversationApi(Resource):
Message.conversation_id == Conversation.id,
)
.join(subquery, subquery.c.conversation_id == Conversation.id)
- .filter(
+ .where(
or_(
Message.query.ilike(keyword_filter),
Message.answer.ilike(keyword_filter),
@@ -285,7 +286,7 @@ class ChatConversationDetailApi(Resource):
conversation = (
db.session.query(Conversation)
- .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
+ .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
@@ -307,7 +308,7 @@ api.add_resource(ChatConversationDetailApi, "/apps//chat-conversati
def _get_conversation(app_model, conversation_id):
conversation = (
db.session.query(Conversation)
- .filter(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
+ .where(Conversation.id == conversation_id, Conversation.app_id == app_model.id)
.first()
)
@@ -315,7 +316,7 @@ def _get_conversation(app_model, conversation_id):
raise NotFound("Conversation Not Exists.")
if not conversation.read_at:
- conversation.read_at = datetime.now(UTC).replace(tzinfo=None)
+ conversation.read_at = naive_utc_now()
conversation.read_account_id = current_user.id
db.session.commit()
diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py
index 0f53860f56..2344fd5acb 100644
--- a/api/controllers/console/app/mcp_server.py
+++ b/api/controllers/console/app/mcp_server.py
@@ -26,7 +26,7 @@ class AppMCPServerController(Resource):
@get_app_model
@marshal_with(app_server_fields)
def get(self, app_model):
- server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == app_model.id).first()
+ server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == app_model.id).first()
return server
@setup_required
@@ -35,16 +35,20 @@ class AppMCPServerController(Resource):
@get_app_model
@marshal_with(app_server_fields)
def post(self, app_model):
- # The role of the current user in the ta table must be editor, admin, or owner
if not current_user.is_editor:
raise NotFound()
parser = reqparse.RequestParser()
- parser.add_argument("description", type=str, required=True, location="json")
+ parser.add_argument("description", type=str, required=False, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
args = parser.parse_args()
+
+ description = args.get("description")
+ if not description:
+ description = app_model.description or ""
+
server = AppMCPServer(
name=app_model.name,
- description=args["description"],
+ description=description,
parameters=json.dumps(args["parameters"], ensure_ascii=False),
status=AppMCPServerStatus.ACTIVE,
app_id=app_model.id,
@@ -65,14 +69,22 @@ class AppMCPServerController(Resource):
raise NotFound()
parser = reqparse.RequestParser()
parser.add_argument("id", type=str, required=True, location="json")
- parser.add_argument("description", type=str, required=True, location="json")
+ parser.add_argument("description", type=str, required=False, location="json")
parser.add_argument("parameters", type=dict, required=True, location="json")
parser.add_argument("status", type=str, required=False, location="json")
args = parser.parse_args()
- server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first()
+ server = db.session.query(AppMCPServer).where(AppMCPServer.id == args["id"]).first()
if not server:
raise NotFound()
- server.description = args["description"]
+
+ description = args.get("description")
+ if description is None:
+ pass
+ elif not description:
+ server.description = app_model.description or ""
+ else:
+ server.description = description
+
server.parameters = json.dumps(args["parameters"], ensure_ascii=False)
if args["status"]:
if args["status"] not in [status.value for status in AppMCPServerStatus]:
@@ -92,8 +104,8 @@ class AppMCPServerRefreshController(Resource):
raise NotFound()
server = (
db.session.query(AppMCPServer)
- .filter(AppMCPServer.id == server_id)
- .filter(AppMCPServer.tenant_id == current_user.current_tenant_id)
+ .where(AppMCPServer.id == server_id)
+ .where(AppMCPServer.tenant_id == current_user.current_tenant_id)
.first()
)
if not server:
diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py
index b7a4c31a15..5e79e8dece 100644
--- a/api/controllers/console/app/message.py
+++ b/api/controllers/console/app/message.py
@@ -5,6 +5,7 @@ from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
+import services
from controllers.console import api
from controllers.console.app.error import (
CompletionRequestError,
@@ -27,7 +28,7 @@ from fields.conversation_fields import annotation_fields, message_detail_fields
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import login_required
-from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
+from models.model import AppMode, Conversation, Message, MessageAnnotation
from services.annotation_service import AppAnnotationService
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@@ -55,7 +56,7 @@ class ChatMessageListApi(Resource):
conversation = (
db.session.query(Conversation)
- .filter(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
+ .where(Conversation.id == args["conversation_id"], Conversation.app_id == app_model.id)
.first()
)
@@ -65,7 +66,7 @@ class ChatMessageListApi(Resource):
if args["first_id"]:
first_message = (
db.session.query(Message)
- .filter(Message.conversation_id == conversation.id, Message.id == args["first_id"])
+ .where(Message.conversation_id == conversation.id, Message.id == args["first_id"])
.first()
)
@@ -74,7 +75,7 @@ class ChatMessageListApi(Resource):
history_messages = (
db.session.query(Message)
- .filter(
+ .where(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
Message.id != first_message.id,
@@ -86,7 +87,7 @@ class ChatMessageListApi(Resource):
else:
history_messages = (
db.session.query(Message)
- .filter(Message.conversation_id == conversation.id)
+ .where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(args["limit"])
.all()
@@ -97,7 +98,7 @@ class ChatMessageListApi(Resource):
current_page_first_message = history_messages[-1]
rest_count = (
db.session.query(Message)
- .filter(
+ .where(
Message.conversation_id == conversation.id,
Message.created_at < current_page_first_message.created_at,
Message.id != current_page_first_message.id,
@@ -124,33 +125,16 @@ class MessageFeedbackApi(Resource):
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
args = parser.parse_args()
- message_id = str(args["message_id"])
-
- message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
-
- if not message:
- raise NotFound("Message Not Exists.")
-
- feedback = message.admin_feedback
-
- if not args["rating"] and feedback:
- db.session.delete(feedback)
- elif args["rating"] and feedback:
- feedback.rating = args["rating"]
- elif not args["rating"] and not feedback:
- raise ValueError("rating cannot be None when feedback not exists")
- else:
- feedback = MessageFeedback(
- app_id=app_model.id,
- conversation_id=message.conversation_id,
- message_id=message.id,
- rating=args["rating"],
- from_source="admin",
- from_account_id=current_user.id,
+ try:
+ MessageService.create_feedback(
+ app_model=app_model,
+ message_id=str(args["message_id"]),
+ user=current_user,
+ rating=args.get("rating"),
+ content=None,
)
- db.session.add(feedback)
-
- db.session.commit()
+ except services.errors.message.MessageNotExistsError:
+ raise NotFound("Message Not Exists.")
return {"result": "success"}
@@ -183,7 +167,7 @@ class MessageAnnotationCountApi(Resource):
@account_initialization_required
@get_app_model
def get(self, app_model):
- count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_model.id).count()
+ count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_model.id).count()
return {"count": count}
@@ -230,7 +214,7 @@ class MessageApi(Resource):
def get(self, app_model, message_id):
message_id = str(message_id)
- message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first()
+ message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app_model.id).first()
if not message:
raise NotFound("Message Not Exists.")
diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py
index f30e3e893c..029138fb6b 100644
--- a/api/controllers/console/app/model_config.py
+++ b/api/controllers/console/app/model_config.py
@@ -42,7 +42,7 @@ class ModelConfigResource(Resource):
if app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent:
# get original app model config
original_app_model_config = (
- db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
+ db.session.query(AppModelConfig).where(AppModelConfig.id == app_model.app_model_config_id).first()
)
if original_app_model_config is None:
raise ValueError("Original app model config not found")
diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py
index 3c3a359eeb..03418f1dd2 100644
--- a/api/controllers/console/app/site.py
+++ b/api/controllers/console/app/site.py
@@ -1,5 +1,3 @@
-from datetime import UTC, datetime
-
from flask_login import current_user
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
@@ -10,6 +8,7 @@ from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from extensions.ext_database import db
from fields.app_fields import app_site_fields
+from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import Site
@@ -50,7 +49,7 @@ class AppSite(Resource):
if not current_user.is_editor:
raise Forbidden()
- site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
+ site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise NotFound
@@ -77,7 +76,7 @@ class AppSite(Resource):
setattr(site, attr_name, value)
site.updated_by = current_user.id
- site.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ site.updated_at = naive_utc_now()
db.session.commit()
return site
@@ -94,14 +93,14 @@ class AppSiteAccessTokenReset(Resource):
if not current_user.is_admin_or_owner:
raise Forbidden()
- site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
+ site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise NotFound
site.code = Site.generate_code(16)
site.updated_by = current_user.id
- site.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ site.updated_at = naive_utc_now()
db.session.commit()
return site
diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py
index 86aed77412..32b64d10c5 100644
--- a/api/controllers/console/app/statistic.py
+++ b/api/controllers/console/app/statistic.py
@@ -2,6 +2,7 @@ from datetime import datetime
from decimal import Decimal
import pytz
+import sqlalchemy as sa
from flask import jsonify
from flask_login import current_user
from flask_restful import Resource, reqparse
@@ -9,10 +10,11 @@ from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
+from core.app.entities.app_invoke_entities import InvokeFrom
from extensions.ext_database import db
from libs.helper import DatetimeString
from libs.login import login_required
-from models.model import AppMode
+from models import AppMode, Message
class DailyMessageStatistic(Resource):
@@ -85,46 +87,41 @@ class DailyConversationStatistic(Resource):
parser.add_argument("end", type=DatetimeString("%Y-%m-%d %H:%M"), location="args")
args = parser.parse_args()
- sql_query = """SELECT
- DATE(DATE_TRUNC('day', created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz )) AS date,
- COUNT(DISTINCT messages.conversation_id) AS conversation_count
-FROM
- messages
-WHERE
- app_id = :app_id"""
- arg_dict = {"tz": account.timezone, "app_id": app_model.id}
-
timezone = pytz.timezone(account.timezone)
utc_timezone = pytz.utc
+ stmt = (
+ sa.select(
+ sa.func.date(
+ sa.func.date_trunc("day", sa.text("created_at AT TIME ZONE 'UTC' AT TIME ZONE :tz"))
+ ).label("date"),
+ sa.func.count(sa.distinct(Message.conversation_id)).label("conversation_count"),
+ )
+ .select_from(Message)
+ .where(Message.app_id == app_model.id, Message.invoke_from != InvokeFrom.DEBUGGER.value)
+ )
+
if args["start"]:
start_datetime = datetime.strptime(args["start"], "%Y-%m-%d %H:%M")
start_datetime = start_datetime.replace(second=0)
-
start_datetime_timezone = timezone.localize(start_datetime)
start_datetime_utc = start_datetime_timezone.astimezone(utc_timezone)
-
- sql_query += " AND created_at >= :start"
- arg_dict["start"] = start_datetime_utc
+ stmt = stmt.where(Message.created_at >= start_datetime_utc)
if args["end"]:
end_datetime = datetime.strptime(args["end"], "%Y-%m-%d %H:%M")
end_datetime = end_datetime.replace(second=0)
-
end_datetime_timezone = timezone.localize(end_datetime)
end_datetime_utc = end_datetime_timezone.astimezone(utc_timezone)
+ stmt = stmt.where(Message.created_at < end_datetime_utc)
- sql_query += " AND created_at < :end"
- arg_dict["end"] = end_datetime_utc
-
- sql_query += " GROUP BY date ORDER BY date"
+ stmt = stmt.group_by("date").order_by("date")
response_data = []
-
with db.engine.begin() as conn:
- rs = conn.execute(db.text(sql_query), arg_dict)
- for i in rs:
- response_data.append({"date": str(i.date), "conversation_count": i.conversation_count})
+ rs = conn.execute(stmt, {"tz": account.timezone})
+ for row in rs:
+ response_data.append({"date": str(row.date), "conversation_count": row.conversation_count})
return jsonify({"data": response_data})
diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py
index 00d6fa3cbf..ba93f82756 100644
--- a/api/controllers/console/app/workflow_draft_variable.py
+++ b/api/controllers/console/app/workflow_draft_variable.py
@@ -68,13 +68,18 @@ def _create_pagination_parser():
return parser
+def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
+ value_type = workflow_draft_var.value_type
+ return value_type.exposed_type().value
+
+
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda model: model.get_variable_type()),
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
- "value_type": fields.String,
+ "value_type": fields.String(attribute=_serialize_variable_type),
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}
@@ -90,7 +95,7 @@ _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
- "value_type": fields.String,
+ "value_type": fields.String(attribute=_serialize_variable_type),
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}
@@ -396,7 +401,7 @@ class EnvironmentVariableCollectionApi(Resource):
"name": v.name,
"description": v.description,
"selector": v.selector,
- "value_type": v.value_type.value,
+ "value_type": v.value_type.exposed_type().value,
"value": v.value,
# Do not track edited for env vars.
"edited": False,
diff --git a/api/controllers/console/app/wraps.py b/api/controllers/console/app/wraps.py
index 03b60610aa..132dc1f96b 100644
--- a/api/controllers/console/app/wraps.py
+++ b/api/controllers/console/app/wraps.py
@@ -11,7 +11,7 @@ from models import App, AppMode
def _load_app_model(app_id: str) -> Optional[App]:
app_model = (
db.session.query(App)
- .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+ .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
return app_model
@@ -35,8 +35,6 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[
raise AppNotFoundError()
app_mode = AppMode.value_of(app_model.mode)
- if app_mode == AppMode.CHANNEL:
- raise AppNotFoundError()
if mode is not None:
if isinstance(mode, list):
diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py
index 1795563ff7..2562fb5eb8 100644
--- a/api/controllers/console/auth/activate.py
+++ b/api/controllers/console/auth/activate.py
@@ -1,5 +1,3 @@
-import datetime
-
from flask import request
from flask_restful import Resource, reqparse
@@ -7,6 +5,7 @@ from constants.languages import supported_language
from controllers.console import api
from controllers.console.error import AlreadyActivateError
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from libs.helper import StrLen, email, extract_remote_ip, timezone
from models.account import AccountStatus
from services.account_service import AccountService, RegisterService
@@ -65,7 +64,7 @@ class ActivateApi(Resource):
account.timezone = args["timezone"]
account.interface_theme = "light"
account.status = AccountStatus.ACTIVE.value
- account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ account.initialized_at = naive_utc_now()
db.session.commit()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py
index b40934dbf5..8c5e23de58 100644
--- a/api/controllers/console/auth/error.py
+++ b/api/controllers/console/auth/error.py
@@ -27,7 +27,19 @@ class InvalidTokenError(BaseHTTPException):
class PasswordResetRateLimitExceededError(BaseHTTPException):
error_code = "password_reset_rate_limit_exceeded"
- description = "Too many password reset emails have been sent. Please try again in 1 minutes."
+ description = "Too many password reset emails have been sent. Please try again in 1 minute."
+ code = 429
+
+
+class EmailChangeRateLimitExceededError(BaseHTTPException):
+ error_code = "email_change_rate_limit_exceeded"
+ description = "Too many email change emails have been sent. Please try again in 1 minute."
+ code = 429
+
+
+class OwnerTransferRateLimitExceededError(BaseHTTPException):
+ error_code = "owner_transfer_rate_limit_exceeded"
+ description = "Too many owner transfer emails have been sent. Please try again in 1 minute."
code = 429
@@ -65,3 +77,39 @@ class EmailPasswordResetLimitError(BaseHTTPException):
error_code = "email_password_reset_limit"
description = "Too many failed password reset attempts. Please try again in 24 hours."
code = 429
+
+
+class EmailChangeLimitError(BaseHTTPException):
+ error_code = "email_change_limit"
+ description = "Too many failed email change attempts. Please try again in 24 hours."
+ code = 429
+
+
+class EmailAlreadyInUseError(BaseHTTPException):
+ error_code = "email_already_in_use"
+ description = "A user with this email already exists."
+ code = 400
+
+
+class OwnerTransferLimitError(BaseHTTPException):
+ error_code = "owner_transfer_limit"
+ description = "Too many failed owner transfer attempts. Please try again in 24 hours."
+ code = 429
+
+
+class NotOwnerError(BaseHTTPException):
+ error_code = "not_owner"
+ description = "You are not the owner of the workspace."
+ code = 400
+
+
+class CannotTransferOwnerToSelfError(BaseHTTPException):
+ error_code = "cannot_transfer_owner_to_self"
+ description = "You cannot transfer ownership to yourself."
+ code = 400
+
+
+class MemberNotInTenantError(BaseHTTPException):
+ error_code = "member_not_in_tenant"
+ description = "The member is not in the workspace."
+ code = 400
diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py
index 395367c9e2..d0a4f3ff6d 100644
--- a/api/controllers/console/auth/oauth.py
+++ b/api/controllers/console/auth/oauth.py
@@ -1,5 +1,4 @@
import logging
-from datetime import UTC, datetime
from typing import Optional
import requests
@@ -13,6 +12,7 @@ from configs import dify_config
from constants.languages import languages
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from libs.helper import extract_remote_ip
from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo
from models import Account
@@ -110,7 +110,7 @@ class OAuthCallback(Resource):
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
- account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
+ account.initialized_at = naive_utc_now()
db.session.commit()
try:
diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py
index 7b0d9373cf..39f8ab5787 100644
--- a/api/controllers/console/datasets/data_source.py
+++ b/api/controllers/console/datasets/data_source.py
@@ -1,4 +1,3 @@
-import datetime
import json
from flask import request
@@ -15,6 +14,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.notion_extractor import NotionExtractor
from extensions.ext_database import db
from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields
+from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import DataSourceOauthBinding, Document
from services.dataset_service import DatasetService, DocumentService
@@ -30,7 +30,7 @@ class DataSourceApi(Resource):
# get workspace data source integrates
data_source_integrates = (
db.session.query(DataSourceOauthBinding)
- .filter(
+ .where(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.disabled == False,
)
@@ -88,7 +88,7 @@ class DataSourceApi(Resource):
if action == "enable":
if data_source_binding.disabled:
data_source_binding.disabled = False
- data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
@@ -97,7 +97,7 @@ class DataSourceApi(Resource):
if action == "disable":
if not data_source_binding.disabled:
data_source_binding.disabled = True
- data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
@@ -171,7 +171,7 @@ class DataSourceNotionApi(Resource):
page_id = str(page_id)
with Session(db.engine) as session:
data_source_binding = session.execute(
- select(DataSourceOauthBinding).filter(
+ select(DataSourceOauthBinding).where(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py
index 1611214cb3..f551bc2432 100644
--- a/api/controllers/console/datasets/datasets.py
+++ b/api/controllers/console/datasets/datasets.py
@@ -211,10 +211,6 @@ class DatasetApi(Resource):
else:
data["embedding_available"] = True
- if data.get("permission") == "partial_members":
- part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str)
- data.update({"partial_member_list": part_users_list})
-
return data, 200
@setup_required
@@ -416,7 +412,7 @@ class DatasetIndexingEstimateApi(Resource):
file_ids = args["info_list"]["file_info_list"]["file_ids"]
file_details = (
db.session.query(UploadFile)
- .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
+ .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id.in_(file_ids))
.all()
)
@@ -521,14 +517,14 @@ class DatasetIndexingStatusApi(Resource):
dataset_id = str(dataset_id)
documents = (
db.session.query(Document)
- .filter(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
+ .where(Document.dataset_id == dataset_id, Document.tenant_id == current_user.current_tenant_id)
.all()
)
documents_status = []
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
- .filter(
+ .where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
@@ -537,7 +533,7 @@ class DatasetIndexingStatusApi(Resource):
)
total_segments = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
+ .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
# Create a dictionary with document attributes and additional fields
@@ -572,7 +568,7 @@ class DatasetApiKeyApi(Resource):
def get(self):
keys = (
db.session.query(ApiToken)
- .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
+ .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.all()
)
return {"items": keys}
@@ -588,7 +584,7 @@ class DatasetApiKeyApi(Resource):
current_key_count = (
db.session.query(ApiToken)
- .filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
+ .where(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id)
.count()
)
@@ -624,7 +620,7 @@ class DatasetApiDeleteApi(Resource):
key = (
db.session.query(ApiToken)
- .filter(
+ .where(
ApiToken.tenant_id == current_user.current_tenant_id,
ApiToken.type == self.resource_type,
ApiToken.id == api_key_id,
@@ -635,7 +631,7 @@ class DatasetApiDeleteApi(Resource):
if key is None:
flask_restful.abort(404, message="API key not found")
- db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete()
+ db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()
return {"result": "success"}, 204
diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py
index b2fcf3ce7b..d14b208a4b 100644
--- a/api/controllers/console/datasets/datasets_document.py
+++ b/api/controllers/console/datasets/datasets_document.py
@@ -1,6 +1,5 @@
import logging
from argparse import ArgumentTypeError
-from datetime import UTC, datetime
from typing import cast
from flask import request
@@ -49,6 +48,7 @@ from fields.document_fields import (
document_status_fields,
document_with_segments_fields,
)
+from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from services.dataset_service import DatasetService, DocumentService
@@ -124,7 +124,7 @@ class GetProcessRuleApi(Resource):
# get the latest process rule
dataset_process_rule = (
db.session.query(DatasetProcessRule)
- .filter(DatasetProcessRule.dataset_id == document.dataset_id)
+ .where(DatasetProcessRule.dataset_id == document.dataset_id)
.order_by(DatasetProcessRule.created_at.desc())
.limit(1)
.one_or_none()
@@ -176,7 +176,7 @@ class DatasetDocumentListApi(Resource):
if search:
search = f"%{search}%"
- query = query.filter(Document.name.like(search))
+ query = query.where(Document.name.like(search))
if sort.startswith("-"):
sort_logic = desc
@@ -212,7 +212,7 @@ class DatasetDocumentListApi(Resource):
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
- .filter(
+ .where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
@@ -221,7 +221,7 @@ class DatasetDocumentListApi(Resource):
)
total_segments = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
+ .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments
@@ -417,7 +417,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
file = (
db.session.query(UploadFile)
- .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
+ .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.first()
)
@@ -492,7 +492,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
- .filter(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
+ .where(UploadFile.tenant_id == current_user.current_tenant_id, UploadFile.id == file_id)
.first()
)
@@ -568,7 +568,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
- .filter(
+ .where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
@@ -577,7 +577,7 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
)
total_segments = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
+ .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
# Create a dictionary with document attributes and additional fields
@@ -611,7 +611,7 @@ class DocumentIndexingStatusApi(DocumentResource):
completed_segments = (
db.session.query(DocumentSegment)
- .filter(
+ .where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id),
DocumentSegment.status != "re_segment",
@@ -620,7 +620,7 @@ class DocumentIndexingStatusApi(DocumentResource):
)
total_segments = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
+ .where(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.count()
)
@@ -750,7 +750,7 @@ class DocumentProcessingApi(DocumentResource):
raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id
- document.paused_at = datetime.now(UTC).replace(tzinfo=None)
+ document.paused_at = naive_utc_now()
document.is_paused = True
db.session.commit()
@@ -830,7 +830,7 @@ class DocumentMetadataApi(DocumentResource):
document.doc_metadata[key] = value
document.doc_type = doc_type
- document.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ document.updated_at = naive_utc_now()
db.session.commit()
return {"result": "success", "message": "Document metadata updated."}, 200
diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py
index 48142dbe73..b3704ce8b1 100644
--- a/api/controllers/console/datasets/datasets_segments.py
+++ b/api/controllers/console/datasets/datasets_segments.py
@@ -78,7 +78,7 @@ class DatasetDocumentSegmentListApi(Resource):
query = (
select(DocumentSegment)
- .filter(
+ .where(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id,
)
@@ -86,19 +86,19 @@ class DatasetDocumentSegmentListApi(Resource):
)
if status_list:
- query = query.filter(DocumentSegment.status.in_(status_list))
+ query = query.where(DocumentSegment.status.in_(status_list))
if hit_count_gte is not None:
- query = query.filter(DocumentSegment.hit_count >= hit_count_gte)
+ query = query.where(DocumentSegment.hit_count >= hit_count_gte)
if keyword:
query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
if args["enabled"].lower() != "all":
if args["enabled"].lower() == "true":
- query = query.filter(DocumentSegment.enabled == True)
+ query = query.where(DocumentSegment.enabled == True)
elif args["enabled"].lower() == "false":
- query = query.filter(DocumentSegment.enabled == False)
+ query = query.where(DocumentSegment.enabled == False)
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@@ -285,7 +285,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
+ .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@@ -331,7 +331,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
+ .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@@ -436,7 +436,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
+ .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@@ -493,7 +493,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
+ .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@@ -540,7 +540,7 @@ class ChildChunkAddApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
+ .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@@ -586,7 +586,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
+ .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@@ -595,7 +595,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk_id = str(child_chunk_id)
child_chunk = (
db.session.query(ChildChunk)
- .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
+ .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first()
)
if not child_chunk:
@@ -635,7 +635,7 @@ class ChildChunkUpdateApi(Resource):
segment_id = str(segment_id)
segment = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
+ .where(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment:
@@ -644,7 +644,7 @@ class ChildChunkUpdateApi(Resource):
child_chunk_id = str(child_chunk_id)
child_chunk = (
db.session.query(ChildChunk)
- .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
+ .where(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first()
)
if not child_chunk:
diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py
index 2f00a84de6..cb68bb5e81 100644
--- a/api/controllers/console/datasets/error.py
+++ b/api/controllers/console/datasets/error.py
@@ -25,12 +25,6 @@ class UnsupportedFileTypeError(BaseHTTPException):
code = 415
-class HighQualityDatasetOnlyError(BaseHTTPException):
- error_code = "high_quality_dataset_only"
- description = "Current operation only supports 'high-quality' datasets."
- code = 400
-
-
class DatasetNotInitializedError(BaseHTTPException):
error_code = "dataset_not_initialized"
description = "The dataset is still being initialized or indexing. Please wait a moment."
diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py
index 4200a51709..fcdc91ec67 100644
--- a/api/controllers/console/datasets/website.py
+++ b/api/controllers/console/datasets/website.py
@@ -4,7 +4,7 @@ from controllers.console import api
from controllers.console.datasets.error import WebsiteCrawlError
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import login_required
-from services.website_service import WebsiteService
+from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService
class WebsiteCrawlApi(Resource):
@@ -24,10 +24,16 @@ class WebsiteCrawlApi(Resource):
parser.add_argument("url", type=str, required=True, nullable=True, location="json")
parser.add_argument("options", type=dict, required=True, nullable=True, location="json")
args = parser.parse_args()
- WebsiteService.document_create_args_validate(args)
- # crawl url
+
+ # Create typed request and validate
+ try:
+ api_request = WebsiteCrawlApiRequest.from_args(args)
+ except ValueError as e:
+ raise WebsiteCrawlError(str(e))
+
+ # Crawl URL using typed request
try:
- result = WebsiteService.crawl_url(args)
+ result = WebsiteService.crawl_url(api_request)
except Exception as e:
raise WebsiteCrawlError(str(e))
return result, 200
@@ -43,9 +49,16 @@ class WebsiteCrawlStatusApi(Resource):
"provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args"
)
args = parser.parse_args()
- # get crawl status
+
+ # Create typed request and validate
+ try:
+ api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id)
+ except ValueError as e:
+ raise WebsiteCrawlError(str(e))
+
+ # Get crawl status using typed request
try:
- result = WebsiteService.get_crawl_status(job_id, args["provider"])
+ result = WebsiteService.get_crawl_status_typed(api_request)
except Exception as e:
raise WebsiteCrawlError(str(e))
return result, 200
diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py
index 4367da1162..4842fefc57 100644
--- a/api/controllers/console/explore/completion.py
+++ b/api/controllers/console/explore/completion.py
@@ -1,5 +1,4 @@
import logging
-from datetime import UTC, datetime
from flask_login import current_user
from flask_restful import reqparse
@@ -27,6 +26,7 @@ from core.errors.error import (
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
from libs import helper
+from libs.datetime_utils import naive_utc_now
from libs.helper import uuid_value
from models.model import AppMode
from services.app_generate_service import AppGenerateService
@@ -51,7 +51,7 @@ class CompletionApi(InstalledAppResource):
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False
- installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
+ installed_app.last_used_at = naive_utc_now()
db.session.commit()
try:
@@ -111,7 +111,7 @@ class ChatApi(InstalledAppResource):
args["auto_generate_name"] = False
- installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None)
+ installed_app.last_used_at = naive_utc_now()
db.session.commit()
try:
diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py
index 9d0c08564e..ffdf73c368 100644
--- a/api/controllers/console/explore/installed_app.py
+++ b/api/controllers/console/explore/installed_app.py
@@ -1,5 +1,4 @@
import logging
-from datetime import UTC, datetime
from typing import Any
from flask import request
@@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
+from libs.datetime_utils import naive_utc_now
from libs.login import login_required
from models import App, InstalledApp, RecommendedApp
from services.account_service import TenantService
@@ -34,11 +34,11 @@ class InstalledAppsListApi(Resource):
if app_id:
installed_apps = (
db.session.query(InstalledApp)
- .filter(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
+ .where(and_(InstalledApp.tenant_id == current_tenant_id, InstalledApp.app_id == app_id))
.all()
)
else:
- installed_apps = db.session.query(InstalledApp).filter(InstalledApp.tenant_id == current_tenant_id).all()
+ installed_apps = db.session.query(InstalledApp).where(InstalledApp.tenant_id == current_tenant_id).all()
current_user.role = TenantService.get_user_role(current_user, current_user.current_tenant)
installed_app_list: list[dict[str, Any]] = [
@@ -94,12 +94,12 @@ class InstalledAppsListApi(Resource):
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
args = parser.parse_args()
- recommended_app = db.session.query(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]).first()
+ recommended_app = db.session.query(RecommendedApp).where(RecommendedApp.app_id == args["app_id"]).first()
if recommended_app is None:
raise NotFound("App not found")
current_tenant_id = current_user.current_tenant_id
- app = db.session.query(App).filter(App.id == args["app_id"]).first()
+ app = db.session.query(App).where(App.id == args["app_id"]).first()
if app is None:
raise NotFound("App not found")
@@ -109,7 +109,7 @@ class InstalledAppsListApi(Resource):
installed_app = (
db.session.query(InstalledApp)
- .filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
+ .where(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
.first()
)
@@ -122,7 +122,7 @@ class InstalledAppsListApi(Resource):
tenant_id=current_tenant_id,
app_owner_tenant_id=app.tenant_id,
is_pinned=False,
- last_used_at=datetime.now(UTC).replace(tzinfo=None),
+ last_used_at=naive_utc_now(),
)
db.session.add(new_installed_app)
db.session.commit()
diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py
index afbd78bd5b..de97fb149e 100644
--- a/api/controllers/console/explore/wraps.py
+++ b/api/controllers/console/explore/wraps.py
@@ -28,7 +28,7 @@ def installed_app_required(view=None):
installed_app = (
db.session.query(InstalledApp)
- .filter(
+ .where(
InstalledApp.id == str(installed_app_id), InstalledApp.tenant_id == current_user.current_tenant_id
)
.first()
diff --git a/api/controllers/console/workspace/__init__.py b/api/controllers/console/workspace/__init__.py
index 072e904caf..ef814dd738 100644
--- a/api/controllers/console/workspace/__init__.py
+++ b/api/controllers/console/workspace/__init__.py
@@ -21,7 +21,7 @@ def plugin_permission_required(
with Session(db.engine) as session:
permission = (
session.query(TenantPluginPermission)
- .filter(
+ .where(
TenantPluginPermission.tenant_id == tenant_id,
)
.first()
diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py
index a9dbf44456..657016e0a8 100644
--- a/api/controllers/console/workspace/account.py
+++ b/api/controllers/console/workspace/account.py
@@ -1,13 +1,21 @@
-import datetime
-
import pytz
from flask import request
from flask_login import current_user
from flask_restful import Resource, fields, marshal_with, reqparse
+from sqlalchemy import select
+from sqlalchemy.orm import Session
from configs import dify_config
from constants.languages import supported_language
from controllers.console import api
+from controllers.console.auth.error import (
+ EmailAlreadyInUseError,
+ EmailChangeLimitError,
+ EmailCodeError,
+ InvalidEmailError,
+ InvalidTokenError,
+)
+from controllers.console.error import AccountNotFound, EmailSendIpLimitError
from controllers.console.workspace.error import (
AccountAlreadyInitedError,
CurrentPasswordIncorrectError,
@@ -18,15 +26,18 @@ from controllers.console.workspace.error import (
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_enabled,
+ enable_change_email,
enterprise_license_required,
only_edition_cloud,
setup_required,
)
from extensions.ext_database import db
from fields.member_fields import account_fields
-from libs.helper import TimestampField, timezone
+from libs.datetime_utils import naive_utc_now
+from libs.helper import TimestampField, email, extract_remote_ip, timezone
from libs.login import login_required
from models import AccountIntegrate, InvitationCode
+from models.account import Account
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@@ -57,7 +68,7 @@ class AccountInitApi(Resource):
# check invitation code
invitation_code = (
db.session.query(InvitationCode)
- .filter(
+ .where(
InvitationCode.code == args["invitation_code"],
InvitationCode.status == "unused",
)
@@ -68,7 +79,7 @@ class AccountInitApi(Resource):
raise InvalidInvitationCodeError()
invitation_code.status = "used"
- invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ invitation_code.used_at = naive_utc_now()
invitation_code.used_by_tenant_id = account.current_tenant_id
invitation_code.used_by_account_id = account.id
@@ -76,7 +87,7 @@ class AccountInitApi(Resource):
account.timezone = args["timezone"]
account.interface_theme = "light"
account.status = "active"
- account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ account.initialized_at = naive_utc_now()
db.session.commit()
return {"result": "success"}
@@ -217,7 +228,7 @@ class AccountIntegrateApi(Resource):
def get(self):
account = current_user
- account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all()
+ account_integrates = db.session.query(AccountIntegrate).where(AccountIntegrate.account_id == account.id).all()
base_url = request.url_root.rstrip("/")
oauth_base_path = "/console/api/oauth/login"
@@ -369,6 +380,138 @@ class EducationAutoCompleteApi(Resource):
return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"])
+class ChangeEmailSendEmailApi(Resource):
+ @enable_change_email
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("email", type=email, required=True, location="json")
+ parser.add_argument("language", type=str, required=False, location="json")
+ parser.add_argument("phase", type=str, required=False, location="json")
+ parser.add_argument("token", type=str, required=False, location="json")
+ args = parser.parse_args()
+
+ ip_address = extract_remote_ip(request)
+ if AccountService.is_email_send_ip_limit(ip_address):
+ raise EmailSendIpLimitError()
+
+ if args["language"] is not None and args["language"] == "zh-Hans":
+ language = "zh-Hans"
+ else:
+ language = "en-US"
+ account = None
+ user_email = args["email"]
+ if args["phase"] is not None and args["phase"] == "new_email":
+ if args["token"] is None:
+ raise InvalidTokenError()
+
+ reset_data = AccountService.get_change_email_data(args["token"])
+ if reset_data is None:
+ raise InvalidTokenError()
+ user_email = reset_data.get("email", "")
+
+ if user_email != current_user.email:
+ raise InvalidEmailError()
+ else:
+ with Session(db.engine) as session:
+ account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none()
+ if account is None:
+ raise AccountNotFound()
+
+ token = AccountService.send_change_email_email(
+ account=account, email=args["email"], old_email=user_email, language=language, phase=args["phase"]
+ )
+ return {"result": "success", "data": token}
+
+
+class ChangeEmailCheckApi(Resource):
+ @enable_change_email
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("email", type=email, required=True, location="json")
+ parser.add_argument("code", type=str, required=True, location="json")
+ parser.add_argument("token", type=str, required=True, nullable=False, location="json")
+ args = parser.parse_args()
+
+ user_email = args["email"]
+
+ is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args["email"])
+ if is_change_email_error_rate_limit:
+ raise EmailChangeLimitError()
+
+ token_data = AccountService.get_change_email_data(args["token"])
+ if token_data is None:
+ raise InvalidTokenError()
+
+ if user_email != token_data.get("email"):
+ raise InvalidEmailError()
+
+ if args["code"] != token_data.get("code"):
+ AccountService.add_change_email_error_rate_limit(args["email"])
+ raise EmailCodeError()
+
+ # Verified, revoke the first token
+ AccountService.revoke_change_email_token(args["token"])
+
+ # Refresh token data by generating a new token
+ _, new_token = AccountService.generate_change_email_token(
+ user_email, code=args["code"], old_email=token_data.get("old_email"), additional_data={}
+ )
+
+ AccountService.reset_change_email_error_rate_limit(args["email"])
+ return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+
+
+class ChangeEmailResetApi(Resource):
+ @enable_change_email
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @marshal_with(account_fields)
+ def post(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("new_email", type=email, required=True, location="json")
+ parser.add_argument("token", type=str, required=True, nullable=False, location="json")
+ args = parser.parse_args()
+
+ reset_data = AccountService.get_change_email_data(args["token"])
+ if not reset_data:
+ raise InvalidTokenError()
+
+ AccountService.revoke_change_email_token(args["token"])
+
+ if not AccountService.check_email_unique(args["new_email"]):
+ raise EmailAlreadyInUseError()
+
+ old_email = reset_data.get("old_email", "")
+ if current_user.email != old_email:
+ raise AccountNotFound()
+
+ updated_account = AccountService.update_account(current_user, email=args["new_email"])
+
+ AccountService.send_change_email_completed_notify_email(
+ email=args["new_email"],
+ )
+
+ return updated_account
+
+
+class CheckEmailUnique(Resource):
+ @setup_required
+ def post(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("email", type=email, required=True, location="json")
+ args = parser.parse_args()
+ if not AccountService.check_email_unique(args["email"]):
+ raise EmailAlreadyInUseError()
+ return {"result": "success"}
+
+
# Register API resources
api.add_resource(AccountInitApi, "/account/init")
api.add_resource(AccountProfileApi, "/account/profile")
@@ -385,5 +528,10 @@ api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback")
api.add_resource(EducationVerifyApi, "/account/education/verify")
api.add_resource(EducationApi, "/account/education")
api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete")
+# Change email
+api.add_resource(ChangeEmailSendEmailApi, "/account/change-email")
+api.add_resource(ChangeEmailCheckApi, "/account/change-email/validity")
+api.add_resource(ChangeEmailResetApi, "/account/change-email/reset")
+api.add_resource(CheckEmailUnique, "/account/change-email/check-email-unique")
# api.add_resource(AccountEmailApi, '/account/email')
# api.add_resource(AccountEmailVerifyApi, '/account/email-verify')
diff --git a/api/controllers/console/workspace/error.py b/api/controllers/console/workspace/error.py
index 8b70ca62b9..4427d1ff72 100644
--- a/api/controllers/console/workspace/error.py
+++ b/api/controllers/console/workspace/error.py
@@ -13,12 +13,6 @@ class CurrentPasswordIncorrectError(BaseHTTPException):
code = 400
-class ProviderRequestFailedError(BaseHTTPException):
- error_code = "provider_request_failed"
- description = None
- code = 400
-
-
class InvalidInvitationCodeError(BaseHTTPException):
error_code = "invalid_invitation_code"
description = "Invalid invitation code."
diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py
index 48225ac90d..f7424923b9 100644
--- a/api/controllers/console/workspace/members.py
+++ b/api/controllers/console/workspace/members.py
@@ -1,22 +1,34 @@
from urllib import parse
+from flask import request
from flask_login import current_user
from flask_restful import Resource, abort, marshal_with, reqparse
import services
from configs import dify_config
from controllers.console import api
-from controllers.console.error import WorkspaceMembersLimitExceeded
+from controllers.console.auth.error import (
+ CannotTransferOwnerToSelfError,
+ EmailCodeError,
+ InvalidEmailError,
+ InvalidTokenError,
+ MemberNotInTenantError,
+ NotOwnerError,
+ OwnerTransferLimitError,
+)
+from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
+ is_allow_transfer_owner,
setup_required,
)
from extensions.ext_database import db
from fields.member_fields import account_with_role_list_fields
+from libs.helper import extract_remote_ip
from libs.login import login_required
from models.account import Account, TenantAccountRole
-from services.account_service import RegisterService, TenantService
+from services.account_service import AccountService, RegisterService, TenantService
from services.errors.account import AccountAlreadyInTenantError
from services.feature_service import FeatureService
@@ -96,7 +108,7 @@ class MemberCancelInviteApi(Resource):
@login_required
@account_initialization_required
def delete(self, member_id):
- member = db.session.query(Account).filter(Account.id == str(member_id)).first()
+ member = db.session.query(Account).where(Account.id == str(member_id)).first()
if member is None:
abort(404)
else:
@@ -156,8 +168,146 @@ class DatasetOperatorMemberListApi(Resource):
return {"result": "success", "accounts": members}, 200
+class SendOwnerTransferEmailApi(Resource):
+ """Send owner transfer email."""
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @is_allow_transfer_owner
+ def post(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("language", type=str, required=False, location="json")
+ args = parser.parse_args()
+ ip_address = extract_remote_ip(request)
+ if AccountService.is_email_send_ip_limit(ip_address):
+ raise EmailSendIpLimitError()
+
+ # check if the current user is the owner of the workspace
+ if not TenantService.is_owner(current_user, current_user.current_tenant):
+ raise NotOwnerError()
+
+ if args["language"] is not None and args["language"] == "zh-Hans":
+ language = "zh-Hans"
+ else:
+ language = "en-US"
+
+ email = current_user.email
+
+ token = AccountService.send_owner_transfer_email(
+ account=current_user,
+ email=email,
+ language=language,
+ workspace_name=current_user.current_tenant.name,
+ )
+
+ return {"result": "success", "data": token}
+
+
+class OwnerTransferCheckApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @is_allow_transfer_owner
+ def post(self):
+ parser = reqparse.RequestParser()
+ parser.add_argument("code", type=str, required=True, location="json")
+ parser.add_argument("token", type=str, required=True, nullable=False, location="json")
+ args = parser.parse_args()
+ # check if the current user is the owner of the workspace
+ if not TenantService.is_owner(current_user, current_user.current_tenant):
+ raise NotOwnerError()
+
+ user_email = current_user.email
+
+ is_owner_transfer_error_rate_limit = AccountService.is_owner_transfer_error_rate_limit(user_email)
+ if is_owner_transfer_error_rate_limit:
+ raise OwnerTransferLimitError()
+
+ token_data = AccountService.get_owner_transfer_data(args["token"])
+ if token_data is None:
+ raise InvalidTokenError()
+
+ if user_email != token_data.get("email"):
+ raise InvalidEmailError()
+
+ if args["code"] != token_data.get("code"):
+ AccountService.add_owner_transfer_error_rate_limit(user_email)
+ raise EmailCodeError()
+
+ # Verified, revoke the first token
+ AccountService.revoke_owner_transfer_token(args["token"])
+
+ # Refresh token data by generating a new token
+ _, new_token = AccountService.generate_owner_transfer_token(user_email, code=args["code"], additional_data={})
+
+ AccountService.reset_owner_transfer_error_rate_limit(user_email)
+ return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
+
+
+class OwnerTransfer(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ @is_allow_transfer_owner
+ def post(self, member_id):
+ parser = reqparse.RequestParser()
+ parser.add_argument("token", type=str, required=True, nullable=False, location="json")
+ args = parser.parse_args()
+
+ # check if the current user is the owner of the workspace
+ if not TenantService.is_owner(current_user, current_user.current_tenant):
+ raise NotOwnerError()
+
+ if current_user.id == str(member_id):
+ raise CannotTransferOwnerToSelfError()
+
+ transfer_token_data = AccountService.get_owner_transfer_data(args["token"])
+ if not transfer_token_data:
+ raise InvalidTokenError()
+
+ if transfer_token_data.get("email") != current_user.email:
+ raise InvalidEmailError()
+
+ AccountService.revoke_owner_transfer_token(args["token"])
+
+ member = db.session.get(Account, str(member_id))
+ if not member:
+ abort(404)
+ else:
+ member_account = member
+ if not TenantService.is_member(member_account, current_user.current_tenant):
+ raise MemberNotInTenantError()
+
+ try:
+ assert member is not None, "Member not found"
+ TenantService.update_member_role(current_user.current_tenant, member, "owner", current_user)
+
+ AccountService.send_new_owner_transfer_notify_email(
+ account=member,
+ email=member.email,
+ workspace_name=current_user.current_tenant.name,
+ )
+
+ AccountService.send_old_owner_transfer_notify_email(
+ account=current_user,
+ email=current_user.email,
+ workspace_name=current_user.current_tenant.name,
+ new_owner_email=member.email,
+ )
+
+ except Exception as e:
+ raise ValueError(str(e))
+
+ return {"result": "success"}
+
+
api.add_resource(MemberListApi, "/workspaces/current/members")
api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email")
api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/")
api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members//update-role")
api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators")
+# owner transfer
+api.add_resource(SendOwnerTransferEmailApi, "/workspaces/current/members/send-owner-transfer-confirm-email")
+api.add_resource(OwnerTransferCheckApi, "/workspaces/current/members/owner-transfer-check")
+api.add_resource(OwnerTransfer, "/workspaces/current/members//owner-transfer")
diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py
index c0a4734828..09846d5c94 100644
--- a/api/controllers/console/workspace/plugin.py
+++ b/api/controllers/console/workspace/plugin.py
@@ -12,7 +12,8 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import login_required
-from models.account import TenantPluginPermission
+from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
+from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
@@ -534,6 +535,114 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})
+class PluginChangePreferencesApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self):
+ user = current_user
+ if not user.is_admin_or_owner:
+ raise Forbidden()
+
+ req = reqparse.RequestParser()
+ req.add_argument("permission", type=dict, required=True, location="json")
+ req.add_argument("auto_upgrade", type=dict, required=True, location="json")
+ args = req.parse_args()
+
+ tenant_id = user.current_tenant_id
+
+ permission = args["permission"]
+
+ install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
+ debug_permission = TenantPluginPermission.DebugPermission(permission.get("debug_permission", "everyone"))
+
+ auto_upgrade = args["auto_upgrade"]
+
+ strategy_setting = TenantPluginAutoUpgradeStrategy.StrategySetting(
+ auto_upgrade.get("strategy_setting", "fix_only")
+ )
+ upgrade_time_of_day = auto_upgrade.get("upgrade_time_of_day", 0)
+ upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode(auto_upgrade.get("upgrade_mode", "exclude"))
+ exclude_plugins = auto_upgrade.get("exclude_plugins", [])
+ include_plugins = auto_upgrade.get("include_plugins", [])
+
+ # set permission
+ set_permission_result = PluginPermissionService.change_permission(
+ tenant_id,
+ install_permission,
+ debug_permission,
+ )
+ if not set_permission_result:
+ return jsonable_encoder({"success": False, "message": "Failed to set permission"})
+
+ # set auto upgrade strategy
+ set_auto_upgrade_strategy_result = PluginAutoUpgradeService.change_strategy(
+ tenant_id,
+ strategy_setting,
+ upgrade_time_of_day,
+ upgrade_mode,
+ exclude_plugins,
+ include_plugins,
+ )
+ if not set_auto_upgrade_strategy_result:
+ return jsonable_encoder({"success": False, "message": "Failed to set auto upgrade strategy"})
+
+ return jsonable_encoder({"success": True})
+
+
+class PluginFetchPreferencesApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self):
+ tenant_id = current_user.current_tenant_id
+
+ permission = PluginPermissionService.get_permission(tenant_id)
+ permission_dict = {
+ "install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
+ "debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
+ }
+
+ if permission:
+ permission_dict["install_permission"] = permission.install_permission
+ permission_dict["debug_permission"] = permission.debug_permission
+
+ auto_upgrade = PluginAutoUpgradeService.get_strategy(tenant_id)
+ auto_upgrade_dict = {
+ "strategy_setting": TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
+ "upgrade_time_of_day": 0,
+ "upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
+ "exclude_plugins": [],
+ "include_plugins": [],
+ }
+
+ if auto_upgrade:
+ auto_upgrade_dict = {
+ "strategy_setting": auto_upgrade.strategy_setting,
+ "upgrade_time_of_day": auto_upgrade.upgrade_time_of_day,
+ "upgrade_mode": auto_upgrade.upgrade_mode,
+ "exclude_plugins": auto_upgrade.exclude_plugins,
+ "include_plugins": auto_upgrade.include_plugins,
+ }
+
+ return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
+
+
+class PluginAutoUpgradeExcludePluginApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self):
+ # exclude one single plugin
+ tenant_id = current_user.current_tenant_id
+
+ req = reqparse.RequestParser()
+ req.add_argument("plugin_id", type=str, required=True, location="json")
+ args = req.parse_args()
+
+ return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
+
+
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
@@ -560,3 +669,7 @@ api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permissi
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options")
+
+api.add_resource(PluginFetchPreferencesApi, "/workspaces/current/plugin/preferences/fetch")
+api.add_resource(PluginChangePreferencesApi, "/workspaces/current/plugin/preferences/change")
+api.add_resource(PluginAutoUpgradeExcludePluginApi, "/workspaces/current/plugin/preferences/autoupgrade/exclude")
diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py
index df50871a38..c4d1ef70d8 100644
--- a/api/controllers/console/workspace/tool_providers.py
+++ b/api/controllers/console/workspace/tool_providers.py
@@ -1,26 +1,35 @@
import io
from urllib.parse import urlparse
-from flask import redirect, send_file
+from flask import make_response, redirect, request, send_file
from flask_login import current_user
-from flask_restful import Resource, reqparse
-from sqlalchemy.orm import Session
+from flask_restful import (
+ Resource,
+ reqparse,
+)
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.console import api
-from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
+from controllers.console.wraps import (
+ account_initialization_required,
+ enterprise_license_required,
+ setup_required,
+)
from core.mcp.auth.auth_flow import auth, handle_callback
from core.mcp.auth.auth_provider import OAuthClientProvider
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.model_runtime.utils.encoders import jsonable_encoder
-from extensions.ext_database import db
-from libs.helper import alphanumeric, uuid_value
+from core.plugin.entities.plugin import ToolProviderID
+from core.plugin.impl.oauth import OAuthHandler
+from core.tools.entities.tool_entities import CredentialType
+from libs.helper import StrLen, alphanumeric, uuid_value
from libs.login import login_required
+from services.plugin.oauth_service import OAuthProxyService
from services.tools.api_tools_manage_service import ApiToolManageService
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
-from services.tools.mcp_tools_mange_service import MCPToolManageService
+from services.tools.mcp_tools_manage_service import MCPToolManageService
from services.tools.tool_labels_service import ToolLabelsService
from services.tools.tools_manage_service import ToolCommonService
from services.tools.tools_transform_service import ToolTransformService
@@ -89,7 +98,7 @@ class ToolBuiltinProviderInfoApi(Resource):
user_id = user.id
tenant_id = user.current_tenant_id
- return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider))
+ return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider))
class ToolBuiltinProviderDeleteApi(Resource):
@@ -98,17 +107,47 @@ class ToolBuiltinProviderDeleteApi(Resource):
@account_initialization_required
def post(self, provider):
user = current_user
-
if not user.is_admin_or_owner:
raise Forbidden()
- user_id = user.id
tenant_id = user.current_tenant_id
+ req = reqparse.RequestParser()
+ req.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
+ args = req.parse_args()
return BuiltinToolManageService.delete_builtin_tool_provider(
- user_id,
tenant_id,
provider,
+ args["credential_id"],
+ )
+
+
+class ToolBuiltinProviderAddApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self, provider):
+ user = current_user
+
+ user_id = user.id
+ tenant_id = user.current_tenant_id
+
+ parser = reqparse.RequestParser()
+ parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
+ parser.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json")
+ parser.add_argument("type", type=str, required=True, nullable=False, location="json")
+ args = parser.parse_args()
+
+ if args["type"] not in CredentialType.values():
+ raise ValueError(f"Invalid credential type: {args['type']}")
+
+ return BuiltinToolManageService.add_builtin_tool_provider(
+ user_id=user_id,
+ tenant_id=tenant_id,
+ provider=provider,
+ credentials=args["credentials"],
+ name=args["name"],
+ api_type=CredentialType.of(args["type"]),
)
@@ -126,19 +165,20 @@ class ToolBuiltinProviderUpdateApi(Resource):
tenant_id = user.current_tenant_id
parser = reqparse.RequestParser()
- parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
+ parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json")
+ parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
+ parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json")
args = parser.parse_args()
- with Session(db.engine) as session:
- result = BuiltinToolManageService.update_builtin_tool_provider(
- session=session,
- user_id=user_id,
- tenant_id=tenant_id,
- provider_name=provider,
- credentials=args["credentials"],
- )
- session.commit()
+ result = BuiltinToolManageService.update_builtin_tool_provider(
+ user_id=user_id,
+ tenant_id=tenant_id,
+ provider=provider,
+ credential_id=args["credential_id"],
+ credentials=args.get("credentials", None),
+ name=args.get("name", ""),
+ )
return result
@@ -149,9 +189,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource):
def get(self, provider):
tenant_id = current_user.current_tenant_id
- return BuiltinToolManageService.get_builtin_tool_provider_credentials(
- tenant_id=tenant_id,
- provider_name=provider,
+ return jsonable_encoder(
+ BuiltinToolManageService.get_builtin_tool_provider_credentials(
+ tenant_id=tenant_id,
+ provider_name=provider,
+ )
)
@@ -344,12 +386,15 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource):
@setup_required
@login_required
@account_initialization_required
- def get(self, provider):
+ def get(self, provider, credential_type):
user = current_user
-
tenant_id = user.current_tenant_id
- return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id)
+ return jsonable_encoder(
+ BuiltinToolManageService.list_builtin_provider_credentials_schema(
+ provider, CredentialType.of(credential_type), tenant_id
+ )
+ )
class ToolApiProviderSchemaApi(Resource):
@@ -586,15 +631,12 @@ class ToolApiListApi(Resource):
@account_initialization_required
def get(self):
user = current_user
-
- user_id = user.id
tenant_id = user.current_tenant_id
return jsonable_encoder(
[
provider.to_dict()
for provider in ApiToolManageService.list_api_tools(
- user_id,
tenant_id,
)
]
@@ -631,6 +673,183 @@ class ToolLabelsApi(Resource):
return jsonable_encoder(ToolLabelsService.list_tool_labels())
+class ToolPluginOAuthApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self, provider):
+ tool_provider = ToolProviderID(provider)
+ plugin_id = tool_provider.plugin_id
+ provider_name = tool_provider.provider_name
+
+ # todo check permission
+ user = current_user
+
+ if not user.is_admin_or_owner:
+ raise Forbidden()
+
+ tenant_id = user.current_tenant_id
+ oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider)
+ if oauth_client_params is None:
+ raise Forbidden("no oauth available client config found for this tool provider")
+
+ oauth_handler = OAuthHandler()
+ context_id = OAuthProxyService.create_proxy_context(
+ user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name
+ )
+ redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
+ authorization_url_response = oauth_handler.get_authorization_url(
+ tenant_id=tenant_id,
+ user_id=user.id,
+ plugin_id=plugin_id,
+ provider=provider_name,
+ redirect_uri=redirect_uri,
+ system_credentials=oauth_client_params,
+ )
+ response = make_response(jsonable_encoder(authorization_url_response))
+ response.set_cookie(
+ "context_id",
+ context_id,
+ httponly=True,
+ samesite="Lax",
+ max_age=OAuthProxyService.__MAX_AGE__,
+ )
+ return response
+
+
+class ToolOAuthCallback(Resource):
+ @setup_required
+ def get(self, provider):
+ context_id = request.cookies.get("context_id")
+ if not context_id:
+ raise Forbidden("context_id not found")
+
+ context = OAuthProxyService.use_proxy_context(context_id)
+ if context is None:
+ raise Forbidden("Invalid context_id")
+
+ tool_provider = ToolProviderID(provider)
+ plugin_id = tool_provider.plugin_id
+ provider_name = tool_provider.provider_name
+ user_id, tenant_id = context.get("user_id"), context.get("tenant_id")
+
+ oauth_handler = OAuthHandler()
+ oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider)
+ if oauth_client_params is None:
+ raise Forbidden("no oauth available client config found for this tool provider")
+
+ redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
+ credentials_response = oauth_handler.get_credentials(
+ tenant_id=tenant_id,
+ user_id=user_id,
+ plugin_id=plugin_id,
+ provider=provider_name,
+ redirect_uri=redirect_uri,
+ system_credentials=oauth_client_params,
+ request=request,
+ )
+
+ credentials = credentials_response.credentials
+ expires_at = credentials_response.expires_at
+
+ if not credentials:
+ raise Exception("the plugin credentials failed")
+
+ # add credentials to database
+ BuiltinToolManageService.add_builtin_tool_provider(
+ user_id=user_id,
+ tenant_id=tenant_id,
+ provider=provider,
+ credentials=dict(credentials),
+ expires_at=expires_at,
+ api_type=CredentialType.OAUTH2,
+ )
+ return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
+
+
+class ToolBuiltinProviderSetDefaultApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self, provider):
+ parser = reqparse.RequestParser()
+ parser.add_argument("id", type=str, required=True, nullable=False, location="json")
+ args = parser.parse_args()
+ return BuiltinToolManageService.set_default_provider(
+ tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"]
+ )
+
+
+class ToolOAuthCustomClient(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def post(self, provider):
+ parser = reqparse.RequestParser()
+ parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
+ parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json")
+ args = parser.parse_args()
+
+ user = current_user
+
+ if not user.is_admin_or_owner:
+ raise Forbidden()
+
+ return BuiltinToolManageService.save_custom_oauth_client_params(
+ tenant_id=user.current_tenant_id,
+ provider=provider,
+ client_params=args.get("client_params", {}),
+ enable_oauth_custom_client=args.get("enable_oauth_custom_client", True),
+ )
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self, provider):
+ return jsonable_encoder(
+ BuiltinToolManageService.get_custom_oauth_client_params(
+ tenant_id=current_user.current_tenant_id, provider=provider
+ )
+ )
+
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def delete(self, provider):
+ return jsonable_encoder(
+ BuiltinToolManageService.delete_custom_oauth_client_params(
+ tenant_id=current_user.current_tenant_id, provider=provider
+ )
+ )
+
+
+class ToolBuiltinProviderGetOauthClientSchemaApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self, provider):
+ return jsonable_encoder(
+ BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema(
+ tenant_id=current_user.current_tenant_id, provider_name=provider
+ )
+ )
+
+
+class ToolBuiltinProviderGetCredentialInfoApi(Resource):
+ @setup_required
+ @login_required
+ @account_initialization_required
+ def get(self, provider):
+ tenant_id = current_user.current_tenant_id
+
+ return jsonable_encoder(
+ BuiltinToolManageService.get_builtin_tool_provider_credential_info(
+ tenant_id=tenant_id,
+ provider=provider,
+ )
+ )
+
+
class ToolProviderMCPApi(Resource):
@setup_required
@login_required
@@ -794,17 +1013,33 @@ class ToolMCPCallbackApi(Resource):
# tool provider
api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers")
+# tool oauth
+api.add_resource(ToolPluginOAuthApi, "/oauth/plugin//tool/authorization-url")
+api.add_resource(ToolOAuthCallback, "/oauth/plugin//tool/callback")
+api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin//oauth/custom-client")
+
# builtin tool provider
api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools")
api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin//info")
+api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin//add")
api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete")
api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update")
+api.add_resource(
+ ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin//default-credential"
+)
+api.add_resource(
+ ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin//credential/info"
+)
api.add_resource(
ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials"
)
api.add_resource(
ToolBuiltinProviderCredentialsSchemaApi,
- "/workspaces/current/tool-provider/builtin//credentials_schema",
+ "/workspaces/current/tool-provider/builtin//credential/schema/",
+)
+api.add_resource(
+ ToolBuiltinProviderGetOauthClientSchemaApi,
+ "/workspaces/current/tool-provider/builtin//oauth/client-schema",
)
api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon")
diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py
index ca122772de..d862dac373 100644
--- a/api/controllers/console/wraps.py
+++ b/api/controllers/console/wraps.py
@@ -235,3 +235,29 @@ def email_password_login_enabled(view):
abort(403)
return decorated
+
+
+def enable_change_email(view):
+ @wraps(view)
+ def decorated(*args, **kwargs):
+ features = FeatureService.get_system_features()
+ if features.enable_change_email:
+ return view(*args, **kwargs)
+
+ # otherwise, return 403
+ abort(403)
+
+ return decorated
+
+
+def is_allow_transfer_owner(view):
+ @wraps(view)
+ def decorated(*args, **kwargs):
+ features = FeatureService.get_features(current_user.current_tenant_id)
+ if features.is_allow_transfer_workspace:
+ return view(*args, **kwargs)
+
+ # otherwise, return 403
+ abort(403)
+
+ return decorated
diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py
index 327e9ce834..5dfe41eb6b 100644
--- a/api/controllers/inner_api/plugin/plugin.py
+++ b/api/controllers/inner_api/plugin/plugin.py
@@ -175,6 +175,7 @@ class PluginInvokeToolApi(Resource):
provider=payload.provider,
tool_name=payload.tool,
tool_parameters=payload.tool_parameters,
+ credential_id=payload.credential_id,
),
)
diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py
index 50408e0929..b533614d4d 100644
--- a/api/controllers/inner_api/plugin/wraps.py
+++ b/api/controllers/inner_api/plugin/wraps.py
@@ -22,7 +22,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
user_id = "DEFAULT-USER"
if user_id == "DEFAULT-USER":
- user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
+ user_model = session.query(EndUser).where(EndUser.session_id == "DEFAULT-USER").first()
if not user_model:
user_model = EndUser(
tenant_id=tenant_id,
@@ -36,7 +36,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
else:
user_model = AccountService.load_user(user_id)
if not user_model:
- user_model = session.query(EndUser).filter(EndUser.id == user_id).first()
+ user_model = session.query(EndUser).where(EndUser.id == user_id).first()
if not user_model:
raise ValueError("user not found")
except Exception:
@@ -71,7 +71,7 @@ def get_user_tenant(view: Optional[Callable] = None):
try:
tenant_model = (
db.session.query(Tenant)
- .filter(
+ .where(
Tenant.id == tenant_id,
)
.first()
diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py
index f3a9312dd0..9e7b3d4f29 100644
--- a/api/controllers/inner_api/wraps.py
+++ b/api/controllers/inner_api/wraps.py
@@ -55,7 +55,7 @@ def enterprise_inner_api_user_auth(view):
if signature_base64 != token:
return view(*args, **kwargs)
- kwargs["user"] = db.session.query(EndUser).filter(EndUser.id == user_id).first()
+ kwargs["user"] = db.session.query(EndUser).where(EndUser.id == user_id).first()
return view(*args, **kwargs)
diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py
index ead728bfb0..87d678796f 100644
--- a/api/controllers/mcp/mcp.py
+++ b/api/controllers/mcp/mcp.py
@@ -30,7 +30,7 @@ class MCPAppApi(Resource):
request_id = args.get("id")
- server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
+ server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not server:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server Not Found")
@@ -41,7 +41,7 @@ class MCPAppApi(Resource):
create_mcp_error_response(request_id, types.INVALID_REQUEST, "Server is not active")
)
- app = db.session.query(App).filter(App.id == server.app_id).first()
+ app = db.session.query(App).where(App.id == server.app_id).first()
if not app:
return helper.compact_generate_response(
create_mcp_error_response(request_id, types.INVALID_REQUEST, "App Not Found")
diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py
index 1d9890199d..7762672494 100644
--- a/api/controllers/service_api/app/completion.py
+++ b/api/controllers/service_api/app/completion.py
@@ -1,5 +1,6 @@
import logging
+from flask import request
from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
@@ -23,6 +24,7 @@ from core.errors.error import (
ProviderTokenNotInitError,
QuotaExceededError,
)
+from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from libs import helper
from libs.helper import uuid_value
@@ -111,6 +113,10 @@ class ChatApi(Resource):
args = parser.parse_args()
+ external_trace_id = get_external_trace_id(request)
+ if external_trace_id:
+ args["external_trace_id"] = external_trace_id
+
streaming = args["response_mode"] == "streaming"
try:
diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py
index e752dfee30..c157b39f6b 100644
--- a/api/controllers/service_api/app/site.py
+++ b/api/controllers/service_api/app/site.py
@@ -16,7 +16,7 @@ class AppSiteApi(Resource):
@marshal_with(fields.site_fields)
def get(self, app_model: App):
"""Retrieve app site info."""
- site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
+ site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()
diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py
index efb4acc5fb..370ff911b4 100644
--- a/api/controllers/service_api/app/workflow.py
+++ b/api/controllers/service_api/app/workflow.py
@@ -1,9 +1,10 @@
import logging
from dateutil.parser import isoparse
+from flask import request
from flask_restful import Resource, fields, marshal_with, reqparse
from flask_restful.inputs import int_range
-from sqlalchemy.orm import Session
+from sqlalchemy.orm import Session, sessionmaker
from werkzeug.exceptions import InternalServerError
from controllers.service_api import api
@@ -23,6 +24,7 @@ from core.errors.error import (
ProviderTokenNotInitError,
QuotaExceededError,
)
+from core.helper.trace_id_helper import get_external_trace_id
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus
from extensions.ext_database import db
@@ -30,7 +32,7 @@ from fields.workflow_app_log_fields import workflow_app_log_pagination_fields
from libs import helper
from libs.helper import TimestampField
from models.model import App, AppMode, EndUser
-from models.workflow import WorkflowRun
+from repositories.factory import DifyAPIRepositoryFactory
from services.app_generate_service import AppGenerateService
from services.errors.llm import InvokeRateLimitError
from services.workflow_app_service import WorkflowAppService
@@ -63,7 +65,15 @@ class WorkflowRunDetailApi(Resource):
if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]:
raise NotWorkflowAppError()
- workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
+ # Use repository to get workflow run
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+
+ workflow_run = workflow_run_repo.get_workflow_run_by_id(
+ tenant_id=app_model.tenant_id,
+ app_id=app_model.id,
+ run_id=workflow_run_id,
+ )
return workflow_run
@@ -82,7 +92,9 @@ class WorkflowRunApi(Resource):
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
args = parser.parse_args()
-
+ external_trace_id = get_external_trace_id(request)
+ if external_trace_id:
+ args["external_trace_id"] = external_trace_id
streaming = args.get("response_mode") == "streaming"
try:
diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py
index d571b21a0a..ac85c0b38d 100644
--- a/api/controllers/service_api/dataset/document.py
+++ b/api/controllers/service_api/dataset/document.py
@@ -63,7 +63,7 @@ class DocumentAddByTextApi(DatasetApiResource):
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
@@ -136,7 +136,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
args = parser.parse_args()
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
@@ -206,7 +206,7 @@ class DocumentAddByFileApi(DatasetApiResource):
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
@@ -299,7 +299,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
# get dataset info
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
@@ -367,7 +367,7 @@ class DocumentDeleteApi(DatasetApiResource):
tenant_id = str(tenant_id)
# get dataset info
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise ValueError("Dataset does not exist.")
@@ -398,7 +398,7 @@ class DocumentListApi(DatasetApiResource):
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
@@ -406,7 +406,7 @@ class DocumentListApi(DatasetApiResource):
if search:
search = f"%{search}%"
- query = query.filter(Document.name.like(search))
+ query = query.where(Document.name.like(search))
query = query.order_by(desc(Document.created_at), desc(Document.position))
@@ -430,7 +430,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
batch = str(batch)
tenant_id = str(tenant_id)
# get dataset
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# get documents
@@ -441,7 +441,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
for document in documents:
completed_segments = (
db.session.query(DocumentSegment)
- .filter(
+ .where(
DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment",
@@ -450,7 +450,7 @@ class DocumentIndexingStatusApi(DatasetApiResource):
)
total_segments = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
+ .where(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
# Create a dictionary with document attributes and additional fields
diff --git a/api/controllers/service_api/dataset/error.py b/api/controllers/service_api/dataset/error.py
index 5ff5e08c72..ecc47b40a1 100644
--- a/api/controllers/service_api/dataset/error.py
+++ b/api/controllers/service_api/dataset/error.py
@@ -25,12 +25,6 @@ class UnsupportedFileTypeError(BaseHTTPException):
code = 415
-class HighQualityDatasetOnlyError(BaseHTTPException):
- error_code = "high_quality_dataset_only"
- description = "Current operation only supports 'high-quality' datasets."
- code = 400
-
-
class DatasetNotInitializedError(BaseHTTPException):
error_code = "dataset_not_initialized"
description = "The dataset is still being initialized or indexing. Please wait a moment."
diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py
index 403b7f0a0c..31f862dc8f 100644
--- a/api/controllers/service_api/dataset/segment.py
+++ b/api/controllers/service_api/dataset/segment.py
@@ -42,7 +42,7 @@ class SegmentApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
@@ -89,7 +89,7 @@ class SegmentApi(DatasetApiResource):
tenant_id = str(tenant_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
@@ -146,7 +146,7 @@ class DatasetSegmentApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
@@ -170,7 +170,7 @@ class DatasetSegmentApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
@@ -216,7 +216,7 @@ class DatasetSegmentApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
@@ -246,7 +246,7 @@ class ChildChunkApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
@@ -296,7 +296,7 @@ class ChildChunkApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
@@ -343,7 +343,7 @@ class DatasetChildChunkApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
@@ -382,7 +382,7 @@ class DatasetChildChunkApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
diff --git a/api/controllers/service_api/dataset/upload_file.py b/api/controllers/service_api/dataset/upload_file.py
index 6382b63ea9..3b4721b5b0 100644
--- a/api/controllers/service_api/dataset/upload_file.py
+++ b/api/controllers/service_api/dataset/upload_file.py
@@ -17,7 +17,7 @@ class UploadFileApi(DatasetApiResource):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
# check document
@@ -31,7 +31,7 @@ class UploadFileApi(DatasetApiResource):
data_source_info = document.data_source_info_dict
if data_source_info and "upload_file_id" in data_source_info:
file_id = data_source_info["upload_file_id"]
- upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
+ upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("UploadFile not found.")
else:
diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py
index 5b919a68d4..da81cc8bc3 100644
--- a/api/controllers/service_api/wraps.py
+++ b/api/controllers/service_api/wraps.py
@@ -1,6 +1,6 @@
import time
from collections.abc import Callable
-from datetime import UTC, datetime, timedelta
+from datetime import timedelta
from enum import Enum
from functools import wraps
from typing import Optional
@@ -15,6 +15,7 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from libs.datetime_utils import naive_utc_now
from libs.login import _get_user
from models.account import Account, Tenant, TenantAccountJoin, TenantStatus
from models.dataset import Dataset, RateLimitLog
@@ -43,7 +44,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
def decorated_view(*args, **kwargs):
api_token = validate_and_get_api_token("app")
- app_model = db.session.query(App).filter(App.id == api_token.app_id).first()
+ app_model = db.session.query(App).where(App.id == api_token.app_id).first()
if not app_model:
raise Forbidden("The app no longer exists.")
@@ -53,7 +54,7 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
if not app_model.enable_api:
raise Forbidden("The app's API service has been disabled.")
- tenant = db.session.query(Tenant).filter(Tenant.id == app_model.tenant_id).first()
+ tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first()
if tenant is None:
raise ValueError("Tenant does not exist.")
if tenant.status == TenantStatus.ARCHIVE:
@@ -61,15 +62,15 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
- .filter(Tenant.id == api_token.tenant_id)
- .filter(TenantAccountJoin.tenant_id == Tenant.id)
- .filter(TenantAccountJoin.role.in_(["owner"]))
- .filter(Tenant.status == TenantStatus.NORMAL)
+ .where(Tenant.id == api_token.tenant_id)
+ .where(TenantAccountJoin.tenant_id == Tenant.id)
+ .where(TenantAccountJoin.role.in_(["owner"]))
+ .where(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
- account = db.session.query(Account).filter(Account.id == ta.account_id).first()
+ account = db.session.query(Account).where(Account.id == ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
@@ -212,15 +213,15 @@ def validate_dataset_token(view=None):
api_token = validate_and_get_api_token("dataset")
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
- .filter(Tenant.id == api_token.tenant_id)
- .filter(TenantAccountJoin.tenant_id == Tenant.id)
- .filter(TenantAccountJoin.role.in_(["owner"]))
- .filter(Tenant.status == TenantStatus.NORMAL)
+ .where(Tenant.id == api_token.tenant_id)
+ .where(TenantAccountJoin.tenant_id == Tenant.id)
+ .where(TenantAccountJoin.role.in_(["owner"]))
+ .where(Tenant.status == TenantStatus.NORMAL)
.one_or_none()
) # TODO: only owner information is required, so only one is returned.
if tenant_account_join:
tenant, ta = tenant_account_join
- account = db.session.query(Account).filter(Account.id == ta.account_id).first()
+ account = db.session.query(Account).where(Account.id == ta.account_id).first()
# Login admin
if account:
account.current_tenant = tenant
@@ -256,7 +257,7 @@ def validate_and_get_api_token(scope: str | None = None):
if auth_scheme != "bearer":
raise Unauthorized("Authorization scheme must be 'Bearer'")
- current_time = datetime.now(UTC).replace(tzinfo=None)
+ current_time = naive_utc_now()
cutoff_time = current_time - timedelta(minutes=1)
with Session(db.engine, expire_on_commit=False) as session:
update_stmt = (
@@ -292,7 +293,7 @@ def create_or_update_end_user_for_user_id(app_model: App, user_id: Optional[str]
end_user = (
db.session.query(EndUser)
- .filter(
+ .where(
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
EndUser.session_id == user_id,
@@ -319,7 +320,7 @@ class DatasetApiResource(Resource):
method_decorators = [validate_dataset_token]
def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset:
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first()
if not dataset:
raise NotFound("Dataset not found.")
diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py
index 10c3cdcf0e..acd3a8b539 100644
--- a/api/controllers/web/passport.py
+++ b/api/controllers/web/passport.py
@@ -3,6 +3,7 @@ from datetime import UTC, datetime, timedelta
from flask import request
from flask_restful import Resource
+from sqlalchemy import func, select
from werkzeug.exceptions import NotFound, Unauthorized
from configs import dify_config
@@ -42,17 +43,17 @@ class PassportResource(Resource):
raise WebAppAuthRequiredError()
# get site from db and check if it is normal
- site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
+ site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
if not site:
raise NotFound()
# get app from db and check if it is normal and enable_site
- app_model = db.session.query(App).filter(App.id == site.app_id).first()
+ app_model = db.session.scalar(select(App).where(App.id == site.app_id))
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
if user_id:
- end_user = (
- db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
+ end_user = db.session.scalar(
+ select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
)
if end_user:
@@ -121,11 +122,11 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
if not user_auth_type:
raise Unauthorized("Missing auth_type in the token.")
- site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first()
+ site = db.session.scalar(select(Site).where(Site.code == app_code, Site.status == "normal"))
if not site:
raise NotFound()
- app_model = db.session.query(App).filter(App.id == site.app_id).first()
+ app_model = db.session.scalar(select(App).where(App.id == site.app_id))
if not app_model or app_model.status != "normal" or not app_model.enable_site:
raise NotFound()
@@ -140,16 +141,14 @@ def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded:
end_user = None
if end_user_id:
- end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
+ end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if session_id:
- end_user = (
- db.session.query(EndUser)
- .filter(
+ end_user = db.session.scalar(
+ select(EndUser).where(
EndUser.session_id == session_id,
EndUser.tenant_id == app_model.tenant_id,
EndUser.app_id == app_model.id,
)
- .first()
)
if not end_user:
if not session_id:
@@ -187,8 +186,8 @@ def _exchange_for_public_app_token(app_model, site, token_decoded):
user_id = token_decoded.get("user_id")
end_user = None
if user_id:
- end_user = (
- db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first()
+ end_user = db.session.scalar(
+ select(EndUser).where(EndUser.app_id == app_model.id, EndUser.session_id == user_id)
)
if not end_user:
@@ -224,6 +223,8 @@ def generate_session_id():
"""
while True:
session_id = str(uuid.uuid4())
- existing_count = db.session.query(EndUser).filter(EndUser.session_id == session_id).count()
+ existing_count = db.session.scalar(
+ select(func.count()).select_from(EndUser).where(EndUser.session_id == session_id)
+ )
if existing_count == 0:
return session_id
diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py
index 0564b15ea3..3c133499b7 100644
--- a/api/controllers/web/site.py
+++ b/api/controllers/web/site.py
@@ -57,7 +57,7 @@ class AppSiteApi(WebApiResource):
def get(self, app_model, end_user):
"""Retrieve app site info."""
# get site
- site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
+ site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()
diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py
index 154bddfc5c..ae6f14a689 100644
--- a/api/controllers/web/wraps.py
+++ b/api/controllers/web/wraps.py
@@ -3,6 +3,7 @@ from functools import wraps
from flask import request
from flask_restful import Resource
+from sqlalchemy import select
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError
@@ -48,8 +49,8 @@ def decode_jwt_token():
decoded = PassportService().verify(tk)
app_code = decoded.get("app_code")
app_id = decoded.get("app_id")
- app_model = db.session.query(App).filter(App.id == app_id).first()
- site = db.session.query(Site).filter(Site.code == app_code).first()
+ app_model = db.session.scalar(select(App).where(App.id == app_id))
+ site = db.session.scalar(select(Site).where(Site.code == app_code))
if not app_model:
raise NotFound()
if not app_code or not site:
@@ -57,7 +58,7 @@ def decode_jwt_token():
if app_model.enable_site is False:
raise BadRequest("Site is disabled.")
end_user_id = decoded.get("end_user_id")
- end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first()
+ end_user = db.session.scalar(select(EndUser).where(EndUser.id == end_user_id))
if not end_user:
raise NotFound()
diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py
index 0d304de97a..1f3c218d59 100644
--- a/api/core/agent/base_agent_runner.py
+++ b/api/core/agent/base_agent_runner.py
@@ -3,6 +3,8 @@ import logging
import uuid
from typing import Optional, Union, cast
+from sqlalchemy import select
+
from core.agent.entities import AgentEntity, AgentToolEntity
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
@@ -97,7 +99,7 @@ class BaseAgentRunner(AppRunner):
# get how many agent thoughts have been created
self.agent_thought_count = (
db.session.query(MessageAgentThought)
- .filter(
+ .where(
MessageAgentThought.message_id == self.message.id,
)
.count()
@@ -334,7 +336,7 @@ class BaseAgentRunner(AppRunner):
Save agent thought
"""
updated_agent_thought = (
- db.session.query(MessageAgentThought).filter(MessageAgentThought.id == agent_thought.id).first()
+ db.session.query(MessageAgentThought).where(MessageAgentThought.id == agent_thought.id).first()
)
if not updated_agent_thought:
raise ValueError("agent thought not found")
@@ -417,12 +419,15 @@ class BaseAgentRunner(AppRunner):
if isinstance(prompt_message, SystemPromptMessage):
result.append(prompt_message)
- messages: list[Message] = (
- db.session.query(Message)
- .filter(
- Message.conversation_id == self.message.conversation_id,
+ messages = (
+ (
+ db.session.execute(
+ select(Message)
+ .where(Message.conversation_id == self.message.conversation_id)
+ .order_by(Message.created_at.desc())
+ )
)
- .order_by(Message.created_at.desc())
+ .scalars()
.all()
)
@@ -491,7 +496,7 @@ class BaseAgentRunner(AppRunner):
return result
def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
- files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
+ files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
if not files:
return UserPromptMessage(content=message.query)
if message.app_model_config:
diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py
index 143a3a51aa..a31c1050bd 100644
--- a/api/core/agent/entities.py
+++ b/api/core/agent/entities.py
@@ -16,6 +16,7 @@ class AgentToolEntity(BaseModel):
tool_name: str
tool_parameters: dict[str, Any] = Field(default_factory=dict)
plugin_unique_identifier: str | None = None
+ credential_id: str | None = None
class AgentPromptEntity(BaseModel):
diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py
index 3b48288710..a3438fc2c7 100644
--- a/api/core/agent/plugin_entities.py
+++ b/api/core/agent/plugin_entities.py
@@ -41,6 +41,7 @@ class AgentStrategyParameter(PluginParameter):
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
+ ANY = CommonParameterType.ANY.value
# deprecated, should not use.
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
diff --git a/api/core/agent/strategy/base.py b/api/core/agent/strategy/base.py
index ead81a7a0e..a52a1dfd7a 100644
--- a/api/core/agent/strategy/base.py
+++ b/api/core/agent/strategy/base.py
@@ -4,6 +4,7 @@ from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyParameter
+from core.plugin.entities.request import InvokeCredentials
class BaseAgentStrategy(ABC):
@@ -18,11 +19,12 @@ class BaseAgentStrategy(ABC):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
+ credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
"""
- yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
+ yield from self._invoke(params, user_id, conversation_id, app_id, message_id, credentials)
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
"""
@@ -38,5 +40,6 @@ class BaseAgentStrategy(ABC):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
+ credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]:
pass
diff --git a/api/core/agent/strategy/plugin.py b/api/core/agent/strategy/plugin.py
index 4cfcfbf86a..04661581a7 100644
--- a/api/core/agent/strategy/plugin.py
+++ b/api/core/agent/strategy/plugin.py
@@ -4,6 +4,7 @@ from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
from core.agent.strategy.base import BaseAgentStrategy
+from core.plugin.entities.request import InvokeCredentials, PluginInvokeContext
from core.plugin.impl.agent import PluginAgentClient
from core.plugin.utils.converter import convert_parameters_to_plugin_format
@@ -40,6 +41,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
+ credentials: Optional[InvokeCredentials] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent strategy.
@@ -58,4 +60,5 @@ class PluginAgentStrategy(BaseAgentStrategy):
conversation_id=conversation_id,
app_id=app_id,
message_id=message_id,
+ context=PluginInvokeContext(credentials=credentials or InvokeCredentials()),
)
diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py
index 590b944c0d..8887d2500c 100644
--- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py
+++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py
@@ -39,6 +39,7 @@ class AgentConfigManager:
"provider_id": tool["provider_id"],
"tool_name": tool["tool_name"],
"tool_parameters": tool.get("tool_parameters", {}),
+ "credential_id": tool.get("credential_id", None),
}
agent_tools.append(AgentToolEntity(**agent_tool_properties))
diff --git a/api/core/app/apps/README.md b/api/core/app/apps/README.md
deleted file mode 100644
index 7a57bb3658..0000000000
--- a/api/core/app/apps/README.md
+++ /dev/null
@@ -1,48 +0,0 @@
-## Guidelines for Database Connection Management in App Runner and Task Pipeline
-
-Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks.
-
-Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid detach errors.
-
-Examples:
-
-1. Creating a new record:
-
- ```python
- app = App(id=1)
- db.session.add(app)
- db.session.commit()
- db.session.refresh(app) # Retrieve table default values, like created_at, cached in the app object, won't affect after close
-
- # Handle non-long-running tasks or store the content of the App instance in memory (via variable assignment).
-
- db.session.close()
-
- return app.id
- ```
-
-2. Fetching a record from the table:
-
- ```python
- app = db.session.query(App).filter(App.id == app_id).first()
-
- created_at = app.created_at
-
- db.session.close()
-
- # Handle tasks (include long-running).
-
- ```
-
-3. Updating a table field:
-
- ```python
- app = db.session.query(App).filter(App.id == app_id).first()
-
- app.updated_at = time.utcnow()
- db.session.commit()
- db.session.close()
-
- return app_id
- ```
-
diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py
index 7877408cef..610a5bb278 100644
--- a/api/core/app/apps/advanced_chat/app_generator.py
+++ b/api/core/app/apps/advanced_chat/app_generator.py
@@ -7,7 +7,8 @@ from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
-from sqlalchemy.orm import sessionmaker
+from sqlalchemy import select
+from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
@@ -17,16 +18,17 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter
from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline
-from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
+from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
+from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
-from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
)
@@ -112,7 +114,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
query = query.replace("\x00", "")
inputs = args["inputs"]
- extras = {"auto_generate_conversation_name": args.get("auto_generate_name", False)}
+ extras = {
+ "auto_generate_conversation_name": args.get("auto_generate_name", False),
+ **extract_external_trace_id_from_args(args),
+ }
# get conversation
conversation = None
@@ -183,14 +188,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
- workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
+ workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
@@ -260,14 +265,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
- workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
+ workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
@@ -343,14 +348,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
- workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
+ workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
@@ -482,21 +487,52 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"""
with preserve_flask_contexts(flask_app, context_vars=context):
- try:
- # get conversation and message
- conversation = self._get_conversation(conversation_id)
- message = self._get_message(message_id)
-
- # chatbot app
- runner = AdvancedChatAppRunner(
- application_generate_entity=application_generate_entity,
- queue_manager=queue_manager,
- conversation=conversation,
- message=message,
- dialogue_count=self._dialogue_count,
- variable_loader=variable_loader,
+ # get conversation and message
+ conversation = self._get_conversation(conversation_id)
+ message = self._get_message(message_id)
+
+ with Session(db.engine, expire_on_commit=False) as session:
+ workflow = session.scalar(
+ select(Workflow).where(
+ Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
+ Workflow.app_id == application_generate_entity.app_config.app_id,
+ Workflow.id == application_generate_entity.app_config.workflow_id,
+ )
)
+ if workflow is None:
+ raise ValueError("Workflow not found")
+
+ # Determine system_user_id based on invocation source
+ is_external_api_call = application_generate_entity.invoke_from in {
+ InvokeFrom.WEB_APP,
+ InvokeFrom.SERVICE_API,
+ }
+
+ if is_external_api_call:
+ # For external API calls, use end user's session ID
+ end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id))
+ system_user_id = end_user.session_id if end_user else ""
+ else:
+ # For internal calls, use the original user ID
+ system_user_id = application_generate_entity.user_id
+
+ app = session.scalar(select(App).where(App.id == application_generate_entity.app_config.app_id))
+ if app is None:
+ raise ValueError("App not found")
+
+ runner = AdvancedChatAppRunner(
+ application_generate_entity=application_generate_entity,
+ queue_manager=queue_manager,
+ conversation=conversation,
+ message=message,
+ dialogue_count=self._dialogue_count,
+ variable_loader=variable_loader,
+ workflow=workflow,
+ system_user_id=system_user_id,
+ app=app,
+ )
+ try:
runner.run()
except GenerateTaskStoppedError:
pass
diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py
index 840a3c9d3b..a75e17af64 100644
--- a/api/core/app/apps/advanced_chat/app_runner.py
+++ b/api/core/app/apps/advanced_chat/app_runner.py
@@ -1,6 +1,6 @@
import logging
from collections.abc import Mapping
-from typing import Any, cast
+from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -9,21 +9,29 @@ from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
-from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
+from core.app.entities.app_invoke_entities import (
+ AdvancedChatAppGenerateEntity,
+ AppGenerateEntity,
+ InvokeFrom,
+)
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
QueueStopEvent,
QueueTextChunkEvent,
)
+from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.moderation.base import ModerationError
+from core.moderation.input_moderation import InputModeration
+from core.variables.variables import VariableUnion
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
+from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
+from models import Workflow
from models.enums import UserFrom
-from models.model import App, Conversation, EndUser, Message
+from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable, WorkflowType
logger = logging.getLogger(__name__)
@@ -36,42 +44,38 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def __init__(
self,
+ *,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
dialogue_count: int,
variable_loader: VariableLoader,
+ workflow: Workflow,
+ system_user_id: str,
+ app: App,
) -> None:
- super().__init__(queue_manager, variable_loader)
+ super().__init__(
+ queue_manager=queue_manager,
+ variable_loader=variable_loader,
+ app_id=application_generate_entity.app_config.app_id,
+ )
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.message = message
self._dialogue_count = dialogue_count
-
- def _get_app_id(self) -> str:
- return self.application_generate_entity.app_config.app_id
+ self._workflow = workflow
+ self.system_user_id = system_user_id
+ self._app = app
def run(self) -> None:
app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config)
- app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
+ app_record = db.session.query(App).where(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
- workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
- if not workflow:
- raise ValueError("Workflow not initialized")
-
- user_id = None
- if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
- end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
- if end_user:
- user_id = end_user.session_id
- else:
- user_id = self.application_generate_entity.user_id
-
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
@@ -79,14 +83,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
- workflow=workflow,
+ workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
- workflow=workflow,
+ workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
)
@@ -97,7 +101,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# moderation
if self.handle_input_moderation(
- app_record=app_record,
+ app_record=self._app,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
@@ -107,7 +111,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# annotation reply
if self.handle_annotation_reply(
- app_record=app_record,
+ app_record=self._app,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity,
@@ -127,7 +131,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
- for variable in workflow.conversation_variables
+ for variable in self._workflow.conversation_variables
]
session.add_all(db_conversation_variables)
# Convert database entities to variables.
@@ -136,38 +140,40 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
session.commit()
# Create a variable pool.
- system_inputs = {
- SystemVariableKey.QUERY: query,
- SystemVariableKey.FILES: files,
- SystemVariableKey.CONVERSATION_ID: self.conversation.id,
- SystemVariableKey.USER_ID: user_id,
- SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count,
- SystemVariableKey.APP_ID: app_config.app_id,
- SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
- SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id,
- }
+ system_inputs = SystemVariable(
+ query=query,
+ files=files,
+ conversation_id=self.conversation.id,
+ user_id=self.system_user_id,
+ dialogue_count=self._dialogue_count,
+ app_id=app_config.app_id,
+ workflow_id=app_config.workflow_id,
+ workflow_execution_id=self.application_generate_entity.workflow_run_id,
+ )
# init variable pool
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
- environment_variables=workflow.environment_variables,
- conversation_variables=conversation_variables,
+ environment_variables=self._workflow.environment_variables,
+ # Based on the definition of `VariableUnion`,
+ # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
+ conversation_variables=cast(list[VariableUnion], conversation_variables),
)
# init graph
- graph = self._init_graph(graph_config=workflow.graph_dict)
+ graph = self._init_graph(graph_config=self._workflow.graph_dict)
db.session.close()
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
- tenant_id=workflow.tenant_id,
- app_id=workflow.app_id,
- workflow_id=workflow.id,
- workflow_type=WorkflowType.value_of(workflow.type),
+ tenant_id=self._workflow.tenant_id,
+ app_id=self._workflow.app_id,
+ workflow_id=self._workflow.id,
+ workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph,
- graph_config=workflow.graph_dict,
+ graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
@@ -238,3 +244,51 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._publish_event(QueueTextChunkEvent(text=text))
self._publish_event(QueueStopEvent(stopped_by=stopped_by))
+
+ def query_app_annotations_to_reply(
+ self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
+ ) -> Optional[MessageAnnotation]:
+ """
+ Query app annotations to reply
+ :param app_record: app record
+ :param message: message
+ :param query: query
+ :param user_id: user id
+ :param invoke_from: invoke from
+ :return:
+ """
+ annotation_reply_feature = AnnotationReplyFeature()
+ return annotation_reply_feature.query(
+ app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
+ )
+
+ def moderation_for_inputs(
+ self,
+ *,
+ app_id: str,
+ tenant_id: str,
+ app_generate_entity: AppGenerateEntity,
+ inputs: Mapping[str, Any],
+ query: str | None = None,
+ message_id: str,
+ ) -> tuple[bool, Mapping[str, Any], str]:
+ """
+ Process sensitive_word_avoidance.
+ :param app_id: app id
+ :param tenant_id: tenant id
+ :param app_generate_entity: app generate entity
+ :param inputs: inputs
+ :param query: query
+ :param message_id: message id
+ :return:
+ """
+ moderation_feature = InputModeration()
+ return moderation_feature.check(
+ app_id=app_id,
+ tenant_id=tenant_id,
+ app_config=app_generate_entity.app_config,
+ inputs=dict(inputs),
+ query=query or "",
+ message_id=message_id,
+ trace_manager=app_generate_entity.trace_manager,
+ )
diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
index 4c52fc3e83..dc27076a4d 100644
--- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py
+++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
@@ -1,6 +1,7 @@
import logging
import time
-from collections.abc import Generator, Mapping
+from collections.abc import Callable, Generator, Mapping
+from contextlib import contextmanager
from threading import Thread
from typing import Any, Optional, Union
@@ -15,6 +16,7 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
)
from core.app.entities.queue_entities import (
+ MessageQueueMessage,
QueueAdvancedChatMessageEndEvent,
QueueAgentLogEvent,
QueueAnnotationReplyEvent,
@@ -44,6 +46,7 @@ from core.app.entities.queue_entities import (
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
+ WorkflowQueueMessage,
)
from core.app.entities.task_entities import (
ChatbotAppBlockingResponse,
@@ -52,6 +55,7 @@ from core.app.entities.task_entities import (
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
MessageEndStreamResponse,
+ PingStreamResponse,
StreamResponse,
WorkflowTaskState,
)
@@ -61,12 +65,12 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType
-from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from events.message_event import message_was_created
from extensions.ext_database import db
@@ -116,16 +120,16 @@ class AdvancedChatAppGenerateTaskPipeline:
self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
- workflow_system_variables={
- SystemVariableKey.QUERY: message.query,
- SystemVariableKey.FILES: application_generate_entity.files,
- SystemVariableKey.CONVERSATION_ID: conversation.id,
- SystemVariableKey.USER_ID: user_session_id,
- SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
- SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
- SystemVariableKey.WORKFLOW_ID: workflow.id,
- SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id,
- },
+ workflow_system_variables=SystemVariable(
+ query=message.query,
+ files=application_generate_entity.files,
+ conversation_id=conversation.id,
+ user_id=user_session_id,
+ dialogue_count=dialogue_count,
+ app_id=application_generate_entity.app_config.app_id,
+ workflow_id=workflow.id,
+ workflow_execution_id=application_generate_entity.workflow_run_id,
+ ),
workflow_info=CycleManagerWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
@@ -162,7 +166,6 @@ class AdvancedChatAppGenerateTaskPipeline:
Process generate task pipeline.
:return:
"""
- # start generate conversation name thread
self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name(
conversation_id=self._conversation_id, query=self._application_generate_entity.query
)
@@ -254,15 +257,12 @@ class AdvancedChatAppGenerateTaskPipeline:
yield response
start_listener_time = time.time()
- # timeout
while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT:
try:
if not tts_publisher:
break
audio_trunk = tts_publisher.check_and_get_audio()
if audio_trunk is None:
- # release cpu
- # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file)
time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME)
continue
if audio_trunk.status == "finish":
@@ -276,403 +276,617 @@ class AdvancedChatAppGenerateTaskPipeline:
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
- def _process_stream_response(
+ @contextmanager
+ def _database_session(self):
+ """Context manager for database sessions."""
+ with Session(db.engine, expire_on_commit=False) as session:
+ try:
+ yield session
+ session.commit()
+ except Exception:
+ session.rollback()
+ raise
+
+ def _ensure_workflow_initialized(self) -> None:
+ """Fluent validation for workflow state."""
+ if not self._workflow_run_id:
+ raise ValueError("workflow run not initialized.")
+
+ def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState:
+ """Fluent validation for graph runtime state."""
+ if not graph_runtime_state:
+ raise ValueError("graph runtime state not initialized.")
+ return graph_runtime_state
+
+ def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
+ """Handle ping events."""
+ yield self._base_task_pipeline._ping_stream_response()
+
+ def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
+ """Handle error events."""
+ with self._database_session() as session:
+ err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id)
+ yield self._base_task_pipeline._error_to_stream_response(err)
+
+ def _handle_workflow_started_event(
+ self, event: QueueWorkflowStartedEvent, *, graph_runtime_state: Optional[GraphRuntimeState] = None, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle workflow started events."""
+ # Override graph runtime state - this is a side effect but necessary
+ graph_runtime_state = event.graph_runtime_state
+
+ with self._database_session() as session:
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
+ self._workflow_run_id = workflow_execution.id_
+
+ message = self._get_message(session=session)
+ if not message:
+ raise ValueError(f"Message not found: {self._message_id}")
+
+ message.workflow_run_id = workflow_execution.id_
+ workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution=workflow_execution,
+ )
+
+ yield workflow_start_resp
+
+ def _handle_node_retry_event(self, event: QueueNodeRetryEvent, **kwargs) -> Generator[StreamResponse, None, None]:
+ """Handle node retry events."""
+ self._ensure_workflow_initialized()
+
+ with self._database_session() as session:
+ workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
+ workflow_execution_id=self._workflow_run_id, event=event
+ )
+ node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
+
+ if node_retry_resp:
+ yield node_retry_resp
+
+ def _handle_node_started_event(
+ self, event: QueueNodeStartedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle node started events."""
+ self._ensure_workflow_initialized()
+
+ workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
+ workflow_execution_id=self._workflow_run_id, event=event
+ )
+
+ node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
+
+ if node_start_resp:
+ yield node_start_resp
+
+ def _handle_node_succeeded_event(
+ self, event: QueueNodeSucceededEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle node succeeded events."""
+ # Record files if it's an answer node or end node
+ if event.node_type in [NodeType.ANSWER, NodeType.END]:
+ self._recorded_files.extend(
+ self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
+ )
+
+ with self._database_session() as session:
+ workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
+ node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
+
+ self._save_output_for_event(event, workflow_node_execution.id)
+
+ if node_finish_resp:
+ yield node_finish_resp
+
+ def _handle_node_failed_events(
self,
+ event: Union[
+ QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
+ ],
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle various node failure events."""
+ workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(event=event)
+
+ node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
+
+ if isinstance(event, QueueNodeExceptionEvent):
+ self._save_output_for_event(event, workflow_node_execution.id)
+
+ if node_finish_resp:
+ yield node_finish_resp
+
+ def _handle_text_chunk_event(
+ self,
+ event: QueueTextChunkEvent,
+ *,
tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
- trace_manager: Optional[TraceQueueManager] = None,
+ queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
+ **kwargs,
) -> Generator[StreamResponse, None, None]:
- """
- Process stream response.
- :return:
- """
- # init fake graph runtime state
- graph_runtime_state: Optional[GraphRuntimeState] = None
+ """Handle text chunk events."""
+ delta_text = event.text
+ if delta_text is None:
+ return
+
+ # Handle output moderation chunk
+ should_direct_answer = self._handle_output_moderation_chunk(delta_text)
+ if should_direct_answer:
+ return
+
+ # Only publish tts message at text chunk streaming
+ if tts_publisher and queue_message:
+ tts_publisher.publish(queue_message)
+
+ self._task_state.answer += delta_text
+ yield self._message_cycle_manager.message_to_stream_response(
+ answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
+ )
- for queue_message in self._base_task_pipeline._queue_manager.listen():
- event = queue_message.event
+ def _handle_parallel_branch_started_event(
+ self, event: QueueParallelBranchRunStartedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle parallel branch started events."""
+ self._ensure_workflow_initialized()
- if isinstance(event, QueuePingEvent):
- yield self._base_task_pipeline._ping_stream_response()
- elif isinstance(event, QueueErrorEvent):
- with Session(db.engine, expire_on_commit=False) as session:
- err = self._base_task_pipeline._handle_error(
- event=event, session=session, message_id=self._message_id
- )
- session.commit()
- yield self._base_task_pipeline._error_to_stream_response(err)
- break
- elif isinstance(event, QueueWorkflowStartedEvent):
- # override graph runtime state
- graph_runtime_state = event.graph_runtime_state
-
- with Session(db.engine, expire_on_commit=False) as session:
- # init workflow run
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
- self._workflow_run_id = workflow_execution.id_
- message = self._get_message(session=session)
- if not message:
- raise ValueError(f"Message not found: {self._message_id}")
- message.workflow_run_id = workflow_execution.id_
- workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
- session.commit()
+ parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield parallel_start_resp
- yield workflow_start_resp
- elif isinstance(
- event,
- QueueNodeRetryEvent,
- ):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
-
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
- workflow_execution_id=self._workflow_run_id, event=event
- )
- node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response(
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
- session.commit()
+ def _handle_parallel_branch_finished_events(
+ self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle parallel branch finished events."""
+ self._ensure_workflow_initialized()
- if node_retry_resp:
- yield node_retry_resp
- elif isinstance(event, QueueNodeStartedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield parallel_finish_resp
- workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
- workflow_execution_id=self._workflow_run_id, event=event
- )
+ def _handle_iteration_start_event(
+ self, event: QueueIterationStartEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle iteration start events."""
+ self._ensure_workflow_initialized()
- node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response(
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
+ iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield iter_start_resp
- if node_start_resp:
- yield node_start_resp
- elif isinstance(event, QueueNodeSucceededEvent):
- # Record files if it's an answer node or end node
- if event.node_type in [NodeType.ANSWER, NodeType.END]:
- self._recorded_files.extend(
- self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {})
- )
+ def _handle_iteration_next_event(
+ self, event: QueueIterationNextEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle iteration next events."""
+ self._ensure_workflow_initialized()
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(
- event=event
- )
+ iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield iter_next_resp
- node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
- session.commit()
- self._save_output_for_event(event, workflow_node_execution.id)
+ def _handle_iteration_completed_event(
+ self, event: QueueIterationCompletedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle iteration completed events."""
+ self._ensure_workflow_initialized()
- if node_finish_resp:
- yield node_finish_resp
- elif isinstance(
- event,
- QueueNodeFailedEvent
- | QueueNodeInIterationFailedEvent
- | QueueNodeInLoopFailedEvent
- | QueueNodeExceptionEvent,
- ):
- workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
- event=event
- )
+ iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield iter_finish_resp
- node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response(
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
- if isinstance(event, QueueNodeExceptionEvent):
- self._save_output_for_event(event, workflow_node_execution.id)
-
- if node_finish_resp:
- yield node_finish_resp
- elif isinstance(event, QueueParallelBranchRunStartedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
-
- parallel_start_resp = (
- self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
- )
+ def _handle_loop_start_event(self, event: QueueLoopStartEvent, **kwargs) -> Generator[StreamResponse, None, None]:
+ """Handle loop start events."""
+ self._ensure_workflow_initialized()
- yield parallel_start_resp
- elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield loop_start_resp
- parallel_finish_resp = (
- self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
- )
+ def _handle_loop_next_event(self, event: QueueLoopNextEvent, **kwargs) -> Generator[StreamResponse, None, None]:
+ """Handle loop next events."""
+ self._ensure_workflow_initialized()
- yield parallel_finish_resp
- elif isinstance(event, QueueIterationStartEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield loop_next_resp
- iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ def _handle_loop_completed_event(
+ self, event: QueueLoopCompletedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle loop completed events."""
+ self._ensure_workflow_initialized()
- yield iter_start_resp
- elif isinstance(event, QueueIterationNextEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield loop_finish_resp
- iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ def _handle_workflow_succeeded_event(
+ self,
+ event: QueueWorkflowSucceededEvent,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle workflow succeeded events."""
+ self._ensure_workflow_initialized()
+ validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
+
+ with self._database_session() as session:
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
+ workflow_run_id=self._workflow_run_id,
+ total_tokens=validated_state.total_tokens,
+ total_steps=validated_state.node_run_steps,
+ outputs=event.outputs,
+ conversation_id=self._conversation_id,
+ trace_manager=trace_manager,
+ external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
+ )
+ workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
+ session=session,
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution=workflow_execution,
+ )
- yield iter_next_resp
- elif isinstance(event, QueueIterationCompletedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ yield workflow_finish_resp
+ self._base_task_pipeline._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
- iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ def _handle_workflow_partial_success_event(
+ self,
+ event: QueueWorkflowPartialSuccessEvent,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle workflow partial success events."""
+ self._ensure_workflow_initialized()
+ validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
+
+ with self._database_session() as session:
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
+ workflow_run_id=self._workflow_run_id,
+ total_tokens=validated_state.total_tokens,
+ total_steps=validated_state.node_run_steps,
+ outputs=event.outputs,
+ exceptions_count=event.exceptions_count,
+ conversation_id=None,
+ trace_manager=trace_manager,
+ external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
+ )
+ workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
+ session=session,
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution=workflow_execution,
+ )
- yield iter_finish_resp
- elif isinstance(event, QueueLoopStartEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ yield workflow_finish_resp
+ self._base_task_pipeline._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
- loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ def _handle_workflow_failed_event(
+ self,
+ event: QueueWorkflowFailedEvent,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle workflow failed events."""
+ self._ensure_workflow_initialized()
+ validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
+
+ with self._database_session() as session:
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
+ workflow_run_id=self._workflow_run_id,
+ total_tokens=validated_state.total_tokens,
+ total_steps=validated_state.node_run_steps,
+ status=WorkflowExecutionStatus.FAILED,
+ error_message=event.error,
+ conversation_id=self._conversation_id,
+ trace_manager=trace_manager,
+ exceptions_count=event.exceptions_count,
+ external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
+ )
+ workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
+ session=session,
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution=workflow_execution,
+ )
+ err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
+ err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id)
- yield loop_start_resp
- elif isinstance(event, QueueLoopNextEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ yield workflow_finish_resp
+ yield self._base_task_pipeline._error_to_stream_response(err)
- loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
+ def _handle_stop_event(
+ self,
+ event: QueueStopEvent,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle stop events."""
+ if self._workflow_run_id and graph_runtime_state:
+ with self._database_session() as session:
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
+ workflow_run_id=self._workflow_run_id,
+ total_tokens=graph_runtime_state.total_tokens,
+ total_steps=graph_runtime_state.node_run_steps,
+ status=WorkflowExecutionStatus.STOPPED,
+ error_message=event.get_stop_reason(),
+ conversation_id=self._conversation_id,
+ trace_manager=trace_manager,
+ external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
+ )
+ workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
+ session=session,
task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
+ workflow_execution=workflow_execution,
)
+ # Save message
+ self._save_message(session=session, graph_runtime_state=graph_runtime_state)
- yield loop_next_resp
- elif isinstance(event, QueueLoopCompletedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ yield workflow_finish_resp
+ elif event.stopped_by in (
+ QueueStopEvent.StopBy.INPUT_MODERATION,
+ QueueStopEvent.StopBy.ANNOTATION_REPLY,
+ ):
+ # When hitting input-moderation or annotation-reply, the workflow will not start
+ with self._database_session() as session:
+ # Save message
+ self._save_message(session=session)
- loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ yield self._message_end_to_stream_response()
- yield loop_finish_resp
- elif isinstance(event, QueueWorkflowSucceededEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ def _handle_advanced_chat_message_end_event(
+ self,
+ event: QueueAdvancedChatMessageEndEvent,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle advanced chat message end events."""
+ self._ensure_graph_runtime_initialized(graph_runtime_state)
- if not graph_runtime_state:
- raise ValueError("workflow run not initialized.")
+ output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
+ self._task_state.answer
+ )
+ if output_moderation_answer:
+ self._task_state.answer = output_moderation_answer
+ yield self._message_cycle_manager.message_replace_to_stream_response(
+ answer=output_moderation_answer,
+ reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
+ )
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
- workflow_run_id=self._workflow_run_id,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- outputs=event.outputs,
- conversation_id=self._conversation_id,
- trace_manager=trace_manager,
- )
+ # Save message
+ with self._database_session() as session:
+ self._save_message(session=session, graph_runtime_state=graph_runtime_state)
- workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
- session=session,
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
+ yield self._message_end_to_stream_response()
- yield workflow_finish_resp
- self._base_task_pipeline._queue_manager.publish(
- QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
- )
- elif isinstance(event, QueueWorkflowPartialSuccessEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
- if not graph_runtime_state:
- raise ValueError("graph runtime state not initialized.")
-
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
- workflow_run_id=self._workflow_run_id,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- outputs=event.outputs,
- exceptions_count=event.exceptions_count,
- conversation_id=None,
- trace_manager=trace_manager,
- )
- workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
- session=session,
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
+ def _handle_retriever_resources_event(
+ self, event: QueueRetrieverResourcesEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle retriever resources events."""
+ self._message_cycle_manager.handle_retriever_resources(event)
- yield workflow_finish_resp
- self._base_task_pipeline._queue_manager.publish(
- QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE
- )
- elif isinstance(event, QueueWorkflowFailedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
- if not graph_runtime_state:
- raise ValueError("graph runtime state not initialized.")
-
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
- workflow_run_id=self._workflow_run_id,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- status=WorkflowExecutionStatus.FAILED,
- error_message=event.error,
- conversation_id=self._conversation_id,
- trace_manager=trace_manager,
- exceptions_count=event.exceptions_count,
- )
- workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
- session=session,
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
- err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}"))
- err = self._base_task_pipeline._handle_error(
- event=err_event, session=session, message_id=self._message_id
- )
+ with self._database_session() as session:
+ message = self._get_message(session=session)
+ message.message_metadata = self._task_state.metadata.model_dump_json()
+ return
+ yield # Make this a generator
- yield workflow_finish_resp
- yield self._base_task_pipeline._error_to_stream_response(err)
- break
- elif isinstance(event, QueueStopEvent):
- if self._workflow_run_id and graph_runtime_state:
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
- workflow_run_id=self._workflow_run_id,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- status=WorkflowExecutionStatus.STOPPED,
- error_message=event.get_stop_reason(),
- conversation_id=self._conversation_id,
- trace_manager=trace_manager,
- )
- workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
- session=session,
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
- # Save message
- self._save_message(session=session, graph_runtime_state=graph_runtime_state)
- session.commit()
-
- yield workflow_finish_resp
- elif event.stopped_by in (
- QueueStopEvent.StopBy.INPUT_MODERATION,
- QueueStopEvent.StopBy.ANNOTATION_REPLY,
- ):
- # When hitting input-moderation or annotation-reply, the workflow will not start
- with Session(db.engine, expire_on_commit=False) as session:
- # Save message
- self._save_message(session=session)
- session.commit()
-
- yield self._message_end_to_stream_response()
- break
- elif isinstance(event, QueueRetrieverResourcesEvent):
- self._message_cycle_manager.handle_retriever_resources(event)
-
- with Session(db.engine, expire_on_commit=False) as session:
- message = self._get_message(session=session)
- message.message_metadata = self._task_state.metadata.model_dump_json()
- session.commit()
- elif isinstance(event, QueueAnnotationReplyEvent):
- self._message_cycle_manager.handle_annotation_reply(event)
-
- with Session(db.engine, expire_on_commit=False) as session:
- message = self._get_message(session=session)
- message.message_metadata = self._task_state.metadata.model_dump_json()
- session.commit()
- elif isinstance(event, QueueTextChunkEvent):
- delta_text = event.text
- if delta_text is None:
- continue
+ def _handle_annotation_reply_event(
+ self, event: QueueAnnotationReplyEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle annotation reply events."""
+ self._message_cycle_manager.handle_annotation_reply(event)
- # handle output moderation chunk
- should_direct_answer = self._handle_output_moderation_chunk(delta_text)
- if should_direct_answer:
- continue
+ with self._database_session() as session:
+ message = self._get_message(session=session)
+ message.message_metadata = self._task_state.metadata.model_dump_json()
+ return
+ yield # Make this a generator
- # only publish tts message at text chunk streaming
- if tts_publisher:
- tts_publisher.publish(queue_message)
+ def _handle_message_replace_event(
+ self, event: QueueMessageReplaceEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle message replace events."""
+ yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason)
- self._task_state.answer += delta_text
- yield self._message_cycle_manager.message_to_stream_response(
- answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
- )
- elif isinstance(event, QueueMessageReplaceEvent):
- # published by moderation
- yield self._message_cycle_manager.message_replace_to_stream_response(
- answer=event.text, reason=event.reason
- )
- elif isinstance(event, QueueAdvancedChatMessageEndEvent):
- if not graph_runtime_state:
- raise ValueError("graph runtime state not initialized.")
+ def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
+ """Handle agent log events."""
+ yield self._workflow_response_converter.handle_agent_log(
+ task_id=self._application_generate_entity.task_id, event=event
+ )
- output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished(
- self._task_state.answer
- )
- if output_moderation_answer:
- self._task_state.answer = output_moderation_answer
- yield self._message_cycle_manager.message_replace_to_stream_response(
- answer=output_moderation_answer,
- reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
+ def _get_event_handlers(self) -> dict[type, Callable]:
+ """Get mapping of event types to their handlers using fluent pattern."""
+ return {
+ # Basic events
+ QueuePingEvent: self._handle_ping_event,
+ QueueErrorEvent: self._handle_error_event,
+ QueueTextChunkEvent: self._handle_text_chunk_event,
+ # Workflow events
+ QueueWorkflowStartedEvent: self._handle_workflow_started_event,
+ QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
+ QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
+ QueueWorkflowFailedEvent: self._handle_workflow_failed_event,
+ # Node events
+ QueueNodeRetryEvent: self._handle_node_retry_event,
+ QueueNodeStartedEvent: self._handle_node_started_event,
+ QueueNodeSucceededEvent: self._handle_node_succeeded_event,
+ # Parallel branch events
+ QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
+ # Iteration events
+ QueueIterationStartEvent: self._handle_iteration_start_event,
+ QueueIterationNextEvent: self._handle_iteration_next_event,
+ QueueIterationCompletedEvent: self._handle_iteration_completed_event,
+ # Loop events
+ QueueLoopStartEvent: self._handle_loop_start_event,
+ QueueLoopNextEvent: self._handle_loop_next_event,
+ QueueLoopCompletedEvent: self._handle_loop_completed_event,
+ # Control events
+ QueueStopEvent: self._handle_stop_event,
+ # Message events
+ QueueRetrieverResourcesEvent: self._handle_retriever_resources_event,
+ QueueAnnotationReplyEvent: self._handle_annotation_reply_event,
+ QueueMessageReplaceEvent: self._handle_message_replace_event,
+ QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event,
+ QueueAgentLogEvent: self._handle_agent_log_event,
+ }
+
+ def _dispatch_event(
+ self,
+ event: Any,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
+ ) -> Generator[StreamResponse, None, None]:
+ """Dispatch events using elegant pattern matching."""
+ handlers = self._get_event_handlers()
+ event_type = type(event)
+
+ # Direct handler lookup
+ if handler := handlers.get(event_type):
+ yield from handler(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ return
+
+ # Handle node failure events with isinstance check
+ if isinstance(
+ event,
+ (
+ QueueNodeFailedEvent,
+ QueueNodeInIterationFailedEvent,
+ QueueNodeInLoopFailedEvent,
+ QueueNodeExceptionEvent,
+ ),
+ ):
+ yield from self._handle_node_failed_events(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ return
+
+ # Handle parallel branch finished events with isinstance check
+ if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
+ yield from self._handle_parallel_branch_finished_events(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ return
+
+ # For unhandled events, we continue (original behavior)
+ return
+
+ def _process_stream_response(
+ self,
+ tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ ) -> Generator[StreamResponse, None, None]:
+ """
+ Process stream response using elegant Fluent Python patterns.
+ Maintains exact same functionality as original 57-if-statement version.
+ """
+ # Initialize graph runtime state
+ graph_runtime_state: Optional[GraphRuntimeState] = None
+
+ for queue_message in self._base_task_pipeline._queue_manager.listen():
+ event = queue_message.event
+
+ match event:
+ case QueueWorkflowStartedEvent():
+ graph_runtime_state = event.graph_runtime_state
+ yield from self._handle_workflow_started_event(event)
+
+ case QueueTextChunkEvent():
+ yield from self._handle_text_chunk_event(
+ event, tts_publisher=tts_publisher, queue_message=queue_message
)
- # Save message
- with Session(db.engine, expire_on_commit=False) as session:
- self._save_message(session=session, graph_runtime_state=graph_runtime_state)
- session.commit()
-
- yield self._message_end_to_stream_response()
- elif isinstance(event, QueueAgentLogEvent):
- yield self._workflow_response_converter.handle_agent_log(
- task_id=self._application_generate_entity.task_id, event=event
- )
- else:
- continue
+ case QueueErrorEvent():
+ yield from self._handle_error_event(event)
+ break
+
+ case QueueWorkflowFailedEvent():
+ yield from self._handle_workflow_failed_event(
+ event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager
+ )
+ break
+
+ case QueueStopEvent():
+ yield from self._handle_stop_event(
+ event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager
+ )
+ break
+
+ # Handle all other events through elegant dispatch
+ case _:
+ if responses := list(
+ self._dispatch_event(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ ):
+ yield from responses
- # publish None when task finished
if tts_publisher:
tts_publisher.publish(None)
@@ -744,7 +958,6 @@ class AdvancedChatAppGenerateTaskPipeline:
"""
if self._base_task_pipeline._output_moderation_handler:
if self._base_task_pipeline._output_moderation_handler.should_direct_output():
- # stop subscribe new token when output moderation should direct output
self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output()
self._base_task_pipeline._queue_manager.publish(
QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE
diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py
index edea6199d3..8665bc9d11 100644
--- a/api/core/app/apps/agent_chat/app_generator.py
+++ b/api/core/app/apps/agent_chat/app_generator.py
@@ -15,7 +15,8 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager
from core.app.apps.agent_chat.app_runner import AgentChatAppRunner
from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter
-from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
+from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom
diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py
index 71328f6d1b..39d6ba39f5 100644
--- a/api/core/app/apps/agent_chat/app_runner.py
+++ b/api/core/app/apps/agent_chat/app_runner.py
@@ -45,7 +45,7 @@ class AgentChatAppRunner(AppRunner):
app_config = application_generate_entity.app_config
app_config = cast(AgentChatAppConfig, app_config)
- app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
+ app_record = db.session.query(App).where(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
@@ -183,10 +183,10 @@ class AgentChatAppRunner(AppRunner):
if {ModelFeature.MULTI_TOOL_CALL, ModelFeature.TOOL_CALL}.intersection(model_schema.features or []):
agent_entity.strategy = AgentEntity.Strategy.FUNCTION_CALLING
- conversation_result = db.session.query(Conversation).filter(Conversation.id == conversation.id).first()
+ conversation_result = db.session.query(Conversation).where(Conversation.id == conversation.id).first()
if conversation_result is None:
raise ValueError("Conversation not found")
- message_result = db.session.query(Message).filter(Message.id == message.id).first()
+ message_result = db.session.query(Message).where(Message.id == message.id).first()
if message_result is None:
raise ValueError("Message not found")
db.session.close()
diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py
index 0ba33fbe0d..9da0bae56a 100644
--- a/api/core/app/apps/base_app_queue_manager.py
+++ b/api/core/app/apps/base_app_queue_manager.py
@@ -169,7 +169,3 @@ class AppQueueManager:
raise TypeError(
"Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed."
)
-
-
-class GenerateTaskStoppedError(Exception):
- pass
diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py
index a3f0cf7f9f..6e8c261a6a 100644
--- a/api/core/app/apps/base_app_runner.py
+++ b/api/core/app/apps/base_app_runner.py
@@ -38,69 +38,6 @@ _logger = logging.getLogger(__name__)
class AppRunner:
- def get_pre_calculate_rest_tokens(
- self,
- app_record: App,
- model_config: ModelConfigWithCredentialsEntity,
- prompt_template_entity: PromptTemplateEntity,
- inputs: Mapping[str, str],
- files: Sequence["File"],
- query: Optional[str] = None,
- ) -> int:
- """
- Get pre calculate rest tokens
- :param app_record: app record
- :param model_config: model config entity
- :param prompt_template_entity: prompt template entity
- :param inputs: inputs
- :param files: files
- :param query: query
- :return:
- """
- # Invoke model
- model_instance = ModelInstance(
- provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
- )
-
- model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
-
- max_tokens = 0
- for parameter_rule in model_config.model_schema.parameter_rules:
- if parameter_rule.name == "max_tokens" or (
- parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
- ):
- max_tokens = (
- model_config.parameters.get(parameter_rule.name)
- or model_config.parameters.get(parameter_rule.use_template or "")
- ) or 0
-
- if model_context_tokens is None:
- return -1
-
- if max_tokens is None:
- max_tokens = 0
-
- # get prompt messages without memory and context
- prompt_messages, stop = self.organize_prompt_messages(
- app_record=app_record,
- model_config=model_config,
- prompt_template_entity=prompt_template_entity,
- inputs=inputs,
- files=files,
- query=query,
- )
-
- prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages)
-
- rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens
- if rest_tokens < 0:
- raise InvokeBadRequestError(
- "Query or prefix prompt is too long, you can reduce the prefix prompt, "
- "or shrink the max token, or switch to a llm with a larger token limit size."
- )
-
- return rest_tokens
-
def recalc_llm_max_tokens(
self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage]
):
@@ -181,7 +118,7 @@ class AppRunner:
else:
memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False))
- model_mode = ModelMode.value_of(model_config.mode)
+ model_mode = ModelMode(model_config.mode)
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
if model_mode == ModelMode.COMPLETION:
advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template
diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py
index a28c106ce9..0c76cc39ae 100644
--- a/api/core/app/apps/chat/app_generator.py
+++ b/api/core/app/apps/chat/app_generator.py
@@ -11,10 +11,11 @@ from configs import dify_config
from constants import UUID_NIL
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
-from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
+from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.chat.app_config_manager import ChatAppConfigManager
from core.app.apps.chat.app_runner import ChatAppRunner
from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom
diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py
index 39597fc036..894d7906d5 100644
--- a/api/core/app/apps/chat/app_runner.py
+++ b/api/core/app/apps/chat/app_runner.py
@@ -43,7 +43,7 @@ class ChatAppRunner(AppRunner):
app_config = application_generate_entity.app_config
app_config = cast(ChatAppConfig, app_config)
- app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
+ app_record = db.session.query(App).where(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py
index 966a6f1d66..9356bd1cea 100644
--- a/api/core/app/apps/completion/app_generator.py
+++ b/api/core/app/apps/completion/app_generator.py
@@ -10,10 +10,11 @@ from pydantic import ValidationError
from configs import dify_config
from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
-from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
+from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.completion.app_config_manager import CompletionAppConfigManager
from core.app.apps.completion.app_runner import CompletionAppRunner
from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom
@@ -247,7 +248,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
"""
message = (
db.session.query(Message)
- .filter(
+ .where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py
index 80fdd0b80e..50d2a0036c 100644
--- a/api/core/app/apps/completion/app_runner.py
+++ b/api/core/app/apps/completion/app_runner.py
@@ -36,7 +36,7 @@ class CompletionAppRunner(AppRunner):
app_config = application_generate_entity.app_config
app_config = cast(CompletionAppConfig, app_config)
- app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
+ app_record = db.session.query(App).where(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
diff --git a/api/core/app/apps/exc.py b/api/core/app/apps/exc.py
new file mode 100644
index 0000000000..4187118b9b
--- /dev/null
+++ b/api/core/app/apps/exc.py
@@ -0,0 +1,2 @@
+class GenerateTaskStoppedError(Exception):
+ pass
diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py
index e84d59209d..7dd9904eeb 100644
--- a/api/core/app/apps/message_based_app_generator.py
+++ b/api/core/app/apps/message_based_app_generator.py
@@ -1,12 +1,12 @@
import json
import logging
from collections.abc import Generator
-from datetime import UTC, datetime
from typing import Optional, Union, cast
from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom
from core.app.apps.base_app_generator import BaseAppGenerator
-from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError
+from core.app.apps.base_app_queue_manager import AppQueueManager
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
@@ -24,6 +24,7 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import CreatorUserRole
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
@@ -84,7 +85,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
if conversation:
app_model_config = (
db.session.query(AppModelConfig)
- .filter(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
+ .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
)
@@ -150,13 +151,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
introduction = self._get_conversation_introduction(application_generate_entity)
# get conversation name
- if isinstance(application_generate_entity, AdvancedChatAppGenerateEntity):
- query = application_generate_entity.query or "New conversation"
- else:
- query = next(iter(application_generate_entity.inputs.values()), "New conversation")
- if isinstance(query, int):
- query = str(query)
- query = query or "New conversation"
+ query = application_generate_entity.query or "New conversation"
conversation_name = (query[:20] + "…") if len(query) > 20 else query
if not conversation:
@@ -183,7 +178,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.commit()
db.session.refresh(conversation)
else:
- conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ conversation.updated_at = naive_utc_now()
db.session.commit()
message = Message(
@@ -258,7 +253,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param conversation_id: conversation id
:return: conversation
"""
- conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
+ conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
if not conversation:
raise ConversationNotExistsError("Conversation not exists")
@@ -271,7 +266,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
:param message_id: message id
:return: message
"""
- message = db.session.query(Message).filter(Message.id == message_id).first()
+ message = db.session.query(Message).where(Message.id == message_id).first()
if message is None:
raise MessageNotExistsError("Message not exists")
diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py
index 363c3c82bb..8507f23f17 100644
--- a/api/core/app/apps/message_based_app_queue_manager.py
+++ b/api/core/app/apps/message_based_app_queue_manager.py
@@ -1,4 +1,5 @@
-from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
+from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py
index 40a1e272a7..4c36f63c71 100644
--- a/api/core/app/apps/workflow/app_generator.py
+++ b/api/core/app/apps/workflow/app_generator.py
@@ -7,13 +7,15 @@ from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
-from sqlalchemy.orm import sessionmaker
+from sqlalchemy import select
+from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_generator import BaseAppGenerator
-from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
+from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner
@@ -21,10 +23,10 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
+from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
-from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@@ -123,6 +125,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
)
inputs: Mapping[str, Any] = args["inputs"]
+
+ extras = {
+ **extract_external_trace_id_from_args(args),
+ }
workflow_run_id = str(uuid.uuid4())
# init application generate entity
application_generate_entity = WorkflowAppGenerateEntity(
@@ -142,6 +148,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
call_depth=call_depth,
trace_manager=trace_manager,
workflow_execution_id=workflow_run_id,
+ extras=extras,
)
contexts.plugin_tool_providers.set({})
@@ -156,14 +163,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
- workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
+ workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from,
)
# Create workflow node execution repository
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
@@ -306,16 +313,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
- workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
+ workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
- session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
-
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
@@ -390,16 +395,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository
- workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
+ workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
)
# Create workflow node execution repository
- session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
-
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=user,
app_id=application_generate_entity.app_config.app_id,
@@ -443,17 +446,44 @@ class WorkflowAppGenerator(BaseAppGenerator):
"""
with preserve_flask_contexts(flask_app, context_vars=context):
- try:
- # workflow app
- runner = WorkflowAppRunner(
- application_generate_entity=application_generate_entity,
- queue_manager=queue_manager,
- workflow_thread_pool_id=workflow_thread_pool_id,
- variable_loader=variable_loader,
+ with Session(db.engine, expire_on_commit=False) as session:
+ workflow = session.scalar(
+ select(Workflow).where(
+ Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
+ Workflow.app_id == application_generate_entity.app_config.app_id,
+ Workflow.id == application_generate_entity.app_config.workflow_id,
+ )
)
+ if workflow is None:
+ raise ValueError("Workflow not found")
+
+ # Determine system_user_id based on invocation source
+ is_external_api_call = application_generate_entity.invoke_from in {
+ InvokeFrom.WEB_APP,
+ InvokeFrom.SERVICE_API,
+ }
+
+ if is_external_api_call:
+ # For external API calls, use end user's session ID
+ end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id))
+ system_user_id = end_user.session_id if end_user else ""
+ else:
+ # For internal calls, use the original user ID
+ system_user_id = application_generate_entity.user_id
+
+ runner = WorkflowAppRunner(
+ application_generate_entity=application_generate_entity,
+ queue_manager=queue_manager,
+ workflow_thread_pool_id=workflow_thread_pool_id,
+ variable_loader=variable_loader,
+ workflow=workflow,
+ system_user_id=system_user_id,
+ )
+ try:
runner.run()
- except GenerateTaskStoppedError:
+ except GenerateTaskStoppedError as e:
+ logger.warning(f"Task stopped: {str(e)}")
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
@@ -469,8 +499,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
- finally:
- db.session.close()
def _handle_response(
self,
diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py
index 349b8eb51b..40fc03afb7 100644
--- a/api/core/app/apps/workflow/app_queue_manager.py
+++ b/api/core/app/apps/workflow/app_queue_manager.py
@@ -1,4 +1,5 @@
-from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
+from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py
index 07aeb57fa3..4f4c1460ae 100644
--- a/api/core/app/apps/workflow/app_runner.py
+++ b/api/core/app/apps/workflow/app_runner.py
@@ -11,13 +11,11 @@ from core.app.entities.app_invoke_entities import (
)
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.enums import SystemVariableKey
+from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
-from extensions.ext_database import db
from models.enums import UserFrom
-from models.model import App, EndUser
-from models.workflow import WorkflowType
+from models.workflow import Workflow, WorkflowType
logger = logging.getLogger(__name__)
@@ -29,22 +27,23 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
def __init__(
self,
+ *,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
+ workflow: Workflow,
+ system_user_id: str,
) -> None:
- """
- :param application_generate_entity: application generate entity
- :param queue_manager: application queue manager
- :param workflow_thread_pool_id: workflow thread pool id
- """
- super().__init__(queue_manager, variable_loader)
+ super().__init__(
+ queue_manager=queue_manager,
+ variable_loader=variable_loader,
+ app_id=application_generate_entity.app_config.app_id,
+ )
self.application_generate_entity = application_generate_entity
self.workflow_thread_pool_id = workflow_thread_pool_id
-
- def _get_app_id(self) -> str:
- return self.application_generate_entity.app_config.app_id
+ self._workflow = workflow
+ self._sys_user_id = system_user_id
def run(self) -> None:
"""
@@ -53,24 +52,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
- user_id = None
- if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
- end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
- if end_user:
- user_id = end_user.session_id
- else:
- user_id = self.application_generate_entity.user_id
-
- app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
- if not app_record:
- raise ValueError("App not found")
-
- workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
- if not workflow:
- raise ValueError("Workflow not initialized")
-
- db.session.close()
-
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
@@ -79,14 +60,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
- workflow=workflow,
+ workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
- workflow=workflow,
+ workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
)
@@ -95,32 +76,33 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
files = self.application_generate_entity.files
# Create a variable pool.
- system_inputs = {
- SystemVariableKey.FILES: files,
- SystemVariableKey.USER_ID: user_id,
- SystemVariableKey.APP_ID: app_config.app_id,
- SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
- SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id,
- }
+
+ system_inputs = SystemVariable(
+ files=files,
+ user_id=self._sys_user_id,
+ app_id=app_config.app_id,
+ workflow_id=app_config.workflow_id,
+ workflow_execution_id=self.application_generate_entity.workflow_execution_id,
+ )
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
- environment_variables=workflow.environment_variables,
+ environment_variables=self._workflow.environment_variables,
conversation_variables=[],
)
# init graph
- graph = self._init_graph(graph_config=workflow.graph_dict)
+ graph = self._init_graph(graph_config=self._workflow.graph_dict)
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
- tenant_id=workflow.tenant_id,
- app_id=workflow.app_id,
- workflow_id=workflow.id,
- workflow_type=WorkflowType.value_of(workflow.type),
+ tenant_id=self._workflow.tenant_id,
+ app_id=self._workflow.app_id,
+ workflow_id=self._workflow.id,
+ workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph,
- graph_config=workflow.graph_dict,
+ graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py
index 2a85cd5e3d..e31a316c56 100644
--- a/api/core/app/apps/workflow/generate_task_pipeline.py
+++ b/api/core/app/apps/workflow/generate_task_pipeline.py
@@ -1,9 +1,9 @@
import logging
import time
-from collections.abc import Generator
-from typing import Optional, Union
+from collections.abc import Callable, Generator
+from contextlib import contextmanager
+from typing import Any, Optional, Union
-from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
@@ -14,6 +14,7 @@ from core.app.entities.app_invoke_entities import (
WorkflowAppGenerateEntity,
)
from core.app.entities.queue_entities import (
+ MessageQueueMessage,
QueueAgentLogEvent,
QueueErrorEvent,
QueueIterationCompletedEvent,
@@ -39,11 +40,13 @@ from core.app.entities.queue_entities import (
QueueWorkflowPartialSuccessEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
+ WorkflowQueueMessage,
)
from core.app.entities.task_entities import (
ErrorStreamResponse,
MessageAudioEndStreamResponse,
MessageAudioStreamResponse,
+ PingStreamResponse,
StreamResponse,
TextChunkStreamResponse,
WorkflowAppBlockingResponse,
@@ -55,10 +58,11 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType
-from core.workflow.enums import SystemVariableKey
+from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager
from extensions.ext_database import db
from models.account import Account
@@ -68,7 +72,6 @@ from models.workflow import (
Workflow,
WorkflowAppLog,
WorkflowAppLogCreatedFrom,
- WorkflowRun,
)
logger = logging.getLogger(__name__)
@@ -109,13 +112,13 @@ class WorkflowAppGenerateTaskPipeline:
self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
- workflow_system_variables={
- SystemVariableKey.FILES: application_generate_entity.files,
- SystemVariableKey.USER_ID: user_session_id,
- SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
- SystemVariableKey.WORKFLOW_ID: workflow.id,
- SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id,
- },
+ workflow_system_variables=SystemVariable(
+ files=application_generate_entity.files,
+ user_id=user_session_id,
+ app_id=application_generate_entity.app_config.app_id,
+ workflow_id=workflow.id,
+ workflow_execution_id=application_generate_entity.workflow_execution_id,
+ ),
workflow_info=CycleManagerWorkflowInfo(
workflow_id=workflow.id,
workflow_type=WorkflowType(workflow.type),
@@ -248,322 +251,500 @@ class WorkflowAppGenerateTaskPipeline:
if tts_publisher:
yield MessageAudioEndStreamResponse(audio="", task_id=task_id)
- def _process_stream_response(
+ @contextmanager
+ def _database_session(self):
+ """Context manager for database sessions."""
+ with Session(db.engine, expire_on_commit=False) as session:
+ try:
+ yield session
+ session.commit()
+ except Exception:
+ session.rollback()
+ raise
+
+ def _ensure_workflow_initialized(self) -> None:
+ """Fluent validation for workflow state."""
+ if not self._workflow_run_id:
+ raise ValueError("workflow run not initialized.")
+
+ def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState:
+ """Fluent validation for graph runtime state."""
+ if not graph_runtime_state:
+ raise ValueError("graph runtime state not initialized.")
+ return graph_runtime_state
+
+ def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]:
+ """Handle ping events."""
+ yield self._base_task_pipeline._ping_stream_response()
+
+ def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]:
+ """Handle error events."""
+ err = self._base_task_pipeline._handle_error(event=event)
+ yield self._base_task_pipeline._error_to_stream_response(err)
+
+ def _handle_workflow_started_event(
+ self, event: QueueWorkflowStartedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle workflow started events."""
+ # init workflow run
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
+ self._workflow_run_id = workflow_execution.id_
+ start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution=workflow_execution,
+ )
+ yield start_resp
+
+ def _handle_node_retry_event(self, event: QueueNodeRetryEvent, **kwargs) -> Generator[StreamResponse, None, None]:
+ """Handle node retry events."""
+ self._ensure_workflow_initialized()
+
+ with self._database_session() as session:
+ workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
+
+ if response:
+ yield response
+
+ def _handle_node_started_event(
+ self, event: QueueNodeStartedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle node started events."""
+ self._ensure_workflow_initialized()
+
+ workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
+ workflow_execution_id=self._workflow_run_id, event=event
+ )
+ node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
+
+ if node_start_response:
+ yield node_start_response
+
+ def _handle_node_succeeded_event(
+ self, event: QueueNodeSucceededEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle node succeeded events."""
+ workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event)
+ node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
+
+ self._save_output_for_event(event, workflow_node_execution.id)
+
+ if node_success_response:
+ yield node_success_response
+
+ def _handle_node_failed_events(
self,
- tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
- trace_manager: Optional[TraceQueueManager] = None,
+ event: Union[
+ QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent
+ ],
+ **kwargs,
) -> Generator[StreamResponse, None, None]:
- """
- Process stream response.
- :return:
- """
- graph_runtime_state = None
+ """Handle various node failure events."""
+ workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
+ event=event,
+ )
+ node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
+ event=event,
+ task_id=self._application_generate_entity.task_id,
+ workflow_node_execution=workflow_node_execution,
+ )
- for queue_message in self._base_task_pipeline._queue_manager.listen():
- event = queue_message.event
+ if isinstance(event, QueueNodeExceptionEvent):
+ self._save_output_for_event(event, workflow_node_execution.id)
- if isinstance(event, QueuePingEvent):
- yield self._base_task_pipeline._ping_stream_response()
- elif isinstance(event, QueueErrorEvent):
- err = self._base_task_pipeline._handle_error(event=event)
- yield self._base_task_pipeline._error_to_stream_response(err)
- break
- elif isinstance(event, QueueWorkflowStartedEvent):
- # override graph runtime state
- graph_runtime_state = event.graph_runtime_state
-
- # init workflow run
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start()
- self._workflow_run_id = workflow_execution.id_
- start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
+ if node_failed_response:
+ yield node_failed_response
- yield start_resp
- elif isinstance(
- event,
- QueueNodeRetryEvent,
- ):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried(
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
- response = self._workflow_response_converter.workflow_node_retry_to_stream_response(
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
- session.commit()
+ def _handle_parallel_branch_started_event(
+ self, event: QueueParallelBranchRunStartedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle parallel branch started events."""
+ self._ensure_workflow_initialized()
- if response:
- yield response
- elif isinstance(event, QueueNodeStartedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield parallel_start_resp
- workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start(
- workflow_execution_id=self._workflow_run_id, event=event
- )
- node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response(
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
+ def _handle_parallel_branch_finished_events(
+ self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle parallel branch finished events."""
+ self._ensure_workflow_initialized()
- if node_start_response:
- yield node_start_response
- elif isinstance(event, QueueNodeSucceededEvent):
- workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(
- event=event
- )
- node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
+ parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield parallel_finish_resp
- self._save_output_for_event(event, workflow_node_execution.id)
+ def _handle_iteration_start_event(
+ self, event: QueueIterationStartEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle iteration start events."""
+ self._ensure_workflow_initialized()
- if node_success_response:
- yield node_success_response
- elif isinstance(
- event,
- QueueNodeFailedEvent
- | QueueNodeInIterationFailedEvent
- | QueueNodeInLoopFailedEvent
- | QueueNodeExceptionEvent,
- ):
- workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(
- event=event,
- )
- node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response(
- event=event,
- task_id=self._application_generate_entity.task_id,
- workflow_node_execution=workflow_node_execution,
- )
- if isinstance(event, QueueNodeExceptionEvent):
- self._save_output_for_event(event, workflow_node_execution.id)
+ iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield iter_start_resp
- if node_failed_response:
- yield node_failed_response
+ def _handle_iteration_next_event(
+ self, event: QueueIterationNextEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle iteration next events."""
+ self._ensure_workflow_initialized()
- elif isinstance(event, QueueParallelBranchRunStartedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield iter_next_resp
- parallel_start_resp = (
- self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
- )
+ def _handle_iteration_completed_event(
+ self, event: QueueIterationCompletedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle iteration completed events."""
+ self._ensure_workflow_initialized()
- yield parallel_start_resp
+ iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield iter_finish_resp
- elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ def _handle_loop_start_event(self, event: QueueLoopStartEvent, **kwargs) -> Generator[StreamResponse, None, None]:
+ """Handle loop start events."""
+ self._ensure_workflow_initialized()
- parallel_finish_resp = (
- self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
- )
+ loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield loop_start_resp
- yield parallel_finish_resp
+ def _handle_loop_next_event(self, event: QueueLoopNextEvent, **kwargs) -> Generator[StreamResponse, None, None]:
+ """Handle loop next events."""
+ self._ensure_workflow_initialized()
- elif isinstance(event, QueueIterationStartEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield loop_next_resp
- iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ def _handle_loop_completed_event(
+ self, event: QueueLoopCompletedEvent, **kwargs
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle loop completed events."""
+ self._ensure_workflow_initialized()
- yield iter_start_resp
+ loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution_id=self._workflow_run_id,
+ event=event,
+ )
+ yield loop_finish_resp
- elif isinstance(event, QueueIterationNextEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ def _handle_workflow_succeeded_event(
+ self,
+ event: QueueWorkflowSucceededEvent,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle workflow succeeded events."""
+ self._ensure_workflow_initialized()
+ validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
+
+ with self._database_session() as session:
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
+ workflow_run_id=self._workflow_run_id,
+ total_tokens=validated_state.total_tokens,
+ total_steps=validated_state.node_run_steps,
+ outputs=event.outputs,
+ conversation_id=None,
+ trace_manager=trace_manager,
+ external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
+ )
- iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ # save workflow app log
+ self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
- yield iter_next_resp
+ workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
+ session=session,
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution=workflow_execution,
+ )
- elif isinstance(event, QueueIterationCompletedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ yield workflow_finish_resp
- iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ def _handle_workflow_partial_success_event(
+ self,
+ event: QueueWorkflowPartialSuccessEvent,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle workflow partial success events."""
+ self._ensure_workflow_initialized()
+ validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
+
+ with self._database_session() as session:
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
+ workflow_run_id=self._workflow_run_id,
+ total_tokens=validated_state.total_tokens,
+ total_steps=validated_state.node_run_steps,
+ outputs=event.outputs,
+ exceptions_count=event.exceptions_count,
+ conversation_id=None,
+ trace_manager=trace_manager,
+ external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
+ )
- yield iter_finish_resp
+ # save workflow app log
+ self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
- elif isinstance(event, QueueLoopStartEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
+ session=session,
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution=workflow_execution,
+ )
- loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ yield workflow_finish_resp
- yield loop_start_resp
+ def _handle_workflow_failed_and_stop_events(
+ self,
+ event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle workflow failed and stop events."""
+ self._ensure_workflow_initialized()
+ validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state)
+
+ with self._database_session() as session:
+ workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
+ workflow_run_id=self._workflow_run_id,
+ total_tokens=validated_state.total_tokens,
+ total_steps=validated_state.node_run_steps,
+ status=WorkflowExecutionStatus.FAILED
+ if isinstance(event, QueueWorkflowFailedEvent)
+ else WorkflowExecutionStatus.STOPPED,
+ error_message=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
+ conversation_id=None,
+ trace_manager=trace_manager,
+ exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
+ external_trace_id=self._application_generate_entity.extras.get("external_trace_id"),
+ )
- elif isinstance(event, QueueLoopNextEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ # save workflow app log
+ self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
- loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
+ session=session,
+ task_id=self._application_generate_entity.task_id,
+ workflow_execution=workflow_execution,
+ )
- yield loop_next_resp
+ yield workflow_finish_resp
- elif isinstance(event, QueueLoopCompletedEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
+ def _handle_text_chunk_event(
+ self,
+ event: QueueTextChunkEvent,
+ *,
+ tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
+ queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
+ **kwargs,
+ ) -> Generator[StreamResponse, None, None]:
+ """Handle text chunk events."""
+ delta_text = event.text
+ if delta_text is None:
+ return
- loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response(
- task_id=self._application_generate_entity.task_id,
- workflow_execution_id=self._workflow_run_id,
- event=event,
- )
+ # only publish tts message at text chunk streaming
+ if tts_publisher and queue_message:
+ tts_publisher.publish(queue_message)
- yield loop_finish_resp
-
- elif isinstance(event, QueueWorkflowSucceededEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
- if not graph_runtime_state:
- raise ValueError("graph runtime state not initialized.")
-
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success(
- workflow_run_id=self._workflow_run_id,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- outputs=event.outputs,
- conversation_id=None,
- trace_manager=trace_manager,
- )
+ yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector)
- # save workflow app log
- self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
+ def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
+ """Handle agent log events."""
+ yield self._workflow_response_converter.handle_agent_log(
+ task_id=self._application_generate_entity.task_id, event=event
+ )
- workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
- session=session,
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
- session.commit()
-
- yield workflow_finish_resp
- elif isinstance(event, QueueWorkflowPartialSuccessEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
- if not graph_runtime_state:
- raise ValueError("graph runtime state not initialized.")
-
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success(
- workflow_run_id=self._workflow_run_id,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- outputs=event.outputs,
- exceptions_count=event.exceptions_count,
- conversation_id=None,
- trace_manager=trace_manager,
- )
+ def _get_event_handlers(self) -> dict[type, Callable]:
+ """Get mapping of event types to their handlers using fluent pattern."""
+ return {
+ # Basic events
+ QueuePingEvent: self._handle_ping_event,
+ QueueErrorEvent: self._handle_error_event,
+ QueueTextChunkEvent: self._handle_text_chunk_event,
+ # Workflow events
+ QueueWorkflowStartedEvent: self._handle_workflow_started_event,
+ QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
+ QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
+ # Node events
+ QueueNodeRetryEvent: self._handle_node_retry_event,
+ QueueNodeStartedEvent: self._handle_node_started_event,
+ QueueNodeSucceededEvent: self._handle_node_succeeded_event,
+ # Parallel branch events
+ QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event,
+ # Iteration events
+ QueueIterationStartEvent: self._handle_iteration_start_event,
+ QueueIterationNextEvent: self._handle_iteration_next_event,
+ QueueIterationCompletedEvent: self._handle_iteration_completed_event,
+ # Loop events
+ QueueLoopStartEvent: self._handle_loop_start_event,
+ QueueLoopNextEvent: self._handle_loop_next_event,
+ QueueLoopCompletedEvent: self._handle_loop_completed_event,
+ # Agent events
+ QueueAgentLogEvent: self._handle_agent_log_event,
+ }
+
+ def _dispatch_event(
+ self,
+ event: Any,
+ *,
+ graph_runtime_state: Optional[GraphRuntimeState] = None,
+ tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None,
+ ) -> Generator[StreamResponse, None, None]:
+ """Dispatch events using elegant pattern matching."""
+ handlers = self._get_event_handlers()
+ event_type = type(event)
- # save workflow app log
- self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
+ # Direct handler lookup
+ if handler := handlers.get(event_type):
+ yield from handler(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ return
- workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
- session=session,
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
- session.commit()
-
- yield workflow_finish_resp
- elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
- if not self._workflow_run_id:
- raise ValueError("workflow run not initialized.")
- if not graph_runtime_state:
- raise ValueError("graph runtime state not initialized.")
-
- with Session(db.engine, expire_on_commit=False) as session:
- workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed(
- workflow_run_id=self._workflow_run_id,
- total_tokens=graph_runtime_state.total_tokens,
- total_steps=graph_runtime_state.node_run_steps,
- status=WorkflowExecutionStatus.FAILED
- if isinstance(event, QueueWorkflowFailedEvent)
- else WorkflowExecutionStatus.STOPPED,
- error_message=event.error
- if isinstance(event, QueueWorkflowFailedEvent)
- else event.get_stop_reason(),
- conversation_id=None,
- trace_manager=trace_manager,
- exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
- )
+ # Handle node failure events with isinstance check
+ if isinstance(
+ event,
+ (
+ QueueNodeFailedEvent,
+ QueueNodeInIterationFailedEvent,
+ QueueNodeInLoopFailedEvent,
+ QueueNodeExceptionEvent,
+ ),
+ ):
+ yield from self._handle_node_failed_events(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ return
- # save workflow app log
- self._save_workflow_app_log(session=session, workflow_execution=workflow_execution)
+ # Handle parallel branch finished events with isinstance check
+ if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)):
+ yield from self._handle_parallel_branch_finished_events(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ return
- workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response(
- session=session,
- task_id=self._application_generate_entity.task_id,
- workflow_execution=workflow_execution,
- )
- session.commit()
+ # Handle workflow failed and stop events with isinstance check
+ if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)):
+ yield from self._handle_workflow_failed_and_stop_events(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ return
- yield workflow_finish_resp
- elif isinstance(event, QueueTextChunkEvent):
- delta_text = event.text
- if delta_text is None:
- continue
+ # For unhandled events, we continue (original behavior)
+ return
- # only publish tts message at text chunk streaming
- if tts_publisher:
- tts_publisher.publish(queue_message)
+ def _process_stream_response(
+ self,
+ tts_publisher: Optional[AppGeneratorTTSPublisher] = None,
+ trace_manager: Optional[TraceQueueManager] = None,
+ ) -> Generator[StreamResponse, None, None]:
+ """
+ Process stream response using elegant Fluent Python patterns.
+ Maintains exact same functionality as original 44-if-statement version.
+ """
+ # Initialize graph runtime state
+ graph_runtime_state = None
- yield self._text_chunk_to_stream_response(
- delta_text, from_variable_selector=event.from_variable_selector
- )
- elif isinstance(event, QueueAgentLogEvent):
- yield self._workflow_response_converter.handle_agent_log(
- task_id=self._application_generate_entity.task_id, event=event
- )
- else:
- continue
+ for queue_message in self._base_task_pipeline._queue_manager.listen():
+ event = queue_message.event
+
+ match event:
+ case QueueWorkflowStartedEvent():
+ graph_runtime_state = event.graph_runtime_state
+ yield from self._handle_workflow_started_event(event)
+
+ case QueueTextChunkEvent():
+ yield from self._handle_text_chunk_event(
+ event, tts_publisher=tts_publisher, queue_message=queue_message
+ )
+
+ case QueueErrorEvent():
+ yield from self._handle_error_event(event)
+ break
+
+ # Handle all other events through elegant dispatch
+ case _:
+ if responses := list(
+ self._dispatch_event(
+ event,
+ graph_runtime_state=graph_runtime_state,
+ tts_publisher=tts_publisher,
+ trace_manager=trace_manager,
+ queue_message=queue_message,
+ )
+ ):
+ yield from responses
if tts_publisher:
tts_publisher.publish(None)
def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None:
- workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_))
- assert workflow_run is not None
invoke_from = self._application_generate_entity.invoke_from
if invoke_from == InvokeFrom.SERVICE_API:
created_from = WorkflowAppLogCreatedFrom.SERVICE_API
@@ -576,10 +757,10 @@ class WorkflowAppGenerateTaskPipeline:
return
workflow_app_log = WorkflowAppLog()
- workflow_app_log.tenant_id = workflow_run.tenant_id
- workflow_app_log.app_id = workflow_run.app_id
- workflow_app_log.workflow_id = workflow_run.workflow_id
- workflow_app_log.workflow_run_id = workflow_run.id
+ workflow_app_log.tenant_id = self._application_generate_entity.app_config.tenant_id
+ workflow_app_log.app_id = self._application_generate_entity.app_config.app_id
+ workflow_app_log.workflow_id = workflow_execution.workflow_id
+ workflow_app_log.workflow_run_id = workflow_execution.id_
workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = self._created_by_role
workflow_app_log.created_by = self._user_id
diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py
index 17b9ac5827..948ea95e63 100644
--- a/api/core/app/apps/workflow_app_runner.py
+++ b/api/core/app/apps/workflow_app_runner.py
@@ -1,8 +1,7 @@
from collections.abc import Mapping
-from typing import Any, Optional, cast
+from typing import Any, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
-from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
@@ -62,20 +61,23 @@ from core.workflow.graph_engine.entities.event import (
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
+from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
-from extensions.ext_database import db
-from models.model import App
from models.workflow import Workflow
-class WorkflowBasedAppRunner(AppRunner):
- def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None:
- self.queue_manager = queue_manager
+class WorkflowBasedAppRunner:
+ def __init__(
+ self,
+ *,
+ queue_manager: AppQueueManager,
+ variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
+ app_id: str,
+ ) -> None:
+ self._queue_manager = queue_manager
self._variable_loader = variable_loader
-
- def _get_app_id(self) -> str:
- raise NotImplementedError("not implemented")
+ self._app_id = app_id
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
"""
@@ -166,7 +168,7 @@ class WorkflowBasedAppRunner(AppRunner):
# init variable pool
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
@@ -263,7 +265,7 @@ class WorkflowBasedAppRunner(AppRunner):
# init variable pool
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=workflow.environment_variables,
)
@@ -692,21 +694,5 @@ class WorkflowBasedAppRunner(AppRunner):
)
)
- def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
- """
- Get workflow
- """
- # fetch workflow by workflow_id
- workflow = (
- db.session.query(Workflow)
- .filter(
- Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
- )
- .first()
- )
-
- # return workflow
- return workflow
-
def _publish_event(self, event: AppQueueEvent) -> None:
- self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
+ self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
diff --git a/api/core/app/features/annotation_reply/annotation_reply.py b/api/core/app/features/annotation_reply/annotation_reply.py
index 83fd3debad..54dc69302a 100644
--- a/api/core/app/features/annotation_reply/annotation_reply.py
+++ b/api/core/app/features/annotation_reply/annotation_reply.py
@@ -26,7 +26,7 @@ class AnnotationReplyFeature:
:return:
"""
annotation_setting = (
- db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_record.id).first()
+ db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_record.id).first()
)
if not annotation_setting:
diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
index 3c8c7bb5a2..888434798a 100644
--- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
@@ -471,7 +471,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
:return:
"""
agent_thought: Optional[MessageAgentThought] = (
- db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first()
+ db.session.query(MessageAgentThought).where(MessageAgentThought.id == event.agent_thought_id).first()
)
if agent_thought:
diff --git a/api/core/app/task_pipeline/exc.py b/api/core/app/task_pipeline/exc.py
index e4b4168d08..df62776977 100644
--- a/api/core/app/task_pipeline/exc.py
+++ b/api/core/app/task_pipeline/exc.py
@@ -10,8 +10,3 @@ class RecordNotFoundError(TaskPipilineError):
class WorkflowRunNotFoundError(RecordNotFoundError):
def __init__(self, workflow_run_id: str):
super().__init__("WorkflowRun", workflow_run_id)
-
-
-class WorkflowNodeExecutionNotFoundError(RecordNotFoundError):
- def __init__(self, workflow_node_execution_id: str):
- super().__init__("WorkflowNodeExecution", workflow_node_execution_id)
diff --git a/api/core/app/task_pipeline/message_cycle_manager.py b/api/core/app/task_pipeline/message_cycle_manager.py
index 2343081eaf..824da0b934 100644
--- a/api/core/app/task_pipeline/message_cycle_manager.py
+++ b/api/core/app/task_pipeline/message_cycle_manager.py
@@ -81,7 +81,7 @@ class MessageCycleManager:
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():
# get conversation and message
- conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
+ conversation = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
if not conversation:
return
@@ -140,7 +140,7 @@ class MessageCycleManager:
:param event: event
:return:
"""
- message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
+ message_file = db.session.query(MessageFile).where(MessageFile.id == event.message_file_id).first()
if message_file and message_file.url is not None:
# get tool file id
diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py
index a3a7b4b812..c55ba5e0fe 100644
--- a/api/core/callback_handler/index_tool_callback_handler.py
+++ b/api/core/callback_handler/index_tool_callback_handler.py
@@ -49,7 +49,7 @@ class DatasetIndexToolCallbackHandler:
for document in documents:
if document.metadata is not None:
document_id = document.metadata["document_id"]
- dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
+ dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
if not dataset_document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
@@ -59,7 +59,7 @@ class DatasetIndexToolCallbackHandler:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
- .filter(
+ .where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
@@ -69,18 +69,18 @@ class DatasetIndexToolCallbackHandler:
if child_chunk:
segment = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.id == child_chunk.segment_id)
+ .where(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
)
else:
- query = db.session.query(DocumentSegment).filter(
+ query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
if "dataset_id" in document.metadata:
- query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
+ query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
diff --git a/api/core/entities/parameter_entities.py b/api/core/entities/parameter_entities.py
index 2fa347c204..fbd62437e6 100644
--- a/api/core/entities/parameter_entities.py
+++ b/api/core/entities/parameter_entities.py
@@ -14,6 +14,7 @@ class CommonParameterType(StrEnum):
APP_SELECTOR = "app-selector"
MODEL_SELECTOR = "model-selector"
TOOLS_SELECTOR = "array[tools]"
+ ANY = "any"
# Dynamic select parameter
# Once you are not sure about the available options until authorization is done
diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py
index 66d8d0f414..af5c18e267 100644
--- a/api/core/entities/provider_configuration.py
+++ b/api/core/entities/provider_configuration.py
@@ -191,7 +191,7 @@ class ProviderConfiguration(BaseModel):
provider_record = (
db.session.query(Provider)
- .filter(
+ .where(
Provider.tenant_id == self.tenant_id,
Provider.provider_type == ProviderType.CUSTOM.value,
Provider.provider_name.in_(provider_names),
@@ -351,7 +351,7 @@ class ProviderConfiguration(BaseModel):
provider_model_record = (
db.session.query(ProviderModel)
- .filter(
+ .where(
ProviderModel.tenant_id == self.tenant_id,
ProviderModel.provider_name.in_(provider_names),
ProviderModel.model_name == model,
@@ -481,7 +481,7 @@ class ProviderConfiguration(BaseModel):
return (
db.session.query(ProviderModelSetting)
- .filter(
+ .where(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
@@ -560,7 +560,7 @@ class ProviderConfiguration(BaseModel):
return (
db.session.query(LoadBalancingModelConfig)
- .filter(
+ .where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
@@ -583,7 +583,7 @@ class ProviderConfiguration(BaseModel):
load_balancing_config_count = (
db.session.query(LoadBalancingModelConfig)
- .filter(
+ .where(
LoadBalancingModelConfig.tenant_id == self.tenant_id,
LoadBalancingModelConfig.provider_name.in_(provider_names),
LoadBalancingModelConfig.model_type == model_type.to_origin_model_type(),
@@ -627,7 +627,7 @@ class ProviderConfiguration(BaseModel):
model_setting = (
db.session.query(ProviderModelSetting)
- .filter(
+ .where(
ProviderModelSetting.tenant_id == self.tenant_id,
ProviderModelSetting.provider_name.in_(provider_names),
ProviderModelSetting.model_type == model_type.to_origin_model_type(),
@@ -693,7 +693,7 @@ class ProviderConfiguration(BaseModel):
preferred_model_provider = (
db.session.query(TenantPreferredModelProvider)
- .filter(
+ .where(
TenantPreferredModelProvider.tenant_id == self.tenant_id,
TenantPreferredModelProvider.provider_name.in_(provider_names),
)
diff --git a/api/core/external_data_tool/api/api.py b/api/core/external_data_tool/api/api.py
index 53acdf075f..2099a9e34c 100644
--- a/api/core/external_data_tool/api/api.py
+++ b/api/core/external_data_tool/api/api.py
@@ -32,7 +32,7 @@ class ApiExternalDataTool(ExternalDataTool):
# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
- .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
+ .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
)
@@ -56,7 +56,7 @@ class ApiExternalDataTool(ExternalDataTool):
# get api_based_extension
api_based_extension = (
db.session.query(APIBasedExtension)
- .filter(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
+ .where(APIBasedExtension.tenant_id == self.tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
)
diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py
index ada19ef8ce..f8c050c2ac 100644
--- a/api/core/file/file_manager.py
+++ b/api/core/file/file_manager.py
@@ -7,6 +7,7 @@ from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
+ TextPromptMessageContent,
VideoPromptMessageContent,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
@@ -44,11 +45,44 @@ def to_prompt_message_content(
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
) -> PromptMessageContentUnionTypes:
+ """
+ Convert a file to prompt message content.
+
+ This function converts files to their appropriate prompt message content types.
+ For supported file types (IMAGE, AUDIO, VIDEO, DOCUMENT), it creates the
+ corresponding message content with proper encoding/URL.
+
+ For unsupported file types, instead of raising an error, it returns a
+ TextPromptMessageContent with a descriptive message about the file.
+
+ Args:
+ f: The file to convert
+ image_detail_config: Optional detail configuration for image files
+
+ Returns:
+ PromptMessageContentUnionTypes: The appropriate message content type
+
+ Raises:
+ ValueError: If file extension or mime_type is missing
+ """
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
raise ValueError("Missing file mime_type")
+ prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
+ FileType.IMAGE: ImagePromptMessageContent,
+ FileType.AUDIO: AudioPromptMessageContent,
+ FileType.VIDEO: VideoPromptMessageContent,
+ FileType.DOCUMENT: DocumentPromptMessageContent,
+ }
+
+ # Check if file type is supported
+ if f.type not in prompt_class_map:
+ # For unsupported file types, return a text description
+ return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]")
+
+ # Process supported file types
params = {
"base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "",
"url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "",
@@ -58,17 +92,7 @@ def to_prompt_message_content(
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
- prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
- FileType.IMAGE: ImagePromptMessageContent,
- FileType.AUDIO: AudioPromptMessageContent,
- FileType.VIDEO: VideoPromptMessageContent,
- FileType.DOCUMENT: DocumentPromptMessageContent,
- }
-
- try:
- return prompt_class_map[f.type].model_validate(params)
- except KeyError:
- raise ValueError(f"file type {f.type} is not supported")
+ return prompt_class_map[f.type].model_validate(params)
def download(f: File, /):
diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py
index 656c9d48ed..fac68beb0f 100644
--- a/api/core/file/tool_file_parser.py
+++ b/api/core/file/tool_file_parser.py
@@ -7,13 +7,6 @@ if TYPE_CHECKING:
_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None
-class ToolFileParser:
- @staticmethod
- def get_tool_file_manager() -> "ToolFileManager":
- assert _tool_file_manager_factory is not None
- return _tool_file_manager_factory()
-
-
def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None:
global _tool_file_manager_factory
_tool_file_manager_factory = factory
diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py
index 84f212a9c1..b416e48ce4 100644
--- a/api/core/helper/code_executor/template_transformer.py
+++ b/api/core/helper/code_executor/template_transformer.py
@@ -5,6 +5,8 @@ from base64 import b64encode
from collections.abc import Mapping
from typing import Any
+from core.variables.utils import SegmentJSONEncoder
+
class TemplateTransformer(ABC):
_code_placeholder: str = "{{code}}"
@@ -43,17 +45,13 @@ class TemplateTransformer(ABC):
result_str = cls.extract_result_str_from_response(response)
result = json.loads(result_str)
except json.JSONDecodeError as e:
- raise ValueError(f"Failed to parse JSON response: {str(e)}. Response content: {result_str[:200]}...")
+ raise ValueError(f"Failed to parse JSON response: {str(e)}.")
except ValueError as e:
# Re-raise ValueError from extract_result_str_from_response
raise e
except Exception as e:
raise ValueError(f"Unexpected error during response transformation: {str(e)}")
- # Check if the result contains an error
- if isinstance(result, dict) and "error" in result:
- raise ValueError(f"JavaScript execution error: {result['error']}")
-
if not isinstance(result, dict):
raise ValueError(f"Result must be a dict, got {type(result).__name__}")
if not all(isinstance(k, str) for k in result):
@@ -95,7 +93,7 @@ class TemplateTransformer(ABC):
@classmethod
def serialize_inputs(cls, inputs: Mapping[str, Any]) -> str:
- inputs_json_str = json.dumps(inputs, ensure_ascii=False).encode()
+ inputs_json_str = json.dumps(inputs, ensure_ascii=False, cls=SegmentJSONEncoder).encode()
input_base64_encoded = b64encode(inputs_json_str).decode("utf-8")
return input_base64_encoded
diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py
index 744fce1cf9..f761d20374 100644
--- a/api/core/helper/encrypter.py
+++ b/api/core/helper/encrypter.py
@@ -15,13 +15,13 @@ def encrypt_token(tenant_id: str, token: str):
from models.account import Tenant
from models.engine import db
- if not (tenant := db.session.query(Tenant).filter(Tenant.id == tenant_id).first()):
+ if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
raise ValueError(f"Tenant with id {tenant_id} not found")
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
return base64.b64encode(encrypted_token).decode()
-def decrypt_token(tenant_id: str, token: str):
+def decrypt_token(tenant_id: str, token: str) -> str:
return rsa.decrypt(base64.b64decode(token), tenant_id)
diff --git a/api/core/helper/marketplace.py b/api/core/helper/marketplace.py
index 65bf4fc1db..fe3078923d 100644
--- a/api/core/helper/marketplace.py
+++ b/api/core/helper/marketplace.py
@@ -25,9 +25,29 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
+
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
+def batch_fetch_plugin_manifests_ignore_deserialization_error(
+ plugin_ids: list[str],
+) -> Sequence[MarketplacePluginDeclaration]:
+ if len(plugin_ids) == 0:
+ return []
+
+ url = str(marketplace_api_url / "api/v1/plugins/batch")
+ response = requests.post(url, json={"plugin_ids": plugin_ids})
+ response.raise_for_status()
+ result: list[MarketplacePluginDeclaration] = []
+ for plugin in response.json()["data"]["plugins"]:
+ try:
+ result.append(MarketplacePluginDeclaration(**plugin))
+ except Exception as e:
+ pass
+
+ return result
+
+
def record_install_plugin_event(plugin_unique_identifier: str):
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})
diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py
new file mode 100644
index 0000000000..48ec3be5c8
--- /dev/null
+++ b/api/core/helper/provider_cache.py
@@ -0,0 +1,84 @@
+import json
+from abc import ABC, abstractmethod
+from json import JSONDecodeError
+from typing import Any, Optional
+
+from extensions.ext_redis import redis_client
+
+
+class ProviderCredentialsCache(ABC):
+ """Base class for provider credentials cache"""
+
+ def __init__(self, **kwargs):
+ self.cache_key = self._generate_cache_key(**kwargs)
+
+ @abstractmethod
+ def _generate_cache_key(self, **kwargs) -> str:
+ """Generate cache key based on subclass implementation"""
+ pass
+
+ def get(self) -> Optional[dict]:
+ """Get cached provider credentials"""
+ cached_credentials = redis_client.get(self.cache_key)
+ if cached_credentials:
+ try:
+ cached_credentials = cached_credentials.decode("utf-8")
+ return dict(json.loads(cached_credentials))
+ except JSONDecodeError:
+ return None
+ return None
+
+ def set(self, config: dict[str, Any]) -> None:
+ """Cache provider credentials"""
+ redis_client.setex(self.cache_key, 86400, json.dumps(config))
+
+ def delete(self) -> None:
+ """Delete cached provider credentials"""
+ redis_client.delete(self.cache_key)
+
+
+class SingletonProviderCredentialsCache(ProviderCredentialsCache):
+ """Cache for tool single provider credentials"""
+
+ def __init__(self, tenant_id: str, provider_type: str, provider_identity: str):
+ super().__init__(
+ tenant_id=tenant_id,
+ provider_type=provider_type,
+ provider_identity=provider_identity,
+ )
+
+ def _generate_cache_key(self, **kwargs) -> str:
+ tenant_id = kwargs["tenant_id"]
+ provider_type = kwargs["provider_type"]
+ identity_name = kwargs["provider_identity"]
+ identity_id = f"{provider_type}.{identity_name}"
+ return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
+
+
+class ToolProviderCredentialsCache(ProviderCredentialsCache):
+ """Cache for tool provider credentials"""
+
+ def __init__(self, tenant_id: str, provider: str, credential_id: str):
+ super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id)
+
+ def _generate_cache_key(self, **kwargs) -> str:
+ tenant_id = kwargs["tenant_id"]
+ provider = kwargs["provider"]
+ credential_id = kwargs["credential_id"]
+ return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}"
+
+
+class NoOpProviderCredentialCache:
+ """No-op provider credential cache"""
+
+ def get(self) -> Optional[dict]:
+ """Get cached provider credentials"""
+ return None
+
+ def set(self, config: dict[str, Any]) -> None:
+ """Cache provider credentials"""
+ pass
+
+ def delete(self) -> None:
+ """Delete cached provider credentials"""
+ pass
diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py
deleted file mode 100644
index 2e4a04c579..0000000000
--- a/api/core/helper/tool_provider_cache.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import json
-from enum import Enum
-from json import JSONDecodeError
-from typing import Optional
-
-from extensions.ext_redis import redis_client
-
-
-class ToolProviderCredentialsCacheType(Enum):
- PROVIDER = "tool_provider"
- ENDPOINT = "endpoint"
-
-
-class ToolProviderCredentialsCache:
- def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType):
- self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}"
-
- def get(self) -> Optional[dict]:
- """
- Get cached model provider credentials.
-
- :return:
- """
- cached_provider_credentials = redis_client.get(self.cache_key)
- if cached_provider_credentials:
- try:
- cached_provider_credentials = cached_provider_credentials.decode("utf-8")
- cached_provider_credentials = json.loads(cached_provider_credentials)
- except JSONDecodeError:
- return None
-
- return dict(cached_provider_credentials)
- else:
- return None
-
- def set(self, credentials: dict) -> None:
- """
- Cache model provider credentials.
-
- :param credentials: provider credentials
- :return:
- """
- redis_client.setex(self.cache_key, 86400, json.dumps(credentials))
-
- def delete(self) -> None:
- """
- Delete cached model provider credentials.
-
- :return:
- """
- redis_client.delete(self.cache_key)
diff --git a/api/core/helper/trace_id_helper.py b/api/core/helper/trace_id_helper.py
new file mode 100644
index 0000000000..e90c3194f2
--- /dev/null
+++ b/api/core/helper/trace_id_helper.py
@@ -0,0 +1,42 @@
+import re
+from collections.abc import Mapping
+from typing import Any, Optional
+
+
+def is_valid_trace_id(trace_id: str) -> bool:
+ """
+ Check if the trace_id is valid.
+
+ Requirements: 1-128 characters, only letters, numbers, '-', and '_'.
+ """
+ return bool(re.match(r"^[a-zA-Z0-9\-_]{1,128}$", trace_id))
+
+
+def get_external_trace_id(request: Any) -> Optional[str]:
+ """
+ Retrieve the trace_id from the request.
+
+ Priority: header ('X-Trace-Id'), then parameters, then JSON body. Returns None if not provided or invalid.
+ """
+ trace_id = request.headers.get("X-Trace-Id")
+ if not trace_id:
+ trace_id = request.args.get("trace_id")
+ if not trace_id and getattr(request, "is_json", False):
+ json_data = getattr(request, "json", None)
+ if json_data:
+ trace_id = json_data.get("trace_id")
+ if isinstance(trace_id, str) and is_valid_trace_id(trace_id):
+ return trace_id
+ return None
+
+
+def extract_external_trace_id_from_args(args: Mapping[str, Any]) -> dict:
+ """
+ Extract 'external_trace_id' from args.
+
+ Returns a dict suitable for use in extras. Returns an empty dict if not found.
+ """
+ trace_id = args.get("external_trace_id")
+ if trace_id:
+ return {"external_trace_id": trace_id}
+ return {}
diff --git a/api/core/helper/url_signer.py b/api/core/helper/url_signer.py
deleted file mode 100644
index dfb143f4c4..0000000000
--- a/api/core/helper/url_signer.py
+++ /dev/null
@@ -1,52 +0,0 @@
-import base64
-import hashlib
-import hmac
-import os
-import time
-
-from pydantic import BaseModel, Field
-
-from configs import dify_config
-
-
-class SignedUrlParams(BaseModel):
- sign_key: str = Field(..., description="The sign key")
- timestamp: str = Field(..., description="Timestamp")
- nonce: str = Field(..., description="Nonce")
- sign: str = Field(..., description="Signature")
-
-
-class UrlSigner:
- @classmethod
- def get_signed_url(cls, url: str, sign_key: str, prefix: str) -> str:
- signed_url_params = cls.get_signed_url_params(sign_key, prefix)
- return (
- f"{url}?timestamp={signed_url_params.timestamp}"
- f"&nonce={signed_url_params.nonce}&sign={signed_url_params.sign}"
- )
-
- @classmethod
- def get_signed_url_params(cls, sign_key: str, prefix: str) -> SignedUrlParams:
- timestamp = str(int(time.time()))
- nonce = os.urandom(16).hex()
- sign = cls._sign(sign_key, timestamp, nonce, prefix)
-
- return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign)
-
- @classmethod
- def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool:
- recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix)
-
- return sign == recalculated_sign
-
- @classmethod
- def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str:
- if not dify_config.SECRET_KEY:
- raise Exception("SECRET_KEY is not set")
-
- data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}"
- secret_key = dify_config.SECRET_KEY.encode()
- sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
- encoded_sign = base64.urlsafe_b64encode(sign).decode()
-
- return encoded_sign
diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py
index 305a9190d5..fc5d0547fc 100644
--- a/api/core/indexing_runner.py
+++ b/api/core/indexing_runner.py
@@ -59,7 +59,7 @@ class IndexingRunner:
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
- .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
+ .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
if not processing_rule:
@@ -119,12 +119,12 @@ class IndexingRunner:
db.session.delete(document_segment)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# delete child chunks
- db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete()
+ db.session.query(ChildChunk).where(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit()
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
- .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
+ .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
if not processing_rule:
@@ -212,7 +212,7 @@ class IndexingRunner:
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
- .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
+ .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
@@ -316,7 +316,7 @@ class IndexingRunner:
# delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content)
for upload_file_id in image_upload_file_ids:
- image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
+ image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
try:
@@ -346,7 +346,7 @@ class IndexingRunner:
raise ValueError("no upload file found")
file_detail = (
- db.session.query(UploadFile).filter(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
+ db.session.query(UploadFile).where(UploadFile.id == data_source_info["upload_file_id"]).one_or_none()
)
if file_detail:
@@ -599,7 +599,7 @@ class IndexingRunner:
keyword.create(documents)
if dataset.indexing_technique != "high_quality":
document_ids = [document.metadata["doc_id"] for document in documents]
- db.session.query(DocumentSegment).filter(
+ db.session.query(DocumentSegment).where(
DocumentSegment.document_id == document_id,
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.index_node_id.in_(document_ids),
@@ -630,7 +630,7 @@ class IndexingRunner:
index_processor.load(dataset, chunk_documents, with_keywords=False)
document_ids = [document.metadata["doc_id"] for document in chunk_documents]
- db.session.query(DocumentSegment).filter(
+ db.session.query(DocumentSegment).where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(document_ids),
@@ -672,8 +672,7 @@ class IndexingRunner:
if extra_update_params:
update_params.update(extra_update_params)
-
- db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params)
+ db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) # type: ignore
db.session.commit()
@staticmethod
diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py
index e01896a491..331ac933c8 100644
--- a/api/core/llm_generator/llm_generator.py
+++ b/api/core/llm_generator/llm_generator.py
@@ -114,7 +114,8 @@ class LLMGenerator:
),
)
- questions = output_parser.parse(cast(str, response.message.content))
+ text_content = response.message.get_text_content()
+ questions = output_parser.parse(text_content) if text_content else []
except InvokeError:
questions = []
except Exception:
@@ -148,9 +149,11 @@ class LLMGenerator:
model_manager = ModelManager()
- model_instance = model_manager.get_default_model_instance(
+ model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
+ provider=model_config.get("provider", ""),
+ model=model_config.get("name", ""),
)
try:
diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py
index c451bf514c..98cdc4c8b7 100644
--- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py
+++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py
@@ -15,5 +15,4 @@ class SuggestedQuestionsAfterAnswerOutputParser:
json_obj = json.loads(action_match.group(0).strip())
else:
json_obj = []
-
return json_obj
diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py
index b63478e822..bcb31a816f 100644
--- a/api/core/mcp/auth/auth_flow.py
+++ b/api/core/mcp/auth/auth_flow.py
@@ -240,7 +240,7 @@ def refresh_authorization(
response = requests.post(token_url, data=params)
if not response.ok:
raise ValueError(f"Token refresh failed: HTTP {response.status_code}")
- return OAuthTokens.parse_obj(response.json())
+ return OAuthTokens.model_validate(response.json())
def register_client(
diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py
index cd55dbf64f..00d5a25956 100644
--- a/api/core/mcp/auth/auth_provider.py
+++ b/api/core/mcp/auth/auth_provider.py
@@ -8,7 +8,7 @@ from core.mcp.types import (
OAuthTokens,
)
from models.tools import MCPToolProvider
-from services.tools.mcp_tools_mange_service import MCPToolManageService
+from services.tools.mcp_tools_manage_service import MCPToolManageService
LATEST_PROTOCOL_VERSION = "1.0"
diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py
index e9036de8c6..5fe52c008a 100644
--- a/api/core/mcp/mcp_client.py
+++ b/api/core/mcp/mcp_client.py
@@ -68,15 +68,17 @@ class MCPClient:
}
parsed_url = urlparse(self.server_url)
- path = parsed_url.path
+ path = parsed_url.path or ""
method_name = path.rstrip("/").split("/")[-1] if path else ""
- try:
+ if method_name in connection_methods:
client_factory = connection_methods[method_name]
self.connect_server(client_factory, method_name)
- except KeyError:
+ else:
try:
+ logger.debug(f"Not supported method {method_name} found in URL path, trying default 'mcp' method.")
self.connect_server(sse_client, "sse")
except MCPConnectionError:
+ logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.")
self.connect_server(streamablehttp_client, "mcp")
def connect_server(
@@ -91,7 +93,7 @@ class MCPClient:
else {}
)
self._streams_context = client_factory(url=self.server_url, headers=headers)
- if self._streams_context is None:
+ if not self._streams_context:
raise MCPConnectionError("Failed to create connection context")
# Use exit_stack to manage context managers properly
@@ -141,10 +143,11 @@ class MCPClient:
try:
# ExitStack will handle proper cleanup of all managed context managers
self.exit_stack.close()
+ except Exception as e:
+ logging.exception("Error during cleanup")
+ raise ValueError(f"Error during cleanup: {e}")
+ finally:
self._session = None
self._session_context = None
self._streams_context = None
self._initialized = False
- except Exception as e:
- logging.exception("Error during cleanup")
- raise ValueError(f"Error during cleanup: {e}")
diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py
index 1c2cf570e2..496b5432a0 100644
--- a/api/core/mcp/server/streamable_http.py
+++ b/api/core/mcp/server/streamable_http.py
@@ -28,7 +28,7 @@ class MCPServerStreamableHTTPRequestHandler:
):
self.app = app
self.request = request
- mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.app_id == self.app.id).first()
+ mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.app_id == self.app.id).first()
if not mcp_server:
raise ValueError("MCP server not found")
self.mcp_server: AppMCPServer = mcp_server
@@ -148,9 +148,7 @@ class MCPServerStreamableHTTPRequestHandler:
if not self.end_user:
raise ValueError("User not found")
request = cast(types.CallToolRequest, self.request.root)
- args = request.params.arguments
- if not args:
- raise ValueError("No arguments provided")
+ args = request.params.arguments or {}
if self.app.mode in {AppMode.WORKFLOW.value}:
args = {"inputs": args}
elif self.app.mode in {AppMode.COMPLETION.value}:
@@ -194,7 +192,7 @@ class MCPServerStreamableHTTPRequestHandler:
def retrieve_end_user(self):
return (
db.session.query(EndUser)
- .filter(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
+ .where(EndUser.external_user_id == self.mcp_server.id, EndUser.type == "mcp")
.first()
)
diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py
index 1c0f582501..7734b8fdd9 100644
--- a/api/core/mcp/session/base_session.py
+++ b/api/core/mcp/session/base_session.py
@@ -1,7 +1,7 @@
import logging
import queue
from collections.abc import Callable
-from concurrent.futures import ThreadPoolExecutor
+from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError
from contextlib import ExitStack
from datetime import timedelta
from types import TracebackType
@@ -171,23 +171,41 @@ class BaseSession(
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._exit_stack = ExitStack()
+ # Initialize executor and future to None for proper cleanup checks
+ self._executor: ThreadPoolExecutor | None = None
+ self._receiver_future: Future | None = None
def __enter__(self) -> Self:
- self._executor = ThreadPoolExecutor()
+ # The thread pool is dedicated to running `_receive_loop`. Setting `max_workers` to 1
+ # ensures no unnecessary threads are created.
+ self._executor = ThreadPoolExecutor(max_workers=1)
self._receiver_future = self._executor.submit(self._receive_loop)
return self
def check_receiver_status(self) -> None:
- if self._receiver_future.done():
+ """`check_receiver_status` ensures that any exceptions raised during the
+ execution of `_receive_loop` are retrieved and propagated."""
+ if self._receiver_future and self._receiver_future.done():
self._receiver_future.result()
def __exit__(
self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None
) -> None:
- self._exit_stack.close()
self._read_stream.put(None)
self._write_stream.put(None)
+ # Wait for the receiver loop to finish
+ if self._receiver_future:
+ try:
+ self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds
+ except TimeoutError:
+ # If the receiver loop is still running after timeout, we'll force shutdown
+ pass
+
+ # Shutdown the executor
+ if self._executor:
+ self._executor.shutdown(wait=True)
+
def send_request(
self,
request: SendRequestT,
diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py
index 2254b3d4d5..7ce124594a 100644
--- a/api/core/memory/token_buffer_memory.py
+++ b/api/core/memory/token_buffer_memory.py
@@ -1,6 +1,8 @@
from collections.abc import Sequence
from typing import Optional
+from sqlalchemy import select
+
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.file import file_manager
from core.model_manager import ModelInstance
@@ -17,11 +19,15 @@ from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from factories import file_factory
from models.model import AppMode, Conversation, Message, MessageFile
-from models.workflow import WorkflowRun
+from models.workflow import Workflow, WorkflowRun
class TokenBufferMemory:
- def __init__(self, conversation: Conversation, model_instance: ModelInstance) -> None:
+ def __init__(
+ self,
+ conversation: Conversation,
+ model_instance: ModelInstance,
+ ) -> None:
self.conversation = conversation
self.model_instance = model_instance
@@ -36,20 +42,8 @@ class TokenBufferMemory:
app_record = self.conversation.app
# fetch limited messages, and return reversed
- query = (
- db.session.query(
- Message.id,
- Message.query,
- Message.answer,
- Message.created_at,
- Message.workflow_run_id,
- Message.parent_message_id,
- Message.answer_tokens,
- )
- .filter(
- Message.conversation_id == self.conversation.id,
- )
- .order_by(Message.created_at.desc())
+ stmt = (
+ select(Message).where(Message.conversation_id == self.conversation.id).order_by(Message.created_at.desc())
)
if message_limit and message_limit > 0:
@@ -57,7 +51,9 @@ class TokenBufferMemory:
else:
message_limit = 500
- messages = query.limit(message_limit).all()
+ stmt = stmt.limit(message_limit)
+
+ messages = db.session.scalars(stmt).all()
# instead of all messages from the conversation, we only need to extract messages
# that belong to the thread of last message
@@ -71,21 +67,23 @@ class TokenBufferMemory:
prompt_messages: list[PromptMessage] = []
for message in messages:
- files = db.session.query(MessageFile).filter(MessageFile.message_id == message.id).all()
+ files = db.session.query(MessageFile).where(MessageFile.message_id == message.id).all()
if files:
file_extra_config = None
- if self.conversation.mode not in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
+ if self.conversation.mode in {AppMode.AGENT_CHAT, AppMode.COMPLETION, AppMode.CHAT}:
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
+ elif self.conversation.mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
+ workflow_run = db.session.scalar(
+ select(WorkflowRun).where(WorkflowRun.id == message.workflow_run_id)
+ )
+ if not workflow_run:
+ raise ValueError(f"Workflow run not found: {message.workflow_run_id}")
+ workflow = db.session.scalar(select(Workflow).where(Workflow.id == workflow_run.workflow_id))
+ if not workflow:
+ raise ValueError(f"Workflow not found: {workflow_run.workflow_id}")
+ file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
else:
- if message.workflow_run_id:
- workflow_run = (
- db.session.query(WorkflowRun).filter(WorkflowRun.id == message.workflow_run_id).first()
- )
-
- if workflow_run and workflow_run.workflow:
- file_extra_config = FileUploadConfigManager.convert(
- workflow_run.workflow.features_dict, is_vision=False
- )
+ raise AssertionError(f"Invalid app mode: {self.conversation.mode}")
detail = ImagePromptMessageContent.DETAIL.LOW
if file_extra_config and app_record:
diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py
index 9d010ae28d..83dc7f0525 100644
--- a/api/core/model_runtime/entities/message_entities.py
+++ b/api/core/model_runtime/entities/message_entities.py
@@ -156,6 +156,23 @@ class PromptMessage(ABC, BaseModel):
"""
return not self.content
+ def get_text_content(self) -> str:
+ """
+ Get text content from prompt message.
+
+ :return: Text content as string, empty string if no text content
+ """
+ if isinstance(self.content, str):
+ return self.content
+ elif isinstance(self.content, list):
+ text_parts = []
+ for item in self.content:
+ if isinstance(item, TextPromptMessageContent):
+ text_parts.append(item.data)
+ return "".join(text_parts)
+ else:
+ return ""
+
@field_validator("content", mode="before")
@classmethod
def validate_content(cls, v):
diff --git a/api/core/moderation/api/api.py b/api/core/moderation/api/api.py
index c65a3885fd..332381555b 100644
--- a/api/core/moderation/api/api.py
+++ b/api/core/moderation/api/api.py
@@ -89,7 +89,7 @@ class ApiModeration(Moderation):
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> Optional[APIBasedExtension]:
extension = (
db.session.query(APIBasedExtension)
- .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
+ .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
)
diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py
index b18a6905fe..cf367efdf0 100644
--- a/api/core/ops/aliyun_trace/aliyun_trace.py
+++ b/api/core/ops/aliyun_trace/aliyun_trace.py
@@ -101,7 +101,8 @@ class AliyunDataTrace(BaseTraceInstance):
raise ValueError(f"Aliyun get run url failed: {str(e)}")
def workflow_trace(self, trace_info: WorkflowTraceInfo):
- trace_id = convert_to_trace_id(trace_info.workflow_run_id)
+ external_trace_id = trace_info.metadata.get("external_trace_id")
+ trace_id = external_trace_id or convert_to_trace_id(trace_info.workflow_run_id)
workflow_span_id = convert_to_span_id(trace_info.workflow_run_id, "workflow")
self.add_workflow_span(trace_id, workflow_span_id, trace_info)
@@ -119,7 +120,7 @@ class AliyunDataTrace(BaseTraceInstance):
user_id = message_data.from_account_id
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
- db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
+ db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
user_id = end_user_data.session_id
@@ -243,14 +244,14 @@ class AliyunDataTrace(BaseTraceInstance):
if not app_id:
raise ValueError("No app_id found in trace_info metadata")
- app = session.query(App).filter(App.id == app_id).first()
+ app = session.query(App).where(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
- service_account = session.query(Account).filter(Account.id == app.created_by).first()
+ service_account = session.query(Account).where(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
current_tenant = (
@@ -284,7 +285,8 @@ class AliyunDataTrace(BaseTraceInstance):
else:
node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution)
return node_span
- except Exception:
+ except Exception as e:
+ logging.debug(f"Error occurred in build_workflow_node_span: {e}", exc_info=True)
return None
def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status:
@@ -306,7 +308,7 @@ class AliyunDataTrace(BaseTraceInstance):
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
- GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
+ GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value,
GEN_AI_FRAMEWORK: "dify",
INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False),
@@ -381,7 +383,7 @@ class AliyunDataTrace(BaseTraceInstance):
start_time=convert_datetime_to_nanoseconds(node_execution.created_at),
end_time=convert_datetime_to_nanoseconds(node_execution.finished_at),
attributes={
- GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
+ GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value,
GEN_AI_FRAMEWORK: "dify",
GEN_AI_MODEL_NAME: process_data.get("model_name", ""),
@@ -415,7 +417,7 @@ class AliyunDataTrace(BaseTraceInstance):
start_time=convert_datetime_to_nanoseconds(trace_info.start_time),
end_time=convert_datetime_to_nanoseconds(trace_info.end_time),
attributes={
- GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""),
+ GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "",
GEN_AI_USER_ID: str(user_id),
GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value,
GEN_AI_FRAMEWORK: "dify",
diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
index ffda0885d4..1b72a4775a 100644
--- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
+++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py
@@ -3,7 +3,7 @@ import json
import logging
import os
from datetime import datetime, timedelta
-from typing import Optional, Union, cast
+from typing import Any, Optional, Union, cast
from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes
from opentelemetry import trace
@@ -142,11 +142,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
raise
def workflow_trace(self, trace_info: WorkflowTraceInfo):
- if trace_info.message_data is None:
- return
-
workflow_metadata = {
- "workflow_id": trace_info.workflow_run_id or "",
+ "workflow_run_id": trace_info.workflow_run_id or "",
"message_id": trace_info.message_id or "",
"workflow_app_log_id": trace_info.workflow_app_log_id or "",
"status": trace_info.workflow_run_status or "",
@@ -156,7 +153,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
}
workflow_metadata.update(trace_info.metadata)
- trace_id = uuid_to_trace_id(trace_info.message_id)
+ external_trace_id = trace_info.metadata.get("external_trace_id")
+ trace_id = external_trace_id or uuid_to_trace_id(trace_info.workflow_run_id)
span_id = RandomIdGenerator().generate_span_id()
context = SpanContext(
trace_id=trace_id,
@@ -213,7 +211,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
if model:
node_metadata["ls_model_name"] = model
- outputs = json.loads(node_execution.outputs).get("usage", {})
+ outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
if usage_data:
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
@@ -236,31 +234,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.SESSION_ID: trace_info.conversation_id or "",
},
start_time=datetime_to_nanos(created_at),
+ context=trace.set_span_in_context(trace.NonRecordingSpan(context)),
)
try:
if node_execution.node_type == "llm":
+ llm_attributes: dict[str, Any] = {
+ SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False),
+ }
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider:
- node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider)
+ llm_attributes[SpanAttributes.LLM_PROVIDER] = provider
if model:
- node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model)
-
- outputs = json.loads(node_execution.outputs).get("usage", {})
+ llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model
+ outputs = (
+ json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {}
+ )
usage_data = (
process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {})
)
if usage_data:
- node_span.set_attribute(
- SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage_data.get("total_tokens", 0)
- )
- node_span.set_attribute(
- SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage_data.get("prompt_tokens", 0)
- )
- node_span.set_attribute(
- SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage_data.get("completion_tokens", 0)
+ llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0)
+ llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = usage_data.get("prompt_tokens", 0)
+ llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = usage_data.get(
+ "completion_tokens", 0
)
+ llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", [])))
+ node_span.set_attributes(llm_attributes)
finally:
node_span.end(end_time=datetime_to_nanos(finished_at))
finally:
@@ -296,7 +297,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
# Add end user data if available
if trace_info.message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
- db.session.query(EndUser).filter(EndUser.id == trace_info.message_data.from_end_user_id).first()
+ db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first()
)
if end_user_data is not None:
message_metadata["end_user_id"] = end_user_data.session_id
@@ -352,25 +353,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False),
SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id,
}
-
- if isinstance(trace_info.inputs, list):
- for i, msg in enumerate(trace_info.inputs):
- if isinstance(msg, dict):
- llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
- llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get(
- "role", "user"
- )
- # todo: handle assistant and tool role messages, as they don't always
- # have a text field, but may have a tool_calls field instead
- # e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
- # 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
- elif isinstance(trace_info.inputs, dict):
- llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs)
- llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
- elif isinstance(trace_info.inputs, str):
- llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs
- llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
-
+ llm_attributes.update(self._construct_llm_attributes(trace_info.inputs))
if trace_info.total_tokens is not None and trace_info.total_tokens > 0:
llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens
if trace_info.message_tokens is not None and trace_info.message_tokens > 0:
@@ -720,7 +703,28 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
WorkflowNodeExecutionModel.process_data,
WorkflowNodeExecutionModel.execution_metadata,
)
- .filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
+ .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id)
.all()
)
return workflow_nodes
+
+ def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]:
+ """Helper method to construct LLM attributes with passed prompts."""
+ attributes = {}
+ if isinstance(prompts, list):
+ for i, msg in enumerate(prompts):
+ if isinstance(msg, dict):
+ attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "")
+ attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user")
+ # todo: handle assistant and tool role messages, as they don't always
+ # have a text field, but may have a tool_calls field instead
+ # e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58',
+ # 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]}
+ elif isinstance(prompts, dict):
+ attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts)
+ attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
+ elif isinstance(prompts, str):
+ attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts
+ attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user"
+
+ return attributes
diff --git a/api/core/ops/base_trace_instance.py b/api/core/ops/base_trace_instance.py
index 8593198bc2..f8e428daf1 100644
--- a/api/core/ops/base_trace_instance.py
+++ b/api/core/ops/base_trace_instance.py
@@ -44,14 +44,14 @@ class BaseTraceInstance(ABC):
"""
with Session(db.engine, expire_on_commit=False) as session:
# Get the app to find its creator
- app = session.query(App).filter(App.id == app_id).first()
+ app = session.query(App).where(App.id == app_id).first()
if not app:
raise ValueError(f"App with id {app_id} not found")
if not app.created_by:
raise ValueError(f"App with id {app_id} has no creator (created_by is None)")
- service_account = session.query(Account).filter(Account.id == app.created_by).first()
+ service_account = session.query(Account).where(Account.id == app.created_by).first()
if not service_account:
raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}")
diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py
index a3dbce0e59..f4a59ef3a7 100644
--- a/api/core/ops/langfuse_trace/langfuse_trace.py
+++ b/api/core/ops/langfuse_trace/langfuse_trace.py
@@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
UnitEnum,
)
from core.ops.utils import filter_none_values
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
@@ -67,13 +67,14 @@ class LangFuseDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
- trace_id = trace_info.workflow_run_id
+ external_trace_id = trace_info.metadata.get("external_trace_id")
+ trace_id = external_trace_id or trace_info.workflow_run_id
user_id = trace_info.metadata.get("user_id")
metadata = trace_info.metadata
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
if trace_info.message_id:
- trace_id = trace_info.message_id
+ trace_id = external_trace_id or trace_info.message_id
name = TraceTaskName.MESSAGE_TRACE.value
trace_data = LangfuseTrace(
id=trace_id,
@@ -123,10 +124,10 @@ class LangFuseDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
- app_id=trace_info.metadata.get("app_id"),
+ app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
@@ -243,7 +244,7 @@ class LangFuseDataTrace(BaseTraceInstance):
user_id = message_data.from_account_id
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
- db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
+ db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
user_id = end_user_data.session_id
diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py
index f94e5e49d7..c97846dc9b 100644
--- a/api/core/ops/langsmith_trace/langsmith_trace.py
+++ b/api/core/ops/langsmith_trace/langsmith_trace.py
@@ -27,7 +27,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunUpdateModel,
)
from core.ops.utils import filter_none_values, generate_dotted_order
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
@@ -65,7 +65,8 @@ class LangSmithDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
- trace_id = trace_info.message_id or trace_info.workflow_run_id
+ external_trace_id = trace_info.metadata.get("external_trace_id")
+ trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id
if trace_info.start_time is None:
trace_info.start_time = datetime.now()
message_dotted_order = (
@@ -145,10 +146,10 @@ class LangSmithDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
- app_id=trace_info.metadata.get("app_id"),
+ app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
@@ -261,7 +262,7 @@ class LangSmithDataTrace(BaseTraceInstance):
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
- db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
+ db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id
diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py
index 8bedea20fb..6079b2faef 100644
--- a/api/core/ops/opik_trace/opik_trace.py
+++ b/api/core/ops/opik_trace/opik_trace.py
@@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
@@ -96,7 +96,8 @@ class OpikDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
- dify_trace_id = trace_info.workflow_run_id
+ external_trace_id = trace_info.metadata.get("external_trace_id")
+ dify_trace_id = external_trace_id or trace_info.workflow_run_id
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
workflow_metadata = wrap_metadata(
trace_info.metadata, message_id=trace_info.message_id, workflow_app_log_id=trace_info.workflow_app_log_id
@@ -104,7 +105,7 @@ class OpikDataTrace(BaseTraceInstance):
root_span_id = None
if trace_info.message_id:
- dify_trace_id = trace_info.message_id
+ dify_trace_id = external_trace_id or trace_info.message_id
opik_trace_id = prepare_opik_uuid(trace_info.start_time, dify_trace_id)
trace_data = {
@@ -160,10 +161,10 @@ class OpikDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
- app_id=trace_info.metadata.get("app_id"),
+ app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
@@ -241,7 +242,7 @@ class OpikDataTrace(BaseTraceInstance):
"trace_id": opik_trace_id,
"id": prepare_opik_uuid(created_at, node_execution_id),
"parent_span_id": prepare_opik_uuid(trace_info.start_time, parent_span_id),
- "name": node_type,
+ "name": node_name,
"type": run_type,
"start_time": created_at,
"end_time": finished_at,
@@ -283,7 +284,7 @@ class OpikDataTrace(BaseTraceInstance):
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
- db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
+ db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id
diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py
index 5c9b9d27b7..2b546b47cc 100644
--- a/api/core/ops/ops_trace_manager.py
+++ b/api/core/ops/ops_trace_manager.py
@@ -218,7 +218,7 @@ class OpsTraceManager:
"""
trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig)
- .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
+ .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
@@ -226,7 +226,7 @@ class OpsTraceManager:
return None
# decrypt_token
- app = db.session.query(App).filter(App.id == app_id).first()
+ app = db.session.query(App).where(App.id == app_id).first()
if not app:
raise ValueError("App not found")
@@ -253,7 +253,7 @@ class OpsTraceManager:
if app_id is None:
return None
- app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
+ app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
if app is None:
return None
@@ -293,18 +293,18 @@ class OpsTraceManager:
@classmethod
def get_app_config_through_message_id(cls, message_id: str):
app_model_config = None
- message_data = db.session.query(Message).filter(Message.id == message_id).first()
+ message_data = db.session.query(Message).where(Message.id == message_id).first()
if not message_data:
return None
conversation_id = message_data.conversation_id
- conversation_data = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
+ conversation_data = db.session.query(Conversation).where(Conversation.id == conversation_id).first()
if not conversation_data:
return None
if conversation_data.app_model_config_id:
app_model_config = (
db.session.query(AppModelConfig)
- .filter(AppModelConfig.id == conversation_data.app_model_config_id)
+ .where(AppModelConfig.id == conversation_data.app_model_config_id)
.first()
)
elif conversation_data.app_model_config_id is None and conversation_data.override_model_configs:
@@ -331,7 +331,7 @@ class OpsTraceManager:
if tracing_provider is not None:
raise ValueError(f"Invalid tracing provider: {tracing_provider}")
- app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
+ app_config: Optional[App] = db.session.query(App).where(App.id == app_id).first()
if not app_config:
raise ValueError("App not found")
app_config.tracing = json.dumps(
@@ -349,7 +349,7 @@ class OpsTraceManager:
:param app_id: app id
:return:
"""
- app: Optional[App] = db.session.query(App).filter(App.id == app_id).first()
+ app: Optional[App] = db.session.query(App).where(App.id == app_id).first()
if not app:
raise ValueError("App not found")
if not app.tracing:
@@ -520,6 +520,10 @@ class TraceTask:
"app_id": workflow_run.app_id,
}
+ external_trace_id = self.kwargs.get("external_trace_id")
+ if external_trace_id:
+ metadata["external_trace_id"] = external_trace_id
+
workflow_trace_info = WorkflowTraceInfo(
workflow_data=workflow_run.to_dict(),
conversation_id=conversation_id,
diff --git a/api/core/ops/utils.py b/api/core/ops/utils.py
index 36d060afd2..573e8cac88 100644
--- a/api/core/ops/utils.py
+++ b/api/core/ops/utils.py
@@ -3,6 +3,8 @@ from datetime import datetime
from typing import Optional, Union
from urllib.parse import urlparse
+from sqlalchemy import select
+
from extensions.ext_database import db
from models.model import Message
@@ -20,7 +22,7 @@ def filter_none_values(data: dict):
def get_message_data(message_id: str):
- return db.session.query(Message).filter(Message.id == message_id).first()
+ return db.session.scalar(select(Message).where(Message.id == message_id))
@contextmanager
diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py
index 3917348a91..a34b3b780c 100644
--- a/api/core/ops/weave_trace/weave_trace.py
+++ b/api/core/ops/weave_trace/weave_trace.py
@@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db
@@ -87,7 +87,8 @@ class WeaveDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
- trace_id = trace_info.message_id or trace_info.workflow_run_id
+ external_trace_id = trace_info.metadata.get("external_trace_id")
+ trace_id = external_trace_id or trace_info.message_id or trace_info.workflow_run_id
if trace_info.start_time is None:
trace_info.start_time = datetime.now()
@@ -144,10 +145,10 @@ class WeaveDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
- workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
- app_id=trace_info.metadata.get("app_id"),
+ app_id=app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
)
@@ -234,7 +235,7 @@ class WeaveDataTrace(BaseTraceInstance):
if message_data.from_end_user_id:
end_user_data: Optional[EndUser] = (
- db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
+ db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first()
)
if end_user_data is not None:
end_user_id = end_user_data.session_id
diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py
index 4e43561a15..e8c9bed099 100644
--- a/api/core/plugin/backwards_invocation/app.py
+++ b/api/core/plugin/backwards_invocation/app.py
@@ -193,9 +193,9 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
get the user by user id
"""
- user = db.session.query(EndUser).filter(EndUser.id == user_id).first()
+ user = db.session.query(EndUser).where(EndUser.id == user_id).first()
if not user:
- user = db.session.query(Account).filter(Account.id == user_id).first()
+ user = db.session.query(Account).where(Account.id == user_id).first()
if not user:
raise ValueError("user not found")
@@ -208,7 +208,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
get app
"""
try:
- app = db.session.query(App).filter(App.id == app_id).filter(App.tenant_id == tenant_id).first()
+ app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first()
except Exception:
raise ValueError("app not found")
diff --git a/api/core/plugin/backwards_invocation/encrypt.py b/api/core/plugin/backwards_invocation/encrypt.py
index 81a5d033a0..213f5c726a 100644
--- a/api/core/plugin/backwards_invocation/encrypt.py
+++ b/api/core/plugin/backwards_invocation/encrypt.py
@@ -1,16 +1,20 @@
+from core.helper.provider_cache import SingletonProviderCredentialsCache
from core.plugin.entities.request import RequestInvokeEncrypt
-from core.tools.utils.configuration import ProviderConfigEncrypter
+from core.tools.utils.encryption import create_provider_encrypter
from models.account import Tenant
class PluginEncrypter:
@classmethod
def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict:
- encrypter = ProviderConfigEncrypter(
+ encrypter, cache = create_provider_encrypter(
tenant_id=tenant.id,
config=payload.config,
- provider_type=payload.namespace,
- provider_identity=payload.identity,
+ cache=SingletonProviderCredentialsCache(
+ tenant_id=tenant.id,
+ provider_type=payload.namespace,
+ provider_identity=payload.identity,
+ ),
)
if payload.opt == "encrypt":
@@ -22,7 +26,7 @@ class PluginEncrypter:
"data": encrypter.decrypt(payload.data),
}
elif payload.opt == "clear":
- encrypter.delete_tool_credentials_cache()
+ cache.delete()
return {
"data": {},
}
diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py
index 1d62743f13..06773504d9 100644
--- a/api/core/plugin/backwards_invocation/tool.py
+++ b/api/core/plugin/backwards_invocation/tool.py
@@ -1,5 +1,5 @@
from collections.abc import Generator
-from typing import Any
+from typing import Any, Optional
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
@@ -23,6 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
+ credential_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
invoke tool
@@ -30,7 +31,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation):
# get tool runtime
try:
tool_runtime = ToolManager.get_tool_runtime_from_plugin(
- tool_type, tenant_id, provider, tool_name, tool_parameters
+ tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id
)
response = ToolEngine.generic_invoke(
tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1
diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py
index a19a44aa3c..1c13a621d4 100644
--- a/api/core/plugin/entities/marketplace.py
+++ b/api/core/plugin/entities/marketplace.py
@@ -32,6 +32,13 @@ class MarketplacePluginDeclaration(BaseModel):
latest_package_identifier: str = Field(
..., description="Unique identifier for the latest package release of the plugin"
)
+ status: str = Field(..., description="Indicate the status of marketplace plugin, enum from `active` `deleted`")
+ deprecated_reason: str = Field(
+ ..., description="Not empty when status='deleted', indicates the reason why this plugin is deleted(deprecated)"
+ )
+ alternative_plugin_id: str = Field(
+ ..., description="Optional, indicates the alternative plugin for user to switch to"
+ )
@model_validator(mode="before")
@classmethod
diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py
index 2be65d67a0..47290ee613 100644
--- a/api/core/plugin/entities/parameters.py
+++ b/api/core/plugin/entities/parameters.py
@@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, field_validator
from core.entities.parameter_entities import CommonParameterType
from core.tools.entities.common_entities import I18nObject
+from core.workflow.nodes.base.entities import NumberType
class PluginParameterOption(BaseModel):
@@ -38,6 +39,7 @@ class PluginParameterType(enum.StrEnum):
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
+ ANY = CommonParameterType.ANY.value
DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value
# deprecated, should not use.
@@ -151,6 +153,10 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /):
if value and not isinstance(value, list):
raise ValueError("The tools selector must be a list.")
return value
+ case PluginParameterType.ANY:
+ if value and not isinstance(value, str | dict | list | NumberType):
+ raise ValueError("The var selector must be a string, dictionary, list or number.")
+ return value
case PluginParameterType.ARRAY:
if not isinstance(value, list):
# Try to parse JSON string for arrays
diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py
index e5cf7ee03a..a07b58d9ea 100644
--- a/api/core/plugin/entities/plugin.py
+++ b/api/core/plugin/entities/plugin.py
@@ -135,17 +135,6 @@ class PluginEntity(PluginInstallation):
return self
-class GithubPackage(BaseModel):
- repo: str
- version: str
- package: str
-
-
-class GithubVersion(BaseModel):
- repo: str
- version: str
-
-
class GenericProviderID:
organization: str
plugin_name: str
diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py
index 00253b8a11..16ab661092 100644
--- a/api/core/plugin/entities/plugin_daemon.py
+++ b/api/core/plugin/entities/plugin_daemon.py
@@ -182,6 +182,10 @@ class PluginOAuthAuthorizationUrlResponse(BaseModel):
class PluginOAuthCredentialsResponse(BaseModel):
+ metadata: Mapping[str, Any] = Field(
+ default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc."
+ )
+ expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.")
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")
diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py
index 89f595ec46..3a783dad3e 100644
--- a/api/core/plugin/entities/request.py
+++ b/api/core/plugin/entities/request.py
@@ -27,6 +27,20 @@ from core.workflow.nodes.question_classifier.entities import (
)
+class InvokeCredentials(BaseModel):
+ tool_credentials: dict[str, str] = Field(
+ default_factory=dict,
+ description="Map of tool provider to credential id, used to store the credential id for the tool provider.",
+ )
+
+
+class PluginInvokeContext(BaseModel):
+ credentials: Optional[InvokeCredentials] = Field(
+ default_factory=InvokeCredentials,
+ description="Credentials context for the plugin invocation or backward invocation.",
+ )
+
+
class RequestInvokeTool(BaseModel):
"""
Request to invoke a tool
@@ -36,6 +50,7 @@ class RequestInvokeTool(BaseModel):
provider: str
tool: str
tool_parameters: dict
+ credential_id: Optional[str] = None
class BaseRequestInvokeModel(BaseModel):
diff --git a/api/core/plugin/impl/agent.py b/api/core/plugin/impl/agent.py
index 66b77c7489..9575c57ac8 100644
--- a/api/core/plugin/impl/agent.py
+++ b/api/core/plugin/impl/agent.py
@@ -6,6 +6,7 @@ from core.plugin.entities.plugin import GenericProviderID
from core.plugin.entities.plugin_daemon import (
PluginAgentProviderEntity,
)
+from core.plugin.entities.request import PluginInvokeContext
from core.plugin.impl.base import BasePluginClient
@@ -83,6 +84,7 @@ class PluginAgentClient(BasePluginClient):
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
+ context: Optional[PluginInvokeContext] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent with the given tenant, user, plugin, provider, name and parameters.
@@ -99,6 +101,7 @@ class PluginAgentClient(BasePluginClient):
"conversation_id": conversation_id,
"app_id": app_id,
"message_id": message_id,
+ "context": context.model_dump() if context else {},
"data": {
"agent_strategy_provider": agent_provider_id.provider_name,
"agent_strategy": agent_strategy,
diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py
index b006bf1d4b..7f022992ff 100644
--- a/api/core/plugin/impl/oauth.py
+++ b/api/core/plugin/impl/oauth.py
@@ -15,27 +15,32 @@ class OAuthHandler(BasePluginClient):
user_id: str,
plugin_id: str,
provider: str,
+ redirect_uri: str,
system_credentials: Mapping[str, Any],
) -> PluginOAuthAuthorizationUrlResponse:
- response = self._request_with_plugin_daemon_response_stream(
- "POST",
- f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
- PluginOAuthAuthorizationUrlResponse,
- data={
- "user_id": user_id,
- "data": {
- "provider": provider,
- "system_credentials": system_credentials,
+ try:
+ response = self._request_with_plugin_daemon_response_stream(
+ "POST",
+ f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
+ PluginOAuthAuthorizationUrlResponse,
+ data={
+ "user_id": user_id,
+ "data": {
+ "provider": provider,
+ "redirect_uri": redirect_uri,
+ "system_credentials": system_credentials,
+ },
},
- },
- headers={
- "X-Plugin-ID": plugin_id,
- "Content-Type": "application/json",
- },
- )
- for resp in response:
- return resp
- raise ValueError("No response received from plugin daemon for authorization URL request.")
+ headers={
+ "X-Plugin-ID": plugin_id,
+ "Content-Type": "application/json",
+ },
+ )
+ for resp in response:
+ return resp
+ raise ValueError("No response received from plugin daemon for authorization URL request.")
+ except Exception as e:
+ raise ValueError(f"Error getting authorization URL: {e}")
def get_credentials(
self,
@@ -43,6 +48,7 @@ class OAuthHandler(BasePluginClient):
user_id: str,
plugin_id: str,
provider: str,
+ redirect_uri: str,
system_credentials: Mapping[str, Any],
request: Request,
) -> PluginOAuthCredentialsResponse:
@@ -50,30 +56,68 @@ class OAuthHandler(BasePluginClient):
Get credentials from the given request.
"""
- # encode request to raw http request
- raw_request_bytes = self._convert_request_to_raw_data(request)
+ try:
+ # encode request to raw http request
+ raw_request_bytes = self._convert_request_to_raw_data(request)
+ response = self._request_with_plugin_daemon_response_stream(
+ "POST",
+ f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
+ PluginOAuthCredentialsResponse,
+ data={
+ "user_id": user_id,
+ "data": {
+ "provider": provider,
+ "redirect_uri": redirect_uri,
+ "system_credentials": system_credentials,
+ # for json serialization
+ "raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
+ },
+ },
+ headers={
+ "X-Plugin-ID": plugin_id,
+ "Content-Type": "application/json",
+ },
+ )
+ for resp in response:
+ return resp
+ raise ValueError("No response received from plugin daemon for authorization URL request.")
+ except Exception as e:
+ raise ValueError(f"Error getting credentials: {e}")
- response = self._request_with_plugin_daemon_response_stream(
- "POST",
- f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
- PluginOAuthCredentialsResponse,
- data={
- "user_id": user_id,
- "data": {
- "provider": provider,
- "system_credentials": system_credentials,
- # for json serialization
- "raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
+ def refresh_credentials(
+ self,
+ tenant_id: str,
+ user_id: str,
+ plugin_id: str,
+ provider: str,
+ redirect_uri: str,
+ system_credentials: Mapping[str, Any],
+ credentials: Mapping[str, Any],
+ ) -> PluginOAuthCredentialsResponse:
+ try:
+ response = self._request_with_plugin_daemon_response_stream(
+ "POST",
+ f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials",
+ PluginOAuthCredentialsResponse,
+ data={
+ "user_id": user_id,
+ "data": {
+ "provider": provider,
+ "redirect_uri": redirect_uri,
+ "system_credentials": system_credentials,
+ "credentials": credentials,
+ },
+ },
+ headers={
+ "X-Plugin-ID": plugin_id,
+ "Content-Type": "application/json",
},
- },
- headers={
- "X-Plugin-ID": plugin_id,
- "Content-Type": "application/json",
- },
- )
- for resp in response:
- return resp
- raise ValueError("No response received from plugin daemon for authorization URL request.")
+ )
+ for resp in response:
+ return resp
+ raise ValueError("No response received from plugin daemon for refresh credentials request.")
+ except Exception as e:
+ raise ValueError(f"Error refreshing credentials: {e}")
def _convert_request_to_raw_data(self, request: Request) -> bytes:
"""
diff --git a/api/core/plugin/impl/plugin.py b/api/core/plugin/impl/plugin.py
index b7f7b31655..04ac8c9649 100644
--- a/api/core/plugin/impl/plugin.py
+++ b/api/core/plugin/impl/plugin.py
@@ -36,7 +36,7 @@ class PluginInstaller(BasePluginClient):
"GET",
f"plugin/{tenant_id}/management/list",
PluginListResponse,
- params={"page": 1, "page_size": 256},
+ params={"page": 1, "page_size": 256, "response_type": "paged"},
)
return result.list
@@ -45,7 +45,7 @@ class PluginInstaller(BasePluginClient):
"GET",
f"plugin/{tenant_id}/management/list",
PluginListResponse,
- params={"page": page, "page_size": page_size},
+ params={"page": page, "page_size": page_size, "response_type": "paged"},
)
def upload_pkg(
diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py
index 19b26c8fe3..04225f95ee 100644
--- a/api/core/plugin/impl/tool.py
+++ b/api/core/plugin/impl/tool.py
@@ -6,7 +6,7 @@ from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient
-from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
+from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter
class PluginToolManager(BasePluginClient):
@@ -78,6 +78,7 @@ class PluginToolManager(BasePluginClient):
tool_provider: str,
tool_name: str,
credentials: dict[str, Any],
+ credential_type: CredentialType,
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
@@ -102,6 +103,7 @@ class PluginToolManager(BasePluginClient):
"provider": tool_provider_id.provider_name,
"tool": tool_name,
"credentials": credentials,
+ "credential_type": credential_type,
"tool_parameters": tool_parameters,
},
},
diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py
index 25964ae063..0f0fe65f27 100644
--- a/api/core/prompt/advanced_prompt_transform.py
+++ b/api/core/prompt/advanced_prompt_transform.py
@@ -158,7 +158,7 @@ class AdvancedPromptTransform(PromptTransform):
if prompt_item.edition_type == "basic" or not prompt_item.edition_type:
if self.with_variable_tmpl:
- vp = VariablePool()
+ vp = VariablePool.empty()
for k, v in inputs.items():
if k.startswith("#"):
vp.add(k[1:-1].split("."), v)
diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py
index 47808928f7..e19c6419ca 100644
--- a/api/core/prompt/simple_prompt_transform.py
+++ b/api/core/prompt/simple_prompt_transform.py
@@ -29,19 +29,6 @@ class ModelMode(enum.StrEnum):
COMPLETION = "completion"
CHAT = "chat"
- @classmethod
- def value_of(cls, value: str) -> "ModelMode":
- """
- Get value of given mode.
-
- :param value: mode value
- :return: mode
- """
- for mode in cls:
- if mode.value == value:
- return mode
- raise ValueError(f"invalid mode value {value}")
-
prompt_file_contents: dict[str, Any] = {}
@@ -65,7 +52,7 @@ class SimplePromptTransform(PromptTransform):
) -> tuple[list[PromptMessage], Optional[list[str]]]:
inputs = {key: str(value) for key, value in inputs.items()}
- model_mode = ModelMode.value_of(model_config.mode)
+ model_mode = ModelMode(model_config.mode)
if model_mode == ModelMode.CHAT:
prompt_messages, stops = self._get_chat_model_prompt_messages(
app_mode=app_mode,
diff --git a/api/core/prompt/utils/extract_thread_messages.py b/api/core/prompt/utils/extract_thread_messages.py
index f7aef76c87..4b883622a7 100644
--- a/api/core/prompt/utils/extract_thread_messages.py
+++ b/api/core/prompt/utils/extract_thread_messages.py
@@ -1,10 +1,11 @@
-from typing import Any
+from collections.abc import Sequence
from constants import UUID_NIL
+from models import Message
-def extract_thread_messages(messages: list[Any]):
- thread_messages = []
+def extract_thread_messages(messages: Sequence[Message]):
+ thread_messages: list[Message] = []
next_message = None
for message in messages:
diff --git a/api/core/prompt/utils/get_thread_messages_length.py b/api/core/prompt/utils/get_thread_messages_length.py
index f49466db6d..de64c27a73 100644
--- a/api/core/prompt/utils/get_thread_messages_length.py
+++ b/api/core/prompt/utils/get_thread_messages_length.py
@@ -1,3 +1,5 @@
+from sqlalchemy import select
+
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from models.model import Message
@@ -8,19 +10,9 @@ def get_thread_messages_length(conversation_id: str) -> int:
Get the number of thread messages based on the parent message id.
"""
# Fetch all messages related to the conversation
- query = (
- db.session.query(
- Message.id,
- Message.parent_message_id,
- Message.answer,
- )
- .filter(
- Message.conversation_id == conversation_id,
- )
- .order_by(Message.created_at.desc())
- )
+ stmt = select(Message).where(Message.conversation_id == conversation_id).order_by(Message.created_at.desc())
- messages = query.all()
+ messages = db.session.scalars(stmt).all()
# Extract thread messages
thread_messages = extract_thread_messages(messages)
diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py
index 488a394679..6de4f3a303 100644
--- a/api/core/provider_manager.py
+++ b/api/core/provider_manager.py
@@ -275,7 +275,7 @@ class ProviderManager:
# Get the corresponding TenantDefaultModel record
default_model = (
db.session.query(TenantDefaultModel)
- .filter(
+ .where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
@@ -367,7 +367,7 @@ class ProviderManager:
# Get the list of available models from get_configurations and check if it is LLM
default_model = (
db.session.query(TenantDefaultModel)
- .filter(
+ .where(
TenantDefaultModel.tenant_id == tenant_id,
TenantDefaultModel.model_type == model_type.to_origin_model_type(),
)
@@ -541,7 +541,7 @@ class ProviderManager:
db.session.rollback()
existed_provider_record = (
db.session.query(Provider)
- .filter(
+ .where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
diff --git a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py
deleted file mode 100644
index 167a919e69..0000000000
--- a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py
+++ /dev/null
@@ -1,12 +0,0 @@
-"""Abstract interface for document clean implementations."""
-
-from core.rag.cleaner.cleaner_base import BaseCleaner
-
-
-class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
- def clean(self, content) -> str:
- """clean document content."""
- from unstructured.cleaners.core import clean_extra_whitespace
-
- # Returns "ITEM 1A: RISK FACTORS"
- return clean_extra_whitespace(content)
diff --git a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py
deleted file mode 100644
index 9c682d29db..0000000000
--- a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py
+++ /dev/null
@@ -1,15 +0,0 @@
-"""Abstract interface for document clean implementations."""
-
-from core.rag.cleaner.cleaner_base import BaseCleaner
-
-
-class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner):
- def clean(self, content) -> str:
- """clean document content."""
- import re
-
- from unstructured.cleaners.core import group_broken_paragraphs
-
- para_split_re = re.compile(r"(\s*\n\s*){3}")
-
- return group_broken_paragraphs(content, paragraph_split=para_split_re)
diff --git a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py
deleted file mode 100644
index 0cdbb171e1..0000000000
--- a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py
+++ /dev/null
@@ -1,12 +0,0 @@
-"""Abstract interface for document clean implementations."""
-
-from core.rag.cleaner.cleaner_base import BaseCleaner
-
-
-class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
- def clean(self, content) -> str:
- """clean document content."""
- from unstructured.cleaners.core import clean_non_ascii_chars
-
- # Returns "This text contains non-ascii characters!"
- return clean_non_ascii_chars(content)
diff --git a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py
deleted file mode 100644
index 9f42044a2d..0000000000
--- a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py
+++ /dev/null
@@ -1,12 +0,0 @@
-"""Abstract interface for document clean implementations."""
-
-from core.rag.cleaner.cleaner_base import BaseCleaner
-
-
-class UnstructuredNonAsciiCharsCleaner(BaseCleaner):
- def clean(self, content) -> str:
- """Replaces unicode quote characters, such as the \x91 character in a string."""
-
- from unstructured.cleaners.core import replace_unicode_quotes
-
- return replace_unicode_quotes(content)
diff --git a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py
deleted file mode 100644
index 32ae7217e8..0000000000
--- a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py
+++ /dev/null
@@ -1,11 +0,0 @@
-"""Abstract interface for document clean implementations."""
-
-from core.rag.cleaner.cleaner_base import BaseCleaner
-
-
-class UnstructuredTranslateTextCleaner(BaseCleaner):
- def clean(self, content) -> str:
- """clean document content."""
- from unstructured.cleaners.translate import translate_text
-
- return translate_text(content)
diff --git a/api/core/rag/datasource/keyword/jieba/jieba.py b/api/core/rag/datasource/keyword/jieba/jieba.py
index d6d0bd88b2..ec3a23bd96 100644
--- a/api/core/rag/datasource/keyword/jieba/jieba.py
+++ b/api/core/rag/datasource/keyword/jieba/jieba.py
@@ -93,11 +93,11 @@ class Jieba(BaseKeyword):
documents = []
for chunk_index in sorted_chunk_indices:
- segment_query = db.session.query(DocumentSegment).filter(
+ segment_query = db.session.query(DocumentSegment).where(
DocumentSegment.dataset_id == self.dataset.id, DocumentSegment.index_node_id == chunk_index
)
if document_ids_filter:
- segment_query = segment_query.filter(DocumentSegment.document_id.in_(document_ids_filter))
+ segment_query = segment_query.where(DocumentSegment.document_id.in_(document_ids_filter))
segment = segment_query.first()
if segment:
@@ -214,7 +214,7 @@ class Jieba(BaseKeyword):
def _update_segment_keywords(self, dataset_id: str, node_id: str, keywords: list[str]):
document_segment = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
+ .where(DocumentSegment.dataset_id == dataset_id, DocumentSegment.index_node_id == node_id)
.first()
)
if document_segment:
diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py
index 2c5178241c..e872a4e375 100644
--- a/api/core/rag/datasource/retrieval_service.py
+++ b/api/core/rag/datasource/retrieval_service.py
@@ -3,7 +3,7 @@ from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from flask import Flask, current_app
-from sqlalchemy.orm import load_only
+from sqlalchemy.orm import Session, load_only
from configs import dify_config
from core.rag.data_post_processor.data_post_processor import DataPostProcessor
@@ -127,7 +127,7 @@ class RetrievalService:
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None,
):
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
return []
metadata_condition = (
@@ -144,7 +144,8 @@ class RetrievalService:
@classmethod
def _get_dataset(cls, dataset_id: str) -> Optional[Dataset]:
- return db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ with Session(db.engine) as session:
+ return session.query(Dataset).where(Dataset.id == dataset_id).first()
@classmethod
def keyword_search(
@@ -293,7 +294,7 @@ class RetrievalService:
dataset_documents = {
doc.id: doc
for doc in db.session.query(DatasetDocument)
- .filter(DatasetDocument.id.in_(document_ids))
+ .where(DatasetDocument.id.in_(document_ids))
.options(load_only(DatasetDocument.id, DatasetDocument.doc_form, DatasetDocument.dataset_id))
.all()
}
@@ -317,7 +318,7 @@ class RetrievalService:
child_index_node_id = document.metadata.get("doc_id")
child_chunk = (
- db.session.query(ChildChunk).filter(ChildChunk.index_node_id == child_index_node_id).first()
+ db.session.query(ChildChunk).where(ChildChunk.index_node_id == child_index_node_id).first()
)
if not child_chunk:
@@ -325,7 +326,7 @@ class RetrievalService:
segment = (
db.session.query(DocumentSegment)
- .filter(
+ .where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
@@ -380,7 +381,7 @@ class RetrievalService:
segment = (
db.session.query(DocumentSegment)
- .filter(
+ .where(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
index 095752ea8e..6f3e15d166 100644
--- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
+++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
@@ -233,6 +233,12 @@ class AnalyticdbVectorOpenAPI:
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
+ document_ids_filter = kwargs.get("document_ids_filter")
+ where_clause = ""
+ if document_ids_filter:
+ document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+ where_clause += f"metadata_->>'document_id' IN ({document_ids})"
+
score_threshold = kwargs.get("score_threshold") or 0.0
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
@@ -245,7 +251,7 @@ class AnalyticdbVectorOpenAPI:
vector=query_vector,
content=None,
top_k=kwargs.get("top_k", 4),
- filter=None,
+ filter=where_clause,
)
response = self._client.query_collection_data(request)
documents = []
@@ -265,6 +271,11 @@ class AnalyticdbVectorOpenAPI:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
from alibabacloud_gpdb20160503 import models as gpdb_20160503_models
+ document_ids_filter = kwargs.get("document_ids_filter")
+ where_clause = ""
+ if document_ids_filter:
+ document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+ where_clause += f"metadata_->>'document_id' IN ({document_ids})"
score_threshold = float(kwargs.get("score_threshold") or 0.0)
request = gpdb_20160503_models.QueryCollectionDataRequest(
dbinstance_id=self.config.instance_id,
@@ -277,7 +288,7 @@ class AnalyticdbVectorOpenAPI:
vector=None,
content=query,
top_k=kwargs.get("top_k", 4),
- filter=None,
+ filter=where_clause,
)
response = self._client.query_collection_data(request)
documents = []
diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
index 44cc5d3e98..ad39717183 100644
--- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
+++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
@@ -147,10 +147,17 @@ class ElasticSearchVector(BaseVector):
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
- query_str = {"match": {Field.CONTENT_KEY.value: query}}
+ query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY.value: query}}
document_ids_filter = kwargs.get("document_ids_filter")
+
if document_ids_filter:
- query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} # type: ignore
+ query_str = {
+ "bool": {
+ "must": {"match": {Field.CONTENT_KEY.value: query}},
+ "filter": {"terms": {"metadata.document_id": document_ids_filter}},
+ }
+ }
+
results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
docs = []
for hit in results["hits"]["hits"]:
diff --git a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
index 46aefef11d..b0f0eeca38 100644
--- a/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
+++ b/api/core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
@@ -6,7 +6,7 @@ from uuid import UUID, uuid4
from numpy import ndarray
from pgvecto_rs.sqlalchemy import VECTOR # type: ignore
from pydantic import BaseModel, model_validator
-from sqlalchemy import Float, String, create_engine, insert, select, text
+from sqlalchemy import Float, create_engine, insert, select, text
from sqlalchemy import text as sql_text
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import Mapped, Session, mapped_column
@@ -67,7 +67,7 @@ class PGVectoRS(BaseVector):
postgresql.UUID(as_uuid=True),
primary_key=True,
)
- text: Mapped[str] = mapped_column(String)
+ text: Mapped[str]
meta: Mapped[dict] = mapped_column(postgresql.JSONB)
vector: Mapped[ndarray] = mapped_column(VECTOR(dim))
diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
index 05fa73011a..dfb95a1839 100644
--- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
+++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py
@@ -443,7 +443,7 @@ class QdrantVectorFactory(AbstractVectorFactory):
if dataset.collection_binding_id:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
- .filter(DatasetCollectionBinding.id == dataset.collection_binding_id)
+ .where(DatasetCollectionBinding.id == dataset.collection_binding_id)
.one_or_none()
)
if dataset_collection_binding:
diff --git a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py
index a124faa503..9ed6e7369b 100644
--- a/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py
+++ b/api/core/rag/datasource/vdb/tablestore/tablestore_vector.py
@@ -4,6 +4,7 @@ from typing import Any, Optional
import tablestore # type: ignore
from pydantic import BaseModel, model_validator
+from tablestore import BatchGetRowRequest, TableInBatchGetRowItem
from configs import dify_config
from core.rag.datasource.vdb.field import Field
@@ -50,6 +51,29 @@ class TableStoreVector(BaseVector):
self._index_name = f"{collection_name}_idx"
self._tags_field = f"{Field.METADATA_KEY.value}_tags"
+ def create_collection(self, embeddings: list[list[float]], **kwargs):
+ dimension = len(embeddings[0])
+ self._create_collection(dimension)
+
+ def get_by_ids(self, ids: list[str]) -> list[Document]:
+ docs = []
+ request = BatchGetRowRequest()
+ columns_to_get = [Field.METADATA_KEY.value, Field.CONTENT_KEY.value]
+ rows_to_get = [[("id", _id)] for _id in ids]
+ request.add(TableInBatchGetRowItem(self._table_name, rows_to_get, columns_to_get, None, 1))
+
+ result = self._tablestore_client.batch_get_row(request)
+ table_result = result.get_result_by_table(self._table_name)
+ for item in table_result:
+ if item.is_ok and item.row:
+ kv = {k: v for k, v, t in item.row.attribute_columns}
+ docs.append(
+ Document(
+ page_content=kv[Field.CONTENT_KEY.value], metadata=json.loads(kv[Field.METADATA_KEY.value])
+ )
+ )
+ return docs
+
def get_type(self) -> str:
return VectorType.TABLESTORE
@@ -94,10 +118,21 @@ class TableStoreVector(BaseVector):
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 4)
- return self._search_by_vector(query_vector, top_k)
+ document_ids_filter = kwargs.get("document_ids_filter")
+ filtered_list = None
+ if document_ids_filter:
+ filtered_list = ["document_id=" + item for item in document_ids_filter]
+ score_threshold = float(kwargs.get("score_threshold") or 0.0)
+ return self._search_by_vector(query_vector, filtered_list, top_k, score_threshold)
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
- return self._search_by_full_text(query)
+ top_k = kwargs.get("top_k", 4)
+ document_ids_filter = kwargs.get("document_ids_filter")
+ filtered_list = None
+ if document_ids_filter:
+ filtered_list = ["document_id=" + item for item in document_ids_filter]
+
+ return self._search_by_full_text(query, filtered_list, top_k)
def delete(self) -> None:
self._delete_table_if_exist()
@@ -206,32 +241,51 @@ class TableStoreVector(BaseVector):
primary_key = [("id", id)]
row = tablestore.Row(primary_key)
self._tablestore_client.delete_row(self._table_name, row, None)
- logging.info("Tablestore delete row successfully. id:%s", id)
def _search_by_metadata(self, key: str, value: str) -> list[str]:
query = tablestore.SearchQuery(
tablestore.TermQuery(self._tags_field, str(key) + "=" + str(value)),
- limit=100,
+ limit=1000,
get_total_count=False,
)
+ rows: list[str] = []
+ next_token = None
+ while True:
+ if next_token is not None:
+ query.next_token = next_token
+
+ search_response = self._tablestore_client.search(
+ table_name=self._table_name,
+ index_name=self._index_name,
+ search_query=query,
+ columns_to_get=tablestore.ColumnsToGet(
+ column_names=[Field.PRIMARY_KEY.value], return_type=tablestore.ColumnReturnType.SPECIFIED
+ ),
+ )
- search_response = self._tablestore_client.search(
- table_name=self._table_name,
- index_name=self._index_name,
- search_query=query,
- columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
- )
+ if search_response is not None:
+ rows.extend([row[0][0][1] for row in search_response.rows])
- return [row[0][0][1] for row in search_response.rows]
+ if search_response is None or search_response.next_token == b"":
+ break
+ else:
+ next_token = search_response.next_token
- def _search_by_vector(self, query_vector: list[float], top_k: int) -> list[Document]:
- ots_query = tablestore.KnnVectorQuery(
+ return rows
+
+ def _search_by_vector(
+ self, query_vector: list[float], document_ids_filter: list[str] | None, top_k: int, score_threshold: float
+ ) -> list[Document]:
+ knn_vector_query = tablestore.KnnVectorQuery(
field_name=Field.VECTOR.value,
top_k=top_k,
float32_query_vector=query_vector,
)
+ if document_ids_filter:
+ knn_vector_query.filter = tablestore.TermsQuery(self._tags_field, document_ids_filter)
+
sort = tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)])
- search_query = tablestore.SearchQuery(ots_query, limit=top_k, get_total_count=False, sort=sort)
+ search_query = tablestore.SearchQuery(knn_vector_query, limit=top_k, get_total_count=False, sort=sort)
search_response = self._tablestore_client.search(
table_name=self._table_name,
@@ -239,30 +293,42 @@ class TableStoreVector(BaseVector):
search_query=search_query,
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
)
- logging.info(
- "Tablestore search successfully. request_id:%s",
- search_response.request_id,
- )
- return self._to_query_result(search_response)
-
- def _to_query_result(self, search_response: tablestore.SearchResponse) -> list[Document]:
documents = []
- for row in search_response.rows:
- documents.append(
- Document(
- page_content=row[1][2][1],
- vector=json.loads(row[1][3][1]),
- metadata=json.loads(row[1][0][1]),
+ for search_hit in search_response.search_hits:
+ if search_hit.score > score_threshold:
+ ots_column_map = {}
+ for col in search_hit.row[1]:
+ ots_column_map[col[0]] = col[1]
+
+ vector_str = ots_column_map.get(Field.VECTOR.value)
+ metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
+
+ vector = json.loads(vector_str) if vector_str else None
+ metadata = json.loads(metadata_str) if metadata_str else {}
+
+ metadata["score"] = search_hit.score
+
+ documents.append(
+ Document(
+ page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
+ vector=vector,
+ metadata=metadata,
+ )
)
- )
-
+ documents = sorted(documents, key=lambda x: x.metadata["score"] if x.metadata else 0, reverse=True)
return documents
- def _search_by_full_text(self, query: str) -> list[Document]:
+ def _search_by_full_text(self, query: str, document_ids_filter: list[str] | None, top_k: int) -> list[Document]:
+ bool_query = tablestore.BoolQuery(must_queries=[], filter_queries=[], should_queries=[], must_not_queries=[])
+ bool_query.must_queries.append(tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value))
+
+ if document_ids_filter:
+ bool_query.filter_queries.append(tablestore.TermsQuery(self._tags_field, document_ids_filter))
+
search_query = tablestore.SearchQuery(
- query=tablestore.MatchQuery(text=query, field_name=Field.CONTENT_KEY.value),
+ query=bool_query,
sort=tablestore.Sort(sorters=[tablestore.ScoreSort(sort_order=tablestore.SortOrder.DESC)]),
- limit=100,
+ limit=top_k,
)
search_response = self._tablestore_client.search(
table_name=self._table_name,
@@ -271,7 +337,25 @@ class TableStoreVector(BaseVector):
columns_to_get=tablestore.ColumnsToGet(return_type=tablestore.ColumnReturnType.ALL_FROM_INDEX),
)
- return self._to_query_result(search_response)
+ documents = []
+ for search_hit in search_response.search_hits:
+ ots_column_map = {}
+ for col in search_hit.row[1]:
+ ots_column_map[col[0]] = col[1]
+
+ vector_str = ots_column_map.get(Field.VECTOR.value)
+ metadata_str = ots_column_map.get(Field.METADATA_KEY.value)
+ vector = json.loads(vector_str) if vector_str else None
+ metadata = json.loads(metadata_str) if metadata_str else {}
+
+ documents.append(
+ Document(
+ page_content=ots_column_map.get(Field.CONTENT_KEY.value) or "",
+ vector=vector,
+ metadata=metadata,
+ )
+ )
+ return documents
class TableStoreVectorFactory(AbstractVectorFactory):
diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py
index 75afe0cdb8..23ed8a3344 100644
--- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py
+++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py
@@ -206,9 +206,19 @@ class TencentVector(BaseVector):
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
- self._client.delete(
- database_name=self._client_config.database, collection_name=self.collection_name, document_ids=ids
- )
+
+ total_count = len(ids)
+ batch_size = self._client_config.max_upsert_batch_size
+ batch = math.ceil(total_count / batch_size)
+
+ for j in range(batch):
+ start_idx = j * batch_size
+ end_idx = min(total_count, (j + 1) * batch_size)
+ batch_ids = ids[start_idx:end_idx]
+
+ self._client.delete(
+ database_name=self._client_config.database, collection_name=self.collection_name, document_ids=batch_ids
+ )
def delete_by_metadata_field(self, key: str, value: str) -> None:
self._client.delete(
@@ -274,7 +284,8 @@ class TencentVector(BaseVector):
# Compatible with version 1.1.3 and below.
meta = json.loads(meta)
score = 1 - result.get("score", 0.0)
- score = result.get("score", 0.0)
+ else:
+ score = result.get("score", 0.0)
if score > score_threshold:
meta["score"] = score
doc = Document(page_content=result.get(self.field_text), metadata=meta)
diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py
deleted file mode 100644
index 1e62b3c589..0000000000
--- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py
+++ /dev/null
@@ -1,17 +0,0 @@
-from typing import Optional
-
-from pydantic import BaseModel
-
-
-class ClusterEntity(BaseModel):
- """
- Model Config Entity.
- """
-
- name: str
- cluster_id: str
- displayName: str
- region: str
- spendingLimit: Optional[int] = 1000
- version: str
- createdBy: str
diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
index 6f895b12af..ba6a9654f0 100644
--- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
+++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
@@ -418,13 +418,13 @@ class TidbOnQdrantVector(BaseVector):
class TidbOnQdrantVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> TidbOnQdrantVector:
tidb_auth_binding = (
- db.session.query(TidbAuthBinding).filter(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
+ db.session.query(TidbAuthBinding).where(TidbAuthBinding.tenant_id == dataset.tenant_id).one_or_none()
)
if not tidb_auth_binding:
with redis_client.lock("create_tidb_serverless_cluster_lock", timeout=900):
tidb_auth_binding = (
db.session.query(TidbAuthBinding)
- .filter(TidbAuthBinding.tenant_id == dataset.tenant_id)
+ .where(TidbAuthBinding.tenant_id == dataset.tenant_id)
.one_or_none()
)
if tidb_auth_binding:
@@ -433,7 +433,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory):
else:
idle_tidb_auth_binding = (
db.session.query(TidbAuthBinding)
- .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
+ .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "ACTIVE")
.limit(1)
.one_or_none()
)
diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py
index 00080b0fae..e018f7d3d4 100644
--- a/api/core/rag/datasource/vdb/vector_factory.py
+++ b/api/core/rag/datasource/vdb/vector_factory.py
@@ -47,7 +47,7 @@ class Vector:
if dify_config.VECTOR_STORE_WHITELIST_ENABLE:
whitelist = (
db.session.query(Whitelist)
- .filter(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
+ .where(Whitelist.tenant_id == self._dataset.tenant_id, Whitelist.category == "vector_db")
.one_or_none()
)
if whitelist:
diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py
index 398b0daad9..f844770a20 100644
--- a/api/core/rag/docstore/dataset_docstore.py
+++ b/api/core/rag/docstore/dataset_docstore.py
@@ -42,7 +42,7 @@ class DatasetDocumentStore:
@property
def docs(self) -> dict[str, Document]:
document_segments = (
- db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == self._dataset.id).all()
+ db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == self._dataset.id).all()
)
output = {}
@@ -63,7 +63,7 @@ class DatasetDocumentStore:
def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None:
max_position = (
db.session.query(func.max(DocumentSegment.position))
- .filter(DocumentSegment.document_id == self._document_id)
+ .where(DocumentSegment.document_id == self._document_id)
.scalar()
)
@@ -147,7 +147,7 @@ class DatasetDocumentStore:
segment_document.tokens = tokens
if save_child and doc.children:
# delete the existing child chunks
- db.session.query(ChildChunk).filter(
+ db.session.query(ChildChunk).where(
ChildChunk.tenant_id == self._dataset.tenant_id,
ChildChunk.dataset_id == self._dataset.id,
ChildChunk.document_id == self._document_id,
@@ -230,7 +230,7 @@ class DatasetDocumentStore:
def get_document_segment(self, doc_id: str) -> Optional[DocumentSegment]:
document_segment = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
+ .where(DocumentSegment.dataset_id == self._dataset.id, DocumentSegment.index_node_id == doc_id)
.first()
)
diff --git a/api/core/rag/extractor/blob/blob.py b/api/core/rag/extractor/blob/blob.py
index e46ab8b7fd..01003a13b6 100644
--- a/api/core/rag/extractor/blob/blob.py
+++ b/api/core/rag/extractor/blob/blob.py
@@ -9,8 +9,7 @@ from __future__ import annotations
import contextlib
import mimetypes
-from abc import ABC, abstractmethod
-from collections.abc import Generator, Iterable, Mapping
+from collections.abc import Generator, Mapping
from io import BufferedReader, BytesIO
from pathlib import Path, PurePath
from typing import Any, Optional, Union
@@ -143,21 +142,3 @@ class Blob(BaseModel):
if self.source:
str_repr += f" {self.source}"
return str_repr
-
-
-class BlobLoader(ABC):
- """Abstract interface for blob loaders implementation.
-
- Implementer should be able to load raw content from a datasource system according
- to some criteria and return the raw content lazily as a stream of blobs.
- """
-
- @abstractmethod
- def yield_blobs(
- self,
- ) -> Iterable[Blob]:
- """A lazy loader for raw data represented by Blob object.
-
- Returns:
- A generator over blobs
- """
diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py
index eca955ddd1..875626eb34 100644
--- a/api/core/rag/extractor/notion_extractor.py
+++ b/api/core/rag/extractor/notion_extractor.py
@@ -331,9 +331,10 @@ class NotionExtractor(BaseExtractor):
last_edited_time = self.get_notion_last_edited_time()
data_source_info = document_model.data_source_info_dict
data_source_info["last_edited_time"] = last_edited_time
- update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)}
- db.session.query(DocumentModel).filter_by(id=document_model.id).update(update_params)
+ db.session.query(DocumentModel).filter_by(id=document_model.id).update(
+ {DocumentModel.data_source_info: json.dumps(data_source_info)}
+ ) # type: ignore
db.session.commit()
def get_notion_last_edited_time(self) -> str:
@@ -365,7 +366,7 @@ class NotionExtractor(BaseExtractor):
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = (
db.session.query(DataSourceOauthBinding)
- .filter(
+ .where(
db.and_(
DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion",
diff --git a/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py
deleted file mode 100644
index dd8a979e70..0000000000
--- a/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import logging
-
-from core.rag.extractor.extractor_base import BaseExtractor
-from core.rag.models.document import Document
-
-logger = logging.getLogger(__name__)
-
-
-class UnstructuredPDFExtractor(BaseExtractor):
- """Load pdf files.
-
-
- Args:
- file_path: Path to the file to load.
-
- api_url: Unstructured API URL
-
- api_key: Unstructured API Key
- """
-
- def __init__(self, file_path: str, api_url: str, api_key: str):
- """Initialize with file path."""
- self._file_path = file_path
- self._api_url = api_url
- self._api_key = api_key
-
- def extract(self) -> list[Document]:
- if self._api_url:
- from unstructured.partition.api import partition_via_api
-
- elements = partition_via_api(
- filename=self._file_path, api_url=self._api_url, api_key=self._api_key, strategy="auto"
- )
- else:
- from unstructured.partition.pdf import partition_pdf
-
- elements = partition_pdf(filename=self._file_path, strategy="auto")
-
- from unstructured.chunking.title import chunk_by_title
-
- chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
- documents = []
- for chunk in chunks:
- text = chunk.text.strip()
- documents.append(Document(page_content=text))
-
- return documents
diff --git a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py
deleted file mode 100644
index 22dfdd2075..0000000000
--- a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py
+++ /dev/null
@@ -1,34 +0,0 @@
-import logging
-
-from core.rag.extractor.extractor_base import BaseExtractor
-from core.rag.models.document import Document
-
-logger = logging.getLogger(__name__)
-
-
-class UnstructuredTextExtractor(BaseExtractor):
- """Load msg files.
-
-
- Args:
- file_path: Path to the file to load.
- """
-
- def __init__(self, file_path: str, api_url: str):
- """Initialize with file path."""
- self._file_path = file_path
- self._api_url = api_url
-
- def extract(self) -> list[Document]:
- from unstructured.partition.text import partition_text
-
- elements = partition_text(filename=self._file_path)
- from unstructured.chunking.title import chunk_by_title
-
- chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000)
- documents = []
- for chunk in chunks:
- text = chunk.text.strip()
- documents.append(Document(page_content=text))
-
- return documents
diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py
index bff0acc48f..14363de7d4 100644
--- a/api/core/rag/extractor/word_extractor.py
+++ b/api/core/rag/extractor/word_extractor.py
@@ -238,9 +238,11 @@ class WordExtractor(BaseExtractor):
paragraph_content = []
for run in paragraph.runs:
if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"):
+ # Process drawing type images
drawing_elements = run.element.findall(
".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing"
)
+ has_drawing = False
for drawing in drawing_elements:
blip_elements = drawing.findall(
".//{http://schemas.openxmlformats.org/drawingml/2006/main}blip"
@@ -252,6 +254,34 @@ class WordExtractor(BaseExtractor):
if embed_id:
image_part = doc.part.related_parts.get(embed_id)
if image_part in image_map:
+ has_drawing = True
+ paragraph_content.append(image_map[image_part])
+ # Process pict type images
+ shape_elements = run.element.findall(
+ ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict"
+ )
+ for shape in shape_elements:
+ # Find image data in VML
+ shape_image = shape.find(
+ ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}binData"
+ )
+ if shape_image is not None and shape_image.text:
+ image_id = shape_image.get(
+ "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
+ )
+ if image_id and image_id in doc.part.rels:
+ image_part = doc.part.rels[image_id].target_part
+ if image_part in image_map and not has_drawing:
+ paragraph_content.append(image_map[image_part])
+ # Find imagedata element in VML
+ image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata")
+ if image_data is not None:
+ image_id = image_data.get("id") or image_data.get(
+ "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id"
+ )
+ if image_id and image_id in doc.part.rels:
+ image_part = doc.part.rels[image_id].target_part
+ if image_part in image_map and not has_drawing:
paragraph_content.append(image_map[image_part])
if run.text.strip():
paragraph_content.append(run.text.strip())
diff --git a/api/core/rag/index_processor/processor/parent_child_index_processor.py b/api/core/rag/index_processor/processor/parent_child_index_processor.py
index 1cde5e1c8f..52756fbacd 100644
--- a/api/core/rag/index_processor/processor/parent_child_index_processor.py
+++ b/api/core/rag/index_processor/processor/parent_child_index_processor.py
@@ -118,7 +118,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_node_ids = (
db.session.query(ChildChunk.index_node_id)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
- .filter(
+ .where(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
ChildChunk.dataset_id == dataset.id,
@@ -128,7 +128,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids]
vector.delete_by_ids(child_node_ids)
if delete_child_chunks:
- db.session.query(ChildChunk).filter(
+ db.session.query(ChildChunk).where(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
).delete()
db.session.commit()
@@ -136,7 +136,7 @@ class ParentChildIndexProcessor(BaseIndexProcessor):
vector.delete()
if delete_child_chunks:
- db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete()
+ db.session.query(ChildChunk).where(ChildChunk.dataset_id == dataset.id).delete()
db.session.commit()
def retrieve(
diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py
index 3fca48be22..a25bc65646 100644
--- a/api/core/rag/retrieval/dataset_retrieval.py
+++ b/api/core/rag/retrieval/dataset_retrieval.py
@@ -9,6 +9,7 @@ from typing import Any, Optional, Union, cast
from flask import Flask, current_app
from sqlalchemy import Float, and_, or_, text
from sqlalchemy import cast as sqlalchemy_cast
+from sqlalchemy.orm import Session
from core.app.app_config.entities import (
DatasetEntity,
@@ -134,7 +135,7 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
# pass if dataset is not available
if not dataset:
@@ -241,7 +242,7 @@ class DatasetRetrieval:
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument)
- .filter(
+ .where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
@@ -326,7 +327,7 @@ class DatasetRetrieval:
if dataset_id:
# get retrieval model config
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
results = []
if dataset.provider == "external":
@@ -515,14 +516,14 @@ class DatasetRetrieval:
if document.metadata is not None:
dataset_document = (
db.session.query(DatasetDocument)
- .filter(DatasetDocument.id == document.metadata["document_id"])
+ .where(DatasetDocument.id == document.metadata["document_id"])
.first()
)
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = (
db.session.query(ChildChunk)
- .filter(
+ .where(
ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id,
@@ -532,7 +533,7 @@ class DatasetRetrieval:
if child_chunk:
segment = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.id == child_chunk.segment_id)
+ .where(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
@@ -540,13 +541,13 @@ class DatasetRetrieval:
)
db.session.commit()
else:
- query = db.session.query(DocumentSegment).filter(
+ query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
# if 'dataset_id' in document.metadata:
if "dataset_id" in document.metadata:
- query = query.filter(DocumentSegment.dataset_id == document.metadata["dataset_id"])
+ query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
query.update(
@@ -598,7 +599,8 @@ class DatasetRetrieval:
metadata_condition: Optional[MetadataCondition] = None,
):
with flask_app.app_context():
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ with Session(db.engine) as session:
+ dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
return []
@@ -683,7 +685,7 @@ class DatasetRetrieval:
available_datasets = []
for dataset_id in dataset_ids:
# get dataset from dataset id
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
# pass if dataset is not available
if not dataset:
@@ -860,7 +862,7 @@ class DatasetRetrieval:
metadata_filtering_conditions: Optional[MetadataFilteringCondition],
inputs: dict,
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
- document_query = db.session.query(DatasetDocument).filter(
+ document_query = db.session.query(DatasetDocument).where(
DatasetDocument.dataset_id.in_(dataset_ids),
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@@ -928,9 +930,9 @@ class DatasetRetrieval:
raise ValueError("Invalid metadata filtering mode")
if filters:
if metadata_filtering_conditions and metadata_filtering_conditions.logical_operator == "and": # type: ignore
- document_query = document_query.filter(and_(*filters))
+ document_query = document_query.where(and_(*filters))
else:
- document_query = document_query.filter(or_(*filters))
+ document_query = document_query.where(or_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
@@ -956,7 +958,7 @@ class DatasetRetrieval:
self, dataset_ids: list, query: str, tenant_id: str, user_id: str, metadata_model_config: ModelConfig
) -> Optional[list[dict[str, Any]]]:
# get all metadata field
- metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
+ metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
# get metadata model config
if metadata_model_config is None:
@@ -1135,7 +1137,7 @@ class DatasetRetrieval:
def _get_prompt_template(
self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str
):
- model_mode = ModelMode.value_of(mode)
+ model_mode = ModelMode(mode)
input_text = query
prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]]
diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py
index 0fb1bcb2e0..bcaf299892 100644
--- a/api/core/rag/splitter/fixed_text_splitter.py
+++ b/api/core/rag/splitter/fixed_text_splitter.py
@@ -102,6 +102,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
splits = text.split()
else:
splits = text.split(separator)
+ splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)]
else:
splits = list(text)
splits = [s for s in splits if (s not in {"", "\n"})]
diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py
index b711e8434a..529d8ccd27 100644
--- a/api/core/rag/splitter/text_splitter.py
+++ b/api/core/rag/splitter/text_splitter.py
@@ -10,7 +10,6 @@ from typing import (
Any,
Literal,
Optional,
- TypedDict,
TypeVar,
Union,
)
@@ -168,167 +167,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
raise NotImplementedError
-class CharacterTextSplitter(TextSplitter):
- """Splitting text that looks at characters."""
-
- def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None:
- """Create a new TextSplitter."""
- super().__init__(**kwargs)
- self._separator = separator
-
- def split_text(self, text: str) -> list[str]:
- """Split incoming text and return chunks."""
- # First we naively split the large input into a bunch of smaller ones.
- splits = _split_text_with_regex(text, self._separator, self._keep_separator)
- _separator = "" if self._keep_separator else self._separator
- _good_splits_lengths = [] # cache the lengths of the splits
- if splits:
- _good_splits_lengths.extend(self._length_function(splits))
- return self._merge_splits(splits, _separator, _good_splits_lengths)
-
-
-class LineType(TypedDict):
- """Line type as typed dict."""
-
- metadata: dict[str, str]
- content: str
-
-
-class HeaderType(TypedDict):
- """Header type as typed dict."""
-
- level: int
- name: str
- data: str
-
-
-class MarkdownHeaderTextSplitter:
- """Splitting markdown files based on specified headers."""
-
- def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False):
- """Create a new MarkdownHeaderTextSplitter.
-
- Args:
- headers_to_split_on: Headers we want to track
- return_each_line: Return each line w/ associated headers
- """
- # Output line-by-line or aggregated into chunks w/ common headers
- self.return_each_line = return_each_line
- # Given the headers we want to split on,
- # (e.g., "#, ##, etc") order by length
- self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True)
-
- def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]:
- """Combine lines with common metadata into chunks
- Args:
- lines: Line of text / associated header metadata
- """
- aggregated_chunks: list[LineType] = []
-
- for line in lines:
- if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]:
- # If the last line in the aggregated list
- # has the same metadata as the current line,
- # append the current content to the last lines's content
- aggregated_chunks[-1]["content"] += " \n" + line["content"]
- else:
- # Otherwise, append the current line to the aggregated list
- aggregated_chunks.append(line)
-
- return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks]
-
- def split_text(self, text: str) -> list[Document]:
- """Split markdown file
- Args:
- text: Markdown file"""
-
- # Split the input text by newline character ("\n").
- lines = text.split("\n")
- # Final output
- lines_with_metadata: list[LineType] = []
- # Content and metadata of the chunk currently being processed
- current_content: list[str] = []
- current_metadata: dict[str, str] = {}
- # Keep track of the nested header structure
- # header_stack: List[Dict[str, Union[int, str]]] = []
- header_stack: list[HeaderType] = []
- initial_metadata: dict[str, str] = {}
-
- for line in lines:
- stripped_line = line.strip()
- # Check each line against each of the header types (e.g., #, ##)
- for sep, name in self.headers_to_split_on:
- # Check if line starts with a header that we intend to split on
- if stripped_line.startswith(sep) and (
- # Header with no text OR header is followed by space
- # Both are valid conditions that sep is being used a header
- len(stripped_line) == len(sep) or stripped_line[len(sep)] == " "
- ):
- # Ensure we are tracking the header as metadata
- if name is not None:
- # Get the current header level
- current_header_level = sep.count("#")
-
- # Pop out headers of lower or same level from the stack
- while header_stack and header_stack[-1]["level"] >= current_header_level:
- # We have encountered a new header
- # at the same or higher level
- popped_header = header_stack.pop()
- # Clear the metadata for the
- # popped header in initial_metadata
- if popped_header["name"] in initial_metadata:
- initial_metadata.pop(popped_header["name"])
-
- # Push the current header to the stack
- header: HeaderType = {
- "level": current_header_level,
- "name": name,
- "data": stripped_line[len(sep) :].strip(),
- }
- header_stack.append(header)
- # Update initial_metadata with the current header
- initial_metadata[name] = header["data"]
-
- # Add the previous line to the lines_with_metadata
- # only if current_content is not empty
- if current_content:
- lines_with_metadata.append(
- {
- "content": "\n".join(current_content),
- "metadata": current_metadata.copy(),
- }
- )
- current_content.clear()
-
- break
- else:
- if stripped_line:
- current_content.append(stripped_line)
- elif current_content:
- lines_with_metadata.append(
- {
- "content": "\n".join(current_content),
- "metadata": current_metadata.copy(),
- }
- )
- current_content.clear()
-
- current_metadata = initial_metadata.copy()
-
- if current_content:
- lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata})
-
- # lines_with_metadata has each line with associated header metadata
- # aggregate these into chunks based on common metadata
- if not self.return_each_line:
- return self.aggregate_lines_to_chunks(lines_with_metadata)
- else:
- return [
- Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata
- ]
-
-
-# should be in newer Python versions (3.10+)
# @dataclass(frozen=True, kw_only=True, slots=True)
@dataclass(frozen=True)
class Tokenizer:
diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py
index 6452317120..052ba1c2cb 100644
--- a/api/core/repositories/__init__.py
+++ b/api/core/repositories/__init__.py
@@ -5,8 +5,11 @@ This package contains concrete implementations of the repository interfaces
defined in the core.workflow.repository package.
"""
+from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
+ "DifyCoreRepositoryFactory",
+ "RepositoryImportError",
"SQLAlchemyWorkflowNodeExecutionRepository",
]
diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py
new file mode 100644
index 0000000000..4118aa61c7
--- /dev/null
+++ b/api/core/repositories/factory.py
@@ -0,0 +1,224 @@
+"""
+Repository factory for dynamically creating repository instances based on configuration.
+
+This module provides a Django-like settings system for repository implementations,
+allowing users to configure different repository backends through string paths.
+"""
+
+import importlib
+import inspect
+import logging
+from typing import Protocol, Union
+
+from sqlalchemy.engine import Engine
+from sqlalchemy.orm import sessionmaker
+
+from configs import dify_config
+from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
+from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from models import Account, EndUser
+from models.enums import WorkflowRunTriggeredFrom
+from models.workflow import WorkflowNodeExecutionTriggeredFrom
+
+logger = logging.getLogger(__name__)
+
+
+class RepositoryImportError(Exception):
+ """Raised when a repository implementation cannot be imported or instantiated."""
+
+ pass
+
+
+class DifyCoreRepositoryFactory:
+ """
+ Factory for creating repository instances based on configuration.
+
+ This factory supports Django-like settings where repository implementations
+ are specified as module paths (e.g., 'module.submodule.ClassName').
+ """
+
+ @staticmethod
+ def _import_class(class_path: str) -> type:
+ """
+ Import a class from a module path string.
+
+ Args:
+ class_path: Full module path to the class (e.g., 'module.submodule.ClassName')
+
+ Returns:
+ The imported class
+
+ Raises:
+ RepositoryImportError: If the class cannot be imported
+ """
+ try:
+ module_path, class_name = class_path.rsplit(".", 1)
+ module = importlib.import_module(module_path)
+ repo_class = getattr(module, class_name)
+ assert isinstance(repo_class, type)
+ return repo_class
+ except (ValueError, ImportError, AttributeError) as e:
+ raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e
+
+ @staticmethod
+ def _validate_repository_interface(repository_class: type, expected_interface: type[Protocol]) -> None: # type: ignore
+ """
+ Validate that a class implements the expected repository interface.
+
+ Args:
+ repository_class: The class to validate
+ expected_interface: The expected interface/protocol
+
+ Raises:
+ RepositoryImportError: If the class doesn't implement the interface
+ """
+ # Check if the class has all required methods from the protocol
+ required_methods = [
+ method
+ for method in dir(expected_interface)
+ if not method.startswith("_") and callable(getattr(expected_interface, method, None))
+ ]
+
+ missing_methods = []
+ for method_name in required_methods:
+ if not hasattr(repository_class, method_name):
+ missing_methods.append(method_name)
+
+ if missing_methods:
+ raise RepositoryImportError(
+ f"Repository class '{repository_class.__name__}' does not implement required methods "
+ f"{missing_methods} from interface '{expected_interface.__name__}'"
+ )
+
+ @staticmethod
+ def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None:
+ """
+ Validate that a repository class constructor accepts required parameters.
+
+ Args:
+ repository_class: The class to validate
+ required_params: List of required parameter names
+
+ Raises:
+ RepositoryImportError: If the constructor doesn't accept required parameters
+ """
+
+ try:
+ # MyPy may flag the line below with the following error:
+ #
+ # > Accessing "__init__" on an instance is unsound, since
+ # > instance.__init__ could be from an incompatible subclass.
+ #
+ # Despite this, we need to ensure that the constructor of `repository_class`
+ # has a compatible signature.
+ signature = inspect.signature(repository_class.__init__) # type: ignore[misc]
+ param_names = list(signature.parameters.keys())
+
+ # Remove 'self' parameter
+ if "self" in param_names:
+ param_names.remove("self")
+
+ missing_params = [param for param in required_params if param not in param_names]
+ if missing_params:
+ raise RepositoryImportError(
+ f"Repository class '{repository_class.__name__}' constructor does not accept required parameters: "
+ f"{missing_params}. Expected parameters: {required_params}"
+ )
+ except Exception as e:
+ raise RepositoryImportError(
+ f"Failed to validate constructor signature for '{repository_class.__name__}': {e}"
+ ) from e
+
+ @classmethod
+ def create_workflow_execution_repository(
+ cls,
+ session_factory: Union[sessionmaker, Engine],
+ user: Union[Account, EndUser],
+ app_id: str,
+ triggered_from: WorkflowRunTriggeredFrom,
+ ) -> WorkflowExecutionRepository:
+ """
+ Create a WorkflowExecutionRepository instance based on configuration.
+
+ Args:
+ session_factory: SQLAlchemy sessionmaker or engine
+ user: Account or EndUser object
+ app_id: Application ID
+ triggered_from: Source of the execution trigger
+
+ Returns:
+ Configured WorkflowExecutionRepository instance
+
+ Raises:
+ RepositoryImportError: If the configured repository cannot be created
+ """
+ class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY
+ logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}")
+
+ try:
+ repository_class = cls._import_class(class_path)
+ cls._validate_repository_interface(repository_class, WorkflowExecutionRepository)
+ cls._validate_constructor_signature(
+ repository_class, ["session_factory", "user", "app_id", "triggered_from"]
+ )
+
+ return repository_class( # type: ignore[no-any-return]
+ session_factory=session_factory,
+ user=user,
+ app_id=app_id,
+ triggered_from=triggered_from,
+ )
+ except RepositoryImportError:
+ # Re-raise our custom errors as-is
+ raise
+ except Exception as e:
+ logger.exception("Failed to create WorkflowExecutionRepository")
+ raise RepositoryImportError(f"Failed to create WorkflowExecutionRepository from '{class_path}': {e}") from e
+
+ @classmethod
+ def create_workflow_node_execution_repository(
+ cls,
+ session_factory: Union[sessionmaker, Engine],
+ user: Union[Account, EndUser],
+ app_id: str,
+ triggered_from: WorkflowNodeExecutionTriggeredFrom,
+ ) -> WorkflowNodeExecutionRepository:
+ """
+ Create a WorkflowNodeExecutionRepository instance based on configuration.
+
+ Args:
+ session_factory: SQLAlchemy sessionmaker or engine
+ user: Account or EndUser object
+ app_id: Application ID
+ triggered_from: Source of the execution trigger
+
+ Returns:
+ Configured WorkflowNodeExecutionRepository instance
+
+ Raises:
+ RepositoryImportError: If the configured repository cannot be created
+ """
+ class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY
+ logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}")
+
+ try:
+ repository_class = cls._import_class(class_path)
+ cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository)
+ cls._validate_constructor_signature(
+ repository_class, ["session_factory", "user", "app_id", "triggered_from"]
+ )
+
+ return repository_class( # type: ignore[no-any-return]
+ session_factory=session_factory,
+ user=user,
+ app_id=app_id,
+ triggered_from=triggered_from,
+ )
+ except RepositoryImportError:
+ # Re-raise our custom errors as-is
+ raise
+ except Exception as e:
+ logger.exception("Failed to create WorkflowNodeExecutionRepository")
+ raise RepositoryImportError(
+ f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}"
+ ) from e
diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py
index 0b3e5eb424..c579ff4028 100644
--- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py
+++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py
@@ -6,7 +6,6 @@ import json
import logging
from typing import Optional, Union
-from sqlalchemy import select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
@@ -206,44 +205,3 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
# Update the in-memory cache for faster subsequent lookups
logger.debug(f"Updating cache for execution_id: {db_model.id}")
self._execution_cache[db_model.id] = db_model
-
- def get(self, execution_id: str) -> Optional[WorkflowExecution]:
- """
- Retrieve a WorkflowExecution by its ID.
-
- First checks the in-memory cache, and if not found, queries the database.
- If found in the database, adds it to the cache for future lookups.
-
- Args:
- execution_id: The workflow execution ID
-
- Returns:
- The WorkflowExecution instance if found, None otherwise
- """
- # First check the cache
- if execution_id in self._execution_cache:
- logger.debug(f"Cache hit for execution_id: {execution_id}")
- # Convert cached DB model to domain model
- cached_db_model = self._execution_cache[execution_id]
- return self._to_domain_model(cached_db_model)
-
- # If not in cache, query the database
- logger.debug(f"Cache miss for execution_id: {execution_id}, querying database")
- with self._session_factory() as session:
- stmt = select(WorkflowRun).where(
- WorkflowRun.id == execution_id,
- WorkflowRun.tenant_id == self._tenant_id,
- )
-
- if self._app_id:
- stmt = stmt.where(WorkflowRun.app_id == self._app_id)
-
- db_model = session.scalar(stmt)
- if db_model:
- # Add DB model to cache
- self._execution_cache[execution_id] = db_model
-
- # Convert to domain model and return
- return self._to_domain_model(db_model)
-
- return None
diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
index a5feeb0d7c..d4a31390f8 100644
--- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
+++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
@@ -7,7 +7,7 @@ import logging
from collections.abc import Sequence
from typing import Optional, Union
-from sqlalchemy import UnaryExpression, asc, delete, desc, select
+from sqlalchemy import UnaryExpression, asc, desc, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
@@ -218,47 +218,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}")
self._node_execution_cache[db_model.node_execution_id] = db_model
- def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
- """
- Retrieve a NodeExecution by its node_execution_id.
-
- First checks the in-memory cache, and if not found, queries the database.
- If found in the database, adds it to the cache for future lookups.
-
- Args:
- node_execution_id: The node execution ID
-
- Returns:
- The NodeExecution instance if found, None otherwise
- """
- # First check the cache
- if node_execution_id in self._node_execution_cache:
- logger.debug(f"Cache hit for node_execution_id: {node_execution_id}")
- # Convert cached DB model to domain model
- cached_db_model = self._node_execution_cache[node_execution_id]
- return self._to_domain_model(cached_db_model)
-
- # If not in cache, query the database
- logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database")
- with self._session_factory() as session:
- stmt = select(WorkflowNodeExecutionModel).where(
- WorkflowNodeExecutionModel.node_execution_id == node_execution_id,
- WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
- )
-
- if self._app_id:
- stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
-
- db_model = session.scalar(stmt)
- if db_model:
- # Add DB model to cache
- self._node_execution_cache[node_execution_id] = db_model
-
- # Convert to domain model and return
- return self._to_domain_model(db_model)
-
- return None
-
def get_db_models_by_workflow_run(
self,
workflow_run_id: str,
@@ -344,68 +303,3 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
domain_models.append(domain_model)
return domain_models
-
- def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
- """
- Retrieve all running NodeExecution instances for a specific workflow run.
-
- This method queries the database directly and updates the cache with any
- retrieved executions that have a node_execution_id.
-
- Args:
- workflow_run_id: The workflow run ID
-
- Returns:
- A list of running NodeExecution instances
- """
- with self._session_factory() as session:
- stmt = select(WorkflowNodeExecutionModel).where(
- WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
- WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
- WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING,
- WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
- )
-
- if self._app_id:
- stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
-
- db_models = session.scalars(stmt).all()
- domain_models = []
-
- for model in db_models:
- # Update cache if node_execution_id is present
- if model.node_execution_id:
- self._node_execution_cache[model.node_execution_id] = model
-
- # Convert to domain model
- domain_model = self._to_domain_model(model)
- domain_models.append(domain_model)
-
- return domain_models
-
- def clear(self) -> None:
- """
- Clear all WorkflowNodeExecution records for the current tenant_id and app_id.
-
- This method deletes all WorkflowNodeExecution records that match the tenant_id
- and app_id (if provided) associated with this repository instance.
- It also clears the in-memory cache.
- """
- with self._session_factory() as session:
- stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id)
-
- if self._app_id:
- stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id)
-
- result = session.execute(stmt)
- session.commit()
-
- deleted_count = result.rowcount
- logger.info(
- f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
- + (f" and app {self._app_id}" if self._app_id else "")
- )
-
- # Clear the in-memory cache
- self._node_execution_cache.clear()
- logger.info("Cleared in-memory node execution cache")
diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py
index c9e157cb77..ddec7b1329 100644
--- a/api/core/tools/__base/tool_runtime.py
+++ b/api/core/tools/__base/tool_runtime.py
@@ -4,7 +4,7 @@ from openai import BaseModel
from pydantic import Field
from core.app.entities.app_invoke_entities import InvokeFrom
-from core.tools.entities.tool_entities import ToolInvokeFrom
+from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom
class ToolRuntime(BaseModel):
@@ -17,6 +17,7 @@ class ToolRuntime(BaseModel):
invoke_from: Optional[InvokeFrom] = None
tool_invoke_from: Optional[ToolInvokeFrom] = None
credentials: dict[str, Any] = Field(default_factory=dict)
+ credential_type: CredentialType = Field(default=CredentialType.API_KEY)
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py
index cf75bd3d7e..a70ded9efd 100644
--- a/api/core/tools/builtin_tool/provider.py
+++ b/api/core/tools/builtin_tool/provider.py
@@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.tool import BuiltinTool
-from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType
+from core.tools.entities.tool_entities import (
+ CredentialType,
+ OAuthSchema,
+ ToolEntity,
+ ToolProviderEntity,
+ ToolProviderType,
+)
from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict
from core.tools.errors import (
ToolProviderNotFoundError,
@@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController):
credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
credentials_schema.append(credential_dict)
+ oauth_schema = None
+ if provider_yaml.get("oauth_schema", None) is not None:
+ oauth_schema = OAuthSchema(
+ client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []),
+ credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []),
+ )
+
super().__init__(
entity=ToolProviderEntity(
identity=provider_yaml["identity"],
credentials_schema=credentials_schema,
+ oauth_schema=oauth_schema,
),
)
@@ -97,10 +111,39 @@ class BuiltinToolProviderController(ToolProviderController):
:return: the credentials schema
"""
- if not self.entity.credentials_schema:
- return []
+ return self.get_credentials_schema_by_type(CredentialType.API_KEY.value)
+
+ def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]:
+ """
+ returns the credentials schema of the provider
- return self.entity.credentials_schema.copy()
+ :param credential_type: the type of the credential
+ :return: the credentials schema of the provider
+ """
+ if credential_type == CredentialType.OAUTH2.value:
+ return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else []
+ if credential_type == CredentialType.API_KEY.value:
+ return self.entity.credentials_schema.copy() if self.entity.credentials_schema else []
+ raise ValueError(f"Invalid credential type: {credential_type}")
+
+ def get_oauth_client_schema(self) -> list[ProviderConfig]:
+ """
+ returns the oauth client schema of the provider
+
+ :return: the oauth client schema
+ """
+ return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else []
+
+ def get_supported_credential_types(self) -> list[str]:
+ """
+ returns the credential support type of the provider
+ """
+ types = []
+ if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0:
+ types.append(CredentialType.API_KEY.value)
+ if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0:
+ types.append(CredentialType.OAUTH2.value)
+ return types
def get_tools(self) -> list[BuiltinTool]:
"""
@@ -123,7 +166,11 @@ class BuiltinToolProviderController(ToolProviderController):
:return: whether the provider needs credentials
"""
- return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
+ return (
+ self.entity.credentials_schema is not None
+ and len(self.entity.credentials_schema) != 0
+ or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0)
+ )
@property
def provider_type(self) -> ToolProviderType:
diff --git a/api/core/tools/custom_tool/provider.py b/api/core/tools/custom_tool/provider.py
index fbe1d79137..95fab6151a 100644
--- a/api/core/tools/custom_tool/provider.py
+++ b/api/core/tools/custom_tool/provider.py
@@ -178,7 +178,7 @@ class ApiToolProviderController(ToolProviderController):
# get tenant api providers
db_providers: list[ApiToolProvider] = (
db.session.query(ApiToolProvider)
- .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
+ .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == self.entity.identity.name)
.all()
)
diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py
index 90134ba71d..27ce96b90e 100644
--- a/api/core/tools/entities/api_entities.py
+++ b/api/core/tools/entities/api_entities.py
@@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, field_validator
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.__base.tool import ToolParameter
from core.tools.entities.common_entities import I18nObject
-from core.tools.entities.tool_entities import ToolProviderType
+from core.tools.entities.tool_entities import CredentialType, ToolProviderType
class ToolApiEntity(BaseModel):
@@ -87,3 +87,22 @@ class ToolProviderApiEntity(BaseModel):
def optional_field(self, key: str, value: Any) -> dict:
"""Return dict with key-value if value is truthy, empty dict otherwise."""
return {key: value} if value else {}
+
+
+class ToolProviderCredentialApiEntity(BaseModel):
+ id: str = Field(description="The unique id of the credential")
+ name: str = Field(description="The name of the credential")
+ provider: str = Field(description="The provider of the credential")
+ credential_type: CredentialType = Field(description="The type of the credential")
+ is_default: bool = Field(
+ default=False, description="Whether the credential is the default credential for the provider in the workspace"
+ )
+ credentials: dict = Field(description="The credentials of the provider")
+
+
+class ToolProviderCredentialInfoApiEntity(BaseModel):
+ supported_credential_types: list[str] = Field(description="The supported credential types of the provider")
+ is_oauth_custom_client_enabled: bool = Field(
+ default=False, description="Whether the OAuth custom client is enabled for the provider"
+ )
+ credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider")
diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py
index b5148e245f..5377cbbb69 100644
--- a/api/core/tools/entities/tool_entities.py
+++ b/api/core/tools/entities/tool_entities.py
@@ -16,6 +16,7 @@ from core.plugin.entities.parameters import (
cast_parameter_value,
init_frontend_parameter,
)
+from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY
@@ -179,6 +180,10 @@ class ToolInvokeMessage(BaseModel):
data: Mapping[str, Any] = Field(..., description="Detailed log data")
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
+ class RetrieverResourceMessage(BaseModel):
+ retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources")
+ context: str = Field(..., description="context")
+
class MessageType(Enum):
TEXT = "text"
IMAGE = "image"
@@ -191,13 +196,22 @@ class ToolInvokeMessage(BaseModel):
FILE = "file"
LOG = "log"
BLOB_CHUNK = "blob_chunk"
+ RETRIEVER_RESOURCES = "retriever_resources"
type: MessageType = MessageType.TEXT
"""
plain text, image url or link url
"""
message: (
- JsonMessage | TextMessage | BlobChunkMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage
+ JsonMessage
+ | TextMessage
+ | BlobChunkMessage
+ | BlobMessage
+ | LogMessage
+ | FileMessage
+ | None
+ | VariableMessage
+ | RetrieverResourceMessage
)
meta: dict[str, Any] | None = None
@@ -243,6 +257,7 @@ class ToolParameter(PluginParameter):
FILES = PluginParameterType.FILES.value
APP_SELECTOR = PluginParameterType.APP_SELECTOR.value
MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value
+ ANY = PluginParameterType.ANY.value
DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value
# MCP object and array type parameters
@@ -355,10 +370,18 @@ class ToolEntity(BaseModel):
return v or []
+class OAuthSchema(BaseModel):
+ client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client")
+ credentials_schema: list[ProviderConfig] = Field(
+ default_factory=list, description="The schema of the OAuth credentials"
+ )
+
+
class ToolProviderEntity(BaseModel):
identity: ToolProviderIdentity
plugin_id: Optional[str] = None
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
+ oauth_schema: Optional[OAuthSchema] = None
class ToolProviderEntityWithPlugin(ToolProviderEntity):
@@ -438,6 +461,7 @@ class ToolSelector(BaseModel):
options: Optional[list[PluginParameterOption]] = None
provider_id: str = Field(..., description="The id of the provider")
+ credential_id: Optional[str] = Field(default=None, description="The id of the credential")
tool_name: str = Field(..., description="The name of the tool")
tool_description: str = Field(..., description="The description of the tool")
tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form")
@@ -445,3 +469,36 @@ class ToolSelector(BaseModel):
def to_plugin_parameter(self) -> dict[str, Any]:
return self.model_dump()
+
+
+class CredentialType(enum.StrEnum):
+ API_KEY = "api-key"
+ OAUTH2 = "oauth2"
+
+ def get_name(self):
+ if self == CredentialType.API_KEY:
+ return "API KEY"
+ elif self == CredentialType.OAUTH2:
+ return "AUTH"
+ else:
+ return self.value.replace("-", " ").upper()
+
+ def is_editable(self):
+ return self == CredentialType.API_KEY
+
+ def is_validate_allowed(self):
+ return self == CredentialType.API_KEY
+
+ @classmethod
+ def values(cls):
+ return [item.value for item in cls]
+
+ @classmethod
+ def of(cls, credential_type: str) -> "CredentialType":
+ type_name = credential_type.lower()
+ if type_name == "api-key":
+ return cls.API_KEY
+ elif type_name == "oauth2":
+ return cls.OAUTH2
+ else:
+ raise ValueError(f"Invalid credential type: {credential_type}")
diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py
index d21e3d7d1c..aef2677c36 100644
--- a/api/core/tools/plugin_tool/tool.py
+++ b/api/core/tools/plugin_tool/tool.py
@@ -44,6 +44,7 @@ class PluginTool(Tool):
tool_provider=self.entity.identity.provider,
tool_name=self.entity.identity.name,
credentials=self.runtime.credentials,
+ credential_type=self.runtime.credential_type,
tool_parameters=tool_parameters,
conversation_id=conversation_id,
app_id=app_id,
diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py
index ece02f9d59..ff054041cf 100644
--- a/api/core/tools/tool_file_manager.py
+++ b/api/core/tools/tool_file_manager.py
@@ -160,7 +160,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
- .filter(
+ .where(
ToolFile.id == id,
)
.first()
@@ -184,7 +184,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session:
message_file: MessageFile | None = (
session.query(MessageFile)
- .filter(
+ .where(
MessageFile.id == id,
)
.first()
@@ -204,7 +204,7 @@ class ToolFileManager:
tool_file: ToolFile | None = (
session.query(ToolFile)
- .filter(
+ .where(
ToolFile.id == tool_file_id,
)
.first()
@@ -228,7 +228,7 @@ class ToolFileManager:
with Session(self._engine, expire_on_commit=False) as session:
tool_file: ToolFile | None = (
session.query(ToolFile)
- .filter(
+ .where(
ToolFile.id == tool_file_id,
)
.first()
diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py
index 4787d7d79c..cdfefbadb3 100644
--- a/api/core/tools/tool_label_manager.py
+++ b/api/core/tools/tool_label_manager.py
@@ -29,7 +29,7 @@ class ToolLabelManager:
raise ValueError("Unsupported tool type")
# delete old labels
- db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id == provider_id).delete()
+ db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete()
# insert new labels
for label in labels:
@@ -57,7 +57,7 @@ class ToolLabelManager:
labels = (
db.session.query(ToolLabelBinding.label_name)
- .filter(
+ .where(
ToolLabelBinding.tool_id == provider_id,
ToolLabelBinding.tool_type == controller.provider_type.value,
)
@@ -90,7 +90,7 @@ class ToolLabelManager:
provider_ids.append(controller.provider_id)
labels: list[ToolLabelBinding] = (
- db.session.query(ToolLabelBinding).filter(ToolLabelBinding.tool_id.in_(provider_ids)).all()
+ db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id.in_(provider_ids)).all()
)
tool_labels: dict[str, list[str]] = {label.tool_id: [] for label in labels}
diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py
index 22a9853b41..f286466de0 100644
--- a/api/core/tools/tool_manager.py
+++ b/api/core/tools/tool_manager.py
@@ -1,15 +1,19 @@
import json
import logging
import mimetypes
-from collections.abc import Generator
+import time
+from collections.abc import Generator, Mapping
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
+from pydantic import TypeAdapter
from yarl import URL
import contexts
+from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
+from core.plugin.impl.oauth import OAuthHandler
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
@@ -17,14 +21,14 @@ from core.tools.mcp_tool.provider import MCPToolProviderController
from core.tools.mcp_tool.tool import MCPTool
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.plugin_tool.tool import PluginTool
+from core.tools.utils.uuid_utils import is_valid_uuid
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.workflow.entities.variable_pool import VariablePool
-from services.tools.mcp_tools_mange_service import MCPToolManageService
+from services.tools.mcp_tools_manage_service import MCPToolManageService
if TYPE_CHECKING:
from core.workflow.nodes.tool.entities import ToolEntity
-
from configs import dify_config
from core.agent.entities import AgentToolEntity
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -41,16 +45,17 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
+ CredentialType,
ToolInvokeFrom,
ToolParameter,
ToolProviderType,
)
-from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError
+from core.tools.errors import ToolProviderNotFoundError
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.utils.configuration import (
- ProviderConfigEncrypter,
ToolParameterConfigurationManager,
)
+from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.tool import WorkflowTool
from extensions.ext_database import db
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
@@ -68,8 +73,11 @@ class ToolManager:
@classmethod
def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController:
"""
+
get the hardcoded provider
+
"""
+
if len(cls._hardcoded_providers) == 0:
# init the builtin providers
cls.load_hardcoded_providers_cache()
@@ -113,7 +121,12 @@ class ToolManager:
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(Lock())
+ plugin_tool_providers = contexts.plugin_tool_providers.get()
+ if provider in plugin_tool_providers:
+ return plugin_tool_providers[provider]
+
with contexts.plugin_tool_providers_lock.get():
+ # double check
plugin_tool_providers = contexts.plugin_tool_providers.get()
if provider in plugin_tool_providers:
return plugin_tool_providers[provider]
@@ -131,25 +144,7 @@ class ToolManager:
)
plugin_tool_providers[provider] = controller
-
- return controller
-
- @classmethod
- def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None:
- """
- get the builtin tool
-
- :param provider: the name of the provider
- :param tool_name: the name of the tool
- :param tenant_id: the id of the tenant
- :return: the provider, the tool
- """
- provider_controller = cls.get_builtin_provider(provider, tenant_id)
- tool = provider_controller.get_tool(tool_name)
- if tool is None:
- raise ToolNotFoundError(f"tool {tool_name} not found")
-
- return tool
+ return controller
@classmethod
def get_tool_runtime(
@@ -160,6 +155,7 @@ class ToolManager:
tenant_id: str,
invoke_from: InvokeFrom = InvokeFrom.DEBUGGER,
tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT,
+ credential_id: Optional[str] = None,
) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]:
"""
get the tool runtime
@@ -170,6 +166,7 @@ class ToolManager:
:param tenant_id: the tenant id
:param invoke_from: invoke from
:param tool_invoke_from: the tool invoke from
+ :param credential_id: the credential id
:return: the tool
"""
@@ -193,49 +190,105 @@ class ToolManager:
)
),
)
-
+ builtin_provider = None
if isinstance(provider_controller, PluginToolProviderController):
provider_id_entity = ToolProviderID(provider_id)
- # get credentials
- builtin_provider: BuiltinToolProvider | None = (
- db.session.query(BuiltinToolProvider)
- .filter(
- BuiltinToolProvider.tenant_id == tenant_id,
- (BuiltinToolProvider.provider == str(provider_id_entity))
- | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
- )
- .first()
- )
-
+ # get specific credentials
+ if is_valid_uuid(credential_id):
+ try:
+ builtin_provider = (
+ db.session.query(BuiltinToolProvider)
+ .where(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ BuiltinToolProvider.id == credential_id,
+ )
+ .first()
+ )
+ except Exception as e:
+ builtin_provider = None
+ logger.info(f"Error getting builtin provider {credential_id}:{e}", exc_info=True)
+ # if the provider has been deleted, raise an error
+ if builtin_provider is None:
+ raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}")
+
+ # fallback to the default provider
if builtin_provider is None:
- raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
+ # use the default provider
+ builtin_provider = (
+ db.session.query(BuiltinToolProvider)
+ .where(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ (BuiltinToolProvider.provider == str(provider_id_entity))
+ | (BuiltinToolProvider.provider == provider_id_entity.provider_name),
+ )
+ .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
+ .first()
+ )
+ if builtin_provider is None:
+ raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
builtin_provider = (
db.session.query(BuiltinToolProvider)
- .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
+ .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
+ .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.first()
)
if builtin_provider is None:
raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found")
- # decrypt the credentials
- credentials = builtin_provider.credentials
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ config=[
+ x.to_basic_provider_config()
+ for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type)
+ ],
+ cache=ToolProviderCredentialsCache(
+ tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
+ ),
)
- decrypted_credentials = tool_configuration.decrypt(credentials)
+ # decrypt the credentials
+ decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
+
+ # check if the credentials is expired
+ if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
+ # TODO: circular import
+ from services.tools.builtin_tools_manage_service import BuiltinToolManageService
+
+ # refresh the credentials
+ tool_provider = ToolProviderID(provider_id)
+ provider_name = tool_provider.provider_name
+ redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
+ system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
+ oauth_handler = OAuthHandler()
+ # refresh the credentials
+ refreshed_credentials = oauth_handler.refresh_credentials(
+ tenant_id=tenant_id,
+ user_id=builtin_provider.user_id,
+ plugin_id=tool_provider.plugin_id,
+ provider=provider_name,
+ redirect_uri=redirect_uri,
+ system_credentials=system_credentials or {},
+ credentials=decrypted_credentials,
+ )
+ # update the credentials
+ builtin_provider.encrypted_credentials = (
+ TypeAdapter(dict[str, Any])
+ .dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
+ .decode("utf-8")
+ )
+ builtin_provider.expires_at = refreshed_credentials.expires_at
+ db.session.commit()
+ decrypted_credentials = refreshed_credentials.credentials
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
- credentials=decrypted_credentials,
+ credentials=dict(decrypted_credentials),
+ credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
@@ -245,22 +298,16 @@ class ToolManager:
elif provider_type == ToolProviderType.API:
api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id)
-
- # decrypt the credentials
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()],
- provider_type=api_provider.provider_type.value,
- provider_identity=api_provider.entity.identity.name,
+ controller=api_provider,
)
- decrypted_credentials = tool_configuration.decrypt(credentials)
-
return cast(
ApiTool,
api_provider.get_tool(tool_name).fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
- credentials=decrypted_credentials,
+ credentials=encrypter.decrypt(credentials),
invoke_from=invoke_from,
tool_invoke_from=tool_invoke_from,
)
@@ -269,7 +316,7 @@ class ToolManager:
elif provider_type == ToolProviderType.WORKFLOW:
workflow_provider = (
db.session.query(WorkflowToolProvider)
- .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
+ .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
)
@@ -320,6 +367,7 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.AGENT,
+ credential_id=agent_tool.credential_id,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
@@ -362,6 +410,7 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=invoke_from,
tool_invoke_from=ToolInvokeFrom.WORKFLOW,
+ credential_id=workflow_tool.credential_id,
)
parameters = tool_runtime.get_merged_runtime_parameters()
@@ -391,6 +440,7 @@ class ToolManager:
provider: str,
tool_name: str,
tool_parameters: dict[str, Any],
+ credential_id: Optional[str] = None,
) -> Tool:
"""
get tool runtime from plugin
@@ -402,6 +452,7 @@ class ToolManager:
tenant_id=tenant_id,
invoke_from=InvokeFrom.SERVICE_API,
tool_invoke_from=ToolInvokeFrom.PLUGIN,
+ credential_id=credential_id,
)
runtime_parameters = {}
parameters = tool_entity.get_merged_runtime_parameters()
@@ -551,6 +602,22 @@ class ToolManager:
return cls._builtin_tools_labels[tool_name]
+ @classmethod
+ def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]:
+ """
+ list all the builtin providers
+ """
+ # according to multi credentials, select the one with is_default=True first, then created_at oldest
+ # for compatibility with old version
+ sql = """
+ SELECT DISTINCT ON (tenant_id, provider) id
+ FROM tool_builtin_providers
+ WHERE tenant_id = :tenant_id
+ ORDER BY tenant_id, provider, is_default DESC, created_at DESC
+ """
+ ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()]
+ return db.session.query(BuiltinToolProvider).where(BuiltinToolProvider.id.in_(ids)).all()
+
@classmethod
def list_providers_from_api(
cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral
@@ -565,21 +632,13 @@ class ToolManager:
with db.session.no_autoflush:
if "builtin" in filters:
- # get builtin providers
builtin_providers = cls.list_builtin_providers(tenant_id)
- # get db builtin providers
- db_builtin_providers: list[BuiltinToolProvider] = (
- db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
- )
-
- # rewrite db_builtin_providers
- for db_provider in db_builtin_providers:
- tool_provider_id = str(ToolProviderID(db_provider.provider))
- db_provider.provider = tool_provider_id
-
- def find_db_builtin_provider(provider):
- return next((x for x in db_builtin_providers if x.provider == provider), None)
+ # key: provider name, value: provider
+ db_builtin_providers = {
+ str(ToolProviderID(provider.provider)): provider
+ for provider in cls.list_default_builtin_providers(tenant_id)
+ }
# append builtin providers
for provider in builtin_providers:
@@ -591,10 +650,9 @@ class ToolManager:
name_func=lambda x: x.identity.name,
):
continue
-
user_provider = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider,
- db_provider=find_db_builtin_provider(provider.entity.identity.name),
+ db_provider=db_builtin_providers.get(provider.entity.identity.name),
decrypt_credentials=False,
)
@@ -604,10 +662,9 @@ class ToolManager:
result_providers[f"builtin_provider.{user_provider.name}"] = user_provider
# get db api providers
-
if "api" in filters:
db_api_providers: list[ApiToolProvider] = (
- db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all()
+ db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all()
)
api_provider_controllers: list[dict[str, Any]] = [
@@ -630,7 +687,7 @@ class ToolManager:
if "workflow" in filters:
# get workflow providers
workflow_providers: list[WorkflowToolProvider] = (
- db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
+ db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
)
workflow_provider_controllers: list[WorkflowToolProviderController] = []
@@ -674,7 +731,7 @@ class ToolManager:
"""
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
- .filter(
+ .where(
ApiToolProvider.id == provider_id,
ApiToolProvider.tenant_id == tenant_id,
)
@@ -711,7 +768,7 @@ class ToolManager:
"""
provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
- .filter(
+ .where(
MCPToolProvider.server_identifier == provider_id,
MCPToolProvider.tenant_id == tenant_id,
)
@@ -736,7 +793,7 @@ class ToolManager:
provider_name = provider
provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
- .filter(
+ .where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
)
@@ -764,15 +821,12 @@ class ToolManager:
auth_type,
)
# init tool configuration
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
- provider_type=controller.provider_type.value,
- provider_identity=controller.entity.identity.name,
+ controller=controller,
)
- decrypted_credentials = tool_configuration.decrypt(credentials)
- masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
+ masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials))
try:
icon = json.loads(provider_obj.icon)
@@ -831,7 +885,7 @@ class ToolManager:
try:
workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
- .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
+ .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.first()
)
@@ -848,7 +902,7 @@ class ToolManager:
try:
api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
- .filter(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
+ .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.first()
)
@@ -865,7 +919,7 @@ class ToolManager:
try:
mcp_provider: MCPToolProvider | None = (
db.session.query(MCPToolProvider)
- .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
+ .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == provider_id)
.first()
)
diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py
index 251fedf56e..aceba6e69f 100644
--- a/api/core/tools/utils/configuration.py
+++ b/api/core/tools/utils/configuration.py
@@ -1,12 +1,8 @@
from copy import deepcopy
from typing import Any
-from pydantic import BaseModel
-
-from core.entities.provider_entities import BasicProviderConfig
from core.helper import encrypter
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
-from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import (
ToolParameter,
@@ -14,110 +10,6 @@ from core.tools.entities.tool_entities import (
)
-class ProviderConfigEncrypter(BaseModel):
- tenant_id: str
- config: list[BasicProviderConfig]
- provider_type: str
- provider_identity: str
-
- def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
- """
- deep copy data
- """
- return deepcopy(data)
-
- def encrypt(self, data: dict[str, str]) -> dict[str, str]:
- """
- encrypt tool credentials with tenant id
-
- return a deep copy of credentials with encrypted values
- """
- data = self._deep_copy(data)
-
- # get fields need to be decrypted
- fields = dict[str, BasicProviderConfig]()
- for credential in self.config:
- fields[credential.name] = credential
-
- for field_name, field in fields.items():
- if field.type == BasicProviderConfig.Type.SECRET_INPUT:
- if field_name in data:
- encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
- data[field_name] = encrypted
-
- return data
-
- def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
- """
- mask tool credentials
-
- return a deep copy of credentials with masked values
- """
- data = self._deep_copy(data)
-
- # get fields need to be decrypted
- fields = dict[str, BasicProviderConfig]()
- for credential in self.config:
- fields[credential.name] = credential
-
- for field_name, field in fields.items():
- if field.type == BasicProviderConfig.Type.SECRET_INPUT:
- if field_name in data:
- if len(data[field_name]) > 6:
- data[field_name] = (
- data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
- )
- else:
- data[field_name] = "*" * len(data[field_name])
-
- return data
-
- def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, str]:
- """
- decrypt tool credentials with tenant id
-
- return a deep copy of credentials with decrypted values
- """
- if use_cache:
- cache = ToolProviderCredentialsCache(
- tenant_id=self.tenant_id,
- identity_id=f"{self.provider_type}.{self.provider_identity}",
- cache_type=ToolProviderCredentialsCacheType.PROVIDER,
- )
- cached_credentials = cache.get()
- if cached_credentials:
- return cached_credentials
- data = self._deep_copy(data)
- # get fields need to be decrypted
- fields = dict[str, BasicProviderConfig]()
- for credential in self.config:
- fields[credential.name] = credential
-
- for field_name, field in fields.items():
- if field.type == BasicProviderConfig.Type.SECRET_INPUT:
- if field_name in data:
- try:
- # if the value is None or empty string, skip decrypt
- if not data[field_name]:
- continue
-
- data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
- except Exception:
- pass
-
- if use_cache:
- cache.set(data)
- return data
-
- def delete_tool_credentials_cache(self):
- cache = ToolProviderCredentialsCache(
- tenant_id=self.tenant_id,
- identity_id=f"{self.provider_type}.{self.provider_identity}",
- cache_type=ToolProviderCredentialsCacheType.PROVIDER,
- )
- cache.delete()
-
-
class ToolParameterConfigurationManager:
"""
Tool parameter configuration manager
diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py
index 2cbc4b9821..7eb4bc017a 100644
--- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py
+++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py
@@ -87,7 +87,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = (
db.session.query(DocumentSegment)
- .filter(
+ .where(
DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
@@ -114,7 +114,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(Document)
- .filter(
+ .where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
@@ -163,7 +163,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
):
with flask_app.app_context():
dataset = (
- db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
+ db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == dataset_id).first()
)
if not dataset:
diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py
index a4d2de3b1c..567275531e 100644
--- a/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py
+++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_base_tool.py
@@ -1,5 +1,5 @@
from abc import abstractmethod
-from typing import Any, Optional
+from typing import Optional
from msal_extensions.persistence import ABC # type: ignore
from pydantic import BaseModel, ConfigDict
@@ -21,11 +21,7 @@ class DatasetRetrieverBaseTool(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod
- def _run(
- self,
- *args: Any,
- **kwargs: Any,
- ) -> Any:
+ def _run(self, query: str) -> str:
"""Use the tool.
Add run_manager: Optional[CallbackManagerForToolRun] = None
diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
index ff1d9021ce..f7689d7707 100644
--- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
+++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
@@ -57,7 +57,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
def _run(self, query: str) -> str:
dataset = (
- db.session.query(Dataset).filter(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
+ db.session.query(Dataset).where(Dataset.tenant_id == self.tenant_id, Dataset.id == self.dataset_id).first()
)
if not dataset:
@@ -190,7 +190,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = (
db.session.query(DatasetDocument) # type: ignore
- .filter(
+ .where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py
new file mode 100644
index 0000000000..5fdfd3b9d1
--- /dev/null
+++ b/api/core/tools/utils/encryption.py
@@ -0,0 +1,142 @@
+from copy import deepcopy
+from typing import Any, Optional, Protocol
+
+from core.entities.provider_entities import BasicProviderConfig
+from core.helper import encrypter
+from core.helper.provider_cache import SingletonProviderCredentialsCache
+from core.tools.__base.tool_provider import ToolProviderController
+
+
+class ProviderConfigCache(Protocol):
+ """
+ Interface for provider configuration cache operations
+ """
+
+ def get(self) -> Optional[dict]:
+ """Get cached provider configuration"""
+ ...
+
+ def set(self, config: dict[str, Any]) -> None:
+ """Cache provider configuration"""
+ ...
+
+ def delete(self) -> None:
+ """Delete cached provider configuration"""
+ ...
+
+
+class ProviderConfigEncrypter:
+ tenant_id: str
+ config: list[BasicProviderConfig]
+ provider_config_cache: ProviderConfigCache
+
+ def __init__(
+ self,
+ tenant_id: str,
+ config: list[BasicProviderConfig],
+ provider_config_cache: ProviderConfigCache,
+ ):
+ self.tenant_id = tenant_id
+ self.config = config
+ self.provider_config_cache = provider_config_cache
+
+ def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
+ """
+ deep copy data
+ """
+ return deepcopy(data)
+
+ def encrypt(self, data: dict[str, str]) -> dict[str, str]:
+ """
+ encrypt tool credentials with tenant id
+
+ return a deep copy of credentials with encrypted values
+ """
+ data = self._deep_copy(data)
+
+ # get fields need to be decrypted
+ fields = dict[str, BasicProviderConfig]()
+ for credential in self.config:
+ fields[credential.name] = credential
+
+ for field_name, field in fields.items():
+ if field.type == BasicProviderConfig.Type.SECRET_INPUT:
+ if field_name in data:
+ encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
+ data[field_name] = encrypted
+
+ return data
+
+ def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
+ """
+ mask tool credentials
+
+ return a deep copy of credentials with masked values
+ """
+ data = self._deep_copy(data)
+
+ # get fields need to be decrypted
+ fields = dict[str, BasicProviderConfig]()
+ for credential in self.config:
+ fields[credential.name] = credential
+
+ for field_name, field in fields.items():
+ if field.type == BasicProviderConfig.Type.SECRET_INPUT:
+ if field_name in data:
+ if len(data[field_name]) > 6:
+ data[field_name] = (
+ data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
+ )
+ else:
+ data[field_name] = "*" * len(data[field_name])
+
+ return data
+
+ def decrypt(self, data: dict[str, str]) -> dict[str, Any]:
+ """
+ decrypt tool credentials with tenant id
+
+ return a deep copy of credentials with decrypted values
+ """
+ cached_credentials = self.provider_config_cache.get()
+ if cached_credentials:
+ return cached_credentials
+
+ data = self._deep_copy(data)
+ # get fields need to be decrypted
+ fields = dict[str, BasicProviderConfig]()
+ for credential in self.config:
+ fields[credential.name] = credential
+
+ for field_name, field in fields.items():
+ if field.type == BasicProviderConfig.Type.SECRET_INPUT:
+ if field_name in data:
+ try:
+ # if the value is None or empty string, skip decrypt
+ if not data[field_name]:
+ continue
+
+ data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
+ except Exception:
+ pass
+
+ self.provider_config_cache.set(data)
+ return data
+
+
+def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache):
+ return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache
+
+
+def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController):
+ cache = SingletonProviderCredentialsCache(
+ tenant_id=tenant_id,
+ provider_type=controller.provider_type.value,
+ provider_identity=controller.entity.identity.name,
+ )
+ encrypt = ProviderConfigEncrypter(
+ tenant_id=tenant_id,
+ config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()],
+ provider_config_cache=cache,
+ )
+ return encrypt, cache
diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py
index 3f844e8234..a3c84615ca 100644
--- a/api/core/tools/utils/parser.py
+++ b/api/core/tools/utils/parser.py
@@ -1,5 +1,4 @@
import re
-import uuid
from json import dumps as json_dumps
from json import loads as json_loads
from json.decoder import JSONDecodeError
@@ -154,7 +153,7 @@ class ApiBasedToolSchemaParser:
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
if not path:
- path = str(uuid.uuid4())
+ path = ""
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
diff --git a/api/core/tools/utils/system_oauth_encryption.py b/api/core/tools/utils/system_oauth_encryption.py
new file mode 100644
index 0000000000..f3c946b95f
--- /dev/null
+++ b/api/core/tools/utils/system_oauth_encryption.py
@@ -0,0 +1,187 @@
+import base64
+import hashlib
+import logging
+from collections.abc import Mapping
+from typing import Any, Optional
+
+from Crypto.Cipher import AES
+from Crypto.Random import get_random_bytes
+from Crypto.Util.Padding import pad, unpad
+from pydantic import TypeAdapter
+
+from configs import dify_config
+
+logger = logging.getLogger(__name__)
+
+
+class OAuthEncryptionError(Exception):
+ """OAuth encryption/decryption specific error"""
+
+ pass
+
+
+class SystemOAuthEncrypter:
+ """
+ A simple OAuth parameters encrypter using AES-CBC encryption.
+
+ This class provides methods to encrypt and decrypt OAuth parameters
+ using AES-CBC mode with a key derived from the application's SECRET_KEY.
+ """
+
+ def __init__(self, secret_key: Optional[str] = None):
+ """
+ Initialize the OAuth encrypter.
+
+ Args:
+ secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
+
+ Raises:
+ ValueError: If SECRET_KEY is not configured or empty
+ """
+ secret_key = secret_key or dify_config.SECRET_KEY or ""
+
+ # Generate a fixed 256-bit key using SHA-256
+ self.key = hashlib.sha256(secret_key.encode()).digest()
+
+ def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str:
+ """
+ Encrypt OAuth parameters.
+
+ Args:
+ oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
+
+ Returns:
+ Base64-encoded encrypted string
+
+ Raises:
+ OAuthEncryptionError: If encryption fails
+ ValueError: If oauth_params is invalid
+ """
+
+ try:
+ # Generate random IV (16 bytes)
+ iv = get_random_bytes(16)
+
+ # Create AES cipher (CBC mode)
+ cipher = AES.new(self.key, AES.MODE_CBC, iv)
+
+ # Encrypt data
+ padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size)
+ encrypted_data = cipher.encrypt(padded_data)
+
+ # Combine IV and encrypted data
+ combined = iv + encrypted_data
+
+ # Return base64 encoded string
+ return base64.b64encode(combined).decode()
+
+ except Exception as e:
+ raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e
+
+ def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]:
+ """
+ Decrypt OAuth parameters.
+
+ Args:
+ encrypted_data: Base64-encoded encrypted string
+
+ Returns:
+ Decrypted OAuth parameters dictionary
+
+ Raises:
+ OAuthEncryptionError: If decryption fails
+ ValueError: If encrypted_data is invalid
+ """
+ if not isinstance(encrypted_data, str):
+ raise ValueError("encrypted_data must be a string")
+
+ if not encrypted_data:
+ raise ValueError("encrypted_data cannot be empty")
+
+ try:
+ # Base64 decode
+ combined = base64.b64decode(encrypted_data)
+
+ # Check minimum length (IV + at least one AES block)
+ if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
+ raise ValueError("Invalid encrypted data format")
+
+ # Separate IV and encrypted data
+ iv = combined[:16]
+ encrypted_data_bytes = combined[16:]
+
+ # Create AES cipher
+ cipher = AES.new(self.key, AES.MODE_CBC, iv)
+
+ # Decrypt data
+ decrypted_data = cipher.decrypt(encrypted_data_bytes)
+ unpadded_data = unpad(decrypted_data, AES.block_size)
+
+ # Parse JSON
+ oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
+
+ if not isinstance(oauth_params, dict):
+ raise ValueError("Decrypted data is not a valid dictionary")
+
+ return oauth_params
+
+ except Exception as e:
+ raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e
+
+
+# Factory function for creating encrypter instances
+def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter:
+ """
+ Create an OAuth encrypter instance.
+
+ Args:
+ secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
+
+ Returns:
+ SystemOAuthEncrypter instance
+ """
+ return SystemOAuthEncrypter(secret_key=secret_key)
+
+
+# Global encrypter instance (for backward compatibility)
+_oauth_encrypter: Optional[SystemOAuthEncrypter] = None
+
+
+def get_system_oauth_encrypter() -> SystemOAuthEncrypter:
+ """
+ Get the global OAuth encrypter instance.
+
+ Returns:
+ SystemOAuthEncrypter instance
+ """
+ global _oauth_encrypter
+ if _oauth_encrypter is None:
+ _oauth_encrypter = SystemOAuthEncrypter()
+ return _oauth_encrypter
+
+
+# Convenience functions for backward compatibility
+def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str:
+ """
+ Encrypt OAuth parameters using the global encrypter.
+
+ Args:
+ oauth_params: OAuth parameters dictionary
+
+ Returns:
+ Base64-encoded encrypted string
+ """
+ return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params)
+
+
+def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]:
+ """
+ Decrypt OAuth parameters using the global encrypter.
+
+ Args:
+ encrypted_data: Base64-encoded encrypted string
+
+ Returns:
+ Decrypted OAuth parameters dictionary
+ """
+ return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data)
diff --git a/api/core/tools/utils/uuid_utils.py b/api/core/tools/utils/uuid_utils.py
index 3046c08c89..bdcc33259d 100644
--- a/api/core/tools/utils/uuid_utils.py
+++ b/api/core/tools/utils/uuid_utils.py
@@ -1,7 +1,9 @@
import uuid
-def is_valid_uuid(uuid_str: str) -> bool:
+def is_valid_uuid(uuid_str: str | None) -> bool:
+ if uuid_str is None or len(uuid_str) == 0:
+ return False
try:
uuid.UUID(uuid_str)
return True
diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py
index 7661e1e6a5..83f5f558d5 100644
--- a/api/core/tools/workflow_as_tool/provider.py
+++ b/api/core/tools/workflow_as_tool/provider.py
@@ -84,7 +84,7 @@ class WorkflowToolProviderController(ToolProviderController):
"""
workflow: Workflow | None = (
db.session.query(Workflow)
- .filter(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
+ .where(Workflow.app_id == db_provider.app_id, Workflow.version == db_provider.version)
.first()
)
@@ -190,7 +190,7 @@ class WorkflowToolProviderController(ToolProviderController):
db_providers: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
- .filter(
+ .where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.app_id == self.provider_id,
)
diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py
index 10bf8ca640..8b89c2a7a9 100644
--- a/api/core/tools/workflow_as_tool/tool.py
+++ b/api/core/tools/workflow_as_tool/tool.py
@@ -142,12 +142,12 @@ class WorkflowTool(Tool):
if not version:
workflow = (
db.session.query(Workflow)
- .filter(Workflow.app_id == app_id, Workflow.version != "draft")
+ .where(Workflow.app_id == app_id, Workflow.version != "draft")
.order_by(Workflow.created_at.desc())
.first()
)
else:
- workflow = db.session.query(Workflow).filter(Workflow.app_id == app_id, Workflow.version == version).first()
+ workflow = db.session.query(Workflow).where(Workflow.app_id == app_id, Workflow.version == version).first()
if not workflow:
raise ValueError("workflow not found or not published")
@@ -158,7 +158,7 @@ class WorkflowTool(Tool):
"""
get the app by app id
"""
- app = db.session.query(App).filter(App.id == app_id).first()
+ app = db.session.query(App).where(App.id == app_id).first()
if not app:
raise ValueError("app not found")
diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py
index 6cf09e0372..13274f4e0e 100644
--- a/api/core/variables/segments.py
+++ b/api/core/variables/segments.py
@@ -1,9 +1,9 @@
import json
import sys
from collections.abc import Mapping, Sequence
-from typing import Any
+from typing import Annotated, Any, TypeAlias
-from pydantic import BaseModel, ConfigDict, field_validator
+from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.file import File
@@ -11,6 +11,11 @@ from .types import SegmentType
class Segment(BaseModel):
+ """Segment is runtime type used during the execution of workflow.
+
+ Note: this class is abstract, you should use subclasses of this class instead.
+ """
+
model_config = ConfigDict(frozen=True)
value_type: SegmentType
@@ -73,7 +78,7 @@ class StringSegment(Segment):
class FloatSegment(Segment):
- value_type: SegmentType = SegmentType.NUMBER
+ value_type: SegmentType = SegmentType.FLOAT
value: float
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass.
@@ -92,7 +97,7 @@ class FloatSegment(Segment):
class IntegerSegment(Segment):
- value_type: SegmentType = SegmentType.NUMBER
+ value_type: SegmentType = SegmentType.INTEGER
value: int
@@ -181,3 +186,46 @@ class ArrayFileSegment(ArraySegment):
@property
def text(self) -> str:
return ""
+
+
+def get_segment_discriminator(v: Any) -> SegmentType | None:
+ if isinstance(v, Segment):
+ return v.value_type
+ elif isinstance(v, dict):
+ value_type = v.get("value_type")
+ if value_type is None:
+ return None
+ try:
+ seg_type = SegmentType(value_type)
+ except ValueError:
+ return None
+ return seg_type
+ else:
+ # return None if the discriminator value isn't found
+ return None
+
+
+# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic.
+# Use `Segment` for type hinting when serialization is not required.
+#
+# Note:
+# - All variants in `SegmentUnion` must inherit from the `Segment` class.
+# - The union must include all non-abstract subclasses of `Segment`, except:
+# - `SegmentGroup`, which is not added to the variable pool.
+# - `Variable` and its subclasses, which are handled by `VariableUnion`.
+SegmentUnion: TypeAlias = Annotated[
+ (
+ Annotated[NoneSegment, Tag(SegmentType.NONE)]
+ | Annotated[StringSegment, Tag(SegmentType.STRING)]
+ | Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
+ | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
+ | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
+ | Annotated[FileSegment, Tag(SegmentType.FILE)]
+ | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
+ | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
+ | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
+ | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
+ | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
+ ),
+ Discriminator(get_segment_discriminator),
+]
diff --git a/api/core/variables/types.py b/api/core/variables/types.py
index 68d3d82883..e79b2410bf 100644
--- a/api/core/variables/types.py
+++ b/api/core/variables/types.py
@@ -1,8 +1,27 @@
+from collections.abc import Mapping
from enum import StrEnum
+from typing import Any, Optional
+
+from core.file.models import File
+
+
+class ArrayValidation(StrEnum):
+ """Strategy for validating array elements"""
+
+ # Skip element validation (only check array container)
+ NONE = "none"
+
+ # Validate the first element (if array is non-empty)
+ FIRST = "first"
+
+ # Validate all elements in the array.
+ ALL = "all"
class SegmentType(StrEnum):
NUMBER = "number"
+ INTEGER = "integer"
+ FLOAT = "float"
STRING = "string"
OBJECT = "object"
SECRET = "secret"
@@ -19,16 +38,139 @@ class SegmentType(StrEnum):
GROUP = "group"
- def is_array_type(self):
+ def is_array_type(self) -> bool:
return self in _ARRAY_TYPES
+ @classmethod
+ def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
+ """
+ Attempt to infer the `SegmentType` based on the Python type of the `value` parameter.
+
+ Returns `None` if no appropriate `SegmentType` can be determined for the given `value`.
+ For example, this may occur if the input is a generic Python object of type `object`.
+ """
+
+ if isinstance(value, list):
+ elem_types: set[SegmentType] = set()
+ for i in value:
+ segment_type = cls.infer_segment_type(i)
+ if segment_type is None:
+ return None
+
+ elem_types.add(segment_type)
+
+ if len(elem_types) != 1:
+ if elem_types.issubset(_NUMERICAL_TYPES):
+ return SegmentType.ARRAY_NUMBER
+ return SegmentType.ARRAY_ANY
+ elif all(i.is_array_type() for i in elem_types):
+ return SegmentType.ARRAY_ANY
+ match elem_types.pop():
+ case SegmentType.STRING:
+ return SegmentType.ARRAY_STRING
+ case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
+ return SegmentType.ARRAY_NUMBER
+ case SegmentType.OBJECT:
+ return SegmentType.ARRAY_OBJECT
+ case SegmentType.FILE:
+ return SegmentType.ARRAY_FILE
+ case SegmentType.NONE:
+ return SegmentType.ARRAY_ANY
+ case _:
+ # This should be unreachable.
+ raise ValueError(f"not supported value {value}")
+ if value is None:
+ return SegmentType.NONE
+ elif isinstance(value, int) and not isinstance(value, bool):
+ return SegmentType.INTEGER
+ elif isinstance(value, float):
+ return SegmentType.FLOAT
+ elif isinstance(value, str):
+ return SegmentType.STRING
+ elif isinstance(value, dict):
+ return SegmentType.OBJECT
+ elif isinstance(value, File):
+ return SegmentType.FILE
+ else:
+ return None
+
+ def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool:
+ if not isinstance(value, list):
+ return False
+ # Skip element validation if array is empty
+ if len(value) == 0:
+ return True
+ if self == SegmentType.ARRAY_ANY:
+ return True
+ element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self]
+
+ if array_validation == ArrayValidation.NONE:
+ return True
+ elif array_validation == ArrayValidation.FIRST:
+ return element_type.is_valid(value[0])
+ else:
+ return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value)
+
+ def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
+ """
+ Check if a value matches the segment type.
+ Users of `SegmentType` should call this method, instead of using
+ `isinstance` manually.
+
+ Args:
+ value: The value to validate
+ array_validation: Validation strategy for array types (ignored for non-array types)
+
+ Returns:
+ True if the value matches the type under the given validation strategy
+ """
+ if self.is_array_type():
+ return self._validate_array(value, array_validation)
+ elif self == SegmentType.NUMBER:
+ return isinstance(value, (int, float))
+ elif self == SegmentType.STRING:
+ return isinstance(value, str)
+ elif self == SegmentType.OBJECT:
+ return isinstance(value, dict)
+ elif self == SegmentType.SECRET:
+ return isinstance(value, str)
+ elif self == SegmentType.FILE:
+ return isinstance(value, File)
+ elif self == SegmentType.NONE:
+ return value is None
+ else:
+ raise AssertionError("this statement should be unreachable.")
+
+ def exposed_type(self) -> "SegmentType":
+ """Returns the type exposed to the frontend.
+
+ The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
+ """
+ if self in (SegmentType.INTEGER, SegmentType.FLOAT):
+ return SegmentType.NUMBER
+ return self
+
+
+_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
+ # ARRAY_ANY does not have correpond element type.
+ SegmentType.ARRAY_STRING: SegmentType.STRING,
+ SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
+ SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
+ SegmentType.ARRAY_FILE: SegmentType.FILE,
+}
_ARRAY_TYPES = frozenset(
- [
+ list(_ARRAY_ELEMENT_TYPES_MAPPING.keys())
+ + [
SegmentType.ARRAY_ANY,
- SegmentType.ARRAY_STRING,
- SegmentType.ARRAY_NUMBER,
- SegmentType.ARRAY_OBJECT,
- SegmentType.ARRAY_FILE,
+ ]
+)
+
+
+_NUMERICAL_TYPES = frozenset(
+ [
+ SegmentType.NUMBER,
+ SegmentType.INTEGER,
+ SegmentType.FLOAT,
]
)
diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py
index b650b1682e..a31ebc848e 100644
--- a/api/core/variables/variables.py
+++ b/api/core/variables/variables.py
@@ -1,8 +1,8 @@
from collections.abc import Sequence
-from typing import cast
+from typing import Annotated, TypeAlias, cast
from uuid import uuid4
-from pydantic import Field
+from pydantic import Discriminator, Field, Tag
from core.helper import encrypter
@@ -20,6 +20,7 @@ from .segments import (
ObjectSegment,
Segment,
StringSegment,
+ get_segment_discriminator,
)
from .types import SegmentType
@@ -27,6 +28,10 @@ from .types import SegmentType
class Variable(Segment):
"""
A variable is a segment that has a name.
+
+ It is mainly used to store segments and their selector in VariablePool.
+
+ Note: this class is abstract, you should use subclasses of this class instead.
"""
id: str = Field(
@@ -93,3 +98,28 @@ class FileVariable(FileSegment, Variable):
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass
+
+
+# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
+# Use `Variable` for type hinting when serialization is not required.
+#
+# Note:
+# - All variants in `VariableUnion` must inherit from the `Variable` class.
+# - The union must include all non-abstract subclasses of `Segment`, except:
+VariableUnion: TypeAlias = Annotated[
+ (
+ Annotated[NoneVariable, Tag(SegmentType.NONE)]
+ | Annotated[StringVariable, Tag(SegmentType.STRING)]
+ | Annotated[FloatVariable, Tag(SegmentType.FLOAT)]
+ | Annotated[IntegerVariable, Tag(SegmentType.INTEGER)]
+ | Annotated[ObjectVariable, Tag(SegmentType.OBJECT)]
+ | Annotated[FileVariable, Tag(SegmentType.FILE)]
+ | Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)]
+ | Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)]
+ | Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)]
+ | Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
+ | Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
+ | Annotated[SecretVariable, Tag(SegmentType.SECRET)]
+ ),
+ Discriminator(get_segment_discriminator),
+]
diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py
index 80dda2632d..fbb8df6b01 100644
--- a/api/core/workflow/entities/variable_pool.py
+++ b/api/core/workflow/entities/variable_pool.py
@@ -1,7 +1,7 @@
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
-from typing import Any, Union
+from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
@@ -9,8 +9,9 @@ from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment
+from core.variables.variables import VariableUnion
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
-from core.workflow.enums import SystemVariableKey
+from core.workflow.system_variable import SystemVariable
from factories import variable_factory
VariableValue = Union[str, int, float, dict, list, File]
@@ -23,31 +24,31 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
- variable_dictionary: dict[str, dict[int, Segment]] = Field(
+ variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
- # TODO: This user inputs is not used for pool.
+
+ # The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere.
user_inputs: Mapping[str, Any] = Field(
description="User inputs",
default_factory=dict,
)
- system_variables: Mapping[SystemVariableKey, Any] = Field(
+ system_variables: SystemVariable = Field(
description="System variables",
- default_factory=dict,
)
- environment_variables: Sequence[Variable] = Field(
+ environment_variables: Sequence[VariableUnion] = Field(
description="Environment variables.",
default_factory=list,
)
- conversation_variables: Sequence[Variable] = Field(
+ conversation_variables: Sequence[VariableUnion] = Field(
description="Conversation variables.",
default_factory=list,
)
def model_post_init(self, context: Any, /) -> None:
- for key, value in self.system_variables.items():
- self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
+ # Create a mapping from field names to SystemVariableKey enum values
+ self._add_system_variables(self.system_variables)
# Add environment variables to the variable pool
for var in self.environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
@@ -83,8 +84,22 @@ class VariablePool(BaseModel):
segment = variable_factory.build_segment(value)
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
- hash_key = hash(tuple(selector[1:]))
- self.variable_dictionary[selector[0]][hash_key] = variable
+ key, hash_key = self._selector_to_keys(selector)
+ # Based on the definition of `VariableUnion`,
+ # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
+ self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
+
+ @classmethod
+ def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]:
+ return selector[0], hash(tuple(selector[1:]))
+
+ def _has(self, selector: Sequence[str]) -> bool:
+ key, hash_key = self._selector_to_keys(selector)
+ if key not in self.variable_dictionary:
+ return False
+ if hash_key not in self.variable_dictionary[key]:
+ return False
+ return True
def get(self, selector: Sequence[str], /) -> Segment | None:
"""
@@ -102,8 +117,8 @@ class VariablePool(BaseModel):
if len(selector) < MIN_SELECTORS_LENGTH:
return None
- hash_key = hash(tuple(selector[1:]))
- value = self.variable_dictionary[selector[0]].get(hash_key)
+ key, hash_key = self._selector_to_keys(selector)
+ value: Segment | None = self.variable_dictionary[key].get(hash_key)
if value is None:
selector, attr = selector[:-1], selector[-1]
@@ -136,8 +151,8 @@ class VariablePool(BaseModel):
if len(selector) == 1:
self.variable_dictionary[selector[0]] = {}
return
- hash_key = hash(tuple(selector[1:]))
- self.variable_dictionary[selector[0]].pop(hash_key, None)
+ key, hash_key = self._selector_to_keys(selector)
+ self.variable_dictionary[key].pop(hash_key, None)
def convert_template(self, template: str, /):
parts = VARIABLE_PATTERN.split(template)
@@ -154,3 +169,20 @@ class VariablePool(BaseModel):
if isinstance(segment, FileSegment):
return segment
return None
+
+ def _add_system_variables(self, system_variable: SystemVariable):
+ sys_var_mapping = system_variable.to_dict()
+ for key, value in sys_var_mapping.items():
+ if value is None:
+ continue
+ selector = (SYSTEM_VARIABLE_NODE_ID, key)
+ # If the system variable already exists, do not add it again.
+ # This ensures that we can keep the id of the system variables intact.
+ if self._has(selector):
+ continue
+ self.add(selector, value) # type: ignore
+
+ @classmethod
+ def empty(cls) -> "VariablePool":
+ """Create an empty variable pool."""
+ return cls(system_variables=SystemVariable.empty())
diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py
deleted file mode 100644
index 8896416f12..0000000000
--- a/api/core/workflow/entities/workflow_entities.py
+++ /dev/null
@@ -1,79 +0,0 @@
-from typing import Optional
-
-from pydantic import BaseModel
-
-from core.app.entities.app_invoke_entities import InvokeFrom
-from core.workflow.nodes.base import BaseIterationState, BaseLoopState, BaseNode
-from models.enums import UserFrom
-from models.workflow import Workflow, WorkflowType
-
-from .node_entities import NodeRunResult
-from .variable_pool import VariablePool
-
-
-class WorkflowNodeAndResult:
- node: BaseNode
- result: Optional[NodeRunResult] = None
-
- def __init__(self, node: BaseNode, result: Optional[NodeRunResult] = None):
- self.node = node
- self.result = result
-
-
-class WorkflowRunState:
- tenant_id: str
- app_id: str
- workflow_id: str
- workflow_type: WorkflowType
- user_id: str
- user_from: UserFrom
- invoke_from: InvokeFrom
-
- workflow_call_depth: int
-
- start_at: float
- variable_pool: VariablePool
-
- total_tokens: int = 0
-
- workflow_nodes_and_results: list[WorkflowNodeAndResult]
-
- class NodeRun(BaseModel):
- node_id: str
- iteration_node_id: str
- loop_node_id: str
-
- workflow_node_runs: list[NodeRun]
- workflow_node_steps: int
-
- current_iteration_state: Optional[BaseIterationState]
- current_loop_state: Optional[BaseLoopState]
-
- def __init__(
- self,
- workflow: Workflow,
- start_at: float,
- variable_pool: VariablePool,
- user_id: str,
- user_from: UserFrom,
- invoke_from: InvokeFrom,
- workflow_call_depth: int,
- ):
- self.workflow_id = workflow.id
- self.tenant_id = workflow.tenant_id
- self.app_id = workflow.app_id
- self.workflow_type = WorkflowType.value_of(workflow.type)
- self.user_id = user_id
- self.user_from = user_from
- self.invoke_from = invoke_from
- self.workflow_call_depth = workflow_call_depth
-
- self.start_at = start_at
- self.variable_pool = variable_pool
-
- self.total_tokens = 0
-
- self.workflow_node_steps = 1
- self.workflow_node_runs = []
- self.current_iteration_state = None
- self.current_loop_state = None
diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py
index bd4ccc1072..594bb2b32e 100644
--- a/api/core/workflow/errors.py
+++ b/api/core/workflow/errors.py
@@ -2,7 +2,7 @@ from core.workflow.nodes.base import BaseNode
class WorkflowNodeRunFailedError(Exception):
- def __init__(self, node_instance: BaseNode, error: str):
- self.node_instance = node_instance
- self.error = error
- super().__init__(f"Node {node_instance.node_data.title} run failed: {error}")
+ def __init__(self, node: BaseNode, err_msg: str):
+ self._node = node
+ self._error = err_msg
+ super().__init__(f"Node {node.title} run failed: {err_msg}")
diff --git a/api/core/workflow/graph_engine/__init__.py b/api/core/workflow/graph_engine/__init__.py
index 2fee3d7fad..12e1de464b 100644
--- a/api/core/workflow/graph_engine/__init__.py
+++ b/api/core/workflow/graph_engine/__init__.py
@@ -1,3 +1,4 @@
from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState
+from .graph_engine import GraphEngine
-__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
+__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"]
diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py
index afc09bfac5..a62ffe46c9 100644
--- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py
+++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py
@@ -17,8 +17,12 @@ class GraphRuntimeState(BaseModel):
"""total tokens"""
llm_usage: LLMUsage = LLMUsage.empty_usage()
"""llm usage info"""
+
+ # The `outputs` field stores the final output values generated by executing workflows or chatflows.
+ #
+ # Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
+ # after a serialization and deserialization round trip.
outputs: dict[str, Any] = {}
- """outputs"""
node_run_steps: int = 0
"""node run steps"""
diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py
index 5a2915e2d3..b315129763 100644
--- a/api/core/workflow/graph_engine/graph_engine.py
+++ b/api/core/workflow/graph_engine/graph_engine.py
@@ -12,7 +12,7 @@ from typing import Any, Optional, cast
from flask import Flask, current_app
from configs import dify_config
-from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult
from core.workflow.entities.variable_pool import VariablePool, VariableValue
@@ -48,11 +48,9 @@ from core.workflow.nodes.agent.entities import AgentNodeData
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.base.entities import BaseNodeData
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
-from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.utils import variable_utils
from libs.flask_utils import preserve_flask_contexts
from models.enums import UserFrom
@@ -260,12 +258,16 @@ class GraphEngine:
# convert to specific node
node_type = NodeType(node_config.get("data", {}).get("type"))
node_version = node_config.get("data", {}).get("version", "1")
+
+ # Import here to avoid circular import
+ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
+
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
# init workflow run state
- node_instance = node_cls( # type: ignore
+ node = node_cls(
id=route_node_state.id,
config=node_config,
graph_init_params=self.init_params,
@@ -274,11 +276,11 @@ class GraphEngine:
previous_node_id=previous_node_id,
thread_pool_id=self.thread_pool_id,
)
- node_instance = cast(BaseNode[BaseNodeData], node_instance)
+ node.init_node_data(node_config.get("data", {}))
try:
# run node
generator = self._run_node(
- node_instance=node_instance,
+ node=node,
route_node_state=route_node_state,
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
@@ -306,16 +308,16 @@ class GraphEngine:
route_node_state.failed_reason = str(e)
yield NodeRunFailedEvent(
error=str(e),
- id=node_instance.id,
+ id=node.id,
node_id=next_node_id,
node_type=node_type,
- node_data=node_instance.node_data,
+ node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=in_parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
- node_version=node_instance.version(),
+ node_version=node.version(),
)
raise e
@@ -337,7 +339,7 @@ class GraphEngine:
edge = edge_mappings[0]
if (
previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
- and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
+ and node.error_strategy == ErrorStrategy.FAIL_BRANCH
and edge.run_condition is None
):
break
@@ -413,8 +415,8 @@ class GraphEngine:
next_node_id = final_node_id
elif (
- node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
- and node_instance.should_continue_on_error
+ node.continue_on_error
+ and node.error_strategy == ErrorStrategy.FAIL_BRANCH
and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION
):
break
@@ -597,7 +599,7 @@ class GraphEngine:
def _run_node(
self,
- node_instance: BaseNode[BaseNodeData],
+ node: BaseNode,
route_node_state: RouteNodeState,
parallel_id: Optional[str] = None,
parallel_start_node_id: Optional[str] = None,
@@ -611,29 +613,29 @@ class GraphEngine:
# trigger node run start event
agent_strategy = (
AgentNodeStrategyInit(
- name=cast(AgentNodeData, node_instance.node_data).agent_strategy_name,
- icon=cast(AgentNode, node_instance).agent_strategy_icon,
+ name=cast(AgentNodeData, node.get_base_node_data()).agent_strategy_name,
+ icon=cast(AgentNode, node).agent_strategy_icon,
)
- if node_instance.node_type == NodeType.AGENT
+ if node.type_ == NodeType.AGENT
else None
)
yield NodeRunStartedEvent(
- id=node_instance.id,
- node_id=node_instance.node_id,
- node_type=node_instance.node_type,
- node_data=node_instance.node_data,
+ id=node.id,
+ node_id=node.node_id,
+ node_type=node.type_,
+ node_data=node.get_base_node_data(),
route_node_state=route_node_state,
- predecessor_node_id=node_instance.previous_node_id,
+ predecessor_node_id=node.previous_node_id,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
agent_strategy=agent_strategy,
- node_version=node_instance.version(),
+ node_version=node.version(),
)
- max_retries = node_instance.node_data.retry_config.max_retries
- retry_interval = node_instance.node_data.retry_config.retry_interval_seconds
+ max_retries = node.retry_config.max_retries
+ retry_interval = node.retry_config.retry_interval_seconds
retries = 0
should_continue_retry = True
while should_continue_retry and retries <= max_retries:
@@ -642,7 +644,7 @@ class GraphEngine:
retry_start_at = datetime.now(UTC).replace(tzinfo=None)
# yield control to other threads
time.sleep(0.001)
- event_stream = node_instance.run()
+ event_stream = node.run()
for event in event_stream:
if isinstance(event, GraphEngineEvent):
# add parallel info to iteration event
@@ -658,21 +660,21 @@ class GraphEngine:
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if (
retries == max_retries
- and node_instance.node_type == NodeType.HTTP_REQUEST
+ and node.type_ == NodeType.HTTP_REQUEST
and run_result.outputs
- and not node_instance.should_continue_on_error
+ and not node.continue_on_error
):
run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED
- if node_instance.should_retry and retries < max_retries:
+ if node.retry and retries < max_retries:
retries += 1
route_node_state.node_run_result = run_result
yield NodeRunRetryEvent(
id=str(uuid.uuid4()),
- node_id=node_instance.node_id,
- node_type=node_instance.node_type,
- node_data=node_instance.node_data,
+ node_id=node.node_id,
+ node_type=node.type_,
+ node_data=node.get_base_node_data(),
route_node_state=route_node_state,
- predecessor_node_id=node_instance.previous_node_id,
+ predecessor_node_id=node.previous_node_id,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
@@ -680,17 +682,17 @@ class GraphEngine:
error=run_result.error or "Unknown error",
retry_index=retries,
start_at=retry_start_at,
- node_version=node_instance.version(),
+ node_version=node.version(),
)
time.sleep(retry_interval)
break
route_node_state.set_finished(run_result=run_result)
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
- if node_instance.should_continue_on_error:
+ if node.continue_on_error:
# if run failed, handle error
run_result = self._handle_continue_on_error(
- node_instance,
+ node,
event.run_result,
self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions,
@@ -701,44 +703,44 @@ class GraphEngine:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
- node_id=node_instance.node_id,
+ node_id=node.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
yield NodeRunExceptionEvent(
error=run_result.error or "System Error",
- id=node_instance.id,
- node_id=node_instance.node_id,
- node_type=node_instance.node_type,
- node_data=node_instance.node_data,
+ id=node.id,
+ node_id=node.node_id,
+ node_type=node.type_,
+ node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
- node_version=node_instance.version(),
+ node_version=node.version(),
)
should_continue_retry = False
else:
yield NodeRunFailedEvent(
error=route_node_state.failed_reason or "Unknown error.",
- id=node_instance.id,
- node_id=node_instance.node_id,
- node_type=node_instance.node_type,
- node_data=node_instance.node_data,
+ id=node.id,
+ node_id=node.node_id,
+ node_type=node.type_,
+ node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
- node_version=node_instance.version(),
+ node_version=node.version(),
)
should_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
if (
- node_instance.should_continue_on_error
- and self.graph.edge_mapping.get(node_instance.node_id)
- and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH
+ node.continue_on_error
+ and self.graph.edge_mapping.get(node.node_id)
+ and node.error_strategy is ErrorStrategy.FAIL_BRANCH
):
run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS
if run_result.metadata and run_result.metadata.get(
@@ -758,7 +760,7 @@ class GraphEngine:
for variable_key, variable_value in run_result.outputs.items():
# append variables to variable pool recursively
self._append_variables_recursively(
- node_id=node_instance.node_id,
+ node_id=node.node_id,
variable_key_list=[variable_key],
variable_value=variable_value,
)
@@ -783,26 +785,26 @@ class GraphEngine:
run_result.metadata = metadata_dict
yield NodeRunSucceededEvent(
- id=node_instance.id,
- node_id=node_instance.node_id,
- node_type=node_instance.node_type,
- node_data=node_instance.node_data,
+ id=node.id,
+ node_id=node.node_id,
+ node_type=node.type_,
+ node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
- node_version=node_instance.version(),
+ node_version=node.version(),
)
should_continue_retry = False
break
elif isinstance(event, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent(
- id=node_instance.id,
- node_id=node_instance.node_id,
- node_type=node_instance.node_type,
- node_data=node_instance.node_data,
+ id=node.id,
+ node_id=node.node_id,
+ node_type=node.type_,
+ node_data=node.get_base_node_data(),
chunk_content=event.chunk_content,
from_variable_selector=event.from_variable_selector,
route_node_state=route_node_state,
@@ -810,14 +812,14 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
- node_version=node_instance.version(),
+ node_version=node.version(),
)
elif isinstance(event, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent(
- id=node_instance.id,
- node_id=node_instance.node_id,
- node_type=node_instance.node_type,
- node_data=node_instance.node_data,
+ id=node.id,
+ node_id=node.node_id,
+ node_type=node.type_,
+ node_data=node.get_base_node_data(),
retriever_resources=event.retriever_resources,
context=event.context,
route_node_state=route_node_state,
@@ -825,7 +827,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
- node_version=node_instance.version(),
+ node_version=node.version(),
)
except GenerateTaskStoppedError:
# trigger node run failed event
@@ -833,20 +835,20 @@ class GraphEngine:
route_node_state.failed_reason = "Workflow stopped."
yield NodeRunFailedEvent(
error="Workflow stopped.",
- id=node_instance.id,
- node_id=node_instance.node_id,
- node_type=node_instance.node_type,
- node_data=node_instance.node_data,
+ id=node.id,
+ node_id=node.node_id,
+ node_type=node.type_,
+ node_data=node.get_base_node_data(),
route_node_state=route_node_state,
parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
- node_version=node_instance.version(),
+ node_version=node.version(),
)
return
except Exception as e:
- logger.exception(f"Node {node_instance.node_data.title} run failed")
+ logger.exception(f"Node {node.title} run failed")
raise e
def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue):
@@ -886,22 +888,14 @@ class GraphEngine:
def _handle_continue_on_error(
self,
- node_instance: BaseNode[BaseNodeData],
+ node: BaseNode,
error_result: NodeRunResult,
variable_pool: VariablePool,
handle_exceptions: list[str] = [],
) -> NodeRunResult:
- """
- handle continue on error when self._should_continue_on_error is True
-
-
- :param error_result (NodeRunResult): error run result
- :param variable_pool (VariablePool): variable pool
- :return: excption run result
- """
# add error message and error type to variable pool
- variable_pool.add([node_instance.node_id, "error_message"], error_result.error)
- variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type)
+ variable_pool.add([node.node_id, "error_message"], error_result.error)
+ variable_pool.add([node.node_id, "error_type"], error_result.error_type)
# add error message to handle_exceptions
handle_exceptions.append(error_result.error or "")
node_error_args: dict[str, Any] = {
@@ -909,21 +903,21 @@ class GraphEngine:
"error": error_result.error,
"inputs": error_result.inputs,
"metadata": {
- WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy,
+ WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy,
},
}
- if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
+ if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
return NodeRunResult(
**node_error_args,
outputs={
- **node_instance.node_data.default_value_dict,
+ **node.default_value_dict,
"error_message": error_result.error,
"error_type": error_result.error_type,
},
)
- elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH:
- if self.graph.edge_mapping.get(node_instance.node_id):
+ elif node.error_strategy is ErrorStrategy.FAIL_BRANCH:
+ if self.graph.edge_mapping.get(node.node_id):
node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED
return NodeRunResult(
**node_error_args,
diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py
index 678b99d546..c83303034e 100644
--- a/api/core/workflow/nodes/agent/agent_node.py
+++ b/api/core/workflow/nodes/agent/agent_node.py
@@ -1,62 +1,100 @@
import json
-import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from packaging.version import Version
+from pydantic import ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.agent.strategy.plugin import PluginAgentStrategy
+from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
+from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
+from core.plugin.entities.request import InvokeCredentials
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller
from core.provider_manager import ProviderManager
-from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType
+from core.tools.entities.tool_entities import (
+ ToolIdentity,
+ ToolInvokeMessage,
+ ToolParameter,
+ ToolProviderType,
+)
from core.tools.tool_manager import ToolManager
-from core.variables.segments import StringSegment
+from core.tools.utils.message_transformer import ToolFileMessageTransformer
+from core.variables.segments import ArrayFileSegment, StringSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
-from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
+from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
+from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
-from core.workflow.nodes.base.entities import BaseNodeData
-from core.workflow.nodes.enums import NodeType
-from core.workflow.nodes.event.event import RunCompletedEvent
-from core.workflow.nodes.tool.tool_node import ToolNode
+from core.workflow.nodes.base import BaseNode
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
+from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
+from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
+from models import ToolFile
from models.model import Conversation
+from services.tools.builtin_tools_manage_service import BuiltinToolManageService
+from .exc import (
+ AgentInputTypeError,
+ AgentInvocationError,
+ AgentMessageTransformError,
+ AgentVariableNotFoundError,
+ AgentVariableTypeError,
+ ToolFileNotFoundError,
+)
-class AgentNode(ToolNode):
+
+class AgentNode(BaseNode):
"""
Agent Node
"""
- _node_data_cls = AgentNodeData # type: ignore
_node_type = NodeType.AGENT
+ _node_data: AgentNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = AgentNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator:
- """
- Run the agent node
- """
- node_data = cast(AgentNodeData, self.node_data)
-
try:
strategy = get_plugin_agent_strategy(
tenant_id=self.tenant_id,
- agent_strategy_provider_name=node_data.agent_strategy_provider_name,
- agent_strategy_name=node_data.agent_strategy_name,
+ agent_strategy_provider_name=self._node_data.agent_strategy_provider_name,
+ agent_strategy_name=self._node_data.agent_strategy_name,
)
except Exception as e:
yield RunCompletedEvent(
@@ -74,16 +112,17 @@ class AgentNode(ToolNode):
parameters = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
- node_data=node_data,
+ node_data=self._node_data,
strategy=strategy,
)
parameters_for_log = self._generate_agent_parameters(
agent_parameters=agent_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
- node_data=node_data,
+ node_data=self._node_data,
for_log=True,
strategy=strategy,
)
+ credentials = self._generate_credentials(parameters=parameters)
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
@@ -94,61 +133,42 @@ class AgentNode(ToolNode):
user_id=self.user_id,
app_id=self.app_id,
conversation_id=conversation_id.text if conversation_id else None,
+ credentials=credentials,
)
except Exception as e:
+ error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
- error=f"Failed to invoke agent: {str(e)}",
+ error=str(error),
)
)
return
try:
- # convert tool messages
- agent_thoughts: list = []
-
- thought_log_message = ToolInvokeMessage(
- type=ToolInvokeMessage.MessageType.LOG,
- message=ToolInvokeMessage.LogMessage(
- id=str(uuid.uuid4()),
- label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}",
- parent_id=None,
- error=None,
- status=ToolInvokeMessage.LogMessage.LogStatus.START,
- data={
- "strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
- "parameters": parameters_for_log,
- "thought_process": "Agent strategy execution started",
- },
- metadata={
- "icon": self.agent_strategy_icon,
- "agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
- },
- ),
- )
-
- def enhanced_message_stream():
- yield thought_log_message
-
- yield from message_stream
-
yield from self._transform_message(
- message_stream,
- {
+ messages=message_stream,
+ tool_info={
"icon": self.agent_strategy_icon,
- "agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name,
+ "agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name,
},
- parameters_for_log,
- agent_thoughts,
+ parameters_for_log=parameters_for_log,
+ user_id=self.user_id,
+ tenant_id=self.tenant_id,
+ node_type=self.type_,
+ node_id=self.node_id,
+ node_execution_id=self.id,
)
except PluginDaemonClientSideError as e:
+ transform_error = AgentMessageTransformError(
+ f"Failed to transform agent message: {str(e)}", original_error=e
+ )
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=parameters_for_log,
- error=f"Failed to transform agent message: {str(e)}",
+ error=str(transform_error),
)
)
@@ -185,7 +205,7 @@ class AgentNode(ToolNode):
if agent_input.type == "variable":
variable = variable_pool.get(agent_input.value) # type: ignore
if variable is None:
- raise ValueError(f"Variable {agent_input.value} does not exist")
+ raise AgentVariableNotFoundError(str(agent_input.value))
parameter_value = variable.value
elif agent_input.type in {"mixed", "constant"}:
# variable_pool.convert_template expects a string template,
@@ -207,7 +227,7 @@ class AgentNode(ToolNode):
except json.JSONDecodeError:
parameter_value = parameter_value
else:
- raise ValueError(f"Unknown agent input type '{agent_input.type}'")
+ raise AgentInputTypeError(agent_input.type)
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
@@ -246,10 +266,18 @@ class AgentNode(ToolNode):
tool_name=tool.get("tool_name", ""),
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
+ credential_id=tool.get("credential_id", None),
)
extra = tool.get("extra", {})
- runtime_variable_pool = variable_pool if self.node_data.version != "1" else None
+
+ # This is an issue that caused problems before.
+ # Logically, we shouldn't use the node_data.version field for judgment
+ # But for backward compatibility with historical data
+ # this version field judgment is still preserved here.
+ runtime_variable_pool: VariablePool | None = None
+ if node_data.version != "1" or node_data.tool_node_version != "1":
+ runtime_variable_pool = variable_pool
tool_runtime = ToolManager.get_agent_tool_runtime(
self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool
)
@@ -276,11 +304,12 @@ class AgentNode(ToolNode):
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
+ "credential_id": tool.get("credential_id", None),
"provider_type": provider_type.value,
}
)
value = tool_value
- if parameter.type == "model-selector":
+ if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
model_instance, model_schema = self._fetch_model(value)
# memory config
@@ -305,25 +334,41 @@ class AgentNode(ToolNode):
return result
+ def _generate_credentials(
+ self,
+ parameters: dict[str, Any],
+ ) -> InvokeCredentials:
+ """
+ Generate credentials based on the given agent parameters.
+ """
+
+ credentials = InvokeCredentials()
+
+ # generate credentials for tools selector
+ credentials.tool_credentials = {}
+ for tool in parameters.get("tools", []):
+ if tool.get("credential_id"):
+ try:
+ identity = ToolIdentity.model_validate(tool.get("identity", {}))
+ credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
+ except ValidationError:
+ continue
+ return credentials
+
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: BaseNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
- node_data = cast(AgentNodeData, node_data)
+ # Create typed NodeData from dict
+ typed_node_data = AgentNodeData.model_validate(node_data)
+
result: dict[str, Any] = {}
- for parameter_name in node_data.agent_parameters:
- input = node_data.agent_parameters[parameter_name]
+ for parameter_name in typed_node_data.agent_parameters:
+ input = typed_node_data.agent_parameters[parameter_name]
if input.type in ["mixed", "constant"]:
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
for selector in selectors:
@@ -348,7 +393,7 @@ class AgentNode(ToolNode):
plugin
for plugin in plugins
if f"{plugin.plugin_id}/{plugin.name}"
- == cast(AgentNodeData, self.node_data).agent_strategy_provider_name
+ == cast(AgentNodeData, self._node_data).agent_strategy_provider_name
)
icon = current_plugin.declaration.icon
except StopIteration:
@@ -416,3 +461,236 @@ class AgentNode(ToolNode):
return tools
else:
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value]
+
+ def _transform_message(
+ self,
+ messages: Generator[ToolInvokeMessage, None, None],
+ tool_info: Mapping[str, Any],
+ parameters_for_log: dict[str, Any],
+ user_id: str,
+ tenant_id: str,
+ node_type: NodeType,
+ node_id: str,
+ node_execution_id: str,
+ ) -> Generator:
+ """
+ Convert ToolInvokeMessages into tuple[plain_text, files]
+ """
+ # transform message and handle file storage
+ message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
+ messages=messages,
+ user_id=user_id,
+ tenant_id=tenant_id,
+ conversation_id=None,
+ )
+
+ text = ""
+ files: list[File] = []
+ json_list: list[dict] = []
+
+ agent_logs: list[AgentLogEvent] = []
+ agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
+ llm_usage: LLMUsage | None = None
+ variables: dict[str, Any] = {}
+
+ for message in message_stream:
+ if message.type in {
+ ToolInvokeMessage.MessageType.IMAGE_LINK,
+ ToolInvokeMessage.MessageType.BINARY_LINK,
+ ToolInvokeMessage.MessageType.IMAGE,
+ }:
+ assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+
+ url = message.message.text
+ if message.meta:
+ transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
+ else:
+ transfer_method = FileTransferMethod.TOOL_FILE
+
+ tool_file_id = str(url).split("/")[-1].split(".")[0]
+
+ with Session(db.engine) as session:
+ stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
+ tool_file = session.scalar(stmt)
+ if tool_file is None:
+ raise ToolFileNotFoundError(tool_file_id)
+
+ mapping = {
+ "tool_file_id": tool_file_id,
+ "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
+ "transfer_method": transfer_method,
+ "url": url,
+ }
+ file = file_factory.build_from_mapping(
+ mapping=mapping,
+ tenant_id=tenant_id,
+ )
+ files.append(file)
+ elif message.type == ToolInvokeMessage.MessageType.BLOB:
+ # get tool file id
+ assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+ assert message.meta
+
+ tool_file_id = message.message.text.split("/")[-1].split(".")[0]
+ with Session(db.engine) as session:
+ stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
+ tool_file = session.scalar(stmt)
+ if tool_file is None:
+ raise ToolFileNotFoundError(tool_file_id)
+
+ mapping = {
+ "tool_file_id": tool_file_id,
+ "transfer_method": FileTransferMethod.TOOL_FILE,
+ }
+
+ files.append(
+ file_factory.build_from_mapping(
+ mapping=mapping,
+ tenant_id=tenant_id,
+ )
+ )
+ elif message.type == ToolInvokeMessage.MessageType.TEXT:
+ assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+ text += message.message.text
+ yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
+ elif message.type == ToolInvokeMessage.MessageType.JSON:
+ assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
+ if node_type == NodeType.AGENT:
+ msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
+ llm_usage = LLMUsage.from_metadata(msg_metadata)
+ agent_execution_metadata = {
+ WorkflowNodeExecutionMetadataKey(key): value
+ for key, value in msg_metadata.items()
+ if key in WorkflowNodeExecutionMetadataKey.__members__.values()
+ }
+ if message.message.json_object is not None:
+ json_list.append(message.message.json_object)
+ elif message.type == ToolInvokeMessage.MessageType.LINK:
+ assert isinstance(message.message, ToolInvokeMessage.TextMessage)
+ stream_text = f"Link: {message.message.text}\n"
+ text += stream_text
+ yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
+ elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
+ assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
+ variable_name = message.message.variable_name
+ variable_value = message.message.variable_value
+ if message.message.stream:
+ if not isinstance(variable_value, str):
+ raise AgentVariableTypeError(
+ "When 'stream' is True, 'variable_value' must be a string.",
+ variable_name=variable_name,
+ expected_type="str",
+ actual_type=type(variable_value).__name__,
+ )
+ if variable_name not in variables:
+ variables[variable_name] = ""
+ variables[variable_name] += variable_value
+
+ yield RunStreamChunkEvent(
+ chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
+ )
+ else:
+ variables[variable_name] = variable_value
+ elif message.type == ToolInvokeMessage.MessageType.FILE:
+ assert message.meta is not None
+ assert isinstance(message.meta, File)
+ files.append(message.meta["file"])
+ elif message.type == ToolInvokeMessage.MessageType.LOG:
+ assert isinstance(message.message, ToolInvokeMessage.LogMessage)
+ if message.message.metadata:
+ icon = tool_info.get("icon", "")
+ dict_metadata = dict(message.message.metadata)
+ if dict_metadata.get("provider"):
+ manager = PluginInstaller()
+ plugins = manager.list_plugins(tenant_id)
+ try:
+ current_plugin = next(
+ plugin
+ for plugin in plugins
+ if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
+ )
+ icon = current_plugin.declaration.icon
+ except StopIteration:
+ pass
+ icon_dark = None
+ try:
+ builtin_tool = next(
+ provider
+ for provider in BuiltinToolManageService.list_builtin_tools(
+ user_id,
+ tenant_id,
+ )
+ if provider.name == dict_metadata["provider"]
+ )
+ icon = builtin_tool.icon
+ icon_dark = builtin_tool.icon_dark
+ except StopIteration:
+ pass
+
+ dict_metadata["icon"] = icon
+ dict_metadata["icon_dark"] = icon_dark
+ message.message.metadata = dict_metadata
+ agent_log = AgentLogEvent(
+ id=message.message.id,
+ node_execution_id=node_execution_id,
+ parent_id=message.message.parent_id,
+ error=message.message.error,
+ status=message.message.status.value,
+ data=message.message.data,
+ label=message.message.label,
+ metadata=message.message.metadata,
+ node_id=node_id,
+ )
+
+ # check if the agent log is already in the list
+ for log in agent_logs:
+ if log.id == agent_log.id:
+ # update the log
+ log.data = agent_log.data
+ log.status = agent_log.status
+ log.error = agent_log.error
+ log.label = agent_log.label
+ log.metadata = agent_log.metadata
+ break
+ else:
+ agent_logs.append(agent_log)
+
+ yield agent_log
+
+ # Add agent_logs to outputs['json'] to ensure frontend can access thinking process
+ json_output: list[dict[str, Any]] = []
+
+ # Step 1: append each agent log as its own dict.
+ if agent_logs:
+ for log in agent_logs:
+ json_output.append(
+ {
+ "id": log.id,
+ "parent_id": log.parent_id,
+ "error": log.error,
+ "status": log.status,
+ "data": log.data,
+ "label": log.label,
+ "metadata": log.metadata,
+ "node_id": log.node_id,
+ }
+ )
+ # Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
+ if json_list:
+ json_output.extend(json_list)
+ else:
+ json_output.append({"data": []})
+
+ yield RunCompletedEvent(
+ run_result=NodeRunResult(
+ status=WorkflowNodeExecutionStatus.SUCCEEDED,
+ outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
+ metadata={
+ **agent_execution_metadata,
+ WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
+ WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
+ },
+ inputs=parameters_for_log,
+ llm_usage=llm_usage,
+ )
+ )
diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py
index 075a41fb2f..11b11068e7 100644
--- a/api/core/workflow/nodes/agent/entities.py
+++ b/api/core/workflow/nodes/agent/entities.py
@@ -13,6 +13,10 @@ class AgentNodeData(BaseNodeData):
agent_strategy_name: str
agent_strategy_label: str # redundancy
memory: MemoryConfig | None = None
+ # The version of the tool parameter.
+ # If this value is None, it indicates this is a previous version
+ # and requires using the legacy parameter parsing rules.
+ tool_node_version: str | None = None
class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any]
diff --git a/api/core/workflow/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exc.py
new file mode 100644
index 0000000000..d5955bdd7d
--- /dev/null
+++ b/api/core/workflow/nodes/agent/exc.py
@@ -0,0 +1,124 @@
+from typing import Optional
+
+
+class AgentNodeError(Exception):
+ """Base exception for all agent node errors."""
+
+ def __init__(self, message: str):
+ self.message = message
+ super().__init__(self.message)
+
+
+class AgentStrategyError(AgentNodeError):
+ """Exception raised when there's an error with the agent strategy."""
+
+ def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None):
+ self.strategy_name = strategy_name
+ self.provider_name = provider_name
+ super().__init__(message)
+
+
+class AgentStrategyNotFoundError(AgentStrategyError):
+ """Exception raised when the specified agent strategy is not found."""
+
+ def __init__(self, strategy_name: str, provider_name: Optional[str] = None):
+ super().__init__(
+ f"Agent strategy '{strategy_name}' not found"
+ + (f" for provider '{provider_name}'" if provider_name else ""),
+ strategy_name,
+ provider_name,
+ )
+
+
+class AgentInvocationError(AgentNodeError):
+ """Exception raised when there's an error invoking the agent."""
+
+ def __init__(self, message: str, original_error: Optional[Exception] = None):
+ self.original_error = original_error
+ super().__init__(message)
+
+
+class AgentParameterError(AgentNodeError):
+ """Exception raised when there's an error with agent parameters."""
+
+ def __init__(self, message: str, parameter_name: Optional[str] = None):
+ self.parameter_name = parameter_name
+ super().__init__(message)
+
+
+class AgentVariableError(AgentNodeError):
+ """Exception raised when there's an error with variables in the agent node."""
+
+ def __init__(self, message: str, variable_name: Optional[str] = None):
+ self.variable_name = variable_name
+ super().__init__(message)
+
+
+class AgentVariableNotFoundError(AgentVariableError):
+ """Exception raised when a variable is not found in the variable pool."""
+
+ def __init__(self, variable_name: str):
+ super().__init__(f"Variable '{variable_name}' does not exist", variable_name)
+
+
+class AgentInputTypeError(AgentNodeError):
+ """Exception raised when an unknown agent input type is encountered."""
+
+ def __init__(self, input_type: str):
+ super().__init__(f"Unknown agent input type '{input_type}'")
+
+
+class ToolFileError(AgentNodeError):
+ """Exception raised when there's an error with a tool file."""
+
+ def __init__(self, message: str, file_id: Optional[str] = None):
+ self.file_id = file_id
+ super().__init__(message)
+
+
+class ToolFileNotFoundError(ToolFileError):
+ """Exception raised when a tool file is not found."""
+
+ def __init__(self, file_id: str):
+ super().__init__(f"Tool file '{file_id}' does not exist", file_id)
+
+
+class AgentMessageTransformError(AgentNodeError):
+ """Exception raised when there's an error transforming agent messages."""
+
+ def __init__(self, message: str, original_error: Optional[Exception] = None):
+ self.original_error = original_error
+ super().__init__(message)
+
+
+class AgentModelError(AgentNodeError):
+ """Exception raised when there's an error with the model used by the agent."""
+
+ def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None):
+ self.model_name = model_name
+ self.provider = provider
+ super().__init__(message)
+
+
+class AgentMemoryError(AgentNodeError):
+ """Exception raised when there's an error with the agent's memory."""
+
+ def __init__(self, message: str, conversation_id: Optional[str] = None):
+ self.conversation_id = conversation_id
+ super().__init__(message)
+
+
+class AgentVariableTypeError(AgentNodeError):
+ """Exception raised when a variable has an unexpected type."""
+
+ def __init__(
+ self,
+ message: str,
+ variable_name: Optional[str] = None,
+ expected_type: Optional[str] = None,
+ actual_type: Optional[str] = None,
+ ):
+ self.variable_name = variable_name
+ self.expected_type = expected_type
+ self.actual_type = actual_type
+ super().__init__(message)
diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py
index 38c2bcbdf5..84bbabca73 100644
--- a/api/core/workflow/nodes/answer/answer_node.py
+++ b/api/core/workflow/nodes/answer/answer_node.py
@@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
-from typing import Any, cast
+from typing import Any, Optional, cast
from core.variables import ArrayFileSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
@@ -12,14 +12,37 @@ from core.workflow.nodes.answer.entities import (
VarGenerateRouteChunk,
)
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser
-class AnswerNode(BaseNode[AnswerNodeData]):
- _node_data_cls = AnswerNodeData
+class AnswerNode(BaseNode):
_node_type = NodeType.ANSWER
+ _node_data: AnswerNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = AnswerNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def version(cls) -> str:
return "1"
@@ -30,7 +53,7 @@ class AnswerNode(BaseNode[AnswerNodeData]):
:return:
"""
# generate routes
- generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data)
+ generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data)
answer = ""
files = []
@@ -60,16 +83,12 @@ class AnswerNode(BaseNode[AnswerNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: AnswerNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
- variable_template_parser = VariableTemplateParser(template=node_data.answer)
+ # Create typed NodeData from dict
+ typed_node_data = AnswerNodeData.model_validate(node_data)
+
+ variable_template_parser = VariableTemplateParser(template=typed_node_data.answer)
variable_selectors = variable_template_parser.extract_variable_selectors()
variable_mapping = {}
diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py
index d853eb71be..dcfed5eed2 100644
--- a/api/core/workflow/nodes/base/entities.py
+++ b/api/core/workflow/nodes/base/entities.py
@@ -122,13 +122,13 @@ class RetryConfig(BaseModel):
class BaseNodeData(ABC, BaseModel):
title: str
desc: Optional[str] = None
+ version: str = "1"
error_strategy: Optional[ErrorStrategy] = None
default_value: Optional[list[DefaultValue]] = None
- version: str = "1"
retry_config: RetryConfig = RetryConfig()
@property
- def default_value_dict(self):
+ def default_value_dict(self) -> dict[str, Any]:
if self.default_value:
return {item.key: item.value for item in self.default_value}
return {}
diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py
index 6973401429..fb5ec55453 100644
--- a/api/core/workflow/nodes/base/node.py
+++ b/api/core/workflow/nodes/base/node.py
@@ -1,28 +1,22 @@
import logging
from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence
-from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
+from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
-from .entities import BaseNodeData
-
if TYPE_CHECKING:
+ from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
from core.workflow.graph_engine.entities.event import InNodeEvent
- from core.workflow.graph_engine.entities.graph import Graph
- from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
- from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
logger = logging.getLogger(__name__)
-GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
-
-class BaseNode(Generic[GenericNodeData]):
- _node_data_cls: type[GenericNodeData]
+class BaseNode:
_node_type: ClassVar[NodeType]
def __init__(
@@ -56,8 +50,8 @@ class BaseNode(Generic[GenericNodeData]):
self.node_id = node_id
- node_data = self._node_data_cls.model_validate(config.get("data", {}))
- self.node_data = node_data
+ @abstractmethod
+ def init_node_data(self, data: Mapping[str, Any]) -> None: ...
@abstractmethod
def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]:
@@ -130,9 +124,9 @@ class BaseNode(Generic[GenericNodeData]):
if not node_id:
raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
- node_data = cls._node_data_cls(**config.get("data", {}))
+ # Pass raw dict data instead of creating NodeData instance
data = cls._extract_variable_selector_to_variable_mapping(
- graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
+ graph_config=graph_config, node_id=node_id, node_data=config.get("data", {})
)
return data
@@ -142,32 +136,16 @@ class BaseNode(Generic[GenericNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: GenericNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
return {}
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
- """
- Get default config of node.
- :param filters: filter by node config parameters.
- :return:
- """
return {}
@property
- def node_type(self) -> NodeType:
- """
- Get node type
- :return:
- """
+ def type_(self) -> NodeType:
return self._node_type
@classmethod
@@ -181,19 +159,68 @@ class BaseNode(Generic[GenericNodeData]):
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@property
- def should_continue_on_error(self) -> bool:
- """judge if should continue on error
+ def continue_on_error(self) -> bool:
+ return False
- Returns:
- bool: if should continue on error
- """
- return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE
+ @property
+ def retry(self) -> bool:
+ return False
+
+ # Abstract methods that subclasses must implement to provide access
+ # to BaseNodeData properties in a type-safe way
+
+ @abstractmethod
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ """Get the error strategy for this node."""
+ ...
+ @abstractmethod
+ def _get_retry_config(self) -> RetryConfig:
+ """Get the retry configuration for this node."""
+ ...
+
+ @abstractmethod
+ def _get_title(self) -> str:
+ """Get the node title."""
+ ...
+
+ @abstractmethod
+ def _get_description(self) -> Optional[str]:
+ """Get the node description."""
+ ...
+
+ @abstractmethod
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ """Get the default values dictionary for this node."""
+ ...
+
+ @abstractmethod
+ def get_base_node_data(self) -> BaseNodeData:
+ """Get the BaseNodeData object for this node."""
+ ...
+
+ # Public interface properties that delegate to abstract methods
@property
- def should_retry(self) -> bool:
- """judge if should retry
+ def error_strategy(self) -> Optional[ErrorStrategy]:
+ """Get the error strategy for this node."""
+ return self._get_error_strategy()
- Returns:
- bool: if should retry
- """
- return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE
+ @property
+ def retry_config(self) -> RetryConfig:
+ """Get the retry configuration for this node."""
+ return self._get_retry_config()
+
+ @property
+ def title(self) -> str:
+ """Get the node title."""
+ return self._get_title()
+
+ @property
+ def description(self) -> Optional[str]:
+ """Get the node description."""
+ return self._get_description()
+
+ @property
+ def default_value_dict(self) -> dict[str, Any]:
+ """Get the default values dictionary for this node."""
+ return self._get_default_value_dict()
diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py
index 22ed9e2651..fdf3932827 100644
--- a/api/core/workflow/nodes/code/code_node.py
+++ b/api/core/workflow/nodes/code/code_node.py
@@ -1,4 +1,5 @@
from collections.abc import Mapping, Sequence
+from decimal import Decimal
from typing import Any, Optional
from configs import dify_config
@@ -10,8 +11,9 @@ from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.code.entities import CodeNodeData
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .exc import (
CodeNodeError,
@@ -20,10 +22,32 @@ from .exc import (
)
-class CodeNode(BaseNode[CodeNodeData]):
- _node_data_cls = CodeNodeData
+class CodeNode(BaseNode):
_node_type = NodeType.CODE
+ _node_data: CodeNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = CodeNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
@@ -46,12 +70,12 @@ class CodeNode(BaseNode[CodeNodeData]):
def _run(self) -> NodeRunResult:
# Get code language
- code_language = self.node_data.code_language
- code = self.node_data.code
+ code_language = self._node_data.code_language
+ code = self._node_data.code
# Get variables
variables = {}
- for variable_selector in self.node_data.variables:
+ for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable
variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
if isinstance(variable, ArrayFileSegment):
@@ -67,7 +91,7 @@ class CodeNode(BaseNode[CodeNodeData]):
)
# Transform result
- result = self._transform_result(result=result, output_schema=self.node_data.outputs)
+ result = self._transform_result(result=result, output_schema=self._node_data.outputs)
except (CodeExecutionError, CodeNodeError) as e:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__
@@ -114,8 +138,10 @@ class CodeNode(BaseNode[CodeNodeData]):
)
if isinstance(value, float):
+ decimal_value = Decimal(str(value)).normalize()
+ precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator]
# raise error if precision is too high
- if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION:
+ if precision > dify_config.CODE_MAX_PRECISION:
raise OutputValidationError(
f"Output variable `{variable}` has too high precision,"
f" it must be less than {dify_config.CODE_MAX_PRECISION} digits."
@@ -331,16 +357,20 @@ class CodeNode(BaseNode[CodeNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: CodeNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
+ # Create typed NodeData from dict
+ typed_node_data = CodeNodeData.model_validate(node_data)
+
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
- for variable_selector in node_data.variables
+ for variable_selector in typed_node_data.variables
}
+
+ @property
+ def continue_on_error(self) -> bool:
+ return self._node_data.error_strategy is not None
+
+ @property
+ def retry(self) -> bool:
+ return self._node_data.retry_config.retry_enabled
diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py
index 8e6150f9cc..ab5964ebd4 100644
--- a/api/core/workflow/nodes/document_extractor/node.py
+++ b/api/core/workflow/nodes/document_extractor/node.py
@@ -5,7 +5,7 @@ import logging
import os
import tempfile
from collections.abc import Mapping, Sequence
-from typing import Any, cast
+from typing import Any, Optional, cast
import chardet
import docx
@@ -28,7 +28,8 @@ from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .entities import DocumentExtractorNodeData
from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError
@@ -36,21 +37,43 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError,
logger = logging.getLogger(__name__)
-class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
+class DocumentExtractorNode(BaseNode):
"""
Extracts text content from various file types.
Supports plain text, PDF, and DOC/DOCX files.
"""
- _node_data_cls = DocumentExtractorNodeData
_node_type = NodeType.DOCUMENT_EXTRACTOR
+ _node_data: DocumentExtractorNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = DocumentExtractorNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def version(cls) -> str:
return "1"
def _run(self):
- variable_selector = self.node_data.variable_selector
+ variable_selector = self._node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector)
if variable is None:
@@ -97,16 +120,12 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: DocumentExtractorNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
- return {node_id + ".files": node_data.variable_selector}
+ # Create typed NodeData from dict
+ typed_node_data = DocumentExtractorNodeData.model_validate(node_data)
+
+ return {node_id + ".files": typed_node_data.variable_selector}
def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py
index 17a0b3adeb..f86f2e8129 100644
--- a/api/core/workflow/nodes/end/end_node.py
+++ b/api/core/workflow/nodes/end/end_node.py
@@ -1,14 +1,40 @@
+from collections.abc import Mapping
+from typing import Any, Optional
+
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.end.entities import EndNodeData
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
-class EndNode(BaseNode[EndNodeData]):
- _node_data_cls = EndNodeData
+class EndNode(BaseNode):
_node_type = NodeType.END
+ _node_data: EndNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = EndNodeData(**data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def version(cls) -> str:
return "1"
@@ -18,7 +44,7 @@ class EndNode(BaseNode[EndNodeData]):
Run node
:return:
"""
- output_variables = self.node_data.outputs
+ output_variables = self._node_data.outputs
outputs = {}
for variable_selector in output_variables:
diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py
index 73b43eeaf7..7cf9ab9107 100644
--- a/api/core/workflow/nodes/enums.py
+++ b/api/core/workflow/nodes/enums.py
@@ -35,7 +35,3 @@ class ErrorStrategy(StrEnum):
class FailBranchSourceHandle(StrEnum):
FAILED = "fail-branch"
SUCCESS = "success-branch"
-
-
-CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST]
-RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE
diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py
index 971e0f73e7..6799d5c63c 100644
--- a/api/core/workflow/nodes/http_request/node.py
+++ b/api/core/workflow/nodes/http_request/node.py
@@ -11,7 +11,8 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.http_request.executor import Executor
from core.workflow.utils import variable_template_parser
from factories import file_factory
@@ -32,10 +33,32 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout(
logger = logging.getLogger(__name__)
-class HttpRequestNode(BaseNode[HttpRequestNodeData]):
- _node_data_cls = HttpRequestNodeData
+class HttpRequestNode(BaseNode):
_node_type = NodeType.HTTP_REQUEST
+ _node_data: HttpRequestNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = HttpRequestNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict:
return {
@@ -69,8 +92,8 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
process_data = {}
try:
http_executor = Executor(
- node_data=self.node_data,
- timeout=self._get_request_timeout(self.node_data),
+ node_data=self._node_data,
+ timeout=self._get_request_timeout(self._node_data),
variable_pool=self.graph_runtime_state.variable_pool,
max_retries=0,
)
@@ -78,7 +101,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
response = http_executor.invoke()
files = self.extract_files(url=http_executor.url, response=response)
- if not response.response.is_success and (self.should_continue_on_error or self.should_retry):
+ if not response.response.is_success and (self.continue_on_error or self.retry):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
outputs={
@@ -131,15 +154,18 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: HttpRequestNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
+ # Create typed NodeData from dict
+ typed_node_data = HttpRequestNodeData.model_validate(node_data)
+
selectors: list[VariableSelector] = []
- selectors += variable_template_parser.extract_selectors_from_template(node_data.url)
- selectors += variable_template_parser.extract_selectors_from_template(node_data.headers)
- selectors += variable_template_parser.extract_selectors_from_template(node_data.params)
- if node_data.body:
- body_type = node_data.body.type
- data = node_data.body.data
+ selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url)
+ selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers)
+ selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params)
+ if typed_node_data.body:
+ body_type = typed_node_data.body.type
+ data = typed_node_data.body.data
match body_type:
case "binary":
if len(data) != 1:
@@ -217,3 +243,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
files.append(file)
return ArrayFileSegment(value=files)
+
+ @property
+ def continue_on_error(self) -> bool:
+ return self._node_data.error_strategy is not None
+
+ @property
+ def retry(self) -> bool:
+ return self._node_data.retry_config.retry_enabled
diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py
index 22b748030c..86e703dc68 100644
--- a/api/core/workflow/nodes/if_else/if_else_node.py
+++ b/api/core/workflow/nodes/if_else/if_else_node.py
@@ -1,5 +1,5 @@
from collections.abc import Mapping, Sequence
-from typing import Any, Literal
+from typing import Any, Literal, Optional
from typing_extensions import deprecated
@@ -7,16 +7,39 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.if_else.entities import IfElseNodeData
from core.workflow.utils.condition.entities import Condition
from core.workflow.utils.condition.processor import ConditionProcessor
-class IfElseNode(BaseNode[IfElseNodeData]):
- _node_data_cls = IfElseNodeData
+class IfElseNode(BaseNode):
_node_type = NodeType.IF_ELSE
+ _node_data: IfElseNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = IfElseNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def version(cls) -> str:
return "1"
@@ -36,8 +59,8 @@ class IfElseNode(BaseNode[IfElseNodeData]):
condition_processor = ConditionProcessor()
try:
# Check if the new cases structure is used
- if self.node_data.cases:
- for case in self.node_data.cases:
+ if self._node_data.cases:
+ for case in self._node_data.cases:
input_conditions, group_result, final_result = condition_processor.process_conditions(
variable_pool=self.graph_runtime_state.variable_pool,
conditions=case.conditions,
@@ -63,8 +86,8 @@ class IfElseNode(BaseNode[IfElseNodeData]):
input_conditions, group_result, final_result = _should_not_use_old_function(
condition_processor=condition_processor,
variable_pool=self.graph_runtime_state.variable_pool,
- conditions=self.node_data.conditions or [],
- operator=self.node_data.logical_operator or "and",
+ conditions=self._node_data.conditions or [],
+ operator=self._node_data.logical_operator or "and",
)
selected_case_id = "true" if final_result else "false"
@@ -98,10 +121,13 @@ class IfElseNode(BaseNode[IfElseNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: IfElseNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
+ # Create typed NodeData from dict
+ typed_node_data = IfElseNodeData.model_validate(node_data)
+
var_mapping: dict[str, list[str]] = {}
- for case in node_data.cases or []:
+ for case in typed_node_data.cases or []:
for condition in case.conditions:
key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector))
var_mapping[key] = condition.variable_selector
diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py
index c447f433aa..5842c8d64b 100644
--- a/api/core/workflow/nodes/iteration/iteration_node.py
+++ b/api/core/workflow/nodes/iteration/iteration_node.py
@@ -36,7 +36,8 @@ from core.workflow.graph_engine.entities.event import (
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from factories.variable_factory import build_segment
@@ -56,14 +57,36 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class IterationNode(BaseNode[IterationNodeData]):
+class IterationNode(BaseNode):
"""
Iteration Node.
"""
- _node_data_cls = IterationNodeData
_node_type = NodeType.ITERATION
+ _node_data: IterationNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = IterationNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
return {
@@ -83,10 +106,10 @@ class IterationNode(BaseNode[IterationNodeData]):
"""
Run the node.
"""
- variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
+ variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector)
if not variable:
- raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found")
+ raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found")
if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
@@ -116,10 +139,10 @@ class IterationNode(BaseNode[IterationNodeData]):
graph_config = self.graph_config
- if not self.node_data.start_node_id:
+ if not self._node_data.start_node_id:
raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found")
- root_node_id = self.node_data.start_node_id
+ root_node_id = self._node_data.start_node_id
# init graph
iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id)
@@ -161,8 +184,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunStartedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
metadata={"iterator_length": len(iterator_list_value)},
@@ -172,8 +195,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
index=0,
pre_iteration_output=None,
duration=None,
@@ -181,11 +204,11 @@ class IterationNode(BaseNode[IterationNodeData]):
iter_run_map: dict[str, float] = {}
outputs: list[Any] = [None] * len(iterator_list_value)
try:
- if self.node_data.is_parallel:
+ if self._node_data.is_parallel:
futures: list[Future] = []
q: Queue = Queue()
thread_pool = GraphEngineThreadPool(
- max_workers=self.node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
+ max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT
)
for index, item in enumerate(iterator_list_value):
future: Future = thread_pool.submit(
@@ -242,7 +265,7 @@ class IterationNode(BaseNode[IterationNodeData]):
iteration_graph=iteration_graph,
iter_run_map=iter_run_map,
)
- if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
+ if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
outputs = [output for output in outputs if output is not None]
# Flatten the list of lists
@@ -253,8 +276,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunSucceededEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
@@ -278,8 +301,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
@@ -305,21 +328,17 @@ class IterationNode(BaseNode[IterationNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: IterationNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
+ # Create typed NodeData from dict
+ typed_node_data = IterationNodeData.model_validate(node_data)
+
variable_mapping: dict[str, Sequence[str]] = {
- f"{node_id}.input_selector": node_data.iterator_selector,
+ f"{node_id}.input_selector": typed_node_data.iterator_selector,
}
# init graph
- iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
+ iteration_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
if not iteration_graph:
raise IterationGraphNotFoundError("iteration graph not found")
@@ -375,7 +394,7 @@ class IterationNode(BaseNode[IterationNodeData]):
"""
if not isinstance(event, BaseNodeEvent):
return event
- if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
+ if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
event.parallel_mode_run_id = parallel_mode_run_id
iter_metadata = {
@@ -438,12 +457,12 @@ class IterationNode(BaseNode[IterationNodeData]):
elif isinstance(event, BaseGraphEvent):
if isinstance(event, GraphRunFailedEvent):
# iteration run failed
- if self.node_data.is_parallel:
+ if self._node_data.is_parallel:
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
parallel_mode_run_id=parallel_mode_run_id,
start_at=start_at,
inputs=inputs,
@@ -456,8 +475,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": outputs},
@@ -478,7 +497,7 @@ class IterationNode(BaseNode[IterationNodeData]):
event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id
)
if isinstance(event, NodeRunFailedEvent):
- if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
+ if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
@@ -491,15 +510,15 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
- elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
+ elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT:
yield NodeInIterationFailedEvent(
**metadata_event.model_dump(),
)
@@ -512,30 +531,64 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=None,
duration=duration,
)
return
- elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
- yield IterationRunFailedEvent(
- iteration_id=self.id,
- iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
- start_at=start_at,
- inputs=inputs,
- outputs={"output": None},
- steps=len(iterator_list_value),
- metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
- error=event.error,
+ elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED:
+ yield NodeInIterationFailedEvent(
+ **metadata_event.model_dump(),
+ )
+ outputs[current_index] = None
+
+ # clean nodes resources
+ for node_id in iteration_graph.node_ids:
+ variable_pool.remove([node_id])
+
+ # iteration run failed
+ if self._node_data.is_parallel:
+ yield IterationRunFailedEvent(
+ iteration_id=self.id,
+ iteration_node_id=self.node_id,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
+ parallel_mode_run_id=parallel_mode_run_id,
+ start_at=start_at,
+ inputs=inputs,
+ outputs={"output": outputs},
+ steps=len(iterator_list_value),
+ metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
+ error=event.error,
+ )
+ else:
+ yield IterationRunFailedEvent(
+ iteration_id=self.id,
+ iteration_node_id=self.node_id,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
+ start_at=start_at,
+ inputs=inputs,
+ outputs={"output": outputs},
+ steps=len(iterator_list_value),
+ metadata={"total_tokens": graph_engine.graph_runtime_state.total_tokens},
+ error=event.error,
+ )
+
+ # stop the iterator
+ yield RunCompletedEvent(
+ run_result=NodeRunResult(
+ status=WorkflowNodeExecutionStatus.FAILED,
+ error=event.error,
+ )
)
+ return
yield metadata_event
- current_output_segment = variable_pool.get(self.node_data.output_selector)
+ current_output_segment = variable_pool.get(self._node_data.output_selector)
if current_output_segment is None:
raise IterationNodeError("iteration output selector not found")
current_iteration_output = current_output_segment.value
@@ -554,8 +607,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunNextEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
index=next_index,
parallel_mode_run_id=parallel_mode_run_id,
pre_iteration_output=current_iteration_output or None,
@@ -567,8 +620,8 @@ class IterationNode(BaseNode[IterationNodeData]):
yield IterationRunFailedEvent(
iteration_id=self.id,
iteration_node_id=self.node_id,
- iteration_node_type=self.node_type,
- iteration_node_data=self.node_data,
+ iteration_node_type=self.type_,
+ iteration_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
outputs={"output": None},
diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py
index 9900aa225d..b82c29291a 100644
--- a/api/core/workflow/nodes/iteration/iteration_start_node.py
+++ b/api/core/workflow/nodes/iteration/iteration_start_node.py
@@ -1,18 +1,44 @@
+from collections.abc import Mapping
+from typing import Any, Optional
+
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.iteration.entities import IterationStartNodeData
-class IterationStartNode(BaseNode[IterationStartNodeData]):
+class IterationStartNode(BaseNode):
"""
Iteration Start Node.
"""
- _node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START
+ _node_data: IterationStartNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = IterationStartNodeData(**data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def version(cls) -> str:
return "1"
diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py
index 19bdee4fe2..f1767bdf9e 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/entities.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py
@@ -1,10 +1,10 @@
from collections.abc import Sequence
-from typing import Any, Literal, Optional
+from typing import Literal, Optional
from pydantic import BaseModel, Field
from core.workflow.nodes.base import BaseNodeData
-from core.workflow.nodes.llm.entities import VisionConfig
+from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig
class RerankingModelConfig(BaseModel):
@@ -56,17 +56,6 @@ class MultipleRetrievalConfig(BaseModel):
weights: Optional[WeightedScoreConfig] = None
-class ModelConfig(BaseModel):
- """
- Model Config.
- """
-
- provider: str
- name: str
- mode: str
- completion_params: dict[str, Any] = {}
-
-
class SingleRetrievalConfig(BaseModel):
"""
Single Retrieval Config.
diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
index b34d62d669..34b0afc75d 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
@@ -4,7 +4,7 @@ import re
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
-from typing import Any, Optional, cast
+from typing import TYPE_CHECKING, Any, Optional, cast
from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast
@@ -15,20 +15,31 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
-from core.model_runtime.entities.message_entities import PromptMessageRole
-from core.model_runtime.entities.model_entities import ModelFeature, ModelType
+from core.model_runtime.entities.message_entities import (
+ PromptMessageRole,
+)
+from core.model_runtime.entities.model_entities import (
+ ModelFeature,
+ ModelType,
+)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
-from core.variables import StringSegment
+from core.variables import (
+ StringSegment,
+)
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
-from core.workflow.nodes.enums import NodeType
-from core.workflow.nodes.event.event import ModelInvokeCompletedEvent
+from core.workflow.nodes.base import BaseNode
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
+from core.workflow.nodes.event import (
+ ModelInvokeCompletedEvent,
+)
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
@@ -38,7 +49,8 @@ from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
-from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
+from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig
+from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -46,7 +58,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from services.feature_service import FeatureService
-from .entities import KnowledgeRetrievalNodeData, ModelConfig
+from .entities import KnowledgeRetrievalNodeData
from .exc import (
InvalidModelTypeError,
KnowledgeRetrievalNodeError,
@@ -56,6 +68,10 @@ from .exc import (
ModelQuotaExceededError,
)
+if TYPE_CHECKING:
+ from core.file.models import File
+ from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
+
logger = logging.getLogger(__name__)
default_retrieval_model = {
@@ -67,18 +83,76 @@ default_retrieval_model = {
}
-class KnowledgeRetrievalNode(LLMNode):
- _node_data_cls = KnowledgeRetrievalNodeData # type: ignore
+class KnowledgeRetrievalNode(BaseNode):
_node_type = NodeType.KNOWLEDGE_RETRIEVAL
+ _node_data: KnowledgeRetrievalNodeData
+
+ # Instance attributes specific to LLMNode.
+ # Output variable for file
+ _file_outputs: list["File"]
+
+ _llm_file_saver: LLMFileSaver
+
+ def __init__(
+ self,
+ id: str,
+ config: Mapping[str, Any],
+ graph_init_params: "GraphInitParams",
+ graph: "Graph",
+ graph_runtime_state: "GraphRuntimeState",
+ previous_node_id: Optional[str] = None,
+ thread_pool_id: Optional[str] = None,
+ *,
+ llm_file_saver: LLMFileSaver | None = None,
+ ) -> None:
+ super().__init__(
+ id=id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=graph,
+ graph_runtime_state=graph_runtime_state,
+ previous_node_id=previous_node_id,
+ thread_pool_id=thread_pool_id,
+ )
+ # LLM file outputs, used for MultiModal outputs.
+ self._file_outputs: list[File] = []
+
+ if llm_file_saver is None:
+ llm_file_saver = FileSaverImpl(
+ user_id=graph_init_params.user_id,
+ tenant_id=graph_init_params.tenant_id,
+ )
+ self._llm_file_saver = llm_file_saver
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = KnowledgeRetrievalNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def version(cls):
return "1"
def _run(self) -> NodeRunResult: # type: ignore
- node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
# extract variables
- variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector)
+ variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector)
if not isinstance(variable, StringSegment):
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@@ -119,7 +193,7 @@ class KnowledgeRetrievalNode(LLMNode):
# retrieve knowledge
try:
- results = self._fetch_dataset_retriever(node_data=node_data, query=query)
+ results = self._fetch_dataset_retriever(node_data=self._node_data, query=query)
outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -144,6 +218,8 @@ class KnowledgeRetrievalNode(LLMNode):
error=str(e),
error_type=type(e).__name__,
)
+ finally:
+ db.session.close()
def _fetch_dataset_retriever(self, node_data: KnowledgeRetrievalNodeData, query: str) -> list[dict[str, Any]]:
available_datasets = []
@@ -152,7 +228,7 @@ class KnowledgeRetrievalNode(LLMNode):
# Subquery: Count the number of available documents for each dataset
subquery = (
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
- .filter(
+ .where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
@@ -166,11 +242,14 @@ class KnowledgeRetrievalNode(LLMNode):
results = (
db.session.query(Dataset)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
- .filter(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
- .filter((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
+ .where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
+ .where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all()
)
+ # avoid blocking at retrieval
+ db.session.close()
+
for dataset in results:
# pass if dataset is not available
if not dataset:
@@ -291,7 +370,7 @@ class KnowledgeRetrievalNode(LLMNode):
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
document = (
db.session.query(Document)
- .filter(
+ .where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
@@ -336,7 +415,7 @@ class KnowledgeRetrievalNode(LLMNode):
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[Optional[dict[str, list[str]]], Optional[MetadataCondition]]:
- document_query = db.session.query(Document).filter(
+ document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
Document.enabled == True,
@@ -383,7 +462,7 @@ class KnowledgeRetrievalNode(LLMNode):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
- if expected_value.value_type == "number": # type: ignore
+ if expected_value.value_type in {"number", "integer", "float"}: # type: ignore
expected_value = expected_value.value # type: ignore
elif expected_value.value_type == "string": # type: ignore
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
@@ -414,9 +493,9 @@ class KnowledgeRetrievalNode(LLMNode):
node_data.metadata_filtering_conditions
and node_data.metadata_filtering_conditions.logical_operator == "and"
): # type: ignore
- document_query = document_query.filter(and_(*filters))
+ document_query = document_query.where(and_(*filters))
else:
- document_query = document_query.filter(or_(*filters))
+ document_query = document_query.where(or_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
@@ -428,22 +507,19 @@ class KnowledgeRetrievalNode(LLMNode):
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> list[dict[str, Any]]:
# get all metadata field
- metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
+ metadata_fields = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
- # get metadata model config
- metadata_model_config = node_data.metadata_model_config
- if metadata_model_config is None:
+ if node_data.metadata_model_config is None:
raise ValueError("metadata_model_config is required")
- # get metadata model instance
- # fetch model config
- model_instance, model_config = self.get_model_config(metadata_model_config)
+ # get metadata model instance and fetch model config
+ model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
# fetch prompt messages
prompt_template = self._get_prompt_template(
node_data=node_data,
metadata_fields=all_metadata_fields,
query=query or "",
)
- prompt_messages, stop = self._fetch_prompt_messages(
+ prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query=query,
memory=None,
@@ -453,16 +529,23 @@ class KnowledgeRetrievalNode(LLMNode):
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
+ tenant_id=self.tenant_id,
)
result_text = ""
try:
# handle invoke result
- generator = self._invoke_llm(
- node_data_model=node_data.metadata_model_config, # type: ignore
+ generator = LLMNode.invoke_llm(
+ node_data_model=node_data.metadata_model_config,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
+ user_id=self.user_id,
+ structured_output_enabled=self._node_data.structured_output_enabled,
+ structured_output=None,
+ file_saver=self._llm_file_saver,
+ file_outputs=self._file_outputs,
+ node_id=self.node_id,
)
for event in generator:
@@ -552,17 +635,13 @@ class KnowledgeRetrievalNode(LLMNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: KnowledgeRetrievalNodeData, # type: ignore
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
+ # Create typed NodeData from dict
+ typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data)
+
variable_mapping = {}
- variable_mapping[node_id + ".query"] = node_data.query_variable_selector
+ variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector
return variable_mapping
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
@@ -624,7 +703,7 @@ class KnowledgeRetrievalNode(LLMNode):
)
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
- model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore
+ model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore
input_text = query
prompt_messages: list[LLMNodeChatModelMessage] = []
diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py
index 3c9ba44cf1..b91fc622f6 100644
--- a/api/core/workflow/nodes/list_operator/node.py
+++ b/api/core/workflow/nodes/list_operator/node.py
@@ -1,5 +1,5 @@
-from collections.abc import Callable, Sequence
-from typing import Any, Literal, Union
+from collections.abc import Callable, Mapping, Sequence
+from typing import Any, Literal, Optional, Union
from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
@@ -7,16 +7,39 @@ from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from .entities import ListOperatorNodeData
from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError
-class ListOperatorNode(BaseNode[ListOperatorNodeData]):
- _node_data_cls = ListOperatorNodeData
+class ListOperatorNode(BaseNode):
_node_type = NodeType.LIST_OPERATOR
+ _node_data: ListOperatorNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = ListOperatorNodeData(**data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def version(cls) -> str:
return "1"
@@ -26,9 +49,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
process_data: dict[str, list] = {}
outputs: dict[str, Any] = {}
- variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable)
+ variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable)
if variable is None:
- error_message = f"Variable not found for selector: {self.node_data.variable}"
+ error_message = f"Variable not found for selector: {self._node_data.variable}"
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs
)
@@ -48,7 +71,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
)
if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment):
error_message = (
- f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
+ f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment "
"or ArrayStringSegment"
)
return NodeRunResult(
@@ -64,19 +87,19 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
try:
# Filter
- if self.node_data.filter_by.enabled:
+ if self._node_data.filter_by.enabled:
variable = self._apply_filter(variable)
# Extract
- if self.node_data.extract_by.enabled:
+ if self._node_data.extract_by.enabled:
variable = self._extract_slice(variable)
# Order
- if self.node_data.order_by.enabled:
+ if self._node_data.order_by.enabled:
variable = self._apply_order(variable)
# Slice
- if self.node_data.limit.enabled:
+ if self._node_data.limit.enabled:
variable = self._apply_slice(variable)
outputs = {
@@ -104,7 +127,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
filter_func: Callable[[Any], bool]
result: list[Any] = []
- for condition in self.node_data.filter_by.conditions:
+ for condition in self._node_data.filter_by.conditions:
if isinstance(variable, ArrayStringSegment):
if not isinstance(condition.value, str):
raise InvalidFilterValueError(f"Invalid filter value: {condition.value}")
@@ -137,14 +160,14 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
if isinstance(variable, ArrayStringSegment):
- result = _order_string(order=self.node_data.order_by.value, array=variable.value)
+ result = _order_string(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayNumberSegment):
- result = _order_number(order=self.node_data.order_by.value, array=variable.value)
+ result = _order_number(order=self._node_data.order_by.value, array=variable.value)
variable = variable.model_copy(update={"value": result})
elif isinstance(variable, ArrayFileSegment):
result = _order_file(
- order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value
+ order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value
)
variable = variable.model_copy(update={"value": result})
return variable
@@ -152,20 +175,19 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
def _apply_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
- result = variable.value[: self.node_data.limit.size]
+ result = variable.value[: self._node_data.limit.size]
return variable.model_copy(update={"value": result})
def _extract_slice(
self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]
) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]:
- value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text)
+ value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text)
if value < 1:
raise ValueError(f"Invalid serial index: must be >= 1, got {value}")
+ if value > len(variable.value):
+ raise InvalidKeyError(f"Invalid serial index: must be <= {len(variable.value)}, got {value}")
value -= 1
- if len(variable.value) > int(value):
- result = variable.value[value]
- else:
- result = ""
+ result = variable.value[value]
return variable.model_copy(update={"value": [result]})
diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py
index 36d0688807..4bb62d35a2 100644
--- a/api/core/workflow/nodes/llm/entities.py
+++ b/api/core/workflow/nodes/llm/entities.py
@@ -1,4 +1,4 @@
-from collections.abc import Sequence
+from collections.abc import Mapping, Sequence
from typing import Any, Optional
from pydantic import BaseModel, Field, field_validator
@@ -65,7 +65,7 @@ class LLMNodeData(BaseNodeData):
memory: Optional[MemoryConfig] = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
- structured_output: dict | None = None
+ structured_output: Mapping[str, Any] | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py
index 9bfb402dc8..90a0397b67 100644
--- a/api/core/workflow/nodes/llm/node.py
+++ b/api/core/workflow/nodes/llm/node.py
@@ -59,7 +59,8 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import (
ModelInvokeCompletedEvent,
NodeEvent,
@@ -90,17 +91,16 @@ from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
- from core.workflow.graph_engine.entities.graph import Graph
- from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
- from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
+ from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
logger = logging.getLogger(__name__)
-class LLMNode(BaseNode[LLMNodeData]):
- _node_data_cls = LLMNodeData
+class LLMNode(BaseNode):
_node_type = NodeType.LLM
+ _node_data: LLMNodeData
+
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
@@ -138,6 +138,27 @@ class LLMNode(BaseNode[LLMNodeData]):
)
self._llm_file_saver = llm_file_saver
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = LLMNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def version(cls) -> str:
return "1"
@@ -152,13 +173,13 @@ class LLMNode(BaseNode[LLMNodeData]):
try:
# init messages template
- self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
+ self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template)
# fetch variables and fetch values from variable pool
- inputs = self._fetch_inputs(node_data=self.node_data)
+ inputs = self._fetch_inputs(node_data=self._node_data)
# fetch jinja2 inputs
- jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data)
+ jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data)
# merge inputs
inputs.update(jinja_inputs)
@@ -169,9 +190,9 @@ class LLMNode(BaseNode[LLMNodeData]):
files = (
llm_utils.fetch_files(
variable_pool=variable_pool,
- selector=self.node_data.vision.configs.variable_selector,
+ selector=self._node_data.vision.configs.variable_selector,
)
- if self.node_data.vision.enabled
+ if self._node_data.vision.enabled
else []
)
@@ -179,7 +200,7 @@ class LLMNode(BaseNode[LLMNodeData]):
node_inputs["#files#"] = [file.to_dict() for file in files]
# fetch context value
- generator = self._fetch_context(node_data=self.node_data)
+ generator = self._fetch_context(node_data=self._node_data)
context = None
for event in generator:
if isinstance(event, RunRetrieverResourceEvent):
@@ -189,44 +210,54 @@ class LLMNode(BaseNode[LLMNodeData]):
node_inputs["#context#"] = context
# fetch model config
- model_instance, model_config = self._fetch_model_config(self.node_data.model)
+ model_instance, model_config = LLMNode._fetch_model_config(
+ node_data_model=self._node_data.model,
+ tenant_id=self.tenant_id,
+ )
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
app_id=self.app_id,
- node_data_memory=self.node_data.memory,
+ node_data_memory=self._node_data.memory,
model_instance=model_instance,
)
query = None
- if self.node_data.memory:
- query = self.node_data.memory.query_prompt_template
+ if self._node_data.memory:
+ query = self._node_data.memory.query_prompt_template
if not query and (
query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY))
):
query = query_variable.text
- prompt_messages, stop = self._fetch_prompt_messages(
+ prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
- prompt_template=self.node_data.prompt_template,
- memory_config=self.node_data.memory,
- vision_enabled=self.node_data.vision.enabled,
- vision_detail=self.node_data.vision.configs.detail,
+ prompt_template=self._node_data.prompt_template,
+ memory_config=self._node_data.memory,
+ vision_enabled=self._node_data.vision.enabled,
+ vision_detail=self._node_data.vision.configs.detail,
variable_pool=variable_pool,
- jinja2_variables=self.node_data.prompt_config.jinja2_variables,
+ jinja2_variables=self._node_data.prompt_config.jinja2_variables,
+ tenant_id=self.tenant_id,
)
# handle invoke result
- generator = self._invoke_llm(
- node_data_model=self.node_data.model,
+ generator = LLMNode.invoke_llm(
+ node_data_model=self._node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
+ user_id=self.user_id,
+ structured_output_enabled=self._node_data.structured_output_enabled,
+ structured_output=self._node_data.structured_output,
+ file_saver=self._llm_file_saver,
+ file_outputs=self._file_outputs,
+ node_id=self.node_id,
)
structured_output: LLMStructuredOutput | None = None
@@ -296,12 +327,19 @@ class LLMNode(BaseNode[LLMNodeData]):
)
)
- def _invoke_llm(
- self,
+ @staticmethod
+ def invoke_llm(
+ *,
node_data_model: ModelConfig,
model_instance: ModelInstance,
prompt_messages: Sequence[PromptMessage],
stop: Optional[Sequence[str]] = None,
+ user_id: str,
+ structured_output_enabled: bool,
+ structured_output: Optional[Mapping[str, Any]] = None,
+ file_saver: LLMFileSaver,
+ file_outputs: list["File"],
+ node_id: str,
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
model_schema = model_instance.model_type_instance.get_model_schema(
node_data_model.name, model_instance.credentials
@@ -309,8 +347,10 @@ class LLMNode(BaseNode[LLMNodeData]):
if not model_schema:
raise ValueError(f"Model schema not found for {node_data_model.name}")
- if self.node_data.structured_output_enabled:
- output_schema = self._fetch_structured_output_schema()
+ if structured_output_enabled:
+ output_schema = LLMNode.fetch_structured_output_schema(
+ structured_output=structured_output or {},
+ )
invoke_result = invoke_llm_with_structured_output(
provider=model_instance.provider,
model_schema=model_schema,
@@ -320,7 +360,7 @@ class LLMNode(BaseNode[LLMNodeData]):
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
stream=True,
- user=self.user_id,
+ user=user_id,
)
else:
invoke_result = model_instance.invoke_llm(
@@ -328,17 +368,31 @@ class LLMNode(BaseNode[LLMNodeData]):
model_parameters=node_data_model.completion_params,
stop=list(stop or []),
stream=True,
- user=self.user_id,
+ user=user_id,
)
- return self._handle_invoke_result(invoke_result=invoke_result)
+ return LLMNode.handle_invoke_result(
+ invoke_result=invoke_result,
+ file_saver=file_saver,
+ file_outputs=file_outputs,
+ node_id=node_id,
+ )
- def _handle_invoke_result(
- self, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None]
+ @staticmethod
+ def handle_invoke_result(
+ *,
+ invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None],
+ file_saver: LLMFileSaver,
+ file_outputs: list["File"],
+ node_id: str,
) -> Generator[NodeEvent | LLMStructuredOutput, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
- event = self._handle_blocking_result(invoke_result=invoke_result)
+ event = LLMNode.handle_blocking_result(
+ invoke_result=invoke_result,
+ saver=file_saver,
+ file_outputs=file_outputs,
+ )
yield event
return
@@ -356,11 +410,13 @@ class LLMNode(BaseNode[LLMNodeData]):
yield result
if isinstance(result, LLMResultChunk):
contents = result.delta.message.content
- for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
+ for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
+ contents=contents,
+ file_saver=file_saver,
+ file_outputs=file_outputs,
+ ):
full_text_buffer.write(text_part)
- yield RunStreamChunkEvent(
- chunk_content=text_part, from_variable_selector=[self.node_id, "text"]
- )
+ yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[node_id, "text"])
# Update the whole metadata
if not model and result.model:
@@ -378,7 +434,8 @@ class LLMNode(BaseNode[LLMNodeData]):
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
- def _image_file_to_markdown(self, file: "File", /):
+ @staticmethod
+ def _image_file_to_markdown(file: "File", /):
text_chunk = f"})"
return text_chunk
@@ -508,7 +565,7 @@ class LLMNode(BaseNode[LLMNodeData]):
retriever_resources=original_retriever_resource, context=context_str.strip()
)
- def _convert_to_original_retriever_resource(self, context_dict: dict):
+ def _convert_to_original_retriever_resource(self, context_dict: dict) -> RetrievalSourceMetadata | None:
if (
"metadata" in context_dict
and "_source" in context_dict["metadata"]
@@ -539,11 +596,14 @@ class LLMNode(BaseNode[LLMNodeData]):
return None
+ @staticmethod
def _fetch_model_config(
- self, node_data_model: ModelConfig
+ *,
+ node_data_model: ModelConfig,
+ tenant_id: str,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model, model_config_with_cred = llm_utils.fetch_model_config(
- tenant_id=self.tenant_id, node_data_model=node_data_model
+ tenant_id=tenant_id, node_data_model=node_data_model
)
completion_params = model_config_with_cred.parameters
@@ -556,8 +616,8 @@ class LLMNode(BaseNode[LLMNodeData]):
node_data_model.completion_params = completion_params
return model, model_config_with_cred
- def _fetch_prompt_messages(
- self,
+ @staticmethod
+ def fetch_prompt_messages(
*,
sys_query: str | None = None,
sys_files: Sequence["File"],
@@ -570,13 +630,14 @@ class LLMNode(BaseNode[LLMNodeData]):
vision_detail: ImagePromptMessageContent.DETAIL,
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
+ tenant_id: str,
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
prompt_messages: list[PromptMessage] = []
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
- self._handle_list_messages(
+ LLMNode.handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
@@ -602,7 +663,7 @@ class LLMNode(BaseNode[LLMNodeData]):
edition_type="basic",
)
prompt_messages.extend(
- self._handle_list_messages(
+ LLMNode.handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
@@ -731,7 +792,7 @@ class LLMNode(BaseNode[LLMNodeData]):
)
model = ModelManager().get_model_instance(
- tenant_id=self.tenant_id,
+ tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,
model=model_config.model,
@@ -750,10 +811,12 @@ class LLMNode(BaseNode[LLMNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: LLMNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- prompt_template = node_data.prompt_template
+ # Create typed NodeData from dict
+ typed_node_data = LLMNodeData.model_validate(node_data)
+ prompt_template = typed_node_data.prompt_template
variable_selectors = []
if isinstance(prompt_template, list) and all(
isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template
@@ -773,7 +836,7 @@ class LLMNode(BaseNode[LLMNodeData]):
for variable_selector in variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
- memory = node_data.memory
+ memory = typed_node_data.memory
if memory and memory.query_prompt_template:
query_variable_selectors = VariableTemplateParser(
template=memory.query_prompt_template
@@ -781,16 +844,16 @@ class LLMNode(BaseNode[LLMNodeData]):
for variable_selector in query_variable_selectors:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
- if node_data.context.enabled:
- variable_mapping["#context#"] = node_data.context.variable_selector
+ if typed_node_data.context.enabled:
+ variable_mapping["#context#"] = typed_node_data.context.variable_selector
- if node_data.vision.enabled:
- variable_mapping["#files#"] = node_data.vision.configs.variable_selector
+ if typed_node_data.vision.enabled:
+ variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector
- if node_data.memory:
+ if typed_node_data.memory:
variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value]
- if node_data.prompt_config:
+ if typed_node_data.prompt_config:
enable_jinja = False
if isinstance(prompt_template, list):
@@ -803,7 +866,7 @@ class LLMNode(BaseNode[LLMNodeData]):
enable_jinja = True
if enable_jinja:
- for variable_selector in node_data.prompt_config.jinja2_variables or []:
+ for variable_selector in typed_node_data.prompt_config.jinja2_variables or []:
variable_mapping[variable_selector.variable] = variable_selector.value_selector
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
@@ -835,8 +898,8 @@ class LLMNode(BaseNode[LLMNodeData]):
},
}
- def _handle_list_messages(
- self,
+ @staticmethod
+ def handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: Optional[str],
@@ -849,7 +912,7 @@ class LLMNode(BaseNode[LLMNodeData]):
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
- jinjia2_variables=jinja2_variables,
+ jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_message_content_with_role(
@@ -897,9 +960,19 @@ class LLMNode(BaseNode[LLMNodeData]):
return prompt_messages
- def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
+ @staticmethod
+ def handle_blocking_result(
+ *,
+ invoke_result: LLMResult,
+ saver: LLMFileSaver,
+ file_outputs: list["File"],
+ ) -> ModelInvokeCompletedEvent:
buffer = io.StringIO()
- for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content):
+ for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown(
+ contents=invoke_result.message.content,
+ file_saver=saver,
+ file_outputs=file_outputs,
+ ):
buffer.write(text_part)
return ModelInvokeCompletedEvent(
@@ -908,7 +981,12 @@ class LLMNode(BaseNode[LLMNodeData]):
finish_reason=None,
)
- def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File":
+ @staticmethod
+ def save_multimodal_image_output(
+ *,
+ content: ImagePromptMessageContent,
+ file_saver: LLMFileSaver,
+ ) -> "File":
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
There are two kinds of multimodal outputs:
@@ -918,26 +996,21 @@ class LLMNode(BaseNode[LLMNodeData]):
Currently, only image files are supported.
"""
- # Inject the saver somehow...
- _saver = self._llm_file_saver
-
- # If this
if content.url != "":
- saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
+ saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE)
else:
- saved_file = _saver.save_binary_string(
+ saved_file = file_saver.save_binary_string(
data=base64.b64decode(content.base64_data),
mime_type=content.mime_type,
file_type=FileType.IMAGE,
)
- self._file_outputs.append(saved_file)
return saved_file
def _fetch_model_schema(self, provider: str) -> AIModelEntity | None:
"""
Fetch model schema
"""
- model_name = self.node_data.model.name
+ model_name = self._node_data.model.name
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name
@@ -948,16 +1021,20 @@ class LLMNode(BaseNode[LLMNodeData]):
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
return model_schema
- def _fetch_structured_output_schema(self) -> dict[str, Any]:
+ @staticmethod
+ def fetch_structured_output_schema(
+ *,
+ structured_output: Mapping[str, Any],
+ ) -> dict[str, Any]:
"""
Fetch the structured output schema from the node data.
Returns:
dict[str, Any]: The structured output schema
"""
- if not self.node_data.structured_output:
+ if not structured_output:
raise LLMNodeError("Please provide a valid structured output schema")
- structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False)
+ structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema")
@@ -969,9 +1046,12 @@ class LLMNode(BaseNode[LLMNodeData]):
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
+ @staticmethod
def _save_multimodal_output_and_convert_result_to_markdown(
- self,
+ *,
contents: str | list[PromptMessageContentUnionTypes] | None,
+ file_saver: LLMFileSaver,
+ file_outputs: list["File"],
) -> Generator[str, None, None]:
"""Convert intermediate prompt messages into strings and yield them to the caller.
@@ -994,9 +1074,12 @@ class LLMNode(BaseNode[LLMNodeData]):
if isinstance(item, TextPromptMessageContent):
yield item.data
elif isinstance(item, ImagePromptMessageContent):
- file = self._save_multimodal_image_output(item)
- self._file_outputs.append(file)
- yield self._image_file_to_markdown(file)
+ file = LLMNode.save_multimodal_image_output(
+ content=item,
+ file_saver=file_saver,
+ )
+ file_outputs.append(file)
+ yield LLMNode._image_file_to_markdown(file)
else:
logger.warning("unknown item type encountered, type=%s", type(item))
yield str(item)
@@ -1004,6 +1087,14 @@ class LLMNode(BaseNode[LLMNodeData]):
logger.warning("unknown contents type encountered, type=%s", type(contents))
yield str(contents)
+ @property
+ def continue_on_error(self) -> bool:
+ return self._node_data.error_strategy is not None
+
+ @property
+ def retry(self) -> bool:
+ return self._node_data.retry_config.retry_enabled
+
def _combine_message_content_with_role(
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
@@ -1021,20 +1112,20 @@ def _combine_message_content_with_role(
def _render_jinja2_message(
*,
template: str,
- jinjia2_variables: Sequence[VariableSelector],
+ jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
):
if not template:
return ""
- jinjia2_inputs = {}
- for jinja2_variable in jinjia2_variables:
+ jinja2_inputs = {}
+ for jinja2_variable in jinja2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
- jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
+ jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
code_execute_resp = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
- inputs=jinjia2_inputs,
+ inputs=jinja2_inputs,
)
result_text = code_execute_resp["result"]
return result_text
@@ -1130,7 +1221,7 @@ def _handle_completion_template(
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
- jinjia2_variables=jinja2_variables,
+ jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
else:
diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py
index 3f4a5edab9..d04e0bfae1 100644
--- a/api/core/workflow/nodes/loop/entities.py
+++ b/api/core/workflow/nodes/loop/entities.py
@@ -1,11 +1,29 @@
from collections.abc import Mapping
-from typing import Any, Literal, Optional
+from typing import Annotated, Any, Literal, Optional
-from pydantic import BaseModel, Field
+from pydantic import AfterValidator, BaseModel, Field
+from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
from core.workflow.utils.condition.entities import Condition
+_VALID_VAR_TYPE = frozenset(
+ [
+ SegmentType.STRING,
+ SegmentType.NUMBER,
+ SegmentType.OBJECT,
+ SegmentType.ARRAY_STRING,
+ SegmentType.ARRAY_NUMBER,
+ SegmentType.ARRAY_OBJECT,
+ ]
+)
+
+
+def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
+ if seg_type not in _VALID_VAR_TYPE:
+ raise ValueError(...)
+ return seg_type
+
class LoopVariableData(BaseModel):
"""
@@ -13,7 +31,7 @@ class LoopVariableData(BaseModel):
"""
label: str
- var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
+ var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"]
value: Optional[Any | list[str]] = None
diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py
index b144021bab..53cadc5251 100644
--- a/api/core/workflow/nodes/loop/loop_end_node.py
+++ b/api/core/workflow/nodes/loop/loop_end_node.py
@@ -1,18 +1,44 @@
+from collections.abc import Mapping
+from typing import Any, Optional
+
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.loop.entities import LoopEndNodeData
-class LoopEndNode(BaseNode[LoopEndNodeData]):
+class LoopEndNode(BaseNode):
"""
Loop End Node.
"""
- _node_data_cls = LoopEndNodeData
_node_type = NodeType.LOOP_END
+ _node_data: LoopEndNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = LoopEndNodeData(**data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def version(cls) -> str:
return "1"
diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py
index 11fd7b6c2d..655de9362f 100644
--- a/api/core/workflow/nodes/loop/loop_node.py
+++ b/api/core/workflow/nodes/loop/loop_node.py
@@ -3,18 +3,13 @@ import logging
import time
from collections.abc import Generator, Mapping, Sequence
from datetime import UTC, datetime
-from typing import TYPE_CHECKING, Any, Literal, cast
+from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from configs import dify_config
from core.variables import (
- ArrayNumberSegment,
- ArrayObjectSegment,
- ArrayStringSegment,
IntegerSegment,
- ObjectSegment,
Segment,
SegmentType,
- StringSegment,
)
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@@ -35,10 +30,12 @@ from core.workflow.graph_engine.entities.event import (
)
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.utils.condition.processor import ConditionProcessor
+from factories.variable_factory import TypeMismatchError, build_segment_with_type
if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
@@ -47,14 +44,36 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
-class LoopNode(BaseNode[LoopNodeData]):
+class LoopNode(BaseNode):
"""
Loop Node.
"""
- _node_data_cls = LoopNodeData
_node_type = NodeType.LOOP
+ _node_data: LoopNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = LoopNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def version(cls) -> str:
return "1"
@@ -62,17 +81,17 @@ class LoopNode(BaseNode[LoopNodeData]):
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
"""Run the node."""
# Get inputs
- loop_count = self.node_data.loop_count
- break_conditions = self.node_data.break_conditions
- logical_operator = self.node_data.logical_operator
+ loop_count = self._node_data.loop_count
+ break_conditions = self._node_data.break_conditions
+ logical_operator = self._node_data.logical_operator
inputs = {"loop_count": loop_count}
- if not self.node_data.start_node_id:
+ if not self._node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self.node_id} not found")
# Initialize graph
- loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self.node_data.start_node_id)
+ loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id)
if not loop_graph:
raise ValueError("loop graph not found")
@@ -82,8 +101,8 @@ class LoopNode(BaseNode[LoopNodeData]):
# Initialize loop variables
loop_variable_selectors = {}
- if self.node_data.loop_variables:
- for loop_variable in self.node_data.loop_variables:
+ if self._node_data.loop_variables:
+ for loop_variable in self._node_data.loop_variables:
value_processor = {
"constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value),
"variable": lambda var=loop_variable: variable_pool.get(var.value),
@@ -131,8 +150,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunStartedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
- loop_node_type=self.node_type,
- loop_node_data=self.node_data,
+ loop_node_type=self.type_,
+ loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
metadata={"loop_length": loop_count},
@@ -188,11 +207,11 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunSucceededEvent(
loop_id=self.id,
loop_node_id=self.node_id,
- loop_node_type=self.node_type,
- loop_node_data=self.node_data,
+ loop_node_type=self.type_,
+ loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
- outputs=self.node_data.outputs,
+ outputs=self._node_data.outputs,
steps=loop_count,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,
@@ -210,7 +229,7 @@ class LoopNode(BaseNode[LoopNodeData]):
WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map,
WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map,
},
- outputs=self.node_data.outputs,
+ outputs=self._node_data.outputs,
inputs=inputs,
)
)
@@ -221,8 +240,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
- loop_node_type=self.node_type,
- loop_node_data=self.node_data,
+ loop_node_type=self.type_,
+ loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=loop_count,
@@ -324,8 +343,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
- loop_node_type=self.node_type,
- loop_node_data=self.node_data,
+ loop_node_type=self.type_,
+ loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
@@ -355,8 +374,8 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunFailedEvent(
loop_id=self.id,
loop_node_id=self.node_id,
- loop_node_type=self.node_type,
- loop_node_data=self.node_data,
+ loop_node_type=self.type_,
+ loop_node_data=self._node_data,
start_at=start_at,
inputs=inputs,
steps=current_index,
@@ -392,7 +411,7 @@ class LoopNode(BaseNode[LoopNodeData]):
_outputs[loop_variable_key] = None
_outputs["loop_round"] = current_index + 1
- self.node_data.outputs = _outputs
+ self._node_data.outputs = _outputs
if check_break_result:
return {"check_break_result": True}
@@ -404,10 +423,10 @@ class LoopNode(BaseNode[LoopNodeData]):
yield LoopRunNextEvent(
loop_id=self.id,
loop_node_id=self.node_id,
- loop_node_type=self.node_type,
- loop_node_data=self.node_data,
+ loop_node_type=self.type_,
+ loop_node_data=self._node_data,
index=next_index,
- pre_loop_output=self.node_data.outputs,
+ pre_loop_output=self._node_data.outputs,
)
return {"check_break_result": False}
@@ -442,19 +461,15 @@ class LoopNode(BaseNode[LoopNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: LoopNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
+ # Create typed NodeData from dict
+ typed_node_data = LoopNodeData.model_validate(node_data)
+
variable_mapping = {}
# init graph
- loop_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id)
+ loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id)
if not loop_graph:
raise ValueError("loop graph not found")
@@ -490,7 +505,7 @@ class LoopNode(BaseNode[LoopNodeData]):
variable_mapping.update(sub_node_variable_mapping)
- for loop_variable in node_data.loop_variables or []:
+ for loop_variable in typed_node_data.loop_variables or []:
if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping
@@ -505,23 +520,21 @@ class LoopNode(BaseNode[LoopNodeData]):
return variable_mapping
@staticmethod
- def _get_segment_for_constant(var_type: str, value: Any) -> Segment:
+ def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
- segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = {
- "string": (StringSegment, SegmentType.STRING),
- "number": (IntegerSegment, SegmentType.NUMBER),
- "object": (ObjectSegment, SegmentType.OBJECT),
- "array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING),
- "array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER),
- "array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT),
- }
if var_type in ["array[string]", "array[number]", "array[object]"]:
- if value:
+ if value and isinstance(value, str):
value = json.loads(value)
else:
value = []
- segment_info = segment_mapping.get(var_type)
- if not segment_info:
- raise ValueError(f"Invalid variable type: {var_type}")
- segment_class, value_type = segment_info
- return segment_class(value=value, value_type=value_type)
+ try:
+ return build_segment_with_type(var_type, value)
+ except TypeMismatchError as type_exc:
+ # Attempt to parse the value as a JSON-encoded string, if applicable.
+ if not isinstance(value, str):
+ raise
+ try:
+ value = json.loads(value)
+ except ValueError:
+ raise type_exc
+ return build_segment_with_type(var_type, value)
diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py
index f5e38b7516..29b45ea0c3 100644
--- a/api/core/workflow/nodes/loop/loop_start_node.py
+++ b/api/core/workflow/nodes/loop/loop_start_node.py
@@ -1,18 +1,44 @@
+from collections.abc import Mapping
+from typing import Any, Optional
+
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.loop.entities import LoopStartNodeData
-class LoopStartNode(BaseNode[LoopStartNodeData]):
+class LoopStartNode(BaseNode):
"""
Loop Start Node.
"""
- _node_data_cls = LoopStartNodeData
_node_type = NodeType.LOOP_START
+ _node_data: LoopStartNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = LoopStartNodeData(**data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def version(cls) -> str:
return "1"
diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py
index ccfaec4a8c..294b47670b 100644
--- a/api/core/workflow/nodes/node_mapping.py
+++ b/api/core/workflow/nodes/node_mapping.py
@@ -73,6 +73,9 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
},
NodeType.TOOL: {
LATEST_VERSION: ToolNode,
+ # This is an issue that caused problems before.
+ # Logically, we shouldn't use two different versions to point to the same class here,
+ # but in order to maintain compatibility with historical data, this approach has been retained.
"2": ToolNode,
"1": ToolNode,
},
@@ -123,6 +126,9 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
},
NodeType.AGENT: {
LATEST_VERSION: AgentNode,
+ # This is an issue that caused problems before.
+ # Logically, we shouldn't use two different versions to point to the same class here,
+ # but in order to maintain compatibility with historical data, this approach has been retained.
"2": AgentNode,
"1": AgentNode,
},
diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
index 25a534256b..a23d284626 100644
--- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
+++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py
@@ -29,8 +29,9 @@ from core.variables.types import SegmentType
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
from core.workflow.nodes.base.node import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.llm import ModelConfig, llm_utils
from core.workflow.utils import variable_template_parser
from factories.variable_factory import build_segment_with_type
@@ -91,10 +92,31 @@ class ParameterExtractorNode(BaseNode):
Parameter Extractor Node.
"""
- # FIXME: figure out why here is different from super class
- _node_data_cls = ParameterExtractorNodeData # type: ignore
_node_type = NodeType.PARAMETER_EXTRACTOR
+ _node_data: ParameterExtractorNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = ParameterExtractorNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
_model_instance: Optional[ModelInstance] = None
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
@@ -119,7 +141,7 @@ class ParameterExtractorNode(BaseNode):
"""
Run the node.
"""
- node_data = cast(ParameterExtractorNodeData, self.node_data)
+ node_data = cast(ParameterExtractorNodeData, self._node_data)
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""
@@ -398,7 +420,7 @@ class ParameterExtractorNode(BaseNode):
"""
Generate prompt engineering prompt.
"""
- model_mode = ModelMode.value_of(data.model.mode)
+ model_mode = ModelMode(data.model.mode)
if model_mode == ModelMode.COMPLETION:
return self._generate_prompt_engineering_completion_prompt(
@@ -694,7 +716,7 @@ class ParameterExtractorNode(BaseNode):
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
) -> list[ChatModelMessage]:
- model_mode = ModelMode.value_of(node_data.model.mode)
+ model_mode = ModelMode(node_data.model.mode)
input_text = query
memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text
@@ -721,7 +743,7 @@ class ParameterExtractorNode(BaseNode):
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
):
- model_mode = ModelMode.value_of(node_data.model.mode)
+ model_mode = ModelMode(node_data.model.mode)
input_text = query
memory_str = ""
instruction = variable_pool.convert_template(node_data.instruction or "").text
@@ -827,19 +849,15 @@ class ParameterExtractorNode(BaseNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: ParameterExtractorNodeData, # type: ignore
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
- variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query}
+ # Create typed NodeData from dict
+ typed_node_data = ParameterExtractorNodeData.model_validate(node_data)
+
+ variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query}
- if node_data.instruction:
- selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction)
+ if typed_node_data.instruction:
+ selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction)
for selector in selectors:
variable_mapping[selector.variable] = selector.value_selector
diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py
index 74024ed90c..15012fa48d 100644
--- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py
+++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py
@@ -1,6 +1,6 @@
import json
from collections.abc import Mapping, Sequence
-from typing import Any, Optional, cast
+from typing import TYPE_CHECKING, Any, Optional, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -11,8 +11,11 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.simple_prompt_transform import ModelMode
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.workflow.entities.node_entities import NodeRunResult
+from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.base.node import BaseNode
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import ModelInvokeCompletedEvent
from core.workflow.nodes.llm import (
LLMNode,
@@ -20,6 +23,7 @@ from core.workflow.nodes.llm import (
LLMNodeCompletionModelPromptTemplate,
llm_utils,
)
+from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from libs.json_in_md_parser import parse_and_check_json_markdown
@@ -35,17 +39,77 @@ from .template_prompts import (
QUESTION_CLASSIFIER_USER_PROMPT_3,
)
+if TYPE_CHECKING:
+ from core.file.models import File
+ from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState
-class QuestionClassifierNode(LLMNode):
- _node_data_cls = QuestionClassifierNodeData # type: ignore
+
+class QuestionClassifierNode(BaseNode):
_node_type = NodeType.QUESTION_CLASSIFIER
+ _node_data: QuestionClassifierNodeData
+
+ _file_outputs: list["File"]
+ _llm_file_saver: LLMFileSaver
+
+ def __init__(
+ self,
+ id: str,
+ config: Mapping[str, Any],
+ graph_init_params: "GraphInitParams",
+ graph: "Graph",
+ graph_runtime_state: "GraphRuntimeState",
+ previous_node_id: Optional[str] = None,
+ thread_pool_id: Optional[str] = None,
+ *,
+ llm_file_saver: LLMFileSaver | None = None,
+ ) -> None:
+ super().__init__(
+ id=id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=graph,
+ graph_runtime_state=graph_runtime_state,
+ previous_node_id=previous_node_id,
+ thread_pool_id=thread_pool_id,
+ )
+ # LLM file outputs, used for MultiModal outputs.
+ self._file_outputs: list[File] = []
+
+ if llm_file_saver is None:
+ llm_file_saver = FileSaverImpl(
+ user_id=graph_init_params.user_id,
+ tenant_id=graph_init_params.tenant_id,
+ )
+ self._llm_file_saver = llm_file_saver
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = QuestionClassifierNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def version(cls):
return "1"
def _run(self):
- node_data = cast(QuestionClassifierNodeData, self.node_data)
+ node_data = cast(QuestionClassifierNodeData, self._node_data)
variable_pool = self.graph_runtime_state.variable_pool
# extract variables
@@ -53,7 +117,10 @@ class QuestionClassifierNode(LLMNode):
query = variable.value if variable else None
variables = {"query": query}
# fetch model config
- model_instance, model_config = self._fetch_model_config(node_data.model)
+ model_instance, model_config = LLMNode._fetch_model_config(
+ node_data_model=node_data.model,
+ tenant_id=self.tenant_id,
+ )
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
@@ -91,7 +158,7 @@ class QuestionClassifierNode(LLMNode):
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
# two consecutive user prompts will be generated, causing model's error.
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
- prompt_messages, stop = self._fetch_prompt_messages(
+ prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query="",
memory=memory,
@@ -101,6 +168,7 @@ class QuestionClassifierNode(LLMNode):
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
+ tenant_id=self.tenant_id,
)
result_text = ""
@@ -109,11 +177,17 @@ class QuestionClassifierNode(LLMNode):
try:
# handle invoke result
- generator = self._invoke_llm(
+ generator = LLMNode.invoke_llm(
node_data_model=node_data.model,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
+ user_id=self.user_id,
+ structured_output_enabled=False,
+ structured_output=None,
+ file_saver=self._llm_file_saver,
+ file_outputs=self._file_outputs,
+ node_id=self.node_id,
)
for event in generator:
@@ -183,23 +257,18 @@ class QuestionClassifierNode(LLMNode):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: Any,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
- node_data = cast(QuestionClassifierNodeData, node_data)
- variable_mapping = {"query": node_data.query_variable_selector}
- variable_selectors = []
- if node_data.instruction:
- variable_template_parser = VariableTemplateParser(template=node_data.instruction)
+ # Create typed NodeData from dict
+ typed_node_data = QuestionClassifierNodeData.model_validate(node_data)
+
+ variable_mapping = {"query": typed_node_data.query_variable_selector}
+ variable_selectors: list[VariableSelector] = []
+ if typed_node_data.instruction:
+ variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction)
variable_selectors.extend(variable_template_parser.extract_variable_selectors())
for variable_selector in variable_selectors:
- variable_mapping[variable_selector.variable] = variable_selector.value_selector
+ variable_mapping[variable_selector.variable] = list(variable_selector.value_selector)
variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()}
@@ -265,7 +334,7 @@ class QuestionClassifierNode(LLMNode):
memory: Optional[TokenBufferMemory],
max_token_limit: int = 2000,
):
- model_mode = ModelMode.value_of(node_data.model.mode)
+ model_mode = ModelMode(node_data.model.mode)
classes = node_data.classes
categories = []
for class_ in classes:
diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py
index 5ee9bc331f..9e401e76bb 100644
--- a/api/core/workflow/nodes/start/start_node.py
+++ b/api/core/workflow/nodes/start/start_node.py
@@ -1,22 +1,48 @@
+from collections.abc import Mapping
+from typing import Any, Optional
+
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.start.entities import StartNodeData
-class StartNode(BaseNode[StartNodeData]):
- _node_data_cls = StartNodeData
+class StartNode(BaseNode):
_node_type = NodeType.START
+ _node_data: StartNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = StartNodeData(**data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
- system_inputs = self.graph_runtime_state.variable_pool.system_variables
+ system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.
diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py
index ba573074c3..1962c82db1 100644
--- a/api/core/workflow/nodes/template_transform/template_transform_node.py
+++ b/api/core/workflow/nodes/template_transform/template_transform_node.py
@@ -6,16 +6,39 @@ from core.helper.code_executor.code_executor import CodeExecutionError, CodeExec
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000"))
-class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
- _node_data_cls = TemplateTransformNodeData
+class TemplateTransformNode(BaseNode):
_node_type = NodeType.TEMPLATE_TRANSFORM
+ _node_data: TemplateTransformNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = TemplateTransformNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
"""
@@ -35,14 +58,14 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
def _run(self) -> NodeRunResult:
# Get variables
variables = {}
- for variable_selector in self.node_data.variables:
+ for variable_selector in self._node_data.variables:
variable_name = variable_selector.variable
value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector)
variables[variable_name] = value.to_object() if value else None
# Run code
try:
result = CodeExecutor.execute_workflow_code_template(
- language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables
+ language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables
)
except CodeExecutionError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
@@ -60,16 +83,12 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]):
@classmethod
def _extract_variable_selector_to_variable_mapping(
- cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData
+ cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any]
) -> Mapping[str, Sequence[str]]:
- """
- Extract variable selector to variable mapping
- :param graph_config: graph config
- :param node_id: node id
- :param node_data: node data
- :return:
- """
+ # Create typed NodeData from dict
+ typed_node_data = TemplateTransformNodeData.model_validate(node_data)
+
return {
node_id + "." + variable_selector.variable: variable_selector.value_selector
- for variable_selector in node_data.variables
+ for variable_selector in typed_node_data.variables
}
diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py
index 691f6e0196..f0a44d919b 100644
--- a/api/core/workflow/nodes/tool/entities.py
+++ b/api/core/workflow/nodes/tool/entities.py
@@ -14,6 +14,7 @@ class ToolEntity(BaseModel):
tool_name: str
tool_label: str # redundancy
tool_configurations: dict[str, Any]
+ credential_id: str | None = None
plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before")
@@ -58,6 +59,10 @@ class ToolNodeData(BaseNodeData, ToolEntity):
return typ
tool_parameters: dict[str, ToolInput]
+ # The version of the tool parameter.
+ # If this value is None, it indicates this is a previous version
+ # and requires using the legacy parameter parsing rules.
+ tool_node_version: str | None = None
@field_validator("tool_parameters", mode="before")
@classmethod
diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py
index 48627a229d..f437ac841d 100644
--- a/api/core/workflow/nodes/tool/tool_node.py
+++ b/api/core/workflow/nodes/tool/tool_node.py
@@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod
-from core.model_runtime.entities.llm_entities import LLMUsage
from core.plugin.impl.exc import PluginDaemonClientSideError
from core.plugin.impl.plugin import PluginInstaller
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
@@ -19,9 +18,9 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.enums import SystemVariableKey
-from core.workflow.graph_engine.entities.event import AgentLogEvent
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
from core.workflow.utils.variable_template_parser import VariableTemplateParser
from extensions.ext_database import db
@@ -37,14 +36,18 @@ from .exc import (
)
-class ToolNode(BaseNode[ToolNodeData]):
+class ToolNode(BaseNode):
"""
Tool Node
"""
- _node_data_cls = ToolNodeData
_node_type = NodeType.TOOL
+ _node_data: ToolNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = ToolNodeData.model_validate(data)
+
@classmethod
def version(cls) -> str:
return "1"
@@ -54,7 +57,7 @@ class ToolNode(BaseNode[ToolNodeData]):
Run the tool node
"""
- node_data = cast(ToolNodeData, self.node_data)
+ node_data = cast(ToolNodeData, self._node_data)
# fetch tool icon
tool_info = {
@@ -67,9 +70,15 @@ class ToolNode(BaseNode[ToolNodeData]):
try:
from core.tools.tool_manager import ToolManager
- variable_pool = self.graph_runtime_state.variable_pool if self.node_data.version != "1" else None
+ # This is an issue that caused problems before.
+ # Logically, we shouldn't use the node_data.version field for judgment
+ # But for backward compatibility with historical data
+ # this version field judgment is still preserved here.
+ variable_pool: VariablePool | None = None
+ if node_data.version != "1" or node_data.tool_node_version != "1":
+ variable_pool = self.graph_runtime_state.variable_pool
tool_runtime = ToolManager.get_workflow_tool_runtime(
- self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from, variable_pool
+ self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool
)
except ToolNodeError as e:
yield RunCompletedEvent(
@@ -88,12 +97,12 @@ class ToolNode(BaseNode[ToolNodeData]):
parameters = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
- node_data=self.node_data,
+ node_data=self._node_data,
)
parameters_for_log = self._generate_parameters(
tool_parameters=tool_parameters,
variable_pool=self.graph_runtime_state.variable_pool,
- node_data=self.node_data,
+ node_data=self._node_data,
for_log=True,
)
# get conversation id
@@ -124,7 +133,14 @@ class ToolNode(BaseNode[ToolNodeData]):
try:
# convert tool messages
- yield from self._transform_message(message_stream, tool_info, parameters_for_log)
+ yield from self._transform_message(
+ messages=message_stream,
+ tool_info=tool_info,
+ parameters_for_log=parameters_for_log,
+ user_id=self.user_id,
+ tenant_id=self.tenant_id,
+ node_id=self.node_id,
+ )
except (PluginDaemonClientSideError, ToolInvokeError) as e:
yield RunCompletedEvent(
run_result=NodeRunResult(
@@ -191,7 +207,9 @@ class ToolNode(BaseNode[ToolNodeData]):
messages: Generator[ToolInvokeMessage, None, None],
tool_info: Mapping[str, Any],
parameters_for_log: dict[str, Any],
- agent_thoughts: Optional[list] = None,
+ user_id: str,
+ tenant_id: str,
+ node_id: str,
) -> Generator:
"""
Convert ToolInvokeMessages into tuple[plain_text, files]
@@ -199,8 +217,8 @@ class ToolNode(BaseNode[ToolNodeData]):
# transform message and handle file storage
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=messages,
- user_id=self.user_id,
- tenant_id=self.tenant_id,
+ user_id=user_id,
+ tenant_id=tenant_id,
conversation_id=None,
)
@@ -208,9 +226,6 @@ class ToolNode(BaseNode[ToolNodeData]):
files: list[File] = []
json: list[dict] = []
- agent_logs: list[AgentLogEvent] = []
- agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
- llm_usage: LLMUsage | None = None
variables: dict[str, Any] = {}
for message in message_stream:
@@ -243,7 +258,7 @@ class ToolNode(BaseNode[ToolNodeData]):
}
file = file_factory.build_from_mapping(
mapping=mapping,
- tenant_id=self.tenant_id,
+ tenant_id=tenant_id,
)
files.append(file)
elif message.type == ToolInvokeMessage.MessageType.BLOB:
@@ -266,51 +281,49 @@ class ToolNode(BaseNode[ToolNodeData]):
files.append(
file_factory.build_from_mapping(
mapping=mapping,
- tenant_id=self.tenant_id,
+ tenant_id=tenant_id,
)
)
elif message.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
text += message.message.text
- yield RunStreamChunkEvent(
- chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
- )
+ yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.JSON:
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
- if self.node_type == NodeType.AGENT:
- msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
- llm_usage = LLMUsage.from_metadata(msg_metadata)
- agent_execution_metadata = {
- WorkflowNodeExecutionMetadataKey(key): value
- for key, value in msg_metadata.items()
- if key in WorkflowNodeExecutionMetadataKey.__members__.values()
- }
+ # JSON message handling for tool node
if message.message.json_object is not None:
json.append(message.message.json_object)
elif message.type == ToolInvokeMessage.MessageType.LINK:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
stream_text = f"Link: {message.message.text}\n"
text += stream_text
- yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
+ yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"])
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
variable_name = message.message.variable_name
variable_value = message.message.variable_value
if message.message.stream:
if not isinstance(variable_value, str):
- raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
+ raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.")
if variable_name not in variables:
variables[variable_name] = ""
variables[variable_name] += variable_value
yield RunStreamChunkEvent(
- chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
+ chunk_content=variable_value, from_variable_selector=[node_id, variable_name]
)
else:
variables[variable_name] = variable_value
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
- assert isinstance(message.meta, File)
+ assert isinstance(message.meta, dict)
+ # Validate that meta contains a 'file' key
+ if "file" not in message.meta:
+ raise ToolNodeError("File message is missing 'file' key in meta")
+
+ # Validate that the file is an instance of File
+ if not isinstance(message.meta["file"], File):
+ raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
@@ -319,7 +332,7 @@ class ToolNode(BaseNode[ToolNodeData]):
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
manager = PluginInstaller()
- plugins = manager.list_plugins(self.tenant_id)
+ plugins = manager.list_plugins(tenant_id)
try:
current_plugin = next(
plugin
@@ -334,8 +347,8 @@ class ToolNode(BaseNode[ToolNodeData]):
builtin_tool = next(
provider
for provider in BuiltinToolManageService.list_builtin_tools(
- self.user_id,
- self.tenant_id,
+ user_id,
+ tenant_id,
)
if provider.name == dict_metadata["provider"]
)
@@ -347,51 +360,10 @@ class ToolNode(BaseNode[ToolNodeData]):
dict_metadata["icon"] = icon
dict_metadata["icon_dark"] = icon_dark
message.message.metadata = dict_metadata
- agent_log = AgentLogEvent(
- id=message.message.id,
- node_execution_id=self.id,
- parent_id=message.message.parent_id,
- error=message.message.error,
- status=message.message.status.value,
- data=message.message.data,
- label=message.message.label,
- metadata=message.message.metadata,
- node_id=self.node_id,
- )
-
- # check if the agent log is already in the list
- for log in agent_logs:
- if log.id == agent_log.id:
- # update the log
- log.data = agent_log.data
- log.status = agent_log.status
- log.error = agent_log.error
- log.label = agent_log.label
- log.metadata = agent_log.metadata
- break
- else:
- agent_logs.append(agent_log)
-
- yield agent_log
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
json_output: list[dict[str, Any]] = []
- # Step 1: append each agent log as its own dict.
- if agent_logs:
- for log in agent_logs:
- json_output.append(
- {
- "id": log.id,
- "parent_id": log.parent_id,
- "error": log.error,
- "status": log.status,
- "data": log.data,
- "label": log.label,
- "metadata": log.metadata,
- "node_id": log.node_id,
- }
- )
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
if json:
json_output.extend(json)
@@ -403,12 +375,9 @@ class ToolNode(BaseNode[ToolNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables},
metadata={
- **agent_execution_metadata,
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
- WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
},
inputs=parameters_for_log,
- llm_usage=llm_usage,
)
)
@@ -418,7 +387,7 @@ class ToolNode(BaseNode[ToolNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: ToolNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selector to variable mapping
@@ -427,9 +396,12 @@ class ToolNode(BaseNode[ToolNodeData]):
:param node_data: node data
:return:
"""
+ # Create typed NodeData from dict
+ typed_node_data = ToolNodeData.model_validate(node_data)
+
result = {}
- for parameter_name in node_data.tool_parameters:
- input = node_data.tool_parameters[parameter_name]
+ for parameter_name in typed_node_data.tool_parameters:
+ input = typed_node_data.tool_parameters[parameter_name]
if input.type == "mixed":
assert isinstance(input.value, str)
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
@@ -443,3 +415,29 @@ class ToolNode(BaseNode[ToolNodeData]):
result = {node_id + "." + key: value for key, value in result.items()}
return result
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
+ @property
+ def continue_on_error(self) -> bool:
+ return self._node_data.error_strategy is not None
+
+ @property
+ def retry(self) -> bool:
+ return self._node_data.retry_config.retry_enabled
diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py
index 96bb3e793a..98127bbeb6 100644
--- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py
+++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py
@@ -1,17 +1,41 @@
from collections.abc import Mapping
+from typing import Any, Optional
from core.variables.segments import Segment
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData
-class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
- _node_data_cls = VariableAssignerNodeData
+class VariableAggregatorNode(BaseNode):
_node_type = NodeType.VARIABLE_AGGREGATOR
+ _node_data: VariableAssignerNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = VariableAssignerNodeData(**data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
@classmethod
def version(cls) -> str:
return "1"
@@ -21,8 +45,8 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
outputs: dict[str, Segment | Mapping[str, Segment]] = {}
inputs = {}
- if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled:
- for selector in self.node_data.variables:
+ if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled:
+ for selector in self._node_data.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
if variable is not None:
outputs = {"output": variable}
@@ -30,7 +54,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]):
inputs = {".".join(selector[1:]): variable.to_object()}
break
else:
- for group in self.node_data.advanced_settings.groups:
+ for group in self._node_data.advanced_settings.groups:
for selector in group.variables:
variable = self.graph_runtime_state.variable_pool.get(selector)
diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py
index be5083c9c1..51383fa588 100644
--- a/api/core/workflow/nodes/variable_assigner/v1/node.py
+++ b/api/core/workflow/nodes/variable_assigner/v1/node.py
@@ -7,7 +7,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from factories import variable_factory
@@ -22,11 +23,33 @@ if TYPE_CHECKING:
_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
-class VariableAssignerNode(BaseNode[VariableAssignerData]):
- _node_data_cls = VariableAssignerData
+class VariableAssignerNode(BaseNode):
_node_type = NodeType.VARIABLE_ASSIGNER
_conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY
+ _node_data: VariableAssignerData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = VariableAssignerData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
def __init__(
self,
id: str,
@@ -59,36 +82,39 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: VariableAssignerData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
+ # Create typed NodeData from dict
+ typed_node_data = VariableAssignerData.model_validate(node_data)
+
mapping = {}
- assigned_variable_node_id = node_data.assigned_variable_selector[0]
+ assigned_variable_node_id = typed_node_data.assigned_variable_selector[0]
if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID:
- selector_key = ".".join(node_data.assigned_variable_selector)
+ selector_key = ".".join(typed_node_data.assigned_variable_selector)
key = f"{node_id}.#{selector_key}#"
- mapping[key] = node_data.assigned_variable_selector
+ mapping[key] = typed_node_data.assigned_variable_selector
- selector_key = ".".join(node_data.input_variable_selector)
+ selector_key = ".".join(typed_node_data.input_variable_selector)
key = f"{node_id}.#{selector_key}#"
- mapping[key] = node_data.input_variable_selector
+ mapping[key] = typed_node_data.input_variable_selector
return mapping
def _run(self) -> NodeRunResult:
- assigned_variable_selector = self.node_data.assigned_variable_selector
+ assigned_variable_selector = self._node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable):
raise VariableOperatorNodeError("assigned variable not found")
- match self.node_data.write_mode:
+ match self._node_data.write_mode:
case WriteMode.OVER_WRITE:
- income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
+ income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
if not income_value:
raise VariableOperatorNodeError("input value not found")
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
- income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
+ income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector)
if not income_value:
raise VariableOperatorNodeError("input value not found")
updated_value = original_variable.value + [income_value.value]
@@ -101,7 +127,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
- raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
+ raise VariableOperatorNodeError(f"unsupported write mode: {self._node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
@@ -130,6 +156,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
def get_zero_value(t: SegmentType):
+ # TODO(QuantumGhost): this should be a method of `SegmentType`.
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return variable_factory.build_segment([])
@@ -137,6 +164,10 @@ def get_zero_value(t: SegmentType):
return variable_factory.build_segment({})
case SegmentType.STRING:
return variable_factory.build_segment("")
+ case SegmentType.INTEGER:
+ return variable_factory.build_segment(0)
+ case SegmentType.FLOAT:
+ return variable_factory.build_segment(0.0)
case SegmentType.NUMBER:
return variable_factory.build_segment(0)
case _:
diff --git a/api/core/workflow/nodes/variable_assigner/v2/constants.py b/api/core/workflow/nodes/variable_assigner/v2/constants.py
index 3797bfa77a..7f760e5baa 100644
--- a/api/core/workflow/nodes/variable_assigner/v2/constants.py
+++ b/api/core/workflow/nodes/variable_assigner/v2/constants.py
@@ -1,5 +1,6 @@
from core.variables import SegmentType
+# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy.
EMPTY_VALUE_MAPPING = {
SegmentType.STRING: "",
SegmentType.NUMBER: 0,
diff --git a/api/core/workflow/nodes/variable_assigner/v2/helpers.py b/api/core/workflow/nodes/variable_assigner/v2/helpers.py
index 8fb2a27388..7a20975b15 100644
--- a/api/core/workflow/nodes/variable_assigner/v2/helpers.py
+++ b/api/core/workflow/nodes/variable_assigner/v2/helpers.py
@@ -10,10 +10,16 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
case Operation.OVER_WRITE | Operation.CLEAR:
return True
case Operation.SET:
- return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER}
+ return variable_type in {
+ SegmentType.OBJECT,
+ SegmentType.STRING,
+ SegmentType.NUMBER,
+ SegmentType.INTEGER,
+ SegmentType.FLOAT,
+ }
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
# Only number variable can be added, subtracted, multiplied or divided
- return variable_type == SegmentType.NUMBER
+ return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
case Operation.APPEND | Operation.EXTEND:
# Only array variable can be appended or extended
return variable_type in {
@@ -46,7 +52,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat
match variable_type:
case SegmentType.STRING | SegmentType.OBJECT:
return operation in {Operation.OVER_WRITE, Operation.SET}
- case SegmentType.NUMBER:
+ case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return operation in {
Operation.OVER_WRITE,
Operation.SET,
@@ -66,7 +72,7 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
case SegmentType.STRING:
return isinstance(value, str)
- case SegmentType.NUMBER:
+ case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
if not isinstance(value, int | float):
return False
if operation == Operation.DIVIDE and value == 0:
diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py
index 9292da6f1c..c0215cae71 100644
--- a/api/core/workflow/nodes/variable_assigner/v2/node.py
+++ b/api/core/workflow/nodes/variable_assigner/v2/node.py
@@ -1,6 +1,6 @@
import json
-from collections.abc import Callable, Mapping, MutableMapping, Sequence
-from typing import Any, TypeAlias, cast
+from collections.abc import Mapping, MutableMapping, Sequence
+from typing import Any, Optional, cast
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import SegmentType, Variable
@@ -10,7 +10,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode
-from core.workflow.nodes.enums import NodeType
+from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
+from core.workflow.nodes.enums import ErrorStrategy, NodeType
from core.workflow.nodes.variable_assigner.common import helpers as common_helpers
from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory
@@ -28,8 +29,6 @@ from .exc import (
VariableNotFoundError,
)
-_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater]
-
def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem):
selector_node_id = item.variable_selector[0]
@@ -54,10 +53,32 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_
mapping[key] = selector
-class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
- _node_data_cls = VariableAssignerNodeData
+class VariableAssignerNode(BaseNode):
_node_type = NodeType.VARIABLE_ASSIGNER
+ _node_data: VariableAssignerNodeData
+
+ def init_node_data(self, data: Mapping[str, Any]) -> None:
+ self._node_data = VariableAssignerNodeData.model_validate(data)
+
+ def _get_error_strategy(self) -> Optional[ErrorStrategy]:
+ return self._node_data.error_strategy
+
+ def _get_retry_config(self) -> RetryConfig:
+ return self._node_data.retry_config
+
+ def _get_title(self) -> str:
+ return self._node_data.title
+
+ def _get_description(self) -> Optional[str]:
+ return self._node_data.desc
+
+ def _get_default_value_dict(self) -> dict[str, Any]:
+ return self._node_data.default_value_dict
+
+ def get_base_node_data(self) -> BaseNodeData:
+ return self._node_data
+
def _conv_var_updater_factory(self) -> ConversationVariableUpdater:
return conversation_variable_updater_factory()
@@ -71,22 +92,25 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
*,
graph_config: Mapping[str, Any],
node_id: str,
- node_data: VariableAssignerNodeData,
+ node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
+ # Create typed NodeData from dict
+ typed_node_data = VariableAssignerNodeData.model_validate(node_data)
+
var_mapping: dict[str, Sequence[str]] = {}
- for item in node_data.items:
+ for item in typed_node_data.items:
_target_mapping_from_item(var_mapping, node_id, item)
_source_mapping_from_item(var_mapping, node_id, item)
return var_mapping
def _run(self) -> NodeRunResult:
- inputs = self.node_data.model_dump()
+ inputs = self._node_data.model_dump()
process_data: dict[str, Any] = {}
# NOTE: This node has no outputs
updated_variable_selectors: list[Sequence[str]] = []
try:
- for item in self.node_data.items:
+ for item in self._node_data.items:
variable = self.graph_runtime_state.variable_pool.get(item.variable_selector)
# ==================== Validation Part
diff --git a/api/core/workflow/repositories/workflow_execution_repository.py b/api/core/workflow/repositories/workflow_execution_repository.py
index 5917310c8b..bcbd253392 100644
--- a/api/core/workflow/repositories/workflow_execution_repository.py
+++ b/api/core/workflow/repositories/workflow_execution_repository.py
@@ -1,4 +1,4 @@
-from typing import Optional, Protocol
+from typing import Protocol
from core.workflow.entities.workflow_execution import WorkflowExecution
@@ -28,15 +28,3 @@ class WorkflowExecutionRepository(Protocol):
execution: The WorkflowExecution instance to save or update
"""
...
-
- def get(self, execution_id: str) -> Optional[WorkflowExecution]:
- """
- Retrieve a WorkflowExecution by its ID.
-
- Args:
- execution_id: The workflow execution ID
-
- Returns:
- The WorkflowExecution instance if found, None otherwise
- """
- ...
diff --git a/api/core/workflow/repositories/workflow_node_execution_repository.py b/api/core/workflow/repositories/workflow_node_execution_repository.py
index 1908a6b190..8bf81f5442 100644
--- a/api/core/workflow/repositories/workflow_node_execution_repository.py
+++ b/api/core/workflow/repositories/workflow_node_execution_repository.py
@@ -39,18 +39,6 @@ class WorkflowNodeExecutionRepository(Protocol):
"""
...
- def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
- """
- Retrieve a NodeExecution by its node_execution_id.
-
- Args:
- node_execution_id: The node execution ID
-
- Returns:
- The NodeExecution instance if found, None otherwise
- """
- ...
-
def get_by_workflow_run(
self,
workflow_run_id: str,
@@ -69,24 +57,3 @@ class WorkflowNodeExecutionRepository(Protocol):
A list of NodeExecution instances
"""
...
-
- def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
- """
- Retrieve all running NodeExecution instances for a specific workflow run.
-
- Args:
- workflow_run_id: The workflow run ID
-
- Returns:
- A list of running NodeExecution instances
- """
- ...
-
- def clear(self) -> None:
- """
- Clear all NodeExecution records based on implementation-specific criteria.
-
- This method is intended to be used for bulk deletion operations, such as removing
- all records associated with a specific app_id and tenant_id in multi-tenant implementations.
- """
- ...
diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py
new file mode 100644
index 0000000000..df90c16596
--- /dev/null
+++ b/api/core/workflow/system_variable.py
@@ -0,0 +1,89 @@
+from collections.abc import Sequence
+from typing import Any
+
+from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
+
+from core.file.models import File
+from core.workflow.enums import SystemVariableKey
+
+
+class SystemVariable(BaseModel):
+ """A model for managing system variables.
+
+ Fields with a value of `None` are treated as absent and will not be included
+ in the variable pool.
+ """
+
+ model_config = ConfigDict(
+ extra="forbid",
+ serialize_by_alias=True,
+ validate_by_alias=True,
+ )
+
+ user_id: str | None = None
+
+ # Ideally, `app_id` and `workflow_id` should be required and not `None`.
+ # However, there are scenarios in the codebase where these fields are not set.
+ # To maintain compatibility, they are marked as optional here.
+ app_id: str | None = None
+ workflow_id: str | None = None
+
+ files: Sequence[File] = Field(default_factory=list)
+
+ # NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`.
+ # To maintain compatibility with existing workflows, it must be serialized
+ # as `workflow_run_id` in dictionaries or JSON objects, and also referenced
+ # as `workflow_run_id` in the variable pool.
+ workflow_execution_id: str | None = Field(
+ validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"),
+ serialization_alias="workflow_run_id",
+ default=None,
+ )
+ # Chatflow related fields.
+ query: str | None = None
+ conversation_id: str | None = None
+ dialogue_count: int | None = None
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_json_fields(cls, data):
+ if isinstance(data, dict):
+ # For JSON validation, only allow workflow_run_id
+ if "workflow_execution_id" in data and "workflow_run_id" not in data:
+ # This is likely from direct instantiation, allow it
+ return data
+ elif "workflow_execution_id" in data and "workflow_run_id" in data:
+ # Both present, remove workflow_execution_id
+ data = data.copy()
+ data.pop("workflow_execution_id")
+ return data
+ return data
+
+ @classmethod
+ def empty(cls) -> "SystemVariable":
+ return cls()
+
+ def to_dict(self) -> dict[SystemVariableKey, Any]:
+ # NOTE: This method is provided for compatibility with legacy code.
+ # New code should use the `SystemVariable` object directly instead of converting
+ # it to a dictionary, as this conversion results in the loss of type information
+ # for each key, making static analysis more difficult.
+
+ d: dict[SystemVariableKey, Any] = {
+ SystemVariableKey.FILES: self.files,
+ }
+ if self.user_id is not None:
+ d[SystemVariableKey.USER_ID] = self.user_id
+ if self.app_id is not None:
+ d[SystemVariableKey.APP_ID] = self.app_id
+ if self.workflow_id is not None:
+ d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id
+ if self.workflow_execution_id is not None:
+ d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id
+ if self.query is not None:
+ d[SystemVariableKey.QUERY] = self.query
+ if self.conversation_id is not None:
+ d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id
+ if self.dialogue_count is not None:
+ d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count
+ return d
diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py
index 0aab2426af..03f670707e 100644
--- a/api/core/workflow/workflow_cycle_manager.py
+++ b/api/core/workflow/workflow_cycle_manager.py
@@ -1,6 +1,6 @@
from collections.abc import Mapping
from dataclasses import dataclass
-from datetime import UTC, datetime
+from datetime import datetime
from typing import Any, Optional, Union
from uuid import uuid4
@@ -26,6 +26,7 @@ from core.workflow.entities.workflow_node_execution import (
from core.workflow.enums import SystemVariableKey
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from libs.datetime_utils import naive_utc_now
@@ -43,7 +44,7 @@ class WorkflowCycleManager:
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
- workflow_system_variables: dict[SystemVariableKey, Any],
+ workflow_system_variables: SystemVariable,
workflow_info: CycleManagerWorkflowInfo,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
@@ -54,19 +55,15 @@ class WorkflowCycleManager:
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
- def handle_workflow_run_start(self) -> WorkflowExecution:
- inputs = {**self._application_generate_entity.inputs}
- for key, value in (self._workflow_system_variables or {}).items():
- if key.value == "conversation":
- continue
- inputs[f"sys.{key.value}"] = value
+ # Initialize caches for workflow execution cycle
+ # These caches avoid redundant repository calls during a single workflow execution
+ self._workflow_execution_cache: dict[str, WorkflowExecution] = {}
+ self._node_execution_cache: dict[str, WorkflowNodeExecution] = {}
- # handle special values
- inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
+ def handle_workflow_run_start(self) -> WorkflowExecution:
+ inputs = self._prepare_workflow_inputs()
+ execution_id = self._get_or_generate_execution_id()
- # init workflow run
- # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
- execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_EXECUTION_ID) or uuid4())
execution = WorkflowExecution.new(
id_=execution_id,
workflow_id=self._workflow_info.workflow_id,
@@ -74,12 +71,10 @@ class WorkflowCycleManager:
workflow_version=self._workflow_info.version,
graph=self._workflow_info.graph_data,
inputs=inputs,
- started_at=datetime.now(UTC).replace(tzinfo=None),
+ started_at=naive_utc_now(),
)
- self._workflow_execution_repository.save(execution)
-
- return execution
+ return self._save_and_cache_workflow_execution(execution)
def handle_workflow_run_success(
self,
@@ -90,26 +85,19 @@ class WorkflowCycleManager:
outputs: Mapping[str, Any] | None = None,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
+ external_trace_id: Optional[str] = None,
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
- # outputs = WorkflowEntry.handle_special_values(outputs)
+ self._update_workflow_execution_completion(
+ workflow_execution,
+ status=WorkflowExecutionStatus.SUCCEEDED,
+ outputs=outputs,
+ total_tokens=total_tokens,
+ total_steps=total_steps,
+ )
- workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED
- workflow_execution.outputs = outputs or {}
- workflow_execution.total_tokens = total_tokens
- workflow_execution.total_steps = total_steps
- workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
-
- if trace_manager:
- trace_manager.add_trace_task(
- TraceTask(
- TraceTaskName.WORKFLOW_TRACE,
- workflow_execution=workflow_execution,
- conversation_id=conversation_id,
- user_id=trace_manager.user_id,
- )
- )
+ self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id, external_trace_id)
self._workflow_execution_repository.save(workflow_execution)
return workflow_execution
@@ -124,26 +112,20 @@ class WorkflowCycleManager:
exceptions_count: int = 0,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
+ external_trace_id: Optional[str] = None,
) -> WorkflowExecution:
execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
- # outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
- execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED
- execution.outputs = outputs or {}
- execution.total_tokens = total_tokens
- execution.total_steps = total_steps
- execution.finished_at = datetime.now(UTC).replace(tzinfo=None)
- execution.exceptions_count = exceptions_count
+ self._update_workflow_execution_completion(
+ execution,
+ status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
+ outputs=outputs,
+ total_tokens=total_tokens,
+ total_steps=total_steps,
+ exceptions_count=exceptions_count,
+ )
- if trace_manager:
- trace_manager.add_trace_task(
- TraceTask(
- TraceTaskName.WORKFLOW_TRACE,
- workflow_execution=execution,
- conversation_id=conversation_id,
- user_id=trace_manager.user_id,
- )
- )
+ self._add_trace_task_if_needed(trace_manager, execution, conversation_id, external_trace_id)
self._workflow_execution_repository.save(execution)
return execution
@@ -159,43 +141,23 @@ class WorkflowCycleManager:
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
exceptions_count: int = 0,
+ external_trace_id: Optional[str] = None,
) -> WorkflowExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id)
now = naive_utc_now()
- workflow_execution.status = WorkflowExecutionStatus(status.value)
- workflow_execution.error_message = error_message
- workflow_execution.total_tokens = total_tokens
- workflow_execution.total_steps = total_steps
- workflow_execution.finished_at = now
- workflow_execution.exceptions_count = exceptions_count
-
- # Use the instance repository to find running executions for a workflow run
- running_node_executions = self._workflow_node_execution_repository.get_running_executions(
- workflow_run_id=workflow_execution.id_
+ self._update_workflow_execution_completion(
+ workflow_execution,
+ status=status,
+ total_tokens=total_tokens,
+ total_steps=total_steps,
+ error_message=error_message,
+ exceptions_count=exceptions_count,
+ finished_at=now,
)
- # Update the domain models
- for node_execution in running_node_executions:
- if node_execution.node_execution_id:
- # Update the domain model
- node_execution.status = WorkflowNodeExecutionStatus.FAILED
- node_execution.error = error_message
- node_execution.finished_at = now
- node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
-
- # Update the repository with the domain model
- self._workflow_node_execution_repository.save(node_execution)
-
- if trace_manager:
- trace_manager.add_trace_task(
- TraceTask(
- TraceTaskName.WORKFLOW_TRACE,
- workflow_execution=workflow_execution,
- conversation_id=conversation_id,
- user_id=trace_manager.user_id,
- )
- )
+ self._fail_running_node_executions(workflow_execution.id_, error_message, now)
+ self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id, external_trace_id)
self._workflow_execution_repository.save(workflow_execution)
return workflow_execution
@@ -208,65 +170,24 @@ class WorkflowCycleManager:
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
- # Create a domain model
- created_at = datetime.now(UTC).replace(tzinfo=None)
- metadata = {
- WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
- WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
- WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
- }
-
- domain_execution = WorkflowNodeExecution(
- id=str(uuid4()),
- workflow_id=workflow_execution.workflow_id,
- workflow_execution_id=workflow_execution.id_,
- predecessor_node_id=event.predecessor_node_id,
- index=event.node_run_index,
- node_execution_id=event.node_execution_id,
- node_id=event.node_id,
- node_type=event.node_type,
- title=event.node_data.title,
+ domain_execution = self._create_node_execution_from_event(
+ workflow_execution=workflow_execution,
+ event=event,
status=WorkflowNodeExecutionStatus.RUNNING,
- metadata=metadata,
- created_at=created_at,
)
- # Use the instance repository to save the domain model
- self._workflow_node_execution_repository.save(domain_execution)
-
- return domain_execution
+ return self._save_and_cache_node_execution(domain_execution)
def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution:
- # Get the domain model from repository
- domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
- if not domain_execution:
- raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
+ domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
- # Process data
- inputs = event.inputs
- process_data = event.process_data
- outputs = event.outputs
-
- # Convert metadata keys to strings
- execution_metadata_dict = {}
- if event.execution_metadata:
- for key, value in event.execution_metadata.items():
- execution_metadata_dict[key] = value
-
- finished_at = datetime.now(UTC).replace(tzinfo=None)
- elapsed_time = (finished_at - event.start_at).total_seconds()
-
- # Update domain model
- domain_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED
- domain_execution.update_from_mapping(
- inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
+ self._update_node_execution_completion(
+ domain_execution,
+ event=event,
+ status=WorkflowNodeExecutionStatus.SUCCEEDED,
)
- domain_execution.finished_at = finished_at
- domain_execution.elapsed_time = elapsed_time
- # Update the repository with the domain model
self._workflow_node_execution_repository.save(domain_execution)
-
return domain_execution
def handle_workflow_node_execution_failed(
@@ -282,96 +203,253 @@ class WorkflowCycleManager:
:param event: queue node failed event
:return:
"""
- # Get the domain model from repository
- domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id)
- if not domain_execution:
- raise ValueError(f"Domain node execution not found: {event.node_execution_id}")
-
- # Process data
- inputs = WorkflowEntry.handle_special_values(event.inputs)
- process_data = WorkflowEntry.handle_special_values(event.process_data)
- outputs = event.outputs
-
- # Convert metadata keys to strings
- execution_metadata_dict = {}
- if event.execution_metadata:
- for key, value in event.execution_metadata.items():
- execution_metadata_dict[key] = value
-
- finished_at = datetime.now(UTC).replace(tzinfo=None)
- elapsed_time = (finished_at - event.start_at).total_seconds()
+ domain_execution = self._get_node_execution_from_cache(event.node_execution_id)
- # Update domain model
- domain_execution.status = (
- WorkflowNodeExecutionStatus.FAILED
- if not isinstance(event, QueueNodeExceptionEvent)
- else WorkflowNodeExecutionStatus.EXCEPTION
+ status = (
+ WorkflowNodeExecutionStatus.EXCEPTION
+ if isinstance(event, QueueNodeExceptionEvent)
+ else WorkflowNodeExecutionStatus.FAILED
)
- domain_execution.error = event.error
- domain_execution.update_from_mapping(
- inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict
+
+ self._update_node_execution_completion(
+ domain_execution,
+ event=event,
+ status=status,
+ error=event.error,
+ handle_special_values=True,
)
- domain_execution.finished_at = finished_at
- domain_execution.elapsed_time = elapsed_time
- # Update the repository with the domain model
self._workflow_node_execution_repository.save(domain_execution)
-
return domain_execution
def handle_workflow_node_execution_retried(
self, *, workflow_execution_id: str, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id)
- created_at = event.start_at
- finished_at = datetime.now(UTC).replace(tzinfo=None)
- elapsed_time = (finished_at - created_at).total_seconds()
+
+ domain_execution = self._create_node_execution_from_event(
+ workflow_execution=workflow_execution,
+ event=event,
+ status=WorkflowNodeExecutionStatus.RETRY,
+ error=event.error,
+ created_at=event.start_at,
+ )
+
+ # Handle inputs and outputs
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = event.outputs
+ metadata = self._merge_event_metadata(event)
- # Convert metadata keys to strings
- origin_metadata = {
- WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
+ domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata)
+
+ return self._save_and_cache_node_execution(domain_execution)
+
+ def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
+ # Check cache first
+ if id in self._workflow_execution_cache:
+ return self._workflow_execution_cache[id]
+
+ raise WorkflowRunNotFoundError(id)
+
+ def _prepare_workflow_inputs(self) -> dict[str, Any]:
+ """Prepare workflow inputs by merging application inputs with system variables."""
+ inputs = {**self._application_generate_entity.inputs}
+
+ if self._workflow_system_variables:
+ for field_name, value in self._workflow_system_variables.to_dict().items():
+ if field_name != SystemVariableKey.CONVERSATION_ID:
+ inputs[f"sys.{field_name}"] = value
+
+ return dict(WorkflowEntry.handle_special_values(inputs) or {})
+
+ def _get_or_generate_execution_id(self) -> str:
+ """Get execution ID from system variables or generate a new one."""
+ if self._workflow_system_variables and self._workflow_system_variables.workflow_execution_id:
+ return str(self._workflow_system_variables.workflow_execution_id)
+ return str(uuid4())
+
+ def _save_and_cache_workflow_execution(self, execution: WorkflowExecution) -> WorkflowExecution:
+ """Save workflow execution to repository and cache it."""
+ self._workflow_execution_repository.save(execution)
+ self._workflow_execution_cache[execution.id_] = execution
+ return execution
+
+ def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution:
+ """Save node execution to repository and cache it if it has an ID."""
+ self._workflow_node_execution_repository.save(execution)
+ if execution.node_execution_id:
+ self._node_execution_cache[execution.node_execution_id] = execution
+ return execution
+
+ def _get_node_execution_from_cache(self, node_execution_id: str) -> WorkflowNodeExecution:
+ """Get node execution from cache or raise error if not found."""
+ domain_execution = self._node_execution_cache.get(node_execution_id)
+ if not domain_execution:
+ raise ValueError(f"Domain node execution not found: {node_execution_id}")
+ return domain_execution
+
+ def _update_workflow_execution_completion(
+ self,
+ execution: WorkflowExecution,
+ *,
+ status: WorkflowExecutionStatus,
+ total_tokens: int,
+ total_steps: int,
+ outputs: Mapping[str, Any] | None = None,
+ error_message: Optional[str] = None,
+ exceptions_count: int = 0,
+ finished_at: Optional[datetime] = None,
+ ) -> None:
+ """Update workflow execution with completion data."""
+ execution.status = status
+ execution.outputs = outputs or {}
+ execution.total_tokens = total_tokens
+ execution.total_steps = total_steps
+ execution.finished_at = finished_at or naive_utc_now()
+ execution.exceptions_count = exceptions_count
+ if error_message:
+ execution.error_message = error_message
+
+ def _add_trace_task_if_needed(
+ self,
+ trace_manager: Optional[TraceQueueManager],
+ workflow_execution: WorkflowExecution,
+ conversation_id: Optional[str],
+ external_trace_id: Optional[str],
+ ) -> None:
+ """Add trace task if trace manager is provided."""
+ if trace_manager:
+ trace_manager.add_trace_task(
+ TraceTask(
+ TraceTaskName.WORKFLOW_TRACE,
+ workflow_execution=workflow_execution,
+ conversation_id=conversation_id,
+ user_id=trace_manager.user_id,
+ external_trace_id=external_trace_id,
+ )
+ )
+
+ def _fail_running_node_executions(
+ self,
+ workflow_execution_id: str,
+ error_message: str,
+ now: datetime,
+ ) -> None:
+ """Fail all running node executions for a workflow."""
+ running_node_executions = [
+ node_exec
+ for node_exec in self._node_execution_cache.values()
+ if node_exec.workflow_execution_id == workflow_execution_id
+ and node_exec.status == WorkflowNodeExecutionStatus.RUNNING
+ ]
+
+ for node_execution in running_node_executions:
+ if node_execution.node_execution_id:
+ node_execution.status = WorkflowNodeExecutionStatus.FAILED
+ node_execution.error = error_message
+ node_execution.finished_at = now
+ node_execution.elapsed_time = (now - node_execution.created_at).total_seconds()
+ self._workflow_node_execution_repository.save(node_execution)
+
+ def _create_node_execution_from_event(
+ self,
+ *,
+ workflow_execution: WorkflowExecution,
+ event: Union[QueueNodeStartedEvent, QueueNodeRetryEvent],
+ status: WorkflowNodeExecutionStatus,
+ error: Optional[str] = None,
+ created_at: Optional[datetime] = None,
+ ) -> WorkflowNodeExecution:
+ """Create a node execution from an event."""
+ now = naive_utc_now()
+ created_at = created_at or now
+
+ metadata = {
WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
+ WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
}
- # Convert execution metadata keys to strings
- execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
- if event.execution_metadata:
- for key, value in event.execution_metadata.items():
- execution_metadata_dict[key] = value
-
- merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
-
- # Create a domain model
domain_execution = WorkflowNodeExecution(
id=str(uuid4()),
workflow_id=workflow_execution.workflow_id,
workflow_execution_id=workflow_execution.id_,
predecessor_node_id=event.predecessor_node_id,
+ index=event.node_run_index,
node_execution_id=event.node_execution_id,
node_id=event.node_id,
node_type=event.node_type,
title=event.node_data.title,
- status=WorkflowNodeExecutionStatus.RETRY,
+ status=status,
+ metadata=metadata,
created_at=created_at,
- finished_at=finished_at,
- elapsed_time=elapsed_time,
- error=event.error,
- index=event.node_run_index,
+ error=error,
)
- # Update with mappings
- domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata)
-
- # Use the instance repository to save the domain model
- self._workflow_node_execution_repository.save(domain_execution)
+ if status == WorkflowNodeExecutionStatus.RETRY:
+ domain_execution.finished_at = now
+ domain_execution.elapsed_time = (now - created_at).total_seconds()
return domain_execution
- def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution:
- execution = self._workflow_execution_repository.get(id)
- if not execution:
- raise WorkflowRunNotFoundError(id)
- return execution
+ def _update_node_execution_completion(
+ self,
+ domain_execution: WorkflowNodeExecution,
+ *,
+ event: Union[
+ QueueNodeSucceededEvent,
+ QueueNodeFailedEvent,
+ QueueNodeInIterationFailedEvent,
+ QueueNodeInLoopFailedEvent,
+ QueueNodeExceptionEvent,
+ ],
+ status: WorkflowNodeExecutionStatus,
+ error: Optional[str] = None,
+ handle_special_values: bool = False,
+ ) -> None:
+ """Update node execution with completion data."""
+ finished_at = naive_utc_now()
+ elapsed_time = (finished_at - event.start_at).total_seconds()
+
+ # Process data
+ if handle_special_values:
+ inputs = WorkflowEntry.handle_special_values(event.inputs)
+ process_data = WorkflowEntry.handle_special_values(event.process_data)
+ else:
+ inputs = event.inputs
+ process_data = event.process_data
+
+ outputs = event.outputs
+
+ # Convert metadata
+ execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, Any] = {}
+ if event.execution_metadata:
+ execution_metadata_dict.update(event.execution_metadata)
+
+ # Update domain model
+ domain_execution.status = status
+ domain_execution.update_from_mapping(
+ inputs=inputs,
+ process_data=process_data,
+ outputs=outputs,
+ metadata=execution_metadata_dict,
+ )
+ domain_execution.finished_at = finished_at
+ domain_execution.elapsed_time = elapsed_time
+
+ if error:
+ domain_execution.error = error
+
+ def _merge_event_metadata(self, event: QueueNodeRetryEvent) -> dict[WorkflowNodeExecutionMetadataKey, str | None]:
+ """Merge event metadata with origin metadata."""
+ origin_metadata = {
+ WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
+ WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
+ WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
+ }
+
+ execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {}
+ if event.execution_metadata:
+ execution_metadata_dict.update(event.execution_metadata)
+
+ return {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata
diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py
index 2868dcb7de..c8082ebf50 100644
--- a/api/core/workflow/workflow_entry.py
+++ b/api/core/workflow/workflow_entry.py
@@ -5,7 +5,7 @@ from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from configs import dify_config
-from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
+from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.workflow.callbacks import WorkflowCallback
@@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event import NodeEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
+from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from factories import file_factory
from models.enums import UserFrom
@@ -145,7 +146,7 @@ class WorkflowEntry:
graph = Graph.init(graph_config=workflow.graph_dict)
# init workflow run state
- node_instance = node_cls(
+ node = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=GraphInitParams(
@@ -162,6 +163,7 @@ class WorkflowEntry:
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
)
+ node.init_node_data(node_config_data)
try:
# variable selector to variable mapping
@@ -189,17 +191,11 @@ class WorkflowEntry:
try:
# run node
- generator = node_instance.run()
+ generator = node.run()
except Exception as e:
- logger.exception(
- "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
- workflow.id,
- node_instance.id,
- node_instance.node_type,
- node_instance.version(),
- )
- raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
- return node_instance, generator
+ logger.exception(f"error while running node, {workflow.id=}, {node.id=}, {node.type_=}, {node.version()=}")
+ raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
+ return node, generator
@classmethod
def run_free_node(
@@ -254,14 +250,14 @@ class WorkflowEntry:
# init variable pool
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=[],
)
node_cls = cast(type[BaseNode], node_cls)
# init workflow run state
- node_instance: BaseNode = node_cls(
+ node: BaseNode = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=GraphInitParams(
@@ -278,6 +274,7 @@ class WorkflowEntry:
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
)
+ node.init_node_data(node_data)
try:
# variable selector to variable mapping
@@ -296,17 +293,12 @@ class WorkflowEntry:
)
# run node
- generator = node_instance.run()
+ generator = node.run()
- return node_instance, generator
+ return node, generator
except Exception as e:
- logger.exception(
- "error while running node_instance, node_id=%s, type=%s, version=%s",
- node_instance.id,
- node_instance.node_type,
- node_instance.version(),
- )
- raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
+ logger.exception(f"error while running node, {node.id=}, {node.type_=}, {node.version()=}")
+ raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
@staticmethod
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:
diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py
index 0123fdac18..2c634d25ec 100644
--- a/api/core/workflow/workflow_type_encoder.py
+++ b/api/core/workflow/workflow_type_encoder.py
@@ -1,4 +1,3 @@
-import json
from collections.abc import Mapping
from typing import Any
@@ -8,18 +7,6 @@ from core.file.models import File
from core.variables import Segment
-class WorkflowRuntimeTypeEncoder(json.JSONEncoder):
- def default(self, o: Any):
- if isinstance(o, Segment):
- return o.value
- elif isinstance(o, File):
- return o.to_dict()
- elif isinstance(o, BaseModel):
- return o.model_dump(mode="json")
- else:
- return super().default(o)
-
-
class WorkflowRuntimeTypeConverter:
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
result = self._to_json_encodable_recursive(value)
diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh
index 18d4f4885d..4de9a25c2f 100755
--- a/api/docker/entrypoint.sh
+++ b/api/docker/entrypoint.sh
@@ -5,6 +5,11 @@ set -e
if [[ "${MIGRATION_ENABLED}" == "true" ]]; then
echo "Running migrations"
flask upgrade-db
+ # Pure migration mode
+ if [[ "${MODE}" == "migration" ]]; then
+ echo "Migration completed, exiting normally"
+ exit 0
+ fi
fi
if [[ "${MODE}" == "worker" ]]; then
@@ -22,7 +27,7 @@ if [[ "${MODE}" == "worker" ]]; then
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
--max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
- -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion}
+ -Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin}
elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}
diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py
index 8a677f6b6f..dc50ca8d96 100644
--- a/api/events/event_handlers/create_document_index.py
+++ b/api/events/event_handlers/create_document_index.py
@@ -1,4 +1,3 @@
-import datetime
import logging
import time
@@ -8,6 +7,7 @@ from werkzeug.exceptions import NotFound
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from events.event_handlers.document_index_event import document_index_created
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models.dataset import Document
@@ -22,7 +22,7 @@ def handle(sender, **kwargs):
document = (
db.session.query(Document)
- .filter(
+ .where(
Document.id == document_id,
Document.dataset_id == dataset_id,
)
@@ -33,7 +33,7 @@ def handle(sender, **kwargs):
raise NotFound("Document not found")
document.indexing_status = "parsing"
- document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py
index 249bd14429..6c9fc0bf1d 100644
--- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py
+++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py
@@ -20,6 +20,7 @@ def handle(sender, **kwargs):
provider_id=tool_entity.provider_id,
tool_name=tool_entity.tool_name,
tenant_id=app.tenant_id,
+ credential_id=tool_entity.credential_id,
)
manager = ToolParameterConfigurationManager(
tenant_id=app.tenant_id,
diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py
index 14396e9920..b8b5a89dc5 100644
--- a/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py
+++ b/api/events/event_handlers/update_app_dataset_join_when_app_model_config_updated.py
@@ -13,7 +13,7 @@ def handle(sender, **kwargs):
dataset_ids = get_dataset_ids_from_model_config(app_model_config)
- app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
+ app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all()
removed_dataset_ids: set[str] = set()
if not app_dataset_joins:
@@ -27,7 +27,7 @@ def handle(sender, **kwargs):
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
- db.session.query(AppDatasetJoin).filter(
+ db.session.query(AppDatasetJoin).where(
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
).delete()
diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py
index dd2efed94b..cf4ba69833 100644
--- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py
+++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py
@@ -15,7 +15,7 @@ def handle(sender, **kwargs):
published_workflow = cast(Workflow, published_workflow)
dataset_ids = get_dataset_ids_from_workflow(published_workflow)
- app_dataset_joins = db.session.query(AppDatasetJoin).filter(AppDatasetJoin.app_id == app.id).all()
+ app_dataset_joins = db.session.query(AppDatasetJoin).where(AppDatasetJoin.app_id == app.id).all()
removed_dataset_ids: set[str] = set()
if not app_dataset_joins:
@@ -29,7 +29,7 @@ def handle(sender, **kwargs):
if removed_dataset_ids:
for dataset_id in removed_dataset_ids:
- db.session.query(AppDatasetJoin).filter(
+ db.session.query(AppDatasetJoin).where(
AppDatasetJoin.app_id == app.id, AppDatasetJoin.dataset_id == dataset_id
).delete()
diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py
index 6279b1ad36..2c2846ba26 100644
--- a/api/extensions/ext_celery.py
+++ b/api/extensions/ext_celery.py
@@ -64,49 +64,62 @@ def init_app(app: DifyApp) -> Celery:
celery_app.set_default()
app.extensions["celery"] = celery_app
- imports = [
- "schedule.clean_embedding_cache_task",
- "schedule.clean_unused_datasets_task",
- "schedule.create_tidb_serverless_task",
- "schedule.update_tidb_serverless_status_task",
- "schedule.clean_messages",
- "schedule.mail_clean_document_notify_task",
- "schedule.queue_monitor_task",
- ]
+ imports = []
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
- beat_schedule = {
- "clean_embedding_cache_task": {
+
+ # if you add a new task, please add the switch to CeleryScheduleTasksConfig
+ beat_schedule = {}
+ if dify_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK:
+ imports.append("schedule.clean_embedding_cache_task")
+ beat_schedule["clean_embedding_cache_task"] = {
"task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task",
"schedule": timedelta(days=day),
- },
- "clean_unused_datasets_task": {
+ }
+ if dify_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK:
+ imports.append("schedule.clean_unused_datasets_task")
+ beat_schedule["clean_unused_datasets_task"] = {
"task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task",
"schedule": timedelta(days=day),
- },
- "create_tidb_serverless_task": {
+ }
+ if dify_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK:
+ imports.append("schedule.create_tidb_serverless_task")
+ beat_schedule["create_tidb_serverless_task"] = {
"task": "schedule.create_tidb_serverless_task.create_tidb_serverless_task",
"schedule": crontab(minute="0", hour="*"),
- },
- "update_tidb_serverless_status_task": {
+ }
+ if dify_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK:
+ imports.append("schedule.update_tidb_serverless_status_task")
+ beat_schedule["update_tidb_serverless_status_task"] = {
"task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task",
"schedule": timedelta(minutes=10),
- },
- "clean_messages": {
+ }
+ if dify_config.ENABLE_CLEAN_MESSAGES:
+ imports.append("schedule.clean_messages")
+ beat_schedule["clean_messages"] = {
"task": "schedule.clean_messages.clean_messages",
"schedule": timedelta(days=day),
- },
- # every Monday
- "mail_clean_document_notify_task": {
+ }
+ if dify_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK:
+ imports.append("schedule.mail_clean_document_notify_task")
+ beat_schedule["mail_clean_document_notify_task"] = {
"task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task",
"schedule": crontab(minute="0", hour="10", day_of_week="1"),
- },
- "datasets-queue-monitor": {
+ }
+ if dify_config.ENABLE_DATASETS_QUEUE_MONITOR:
+ imports.append("schedule.queue_monitor_task")
+ beat_schedule["datasets-queue-monitor"] = {
"task": "schedule.queue_monitor_task.queue_monitor_task",
"schedule": timedelta(
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
),
- },
- }
+ }
+ if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:
+ imports.append("schedule.check_upgradable_plugin_task")
+ beat_schedule["check_upgradable_plugin_task"] = {
+ "task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task",
+ "schedule": crontab(minute="*/15"),
+ }
+
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
return celery_app
diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py
index ddc2158a02..600e336c19 100644
--- a/api/extensions/ext_commands.py
+++ b/api/extensions/ext_commands.py
@@ -18,6 +18,7 @@ def init_app(app: DifyApp):
reset_email,
reset_encrypt_key_pair,
reset_password,
+ setup_system_tool_oauth_client,
upgrade_db,
vdb_migrate,
)
@@ -40,6 +41,7 @@ def init_app(app: DifyApp):
clear_free_plan_tenant_expired_logs,
clear_orphaned_file_records,
remove_orphaned_files_on_storage,
+ setup_system_tool_oauth_client,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)
diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py
index 11d1856ac4..9b18e25eaa 100644
--- a/api/extensions/ext_login.py
+++ b/api/extensions/ext_login.py
@@ -40,9 +40,9 @@ def load_user_from_request(request_from_flask_login):
if workspace_id:
tenant_account_join = (
db.session.query(Tenant, TenantAccountJoin)
- .filter(Tenant.id == workspace_id)
- .filter(TenantAccountJoin.tenant_id == Tenant.id)
- .filter(TenantAccountJoin.role == "owner")
+ .where(Tenant.id == workspace_id)
+ .where(TenantAccountJoin.tenant_id == Tenant.id)
+ .where(TenantAccountJoin.role == "owner")
.one_or_none()
)
if tenant_account_join:
@@ -70,7 +70,7 @@ def load_user_from_request(request_from_flask_login):
end_user_id = decoded.get("end_user_id")
if not end_user_id:
raise Unauthorized("Invalid Authorization token.")
- end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first()
+ end_user = db.session.query(EndUser).where(EndUser.id == decoded["end_user_id"]).first()
if not end_user:
raise NotFound("End user not found.")
return end_user
@@ -78,12 +78,12 @@ def load_user_from_request(request_from_flask_login):
server_code = request.view_args.get("server_code") if request.view_args else None
if not server_code:
raise Unauthorized("Invalid Authorization token.")
- app_mcp_server = db.session.query(AppMCPServer).filter(AppMCPServer.server_code == server_code).first()
+ app_mcp_server = db.session.query(AppMCPServer).where(AppMCPServer.server_code == server_code).first()
if not app_mcp_server:
raise NotFound("App MCP server not found.")
end_user = (
db.session.query(EndUser)
- .filter(EndUser.external_user_id == app_mcp_server.id, EndUser.type == "mcp")
+ .where(EndUser.external_user_id == app_mcp_server.id, EndUser.type == "mcp")
.first()
)
if not end_user:
diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py
index b62b0b60d6..b027a165f9 100644
--- a/api/extensions/ext_otel.py
+++ b/api/extensions/ext_otel.py
@@ -193,13 +193,22 @@ def init_app(app: DifyApp):
insecure=True,
)
else:
+ headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None
+
+ trace_endpoint = dify_config.OTLP_TRACE_ENDPOINT
+ if not trace_endpoint:
+ trace_endpoint = dify_config.OTLP_BASE_ENDPOINT + "/v1/traces"
exporter = HTTPSpanExporter(
- endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/traces",
- headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"},
+ endpoint=trace_endpoint,
+ headers=headers,
)
+
+ metric_endpoint = dify_config.OTLP_METRIC_ENDPOINT
+ if not metric_endpoint:
+ metric_endpoint = dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics"
metric_exporter = HTTPMetricExporter(
- endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics",
- headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"},
+ endpoint=metric_endpoint,
+ headers=headers,
)
else:
exporter = ConsoleSpanExporter()
diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py
index 7448fd4a6b..81eec94da4 100644
--- a/api/extensions/storage/azure_blob_storage.py
+++ b/api/extensions/storage/azure_blob_storage.py
@@ -1,5 +1,5 @@
from collections.abc import Generator
-from datetime import UTC, datetime, timedelta
+from datetime import timedelta
from typing import Optional
from azure.identity import ChainedTokenCredential, DefaultAzureCredential
@@ -8,6 +8,7 @@ from azure.storage.blob import AccountSasPermissions, BlobServiceClient, Resourc
from configs import dify_config
from extensions.ext_redis import redis_client
from extensions.storage.base_storage import BaseStorage
+from libs.datetime_utils import naive_utc_now
class AzureBlobStorage(BaseStorage):
@@ -78,7 +79,7 @@ class AzureBlobStorage(BaseStorage):
account_key=self.account_key or "",
resource_types=ResourceTypes(service=True, container=True, object=True),
permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True),
- expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1),
+ expiry=naive_utc_now() + timedelta(hours=1),
)
redis_client.set(cache_key, sas_token, ex=3000)
return BlobServiceClient(account_url=self.account_url or "", credential=sas_token)
diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py
index 25d1390492..512a9cb608 100644
--- a/api/factories/file_factory.py
+++ b/api/factories/file_factory.py
@@ -148,9 +148,7 @@ def _build_from_local_file(
if strict_type_validation and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
- file_type = (
- FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
- )
+ file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
return File(
id=mapping.get("id"),
@@ -199,9 +197,7 @@ def _build_from_remote_url(
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
file_type = (
- FileType(specified_type)
- if specified_type and specified_type != FileType.CUSTOM.value
- else detected_file_type
+ FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
)
return File(
@@ -265,13 +261,11 @@ def _build_from_tool_file(
transfer_method: FileTransferMethod,
strict_type_validation: bool = False,
) -> File:
- tool_file = (
- db.session.query(ToolFile)
- .filter(
+ tool_file = db.session.scalar(
+ select(ToolFile).where(
ToolFile.id == mapping.get("tool_file_id"),
ToolFile.tenant_id == tenant_id,
)
- .first()
)
if tool_file is None:
@@ -279,16 +273,14 @@ def _build_from_tool_file(
extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin"
- detected_file_type = _standardize_file_type(extension="." + extension, mime_type=tool_file.mimetype)
+ detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype)
specified_type = mapping.get("type")
if strict_type_validation and specified_type and detected_file_type.value != specified_type:
raise ValueError("Detected file type does not match the specified type. Please verify the file.")
- file_type = (
- FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type
- )
+ file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type
return File(
id=mapping.get("id"),
diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py
index 250ee4695e..39ebd009d5 100644
--- a/api/factories/variable_factory.py
+++ b/api/factories/variable_factory.py
@@ -91,9 +91,13 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
result = StringVariable.model_validate(mapping)
case SegmentType.SECRET:
result = SecretVariable.model_validate(mapping)
- case SegmentType.NUMBER if isinstance(value, int):
+ case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int):
+ mapping = dict(mapping)
+ mapping["value_type"] = SegmentType.INTEGER
result = IntegerVariable.model_validate(mapping)
- case SegmentType.NUMBER if isinstance(value, float):
+ case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float):
+ mapping = dict(mapping)
+ mapping["value_type"] = SegmentType.FLOAT
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f"invalid number value {value}")
@@ -119,6 +123,8 @@ def infer_segment_type_from_value(value: Any, /) -> SegmentType:
def build_segment(value: Any, /) -> Segment:
+ # NOTE: If you have runtime type information available, consider using the `build_segment_with_type`
+ # below
if value is None:
return NoneSegment()
if isinstance(value, str):
@@ -134,12 +140,17 @@ def build_segment(value: Any, /) -> Segment:
if isinstance(value, list):
items = [build_segment(item) for item in value]
types = {item.value_type for item in items}
- if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items):
+ if all(isinstance(item, ArraySegment) for item in items):
return ArrayAnySegment(value=value)
+ elif len(types) != 1:
+ if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}):
+ return ArrayNumberSegment(value=value)
+ return ArrayAnySegment(value=value)
+
match types.pop():
case SegmentType.STRING:
return ArrayStringSegment(value=value)
- case SegmentType.NUMBER:
+ case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return ArrayNumberSegment(value=value)
case SegmentType.OBJECT:
return ArrayObjectSegment(value=value)
@@ -153,6 +164,22 @@ def build_segment(value: Any, /) -> Segment:
raise ValueError(f"not supported value {value}")
+_segment_factory: Mapping[SegmentType, type[Segment]] = {
+ SegmentType.NONE: NoneSegment,
+ SegmentType.STRING: StringSegment,
+ SegmentType.INTEGER: IntegerSegment,
+ SegmentType.FLOAT: FloatSegment,
+ SegmentType.FILE: FileSegment,
+ SegmentType.OBJECT: ObjectSegment,
+ # Array types
+ SegmentType.ARRAY_ANY: ArrayAnySegment,
+ SegmentType.ARRAY_STRING: ArrayStringSegment,
+ SegmentType.ARRAY_NUMBER: ArrayNumberSegment,
+ SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
+ SegmentType.ARRAY_FILE: ArrayFileSegment,
+}
+
+
def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
"""
Build a segment with explicit type checking.
@@ -190,7 +217,7 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
if segment_type == SegmentType.NONE:
return NoneSegment()
else:
- raise TypeMismatchError(f"Expected {segment_type}, but got None")
+ raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None")
# Handle empty list special case for array types
if isinstance(value, list) and len(value) == 0:
@@ -205,21 +232,25 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
elif segment_type == SegmentType.ARRAY_FILE:
return ArrayFileSegment(value=value)
else:
- raise TypeMismatchError(f"Expected {segment_type}, but got empty list")
-
- # Build segment using existing logic to infer actual type
- inferred_segment = build_segment(value)
- inferred_type = inferred_segment.value_type
+ raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list")
+ inferred_type = SegmentType.infer_segment_type(value)
# Type compatibility checking
+ if inferred_type is None:
+ raise TypeMismatchError(
+ f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}"
+ )
if inferred_type == segment_type:
- return inferred_segment
-
- # Type mismatch - raise error with descriptive message
- raise TypeMismatchError(
- f"Type mismatch: expected {segment_type}, but value '{value}' "
- f"(type: {type(value).__name__}) corresponds to {inferred_type}"
- )
+ segment_class = _segment_factory[segment_type]
+ return segment_class(value_type=segment_type, value=value)
+ elif segment_type == SegmentType.NUMBER and inferred_type in (
+ SegmentType.INTEGER,
+ SegmentType.FLOAT,
+ ):
+ segment_class = _segment_factory[inferred_type]
+ return segment_class(value_type=inferred_type, value=value)
+ else:
+ raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
def segment_to_variable(
@@ -247,6 +278,6 @@ def segment_to_variable(
name=name,
description=description,
value=segment.value,
- selector=selector,
+ selector=list(selector),
),
)
diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py
new file mode 100644
index 0000000000..8288bd54a3
--- /dev/null
+++ b/api/fields/_value_type_serializer.py
@@ -0,0 +1,15 @@
+from typing import TypedDict
+
+from core.variables.segments import Segment
+from core.variables.types import SegmentType
+
+
+class _VarTypedDict(TypedDict, total=False):
+ value_type: SegmentType
+
+
+def serialize_value_type(v: _VarTypedDict | Segment) -> str:
+ if isinstance(v, Segment):
+ return v.value_type.exposed_type().value
+ else:
+ return v["value_type"].exposed_type().value
diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py
index 73c224542a..b6d85e0e24 100644
--- a/api/fields/app_fields.py
+++ b/api/fields/app_fields.py
@@ -188,6 +188,7 @@ app_detail_fields_with_site = {
"site": fields.Nested(site_fields),
"api_base_url": fields.String,
"use_icon_as_answer_icon": fields.Boolean,
+ "max_active_requests": fields.Integer,
"created_by": fields.String,
"created_at": TimestampField,
"updated_by": fields.String,
diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py
index 71785e7d67..c5a0c9a49d 100644
--- a/api/fields/conversation_variable_fields.py
+++ b/api/fields/conversation_variable_fields.py
@@ -2,10 +2,12 @@ from flask_restful import fields
from libs.helper import TimestampField
+from ._value_type_serializer import serialize_value_type
+
conversation_variable_fields = {
"id": fields.String,
"name": fields.String,
- "value_type": fields.String(attribute="value_type.value"),
+ "value_type": fields.String(attribute=serialize_value_type),
"value": fields.String,
"description": fields.String,
"created_at": TimestampField,
diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py
index f00ea71c54..930e59cc1c 100644
--- a/api/fields/workflow_fields.py
+++ b/api/fields/workflow_fields.py
@@ -5,6 +5,8 @@ from core.variables import SecretVariable, SegmentType, Variable
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
+from ._value_type_serializer import serialize_value_type
+
ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET)
@@ -24,11 +26,16 @@ class EnvironmentVariableField(fields.Raw):
"id": value.id,
"name": value.name,
"value": value.value,
- "value_type": value.value_type.value,
+ "value_type": value.value_type.exposed_type().value,
"description": value.description,
}
if isinstance(value, dict):
- value_type = value.get("value_type")
+ value_type_str = value.get("value_type")
+ if not isinstance(value_type_str, str):
+ raise TypeError(
+ f"unexpected type for value_type field, value={value_type_str}, type={type(value_type_str)}"
+ )
+ value_type = SegmentType(value_type_str).exposed_type()
if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES:
raise ValueError(f"Unsupported environment variable value type: {value_type}")
return value
@@ -37,7 +44,7 @@ class EnvironmentVariableField(fields.Raw):
conversation_variable_fields = {
"id": fields.String,
"name": fields.String,
- "value_type": fields.String(attribute="value_type.value"),
+ "value_type": fields.String(attribute=serialize_value_type),
"value": fields.Raw,
"description": fields.String,
}
diff --git a/api/libs/email_i18n.py b/api/libs/email_i18n.py
new file mode 100644
index 0000000000..b7c9f3ec6c
--- /dev/null
+++ b/api/libs/email_i18n.py
@@ -0,0 +1,474 @@
+"""
+Email Internationalization Module
+
+This module provides a centralized, elegant way to handle email internationalization
+in Dify. It follows Domain-Driven Design principles with proper type hints and
+eliminates the need for repetitive language switching logic.
+"""
+
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, Optional, Protocol
+
+from flask import render_template
+from pydantic import BaseModel, Field
+
+from extensions.ext_mail import mail
+from services.feature_service import BrandingModel, FeatureService
+
+
+class EmailType(Enum):
+ """Enumeration of supported email types."""
+
+ RESET_PASSWORD = "reset_password"
+ INVITE_MEMBER = "invite_member"
+ EMAIL_CODE_LOGIN = "email_code_login"
+ CHANGE_EMAIL_OLD = "change_email_old"
+ CHANGE_EMAIL_NEW = "change_email_new"
+ CHANGE_EMAIL_COMPLETED = "change_email_completed"
+ OWNER_TRANSFER_CONFIRM = "owner_transfer_confirm"
+ OWNER_TRANSFER_OLD_NOTIFY = "owner_transfer_old_notify"
+ OWNER_TRANSFER_NEW_NOTIFY = "owner_transfer_new_notify"
+ ACCOUNT_DELETION_SUCCESS = "account_deletion_success"
+ ACCOUNT_DELETION_VERIFICATION = "account_deletion_verification"
+ ENTERPRISE_CUSTOM = "enterprise_custom"
+ QUEUE_MONITOR_ALERT = "queue_monitor_alert"
+ DOCUMENT_CLEAN_NOTIFY = "document_clean_notify"
+
+
+class EmailLanguage(Enum):
+ """Supported email languages with fallback handling."""
+
+ EN_US = "en-US"
+ ZH_HANS = "zh-Hans"
+
+ @classmethod
+ def from_language_code(cls, language_code: str) -> "EmailLanguage":
+ """Convert a language code to EmailLanguage with fallback to English."""
+ if language_code == "zh-Hans":
+ return cls.ZH_HANS
+ return cls.EN_US
+
+
+@dataclass(frozen=True)
+class EmailTemplate:
+ """Immutable value object representing an email template configuration."""
+
+ subject: str
+ template_path: str
+ branded_template_path: str
+
+
+@dataclass(frozen=True)
+class EmailContent:
+ """Immutable value object containing rendered email content."""
+
+ subject: str
+ html_content: str
+ template_context: dict[str, Any]
+
+
+class EmailI18nConfig(BaseModel):
+ """Configuration for email internationalization."""
+
+ model_config = {"frozen": True, "extra": "forbid"}
+
+ templates: dict[EmailType, dict[EmailLanguage, EmailTemplate]] = Field(
+ default_factory=dict, description="Mapping of email types to language-specific templates"
+ )
+
+ def get_template(self, email_type: EmailType, language: EmailLanguage) -> EmailTemplate:
+ """Get template configuration for specific email type and language."""
+ type_templates = self.templates.get(email_type)
+ if not type_templates:
+ raise ValueError(f"No templates configured for email type: {email_type}")
+
+ template = type_templates.get(language)
+ if not template:
+ # Fallback to English if specific language not found
+ template = type_templates.get(EmailLanguage.EN_US)
+ if not template:
+ raise ValueError(f"No template found for {email_type} in {language} or English")
+
+ return template
+
+
+class EmailRenderer(Protocol):
+ """Protocol for email template renderers."""
+
+ def render_template(self, template_path: str, **context: Any) -> str:
+ """Render email template with given context."""
+ ...
+
+
+class FlaskEmailRenderer:
+ """Flask-based email template renderer."""
+
+ def render_template(self, template_path: str, **context: Any) -> str:
+ """Render email template using Flask's render_template."""
+ return render_template(template_path, **context)
+
+
+class BrandingService(Protocol):
+ """Protocol for branding service abstraction."""
+
+ def get_branding_config(self) -> BrandingModel:
+ """Get current branding configuration."""
+ ...
+
+
+class FeatureBrandingService:
+ """Feature service based branding implementation."""
+
+ def get_branding_config(self) -> BrandingModel:
+ """Get branding configuration from feature service."""
+ return FeatureService.get_system_features().branding
+
+
+class EmailSender(Protocol):
+ """Protocol for email sending abstraction."""
+
+ def send_email(self, to: str, subject: str, html_content: str) -> None:
+ """Send email with given parameters."""
+ ...
+
+
+class FlaskMailSender:
+ """Flask-Mail based email sender."""
+
+ def send_email(self, to: str, subject: str, html_content: str) -> None:
+ """Send email using Flask-Mail."""
+ if mail.is_inited():
+ mail.send(to=to, subject=subject, html=html_content)
+
+
+class EmailI18nService:
+ """
+ Main service for internationalized email handling.
+
+ This service provides a clean API for sending internationalized emails
+ with proper branding support and template management.
+ """
+
+ def __init__(
+ self,
+ config: EmailI18nConfig,
+ renderer: EmailRenderer,
+ branding_service: BrandingService,
+ sender: EmailSender,
+ ) -> None:
+ self._config = config
+ self._renderer = renderer
+ self._branding_service = branding_service
+ self._sender = sender
+
+ def send_email(
+ self,
+ email_type: EmailType,
+ language_code: str,
+ to: str,
+ template_context: Optional[dict[str, Any]] = None,
+ ) -> None:
+ """
+ Send internationalized email with branding support.
+
+ Args:
+ email_type: Type of email to send
+ language_code: Target language code
+ to: Recipient email address
+ template_context: Additional context for template rendering
+ """
+ if template_context is None:
+ template_context = {}
+
+ language = EmailLanguage.from_language_code(language_code)
+ email_content = self._render_email_content(email_type, language, template_context)
+
+ self._sender.send_email(to=to, subject=email_content.subject, html_content=email_content.html_content)
+
+ def send_change_email(
+ self,
+ language_code: str,
+ to: str,
+ code: str,
+ phase: str,
+ ) -> None:
+ """
+ Send change email notification with phase-specific handling.
+
+ Args:
+ language_code: Target language code
+ to: Recipient email address
+ code: Verification code
+ phase: Either 'old_email' or 'new_email'
+ """
+ if phase == "old_email":
+ email_type = EmailType.CHANGE_EMAIL_OLD
+ elif phase == "new_email":
+ email_type = EmailType.CHANGE_EMAIL_NEW
+ else:
+ raise ValueError(f"Invalid phase: {phase}. Must be 'old_email' or 'new_email'")
+
+ self.send_email(
+ email_type=email_type,
+ language_code=language_code,
+ to=to,
+ template_context={
+ "to": to,
+ "code": code,
+ },
+ )
+
+ def send_raw_email(
+ self,
+ to: str | list[str],
+ subject: str,
+ html_content: str,
+ ) -> None:
+ """
+ Send a raw email directly without template processing.
+
+ This method is provided for backward compatibility with legacy email
+ sending that uses pre-rendered HTML content (e.g., enterprise emails
+ with custom templates).
+
+ Args:
+ to: Recipient email address(es)
+ subject: Email subject
+ html_content: Pre-rendered HTML content
+ """
+ if isinstance(to, list):
+ for recipient in to:
+ self._sender.send_email(to=recipient, subject=subject, html_content=html_content)
+ else:
+ self._sender.send_email(to=to, subject=subject, html_content=html_content)
+
+ def _render_email_content(
+ self,
+ email_type: EmailType,
+ language: EmailLanguage,
+ template_context: dict[str, Any],
+ ) -> EmailContent:
+ """Render email content with branding and internationalization."""
+ template_config = self._config.get_template(email_type, language)
+ branding = self._branding_service.get_branding_config()
+
+ # Determine template path based on branding
+ template_path = template_config.branded_template_path if branding.enabled else template_config.template_path
+
+ # Prepare template context with branding information
+ full_context = {
+ **template_context,
+ "branding_enabled": branding.enabled,
+ "application_title": branding.application_title if branding.enabled else "Dify",
+ }
+
+ # Render template
+ html_content = self._renderer.render_template(template_path, **full_context)
+
+ # Apply templating to subject with all context variables
+ subject = template_config.subject
+ try:
+ subject = subject.format(**full_context)
+ except KeyError:
+ # If template variables are missing, fall back to basic formatting
+ if branding.enabled and "{application_title}" in subject:
+ subject = subject.format(application_title=branding.application_title)
+
+ return EmailContent(
+ subject=subject,
+ html_content=html_content,
+ template_context=full_context,
+ )
+
+
+def create_default_email_config() -> EmailI18nConfig:
+ """Create default email i18n configuration with all supported templates."""
+ templates: dict[EmailType, dict[EmailLanguage, EmailTemplate]] = {
+ EmailType.RESET_PASSWORD: {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="Set Your {application_title} Password",
+ template_path="reset_password_mail_template_en-US.html",
+ branded_template_path="without-brand/reset_password_mail_template_en-US.html",
+ ),
+ EmailLanguage.ZH_HANS: EmailTemplate(
+ subject="设置您的 {application_title} 密码",
+ template_path="reset_password_mail_template_zh-CN.html",
+ branded_template_path="without-brand/reset_password_mail_template_zh-CN.html",
+ ),
+ },
+ EmailType.INVITE_MEMBER: {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="Join {application_title} Workspace Now",
+ template_path="invite_member_mail_template_en-US.html",
+ branded_template_path="without-brand/invite_member_mail_template_en-US.html",
+ ),
+ EmailLanguage.ZH_HANS: EmailTemplate(
+ subject="立即加入 {application_title} 工作空间",
+ template_path="invite_member_mail_template_zh-CN.html",
+ branded_template_path="without-brand/invite_member_mail_template_zh-CN.html",
+ ),
+ },
+ EmailType.EMAIL_CODE_LOGIN: {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="{application_title} Login Code",
+ template_path="email_code_login_mail_template_en-US.html",
+ branded_template_path="without-brand/email_code_login_mail_template_en-US.html",
+ ),
+ EmailLanguage.ZH_HANS: EmailTemplate(
+ subject="{application_title} 登录验证码",
+ template_path="email_code_login_mail_template_zh-CN.html",
+ branded_template_path="without-brand/email_code_login_mail_template_zh-CN.html",
+ ),
+ },
+ EmailType.CHANGE_EMAIL_OLD: {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="Check your current email",
+ template_path="change_mail_confirm_old_template_en-US.html",
+ branded_template_path="without-brand/change_mail_confirm_old_template_en-US.html",
+ ),
+ EmailLanguage.ZH_HANS: EmailTemplate(
+ subject="检测您现在的邮箱",
+ template_path="change_mail_confirm_old_template_zh-CN.html",
+ branded_template_path="without-brand/change_mail_confirm_old_template_zh-CN.html",
+ ),
+ },
+ EmailType.CHANGE_EMAIL_NEW: {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="Confirm your new email address",
+ template_path="change_mail_confirm_new_template_en-US.html",
+ branded_template_path="without-brand/change_mail_confirm_new_template_en-US.html",
+ ),
+ EmailLanguage.ZH_HANS: EmailTemplate(
+ subject="确认您的邮箱地址变更",
+ template_path="change_mail_confirm_new_template_zh-CN.html",
+ branded_template_path="without-brand/change_mail_confirm_new_template_zh-CN.html",
+ ),
+ },
+ EmailType.CHANGE_EMAIL_COMPLETED: {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="Your login email has been changed",
+ template_path="change_mail_completed_template_en-US.html",
+ branded_template_path="without-brand/change_mail_completed_template_en-US.html",
+ ),
+ EmailLanguage.ZH_HANS: EmailTemplate(
+ subject="您的登录邮箱已更改",
+ template_path="change_mail_completed_template_zh-CN.html",
+ branded_template_path="without-brand/change_mail_completed_template_zh-CN.html",
+ ),
+ },
+ EmailType.OWNER_TRANSFER_CONFIRM: {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="Verify Your Request to Transfer Workspace Ownership",
+ template_path="transfer_workspace_owner_confirm_template_en-US.html",
+ branded_template_path="without-brand/transfer_workspace_owner_confirm_template_en-US.html",
+ ),
+ EmailLanguage.ZH_HANS: EmailTemplate(
+ subject="验证您转移工作空间所有权的请求",
+ template_path="transfer_workspace_owner_confirm_template_zh-CN.html",
+ branded_template_path="without-brand/transfer_workspace_owner_confirm_template_zh-CN.html",
+ ),
+ },
+ EmailType.OWNER_TRANSFER_OLD_NOTIFY: {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="Workspace ownership has been transferred",
+ template_path="transfer_workspace_old_owner_notify_template_en-US.html",
+ branded_template_path="without-brand/transfer_workspace_old_owner_notify_template_en-US.html",
+ ),
+ EmailLanguage.ZH_HANS: EmailTemplate(
+ subject="工作区所有权已转移",
+ template_path="transfer_workspace_old_owner_notify_template_zh-CN.html",
+ branded_template_path="without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html",
+ ),
+ },
+ EmailType.OWNER_TRANSFER_NEW_NOTIFY: {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="You are now the owner of {WorkspaceName}",
+ template_path="transfer_workspace_new_owner_notify_template_en-US.html",
+ branded_template_path="without-brand/transfer_workspace_new_owner_notify_template_en-US.html",
+ ),
+ EmailLanguage.ZH_HANS: EmailTemplate(
+ subject="您现在是 {WorkspaceName} 的所有者",
+ template_path="transfer_workspace_new_owner_notify_template_zh-CN.html",
+ branded_template_path="without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html",
+ ),
+ },
+ EmailType.ACCOUNT_DELETION_SUCCESS: {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="Your Dify.AI Account Has Been Successfully Deleted",
+ template_path="delete_account_success_template_en-US.html",
+ branded_template_path="delete_account_success_template_en-US.html",
+ ),
+ EmailLanguage.ZH_HANS: EmailTemplate(
+ subject="您的 Dify.AI 账户已成功删除",
+ template_path="delete_account_success_template_zh-CN.html",
+ branded_template_path="delete_account_success_template_zh-CN.html",
+ ),
+ },
+ EmailType.ACCOUNT_DELETION_VERIFICATION: {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="Dify.AI Account Deletion and Verification",
+ template_path="delete_account_code_email_template_en-US.html",
+ branded_template_path="delete_account_code_email_template_en-US.html",
+ ),
+ EmailLanguage.ZH_HANS: EmailTemplate(
+ subject="Dify.AI 账户删除和验证",
+ template_path="delete_account_code_email_template_zh-CN.html",
+ branded_template_path="delete_account_code_email_template_zh-CN.html",
+ ),
+ },
+ EmailType.QUEUE_MONITOR_ALERT: {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="Alert: Dataset Queue pending tasks exceeded the limit",
+ template_path="queue_monitor_alert_email_template_en-US.html",
+ branded_template_path="queue_monitor_alert_email_template_en-US.html",
+ ),
+ EmailLanguage.ZH_HANS: EmailTemplate(
+ subject="警报:数据集队列待处理任务超过限制",
+ template_path="queue_monitor_alert_email_template_zh-CN.html",
+ branded_template_path="queue_monitor_alert_email_template_zh-CN.html",
+ ),
+ },
+ EmailType.DOCUMENT_CLEAN_NOTIFY: {
+ EmailLanguage.EN_US: EmailTemplate(
+ subject="Dify Knowledge base auto disable notification",
+ template_path="clean_document_job_mail_template-US.html",
+ branded_template_path="clean_document_job_mail_template-US.html",
+ ),
+ EmailLanguage.ZH_HANS: EmailTemplate(
+ subject="Dify 知识库自动禁用通知",
+ template_path="clean_document_job_mail_template_zh-CN.html",
+ branded_template_path="clean_document_job_mail_template_zh-CN.html",
+ ),
+ },
+ }
+
+ return EmailI18nConfig(templates=templates)
+
+
+# Singleton instance for application-wide use
+def get_default_email_i18n_service() -> EmailI18nService:
+ """Get configured email i18n service with default dependencies."""
+ config = create_default_email_config()
+ renderer = FlaskEmailRenderer()
+ branding_service = FeatureBrandingService()
+ sender = FlaskMailSender()
+
+ return EmailI18nService(
+ config=config,
+ renderer=renderer,
+ branding_service=branding_service,
+ sender=sender,
+ )
+
+
+# Global instance
+_email_i18n_service: Optional[EmailI18nService] = None
+
+
+def get_email_i18n_service() -> EmailI18nService:
+ """Get global email i18n service instance."""
+ global _email_i18n_service
+ if _email_i18n_service is None:
+ _email_i18n_service = get_default_email_i18n_service()
+ return _email_i18n_service
diff --git a/api/libs/helper.py b/api/libs/helper.py
index 48126461a3..00772d530a 100644
--- a/api/libs/helper.py
+++ b/api/libs/helper.py
@@ -148,25 +148,6 @@ class StrLen:
return value
-class FloatRange:
- """Restrict input to an float in a range (inclusive)"""
-
- def __init__(self, low, high, argument="argument"):
- self.low = low
- self.high = high
- self.argument = argument
-
- def __call__(self, value):
- value = _get_float(value)
- if value < self.low or value > self.high:
- error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format(
- arg=self.argument, val=value, lo=self.low, hi=self.high
- )
- raise ValueError(error)
-
- return value
-
-
class DatetimeString:
def __init__(self, format, argument="argument"):
self.format = format
diff --git a/api/libs/jsonutil.py b/api/libs/jsonutil.py
deleted file mode 100644
index fa29671034..0000000000
--- a/api/libs/jsonutil.py
+++ /dev/null
@@ -1,11 +0,0 @@
-import json
-
-from pydantic import BaseModel
-
-
-class PydanticModelEncoder(json.JSONEncoder):
- def default(self, o):
- if isinstance(o, BaseModel):
- return o.model_dump()
- else:
- super().default(o)
diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py
index 218109522d..987c5d7135 100644
--- a/api/libs/oauth_data_source.py
+++ b/api/libs/oauth_data_source.py
@@ -1,11 +1,12 @@
-import datetime
import urllib.parse
from typing import Any
import requests
from flask_login import current_user
+from sqlalchemy import select
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding
@@ -61,21 +62,17 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages),
}
# save data source binding
- data_source_binding = (
- db.session.query(DataSourceOauthBinding)
- .filter(
- db.and_(
- DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
- DataSourceOauthBinding.provider == "notion",
- DataSourceOauthBinding.access_token == access_token,
- )
+ data_source_binding = db.session.scalar(
+ select(DataSourceOauthBinding).where(
+ DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+ DataSourceOauthBinding.provider == "notion",
+ DataSourceOauthBinding.access_token == access_token,
)
- .first()
)
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
- data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
@@ -101,21 +98,17 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages),
}
# save data source binding
- data_source_binding = (
- db.session.query(DataSourceOauthBinding)
- .filter(
- db.and_(
- DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
- DataSourceOauthBinding.provider == "notion",
- DataSourceOauthBinding.access_token == access_token,
- )
+ data_source_binding = db.session.scalar(
+ select(DataSourceOauthBinding).where(
+ DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+ DataSourceOauthBinding.provider == "notion",
+ DataSourceOauthBinding.access_token == access_token,
)
- .first()
)
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.disabled = False
- data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
new_data_source_binding = DataSourceOauthBinding(
@@ -129,18 +122,15 @@ class NotionOAuth(OAuthDataSource):
def sync_data_source(self, binding_id: str):
# save data source binding
- data_source_binding = (
- db.session.query(DataSourceOauthBinding)
- .filter(
- db.and_(
- DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
- DataSourceOauthBinding.provider == "notion",
- DataSourceOauthBinding.id == binding_id,
- DataSourceOauthBinding.disabled == False,
- )
+ data_source_binding = db.session.scalar(
+ select(DataSourceOauthBinding).where(
+ DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
+ DataSourceOauthBinding.provider == "notion",
+ DataSourceOauthBinding.id == binding_id,
+ DataSourceOauthBinding.disabled == False,
)
- .first()
)
+
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)
@@ -154,7 +144,7 @@ class NotionOAuth(OAuthDataSource):
}
data_source_binding.source_info = new_source_info
data_source_binding.disabled = False
- data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
raise ValueError("Data source binding not found")
diff --git a/api/libs/rsa.py b/api/libs/rsa.py
index 637bcc4a1d..ed7a0eb116 100644
--- a/api/libs/rsa.py
+++ b/api/libs/rsa.py
@@ -1,4 +1,6 @@
import hashlib
+import os
+from typing import Union
from Crypto.Cipher import AES
from Crypto.PublicKey import RSA
@@ -9,14 +11,14 @@ from extensions.ext_storage import storage
from libs import gmpy2_pkcs10aep_cipher
-def generate_key_pair(tenant_id):
+def generate_key_pair(tenant_id: str) -> str:
private_key = RSA.generate(2048)
public_key = private_key.publickey()
pem_private = private_key.export_key()
pem_public = public_key.export_key()
- filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
+ filepath = os.path.join("privkeys", tenant_id, "private.pem")
storage.save(filepath, pem_private)
@@ -26,7 +28,7 @@ def generate_key_pair(tenant_id):
prefix_hybrid = b"HYBRID:"
-def encrypt(text, public_key):
+def encrypt(text: str, public_key: Union[str, bytes]) -> bytes:
if isinstance(public_key, str):
public_key = public_key.encode()
@@ -38,15 +40,15 @@ def encrypt(text, public_key):
rsa_key = RSA.import_key(public_key)
cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key)
- enc_aes_key = cipher_rsa.encrypt(aes_key)
+ enc_aes_key: bytes = cipher_rsa.encrypt(aes_key)
encrypted_data = enc_aes_key + cipher_aes.nonce + tag + ciphertext
return prefix_hybrid + encrypted_data
-def get_decrypt_decoding(tenant_id):
- filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
+def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]:
+ filepath = os.path.join("privkeys", tenant_id, "private.pem")
cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
private_key = redis_client.get(cache_key)
@@ -64,7 +66,7 @@ def get_decrypt_decoding(tenant_id):
return rsa_key, cipher_rsa
-def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa):
+def decrypt_token_with_decoding(encrypted_text: bytes, rsa_key: RSA.RsaKey, cipher_rsa) -> str:
if encrypted_text.startswith(prefix_hybrid):
encrypted_text = encrypted_text[len(prefix_hybrid) :]
@@ -83,10 +85,10 @@ def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa):
return decrypted_text.decode()
-def decrypt(encrypted_text, tenant_id):
+def decrypt(encrypted_text: bytes, tenant_id: str) -> str:
rsa_key, cipher_rsa = get_decrypt_decoding(tenant_id)
- return decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa)
+ return decrypt_token_with_decoding(encrypted_text=encrypted_text, rsa_key=rsa_key, cipher_rsa=cipher_rsa)
class PrivkeyNotFoundError(Exception):
diff --git a/api/libs/uuid_utils.py b/api/libs/uuid_utils.py
new file mode 100644
index 0000000000..a8190011ed
--- /dev/null
+++ b/api/libs/uuid_utils.py
@@ -0,0 +1,164 @@
+import secrets
+import struct
+import time
+import uuid
+
+# Reference for UUIDv7 specification:
+# RFC 9562, Section 5.7 - https://www.rfc-editor.org/rfc/rfc9562.html#section-5.7
+
+# Define the format for packing the timestamp as an unsigned 64-bit integer (big-endian).
+#
+# For details on the `struct.pack` format, refer to:
+# https://docs.python.org/3/library/struct.html#byte-order-size-and-alignment
+_PACK_TIMESTAMP = ">Q"
+
+# Define the format for packing the 12-bit random data A (as specified in RFC 9562 Section 5.7)
+# into an unsigned 16-bit integer (big-endian).
+_PACK_RAND_A = ">H"
+
+
+def _create_uuidv7_bytes(timestamp_ms: int, random_bytes: bytes) -> bytes:
+ """Create UUIDv7 byte structure with given timestamp and random bytes.
+
+ This is a private helper function that handles the common logic for creating
+ UUIDv7 byte structure according to RFC 9562 specification.
+
+ UUIDv7 Structure:
+ - 48 bits: timestamp (milliseconds since Unix epoch)
+ - 12 bits: random data A (with version bits)
+ - 62 bits: random data B (with variant bits)
+
+ The function performs the following operations:
+ 1. Creates a 128-bit (16-byte) UUID structure
+ 2. Packs the timestamp into the first 48 bits (6 bytes)
+ 3. Sets the version bits to 7 (0111) in the correct position
+ 4. Sets the variant bits to 10 (binary) in the correct position
+ 5. Fills the remaining bits with the provided random bytes
+
+ Args:
+ timestamp_ms: The timestamp in milliseconds since Unix epoch (48 bits).
+ random_bytes: Random bytes to use for the random portions (must be 10 bytes).
+ First 2 bytes are used for random data A (12 bits after version).
+ Last 8 bytes are used for random data B (62 bits after variant).
+
+ Returns:
+ A 16-byte bytes object representing the complete UUIDv7 structure.
+
+ Note:
+ This function assumes the random_bytes parameter is exactly 10 bytes.
+ The caller is responsible for providing appropriate random data.
+ """
+ # Create the 128-bit UUID structure
+ uuid_bytes = bytearray(16)
+
+ # Pack timestamp (48 bits) into first 6 bytes
+ uuid_bytes[0:6] = struct.pack(_PACK_TIMESTAMP, timestamp_ms)[2:8] # Take last 6 bytes of 8-byte big-endian
+
+ # Next 16 bits: random data A (12 bits) + version (4 bits)
+ # Take first 2 random bytes and set version to 7
+ rand_a = struct.unpack(_PACK_RAND_A, random_bytes[0:2])[0]
+ # Clear the highest 4 bits to make room for the version field
+ # by performing a bitwise AND with 0x0FFF (binary: 0b0000_1111_1111_1111).
+ rand_a = rand_a & 0x0FFF
+ # Set the version field to 7 (binary: 0111) by performing a bitwise OR with 0x7000 (binary: 0b0111_0000_0000_0000).
+ rand_a = rand_a | 0x7000
+ uuid_bytes[6:8] = struct.pack(_PACK_RAND_A, rand_a)
+
+ # Last 64 bits: random data B (62 bits) + variant (2 bits)
+ # Use remaining 8 random bytes and set variant to 10 (binary)
+ uuid_bytes[8:16] = random_bytes[2:10]
+
+ # Set variant bits (first 2 bits of byte 8 should be '10')
+ uuid_bytes[8] = (uuid_bytes[8] & 0x3F) | 0x80 # Set variant to 10xxxxxx
+
+ return bytes(uuid_bytes)
+
+
+def uuidv7(timestamp_ms: int | None = None) -> uuid.UUID:
+ """Generate a UUID version 7 according to RFC 9562 specification.
+
+ UUIDv7 features a time-ordered value field derived from the widely
+ implemented and well known Unix Epoch timestamp source, the number of
+ milliseconds since midnight 1 Jan 1970 UTC, leap seconds excluded.
+
+ Structure:
+ - 48 bits: timestamp (milliseconds since Unix epoch)
+ - 12 bits: random data A (with version bits)
+ - 62 bits: random data B (with variant bits)
+
+ Args:
+ timestamp_ms: The timestamp used when generating UUID, use the current time if unspecified.
+ Should be an integer representing milliseconds since Unix epoch.
+
+ Returns:
+ A UUID object representing a UUIDv7.
+
+ Example:
+ >>> import time
+ >>> # Generate UUIDv7 with current time
+ >>> uuid_current = uuidv7()
+ >>> # Generate UUIDv7 with specific timestamp
+ >>> uuid_specific = uuidv7(int(time.time() * 1000))
+ """
+ if timestamp_ms is None:
+ timestamp_ms = int(time.time() * 1000)
+
+ # Generate 10 random bytes for the random portions
+ random_bytes = secrets.token_bytes(10)
+
+ # Create UUIDv7 bytes using the helper function
+ uuid_bytes = _create_uuidv7_bytes(timestamp_ms, random_bytes)
+
+ return uuid.UUID(bytes=uuid_bytes)
+
+
+def uuidv7_timestamp(id_: uuid.UUID) -> int:
+ """Extract the timestamp from a UUIDv7.
+
+ UUIDv7 contains a 48-bit timestamp field representing milliseconds since
+ the Unix epoch (1970-01-01 00:00:00 UTC). This function extracts and
+ returns that timestamp as an integer representing milliseconds since the epoch.
+
+ Args:
+ id_: A UUID object that should be a UUIDv7 (version 7).
+
+ Returns:
+ The timestamp as an integer representing milliseconds since Unix epoch.
+
+ Raises:
+ ValueError: If the provided UUID is not version 7.
+
+ Example:
+ >>> uuid_v7 = uuidv7()
+ >>> timestamp = uuidv7_timestamp(uuid_v7)
+ >>> print(f"UUID was created at: {timestamp} ms")
+ """
+ # Verify this is a UUIDv7
+ if id_.version != 7:
+ raise ValueError(f"Expected UUIDv7 (version 7), got version {id_.version}")
+
+ # Extract the UUID bytes
+ uuid_bytes = id_.bytes
+
+ # Extract the first 48 bits (6 bytes) as the timestamp in milliseconds
+ # Pad with 2 zero bytes at the beginning to make it 8 bytes for unpacking as Q (unsigned long long)
+ timestamp_bytes = b"\x00\x00" + uuid_bytes[0:6]
+ ts_in_ms = struct.unpack(_PACK_TIMESTAMP, timestamp_bytes)[0]
+
+ # Return timestamp directly in milliseconds as integer
+ assert isinstance(ts_in_ms, int)
+ return ts_in_ms
+
+
+def uuidv7_boundary(timestamp_ms: int) -> uuid.UUID:
+ """Generate a non-random uuidv7 with the given timestamp (first 48 bits) and
+ all random bits to 0. As the smallest possible uuidv7 for that timestamp,
+ it may be used as a boundary for partitions.
+ """
+ # Use zero bytes for all random portions
+ zero_random_bytes = b"\x00" * 10
+
+ # Create UUIDv7 bytes using the helper function
+ uuid_bytes = _create_uuidv7_bytes(timestamp_ms, zero_random_bytes)
+
+ return uuid.UUID(bytes=uuid_bytes)
diff --git a/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py b/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py
new file mode 100644
index 0000000000..2bbbb3d28e
--- /dev/null
+++ b/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py
@@ -0,0 +1,86 @@
+"""add uuidv7 function in SQL
+
+Revision ID: 1c9ba48be8e4
+Revises: 58eb7bdb93fe
+Create Date: 2025-07-02 23:32:38.484499
+
+"""
+
+"""
+The functions in this files comes from https://github.com/dverite/postgres-uuidv7-sql/, with minor modifications.
+
+LICENSE:
+
+# Copyright and License
+
+Copyright (c) 2024, Daniel Vérité
+
+Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement is hereby granted, provided that the above copyright notice and this paragraph and the following two paragraphs appear in all copies.
+
+In no event shall Daniel Vérité be liable to any party for direct, indirect, special, incidental, or consequential damages, including lost profits, arising out of the use of this software and its documentation, even if Daniel Vérité has been advised of the possibility of such damage.
+
+Daniel Vérité specifically disclaims any warranties, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose. The software provided hereunder is on an "AS IS" basis, and Daniel Vérité has no obligations to provide maintenance, support, updates, enhancements, or modifications.
+"""
+
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '1c9ba48be8e4'
+down_revision = '58eb7bdb93fe'
+branch_labels: None = None
+depends_on: None = None
+
+
+def upgrade():
+ # This implementation differs slightly from the original uuidv7 function in
+ # https://github.com/dverite/postgres-uuidv7-sql/.
+ # The ability to specify source timestamp has been removed because its type signature is incompatible with
+ # PostgreSQL 18's `uuidv7` function. This capability is rarely needed in practice, as IDs can be
+ # generated and controlled within the application layer.
+ op.execute(sa.text(r"""
+/* Main function to generate a uuidv7 value with millisecond precision */
+CREATE FUNCTION uuidv7() RETURNS uuid
+AS
+$$
+ -- Replace the first 48 bits of a uuidv4 with the current
+ -- number of milliseconds since 1970-01-01 UTC
+ -- and set the "ver" field to 7 by setting additional bits
+SELECT encode(
+ set_bit(
+ set_bit(
+ overlay(uuid_send(gen_random_uuid()) placing
+ substring(int8send((extract(epoch from clock_timestamp()) * 1000)::bigint) from
+ 3)
+ from 1 for 6),
+ 52, 1),
+ 53, 1), 'hex')::uuid;
+$$ LANGUAGE SQL VOLATILE PARALLEL SAFE;
+
+COMMENT ON FUNCTION uuidv7 IS
+ 'Generate a uuid-v7 value with a 48-bit timestamp (millisecond precision) and 74 bits of randomness';
+"""))
+
+ op.execute(sa.text(r"""
+CREATE FUNCTION uuidv7_boundary(timestamptz) RETURNS uuid
+AS
+$$
+ /* uuid fields: version=0b0111, variant=0b10 */
+SELECT encode(
+ overlay('\x00000000000070008000000000000000'::bytea
+ placing substring(int8send(floor(extract(epoch from $1) * 1000)::bigint) from 3)
+ from 1 for 6),
+ 'hex')::uuid;
+$$ LANGUAGE SQL STABLE STRICT PARALLEL SAFE;
+
+COMMENT ON FUNCTION uuidv7_boundary(timestamptz) IS
+ 'Generate a non-random uuidv7 with the given timestamp (first 48 bits) and all random bits to 0. As the smallest possible uuidv7 for that timestamp, it may be used as a boundary for partitions.';
+"""
+))
+
+
+def downgrade():
+ op.execute(sa.text("DROP FUNCTION uuidv7"))
+ op.execute(sa.text("DROP FUNCTION uuidv7_boundary"))
diff --git a/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py b/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py
new file mode 100644
index 0000000000..df4fbf0a0e
--- /dev/null
+++ b/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py
@@ -0,0 +1,62 @@
+"""tool oauth
+
+Revision ID: 71f5020c6470
+Revises: 4474872b0ee6
+Create Date: 2025-06-24 17:05:43.118647
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '71f5020c6470'
+down_revision = '1c9ba48be8e4'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('tool_oauth_system_clients',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('plugin_id', sa.String(length=512), nullable=False),
+ sa.Column('provider', sa.String(length=255), nullable=False),
+ sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'),
+ sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx')
+ )
+ op.create_table('tool_oauth_tenant_clients',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+ sa.Column('plugin_id', sa.String(length=512), nullable=False),
+ sa.Column('provider', sa.String(length=255), nullable=False),
+ sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False),
+ sa.Column('encrypted_oauth_params', sa.Text(), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'),
+ sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client')
+ )
+
+ with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False))
+ batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False))
+ batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False))
+ batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
+ batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name'])
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
+ batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique')
+ batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider'])
+ batch_op.drop_column('credential_type')
+ batch_op.drop_column('is_default')
+ batch_op.drop_column('name')
+
+ op.drop_table('tool_oauth_tenant_clients')
+ op.drop_table('tool_oauth_system_clients')
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py b/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py
new file mode 100644
index 0000000000..3bdbafda7c
--- /dev/null
+++ b/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py
@@ -0,0 +1,51 @@
+"""update models
+
+Revision ID: 1a83934ad6d1
+Revises: 71f5020c6470
+Create Date: 2025-07-21 09:35:48.774794
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+
+
+# revision identifiers, used by Alembic.
+revision = '1a83934ad6d1'
+down_revision = '71f5020c6470'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
+ batch_op.alter_column('server_identifier',
+ existing_type=sa.VARCHAR(length=24),
+ type_=sa.String(length=64),
+ existing_nullable=False)
+
+ with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
+ batch_op.alter_column('tool_name',
+ existing_type=sa.VARCHAR(length=40),
+ type_=sa.String(length=128),
+ existing_nullable=False)
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op:
+ batch_op.alter_column('tool_name',
+ existing_type=sa.String(length=128),
+ type_=sa.VARCHAR(length=40),
+ existing_nullable=False)
+
+ with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op:
+ batch_op.alter_column('server_identifier',
+ existing_type=sa.String(length=64),
+ type_=sa.VARCHAR(length=24),
+ existing_nullable=False)
+
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2025_07_22_0019-375fe79ead14_oauth_refresh_token.py b/api/migrations/versions/2025_07_22_0019-375fe79ead14_oauth_refresh_token.py
new file mode 100644
index 0000000000..76d0cb2940
--- /dev/null
+++ b/api/migrations/versions/2025_07_22_0019-375fe79ead14_oauth_refresh_token.py
@@ -0,0 +1,34 @@
+"""oauth_refresh_token
+
+Revision ID: 375fe79ead14
+Revises: 1a83934ad6d1
+Create Date: 2025-07-22 00:19:45.599636
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '375fe79ead14'
+down_revision = '1a83934ad6d1'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
+ batch_op.add_column(sa.Column('expires_at', sa.BigInteger(), server_default=sa.text('-1'), nullable=False))
+
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+
+ with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
+ batch_op.drop_column('expires_at')
+
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2025_07_23_1508-8bcc02c9bd07_add_tenant_plugin_autoupgrade_table.py b/api/migrations/versions/2025_07_23_1508-8bcc02c9bd07_add_tenant_plugin_autoupgrade_table.py
new file mode 100644
index 0000000000..4ff0402a97
--- /dev/null
+++ b/api/migrations/versions/2025_07_23_1508-8bcc02c9bd07_add_tenant_plugin_autoupgrade_table.py
@@ -0,0 +1,42 @@
+"""add_tenant_plugin_autoupgrade_table
+
+Revision ID: 8bcc02c9bd07
+Revises: 375fe79ead14
+Create Date: 2025-07-23 15:08:50.161441
+
+"""
+from alembic import op
+import models as models
+import sqlalchemy as sa
+from sqlalchemy.dialects import postgresql
+
+# revision identifiers, used by Alembic.
+revision = '8bcc02c9bd07'
+down_revision = '375fe79ead14'
+branch_labels = None
+depends_on = None
+
+
+def upgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+ op.create_table('tenant_plugin_auto_upgrade_strategies',
+ sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
+ sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
+ sa.Column('strategy_setting', sa.String(length=16), server_default='fix_only', nullable=False),
+ sa.Column('upgrade_time_of_day', sa.Integer(), nullable=False),
+ sa.Column('upgrade_mode', sa.String(length=16), server_default='exclude', nullable=False),
+ sa.Column('exclude_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
+ sa.Column('include_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
+ sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
+ sa.PrimaryKeyConstraint('id', name='tenant_plugin_auto_upgrade_strategy_pkey'),
+ sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin_auto_upgrade_strategy')
+ )
+ # ### end Alembic commands ###
+
+
+def downgrade():
+ # ### commands auto generated by Alembic - please adjust! ###
+
+ op.drop_table('tenant_plugin_auto_upgrade_strategies')
+ # ### end Alembic commands ###
diff --git a/api/models/account.py b/api/models/account.py
index 7ffeefa980..d63c5d7fb5 100644
--- a/api/models/account.py
+++ b/api/models/account.py
@@ -1,9 +1,10 @@
import enum
import json
+from datetime import datetime
from typing import Optional, cast
from flask_login import UserMixin # type: ignore
-from sqlalchemy import func
+from sqlalchemy import func, select
from sqlalchemy.orm import Mapped, mapped_column, reconstructor
from models.base import Base
@@ -85,21 +86,23 @@ class Account(UserMixin, Base):
__table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email"))
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- name = db.Column(db.String(255), nullable=False)
- email = db.Column(db.String(255), nullable=False)
- password = db.Column(db.String(255), nullable=True)
- password_salt = db.Column(db.String(255), nullable=True)
- avatar = db.Column(db.String(255))
- interface_language = db.Column(db.String(255))
- interface_theme = db.Column(db.String(255))
- timezone = db.Column(db.String(255))
- last_login_at = db.Column(db.DateTime)
- last_login_ip = db.Column(db.String(255))
- last_active_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- status = db.Column(db.String(16), nullable=False, server_default=db.text("'active'::character varying"))
- initialized_at = db.Column(db.DateTime)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ name: Mapped[str] = mapped_column(db.String(255))
+ email: Mapped[str] = mapped_column(db.String(255))
+ password: Mapped[Optional[str]] = mapped_column(db.String(255))
+ password_salt: Mapped[Optional[str]] = mapped_column(db.String(255))
+ avatar: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
+ interface_language: Mapped[Optional[str]] = mapped_column(db.String(255))
+ interface_theme: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
+ timezone: Mapped[Optional[str]] = mapped_column(db.String(255))
+ last_login_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
+ last_login_ip: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
+ last_active_at: Mapped[datetime] = mapped_column(
+ db.DateTime, server_default=func.current_timestamp(), nullable=False
+ )
+ status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'active'::character varying"))
+ initialized_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False)
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False)
@reconstructor
def init_on_load(self):
@@ -116,7 +119,7 @@ class Account(UserMixin, Base):
@current_tenant.setter
def current_tenant(self, tenant: "Tenant"):
- ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first()
+ ta = db.session.scalar(select(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).limit(1))
if ta:
self.role = TenantAccountRole(ta.role)
self._current_tenant = tenant
@@ -132,9 +135,9 @@ class Account(UserMixin, Base):
tuple[Tenant, TenantAccountJoin],
(
db.session.query(Tenant, TenantAccountJoin)
- .filter(Tenant.id == tenant_id)
- .filter(TenantAccountJoin.tenant_id == Tenant.id)
- .filter(TenantAccountJoin.account_id == self.id)
+ .where(Tenant.id == tenant_id)
+ .where(TenantAccountJoin.tenant_id == Tenant.id)
+ .where(TenantAccountJoin.account_id == self.id)
.one_or_none()
),
)
@@ -143,7 +146,7 @@ class Account(UserMixin, Base):
return
tenant, join = tenant_account_join
- self.role = join.role
+ self.role = TenantAccountRole(join.role)
self._current_tenant = tenant
@property
@@ -158,11 +161,11 @@ class Account(UserMixin, Base):
def get_by_openid(cls, provider: str, open_id: str):
account_integrate = (
db.session.query(AccountIntegrate)
- .filter(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
+ .where(AccountIntegrate.provider == provider, AccountIntegrate.open_id == open_id)
.one_or_none()
)
if account_integrate:
- return db.session.query(Account).filter(Account.id == account_integrate.account_id).one_or_none()
+ return db.session.query(Account).where(Account.id == account_integrate.account_id).one_or_none()
return None
# check current_user.current_tenant.current_role in ['admin', 'owner']
@@ -196,19 +199,19 @@ class Tenant(Base):
__tablename__ = "tenants"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- name = db.Column(db.String(255), nullable=False)
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ name: Mapped[str] = mapped_column(db.String(255))
encrypt_public_key = db.Column(db.Text)
- plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))
- status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
- custom_config = db.Column(db.Text)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ plan: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'basic'::character varying"))
+ status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying"))
+ custom_config: Mapped[Optional[str]] = mapped_column(db.Text)
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp(), nullable=False)
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
def get_accounts(self) -> list[Account]:
return (
db.session.query(Account)
- .filter(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id)
+ .where(Account.id == TenantAccountJoin.account_id, TenantAccountJoin.tenant_id == self.id)
.all()
)
@@ -230,14 +233,14 @@ class TenantAccountJoin(Base):
db.UniqueConstraint("tenant_id", "account_id", name="unique_tenant_account_join"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- account_id = db.Column(StringUUID, nullable=False)
- current = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- role = db.Column(db.String(16), nullable=False, server_default="normal")
- invited_by = db.Column(StringUUID, nullable=True)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID)
+ account_id: Mapped[str] = mapped_column(StringUUID)
+ current: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
+ role: Mapped[str] = mapped_column(db.String(16), server_default="normal")
+ invited_by: Mapped[Optional[str]] = mapped_column(StringUUID)
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
class AccountIntegrate(Base):
@@ -248,13 +251,13 @@ class AccountIntegrate(Base):
db.UniqueConstraint("provider", "open_id", name="unique_provider_open_id"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- account_id = db.Column(StringUUID, nullable=False)
- provider = db.Column(db.String(16), nullable=False)
- open_id = db.Column(db.String(255), nullable=False)
- encrypted_token = db.Column(db.String(255), nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ account_id: Mapped[str] = mapped_column(StringUUID)
+ provider: Mapped[str] = mapped_column(db.String(16))
+ open_id: Mapped[str] = mapped_column(db.String(255))
+ encrypted_token: Mapped[str] = mapped_column(db.String(255))
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
class InvitationCode(Base):
@@ -265,15 +268,15 @@ class InvitationCode(Base):
db.Index("invitation_codes_code_idx", "code", "status"),
)
- id = db.Column(db.Integer, nullable=False)
- batch = db.Column(db.String(255), nullable=False)
- code = db.Column(db.String(32), nullable=False)
- status = db.Column(db.String(16), nullable=False, server_default=db.text("'unused'::character varying"))
- used_at = db.Column(db.DateTime)
- used_by_tenant_id = db.Column(StringUUID)
- used_by_account_id = db.Column(StringUUID)
- deprecated_at = db.Column(db.DateTime)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+ id: Mapped[int] = mapped_column(db.Integer)
+ batch: Mapped[str] = mapped_column(db.String(255))
+ code: Mapped[str] = mapped_column(db.String(32))
+ status: Mapped[str] = mapped_column(db.String(16), server_default=db.text("'unused'::character varying"))
+ used_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
+ used_by_tenant_id: Mapped[Optional[str]] = mapped_column(StringUUID)
+ used_by_account_id: Mapped[Optional[str]] = mapped_column(StringUUID)
+ deprecated_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class TenantPluginPermission(Base):
@@ -299,3 +302,35 @@ class TenantPluginPermission(Base):
db.String(16), nullable=False, server_default="everyone"
)
debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), nullable=False, server_default="noone")
+
+
+class TenantPluginAutoUpgradeStrategy(Base):
+ class StrategySetting(enum.StrEnum):
+ DISABLED = "disabled"
+ FIX_ONLY = "fix_only"
+ LATEST = "latest"
+
+ class UpgradeMode(enum.StrEnum):
+ ALL = "all"
+ PARTIAL = "partial"
+ EXCLUDE = "exclude"
+
+ __tablename__ = "tenant_plugin_auto_upgrade_strategies"
+ __table_args__ = (
+ db.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"),
+ db.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
+ )
+
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ strategy_setting: Mapped[StrategySetting] = mapped_column(db.String(16), nullable=False, server_default="fix_only")
+ upgrade_time_of_day: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) # seconds of the day
+ upgrade_mode: Mapped[UpgradeMode] = mapped_column(db.String(16), nullable=False, server_default="exclude")
+ exclude_plugins: Mapped[list[str]] = mapped_column(
+ db.ARRAY(db.String(255)), nullable=False
+ ) # plugin_id (author/name)
+ include_plugins: Mapped[list[str]] = mapped_column(
+ db.ARRAY(db.String(255)), nullable=False
+ ) # plugin_id (author/name)
+ created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py
index 5a70e18622..3cef5a0fb2 100644
--- a/api/models/api_based_extension.py
+++ b/api/models/api_based_extension.py
@@ -1,6 +1,7 @@
import enum
from sqlalchemy import func
+from sqlalchemy.orm import mapped_column
from .base import Base
from .engine import db
@@ -21,9 +22,9 @@ class APIBasedExtension(Base):
db.Index("api_based_extension_tenant_idx", "tenant_id"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- name = db.Column(db.String(255), nullable=False)
- api_endpoint = db.Column(db.String(255), nullable=False)
- api_key = db.Column(db.Text, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ name = mapped_column(db.String(255), nullable=False)
+ api_endpoint = mapped_column(db.String(255), nullable=False)
+ api_key = mapped_column(db.Text, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
diff --git a/api/models/dataset.py b/api/models/dataset.py
index 9d299bb6f7..d877540213 100644
--- a/api/models/dataset.py
+++ b/api/models/dataset.py
@@ -8,12 +8,13 @@ import os
import pickle
import re
import time
+from datetime import datetime
from json import JSONDecodeError
-from typing import Any, cast
+from typing import Any, Optional, cast
-from sqlalchemy import func
+from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import JSONB
-from sqlalchemy.orm import Mapped
+from sqlalchemy.orm import Mapped, mapped_column
from configs import dify_config
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
@@ -45,29 +46,29 @@ class Dataset(Base):
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
PROVIDER_LIST = ["vendor", "external", None]
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- name = db.Column(db.String(255), nullable=False)
- description = db.Column(db.Text, nullable=True)
- provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying"))
- permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying"))
- data_source_type = db.Column(db.String(255))
- indexing_technique = db.Column(db.String(255), nullable=True)
- index_struct = db.Column(db.Text, nullable=True)
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_by = db.Column(StringUUID, nullable=True)
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- embedding_model = db.Column(db.String(255), nullable=True)
- embedding_model_provider = db.Column(db.String(255), nullable=True)
- collection_binding_id = db.Column(StringUUID, nullable=True)
- retrieval_model = db.Column(JSONB, nullable=True)
- built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID)
+ name: Mapped[str] = mapped_column(db.String(255))
+ description = mapped_column(db.Text, nullable=True)
+ provider: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'vendor'::character varying"))
+ permission: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'only_me'::character varying"))
+ data_source_type = mapped_column(db.String(255))
+ indexing_technique: Mapped[Optional[str]] = mapped_column(db.String(255))
+ index_struct = mapped_column(db.Text, nullable=True)
+ created_by = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_by = mapped_column(StringUUID, nullable=True)
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ embedding_model = db.Column(db.String(255), nullable=True) # TODO: mapped_column
+ embedding_model_provider = db.Column(db.String(255), nullable=True) # TODO: mapped_column
+ collection_binding_id = mapped_column(StringUUID, nullable=True)
+ retrieval_model = mapped_column(JSONB, nullable=True)
+ built_in_field_enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
@property
def dataset_keyword_table(self):
dataset_keyword_table = (
- db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first()
+ db.session.query(DatasetKeywordTable).where(DatasetKeywordTable.dataset_id == self.id).first()
)
if dataset_keyword_table:
return dataset_keyword_table
@@ -94,7 +95,7 @@ class Dataset(Base):
def latest_process_rule(self):
return (
db.session.query(DatasetProcessRule)
- .filter(DatasetProcessRule.dataset_id == self.id)
+ .where(DatasetProcessRule.dataset_id == self.id)
.order_by(DatasetProcessRule.created_at.desc())
.first()
)
@@ -103,19 +104,19 @@ class Dataset(Base):
def app_count(self):
return (
db.session.query(func.count(AppDatasetJoin.id))
- .filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
+ .where(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
.scalar()
)
@property
def document_count(self):
- return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()
+ return db.session.query(func.count(Document.id)).where(Document.dataset_id == self.id).scalar()
@property
def available_document_count(self):
return (
db.session.query(func.count(Document.id))
- .filter(
+ .where(
Document.dataset_id == self.id,
Document.indexing_status == "completed",
Document.enabled == True,
@@ -128,7 +129,7 @@ class Dataset(Base):
def available_segment_count(self):
return (
db.session.query(func.count(DocumentSegment.id))
- .filter(
+ .where(
DocumentSegment.dataset_id == self.id,
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
@@ -141,13 +142,13 @@ class Dataset(Base):
return (
db.session.query(Document)
.with_entities(func.coalesce(func.sum(Document.word_count), 0))
- .filter(Document.dataset_id == self.id)
+ .where(Document.dataset_id == self.id)
.scalar()
)
@property
def doc_form(self):
- document = db.session.query(Document).filter(Document.dataset_id == self.id).first()
+ document = db.session.query(Document).where(Document.dataset_id == self.id).first()
if document:
return document.doc_form
return None
@@ -168,7 +169,7 @@ class Dataset(Base):
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
- .filter(
+ .where(
TagBinding.target_id == self.id,
TagBinding.tenant_id == self.tenant_id,
Tag.tenant_id == self.tenant_id,
@@ -184,14 +185,14 @@ class Dataset(Base):
if self.provider != "external":
return None
external_knowledge_binding = (
- db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first()
+ db.session.query(ExternalKnowledgeBindings).where(ExternalKnowledgeBindings.dataset_id == self.id).first()
)
if not external_knowledge_binding:
return None
- external_knowledge_api = (
- db.session.query(ExternalKnowledgeApis)
- .filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
- .first()
+ external_knowledge_api = db.session.scalar(
+ select(ExternalKnowledgeApis).where(
+ ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id
+ )
)
if not external_knowledge_api:
return None
@@ -204,7 +205,7 @@ class Dataset(Base):
@property
def doc_metadata(self):
- dataset_metadatas = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == self.id).all()
+ dataset_metadatas = db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == self.id).all()
doc_metadata = [
{
@@ -255,7 +256,7 @@ class Dataset(Base):
@staticmethod
def gen_collection_name_by_id(dataset_id: str) -> str:
normalized_dataset_id = dataset_id.replace("-", "_")
- return f"Vector_index_{normalized_dataset_id}_Node"
+ return f"{dify_config.VECTOR_INDEX_NAME_PREFIX}_{normalized_dataset_id}_Node"
class DatasetProcessRule(Base):
@@ -265,12 +266,12 @@ class DatasetProcessRule(Base):
db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
)
- id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
- dataset_id = db.Column(StringUUID, nullable=False)
- mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
- rules = db.Column(db.Text, nullable=True)
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+ dataset_id = mapped_column(StringUUID, nullable=False)
+ mode = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
+ rules = mapped_column(db.Text, nullable=True)
+ created_by = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
MODES = ["automatic", "custom", "hierarchical"]
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
@@ -309,62 +310,64 @@ class Document(Base):
)
# initial fields
- id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- dataset_id = db.Column(StringUUID, nullable=False)
- position = db.Column(db.Integer, nullable=False)
- data_source_type = db.Column(db.String(255), nullable=False)
- data_source_info = db.Column(db.Text, nullable=True)
- dataset_process_rule_id = db.Column(StringUUID, nullable=True)
- batch = db.Column(db.String(255), nullable=False)
- name = db.Column(db.String(255), nullable=False)
- created_from = db.Column(db.String(255), nullable=False)
- created_by = db.Column(StringUUID, nullable=False)
- created_api_request_id = db.Column(StringUUID, nullable=True)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ dataset_id = mapped_column(StringUUID, nullable=False)
+ position = mapped_column(db.Integer, nullable=False)
+ data_source_type = mapped_column(db.String(255), nullable=False)
+ data_source_info = mapped_column(db.Text, nullable=True)
+ dataset_process_rule_id = mapped_column(StringUUID, nullable=True)
+ batch = mapped_column(db.String(255), nullable=False)
+ name = mapped_column(db.String(255), nullable=False)
+ created_from = mapped_column(db.String(255), nullable=False)
+ created_by = mapped_column(StringUUID, nullable=False)
+ created_api_request_id = mapped_column(StringUUID, nullable=True)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
# start processing
- processing_started_at = db.Column(db.DateTime, nullable=True)
+ processing_started_at = mapped_column(db.DateTime, nullable=True)
# parsing
- file_id = db.Column(db.Text, nullable=True)
- word_count = db.Column(db.Integer, nullable=True)
- parsing_completed_at = db.Column(db.DateTime, nullable=True)
+ file_id = mapped_column(db.Text, nullable=True)
+ word_count = mapped_column(db.Integer, nullable=True)
+ parsing_completed_at = mapped_column(db.DateTime, nullable=True)
# cleaning
- cleaning_completed_at = db.Column(db.DateTime, nullable=True)
+ cleaning_completed_at = mapped_column(db.DateTime, nullable=True)
# split
- splitting_completed_at = db.Column(db.DateTime, nullable=True)
+ splitting_completed_at = mapped_column(db.DateTime, nullable=True)
# indexing
- tokens = db.Column(db.Integer, nullable=True)
- indexing_latency = db.Column(db.Float, nullable=True)
- completed_at = db.Column(db.DateTime, nullable=True)
+ tokens = mapped_column(db.Integer, nullable=True)
+ indexing_latency = mapped_column(db.Float, nullable=True)
+ completed_at = mapped_column(db.DateTime, nullable=True)
# pause
- is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
- paused_by = db.Column(StringUUID, nullable=True)
- paused_at = db.Column(db.DateTime, nullable=True)
+ is_paused = mapped_column(db.Boolean, nullable=True, server_default=db.text("false"))
+ paused_by = mapped_column(StringUUID, nullable=True)
+ paused_at = mapped_column(db.DateTime, nullable=True)
# error
- error = db.Column(db.Text, nullable=True)
- stopped_at = db.Column(db.DateTime, nullable=True)
+ error = mapped_column(db.Text, nullable=True)
+ stopped_at = mapped_column(db.DateTime, nullable=True)
# basic fields
- indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
- enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
- disabled_at = db.Column(db.DateTime, nullable=True)
- disabled_by = db.Column(StringUUID, nullable=True)
- archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- archived_reason = db.Column(db.String(255), nullable=True)
- archived_by = db.Column(StringUUID, nullable=True)
- archived_at = db.Column(db.DateTime, nullable=True)
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- doc_type = db.Column(db.String(40), nullable=True)
- doc_metadata = db.Column(JSONB, nullable=True)
- doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
- doc_language = db.Column(db.String(255), nullable=True)
+ indexing_status = mapped_column(
+ db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")
+ )
+ enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
+ disabled_at = mapped_column(db.DateTime, nullable=True)
+ disabled_by = mapped_column(StringUUID, nullable=True)
+ archived = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
+ archived_reason = mapped_column(db.String(255), nullable=True)
+ archived_by = mapped_column(StringUUID, nullable=True)
+ archived_at = mapped_column(db.DateTime, nullable=True)
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ doc_type = mapped_column(db.String(40), nullable=True)
+ doc_metadata = mapped_column(JSONB, nullable=True)
+ doc_form = mapped_column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
+ doc_language = mapped_column(db.String(255), nullable=True)
DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]
@@ -405,7 +408,7 @@ class Document(Base):
data_source_info_dict = json.loads(self.data_source_info)
file_detail = (
db.session.query(UploadFile)
- .filter(UploadFile.id == data_source_info_dict["upload_file_id"])
+ .where(UploadFile.id == data_source_info_dict["upload_file_id"])
.one_or_none()
)
if file_detail:
@@ -438,24 +441,24 @@ class Document(Base):
@property
def dataset(self):
- return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()
+ return db.session.query(Dataset).where(Dataset.id == self.dataset_id).one_or_none()
@property
def segment_count(self):
- return db.session.query(DocumentSegment).filter(DocumentSegment.document_id == self.id).count()
+ return db.session.query(DocumentSegment).where(DocumentSegment.document_id == self.id).count()
@property
def hit_count(self):
return (
db.session.query(DocumentSegment)
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count), 0))
- .filter(DocumentSegment.document_id == self.id)
+ .where(DocumentSegment.document_id == self.id)
.scalar()
)
@property
def uploader(self):
- user = db.session.query(Account).filter(Account.id == self.created_by).first()
+ user = db.session.query(Account).where(Account.id == self.created_by).first()
return user.name if user else None
@property
@@ -472,7 +475,7 @@ class Document(Base):
document_metadatas = (
db.session.query(DatasetMetadata)
.join(DatasetMetadataBinding, DatasetMetadataBinding.metadata_id == DatasetMetadata.id)
- .filter(
+ .where(
DatasetMetadataBinding.dataset_id == self.dataset_id, DatasetMetadataBinding.document_id == self.id
)
.all()
@@ -652,58 +655,58 @@ class DocumentSegment(Base):
)
# initial fields
- id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- dataset_id = db.Column(StringUUID, nullable=False)
- document_id = db.Column(StringUUID, nullable=False)
+ id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ dataset_id = mapped_column(StringUUID, nullable=False)
+ document_id = mapped_column(StringUUID, nullable=False)
position: Mapped[int]
- content = db.Column(db.Text, nullable=False)
- answer = db.Column(db.Text, nullable=True)
- word_count = db.Column(db.Integer, nullable=False)
- tokens = db.Column(db.Integer, nullable=False)
+ content = mapped_column(db.Text, nullable=False)
+ answer = mapped_column(db.Text, nullable=True)
+ word_count: Mapped[int]
+ tokens: Mapped[int]
# indexing fields
- keywords = db.Column(db.JSON, nullable=True)
- index_node_id = db.Column(db.String(255), nullable=True)
- index_node_hash = db.Column(db.String(255), nullable=True)
+ keywords = mapped_column(db.JSON, nullable=True)
+ index_node_id = mapped_column(db.String(255), nullable=True)
+ index_node_hash = mapped_column(db.String(255), nullable=True)
# basic fields
- hit_count = db.Column(db.Integer, nullable=False, default=0)
- enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
- disabled_at = db.Column(db.DateTime, nullable=True)
- disabled_by = db.Column(StringUUID, nullable=True)
- status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_by = db.Column(StringUUID, nullable=True)
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- indexing_at = db.Column(db.DateTime, nullable=True)
- completed_at = db.Column(db.DateTime, nullable=True)
- error = db.Column(db.Text, nullable=True)
- stopped_at = db.Column(db.DateTime, nullable=True)
+ hit_count = mapped_column(db.Integer, nullable=False, default=0)
+ enabled = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
+ disabled_at = mapped_column(db.DateTime, nullable=True)
+ disabled_by = mapped_column(StringUUID, nullable=True)
+ status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'waiting'::character varying"))
+ created_by = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_by = mapped_column(StringUUID, nullable=True)
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ indexing_at = mapped_column(db.DateTime, nullable=True)
+ completed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True)
+ error = mapped_column(db.Text, nullable=True)
+ stopped_at = mapped_column(db.DateTime, nullable=True)
@property
def dataset(self):
- return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
+ return db.session.scalar(select(Dataset).where(Dataset.id == self.dataset_id))
@property
def document(self):
- return db.session.query(Document).filter(Document.id == self.document_id).first()
+ return db.session.scalar(select(Document).where(Document.id == self.document_id))
@property
def previous_segment(self):
- return (
- db.session.query(DocumentSegment)
- .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1)
- .first()
+ return db.session.scalar(
+ select(DocumentSegment).where(
+ DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1
+ )
)
@property
def next_segment(self):
- return (
- db.session.query(DocumentSegment)
- .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1)
- .first()
+ return db.session.scalar(
+ select(DocumentSegment).where(
+ DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1
+ )
)
@property
@@ -714,7 +717,7 @@ class DocumentSegment(Base):
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
- .filter(ChildChunk.segment_id == self.id)
+ .where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
@@ -731,7 +734,7 @@ class DocumentSegment(Base):
if rules.parent_mode:
child_chunks = (
db.session.query(ChildChunk)
- .filter(ChildChunk.segment_id == self.id)
+ .where(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
@@ -800,37 +803,37 @@ class ChildChunk(Base):
)
# initial fields
- id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- dataset_id = db.Column(StringUUID, nullable=False)
- document_id = db.Column(StringUUID, nullable=False)
- segment_id = db.Column(StringUUID, nullable=False)
- position = db.Column(db.Integer, nullable=False)
- content = db.Column(db.Text, nullable=False)
- word_count = db.Column(db.Integer, nullable=False)
+ id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ dataset_id = mapped_column(StringUUID, nullable=False)
+ document_id = mapped_column(StringUUID, nullable=False)
+ segment_id = mapped_column(StringUUID, nullable=False)
+ position = mapped_column(db.Integer, nullable=False)
+ content = mapped_column(db.Text, nullable=False)
+ word_count = mapped_column(db.Integer, nullable=False)
# indexing fields
- index_node_id = db.Column(db.String(255), nullable=True)
- index_node_hash = db.Column(db.String(255), nullable=True)
- type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- updated_by = db.Column(StringUUID, nullable=True)
- updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- indexing_at = db.Column(db.DateTime, nullable=True)
- completed_at = db.Column(db.DateTime, nullable=True)
- error = db.Column(db.Text, nullable=True)
+ index_node_id = mapped_column(db.String(255), nullable=True)
+ index_node_hash = mapped_column(db.String(255), nullable=True)
+ type = mapped_column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
+ created_by = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+ updated_by = mapped_column(StringUUID, nullable=True)
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+ indexing_at = mapped_column(db.DateTime, nullable=True)
+ completed_at = mapped_column(db.DateTime, nullable=True)
+ error = mapped_column(db.Text, nullable=True)
@property
def dataset(self):
- return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
+ return db.session.query(Dataset).where(Dataset.id == self.dataset_id).first()
@property
def document(self):
- return db.session.query(Document).filter(Document.id == self.document_id).first()
+ return db.session.query(Document).where(Document.id == self.document_id).first()
@property
def segment(self):
- return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
+ return db.session.query(DocumentSegment).where(DocumentSegment.id == self.segment_id).first()
class AppDatasetJoin(Base):
@@ -840,10 +843,10 @@ class AppDatasetJoin(Base):
db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"),
)
- id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
- app_id = db.Column(StringUUID, nullable=False)
- dataset_id = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
+ id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
+ app_id = mapped_column(StringUUID, nullable=False)
+ dataset_id = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
@property
def app(self):
@@ -857,14 +860,14 @@ class DatasetQuery(Base):
db.Index("dataset_query_dataset_id_idx", "dataset_id"),
)
- id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
- dataset_id = db.Column(StringUUID, nullable=False)
- content = db.Column(db.Text, nullable=False)
- source = db.Column(db.String(255), nullable=False)
- source_app_id = db.Column(StringUUID, nullable=True)
- created_by_role = db.Column(db.String, nullable=False)
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
+ id = mapped_column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
+ dataset_id = mapped_column(StringUUID, nullable=False)
+ content = mapped_column(db.Text, nullable=False)
+ source = mapped_column(db.String(255), nullable=False)
+ source_app_id = mapped_column(StringUUID, nullable=True)
+ created_by_role = mapped_column(db.String, nullable=False)
+ created_by = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
class DatasetKeywordTable(Base):
@@ -874,10 +877,10 @@ class DatasetKeywordTable(Base):
db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
)
- id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
- dataset_id = db.Column(StringUUID, nullable=False, unique=True)
- keyword_table = db.Column(db.Text, nullable=False)
- data_source_type = db.Column(
+ id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
+ dataset_id = mapped_column(StringUUID, nullable=False, unique=True)
+ keyword_table = mapped_column(db.Text, nullable=False)
+ data_source_type = mapped_column(
db.String(255), nullable=False, server_default=db.text("'database'::character varying")
)
@@ -920,14 +923,14 @@ class Embedding(Base):
db.Index("created_at_idx", "created_at"),
)
- id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
- model_name = db.Column(
+ id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
+ model_name = mapped_column(
db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")
)
- hash = db.Column(db.String(64), nullable=False)
- embedding = db.Column(db.LargeBinary, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying"))
+ hash = mapped_column(db.String(64), nullable=False)
+ embedding = mapped_column(db.LargeBinary, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ provider_name = mapped_column(db.String(255), nullable=False, server_default=db.text("''::character varying"))
def set_embedding(self, embedding_data: list[float]):
self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)
@@ -943,12 +946,12 @@ class DatasetCollectionBinding(Base):
db.Index("provider_model_name_idx", "provider_name", "model_name"),
)
- id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
- provider_name = db.Column(db.String(255), nullable=False)
- model_name = db.Column(db.String(255), nullable=False)
- type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
- collection_name = db.Column(db.String(64), nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
+ provider_name = mapped_column(db.String(255), nullable=False)
+ model_name = mapped_column(db.String(255), nullable=False)
+ type = mapped_column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
+ collection_name = mapped_column(db.String(64), nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TidbAuthBinding(Base):
@@ -960,15 +963,15 @@ class TidbAuthBinding(Base):
db.Index("tidb_auth_bindings_created_at_idx", "created_at"),
db.Index("tidb_auth_bindings_status_idx", "status"),
)
- id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=True)
- cluster_id = db.Column(db.String(255), nullable=False)
- cluster_name = db.Column(db.String(255), nullable=False)
- active = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- status = db.Column(db.String(255), nullable=False, server_default=db.text("'CREATING'::character varying"))
- account = db.Column(db.String(255), nullable=False)
- password = db.Column(db.String(255), nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=True)
+ cluster_id = mapped_column(db.String(255), nullable=False)
+ cluster_name = mapped_column(db.String(255), nullable=False)
+ active = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
+ status = mapped_column(db.String(255), nullable=False, server_default=db.text("CREATING"))
+ account = mapped_column(db.String(255), nullable=False)
+ password = mapped_column(db.String(255), nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class Whitelist(Base):
@@ -977,10 +980,10 @@ class Whitelist(Base):
db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
db.Index("whitelists_tenant_idx", "tenant_id"),
)
- id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=True)
- category = db.Column(db.String(255), nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=True)
+ category = mapped_column(db.String(255), nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetPermission(Base):
@@ -992,12 +995,12 @@ class DatasetPermission(Base):
db.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True)
- dataset_id = db.Column(StringUUID, nullable=False)
- account_id = db.Column(StringUUID, nullable=False)
- tenant_id = db.Column(StringUUID, nullable=False)
- has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True)
+ dataset_id = mapped_column(StringUUID, nullable=False)
+ account_id = mapped_column(StringUUID, nullable=False)
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ has_permission = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ExternalKnowledgeApis(Base):
@@ -1008,15 +1011,15 @@ class ExternalKnowledgeApis(Base):
db.Index("external_knowledge_apis_name_idx", "name"),
)
- id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
- name = db.Column(db.String(255), nullable=False)
- description = db.Column(db.String(255), nullable=False)
- tenant_id = db.Column(StringUUID, nullable=False)
- settings = db.Column(db.Text, nullable=True)
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_by = db.Column(StringUUID, nullable=True)
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+ name = mapped_column(db.String(255), nullable=False)
+ description = mapped_column(db.String(255), nullable=False)
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ settings = mapped_column(db.Text, nullable=True)
+ created_by = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_by = mapped_column(StringUUID, nullable=True)
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
def to_dict(self):
return {
@@ -1041,11 +1044,11 @@ class ExternalKnowledgeApis(Base):
def dataset_bindings(self):
external_knowledge_bindings = (
db.session.query(ExternalKnowledgeBindings)
- .filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
+ .where(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
.all()
)
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
- datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all()
+ datasets = db.session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
dataset_bindings = []
for dataset in datasets:
dataset_bindings.append({"id": dataset.id, "name": dataset.name})
@@ -1063,15 +1066,15 @@ class ExternalKnowledgeBindings(Base):
db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
)
- id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- external_knowledge_api_id = db.Column(StringUUID, nullable=False)
- dataset_id = db.Column(StringUUID, nullable=False)
- external_knowledge_id = db.Column(db.Text, nullable=False)
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_by = db.Column(StringUUID, nullable=True)
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ external_knowledge_api_id = mapped_column(StringUUID, nullable=False)
+ dataset_id = mapped_column(StringUUID, nullable=False)
+ external_knowledge_id = mapped_column(db.Text, nullable=False)
+ created_by = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_by = mapped_column(StringUUID, nullable=True)
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetAutoDisableLog(Base):
@@ -1083,12 +1086,12 @@ class DatasetAutoDisableLog(Base):
db.Index("dataset_auto_disable_log_created_atx", "created_at"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- dataset_id = db.Column(StringUUID, nullable=False)
- document_id = db.Column(StringUUID, nullable=False)
- notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ dataset_id = mapped_column(StringUUID, nullable=False)
+ document_id = mapped_column(StringUUID, nullable=False)
+ notified = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class RateLimitLog(Base):
@@ -1099,11 +1102,11 @@ class RateLimitLog(Base):
db.Index("rate_limit_log_operation_idx", "operation"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- subscription_plan = db.Column(db.String(255), nullable=False)
- operation = db.Column(db.String(255), nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ subscription_plan = mapped_column(db.String(255), nullable=False)
+ operation = mapped_column(db.String(255), nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class DatasetMetadata(Base):
@@ -1114,15 +1117,15 @@ class DatasetMetadata(Base):
db.Index("dataset_metadata_dataset_idx", "dataset_id"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- dataset_id = db.Column(StringUUID, nullable=False)
- type = db.Column(db.String(255), nullable=False)
- name = db.Column(db.String(255), nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- created_by = db.Column(StringUUID, nullable=False)
- updated_by = db.Column(StringUUID, nullable=True)
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ dataset_id = mapped_column(StringUUID, nullable=False)
+ type = mapped_column(db.String(255), nullable=False)
+ name = mapped_column(db.String(255), nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+ created_by = mapped_column(StringUUID, nullable=False)
+ updated_by = mapped_column(StringUUID, nullable=True)
class DatasetMetadataBinding(Base):
@@ -1135,10 +1138,10 @@ class DatasetMetadataBinding(Base):
db.Index("dataset_metadata_binding_document_idx", "document_id"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- dataset_id = db.Column(StringUUID, nullable=False)
- metadata_id = db.Column(StringUUID, nullable=False)
- document_id = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- created_by = db.Column(StringUUID, nullable=False)
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ dataset_id = mapped_column(StringUUID, nullable=False)
+ metadata_id = mapped_column(StringUUID, nullable=False)
+ document_id = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ created_by = mapped_column(StringUUID, nullable=False)
diff --git a/api/models/model.py b/api/models/model.py
index b1007c4a79..a78a91ebd5 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -40,8 +40,8 @@ class DifySetup(Base):
__tablename__ = "dify_setups"
__table_args__ = (db.PrimaryKeyConstraint("version", name="dify_setup_pkey"),)
- version = db.Column(db.String(255), nullable=False)
- setup_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ version = mapped_column(db.String(255), nullable=False)
+ setup_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class AppMode(StrEnum):
@@ -50,7 +50,6 @@ class AppMode(StrEnum):
CHAT = "chat"
ADVANCED_CHAT = "advanced-chat"
AGENT_CHAT = "agent-chat"
- CHANNEL = "channel"
@classmethod
def value_of(cls, value: str) -> "AppMode":
@@ -75,31 +74,31 @@ class App(Base):
__tablename__ = "apps"
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_pkey"), db.Index("app_tenant_id_idx", "tenant_id"))
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
- name = db.Column(db.String(255), nullable=False)
- description = db.Column(db.Text, nullable=False, server_default=db.text("''::character varying"))
- mode: Mapped[str] = mapped_column(db.String(255), nullable=False)
- icon_type = db.Column(db.String(255), nullable=True) # image, emoji
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID)
+ name: Mapped[str] = mapped_column(db.String(255))
+ description: Mapped[str] = mapped_column(db.Text, server_default=db.text("''::character varying"))
+ mode: Mapped[str] = mapped_column(db.String(255))
+ icon_type: Mapped[Optional[str]] = mapped_column(db.String(255)) # image, emoji
icon = db.Column(db.String(255))
- icon_background = db.Column(db.String(255))
- app_model_config_id = db.Column(StringUUID, nullable=True)
- workflow_id = db.Column(StringUUID, nullable=True)
- status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
- enable_site = db.Column(db.Boolean, nullable=False)
- enable_api = db.Column(db.Boolean, nullable=False)
- api_rpm = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
- api_rph = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
- is_demo = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- is_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- is_universal = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- tracing = db.Column(db.Text, nullable=True)
- max_active_requests: Mapped[Optional[int]] = mapped_column(nullable=True)
- created_by = db.Column(StringUUID, nullable=True)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_by = db.Column(StringUUID, nullable=True)
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
+ icon_background: Mapped[Optional[str]] = mapped_column(db.String(255))
+ app_model_config_id = mapped_column(StringUUID, nullable=True)
+ workflow_id = mapped_column(StringUUID, nullable=True)
+ status: Mapped[str] = mapped_column(db.String(255), server_default=db.text("'normal'::character varying"))
+ enable_site: Mapped[bool] = mapped_column(db.Boolean)
+ enable_api: Mapped[bool] = mapped_column(db.Boolean)
+ api_rpm: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
+ api_rph: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"))
+ is_demo: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
+ is_public: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
+ is_universal: Mapped[bool] = mapped_column(db.Boolean, server_default=db.text("false"))
+ tracing = mapped_column(db.Text, nullable=True)
+ max_active_requests: Mapped[Optional[int]]
+ created_by = mapped_column(StringUUID, nullable=True)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_by = mapped_column(StringUUID, nullable=True)
+ updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ use_icon_as_answer_icon: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
@property
def desc_or_prompt(self):
@@ -114,13 +113,13 @@ class App(Base):
@property
def site(self):
- site = db.session.query(Site).filter(Site.app_id == self.id).first()
+ site = db.session.query(Site).where(Site.app_id == self.id).first()
return site
@property
def app_model_config(self):
if self.app_model_config_id:
- return db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first()
+ return db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
return None
@@ -129,7 +128,7 @@ class App(Base):
if self.workflow_id:
from .workflow import Workflow
- return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first()
+ return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
return None
@@ -139,7 +138,7 @@ class App(Base):
@property
def tenant(self):
- tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
+ tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
@property
@@ -283,7 +282,7 @@ class App(Base):
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
- .filter(
+ .where(
TagBinding.target_id == self.id,
TagBinding.tenant_id == self.tenant_id,
Tag.tenant_id == self.tenant_id,
@@ -297,7 +296,7 @@ class App(Base):
@property
def author_name(self):
if self.created_by:
- account = db.session.query(Account).filter(Account.id == self.created_by).first()
+ account = db.session.query(Account).where(Account.id == self.created_by).first()
if account:
return account.name
@@ -308,38 +307,38 @@ class AppModelConfig(Base):
__tablename__ = "app_model_configs"
__table_args__ = (db.PrimaryKeyConstraint("id", name="app_model_config_pkey"), db.Index("app_app_id_idx", "app_id"))
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- app_id = db.Column(StringUUID, nullable=False)
- provider = db.Column(db.String(255), nullable=True)
- model_id = db.Column(db.String(255), nullable=True)
- configs = db.Column(db.JSON, nullable=True)
- created_by = db.Column(StringUUID, nullable=True)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_by = db.Column(StringUUID, nullable=True)
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- opening_statement = db.Column(db.Text)
- suggested_questions = db.Column(db.Text)
- suggested_questions_after_answer = db.Column(db.Text)
- speech_to_text = db.Column(db.Text)
- text_to_speech = db.Column(db.Text)
- more_like_this = db.Column(db.Text)
- model = db.Column(db.Text)
- user_input_form = db.Column(db.Text)
- dataset_query_variable = db.Column(db.String(255))
- pre_prompt = db.Column(db.Text)
- agent_mode = db.Column(db.Text)
- sensitive_word_avoidance = db.Column(db.Text)
- retriever_resource = db.Column(db.Text)
- prompt_type = db.Column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying"))
- chat_prompt_config = db.Column(db.Text)
- completion_prompt_config = db.Column(db.Text)
- dataset_configs = db.Column(db.Text)
- external_data_tools = db.Column(db.Text)
- file_upload = db.Column(db.Text)
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ app_id = mapped_column(StringUUID, nullable=False)
+ provider = mapped_column(db.String(255), nullable=True)
+ model_id = mapped_column(db.String(255), nullable=True)
+ configs = mapped_column(db.JSON, nullable=True)
+ created_by = mapped_column(StringUUID, nullable=True)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_by = mapped_column(StringUUID, nullable=True)
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ opening_statement = mapped_column(db.Text)
+ suggested_questions = mapped_column(db.Text)
+ suggested_questions_after_answer = mapped_column(db.Text)
+ speech_to_text = mapped_column(db.Text)
+ text_to_speech = mapped_column(db.Text)
+ more_like_this = mapped_column(db.Text)
+ model = mapped_column(db.Text)
+ user_input_form = mapped_column(db.Text)
+ dataset_query_variable = mapped_column(db.String(255))
+ pre_prompt = mapped_column(db.Text)
+ agent_mode = mapped_column(db.Text)
+ sensitive_word_avoidance = mapped_column(db.Text)
+ retriever_resource = mapped_column(db.Text)
+ prompt_type = mapped_column(db.String(255), nullable=False, server_default=db.text("'simple'::character varying"))
+ chat_prompt_config = mapped_column(db.Text)
+ completion_prompt_config = mapped_column(db.Text)
+ dataset_configs = mapped_column(db.Text)
+ external_data_tools = mapped_column(db.Text)
+ file_upload = mapped_column(db.Text)
@property
def app(self):
- app = db.session.query(App).filter(App.id == self.app_id).first()
+ app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
@@ -373,7 +372,7 @@ class AppModelConfig(Base):
@property
def annotation_reply_dict(self) -> dict:
annotation_setting = (
- db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == self.app_id).first()
+ db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == self.app_id).first()
)
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
@@ -562,23 +561,23 @@ class RecommendedApp(Base):
db.Index("recommended_app_is_listed_idx", "is_listed", "language"),
)
- id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
- app_id = db.Column(StringUUID, nullable=False)
- description = db.Column(db.JSON, nullable=False)
- copyright = db.Column(db.String(255), nullable=False)
- privacy_policy = db.Column(db.String(255), nullable=False)
+ id = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
+ app_id = mapped_column(StringUUID, nullable=False)
+ description = mapped_column(db.JSON, nullable=False)
+ copyright = mapped_column(db.String(255), nullable=False)
+ privacy_policy = mapped_column(db.String(255), nullable=False)
custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
- category = db.Column(db.String(255), nullable=False)
- position = db.Column(db.Integer, nullable=False, default=0)
- is_listed = db.Column(db.Boolean, nullable=False, default=True)
- install_count = db.Column(db.Integer, nullable=False, default=0)
- language = db.Column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying"))
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ category = mapped_column(db.String(255), nullable=False)
+ position = mapped_column(db.Integer, nullable=False, default=0)
+ is_listed = mapped_column(db.Boolean, nullable=False, default=True)
+ install_count = mapped_column(db.Integer, nullable=False, default=0)
+ language = mapped_column(db.String(255), nullable=False, server_default=db.text("'en-US'::character varying"))
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self):
- app = db.session.query(App).filter(App.id == self.app_id).first()
+ app = db.session.query(App).where(App.id == self.app_id).first()
return app
@@ -591,34 +590,26 @@ class InstalledApp(Base):
db.UniqueConstraint("tenant_id", "app_id", name="unique_tenant_app"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- app_id = db.Column(StringUUID, nullable=False)
- app_owner_tenant_id = db.Column(StringUUID, nullable=False)
- position = db.Column(db.Integer, nullable=False, default=0)
- is_pinned = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- last_used_at = db.Column(db.DateTime, nullable=True)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ app_id = mapped_column(StringUUID, nullable=False)
+ app_owner_tenant_id = mapped_column(StringUUID, nullable=False)
+ position = mapped_column(db.Integer, nullable=False, default=0)
+ is_pinned = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
+ last_used_at = mapped_column(db.DateTime, nullable=True)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def app(self):
- app = db.session.query(App).filter(App.id == self.app_id).first()
+ app = db.session.query(App).where(App.id == self.app_id).first()
return app
@property
def tenant(self):
- tenant = db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
+ tenant = db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
return tenant
-class ConversationSource(StrEnum):
- """This enumeration is designed for use with `Conversation.from_source`."""
-
- # NOTE(QuantumGhost): The enumeration members may not cover all possible cases.
- API = "api"
- CONSOLE = "console"
-
-
class Conversation(Base):
__tablename__ = "conversations"
__table_args__ = (
@@ -627,42 +618,42 @@ class Conversation(Base):
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- app_id = db.Column(StringUUID, nullable=False)
- app_model_config_id = db.Column(StringUUID, nullable=True)
- model_provider = db.Column(db.String(255), nullable=True)
- override_model_configs = db.Column(db.Text)
- model_id = db.Column(db.String(255), nullable=True)
+ app_id = mapped_column(StringUUID, nullable=False)
+ app_model_config_id = mapped_column(StringUUID, nullable=True)
+ model_provider = mapped_column(db.String(255), nullable=True)
+ override_model_configs = mapped_column(db.Text)
+ model_id = mapped_column(db.String(255), nullable=True)
mode: Mapped[str] = mapped_column(db.String(255))
- name = db.Column(db.String(255), nullable=False)
- summary = db.Column(db.Text)
+ name = mapped_column(db.String(255), nullable=False)
+ summary = mapped_column(db.Text)
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
- introduction = db.Column(db.Text)
- system_instruction = db.Column(db.Text)
- system_instruction_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
- status = db.Column(db.String(255), nullable=False)
+ introduction = mapped_column(db.Text)
+ system_instruction = mapped_column(db.Text)
+ system_instruction_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
+ status = mapped_column(db.String(255), nullable=False)
# The `invoke_from` records how the conversation is created.
#
# Its value corresponds to the members of `InvokeFrom`.
# (api/core/app/entities/app_invoke_entities.py)
- invoke_from = db.Column(db.String(255), nullable=True)
+ invoke_from = mapped_column(db.String(255), nullable=True)
# ref: ConversationSource.
- from_source = db.Column(db.String(255), nullable=False)
- from_end_user_id = db.Column(StringUUID)
- from_account_id = db.Column(StringUUID)
- read_at = db.Column(db.DateTime)
- read_account_id = db.Column(StringUUID)
+ from_source = mapped_column(db.String(255), nullable=False)
+ from_end_user_id = mapped_column(StringUUID)
+ from_account_id = mapped_column(StringUUID)
+ read_at = mapped_column(db.DateTime)
+ read_account_id = mapped_column(StringUUID)
dialogue_count: Mapped[int] = mapped_column(default=0)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
messages = db.relationship("Message", backref="conversation", lazy="select", passive_deletes="all")
message_annotations = db.relationship(
"MessageAnnotation", backref="conversation", lazy="select", passive_deletes="all"
)
- is_deleted = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
+ is_deleted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
@property
def inputs(self):
@@ -723,7 +714,7 @@ class Conversation(Base):
model_config["configs"] = override_model_configs
else:
app_model_config = (
- db.session.query(AppModelConfig).filter(AppModelConfig.id == self.app_model_config_id).first()
+ db.session.query(AppModelConfig).where(AppModelConfig.id == self.app_model_config_id).first()
)
if app_model_config:
model_config = app_model_config.to_dict()
@@ -746,21 +737,21 @@ class Conversation(Base):
@property
def annotated(self):
- return db.session.query(MessageAnnotation).filter(MessageAnnotation.conversation_id == self.id).count() > 0
+ return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).count() > 0
@property
def annotation(self):
- return db.session.query(MessageAnnotation).filter(MessageAnnotation.conversation_id == self.id).first()
+ return db.session.query(MessageAnnotation).where(MessageAnnotation.conversation_id == self.id).first()
@property
def message_count(self):
- return db.session.query(Message).filter(Message.conversation_id == self.id).count()
+ return db.session.query(Message).where(Message.conversation_id == self.id).count()
@property
def user_feedback_stats(self):
like = (
db.session.query(MessageFeedback)
- .filter(
+ .where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "like",
@@ -770,7 +761,7 @@ class Conversation(Base):
dislike = (
db.session.query(MessageFeedback)
- .filter(
+ .where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike",
@@ -784,7 +775,7 @@ class Conversation(Base):
def admin_feedback_stats(self):
like = (
db.session.query(MessageFeedback)
- .filter(
+ .where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like",
@@ -794,7 +785,7 @@ class Conversation(Base):
dislike = (
db.session.query(MessageFeedback)
- .filter(
+ .where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike",
@@ -806,7 +797,7 @@ class Conversation(Base):
@property
def status_count(self):
- messages = db.session.query(Message).filter(Message.conversation_id == self.id).all()
+ messages = db.session.query(Message).where(Message.conversation_id == self.id).all()
status_counts = {
WorkflowExecutionStatus.RUNNING: 0,
WorkflowExecutionStatus.SUCCEEDED: 0,
@@ -833,19 +824,19 @@ class Conversation(Base):
def first_message(self):
return (
db.session.query(Message)
- .filter(Message.conversation_id == self.id)
+ .where(Message.conversation_id == self.id)
.order_by(Message.created_at.asc())
.first()
)
@property
def app(self):
- return db.session.query(App).filter(App.id == self.app_id).first()
+ return db.session.query(App).where(App.id == self.app_id).first()
@property
def from_end_user_session_id(self):
if self.from_end_user_id:
- end_user = db.session.query(EndUser).filter(EndUser.id == self.from_end_user_id).first()
+ end_user = db.session.query(EndUser).where(EndUser.id == self.from_end_user_id).first()
if end_user:
return end_user.session_id
@@ -854,7 +845,7 @@ class Conversation(Base):
@property
def from_account_name(self):
if self.from_account_id:
- account = db.session.query(Account).filter(Account.id == self.from_account_id).first()
+ account = db.session.query(Account).where(Account.id == self.from_account_id).first()
if account:
return account.name
@@ -905,36 +896,36 @@ class Message(Base):
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- app_id = db.Column(StringUUID, nullable=False)
- model_provider = db.Column(db.String(255), nullable=True)
- model_id = db.Column(db.String(255), nullable=True)
- override_model_configs = db.Column(db.Text)
- conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=False)
+ app_id = mapped_column(StringUUID, nullable=False)
+ model_provider = mapped_column(db.String(255), nullable=True)
+ model_id = mapped_column(db.String(255), nullable=True)
+ override_model_configs = mapped_column(db.Text)
+ conversation_id = mapped_column(StringUUID, db.ForeignKey("conversations.id"), nullable=False)
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
- query: Mapped[str] = db.Column(db.Text, nullable=False)
- message = db.Column(db.JSON, nullable=False)
- message_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
- message_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
- message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
- answer: Mapped[str] = db.Column(db.Text, nullable=False)
- answer_tokens: Mapped[int] = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
- answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
- answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
- parent_message_id = db.Column(StringUUID, nullable=True)
- provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0"))
- total_price = db.Column(db.Numeric(10, 7))
- currency = db.Column(db.String(255), nullable=False)
- status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
- error = db.Column(db.Text)
- message_metadata = db.Column(db.Text)
- invoke_from: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True)
- from_source = db.Column(db.String(255), nullable=False)
- from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID)
- from_account_id: Mapped[Optional[str]] = db.Column(StringUUID)
+ query: Mapped[str] = mapped_column(db.Text, nullable=False)
+ message = mapped_column(db.JSON, nullable=False)
+ message_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
+ message_unit_price = mapped_column(db.Numeric(10, 4), nullable=False)
+ message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
+ answer: Mapped[str] = db.Column(db.Text, nullable=False) # TODO make it mapped_column
+ answer_tokens: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
+ answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False)
+ answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
+ parent_message_id = mapped_column(StringUUID, nullable=True)
+ provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0"))
+ total_price = mapped_column(db.Numeric(10, 7))
+ currency = mapped_column(db.String(255), nullable=False)
+ status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
+ error = mapped_column(db.Text)
+ message_metadata = mapped_column(db.Text)
+ invoke_from: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
+ from_source = mapped_column(db.String(255), nullable=False)
+ from_end_user_id: Mapped[Optional[str]] = mapped_column(StringUUID)
+ from_account_id: Mapped[Optional[str]] = mapped_column(StringUUID)
created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- workflow_run_id = db.Column(StringUUID)
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ agent_based = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
+ workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
@property
def inputs(self):
@@ -1049,7 +1040,7 @@ class Message(Base):
def user_feedback(self):
feedback = (
db.session.query(MessageFeedback)
- .filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
+ .where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "user")
.first()
)
return feedback
@@ -1058,30 +1049,30 @@ class Message(Base):
def admin_feedback(self):
feedback = (
db.session.query(MessageFeedback)
- .filter(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
+ .where(MessageFeedback.message_id == self.id, MessageFeedback.from_source == "admin")
.first()
)
return feedback
@property
def feedbacks(self):
- feedbacks = db.session.query(MessageFeedback).filter(MessageFeedback.message_id == self.id).all()
+ feedbacks = db.session.query(MessageFeedback).where(MessageFeedback.message_id == self.id).all()
return feedbacks
@property
def annotation(self):
- annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == self.id).first()
+ annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == self.id).first()
return annotation
@property
def annotation_hit_history(self):
annotation_history = (
- db.session.query(AppAnnotationHitHistory).filter(AppAnnotationHitHistory.message_id == self.id).first()
+ db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.message_id == self.id).first()
)
if annotation_history:
annotation = (
db.session.query(MessageAnnotation)
- .filter(MessageAnnotation.id == annotation_history.annotation_id)
+ .where(MessageAnnotation.id == annotation_history.annotation_id)
.first()
)
return annotation
@@ -1089,11 +1080,9 @@ class Message(Base):
@property
def app_model_config(self):
- conversation = db.session.query(Conversation).filter(Conversation.id == self.conversation_id).first()
+ conversation = db.session.query(Conversation).where(Conversation.id == self.conversation_id).first()
if conversation:
- return (
- db.session.query(AppModelConfig).filter(AppModelConfig.id == conversation.app_model_config_id).first()
- )
+ return db.session.query(AppModelConfig).where(AppModelConfig.id == conversation.app_model_config_id).first()
return None
@@ -1109,7 +1098,7 @@ class Message(Base):
def agent_thoughts(self):
return (
db.session.query(MessageAgentThought)
- .filter(MessageAgentThought.message_id == self.id)
+ .where(MessageAgentThought.message_id == self.id)
.order_by(MessageAgentThought.position.asc())
.all()
)
@@ -1122,8 +1111,8 @@ class Message(Base):
def message_files(self):
from factories import file_factory
- message_files = db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all()
- current_app = db.session.query(App).filter(App.id == self.app_id).first()
+ message_files = db.session.query(MessageFile).where(MessageFile.message_id == self.id).all()
+ current_app = db.session.query(App).where(App.id == self.app_id).first()
if not current_app:
raise ValueError(f"App {self.app_id} not found")
@@ -1187,7 +1176,7 @@ class Message(Base):
if self.workflow_run_id:
from .workflow import WorkflowRun
- return db.session.query(WorkflowRun).filter(WorkflowRun.id == self.workflow_run_id).first()
+ return db.session.query(WorkflowRun).where(WorkflowRun.id == self.workflow_run_id).first()
return None
@@ -1248,21 +1237,21 @@ class MessageFeedback(Base):
db.Index("message_feedback_conversation_idx", "conversation_id", "from_source", "rating"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- app_id = db.Column(StringUUID, nullable=False)
- conversation_id = db.Column(StringUUID, nullable=False)
- message_id = db.Column(StringUUID, nullable=False)
- rating = db.Column(db.String(255), nullable=False)
- content = db.Column(db.Text)
- from_source = db.Column(db.String(255), nullable=False)
- from_end_user_id = db.Column(StringUUID)
- from_account_id = db.Column(StringUUID)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ app_id = mapped_column(StringUUID, nullable=False)
+ conversation_id = mapped_column(StringUUID, nullable=False)
+ message_id = mapped_column(StringUUID, nullable=False)
+ rating = mapped_column(db.String(255), nullable=False)
+ content = mapped_column(db.Text)
+ from_source = mapped_column(db.String(255), nullable=False)
+ from_end_user_id = mapped_column(StringUUID)
+ from_account_id = mapped_column(StringUUID)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def from_account(self):
- account = db.session.query(Account).filter(Account.id == self.from_account_id).first()
+ account = db.session.query(Account).where(Account.id == self.from_account_id).first()
return account
def to_dict(self):
@@ -1310,16 +1299,16 @@ class MessageFile(Base):
self.created_by_role = created_by_role.value
self.created_by = created_by
- id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- message_id: Mapped[str] = db.Column(StringUUID, nullable=False)
- type: Mapped[str] = db.Column(db.String(255), nullable=False)
- transfer_method: Mapped[str] = db.Column(db.String(255), nullable=False)
- url: Mapped[Optional[str]] = db.Column(db.Text, nullable=True)
- belongs_to: Mapped[Optional[str]] = db.Column(db.String(255), nullable=True)
- upload_file_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True)
- created_by_role: Mapped[str] = db.Column(db.String(255), nullable=False)
- created_by: Mapped[str] = db.Column(StringUUID, nullable=False)
- created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ type: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ transfer_method: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ url: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True)
+ belongs_to: Mapped[Optional[str]] = mapped_column(db.String(255), nullable=True)
+ upload_file_id: Mapped[Optional[str]] = mapped_column(StringUUID, nullable=True)
+ created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class MessageAnnotation(Base):
@@ -1331,25 +1320,25 @@ class MessageAnnotation(Base):
db.Index("message_annotation_message_idx", "message_id"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- app_id = db.Column(StringUUID, nullable=False)
- conversation_id = db.Column(StringUUID, db.ForeignKey("conversations.id"), nullable=True)
- message_id = db.Column(StringUUID, nullable=True)
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ app_id: Mapped[str] = mapped_column(StringUUID)
+ conversation_id: Mapped[Optional[str]] = mapped_column(StringUUID, db.ForeignKey("conversations.id"))
+ message_id: Mapped[Optional[str]] = mapped_column(StringUUID)
question = db.Column(db.Text, nullable=True)
- content = db.Column(db.Text, nullable=False)
- hit_count = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
- account_id = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ content = mapped_column(db.Text, nullable=False)
+ hit_count = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
+ account_id = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def account(self):
- account = db.session.query(Account).filter(Account.id == self.account_id).first()
+ account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
@property
def annotation_create_account(self):
- account = db.session.query(Account).filter(Account.id == self.account_id).first()
+ account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
@@ -1363,31 +1352,31 @@ class AppAnnotationHitHistory(Base):
db.Index("app_annotation_hit_histories_message_idx", "message_id"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- app_id = db.Column(StringUUID, nullable=False)
- annotation_id: Mapped[str] = db.Column(StringUUID, nullable=False)
- source = db.Column(db.Text, nullable=False)
- question = db.Column(db.Text, nullable=False)
- account_id = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- score = db.Column(Float, nullable=False, server_default=db.text("0"))
- message_id = db.Column(StringUUID, nullable=False)
- annotation_question = db.Column(db.Text, nullable=False)
- annotation_content = db.Column(db.Text, nullable=False)
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ app_id = mapped_column(StringUUID, nullable=False)
+ annotation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ source = mapped_column(db.Text, nullable=False)
+ question = mapped_column(db.Text, nullable=False)
+ account_id = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ score = mapped_column(Float, nullable=False, server_default=db.text("0"))
+ message_id = mapped_column(StringUUID, nullable=False)
+ annotation_question = mapped_column(db.Text, nullable=False)
+ annotation_content = mapped_column(db.Text, nullable=False)
@property
def account(self):
account = (
db.session.query(Account)
.join(MessageAnnotation, MessageAnnotation.account_id == Account.id)
- .filter(MessageAnnotation.id == self.annotation_id)
+ .where(MessageAnnotation.id == self.annotation_id)
.first()
)
return account
@property
def annotation_create_account(self):
- account = db.session.query(Account).filter(Account.id == self.account_id).first()
+ account = db.session.query(Account).where(Account.id == self.account_id).first()
return account
@@ -1398,14 +1387,14 @@ class AppAnnotationSetting(Base):
db.Index("app_annotation_settings_app_idx", "app_id"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- app_id = db.Column(StringUUID, nullable=False)
- score_threshold = db.Column(Float, nullable=False, server_default=db.text("0"))
- collection_binding_id = db.Column(StringUUID, nullable=False)
- created_user_id = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_user_id = db.Column(StringUUID, nullable=False)
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ app_id = mapped_column(StringUUID, nullable=False)
+ score_threshold = mapped_column(Float, nullable=False, server_default=db.text("0"))
+ collection_binding_id = mapped_column(StringUUID, nullable=False)
+ created_user_id = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_user_id = mapped_column(StringUUID, nullable=False)
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def collection_binding_detail(self):
@@ -1413,7 +1402,7 @@ class AppAnnotationSetting(Base):
collection_binding_detail = (
db.session.query(DatasetCollectionBinding)
- .filter(DatasetCollectionBinding.id == self.collection_binding_id)
+ .where(DatasetCollectionBinding.id == self.collection_binding_id)
.first()
)
return collection_binding_detail
@@ -1426,14 +1415,14 @@ class OperationLog(Base):
db.Index("operation_log_account_action_idx", "tenant_id", "account_id", "action"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- account_id = db.Column(StringUUID, nullable=False)
- action = db.Column(db.String(255), nullable=False)
- content = db.Column(db.JSON)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- created_ip = db.Column(db.String(255), nullable=False)
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ account_id = mapped_column(StringUUID, nullable=False)
+ action = mapped_column(db.String(255), nullable=False)
+ content = mapped_column(db.JSON)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ created_ip = mapped_column(db.String(255), nullable=False)
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class EndUser(Base, UserMixin):
@@ -1444,16 +1433,16 @@ class EndUser(Base, UserMixin):
db.Index("end_user_tenant_session_id_idx", "tenant_id", "session_id", "type"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
- app_id = db.Column(StringUUID, nullable=True)
- type = db.Column(db.String(255), nullable=False)
- external_user_id = db.Column(db.String(255), nullable=True)
- name = db.Column(db.String(255))
- is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ app_id = mapped_column(StringUUID, nullable=True)
+ type = mapped_column(db.String(255), nullable=False)
+ external_user_id = mapped_column(db.String(255), nullable=True)
+ name = mapped_column(db.String(255))
+ is_anonymous = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
session_id: Mapped[str] = mapped_column()
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class AppMCPServer(Base):
@@ -1463,23 +1452,23 @@ class AppMCPServer(Base):
db.UniqueConstraint("tenant_id", "app_id", name="unique_app_mcp_server_tenant_app_id"),
db.UniqueConstraint("server_code", name="unique_app_mcp_server_server_code"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- app_id = db.Column(StringUUID, nullable=False)
- name = db.Column(db.String(255), nullable=False)
- description = db.Column(db.String(255), nullable=False)
- server_code = db.Column(db.String(255), nullable=False)
- status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
- parameters = db.Column(db.Text, nullable=False)
-
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ app_id = mapped_column(StringUUID, nullable=False)
+ name = mapped_column(db.String(255), nullable=False)
+ description = mapped_column(db.String(255), nullable=False)
+ server_code = mapped_column(db.String(255), nullable=False)
+ status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
+ parameters = mapped_column(db.Text, nullable=False)
+
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@staticmethod
def generate_server_code(n):
while True:
result = generate_string(n)
- while db.session.query(AppMCPServer).filter(AppMCPServer.server_code == result).count() > 0:
+ while db.session.query(AppMCPServer).where(AppMCPServer.server_code == result).count() > 0:
result = generate_string(n)
return result
@@ -1497,30 +1486,30 @@ class Site(Base):
db.Index("site_code_idx", "code", "status"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- app_id = db.Column(StringUUID, nullable=False)
- title = db.Column(db.String(255), nullable=False)
- icon_type = db.Column(db.String(255), nullable=True)
- icon = db.Column(db.String(255))
- icon_background = db.Column(db.String(255))
- description = db.Column(db.Text)
- default_language = db.Column(db.String(255), nullable=False)
- chat_color_theme = db.Column(db.String(255))
- chat_color_theme_inverted = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- copyright = db.Column(db.String(255))
- privacy_policy = db.Column(db.String(255))
- show_workflow_steps = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
- use_icon_as_answer_icon = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ app_id = mapped_column(StringUUID, nullable=False)
+ title = mapped_column(db.String(255), nullable=False)
+ icon_type = mapped_column(db.String(255), nullable=True)
+ icon = mapped_column(db.String(255))
+ icon_background = mapped_column(db.String(255))
+ description = mapped_column(db.Text)
+ default_language = mapped_column(db.String(255), nullable=False)
+ chat_color_theme = mapped_column(db.String(255))
+ chat_color_theme_inverted = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
+ copyright = mapped_column(db.String(255))
+ privacy_policy = mapped_column(db.String(255))
+ show_workflow_steps = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
+ use_icon_as_answer_icon = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
_custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", sa.TEXT, default="")
- customize_domain = db.Column(db.String(255))
- customize_token_strategy = db.Column(db.String(255), nullable=False)
- prompt_public = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
- created_by = db.Column(StringUUID, nullable=True)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_by = db.Column(StringUUID, nullable=True)
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- code = db.Column(db.String(255))
+ customize_domain = mapped_column(db.String(255))
+ customize_token_strategy = mapped_column(db.String(255), nullable=False)
+ prompt_public = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
+ status = mapped_column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
+ created_by = mapped_column(StringUUID, nullable=True)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_by = mapped_column(StringUUID, nullable=True)
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ code = mapped_column(db.String(255))
@property
def custom_disclaimer(self):
@@ -1536,7 +1525,7 @@ class Site(Base):
def generate_code(n):
while True:
result = generate_string(n)
- while db.session.query(Site).filter(Site.code == result).count() > 0:
+ while db.session.query(Site).where(Site.code == result).count() > 0:
result = generate_string(n)
return result
@@ -1555,19 +1544,19 @@ class ApiToken(Base):
db.Index("api_token_tenant_idx", "tenant_id", "type"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- app_id = db.Column(StringUUID, nullable=True)
- tenant_id = db.Column(StringUUID, nullable=True)
- type = db.Column(db.String(16), nullable=False)
- token = db.Column(db.String(255), nullable=False)
- last_used_at = db.Column(db.DateTime, nullable=True)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ app_id = mapped_column(StringUUID, nullable=True)
+ tenant_id = mapped_column(StringUUID, nullable=True)
+ type = mapped_column(db.String(16), nullable=False)
+ token = mapped_column(db.String(255), nullable=False)
+ last_used_at = mapped_column(db.DateTime, nullable=True)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@staticmethod
def generate_api_key(prefix, n):
while True:
result = prefix + generate_string(n)
- if db.session.query(ApiToken).filter(ApiToken.token == result).count() > 0:
+ if db.session.query(ApiToken).where(ApiToken.token == result).count() > 0:
continue
return result
@@ -1579,23 +1568,23 @@ class UploadFile(Base):
db.Index("upload_file_tenant_idx", "tenant_id"),
)
- id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
- storage_type: Mapped[str] = db.Column(db.String(255), nullable=False)
- key: Mapped[str] = db.Column(db.String(255), nullable=False)
- name: Mapped[str] = db.Column(db.String(255), nullable=False)
- size: Mapped[int] = db.Column(db.Integer, nullable=False)
- extension: Mapped[str] = db.Column(db.String(255), nullable=False)
- mime_type: Mapped[str] = db.Column(db.String(255), nullable=True)
- created_by_role: Mapped[str] = db.Column(
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ storage_type: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ key: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ name: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ size: Mapped[int] = mapped_column(db.Integer, nullable=False)
+ extension: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ mime_type: Mapped[str] = mapped_column(db.String(255), nullable=True)
+ created_by_role: Mapped[str] = mapped_column(
db.String(255), nullable=False, server_default=db.text("'account'::character varying")
)
- created_by: Mapped[str] = db.Column(StringUUID, nullable=False)
- created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- used: Mapped[bool] = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
- used_by: Mapped[str | None] = db.Column(StringUUID, nullable=True)
- used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True)
- hash: Mapped[str | None] = db.Column(db.String(255), nullable=True)
+ created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ used: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
+ used_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
+ used_at: Mapped[datetime | None] = mapped_column(db.DateTime, nullable=True)
+ hash: Mapped[str | None] = mapped_column(db.String(255), nullable=True)
source_url: Mapped[str] = mapped_column(sa.TEXT, default="")
def __init__(
@@ -1641,14 +1630,14 @@ class ApiRequest(Base):
db.Index("api_request_token_idx", "tenant_id", "api_token_id"),
)
- id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- api_token_id = db.Column(StringUUID, nullable=False)
- path = db.Column(db.String(255), nullable=False)
- request = db.Column(db.Text, nullable=True)
- response = db.Column(db.Text, nullable=True)
- ip = db.Column(db.String(255), nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ api_token_id = mapped_column(StringUUID, nullable=False)
+ path = mapped_column(db.String(255), nullable=False)
+ request = mapped_column(db.Text, nullable=True)
+ response = mapped_column(db.Text, nullable=True)
+ ip = mapped_column(db.String(255), nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class MessageChain(Base):
@@ -1658,12 +1647,12 @@ class MessageChain(Base):
db.Index("message_chain_message_id_idx", "message_id"),
)
- id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
- message_id = db.Column(StringUUID, nullable=False)
- type = db.Column(db.String(255), nullable=False)
- input = db.Column(db.Text, nullable=True)
- output = db.Column(db.Text, nullable=True)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
+ id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+ message_id = mapped_column(StringUUID, nullable=False)
+ type = mapped_column(db.String(255), nullable=False)
+ input = mapped_column(db.Text, nullable=True)
+ output = mapped_column(db.Text, nullable=True)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
class MessageAgentThought(Base):
@@ -1674,34 +1663,34 @@ class MessageAgentThought(Base):
db.Index("message_agent_thought_message_chain_id_idx", "message_chain_id"),
)
- id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
- message_id = db.Column(StringUUID, nullable=False)
- message_chain_id = db.Column(StringUUID, nullable=True)
- position = db.Column(db.Integer, nullable=False)
- thought = db.Column(db.Text, nullable=True)
- tool = db.Column(db.Text, nullable=True)
- tool_labels_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
- tool_meta_str = db.Column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
- tool_input = db.Column(db.Text, nullable=True)
- observation = db.Column(db.Text, nullable=True)
- # plugin_id = db.Column(StringUUID, nullable=True) ## for future design
- tool_process_data = db.Column(db.Text, nullable=True)
- message = db.Column(db.Text, nullable=True)
- message_token = db.Column(db.Integer, nullable=True)
- message_unit_price = db.Column(db.Numeric, nullable=True)
- message_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
- message_files = db.Column(db.Text, nullable=True)
+ id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+ message_id = mapped_column(StringUUID, nullable=False)
+ message_chain_id = mapped_column(StringUUID, nullable=True)
+ position = mapped_column(db.Integer, nullable=False)
+ thought = mapped_column(db.Text, nullable=True)
+ tool = mapped_column(db.Text, nullable=True)
+ tool_labels_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
+ tool_meta_str = mapped_column(db.Text, nullable=False, server_default=db.text("'{}'::text"))
+ tool_input = mapped_column(db.Text, nullable=True)
+ observation = mapped_column(db.Text, nullable=True)
+ # plugin_id = mapped_column(StringUUID, nullable=True) ## for future design
+ tool_process_data = mapped_column(db.Text, nullable=True)
+ message = mapped_column(db.Text, nullable=True)
+ message_token = mapped_column(db.Integer, nullable=True)
+ message_unit_price = mapped_column(db.Numeric, nullable=True)
+ message_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
+ message_files = mapped_column(db.Text, nullable=True)
answer = db.Column(db.Text, nullable=True)
- answer_token = db.Column(db.Integer, nullable=True)
- answer_unit_price = db.Column(db.Numeric, nullable=True)
- answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
- tokens = db.Column(db.Integer, nullable=True)
- total_price = db.Column(db.Numeric, nullable=True)
- currency = db.Column(db.String, nullable=True)
- latency = db.Column(db.Float, nullable=True)
- created_by_role = db.Column(db.String, nullable=False)
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
+ answer_token = mapped_column(db.Integer, nullable=True)
+ answer_unit_price = mapped_column(db.Numeric, nullable=True)
+ answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
+ tokens = mapped_column(db.Integer, nullable=True)
+ total_price = mapped_column(db.Numeric, nullable=True)
+ currency = mapped_column(db.String, nullable=True)
+ latency = mapped_column(db.Float, nullable=True)
+ created_by_role = mapped_column(db.String, nullable=False)
+ created_by = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
@property
def files(self) -> list:
@@ -1787,24 +1776,24 @@ class DatasetRetrieverResource(Base):
db.Index("dataset_retriever_resource_message_id_idx", "message_id"),
)
- id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
- message_id = db.Column(StringUUID, nullable=False)
- position = db.Column(db.Integer, nullable=False)
- dataset_id = db.Column(StringUUID, nullable=False)
- dataset_name = db.Column(db.Text, nullable=False)
- document_id = db.Column(StringUUID, nullable=True)
- document_name = db.Column(db.Text, nullable=False)
- data_source_type = db.Column(db.Text, nullable=True)
- segment_id = db.Column(StringUUID, nullable=True)
- score = db.Column(db.Float, nullable=True)
- content = db.Column(db.Text, nullable=False)
- hit_count = db.Column(db.Integer, nullable=True)
- word_count = db.Column(db.Integer, nullable=True)
- segment_position = db.Column(db.Integer, nullable=True)
- index_node_hash = db.Column(db.Text, nullable=True)
- retriever_from = db.Column(db.Text, nullable=False)
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
+ id = mapped_column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
+ message_id = mapped_column(StringUUID, nullable=False)
+ position = mapped_column(db.Integer, nullable=False)
+ dataset_id = mapped_column(StringUUID, nullable=False)
+ dataset_name = mapped_column(db.Text, nullable=False)
+ document_id = mapped_column(StringUUID, nullable=True)
+ document_name = mapped_column(db.Text, nullable=False)
+ data_source_type = mapped_column(db.Text, nullable=True)
+ segment_id = mapped_column(StringUUID, nullable=True)
+ score = mapped_column(db.Float, nullable=True)
+ content = mapped_column(db.Text, nullable=False)
+ hit_count = mapped_column(db.Integer, nullable=True)
+ word_count = mapped_column(db.Integer, nullable=True)
+ segment_position = mapped_column(db.Integer, nullable=True)
+ index_node_hash = mapped_column(db.Text, nullable=True)
+ retriever_from = mapped_column(db.Text, nullable=False)
+ created_by = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
class Tag(Base):
@@ -1817,12 +1806,12 @@ class Tag(Base):
TAG_TYPE_LIST = ["knowledge", "app"]
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=True)
- type = db.Column(db.String(16), nullable=False)
- name = db.Column(db.String(255), nullable=False)
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=True)
+ type = mapped_column(db.String(16), nullable=False)
+ name = mapped_column(db.String(255), nullable=False)
+ created_by = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TagBinding(Base):
@@ -1833,12 +1822,12 @@ class TagBinding(Base):
db.Index("tag_bind_tag_id_idx", "tag_id"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=True)
- tag_id = db.Column(StringUUID, nullable=True)
- target_id = db.Column(StringUUID, nullable=True)
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=True)
+ tag_id = mapped_column(StringUUID, nullable=True)
+ target_id = mapped_column(StringUUID, nullable=True)
+ created_by = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TraceAppConfig(Base):
@@ -1848,15 +1837,15 @@ class TraceAppConfig(Base):
db.Index("trace_app_config_app_id_idx", "app_id"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- app_id = db.Column(StringUUID, nullable=False)
- tracing_provider = db.Column(db.String(255), nullable=True)
- tracing_config = db.Column(db.JSON, nullable=True)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ app_id = mapped_column(StringUUID, nullable=False)
+ tracing_provider = mapped_column(db.String(255), nullable=True)
+ tracing_config = mapped_column(db.JSON, nullable=True)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
- is_active = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
+ is_active = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
@property
def tracing_config_dict(self):
diff --git a/api/models/source.py b/api/models/source.py
index f6e0900ae6..100e0d96ef 100644
--- a/api/models/source.py
+++ b/api/models/source.py
@@ -2,6 +2,7 @@ import json
from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.orm import mapped_column
from models.base import Base
@@ -17,14 +18,14 @@ class DataSourceOauthBinding(Base):
db.Index("source_info_idx", "source_info", postgresql_using="gin"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- access_token = db.Column(db.String(255), nullable=False)
- provider = db.Column(db.String(255), nullable=False)
- source_info = db.Column(JSONB, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ access_token = mapped_column(db.String(255), nullable=False)
+ provider = mapped_column(db.String(255), nullable=False)
+ source_info = mapped_column(JSONB, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false"))
class DataSourceApiKeyAuthBinding(Base):
@@ -35,14 +36,14 @@ class DataSourceApiKeyAuthBinding(Base):
db.Index("data_source_api_key_auth_binding_provider_idx", "provider"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- tenant_id = db.Column(StringUUID, nullable=False)
- category = db.Column(db.String(255), nullable=False)
- provider = db.Column(db.String(255), nullable=False)
- credentials = db.Column(db.Text, nullable=True) # JSON
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- disabled = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ tenant_id = mapped_column(StringUUID, nullable=False)
+ category = mapped_column(db.String(255), nullable=False)
+ provider = mapped_column(db.String(255), nullable=False)
+ credentials = mapped_column(db.Text, nullable=True) # JSON
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ disabled = mapped_column(db.Boolean, nullable=True, server_default=db.text("false"))
def to_dict(self):
return {
diff --git a/api/models/task.py b/api/models/task.py
index d853c1dd9a..3e5ebd2099 100644
--- a/api/models/task.py
+++ b/api/models/task.py
@@ -1,7 +1,10 @@
-from datetime import UTC, datetime
+from datetime import datetime
+from typing import Optional
from celery import states # type: ignore
+from sqlalchemy.orm import Mapped, mapped_column
+from libs.datetime_utils import naive_utc_now
from models.base import Base
from .engine import db
@@ -12,23 +15,23 @@ class CeleryTask(Base):
__tablename__ = "celery_taskmeta"
- id = db.Column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True)
- task_id = db.Column(db.String(155), unique=True)
- status = db.Column(db.String(50), default=states.PENDING)
- result = db.Column(db.PickleType, nullable=True)
- date_done = db.Column(
+ id = mapped_column(db.Integer, db.Sequence("task_id_sequence"), primary_key=True, autoincrement=True)
+ task_id = mapped_column(db.String(155), unique=True)
+ status = mapped_column(db.String(50), default=states.PENDING)
+ result = mapped_column(db.PickleType, nullable=True)
+ date_done = mapped_column(
db.DateTime,
- default=lambda: datetime.now(UTC).replace(tzinfo=None),
- onupdate=lambda: datetime.now(UTC).replace(tzinfo=None),
+ default=lambda: naive_utc_now(),
+ onupdate=lambda: naive_utc_now(),
nullable=True,
)
- traceback = db.Column(db.Text, nullable=True)
- name = db.Column(db.String(155), nullable=True)
- args = db.Column(db.LargeBinary, nullable=True)
- kwargs = db.Column(db.LargeBinary, nullable=True)
- worker = db.Column(db.String(155), nullable=True)
- retries = db.Column(db.Integer, nullable=True)
- queue = db.Column(db.String(155), nullable=True)
+ traceback = mapped_column(db.Text, nullable=True)
+ name = mapped_column(db.String(155), nullable=True)
+ args = mapped_column(db.LargeBinary, nullable=True)
+ kwargs = mapped_column(db.LargeBinary, nullable=True)
+ worker = mapped_column(db.String(155), nullable=True)
+ retries = mapped_column(db.Integer, nullable=True)
+ queue = mapped_column(db.String(155), nullable=True)
class CeleryTaskSet(Base):
@@ -36,7 +39,9 @@ class CeleryTaskSet(Base):
__tablename__ = "celery_tasksetmeta"
- id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True)
- taskset_id = db.Column(db.String(155), unique=True)
- result = db.Column(db.PickleType, nullable=True)
- date_done = db.Column(db.DateTime, default=lambda: datetime.now(UTC).replace(tzinfo=None), nullable=True)
+ id: Mapped[int] = mapped_column(
+ db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True
+ )
+ taskset_id = mapped_column(db.String(155), unique=True)
+ result = mapped_column(db.PickleType, nullable=True)
+ date_done: Mapped[Optional[datetime]] = mapped_column(db.DateTime, default=lambda: naive_utc_now(), nullable=True)
diff --git a/api/models/tools.py b/api/models/tools.py
index 9d2c3baea5..68f4211e59 100644
--- a/api/models/tools.py
+++ b/api/models/tools.py
@@ -21,6 +21,43 @@ from .model import Account, App, Tenant
from .types import StringUUID
+# system level tool oauth client params (client_id, client_secret, etc.)
+class ToolOAuthSystemClient(Base):
+ __tablename__ = "tool_oauth_system_clients"
+ __table_args__ = (
+ db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"),
+ db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"),
+ )
+
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
+ provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ # oauth params of the tool provider
+ encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
+
+
+# tenant level tool oauth client params (client_id, client_secret, etc.)
+class ToolOAuthTenantClient(Base):
+ __tablename__ = "tool_oauth_tenant_clients"
+ __table_args__ = (
+ db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"),
+ db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"),
+ )
+
+ id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ # tenant id
+ tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
+ plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False)
+ provider: Mapped[str] = mapped_column(db.String(255), nullable=False)
+ enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true"))
+ # oauth params of the tool provider
+ encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False)
+
+ @property
+ def oauth_params(self) -> dict:
+ return cast(dict, json.loads(self.encrypted_oauth_params or "{}"))
+
+
class BuiltinToolProvider(Base):
"""
This table stores the tool provider information for built-in tools for each tenant.
@@ -29,12 +66,14 @@ class BuiltinToolProvider(Base):
__tablename__ = "tool_builtin_providers"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"),
- # one tenant can only have one tool provider with the same name
- db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"),
+ db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"),
)
# id of the tool provider
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ name: Mapped[str] = mapped_column(
+ db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying")
+ )
# id of the tenant
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
# who created this tool provider
@@ -49,6 +88,12 @@ class BuiltinToolProvider(Base):
updated_at: Mapped[datetime] = mapped_column(
db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")
)
+ is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
+ # credential type, e.g., "api-key", "oauth2"
+ credential_type: Mapped[str] = mapped_column(
+ db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
+ )
+ expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1"))
@property
def credentials(self) -> dict:
@@ -66,26 +111,26 @@ class ApiToolProvider(Base):
db.UniqueConstraint("name", "tenant_id", name="unique_api_tool_provider"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# name of the api provider
- name = db.Column(db.String(255), nullable=False)
+ name = mapped_column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying"))
# icon
- icon = db.Column(db.String(255), nullable=False)
+ icon = mapped_column(db.String(255), nullable=False)
# original schema
- schema = db.Column(db.Text, nullable=False)
- schema_type_str: Mapped[str] = db.Column(db.String(40), nullable=False)
+ schema = mapped_column(db.Text, nullable=False)
+ schema_type_str: Mapped[str] = mapped_column(db.String(40), nullable=False)
# who created this tool
- user_id = db.Column(StringUUID, nullable=False)
+ user_id = mapped_column(StringUUID, nullable=False)
# tenant id
- tenant_id = db.Column(StringUUID, nullable=False)
+ tenant_id = mapped_column(StringUUID, nullable=False)
# description of the provider
- description = db.Column(db.Text, nullable=False)
+ description = mapped_column(db.Text, nullable=False)
# json format tools
- tools_str = db.Column(db.Text, nullable=False)
+ tools_str = mapped_column(db.Text, nullable=False)
# json format credentials
- credentials_str = db.Column(db.Text, nullable=False)
+ credentials_str = mapped_column(db.Text, nullable=False)
# privacy policy
- privacy_policy = db.Column(db.String(255), nullable=True)
+ privacy_policy = mapped_column(db.String(255), nullable=True)
# custom_disclaimer
custom_disclaimer: Mapped[str] = mapped_column(sa.TEXT, default="")
@@ -108,11 +153,11 @@ class ApiToolProvider(Base):
def user(self) -> Account | None:
if not self.user_id:
return None
- return db.session.query(Account).filter(Account.id == self.user_id).first()
+ return db.session.query(Account).where(Account.id == self.user_id).first()
@property
def tenant(self) -> Tenant | None:
- return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
+ return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
class ToolLabelBinding(Base):
@@ -178,11 +223,11 @@ class WorkflowToolProvider(Base):
@property
def user(self) -> Account | None:
- return db.session.query(Account).filter(Account.id == self.user_id).first()
+ return db.session.query(Account).where(Account.id == self.user_id).first()
@property
def tenant(self) -> Tenant | None:
- return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
+ return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property
def parameter_configurations(self) -> list[WorkflowToolParameterConfiguration]:
@@ -190,7 +235,7 @@ class WorkflowToolProvider(Base):
@property
def app(self) -> App | None:
- return db.session.query(App).filter(App.id == self.app_id).first()
+ return db.session.query(App).where(App.id == self.app_id).first()
class MCPToolProvider(Base):
@@ -210,7 +255,7 @@ class MCPToolProvider(Base):
# name of the mcp provider
name: Mapped[str] = mapped_column(db.String(40), nullable=False)
# server identifier of the mcp provider
- server_identifier: Mapped[str] = mapped_column(db.String(24), nullable=False)
+ server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False)
# encrypted url of the mcp provider
server_url: Mapped[str] = mapped_column(db.Text, nullable=False)
# hash of server_url for uniqueness check
@@ -235,11 +280,11 @@ class MCPToolProvider(Base):
)
def load_user(self) -> Account | None:
- return db.session.query(Account).filter(Account.id == self.user_id).first()
+ return db.session.query(Account).where(Account.id == self.user_id).first()
@property
def tenant(self) -> Tenant | None:
- return db.session.query(Tenant).filter(Tenant.id == self.tenant_id).first()
+ return db.session.query(Tenant).where(Tenant.id == self.tenant_id).first()
@property
def credentials(self) -> dict:
@@ -281,18 +326,19 @@ class MCPToolProvider(Base):
@property
def decrypted_credentials(self) -> dict:
+ from core.helper.provider_cache import NoOpProviderCredentialCache
from core.tools.mcp_tool.provider import MCPToolProviderController
- from core.tools.utils.configuration import ProviderConfigEncrypter
+ from core.tools.utils.encryption import create_provider_encrypter
provider_controller = MCPToolProviderController._from_db(self)
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, _ = create_provider_encrypter(
tenant_id=self.tenant_id,
- config=list(provider_controller.get_credentials_schema()),
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.provider_id,
+ config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
+ cache=NoOpProviderCredentialCache(),
)
- return tool_configuration.decrypt(self.credentials, use_cache=False)
+
+ return encrypter.decrypt(self.credentials) # type: ignore
class ToolModelInvoke(Base):
@@ -303,33 +349,33 @@ class ToolModelInvoke(Base):
__tablename__ = "tool_model_invokes"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tool_model_invoke_pkey"),)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# who invoke this tool
- user_id = db.Column(StringUUID, nullable=False)
+ user_id = mapped_column(StringUUID, nullable=False)
# tenant id
- tenant_id = db.Column(StringUUID, nullable=False)
+ tenant_id = mapped_column(StringUUID, nullable=False)
# provider
- provider = db.Column(db.String(255), nullable=False)
+ provider = mapped_column(db.String(255), nullable=False)
# type
- tool_type = db.Column(db.String(40), nullable=False)
+ tool_type = mapped_column(db.String(40), nullable=False)
# tool name
- tool_name = db.Column(db.String(40), nullable=False)
+ tool_name = mapped_column(db.String(128), nullable=False)
# invoke parameters
- model_parameters = db.Column(db.Text, nullable=False)
+ model_parameters = mapped_column(db.Text, nullable=False)
# prompt messages
- prompt_messages = db.Column(db.Text, nullable=False)
+ prompt_messages = mapped_column(db.Text, nullable=False)
# invoke response
- model_response = db.Column(db.Text, nullable=False)
+ model_response = mapped_column(db.Text, nullable=False)
- prompt_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
- answer_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0"))
- answer_unit_price = db.Column(db.Numeric(10, 4), nullable=False)
- answer_price_unit = db.Column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
- provider_response_latency = db.Column(db.Float, nullable=False, server_default=db.text("0"))
- total_price = db.Column(db.Numeric(10, 7))
- currency = db.Column(db.String(255), nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ prompt_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
+ answer_tokens = mapped_column(db.Integer, nullable=False, server_default=db.text("0"))
+ answer_unit_price = mapped_column(db.Numeric(10, 4), nullable=False)
+ answer_price_unit = mapped_column(db.Numeric(10, 7), nullable=False, server_default=db.text("0.001"))
+ provider_response_latency = mapped_column(db.Float, nullable=False, server_default=db.text("0"))
+ total_price = mapped_column(db.Numeric(10, 7))
+ currency = mapped_column(db.String(255), nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@deprecated
@@ -346,18 +392,18 @@ class ToolConversationVariables(Base):
db.Index("conversation_id_idx", "conversation_id"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# conversation user id
- user_id = db.Column(StringUUID, nullable=False)
+ user_id = mapped_column(StringUUID, nullable=False)
# tenant id
- tenant_id = db.Column(StringUUID, nullable=False)
+ tenant_id = mapped_column(StringUUID, nullable=False)
# conversation id
- conversation_id = db.Column(StringUUID, nullable=False)
+ conversation_id = mapped_column(StringUUID, nullable=False)
# variables pool
- variables_str = db.Column(db.Text, nullable=False)
+ variables_str = mapped_column(db.Text, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
- updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def variables(self) -> Any:
@@ -406,26 +452,26 @@ class DeprecatedPublishedAppTool(Base):
db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# id of the app
- app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False)
+ app_id = mapped_column(StringUUID, ForeignKey("apps.id"), nullable=False)
- user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
+ user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
# who published this tool
- description = db.Column(db.Text, nullable=False)
+ description = mapped_column(db.Text, nullable=False)
# llm_description of the tool, for LLM
- llm_description = db.Column(db.Text, nullable=False)
+ llm_description = mapped_column(db.Text, nullable=False)
# query description, query will be seem as a parameter of the tool,
# to describe this parameter to llm, we need this field
- query_description = db.Column(db.Text, nullable=False)
+ query_description = mapped_column(db.Text, nullable=False)
# query name, the name of the query parameter
- query_name = db.Column(db.String(40), nullable=False)
+ query_name = mapped_column(db.String(40), nullable=False)
# name of the tool provider
- tool_name = db.Column(db.String(40), nullable=False)
+ tool_name = mapped_column(db.String(40), nullable=False)
# author
- author = db.Column(db.String(40), nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
- updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+ author = mapped_column(db.String(40), nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
+ updated_at = mapped_column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
@property
def description_i18n(self) -> I18nObject:
diff --git a/api/models/web.py b/api/models/web.py
index fe2f0c47f8..ce00f4010f 100644
--- a/api/models/web.py
+++ b/api/models/web.py
@@ -15,16 +15,18 @@ class SavedMessage(Base):
db.Index("saved_message_message_idx", "app_id", "message_id", "created_by_role", "created_by"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- app_id = db.Column(StringUUID, nullable=False)
- message_id = db.Column(StringUUID, nullable=False)
- created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ app_id = mapped_column(StringUUID, nullable=False)
+ message_id = mapped_column(StringUUID, nullable=False)
+ created_by_role = mapped_column(
+ db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")
+ )
+ created_by = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def message(self):
- return db.session.query(Message).filter(Message.id == self.message_id).first()
+ return db.session.query(Message).where(Message.id == self.message_id).first()
class PinnedConversation(Base):
@@ -34,9 +36,11 @@ class PinnedConversation(Base):
db.Index("pinned_conversation_conversation_idx", "app_id", "conversation_id", "created_by_role", "created_by"),
)
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- app_id = db.Column(StringUUID, nullable=False)
+ id = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
+ app_id = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID)
- created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'end_user'::character varying"))
- created_by = db.Column(StringUUID, nullable=False)
- created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
+ created_by_role = mapped_column(
+ db.String(255), nullable=False, server_default=db.text("'end_user'::character varying")
+ )
+ created_by = mapped_column(StringUUID, nullable=False)
+ created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
diff --git a/api/models/workflow.py b/api/models/workflow.py
index 77d48bec4f..79d96e42dd 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -1,7 +1,7 @@
import json
import logging
from collections.abc import Mapping, Sequence
-from datetime import UTC, datetime
+from datetime import datetime
from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Union
from uuid import uuid4
@@ -12,9 +12,11 @@ from sqlalchemy import orm
from core.file.constants import maybe_file_object
from core.file.models import File
from core.variables import utils as variable_utils
+from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes.enums import NodeType
from factories.variable_factory import TypeMismatchError, build_segment_with_type
+from libs.datetime_utils import naive_utc_now
from libs.helper import extract_tenant_id
from ._workflow_exc import NodeNotFoundError, WorkflowDataError
@@ -137,7 +139,7 @@ class Workflow(Base):
updated_at: Mapped[datetime] = mapped_column(
db.DateTime,
nullable=False,
- default=datetime.now(UTC).replace(tzinfo=None),
+ default=naive_utc_now(),
server_onupdate=func.current_timestamp(),
)
_environment_variables: Mapped[str] = mapped_column(
@@ -178,7 +180,7 @@ class Workflow(Base):
workflow.conversation_variables = conversation_variables or []
workflow.marked_name = marked_name
workflow.marked_comment = marked_comment
- workflow.created_at = datetime.now(UTC).replace(tzinfo=None)
+ workflow.created_at = naive_utc_now()
workflow.updated_at = workflow.created_at
return workflow
@@ -341,13 +343,13 @@ class Workflow(Base):
return (
db.session.query(WorkflowToolProvider)
- .filter(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id)
+ .where(WorkflowToolProvider.tenant_id == self.tenant_id, WorkflowToolProvider.app_id == self.app_id)
.count()
> 0
)
@property
- def environment_variables(self) -> Sequence[Variable]:
+ def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# TODO: find some way to init `self._environment_variables` when instance created.
if self._environment_variables is None:
self._environment_variables = "{}"
@@ -367,11 +369,15 @@ class Workflow(Base):
def decrypt_func(var):
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
- else:
+ elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
return var
+ else:
+ raise AssertionError("this statement should be unreachable.")
- results = list(map(decrypt_func, results))
- return results
+ decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list(
+ map(decrypt_func, results)
+ )
+ return decrypted_results
@environment_variables.setter
def environment_variables(self, value: Sequence[Variable]):
@@ -543,12 +549,12 @@ class WorkflowRun(Base):
from models.model import Message
return (
- db.session.query(Message).filter(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
+ db.session.query(Message).where(Message.app_id == self.app_id, Message.workflow_run_id == self.id).first()
)
@property
def workflow(self):
- return db.session.query(Workflow).filter(Workflow.id == self.workflow_id).first()
+ return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
def to_dict(self):
return {
@@ -902,7 +908,7 @@ _EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"])
def _naive_utc_datetime():
- return datetime.now(UTC).replace(tzinfo=None)
+ return naive_utc_now()
class WorkflowDraftVariable(Base):
diff --git a/api/pyproject.toml b/api/pyproject.toml
index 7f1efa671f..7ec8a91198 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "dify-api"
-version = "1.6.0"
+version = "1.7.0"
requires-python = ">=3.11,<3.13"
dependencies = [
diff --git a/api/repositories/__init__.py b/api/repositories/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py
new file mode 100644
index 0000000000..00a2d1f87d
--- /dev/null
+++ b/api/repositories/api_workflow_node_execution_repository.py
@@ -0,0 +1,197 @@
+"""
+Service-layer repository protocol for WorkflowNodeExecutionModel operations.
+
+This module provides a protocol interface for service-layer operations on WorkflowNodeExecutionModel
+that abstracts database queries currently done directly in service classes. This repository is
+specifically designed for service-layer needs and is separate from the core domain repository.
+
+The service repository handles operations that require access to database-specific fields like
+tenant_id, app_id, triggered_from, etc., which are not part of the core domain model.
+"""
+
+from collections.abc import Sequence
+from datetime import datetime
+from typing import Optional, Protocol
+
+from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from models.workflow import WorkflowNodeExecutionModel
+
+
+class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol):
+ """
+ Protocol for service-layer operations on WorkflowNodeExecutionModel.
+
+ This repository provides database access patterns specifically needed by service classes,
+ handling queries that involve database-specific fields and multi-tenancy concerns.
+
+ Key responsibilities:
+ - Manages database operations for workflow node executions
+ - Handles multi-tenant data isolation
+ - Provides batch processing capabilities
+ - Supports execution lifecycle management
+
+ Implementation notes:
+ - Returns database models directly (WorkflowNodeExecutionModel)
+ - Handles tenant/app filtering automatically
+ - Provides service-specific query patterns
+ - Focuses on database operations without domain logic
+ - Supports cleanup and maintenance operations
+ """
+
+ def get_node_last_execution(
+ self,
+ tenant_id: str,
+ app_id: str,
+ workflow_id: str,
+ node_id: str,
+ ) -> Optional[WorkflowNodeExecutionModel]:
+ """
+ Get the most recent execution for a specific node.
+
+ This method finds the latest execution of a specific node within a workflow,
+ ordered by creation time. Used primarily for debugging and inspection purposes.
+
+ Args:
+ tenant_id: The tenant identifier
+ app_id: The application identifier
+ workflow_id: The workflow identifier
+ node_id: The node identifier
+
+ Returns:
+ The most recent WorkflowNodeExecutionModel for the node, or None if not found
+ """
+ ...
+
+ def get_executions_by_workflow_run(
+ self,
+ tenant_id: str,
+ app_id: str,
+ workflow_run_id: str,
+ ) -> Sequence[WorkflowNodeExecutionModel]:
+ """
+ Get all node executions for a specific workflow run.
+
+ This method retrieves all node executions that belong to a specific workflow run,
+ ordered by index in descending order for proper trace visualization.
+
+ Args:
+ tenant_id: The tenant identifier
+ app_id: The application identifier
+ workflow_run_id: The workflow run identifier
+
+ Returns:
+ A sequence of WorkflowNodeExecutionModel instances ordered by index (desc)
+ """
+ ...
+
+ def get_execution_by_id(
+ self,
+ execution_id: str,
+ tenant_id: Optional[str] = None,
+ ) -> Optional[WorkflowNodeExecutionModel]:
+ """
+ Get a workflow node execution by its ID.
+
+ This method retrieves a specific execution by its unique identifier.
+ Tenant filtering is optional for cases where the execution ID is globally unique.
+
+ When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants.
+ If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should
+ set `tenant_id` to prevent horizontal privilege escalation.
+
+ Args:
+ execution_id: The execution identifier
+ tenant_id: Optional tenant identifier for additional filtering
+
+ Returns:
+ The WorkflowNodeExecutionModel if found, or None if not found
+ """
+ ...
+
+ def delete_expired_executions(
+ self,
+ tenant_id: str,
+ before_date: datetime,
+ batch_size: int = 1000,
+ ) -> int:
+ """
+ Delete workflow node executions that are older than the specified date.
+
+ This method is used for cleanup operations to remove expired executions
+ in batches to avoid overwhelming the database.
+
+ Args:
+ tenant_id: The tenant identifier
+ before_date: Delete executions created before this date
+ batch_size: Maximum number of executions to delete in one batch
+
+ Returns:
+ The number of executions deleted
+ """
+ ...
+
+ def delete_executions_by_app(
+ self,
+ tenant_id: str,
+ app_id: str,
+ batch_size: int = 1000,
+ ) -> int:
+ """
+ Delete all workflow node executions for a specific app.
+
+ This method is used when removing an app and all its related data.
+ Executions are deleted in batches to avoid overwhelming the database.
+
+ Args:
+ tenant_id: The tenant identifier
+ app_id: The application identifier
+ batch_size: Maximum number of executions to delete in one batch
+
+ Returns:
+ The total number of executions deleted
+ """
+ ...
+
+ def get_expired_executions_batch(
+ self,
+ tenant_id: str,
+ before_date: datetime,
+ batch_size: int = 1000,
+ ) -> Sequence[WorkflowNodeExecutionModel]:
+ """
+ Get a batch of expired workflow node executions for backup purposes.
+
+ This method retrieves expired executions without deleting them,
+ allowing the caller to backup the data before deletion.
+
+ Args:
+ tenant_id: The tenant identifier
+ before_date: Get executions created before this date
+ batch_size: Maximum number of executions to retrieve
+
+ Returns:
+ A sequence of WorkflowNodeExecutionModel instances
+ """
+ ...
+
+ def delete_executions_by_ids(
+ self,
+ execution_ids: Sequence[str],
+ ) -> int:
+ """
+ Delete workflow node executions by their IDs.
+
+ This method deletes specific executions by their IDs,
+ typically used after backing up the data.
+
+ This method does not perform tenant isolation checks. The caller is responsible for ensuring proper
+ data isolation between tenants. When execution IDs come from untrusted sources (e.g., API requests),
+ additional tenant validation should be implemented to prevent unauthorized access.
+
+ Args:
+ execution_ids: List of execution IDs to delete
+
+ Returns:
+ The number of executions deleted
+ """
+ ...
diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py
new file mode 100644
index 0000000000..59e7baeb79
--- /dev/null
+++ b/api/repositories/api_workflow_run_repository.py
@@ -0,0 +1,181 @@
+"""
+API WorkflowRun Repository Protocol
+
+This module defines the protocol for service-layer WorkflowRun operations.
+The repository provides an abstraction layer for WorkflowRun database operations
+used by service classes, separating service-layer concerns from core domain logic.
+
+Key Features:
+- Paginated workflow run queries with filtering
+- Bulk deletion operations with OSS backup support
+- Multi-tenant data isolation
+- Expired record cleanup with data retention
+- Service-layer specific query patterns
+
+Usage:
+ This protocol should be used by service classes that need to perform
+ WorkflowRun database operations. It provides a clean interface that
+ hides implementation details and supports dependency injection.
+
+Example:
+ ```python
+ from repositories.dify_api_repository_factory import DifyAPIRepositoryFactory
+
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+
+ # Get paginated workflow runs
+ runs = repo.get_paginated_workflow_runs(
+ tenant_id="tenant-123",
+ app_id="app-456",
+ triggered_from="debugging",
+ limit=20
+ )
+ ```
+"""
+
+from collections.abc import Sequence
+from datetime import datetime
+from typing import Optional, Protocol
+
+from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
+from libs.infinite_scroll_pagination import InfiniteScrollPagination
+from models.workflow import WorkflowRun
+
+
+class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
+ """
+ Protocol for service-layer WorkflowRun repository operations.
+
+ This protocol defines the interface for WorkflowRun database operations
+ that are specific to service-layer needs, including pagination, filtering,
+ and bulk operations with data backup support.
+ """
+
+ def get_paginated_workflow_runs(
+ self,
+ tenant_id: str,
+ app_id: str,
+ triggered_from: str,
+ limit: int = 20,
+ last_id: Optional[str] = None,
+ ) -> InfiniteScrollPagination:
+ """
+ Get paginated workflow runs with filtering.
+
+ Retrieves workflow runs for a specific app and trigger source with
+ cursor-based pagination support. Used primarily for debugging and
+ workflow run listing in the UI.
+
+ Args:
+ tenant_id: Tenant identifier for multi-tenant isolation
+ app_id: Application identifier
+ triggered_from: Filter by trigger source (e.g., "debugging", "app-run")
+ limit: Maximum number of records to return (default: 20)
+ last_id: Cursor for pagination - ID of the last record from previous page
+
+ Returns:
+ InfiniteScrollPagination object containing:
+ - data: List of WorkflowRun objects
+ - limit: Applied limit
+ - has_more: Boolean indicating if more records exist
+
+ Raises:
+ ValueError: If last_id is provided but the corresponding record doesn't exist
+ """
+ ...
+
+ def get_workflow_run_by_id(
+ self,
+ tenant_id: str,
+ app_id: str,
+ run_id: str,
+ ) -> Optional[WorkflowRun]:
+ """
+ Get a specific workflow run by ID.
+
+ Retrieves a single workflow run with tenant and app isolation.
+ Used for workflow run detail views and execution tracking.
+
+ Args:
+ tenant_id: Tenant identifier for multi-tenant isolation
+ app_id: Application identifier
+ run_id: Workflow run identifier
+
+ Returns:
+ WorkflowRun object if found, None otherwise
+ """
+ ...
+
+ def get_expired_runs_batch(
+ self,
+ tenant_id: str,
+ before_date: datetime,
+ batch_size: int = 1000,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Get a batch of expired workflow runs for cleanup.
+
+ Retrieves workflow runs created before the specified date for
+ cleanup operations. Used by scheduled tasks to remove old data
+ while maintaining data retention policies.
+
+ Args:
+ tenant_id: Tenant identifier for multi-tenant isolation
+ before_date: Only return runs created before this date
+ batch_size: Maximum number of records to return
+
+ Returns:
+ Sequence of WorkflowRun objects to be processed for cleanup
+ """
+ ...
+
+ def delete_runs_by_ids(
+ self,
+ run_ids: Sequence[str],
+ ) -> int:
+ """
+ Delete workflow runs by their IDs.
+
+ Performs bulk deletion of workflow runs by ID. This method should
+ be used after backing up the data to OSS storage for retention.
+
+ Args:
+ run_ids: Sequence of workflow run IDs to delete
+
+ Returns:
+ Number of records actually deleted
+
+ Note:
+ This method performs hard deletion. Ensure data is backed up
+ to OSS storage before calling this method for compliance with
+ data retention policies.
+ """
+ ...
+
+ def delete_runs_by_app(
+ self,
+ tenant_id: str,
+ app_id: str,
+ batch_size: int = 1000,
+ ) -> int:
+ """
+ Delete all workflow runs for a specific app.
+
+ Performs bulk deletion of all workflow runs associated with an app.
+ Used during app cleanup operations. Processes records in batches
+ to avoid memory issues and long-running transactions.
+
+ Args:
+ tenant_id: Tenant identifier for multi-tenant isolation
+ app_id: Application identifier
+ batch_size: Number of records to process in each batch
+
+ Returns:
+ Total number of records deleted across all batches
+
+ Note:
+ This method performs hard deletion without backup. Use with caution
+ and ensure proper data retention policies are followed.
+ """
+ ...
diff --git a/api/repositories/factory.py b/api/repositories/factory.py
new file mode 100644
index 0000000000..0a0adbf2c2
--- /dev/null
+++ b/api/repositories/factory.py
@@ -0,0 +1,103 @@
+"""
+DifyAPI Repository Factory for creating repository instances.
+
+This factory is specifically designed for DifyAPI repositories that handle
+service-layer operations with dependency injection patterns.
+"""
+
+import logging
+
+from sqlalchemy.orm import sessionmaker
+
+from configs import dify_config
+from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError
+from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+
+logger = logging.getLogger(__name__)
+
+
+class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
+ """
+ Factory for creating DifyAPI repository instances based on configuration.
+
+ This factory handles the creation of repositories that are specifically designed
+ for service-layer operations and use dependency injection with sessionmaker
+ for better testability and separation of concerns.
+ """
+
+ @classmethod
+ def create_api_workflow_node_execution_repository(
+ cls, session_maker: sessionmaker
+ ) -> DifyAPIWorkflowNodeExecutionRepository:
+ """
+ Create a DifyAPIWorkflowNodeExecutionRepository instance based on configuration.
+
+ This repository is designed for service-layer operations and uses dependency injection
+ with a sessionmaker for better testability and separation of concerns. It provides
+ database access patterns specifically needed by service classes, handling queries
+ that involve database-specific fields and multi-tenancy concerns.
+
+ Args:
+ session_maker: SQLAlchemy sessionmaker to inject for database session management.
+
+ Returns:
+ Configured DifyAPIWorkflowNodeExecutionRepository instance
+
+ Raises:
+ RepositoryImportError: If the configured repository cannot be imported or instantiated
+ """
+ class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY
+ logger.debug(f"Creating DifyAPIWorkflowNodeExecutionRepository from: {class_path}")
+
+ try:
+ repository_class = cls._import_class(class_path)
+ cls._validate_repository_interface(repository_class, DifyAPIWorkflowNodeExecutionRepository)
+ # Service repository requires session_maker parameter
+ cls._validate_constructor_signature(repository_class, ["session_maker"])
+
+ return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
+ except RepositoryImportError:
+ # Re-raise our custom errors as-is
+ raise
+ except Exception as e:
+ logger.exception("Failed to create DifyAPIWorkflowNodeExecutionRepository")
+ raise RepositoryImportError(
+ f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
+ ) from e
+
+ @classmethod
+ def create_api_workflow_run_repository(cls, session_maker: sessionmaker) -> APIWorkflowRunRepository:
+ """
+ Create an APIWorkflowRunRepository instance based on configuration.
+
+ This repository is designed for service-layer WorkflowRun operations and uses dependency
+ injection with a sessionmaker for better testability and separation of concerns. It provides
+ database access patterns specifically needed by service classes for workflow run management,
+ including pagination, filtering, and bulk operations.
+
+ Args:
+ session_maker: SQLAlchemy sessionmaker to inject for database session management.
+
+ Returns:
+ Configured APIWorkflowRunRepository instance
+
+ Raises:
+ RepositoryImportError: If the configured repository cannot be imported or instantiated
+ """
+ class_path = dify_config.API_WORKFLOW_RUN_REPOSITORY
+ logger.debug(f"Creating APIWorkflowRunRepository from: {class_path}")
+
+ try:
+ repository_class = cls._import_class(class_path)
+ cls._validate_repository_interface(repository_class, APIWorkflowRunRepository)
+ # Service repository requires session_maker parameter
+ cls._validate_constructor_signature(repository_class, ["session_maker"])
+
+ return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
+ except RepositoryImportError:
+ # Re-raise our custom errors as-is
+ raise
+ except Exception as e:
+ logger.exception("Failed to create APIWorkflowRunRepository")
+ raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e
diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py
new file mode 100644
index 0000000000..e6a23ddf9f
--- /dev/null
+++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py
@@ -0,0 +1,290 @@
+"""
+SQLAlchemy implementation of WorkflowNodeExecutionServiceRepository.
+
+This module provides a concrete implementation of the service repository protocol
+using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
+"""
+
+from collections.abc import Sequence
+from datetime import datetime
+from typing import Optional
+
+from sqlalchemy import delete, desc, select
+from sqlalchemy.orm import Session, sessionmaker
+
+from models.workflow import WorkflowNodeExecutionModel
+from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
+
+
+class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
+ """
+ SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
+
+ This repository provides service-layer database operations for WorkflowNodeExecutionModel
+ using SQLAlchemy 2.0 style queries. It implements the DifyAPIWorkflowNodeExecutionRepository
+ protocol with the following features:
+
+ - Multi-tenancy data isolation through tenant_id filtering
+ - Direct database model operations without domain conversion
+ - Batch processing for efficient large-scale operations
+ - Optimized query patterns for common access patterns
+ - Dependency injection for better testability and maintainability
+ - Session management and transaction handling with proper cleanup
+ - Maintenance operations for data lifecycle management
+ - Thread-safe database operations using session-per-request pattern
+ """
+
+ def __init__(self, session_maker: sessionmaker[Session]):
+ """
+ Initialize the repository with a sessionmaker.
+
+ Args:
+ session_maker: SQLAlchemy sessionmaker for creating database sessions
+ """
+ self._session_maker = session_maker
+
+ def get_node_last_execution(
+ self,
+ tenant_id: str,
+ app_id: str,
+ workflow_id: str,
+ node_id: str,
+ ) -> Optional[WorkflowNodeExecutionModel]:
+ """
+ Get the most recent execution for a specific node.
+
+ This method replicates the query pattern from WorkflowService.get_node_last_run()
+ using SQLAlchemy 2.0 style syntax.
+
+ Args:
+ tenant_id: The tenant identifier
+ app_id: The application identifier
+ workflow_id: The workflow identifier
+ node_id: The node identifier
+
+ Returns:
+ The most recent WorkflowNodeExecutionModel for the node, or None if not found
+ """
+ stmt = (
+ select(WorkflowNodeExecutionModel)
+ .where(
+ WorkflowNodeExecutionModel.tenant_id == tenant_id,
+ WorkflowNodeExecutionModel.app_id == app_id,
+ WorkflowNodeExecutionModel.workflow_id == workflow_id,
+ WorkflowNodeExecutionModel.node_id == node_id,
+ )
+ .order_by(desc(WorkflowNodeExecutionModel.created_at))
+ .limit(1)
+ )
+
+ with self._session_maker() as session:
+ return session.scalar(stmt)
+
+ def get_executions_by_workflow_run(
+ self,
+ tenant_id: str,
+ app_id: str,
+ workflow_run_id: str,
+ ) -> Sequence[WorkflowNodeExecutionModel]:
+ """
+ Get all node executions for a specific workflow run.
+
+ This method replicates the query pattern from WorkflowRunService.get_workflow_run_node_executions()
+ using SQLAlchemy 2.0 style syntax.
+
+ Args:
+ tenant_id: The tenant identifier
+ app_id: The application identifier
+ workflow_run_id: The workflow run identifier
+
+ Returns:
+ A sequence of WorkflowNodeExecutionModel instances ordered by index (desc)
+ """
+ stmt = (
+ select(WorkflowNodeExecutionModel)
+ .where(
+ WorkflowNodeExecutionModel.tenant_id == tenant_id,
+ WorkflowNodeExecutionModel.app_id == app_id,
+ WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
+ )
+ .order_by(desc(WorkflowNodeExecutionModel.index))
+ )
+
+ with self._session_maker() as session:
+ return session.execute(stmt).scalars().all()
+
+ def get_execution_by_id(
+ self,
+ execution_id: str,
+ tenant_id: Optional[str] = None,
+ ) -> Optional[WorkflowNodeExecutionModel]:
+ """
+ Get a workflow node execution by its ID.
+
+ This method replicates the query pattern from WorkflowDraftVariableService
+ and WorkflowService.single_step_run_workflow_node() using SQLAlchemy 2.0 style syntax.
+
+ When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants.
+ If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should
+ set `tenant_id` to prevent horizontal privilege escalation.
+
+ Args:
+ execution_id: The execution identifier
+ tenant_id: Optional tenant identifier for additional filtering
+
+ Returns:
+ The WorkflowNodeExecutionModel if found, or None if not found
+ """
+ stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution_id)
+
+ # Add tenant filtering if provided
+ if tenant_id is not None:
+ stmt = stmt.where(WorkflowNodeExecutionModel.tenant_id == tenant_id)
+
+ with self._session_maker() as session:
+ return session.scalar(stmt)
+
+ def delete_expired_executions(
+ self,
+ tenant_id: str,
+ before_date: datetime,
+ batch_size: int = 1000,
+ ) -> int:
+ """
+ Delete workflow node executions that are older than the specified date.
+
+ Args:
+ tenant_id: The tenant identifier
+ before_date: Delete executions created before this date
+ batch_size: Maximum number of executions to delete in one batch
+
+ Returns:
+ The number of executions deleted
+ """
+ total_deleted = 0
+
+ while True:
+ with self._session_maker() as session:
+ # Find executions to delete in batches
+ stmt = (
+ select(WorkflowNodeExecutionModel.id)
+ .where(
+ WorkflowNodeExecutionModel.tenant_id == tenant_id,
+ WorkflowNodeExecutionModel.created_at < before_date,
+ )
+ .limit(batch_size)
+ )
+
+ execution_ids = session.execute(stmt).scalars().all()
+ if not execution_ids:
+ break
+
+ # Delete the batch
+ delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
+ result = session.execute(delete_stmt)
+ session.commit()
+ total_deleted += result.rowcount
+
+ # If we deleted fewer than the batch size, we're done
+ if len(execution_ids) < batch_size:
+ break
+
+ return total_deleted
+
+ def delete_executions_by_app(
+ self,
+ tenant_id: str,
+ app_id: str,
+ batch_size: int = 1000,
+ ) -> int:
+ """
+ Delete all workflow node executions for a specific app.
+
+ Args:
+ tenant_id: The tenant identifier
+ app_id: The application identifier
+ batch_size: Maximum number of executions to delete in one batch
+
+ Returns:
+ The total number of executions deleted
+ """
+ total_deleted = 0
+
+ while True:
+ with self._session_maker() as session:
+ # Find executions to delete in batches
+ stmt = (
+ select(WorkflowNodeExecutionModel.id)
+ .where(
+ WorkflowNodeExecutionModel.tenant_id == tenant_id,
+ WorkflowNodeExecutionModel.app_id == app_id,
+ )
+ .limit(batch_size)
+ )
+
+ execution_ids = session.execute(stmt).scalars().all()
+ if not execution_ids:
+ break
+
+ # Delete the batch
+ delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
+ result = session.execute(delete_stmt)
+ session.commit()
+ total_deleted += result.rowcount
+
+ # If we deleted fewer than the batch size, we're done
+ if len(execution_ids) < batch_size:
+ break
+
+ return total_deleted
+
+ def get_expired_executions_batch(
+ self,
+ tenant_id: str,
+ before_date: datetime,
+ batch_size: int = 1000,
+ ) -> Sequence[WorkflowNodeExecutionModel]:
+ """
+ Get a batch of expired workflow node executions for backup purposes.
+
+ Args:
+ tenant_id: The tenant identifier
+ before_date: Get executions created before this date
+ batch_size: Maximum number of executions to retrieve
+
+ Returns:
+ A sequence of WorkflowNodeExecutionModel instances
+ """
+ stmt = (
+ select(WorkflowNodeExecutionModel)
+ .where(
+ WorkflowNodeExecutionModel.tenant_id == tenant_id,
+ WorkflowNodeExecutionModel.created_at < before_date,
+ )
+ .limit(batch_size)
+ )
+
+ with self._session_maker() as session:
+ return session.execute(stmt).scalars().all()
+
+ def delete_executions_by_ids(
+ self,
+ execution_ids: Sequence[str],
+ ) -> int:
+ """
+ Delete workflow node executions by their IDs.
+
+ Args:
+ execution_ids: List of execution IDs to delete
+
+ Returns:
+ The number of executions deleted
+ """
+ if not execution_ids:
+ return 0
+
+ with self._session_maker() as session:
+ stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
+ result = session.execute(stmt)
+ session.commit()
+ return result.rowcount
diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py
new file mode 100644
index 0000000000..ebd1d74b20
--- /dev/null
+++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py
@@ -0,0 +1,203 @@
+"""
+SQLAlchemy API WorkflowRun Repository Implementation
+
+This module provides the SQLAlchemy-based implementation of the APIWorkflowRunRepository
+protocol. It handles service-layer WorkflowRun database operations using SQLAlchemy 2.0
+style queries with proper session management and multi-tenant data isolation.
+
+Key Features:
+- SQLAlchemy 2.0 style queries for modern database operations
+- Cursor-based pagination for efficient large dataset handling
+- Bulk operations with batch processing for performance
+- Multi-tenant data isolation and security
+- Proper session management with dependency injection
+
+Implementation Notes:
+- Uses sessionmaker for consistent session management
+- Implements cursor-based pagination using created_at timestamps
+- Provides efficient bulk deletion with batch processing
+- Maintains data consistency with proper transaction handling
+"""
+
+import logging
+from collections.abc import Sequence
+from datetime import datetime
+from typing import Optional, cast
+
+from sqlalchemy import delete, select
+from sqlalchemy.orm import Session, sessionmaker
+
+from libs.infinite_scroll_pagination import InfiniteScrollPagination
+from models.workflow import WorkflowRun
+from repositories.api_workflow_run_repository import APIWorkflowRunRepository
+
+logger = logging.getLogger(__name__)
+
+
+class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
+ """
+ SQLAlchemy implementation of APIWorkflowRunRepository.
+
+ Provides service-layer WorkflowRun database operations using SQLAlchemy 2.0
+ style queries. Supports dependency injection through sessionmaker and
+ maintains proper multi-tenant data isolation.
+
+ Args:
+ session_maker: SQLAlchemy sessionmaker instance for database connections
+ """
+
+ def __init__(self, session_maker: sessionmaker[Session]) -> None:
+ """
+ Initialize the repository with a sessionmaker.
+
+ Args:
+ session_maker: SQLAlchemy sessionmaker for database connections
+ """
+ self._session_maker = session_maker
+
+ def get_paginated_workflow_runs(
+ self,
+ tenant_id: str,
+ app_id: str,
+ triggered_from: str,
+ limit: int = 20,
+ last_id: Optional[str] = None,
+ ) -> InfiniteScrollPagination:
+ """
+ Get paginated workflow runs with filtering.
+
+ Implements cursor-based pagination using created_at timestamps for
+ efficient handling of large datasets. Filters by tenant, app, and
+ trigger source for proper data isolation.
+ """
+ with self._session_maker() as session:
+ # Build base query with filters
+ base_stmt = select(WorkflowRun).where(
+ WorkflowRun.tenant_id == tenant_id,
+ WorkflowRun.app_id == app_id,
+ WorkflowRun.triggered_from == triggered_from,
+ )
+
+ if last_id:
+ # Get the last workflow run for cursor-based pagination
+ last_run_stmt = base_stmt.where(WorkflowRun.id == last_id)
+ last_workflow_run = session.scalar(last_run_stmt)
+
+ if not last_workflow_run:
+ raise ValueError("Last workflow run not exists")
+
+ # Get records created before the last run's timestamp
+ base_stmt = base_stmt.where(
+ WorkflowRun.created_at < last_workflow_run.created_at,
+ WorkflowRun.id != last_workflow_run.id,
+ )
+
+ # First page - get most recent records
+ workflow_runs = session.scalars(base_stmt.order_by(WorkflowRun.created_at.desc()).limit(limit + 1)).all()
+
+ # Check if there are more records for pagination
+ has_more = len(workflow_runs) > limit
+ if has_more:
+ workflow_runs = workflow_runs[:-1]
+
+ return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
+
+ def get_workflow_run_by_id(
+ self,
+ tenant_id: str,
+ app_id: str,
+ run_id: str,
+ ) -> Optional[WorkflowRun]:
+ """
+ Get a specific workflow run by ID with tenant and app isolation.
+ """
+ with self._session_maker() as session:
+ stmt = select(WorkflowRun).where(
+ WorkflowRun.tenant_id == tenant_id,
+ WorkflowRun.app_id == app_id,
+ WorkflowRun.id == run_id,
+ )
+ return cast(Optional[WorkflowRun], session.scalar(stmt))
+
+ def get_expired_runs_batch(
+ self,
+ tenant_id: str,
+ before_date: datetime,
+ batch_size: int = 1000,
+ ) -> Sequence[WorkflowRun]:
+ """
+ Get a batch of expired workflow runs for cleanup operations.
+ """
+ with self._session_maker() as session:
+ stmt = (
+ select(WorkflowRun)
+ .where(
+ WorkflowRun.tenant_id == tenant_id,
+ WorkflowRun.created_at < before_date,
+ )
+ .limit(batch_size)
+ )
+ return cast(Sequence[WorkflowRun], session.scalars(stmt).all())
+
+ def delete_runs_by_ids(
+ self,
+ run_ids: Sequence[str],
+ ) -> int:
+ """
+ Delete workflow runs by their IDs using bulk deletion.
+ """
+ if not run_ids:
+ return 0
+
+ with self._session_maker() as session:
+ stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
+ result = session.execute(stmt)
+ session.commit()
+
+ deleted_count = cast(int, result.rowcount)
+ logger.info(f"Deleted {deleted_count} workflow runs by IDs")
+ return deleted_count
+
+ def delete_runs_by_app(
+ self,
+ tenant_id: str,
+ app_id: str,
+ batch_size: int = 1000,
+ ) -> int:
+ """
+ Delete all workflow runs for a specific app in batches.
+ """
+ total_deleted = 0
+
+ while True:
+ with self._session_maker() as session:
+ # Get a batch of run IDs to delete
+ stmt = (
+ select(WorkflowRun.id)
+ .where(
+ WorkflowRun.tenant_id == tenant_id,
+ WorkflowRun.app_id == app_id,
+ )
+ .limit(batch_size)
+ )
+ run_ids = session.scalars(stmt).all()
+
+ if not run_ids:
+ break
+
+ # Delete the batch
+ delete_stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
+ result = session.execute(delete_stmt)
+ session.commit()
+
+ batch_deleted = result.rowcount
+ total_deleted += batch_deleted
+
+ logger.info(f"Deleted batch of {batch_deleted} workflow runs for app {app_id}")
+
+ # If we deleted fewer records than the batch size, we're done
+ if batch_deleted < batch_size:
+ break
+
+ logger.info(f"Total deleted {total_deleted} workflow runs for app {app_id}")
+ return total_deleted
diff --git a/api/schedule/check_upgradable_plugin_task.py b/api/schedule/check_upgradable_plugin_task.py
new file mode 100644
index 0000000000..c1d6018827
--- /dev/null
+++ b/api/schedule/check_upgradable_plugin_task.py
@@ -0,0 +1,49 @@
+import time
+
+import click
+
+import app
+from extensions.ext_database import db
+from models.account import TenantPluginAutoUpgradeStrategy
+from tasks.process_tenant_plugin_autoupgrade_check_task import process_tenant_plugin_autoupgrade_check_task
+
+AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL = 15 * 60 # 15 minutes
+
+
+@app.celery.task(queue="plugin")
+def check_upgradable_plugin_task():
+ click.echo(click.style("Start check upgradable plugin.", fg="green"))
+ start_at = time.perf_counter()
+
+ now_seconds_of_day = time.time() % 86400 - 30 # we assume the tz is UTC
+ click.echo(click.style("Now seconds of day: {}".format(now_seconds_of_day), fg="green"))
+
+ strategies = (
+ db.session.query(TenantPluginAutoUpgradeStrategy)
+ .filter(
+ TenantPluginAutoUpgradeStrategy.upgrade_time_of_day >= now_seconds_of_day,
+ TenantPluginAutoUpgradeStrategy.upgrade_time_of_day
+ < now_seconds_of_day + AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL,
+ TenantPluginAutoUpgradeStrategy.strategy_setting
+ != TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
+ )
+ .all()
+ )
+
+ for strategy in strategies:
+ process_tenant_plugin_autoupgrade_check_task.delay(
+ strategy.tenant_id,
+ strategy.strategy_setting,
+ strategy.upgrade_time_of_day,
+ strategy.upgrade_mode,
+ strategy.exclude_plugins,
+ strategy.include_plugins,
+ )
+
+ end_at = time.perf_counter()
+ click.echo(
+ click.style(
+ "Checked upgradable plugin success latency: {}".format(end_at - start_at),
+ fg="green",
+ )
+ )
diff --git a/api/schedule/clean_embedding_cache_task.py b/api/schedule/clean_embedding_cache_task.py
index 9efe120b7a..024e3d6f50 100644
--- a/api/schedule/clean_embedding_cache_task.py
+++ b/api/schedule/clean_embedding_cache_task.py
@@ -21,7 +21,7 @@ def clean_embedding_cache_task():
try:
embedding_ids = (
db.session.query(Embedding.id)
- .filter(Embedding.created_at < thirty_days_ago)
+ .where(Embedding.created_at < thirty_days_ago)
.order_by(Embedding.created_at.desc())
.limit(100)
.all()
diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py
index d02bc81f33..a6851e36e5 100644
--- a/api/schedule/clean_messages.py
+++ b/api/schedule/clean_messages.py
@@ -36,7 +36,7 @@ def clean_messages():
# Main query with join and filter
messages = (
db.session.query(Message)
- .filter(Message.created_at < plan_sandbox_clean_message_day)
+ .where(Message.created_at < plan_sandbox_clean_message_day)
.order_by(Message.created_at.desc())
.limit(100)
.all()
@@ -66,25 +66,25 @@ def clean_messages():
plan = plan_cache.decode()
if plan == "sandbox":
# clean related message
- db.session.query(MessageFeedback).filter(MessageFeedback.message_id == message.id).delete(
+ db.session.query(MessageFeedback).where(MessageFeedback.message_id == message.id).delete(
synchronize_session=False
)
- db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == message.id).delete(
+ db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message.id).delete(
synchronize_session=False
)
- db.session.query(MessageChain).filter(MessageChain.message_id == message.id).delete(
+ db.session.query(MessageChain).where(MessageChain.message_id == message.id).delete(
synchronize_session=False
)
- db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message.id).delete(
+ db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message.id).delete(
synchronize_session=False
)
- db.session.query(MessageFile).filter(MessageFile.message_id == message.id).delete(
+ db.session.query(MessageFile).where(MessageFile.message_id == message.id).delete(
synchronize_session=False
)
- db.session.query(SavedMessage).filter(SavedMessage.message_id == message.id).delete(
+ db.session.query(SavedMessage).where(SavedMessage.message_id == message.id).delete(
synchronize_session=False
)
- db.session.query(Message).filter(Message.id == message.id).delete()
+ db.session.query(Message).where(Message.id == message.id).delete()
db.session.commit()
end_at = time.perf_counter()
click.echo(click.style("Cleaned messages from db success latency: {}".format(end_at - start_at), fg="green"))
diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py
index c0cd42a226..72e2e73e65 100644
--- a/api/schedule/clean_unused_datasets_task.py
+++ b/api/schedule/clean_unused_datasets_task.py
@@ -27,7 +27,7 @@ def clean_unused_datasets_task():
# Subquery for counting new documents
document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
- .filter(
+ .where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
@@ -40,7 +40,7 @@ def clean_unused_datasets_task():
# Subquery for counting old documents
document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
- .filter(
+ .where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
@@ -55,7 +55,7 @@ def clean_unused_datasets_task():
select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
- .filter(
+ .where(
Dataset.created_at < plan_sandbox_clean_day,
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
@@ -72,7 +72,7 @@ def clean_unused_datasets_task():
for dataset in datasets:
dataset_query = (
db.session.query(DatasetQuery)
- .filter(DatasetQuery.created_at > plan_sandbox_clean_day, DatasetQuery.dataset_id == dataset.id)
+ .where(DatasetQuery.created_at > plan_sandbox_clean_day, DatasetQuery.dataset_id == dataset.id)
.all()
)
if not dataset_query or len(dataset_query) == 0:
@@ -80,7 +80,7 @@ def clean_unused_datasets_task():
# add auto disable log
documents = (
db.session.query(Document)
- .filter(
+ .where(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
@@ -99,9 +99,7 @@ def clean_unused_datasets_task():
index_processor.clean(dataset, None)
# update document
- update_params = {Document.enabled: False}
-
- db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params)
+ db.session.query(Document).filter_by(dataset_id=dataset.id).update({Document.enabled: False})
db.session.commit()
click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green"))
except Exception as e:
@@ -113,7 +111,7 @@ def clean_unused_datasets_task():
# Subquery for counting new documents
document_subquery_new = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
- .filter(
+ .where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
@@ -126,7 +124,7 @@ def clean_unused_datasets_task():
# Subquery for counting old documents
document_subquery_old = (
db.session.query(Document.dataset_id, func.count(Document.id).label("document_count"))
- .filter(
+ .where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
@@ -141,7 +139,7 @@ def clean_unused_datasets_task():
select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
- .filter(
+ .where(
Dataset.created_at < plan_pro_clean_day,
func.coalesce(document_subquery_new.c.document_count, 0) == 0,
func.coalesce(document_subquery_old.c.document_count, 0) > 0,
@@ -157,7 +155,7 @@ def clean_unused_datasets_task():
for dataset in datasets:
dataset_query = (
db.session.query(DatasetQuery)
- .filter(DatasetQuery.created_at > plan_pro_clean_day, DatasetQuery.dataset_id == dataset.id)
+ .where(DatasetQuery.created_at > plan_pro_clean_day, DatasetQuery.dataset_id == dataset.id)
.all()
)
if not dataset_query or len(dataset_query) == 0:
@@ -176,9 +174,7 @@ def clean_unused_datasets_task():
index_processor.clean(dataset, None)
# update document
- update_params = {Document.enabled: False}
-
- db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params)
+ db.session.query(Document).filter_by(dataset_id=dataset.id).update({Document.enabled: False})
db.session.commit()
click.echo(
click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")
diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py
index 8a02278de8..91953354e6 100644
--- a/api/schedule/create_tidb_serverless_task.py
+++ b/api/schedule/create_tidb_serverless_task.py
@@ -20,7 +20,7 @@ def create_tidb_serverless_task():
try:
# check the number of idle tidb serverless
idle_tidb_serverless_number = (
- db.session.query(TidbAuthBinding).filter(TidbAuthBinding.active == False).count()
+ db.session.query(TidbAuthBinding).where(TidbAuthBinding.active == False).count()
)
if idle_tidb_serverless_number >= tidb_serverless_number:
break
diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py
index 5ee813e1de..5911c98b0a 100644
--- a/api/schedule/mail_clean_document_notify_task.py
+++ b/api/schedule/mail_clean_document_notify_task.py
@@ -3,12 +3,12 @@ import time
from collections import defaultdict
import click
-from flask import render_template # type: ignore
import app
from configs import dify_config
from extensions.ext_database import db
from extensions.ext_mail import mail
+from libs.email_i18n import EmailType, get_email_i18n_service
from models.account import Account, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetAutoDisableLog
from services.feature_service import FeatureService
@@ -30,7 +30,7 @@ def mail_clean_document_notify_task():
# send document clean notify mail
try:
dataset_auto_disable_logs = (
- db.session.query(DatasetAutoDisableLog).filter(DatasetAutoDisableLog.notified == False).all()
+ db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.notified == False).all()
)
# group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
@@ -45,7 +45,7 @@ def mail_clean_document_notify_task():
if plan != "sandbox":
knowledge_details = []
# check tenant
- tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
+ tenant = db.session.query(Tenant).where(Tenant.id == tenant_id).first()
if not tenant:
continue
# check current owner
@@ -54,7 +54,7 @@ def mail_clean_document_notify_task():
)
if not current_owner_join:
continue
- account = db.session.query(Account).filter(Account.id == current_owner_join.account_id).first()
+ account = db.session.query(Account).where(Account.id == current_owner_join.account_id).first()
if not account:
continue
@@ -67,19 +67,21 @@ def mail_clean_document_notify_task():
)
for dataset_id, document_ids in dataset_auto_dataset_map.items():
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
if knowledge_details:
- html_content = render_template(
- "clean_document_job_mail_template-US.html",
- userName=account.email,
- knowledge_details=knowledge_details,
- url=url,
- )
- mail.send(
- to=account.email, subject="Dify Knowledge base auto disable notification", html=html_content
+ email_service = get_email_i18n_service()
+ email_service.send_email(
+ email_type=EmailType.DOCUMENT_CLEAN_NOTIFY,
+ language_code="en-US",
+ to=account.email,
+ template_context={
+ "userName": account.email,
+ "knowledge_details": knowledge_details,
+ "url": url,
+ },
)
# update notified to True
diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py
index e3a7021b9d..a05e1358ed 100644
--- a/api/schedule/queue_monitor_task.py
+++ b/api/schedule/queue_monitor_task.py
@@ -3,13 +3,12 @@ from datetime import datetime
from urllib.parse import urlparse
import click
-from flask import render_template
from redis import Redis
import app
from configs import dify_config
from extensions.ext_database import db
-from extensions.ext_mail import mail
+from libs.email_i18n import EmailType, get_email_i18n_service
# Create a dedicated Redis connection (using the same configuration as Celery)
celery_broker_url = dify_config.CELERY_BROKER_URL
@@ -39,18 +38,20 @@ def queue_monitor_task():
alter_emails = dify_config.QUEUE_MONITOR_ALERT_EMAILS
if alter_emails:
to_list = alter_emails.split(",")
+ email_service = get_email_i18n_service()
for to in to_list:
try:
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- html_content = render_template(
- "queue_monitor_alert_email_template_en-US.html",
- queue_name=queue_name,
- queue_length=queue_length,
- threshold=threshold,
- alert_time=current_time,
- )
- mail.send(
- to=to, subject="Alert: Dataset Queue pending tasks exceeded the limit", html=html_content
+ email_service.send_email(
+ email_type=EmailType.QUEUE_MONITOR_ALERT,
+ language_code="en-US",
+ to=to,
+ template_context={
+ "queue_name": queue_name,
+ "queue_length": queue_length,
+ "threshold": threshold,
+ "alert_time": current_time,
+ },
)
except Exception as e:
logging.exception(click.style("Exception occurred during sending email", fg="red"))
diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py
index ce4ecb6e7c..4d6c1f1877 100644
--- a/api/schedule/update_tidb_serverless_status_task.py
+++ b/api/schedule/update_tidb_serverless_status_task.py
@@ -17,7 +17,7 @@ def update_tidb_serverless_status_task():
# check the number of idle tidb serverless
tidb_serverless_list = (
db.session.query(TidbAuthBinding)
- .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
+ .where(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
.all()
)
if len(tidb_serverless_list) == 0:
diff --git a/api/services/account_service.py b/api/services/account_service.py
index 2ba6f4345b..eb57b675c4 100644
--- a/api/services/account_service.py
+++ b/api/services/account_service.py
@@ -17,6 +17,7 @@ from constants.languages import language_timezone_mapping, languages
from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client, redis_fallback
+from libs.datetime_utils import naive_utc_now
from libs.helper import RateLimiter, TokenManager
from libs.passport import PassportService
from libs.password import compare_password, hash_password, valid_password
@@ -28,6 +29,7 @@ from models.account import (
Tenant,
TenantAccountJoin,
TenantAccountRole,
+ TenantPluginAutoUpgradeStrategy,
TenantStatus,
)
from models.model import DifySetup
@@ -52,8 +54,17 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces
from services.feature_service import FeatureService
from tasks.delete_account_task import delete_account_task
from tasks.mail_account_deletion_task import send_account_deletion_verification_code
+from tasks.mail_change_mail_task import (
+ send_change_mail_completed_notification_task,
+ send_change_mail_task,
+)
from tasks.mail_email_code_login import send_email_code_login_mail_task
from tasks.mail_invite_member_task import send_invite_member_mail_task
+from tasks.mail_owner_transfer_task import (
+ send_new_owner_transfer_notify_email_task,
+ send_old_owner_transfer_notify_email_task,
+ send_owner_transfer_confirm_task,
+)
from tasks.mail_reset_password_task import send_reset_password_mail_task
@@ -75,8 +86,13 @@ class AccountService:
email_code_account_deletion_rate_limiter = RateLimiter(
prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1
)
+ change_email_rate_limiter = RateLimiter(prefix="change_email_rate_limit", max_attempts=1, time_window=60 * 1)
+ owner_transfer_rate_limiter = RateLimiter(prefix="owner_transfer_rate_limit", max_attempts=1, time_window=60 * 1)
+
LOGIN_MAX_ERROR_LIMITS = 5
FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5
+ CHANGE_EMAIL_MAX_ERROR_LIMITS = 5
+ OWNER_TRANSFER_MAX_ERROR_LIMITS = 5
@staticmethod
def _get_refresh_token_key(refresh_token: str) -> str:
@@ -124,8 +140,8 @@ class AccountService:
available_ta.current = True
db.session.commit()
- if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10):
- account.last_active_at = datetime.now(UTC).replace(tzinfo=None)
+ if naive_utc_now() - account.last_active_at > timedelta(minutes=10):
+ account.last_active_at = naive_utc_now()
db.session.commit()
return cast(Account, account)
@@ -169,7 +185,7 @@ class AccountService:
if account.status == AccountStatus.PENDING.value:
account.status = AccountStatus.ACTIVE.value
- account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
+ account.initialized_at = naive_utc_now()
db.session.commit()
@@ -307,7 +323,7 @@ class AccountService:
# If it exists, update the record
account_integrate.open_id = open_id
account_integrate.encrypted_token = "" # todo
- account_integrate.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ account_integrate.updated_at = naive_utc_now()
else:
# If it does not exist, create a new record
account_integrate = AccountIntegrate(
@@ -342,7 +358,7 @@ class AccountService:
@staticmethod
def update_login_info(account: Account, *, ip_address: str) -> None:
"""Update last login time and ip"""
- account.last_login_at = datetime.now(UTC).replace(tzinfo=None)
+ account.last_login_at = naive_utc_now()
account.last_login_ip = ip_address
db.session.add(account)
db.session.commit()
@@ -419,6 +435,117 @@ class AccountService:
cls.reset_password_rate_limiter.increment_rate_limit(account_email)
return token
+ @classmethod
+ def send_change_email_email(
+ cls,
+ account: Optional[Account] = None,
+ email: Optional[str] = None,
+ old_email: Optional[str] = None,
+ language: Optional[str] = "en-US",
+ phase: Optional[str] = None,
+ ):
+ account_email = account.email if account else email
+ if account_email is None:
+ raise ValueError("Email must be provided.")
+
+ if cls.change_email_rate_limiter.is_rate_limited(account_email):
+ from controllers.console.auth.error import EmailChangeRateLimitExceededError
+
+ raise EmailChangeRateLimitExceededError()
+
+ code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
+
+ send_change_mail_task.delay(
+ language=language,
+ to=account_email,
+ code=code,
+ phase=phase,
+ )
+ cls.change_email_rate_limiter.increment_rate_limit(account_email)
+ return token
+
+ @classmethod
+ def send_change_email_completed_notify_email(
+ cls,
+ account: Optional[Account] = None,
+ email: Optional[str] = None,
+ language: Optional[str] = "en-US",
+ ):
+ account_email = account.email if account else email
+ if account_email is None:
+ raise ValueError("Email must be provided.")
+
+ send_change_mail_completed_notification_task.delay(
+ language=language,
+ to=account_email,
+ )
+
+ @classmethod
+ def send_owner_transfer_email(
+ cls,
+ account: Optional[Account] = None,
+ email: Optional[str] = None,
+ language: Optional[str] = "en-US",
+ workspace_name: Optional[str] = "",
+ ):
+ account_email = account.email if account else email
+ if account_email is None:
+ raise ValueError("Email must be provided.")
+
+ if cls.owner_transfer_rate_limiter.is_rate_limited(account_email):
+ from controllers.console.auth.error import OwnerTransferRateLimitExceededError
+
+ raise OwnerTransferRateLimitExceededError()
+
+ code, token = cls.generate_owner_transfer_token(account_email, account)
+
+ send_owner_transfer_confirm_task.delay(
+ language=language,
+ to=account_email,
+ code=code,
+ workspace=workspace_name,
+ )
+ cls.owner_transfer_rate_limiter.increment_rate_limit(account_email)
+ return token
+
+ @classmethod
+ def send_old_owner_transfer_notify_email(
+ cls,
+ account: Optional[Account] = None,
+ email: Optional[str] = None,
+ language: Optional[str] = "en-US",
+ workspace_name: Optional[str] = "",
+ new_owner_email: Optional[str] = "",
+ ):
+ account_email = account.email if account else email
+ if account_email is None:
+ raise ValueError("Email must be provided.")
+
+ send_old_owner_transfer_notify_email_task.delay(
+ language=language,
+ to=account_email,
+ workspace=workspace_name,
+ new_owner_email=new_owner_email,
+ )
+
+ @classmethod
+ def send_new_owner_transfer_notify_email(
+ cls,
+ account: Optional[Account] = None,
+ email: Optional[str] = None,
+ language: Optional[str] = "en-US",
+ workspace_name: Optional[str] = "",
+ ):
+ account_email = account.email if account else email
+ if account_email is None:
+ raise ValueError("Email must be provided.")
+
+ send_new_owner_transfer_notify_email_task.delay(
+ language=language,
+ to=account_email,
+ workspace=workspace_name,
+ )
+
@classmethod
def generate_reset_password_token(
cls,
@@ -435,14 +562,64 @@ class AccountService:
)
return code, token
+ @classmethod
+ def generate_change_email_token(
+ cls,
+ email: str,
+ account: Optional[Account] = None,
+ code: Optional[str] = None,
+ old_email: Optional[str] = None,
+ additional_data: dict[str, Any] = {},
+ ):
+ if not code:
+ code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
+ additional_data["code"] = code
+ additional_data["old_email"] = old_email
+ token = TokenManager.generate_token(
+ account=account, email=email, token_type="change_email", additional_data=additional_data
+ )
+ return code, token
+
+ @classmethod
+ def generate_owner_transfer_token(
+ cls,
+ email: str,
+ account: Optional[Account] = None,
+ code: Optional[str] = None,
+ additional_data: dict[str, Any] = {},
+ ):
+ if not code:
+ code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)])
+ additional_data["code"] = code
+ token = TokenManager.generate_token(
+ account=account, email=email, token_type="owner_transfer", additional_data=additional_data
+ )
+ return code, token
+
@classmethod
def revoke_reset_password_token(cls, token: str):
TokenManager.revoke_token(token, "reset_password")
+ @classmethod
+ def revoke_change_email_token(cls, token: str):
+ TokenManager.revoke_token(token, "change_email")
+
+ @classmethod
+ def revoke_owner_transfer_token(cls, token: str):
+ TokenManager.revoke_token(token, "owner_transfer")
+
@classmethod
def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]:
return TokenManager.get_token_data(token, "reset_password")
+ @classmethod
+ def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]:
+ return TokenManager.get_token_data(token, "change_email")
+
+ @classmethod
+ def get_owner_transfer_data(cls, token: str) -> Optional[dict[str, Any]]:
+ return TokenManager.get_token_data(token, "owner_transfer")
+
@classmethod
def send_email_code_login_email(
cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US"
@@ -485,7 +662,7 @@ class AccountService:
)
)
- account = db.session.query(Account).filter(Account.email == email).first()
+ account = db.session.query(Account).where(Account.email == email).first()
if not account:
return None
@@ -552,6 +729,62 @@ class AccountService:
key = f"forgot_password_error_rate_limit:{email}"
redis_client.delete(key)
+ @staticmethod
+ @redis_fallback(default_return=None)
+ def add_change_email_error_rate_limit(email: str) -> None:
+ key = f"change_email_error_rate_limit:{email}"
+ count = redis_client.get(key)
+ if count is None:
+ count = 0
+ count = int(count) + 1
+ redis_client.setex(key, dify_config.CHANGE_EMAIL_LOCKOUT_DURATION, count)
+
+ @staticmethod
+ @redis_fallback(default_return=False)
+ def is_change_email_error_rate_limit(email: str) -> bool:
+ key = f"change_email_error_rate_limit:{email}"
+ count = redis_client.get(key)
+ if count is None:
+ return False
+ count = int(count)
+ if count > AccountService.CHANGE_EMAIL_MAX_ERROR_LIMITS:
+ return True
+ return False
+
+ @staticmethod
+ @redis_fallback(default_return=None)
+ def reset_change_email_error_rate_limit(email: str):
+ key = f"change_email_error_rate_limit:{email}"
+ redis_client.delete(key)
+
+ @staticmethod
+ @redis_fallback(default_return=None)
+ def add_owner_transfer_error_rate_limit(email: str) -> None:
+ key = f"owner_transfer_error_rate_limit:{email}"
+ count = redis_client.get(key)
+ if count is None:
+ count = 0
+ count = int(count) + 1
+ redis_client.setex(key, dify_config.OWNER_TRANSFER_LOCKOUT_DURATION, count)
+
+ @staticmethod
+ @redis_fallback(default_return=False)
+ def is_owner_transfer_error_rate_limit(email: str) -> bool:
+ key = f"owner_transfer_error_rate_limit:{email}"
+ count = redis_client.get(key)
+ if count is None:
+ return False
+ count = int(count)
+ if count > AccountService.OWNER_TRANSFER_MAX_ERROR_LIMITS:
+ return True
+ return False
+
+ @staticmethod
+ @redis_fallback(default_return=None)
+ def reset_owner_transfer_error_rate_limit(email: str):
+ key = f"owner_transfer_error_rate_limit:{email}"
+ redis_client.delete(key)
+
@staticmethod
@redis_fallback(default_return=False)
def is_email_send_ip_limit(ip_address: str):
@@ -593,6 +826,10 @@ class AccountService:
return False
+ @staticmethod
+ def check_email_unique(email: str) -> bool:
+ return db.session.query(Account).filter_by(email=email).first() is None
+
class TenantService:
@staticmethod
@@ -611,6 +848,17 @@ class TenantService:
db.session.add(tenant)
db.session.commit()
+ plugin_upgrade_strategy = TenantPluginAutoUpgradeStrategy(
+ tenant_id=tenant.id,
+ strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
+ upgrade_time_of_day=0,
+ upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
+ exclude_plugins=[],
+ include_plugins=[],
+ )
+ db.session.add(plugin_upgrade_strategy)
+ db.session.commit()
+
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.commit()
return tenant
@@ -671,7 +919,7 @@ class TenantService:
return (
db.session.query(Tenant)
.join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id)
- .filter(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
+ .where(TenantAccountJoin.account_id == account.id, Tenant.status == TenantStatus.NORMAL)
.all()
)
@@ -700,7 +948,7 @@ class TenantService:
tenant_account_join = (
db.session.query(TenantAccountJoin)
.join(Tenant, TenantAccountJoin.tenant_id == Tenant.id)
- .filter(
+ .where(
TenantAccountJoin.account_id == account.id,
TenantAccountJoin.tenant_id == tenant_id,
Tenant.status == TenantStatus.NORMAL,
@@ -711,7 +959,7 @@ class TenantService:
if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else:
- db.session.query(TenantAccountJoin).filter(
+ db.session.query(TenantAccountJoin).where(
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
).update({"current": False})
tenant_account_join.current = True
@@ -726,7 +974,7 @@ class TenantService:
db.session.query(Account, TenantAccountJoin.role)
.select_from(Account)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
- .filter(TenantAccountJoin.tenant_id == tenant.id)
+ .where(TenantAccountJoin.tenant_id == tenant.id)
)
# Initialize an empty list to store the updated accounts
@@ -745,8 +993,8 @@ class TenantService:
db.session.query(Account, TenantAccountJoin.role)
.select_from(Account)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
- .filter(TenantAccountJoin.tenant_id == tenant.id)
- .filter(TenantAccountJoin.role == "dataset_operator")
+ .where(TenantAccountJoin.tenant_id == tenant.id)
+ .where(TenantAccountJoin.role == "dataset_operator")
)
# Initialize an empty list to store the updated accounts
@@ -766,9 +1014,7 @@ class TenantService:
return (
db.session.query(TenantAccountJoin)
- .filter(
- TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles])
- )
+ .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.role.in_([role.value for role in roles]))
.first()
is not None
)
@@ -778,10 +1024,10 @@ class TenantService:
"""Get the role of the current account for a given tenant"""
join = (
db.session.query(TenantAccountJoin)
- .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
+ .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == account.id)
.first()
)
- return join.role if join else None
+ return TenantAccountRole(join.role) if join else None
@staticmethod
def get_tenant_count() -> int:
@@ -850,21 +1096,21 @@ class TenantService:
target_member_join.role = new_role
db.session.commit()
- @staticmethod
- def dissolve_tenant(tenant: Tenant, operator: Account) -> None:
- """Dissolve tenant"""
- if not TenantService.check_member_permission(tenant, operator, operator, "remove"):
- raise NoPermissionError("No permission to dissolve tenant.")
- db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete()
- db.session.delete(tenant)
- db.session.commit()
-
@staticmethod
def get_custom_config(tenant_id: str) -> dict:
tenant = db.get_or_404(Tenant, tenant_id)
return cast(dict, tenant.custom_config_dict)
+ @staticmethod
+ def is_owner(account: Account, tenant: Tenant) -> bool:
+ return TenantService.get_user_role(account, tenant) == TenantAccountRole.OWNER
+
+ @staticmethod
+ def is_member(account: Account, tenant: Tenant) -> bool:
+ """Check if the account is a member of the tenant"""
+ return TenantService.get_user_role(account, tenant) is not None
+
class RegisterService:
@classmethod
@@ -892,7 +1138,7 @@ class RegisterService:
)
account.last_login_ip = ip_address
- account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
+ account.initialized_at = naive_utc_now()
TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True)
@@ -933,7 +1179,7 @@ class RegisterService:
is_setup=is_setup,
)
account.status = AccountStatus.ACTIVE.value if not status else status.value
- account.initialized_at = datetime.now(UTC).replace(tzinfo=None)
+ account.initialized_at = naive_utc_now()
if open_id is not None and provider is not None:
AccountService.link_account_integrate(provider, open_id, account)
@@ -1045,7 +1291,7 @@ class RegisterService:
tenant = (
db.session.query(Tenant)
- .filter(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
+ .where(Tenant.id == invitation_data["workspace_id"], Tenant.status == "normal")
.first()
)
@@ -1055,7 +1301,7 @@ class RegisterService:
tenant_account = (
db.session.query(Account, TenantAccountJoin.role)
.join(TenantAccountJoin, Account.id == TenantAccountJoin.account_id)
- .filter(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
+ .where(Account.email == invitation_data["email"], TenantAccountJoin.tenant_id == tenant.id)
.first()
)
diff --git a/api/services/agent_service.py b/api/services/agent_service.py
index 503b31ede2..7c6df2428f 100644
--- a/api/services/agent_service.py
+++ b/api/services/agent_service.py
@@ -25,7 +25,7 @@ class AgentService:
conversation: Conversation | None = (
db.session.query(Conversation)
- .filter(
+ .where(
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
)
@@ -37,7 +37,7 @@ class AgentService:
message: Optional[Message] = (
db.session.query(Message)
- .filter(
+ .where(
Message.id == message_id,
Message.conversation_id == conversation_id,
)
@@ -52,12 +52,10 @@ class AgentService:
if conversation.from_end_user_id:
# only select name field
executor = (
- db.session.query(EndUser, EndUser.name).filter(EndUser.id == conversation.from_end_user_id).first()
+ db.session.query(EndUser, EndUser.name).where(EndUser.id == conversation.from_end_user_id).first()
)
else:
- executor = (
- db.session.query(Account, Account.name).filter(Account.id == conversation.from_account_id).first()
- )
+ executor = db.session.query(Account, Account.name).where(Account.id == conversation.from_account_id).first()
if executor:
executor = executor.name
diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py
index 8c950abc24..7cb0b46517 100644
--- a/api/services/annotation_service.py
+++ b/api/services/annotation_service.py
@@ -26,7 +26,7 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
- .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+ .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
@@ -35,7 +35,7 @@ class AppAnnotationService:
if args.get("message_id"):
message_id = str(args["message_id"])
# get message info
- message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app.id).first()
+ message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app.id).first()
if not message:
raise NotFound("Message Not Exists.")
@@ -61,9 +61,7 @@ class AppAnnotationService:
db.session.add(annotation)
db.session.commit()
# if annotation reply is enabled , add annotation to index
- annotation_setting = (
- db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
- )
+ annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
@@ -117,7 +115,7 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
- .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+ .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
@@ -126,8 +124,8 @@ class AppAnnotationService:
if keyword:
stmt = (
select(MessageAnnotation)
- .filter(MessageAnnotation.app_id == app_id)
- .filter(
+ .where(MessageAnnotation.app_id == app_id)
+ .where(
or_(
MessageAnnotation.question.ilike("%{}%".format(keyword)),
MessageAnnotation.content.ilike("%{}%".format(keyword)),
@@ -138,7 +136,7 @@ class AppAnnotationService:
else:
stmt = (
select(MessageAnnotation)
- .filter(MessageAnnotation.app_id == app_id)
+ .where(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
)
annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False)
@@ -149,7 +147,7 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
- .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+ .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
@@ -157,7 +155,7 @@ class AppAnnotationService:
raise NotFound("App not found")
annotations = (
db.session.query(MessageAnnotation)
- .filter(MessageAnnotation.app_id == app_id)
+ .where(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc())
.all()
)
@@ -168,7 +166,7 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
- .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+ .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
@@ -181,9 +179,7 @@ class AppAnnotationService:
db.session.add(annotation)
db.session.commit()
# if annotation reply is enabled , add annotation to index
- annotation_setting = (
- db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
- )
+ annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
@@ -199,14 +195,14 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
- .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+ .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
- annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
+ annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
@@ -217,7 +213,7 @@ class AppAnnotationService:
db.session.commit()
# if annotation reply is enabled , add annotation to index
app_annotation_setting = (
- db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
+ db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if app_annotation_setting:
@@ -236,14 +232,14 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
- .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+ .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
- annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
+ annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
@@ -252,7 +248,7 @@ class AppAnnotationService:
annotation_hit_histories = (
db.session.query(AppAnnotationHitHistory)
- .filter(AppAnnotationHitHistory.annotation_id == annotation_id)
+ .where(AppAnnotationHitHistory.annotation_id == annotation_id)
.all()
)
if annotation_hit_histories:
@@ -262,7 +258,7 @@ class AppAnnotationService:
db.session.commit()
# if annotation reply is enabled , delete annotation index
app_annotation_setting = (
- db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
+ db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if app_annotation_setting:
@@ -275,7 +271,7 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
- .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+ .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
@@ -314,21 +310,21 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
- .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+ .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
- annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
+ annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
if not annotation:
raise NotFound("Annotation not found")
stmt = (
select(AppAnnotationHitHistory)
- .filter(
+ .where(
AppAnnotationHitHistory.app_id == app_id,
AppAnnotationHitHistory.annotation_id == annotation_id,
)
@@ -341,7 +337,7 @@ class AppAnnotationService:
@classmethod
def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None:
- annotation = db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).first()
+ annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
if not annotation:
return None
@@ -361,7 +357,7 @@ class AppAnnotationService:
score: float,
):
# add hit count to annotation
- db.session.query(MessageAnnotation).filter(MessageAnnotation.id == annotation_id).update(
+ db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).update(
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False
)
@@ -384,16 +380,14 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
- .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+ .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app:
raise NotFound("App not found")
- annotation_setting = (
- db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
- )
+ annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
return {
@@ -412,7 +406,7 @@ class AppAnnotationService:
# get app info
app = (
db.session.query(App)
- .filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
+ .where(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
@@ -421,7 +415,7 @@ class AppAnnotationService:
annotation_setting = (
db.session.query(AppAnnotationSetting)
- .filter(
+ .where(
AppAnnotationSetting.app_id == app_id,
AppAnnotationSetting.id == annotation_setting_id,
)
diff --git a/api/services/api_based_extension_service.py b/api/services/api_based_extension_service.py
index 601d67d2fb..457c91e5c0 100644
--- a/api/services/api_based_extension_service.py
+++ b/api/services/api_based_extension_service.py
@@ -73,7 +73,7 @@ class APIBasedExtensionService:
db.session.query(APIBasedExtension)
.filter_by(tenant_id=extension_data.tenant_id)
.filter_by(name=extension_data.name)
- .filter(APIBasedExtension.id != extension_data.id)
+ .where(APIBasedExtension.id != extension_data.id)
.first()
)
diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py
index 20257fa345..fe0efd061d 100644
--- a/api/services/app_dsl_service.py
+++ b/api/services/app_dsl_service.py
@@ -41,7 +41,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
-CURRENT_DSL_VERSION = "0.3.0"
+CURRENT_DSL_VERSION = "0.3.1"
class ImportMode(StrEnum):
@@ -575,13 +575,26 @@ class AppDslService:
raise ValueError("Missing draft workflow configuration, please check.")
workflow_dict = workflow.to_dict(include_secret=include_secret)
+ # TODO: refactor: we need a better way to filter workspace related data from nodes
for node in workflow_dict.get("graph", {}).get("nodes", []):
- if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
- dataset_ids = node["data"].get("dataset_ids", [])
- node["data"]["dataset_ids"] = [
+ node_data = node.get("data", {})
+ if not node_data:
+ continue
+ data_type = node_data.get("type", "")
+ if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value:
+ dataset_ids = node_data.get("dataset_ids", [])
+ node_data["dataset_ids"] = [
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id)
for dataset_id in dataset_ids
]
+ # filter credential id from tool node
+ if not include_secret and data_type == NodeType.TOOL.value:
+ node_data.pop("credential_id", None)
+ # filter credential id from agent node
+ if not include_secret and data_type == NodeType.AGENT.value:
+ for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []):
+ tool.pop("credential_id", None)
+
export_data["workflow"] = workflow_dict
dependencies = cls._extract_dependencies_from_workflow(workflow)
export_data["dependencies"] = [
@@ -602,7 +615,15 @@ class AppDslService:
if not app_model_config:
raise ValueError("Missing app configuration, please check.")
- export_data["model_config"] = app_model_config.to_dict()
+ model_config = app_model_config.to_dict()
+
+ # TODO: refactor: we need a better way to filter workspace related data from model config
+ # filter credential id from model config
+ for tool in model_config.get("agent_mode", {}).get("tools", []):
+ tool.pop("credential_id", None)
+
+ export_data["model_config"] = model_config
+
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
export_data["dependencies"] = [
jsonable_encoder(d.model_dump())
diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py
index 245c123a04..6f7e705b52 100644
--- a/api/services/app_generate_service.py
+++ b/api/services/app_generate_service.py
@@ -129,11 +129,25 @@ class AppGenerateService:
rate_limit.exit(request_id)
@staticmethod
- def _get_max_active_requests(app_model: App) -> int:
- max_active_requests = app_model.max_active_requests
- if max_active_requests is None:
- max_active_requests = int(dify_config.APP_MAX_ACTIVE_REQUESTS)
- return max_active_requests
+ def _get_max_active_requests(app: App) -> int:
+ """
+ Get the maximum number of active requests allowed for an app.
+
+ Returns the smaller value between app's custom limit and global config limit.
+ A value of 0 means infinite (no limit).
+
+ Args:
+ app: The App model instance
+
+ Returns:
+ The maximum number of active requests allowed
+ """
+ app_limit = app.max_active_requests or 0
+ config_limit = dify_config.APP_MAX_ACTIVE_REQUESTS
+
+ # Filter out infinite (0) values and return the minimum, or 0 if both are infinite
+ limits = [limit for limit in [app_limit, config_limit] if limit > 0]
+ return min(limits) if limits else 0
@classmethod
def generate_single_iteration(cls, app_model: App, user: Account, node_id: str, args: Any, streaming: bool = True):
diff --git a/api/services/app_service.py b/api/services/app_service.py
index d08462d001..0b6b85bcb2 100644
--- a/api/services/app_service.py
+++ b/api/services/app_service.py
@@ -1,7 +1,6 @@
import json
import logging
-from datetime import UTC, datetime
-from typing import Optional, cast
+from typing import Optional, TypedDict, cast
from flask_login import current_user
from flask_sqlalchemy.pagination import Pagination
@@ -17,6 +16,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models.account import Account
from models.model import App, AppMode, AppModelConfig, Site
from models.tools import ApiToolProvider
@@ -47,8 +47,6 @@ class AppService:
filters.append(App.mode == AppMode.ADVANCED_CHAT.value)
elif args["mode"] == "agent-chat":
filters.append(App.mode == AppMode.AGENT_CHAT.value)
- elif args["mode"] == "channel":
- filters.append(App.mode == AppMode.CHANNEL.value)
if args.get("is_created_by_me", False):
filters.append(App.created_by == user_id)
@@ -222,21 +220,31 @@ class AppService:
return app
- def update_app(self, app: App, args: dict) -> App:
+ class ArgsDict(TypedDict):
+ name: str
+ description: str
+ icon_type: str
+ icon: str
+ icon_background: str
+ use_icon_as_answer_icon: bool
+ max_active_requests: int
+
+ def update_app(self, app: App, args: ArgsDict) -> App:
"""
Update app
:param app: App instance
:param args: request args
:return: App instance
"""
- app.name = args.get("name")
- app.description = args.get("description", "")
- app.icon_type = args.get("icon_type", "emoji")
- app.icon = args.get("icon")
- app.icon_background = args.get("icon_background")
+ app.name = args["name"]
+ app.description = args["description"]
+ app.icon_type = args["icon_type"]
+ app.icon = args["icon"]
+ app.icon_background = args["icon_background"]
app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False)
+ app.max_active_requests = args.get("max_active_requests")
app.updated_by = current_user.id
- app.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ app.updated_at = naive_utc_now()
db.session.commit()
return app
@@ -250,7 +258,7 @@ class AppService:
"""
app.name = name
app.updated_by = current_user.id
- app.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ app.updated_at = naive_utc_now()
db.session.commit()
return app
@@ -266,7 +274,7 @@ class AppService:
app.icon = icon
app.icon_background = icon_background
app.updated_by = current_user.id
- app.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ app.updated_at = naive_utc_now()
db.session.commit()
return app
@@ -283,7 +291,7 @@ class AppService:
app.enable_site = enable_site
app.updated_by = current_user.id
- app.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ app.updated_at = naive_utc_now()
db.session.commit()
return app
@@ -300,7 +308,7 @@ class AppService:
app.enable_api = enable_api
app.updated_by = current_user.id
- app.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ app.updated_at = naive_utc_now()
db.session.commit()
return app
@@ -374,7 +382,7 @@ class AppService:
elif provider_type == "api":
try:
provider: Optional[ApiToolProvider] = (
- db.session.query(ApiToolProvider).filter(ApiToolProvider.id == provider_id).first()
+ db.session.query(ApiToolProvider).where(ApiToolProvider.id == provider_id).first()
)
if provider is None:
raise ValueError(f"provider not found for tool {tool_name}")
@@ -391,7 +399,7 @@ class AppService:
:param app_id: app id
:return: app code
"""
- site = db.session.query(Site).filter(Site.app_id == app_id).first()
+ site = db.session.query(Site).where(Site.app_id == app_id).first()
if not site:
raise ValueError(f"App with id {app_id} not found")
return str(site.code)
@@ -403,7 +411,7 @@ class AppService:
:param app_code: app code
:return: app id
"""
- site = db.session.query(Site).filter(Site.code == app_code).first()
+ site = db.session.query(Site).where(Site.code == app_code).first()
if not site:
raise ValueError(f"App with code {app_code} not found")
return str(site.app_id)
diff --git a/api/services/audio_service.py b/api/services/audio_service.py
index e8923eb51b..0084eebb32 100644
--- a/api/services/audio_service.py
+++ b/api/services/audio_service.py
@@ -135,7 +135,7 @@ class AudioService:
uuid.UUID(message_id)
except ValueError:
return None
- message = db.session.query(Message).filter(Message.id == message_id).first()
+ message = db.session.query(Message).where(Message.id == message_id).first()
if message is None:
return None
if message.answer == "" and message.status == MessageStatus.NORMAL:
diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py
index e5f4a3ef6e..996e9187f3 100644
--- a/api/services/auth/api_key_auth_service.py
+++ b/api/services/auth/api_key_auth_service.py
@@ -11,7 +11,7 @@ class ApiKeyAuthService:
def get_provider_auth_list(tenant_id: str) -> list:
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
- .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
+ .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False))
.all()
)
return data_source_api_key_bindings
@@ -36,7 +36,7 @@ class ApiKeyAuthService:
def get_auth_credentials(tenant_id: str, category: str, provider: str):
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
- .filter(
+ .where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.category == category,
DataSourceApiKeyAuthBinding.provider == provider,
@@ -53,7 +53,7 @@ class ApiKeyAuthService:
def delete_provider_auth(tenant_id: str, binding_id: str):
data_source_api_key_binding = (
db.session.query(DataSourceApiKeyAuthBinding)
- .filter(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
+ .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
.first()
)
if data_source_api_key_binding:
diff --git a/api/services/billing_service.py b/api/services/billing_service.py
index d44483ad89..5a12aa2e54 100644
--- a/api/services/billing_service.py
+++ b/api/services/billing_service.py
@@ -75,14 +75,14 @@ class BillingService:
join: Optional[TenantAccountJoin] = (
db.session.query(TenantAccountJoin)
- .filter(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
+ .where(TenantAccountJoin.tenant_id == tenant_id, TenantAccountJoin.account_id == current_user.id)
.first()
)
if not join:
raise ValueError("Tenant account join not found")
- if not TenantAccountRole.is_privileged_role(join.role):
+ if not TenantAccountRole.is_privileged_role(TenantAccountRole(join.role)):
raise ValueError("Only team owner or team admin can perform this action")
@classmethod
diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py
index 1fd560d581..ad9b750d40 100644
--- a/api/services/clear_free_plan_tenant_expired_logs.py
+++ b/api/services/clear_free_plan_tenant_expired_logs.py
@@ -6,7 +6,7 @@ from concurrent.futures import ThreadPoolExecutor
import click
from flask import Flask, current_app
-from sqlalchemy.orm import Session
+from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -14,7 +14,7 @@ from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Tenant
from models.model import App, Conversation, Message
-from models.workflow import WorkflowNodeExecutionModel, WorkflowRun
+from repositories.factory import DifyAPIRepositoryFactory
from services.billing_service import BillingService
logger = logging.getLogger(__name__)
@@ -24,13 +24,13 @@ class ClearFreePlanTenantExpiredLogs:
@classmethod
def process_tenant(cls, flask_app: Flask, tenant_id: str, days: int, batch: int):
with flask_app.app_context():
- apps = db.session.query(App).filter(App.tenant_id == tenant_id).all()
+ apps = db.session.query(App).where(App.tenant_id == tenant_id).all()
app_ids = [app.id for app in apps]
while True:
with Session(db.engine).no_autoflush as session:
messages = (
session.query(Message)
- .filter(
+ .where(
Message.app_id.in_(app_ids),
Message.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
@@ -54,7 +54,7 @@ class ClearFreePlanTenantExpiredLogs:
message_ids = [message.id for message in messages]
# delete messages
- session.query(Message).filter(
+ session.query(Message).where(
Message.id.in_(message_ids),
).delete(synchronize_session=False)
@@ -70,7 +70,7 @@ class ClearFreePlanTenantExpiredLogs:
with Session(db.engine).no_autoflush as session:
conversations = (
session.query(Conversation)
- .filter(
+ .where(
Conversation.app_id.in_(app_ids),
Conversation.updated_at < datetime.datetime.now() - datetime.timedelta(days=days),
)
@@ -93,7 +93,7 @@ class ClearFreePlanTenantExpiredLogs:
)
conversation_ids = [conversation.id for conversation in conversations]
- session.query(Conversation).filter(
+ session.query(Conversation).where(
Conversation.id.in_(conversation_ids),
).delete(synchronize_session=False)
session.commit()
@@ -105,84 +105,99 @@ class ClearFreePlanTenantExpiredLogs:
)
)
- while True:
- with Session(db.engine).no_autoflush as session:
- workflow_node_executions = (
- session.query(WorkflowNodeExecutionModel)
- .filter(
- WorkflowNodeExecutionModel.tenant_id == tenant_id,
- WorkflowNodeExecutionModel.created_at
- < datetime.datetime.now() - datetime.timedelta(days=days),
- )
- .limit(batch)
- .all()
- )
+ # Process expired workflow node executions with backup
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
+ before_date = datetime.datetime.now() - datetime.timedelta(days=days)
+ total_deleted = 0
- if len(workflow_node_executions) == 0:
- break
+ while True:
+ # Get a batch of expired executions for backup
+ workflow_node_executions = node_execution_repo.get_expired_executions_batch(
+ tenant_id=tenant_id,
+ before_date=before_date,
+ batch_size=batch,
+ )
- # save workflow node executions
- storage.save(
- f"free_plan_tenant_expired_logs/"
- f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}"
- f"-{time.time()}.json",
- json.dumps(
- jsonable_encoder(workflow_node_executions),
- ).encode("utf-8"),
- )
+ if len(workflow_node_executions) == 0:
+ break
+
+ # Save workflow node executions to storage
+ storage.save(
+ f"free_plan_tenant_expired_logs/"
+ f"{tenant_id}/workflow_node_executions/{datetime.datetime.now().strftime('%Y-%m-%d')}"
+ f"-{time.time()}.json",
+ json.dumps(
+ jsonable_encoder(workflow_node_executions),
+ ).encode("utf-8"),
+ )
- workflow_node_execution_ids = [
- workflow_node_execution.id for workflow_node_execution in workflow_node_executions
- ]
+ # Extract IDs for deletion
+ workflow_node_execution_ids = [
+ workflow_node_execution.id for workflow_node_execution in workflow_node_executions
+ ]
- # delete workflow node executions
- session.query(WorkflowNodeExecutionModel).filter(
- WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids),
- ).delete(synchronize_session=False)
- session.commit()
+ # Delete the backed up executions
+ deleted_count = node_execution_repo.delete_executions_by_ids(workflow_node_execution_ids)
+ total_deleted += deleted_count
- click.echo(
- click.style(
- f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}"
- f" workflow node executions for tenant {tenant_id}"
- )
+ click.echo(
+ click.style(
+ f"[{datetime.datetime.now()}] Processed {len(workflow_node_execution_ids)}"
+ f" workflow node executions for tenant {tenant_id}"
)
+ )
+
+ # If we got fewer than the batch size, we're done
+ if len(workflow_node_executions) < batch:
+ break
+
+ # Process expired workflow runs with backup
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+ before_date = datetime.datetime.now() - datetime.timedelta(days=days)
+ total_deleted = 0
while True:
- with Session(db.engine).no_autoflush as session:
- workflow_runs = (
- session.query(WorkflowRun)
- .filter(
- WorkflowRun.tenant_id == tenant_id,
- WorkflowRun.created_at < datetime.datetime.now() - datetime.timedelta(days=days),
- )
- .limit(batch)
- .all()
- )
+ # Get a batch of expired workflow runs for backup
+ workflow_runs = workflow_run_repo.get_expired_runs_batch(
+ tenant_id=tenant_id,
+ before_date=before_date,
+ batch_size=batch,
+ )
- if len(workflow_runs) == 0:
- break
+ if len(workflow_runs) == 0:
+ break
+
+ # Save workflow runs to storage
+ storage.save(
+ f"free_plan_tenant_expired_logs/"
+ f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
+ f"-{time.time()}.json",
+ json.dumps(
+ jsonable_encoder(
+ [workflow_run.to_dict() for workflow_run in workflow_runs],
+ ),
+ ).encode("utf-8"),
+ )
- # save workflow runs
+ # Extract IDs for deletion
+ workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs]
- storage.save(
- f"free_plan_tenant_expired_logs/"
- f"{tenant_id}/workflow_runs/{datetime.datetime.now().strftime('%Y-%m-%d')}"
- f"-{time.time()}.json",
- json.dumps(
- jsonable_encoder(
- [workflow_run.to_dict() for workflow_run in workflow_runs],
- ),
- ).encode("utf-8"),
- )
+ # Delete the backed up workflow runs
+ deleted_count = workflow_run_repo.delete_runs_by_ids(workflow_run_ids)
+ total_deleted += deleted_count
- workflow_run_ids = [workflow_run.id for workflow_run in workflow_runs]
+ click.echo(
+ click.style(
+ f"[{datetime.datetime.now()}] Processed {len(workflow_run_ids)}"
+ f" workflow runs for tenant {tenant_id}"
+ )
+ )
- # delete workflow runs
- session.query(WorkflowRun).filter(
- WorkflowRun.id.in_(workflow_run_ids),
- ).delete(synchronize_session=False)
- session.commit()
+ # If we got fewer than the batch size, we're done
+ if len(workflow_runs) < batch:
+ break
@classmethod
def process(cls, days: int, batch: int, tenant_ids: list[str]):
@@ -261,7 +276,7 @@ class ClearFreePlanTenantExpiredLogs:
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)
- .filter(Tenant.created_at.between(current_time, current_time + test_interval))
+ .where(Tenant.created_at.between(current_time, current_time + test_interval))
.count()
)
if tenant_count <= 100:
@@ -286,7 +301,7 @@ class ClearFreePlanTenantExpiredLogs:
rs = (
session.query(Tenant.id)
- .filter(Tenant.created_at.between(current_time, batch_end))
+ .where(Tenant.created_at.between(current_time, batch_end))
.order_by(Tenant.created_at)
)
diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py
index afdaa49465..525c87fe4a 100644
--- a/api/services/conversation_service.py
+++ b/api/services/conversation_service.py
@@ -1,5 +1,4 @@
from collections.abc import Callable, Sequence
-from datetime import UTC, datetime
from typing import Optional, Union
from sqlalchemy import asc, desc, func, or_, select
@@ -8,6 +7,7 @@ from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import ConversationVariable
from models.account import Account
@@ -113,7 +113,7 @@ class ConversationService:
return cls.auto_generate_name(app_model, conversation)
else:
conversation.name = name
- conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ conversation.updated_at = naive_utc_now()
db.session.commit()
return conversation
@@ -123,7 +123,7 @@ class ConversationService:
# get conversation first message
message = (
db.session.query(Message)
- .filter(Message.app_id == app_model.id, Message.conversation_id == conversation.id)
+ .where(Message.app_id == app_model.id, Message.conversation_id == conversation.id)
.order_by(Message.created_at.asc())
.first()
)
@@ -148,7 +148,7 @@ class ConversationService:
def get_conversation(cls, app_model: App, conversation_id: str, user: Optional[Union[Account, EndUser]]):
conversation = (
db.session.query(Conversation)
- .filter(
+ .where(
Conversation.id == conversation_id,
Conversation.app_id == app_model.id,
Conversation.from_source == ("api" if isinstance(user, EndUser) else "console"),
@@ -169,7 +169,7 @@ class ConversationService:
conversation = cls.get_conversation(app_model, conversation_id, user)
conversation.is_deleted = True
- conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ conversation.updated_at = naive_utc_now()
db.session.commit()
@classmethod
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index e42b5ace75..4872702a76 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -26,6 +26,7 @@ from events.document_event import document_was_deleted
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs import helper
+from libs.datetime_utils import naive_utc_now
from models.account import Account, TenantAccountRole
from models.dataset import (
AppDatasetJoin,
@@ -79,7 +80,7 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
class DatasetService:
@staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
- query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
+ query = select(Dataset).where(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
if user:
# get permitted dataset ids
@@ -91,14 +92,14 @@ class DatasetService:
if user.current_role == TenantAccountRole.DATASET_OPERATOR:
# only show datasets that the user has permission to access
if permitted_dataset_ids:
- query = query.filter(Dataset.id.in_(permitted_dataset_ids))
+ query = query.where(Dataset.id.in_(permitted_dataset_ids))
else:
return [], 0
else:
if user.current_role != TenantAccountRole.OWNER or not include_all:
# show all datasets that the user has permission to access
if permitted_dataset_ids:
- query = query.filter(
+ query = query.where(
db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_(
@@ -111,7 +112,7 @@ class DatasetService:
)
)
else:
- query = query.filter(
+ query = query.where(
db.or_(
Dataset.permission == DatasetPermissionEnum.ALL_TEAM,
db.and_(
@@ -121,15 +122,15 @@ class DatasetService:
)
else:
# if no user, only show datasets that are shared with all team members
- query = query.filter(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
+ query = query.where(Dataset.permission == DatasetPermissionEnum.ALL_TEAM)
if search:
- query = query.filter(Dataset.name.ilike(f"%{search}%"))
+ query = query.where(Dataset.name.ilike(f"%{search}%"))
if tag_ids:
target_ids = TagService.get_target_ids_by_tag_ids("knowledge", tenant_id, tag_ids)
if target_ids:
- query = query.filter(Dataset.id.in_(target_ids))
+ query = query.where(Dataset.id.in_(target_ids))
else:
return [], 0
@@ -142,7 +143,7 @@ class DatasetService:
# get the latest process rule
dataset_process_rule = (
db.session.query(DatasetProcessRule)
- .filter(DatasetProcessRule.dataset_id == dataset_id)
+ .where(DatasetProcessRule.dataset_id == dataset_id)
.order_by(DatasetProcessRule.created_at.desc())
.limit(1)
.one_or_none()
@@ -157,7 +158,7 @@ class DatasetService:
@staticmethod
def get_datasets_by_ids(ids, tenant_id):
- stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
+ stmt = select(Dataset).where(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
@@ -214,9 +215,9 @@ class DatasetService:
dataset.created_by = account.id
dataset.updated_by = account.id
dataset.tenant_id = tenant_id
- dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
- dataset.embedding_model = embedding_model.model if embedding_model else None
- dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
+ dataset.embedding_model_provider = embedding_model.provider if embedding_model else None # type: ignore
+ dataset.embedding_model = embedding_model.model if embedding_model else None # type: ignore
+ dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None # type: ignore
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
dataset.provider = provider
db.session.add(dataset)
@@ -428,7 +429,7 @@ class DatasetService:
# Add metadata fields
filtered_data["updated_by"] = user.id
- filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ filtered_data["updated_at"] = naive_utc_now()
# update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"]
@@ -696,7 +697,7 @@ class DatasetService:
def get_related_apps(dataset_id: str):
return (
db.session.query(AppDatasetJoin)
- .filter(AppDatasetJoin.dataset_id == dataset_id)
+ .where(AppDatasetJoin.dataset_id == dataset_id)
.order_by(db.desc(AppDatasetJoin.created_at))
.all()
)
@@ -713,7 +714,7 @@ class DatasetService:
start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog)
- .filter(
+ .where(
DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date,
)
@@ -842,7 +843,7 @@ class DocumentService:
def get_document(dataset_id: str, document_id: Optional[str] = None) -> Optional[Document]:
if document_id:
document = (
- db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
return document
else:
@@ -850,7 +851,7 @@ class DocumentService:
@staticmethod
def get_document_by_id(document_id: str) -> Optional[Document]:
- document = db.session.query(Document).filter(Document.id == document_id).first()
+ document = db.session.query(Document).where(Document.id == document_id).first()
return document
@@ -858,7 +859,7 @@ class DocumentService:
def get_document_by_ids(document_ids: list[str]) -> list[Document]:
documents = (
db.session.query(Document)
- .filter(
+ .where(
Document.id.in_(document_ids),
Document.enabled == True,
Document.indexing_status == "completed",
@@ -872,7 +873,7 @@ class DocumentService:
def get_document_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
- .filter(
+ .where(
Document.dataset_id == dataset_id,
Document.enabled == True,
)
@@ -885,7 +886,7 @@ class DocumentService:
def get_working_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
- .filter(
+ .where(
Document.dataset_id == dataset_id,
Document.enabled == True,
Document.indexing_status == "completed",
@@ -900,7 +901,7 @@ class DocumentService:
def get_error_documents_by_dataset_id(dataset_id: str) -> list[Document]:
documents = (
db.session.query(Document)
- .filter(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
+ .where(Document.dataset_id == dataset_id, Document.indexing_status.in_(["error", "paused"]))
.all()
)
return documents
@@ -909,7 +910,7 @@ class DocumentService:
def get_batch_documents(dataset_id: str, batch: str) -> list[Document]:
documents = (
db.session.query(Document)
- .filter(
+ .where(
Document.batch == batch,
Document.dataset_id == dataset_id,
Document.tenant_id == current_user.current_tenant_id,
@@ -921,7 +922,7 @@ class DocumentService:
@staticmethod
def get_document_file_detail(file_id: str):
- file_detail = db.session.query(UploadFile).filter(UploadFile.id == file_id).one_or_none()
+ file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none()
return file_detail
@staticmethod
@@ -949,7 +950,7 @@ class DocumentService:
@staticmethod
def delete_documents(dataset: Dataset, document_ids: list[str]):
- documents = db.session.query(Document).filter(Document.id.in_(document_ids)).all()
+ documents = db.session.query(Document).where(Document.id.in_(document_ids)).all()
file_ids = [
document.data_source_info_dict["upload_file_id"]
for document in documents
@@ -994,7 +995,7 @@ class DocumentService:
# update document to be paused
document.is_paused = True
document.paused_by = current_user.id
- document.paused_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.paused_at = naive_utc_now()
db.session.add(document)
db.session.commit()
@@ -1188,7 +1189,7 @@ class DocumentService:
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
- .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
+ .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
@@ -1269,7 +1270,7 @@ class DocumentService:
workspace_id = notion_info.workspace_id
data_source_binding = (
db.session.query(DataSourceOauthBinding)
- .filter(
+ .where(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
@@ -1412,7 +1413,7 @@ class DocumentService:
def get_tenant_documents_count():
documents_count = (
db.session.query(Document)
- .filter(
+ .where(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
@@ -1468,7 +1469,7 @@ class DocumentService:
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
- .filter(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
+ .where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
)
@@ -1488,7 +1489,7 @@ class DocumentService:
workspace_id = notion_info.workspace_id
data_source_binding = (
db.session.query(DataSourceOauthBinding)
- .filter(
+ .where(
db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion",
@@ -1539,8 +1540,10 @@ class DocumentService:
db.session.add(document)
db.session.commit()
# update document segment
- update_params = {DocumentSegment.status: "re_segment"}
- db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params)
+
+ db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
+ {DocumentSegment.status: "re_segment"}
+ ) # type: ignore
db.session.commit()
# trigger async task
document_indexing_update_task.delay(document.dataset_id, document.id)
@@ -2002,7 +2005,7 @@ class SegmentService:
with redis_client.lock(lock_name, timeout=600):
max_position = (
db.session.query(func.max(DocumentSegment.position))
- .filter(DocumentSegment.document_id == document.id)
+ .where(DocumentSegment.document_id == document.id)
.scalar()
)
segment_document = DocumentSegment(
@@ -2040,7 +2043,7 @@ class SegmentService:
segment_document.status = "error"
segment_document.error = str(e)
db.session.commit()
- segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_document.id).first()
+ segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_document.id).first()
return segment
@classmethod
@@ -2059,7 +2062,7 @@ class SegmentService:
)
max_position = (
db.session.query(func.max(DocumentSegment.position))
- .filter(DocumentSegment.document_id == document.id)
+ .where(DocumentSegment.document_id == document.id)
.scalar()
)
pre_segment_data_list = []
@@ -2198,7 +2201,7 @@ class SegmentService:
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
- .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
+ .where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
@@ -2225,7 +2228,7 @@ class SegmentService:
# calc embedding use tokens
if document.doc_form == "qa_model":
segment.answer = args.answer
- tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0]
+ tokens = embedding_model.get_text_embedding_num_tokens(texts=[content + segment.answer])[0] # type: ignore
else:
tokens = embedding_model.get_text_embedding_num_tokens(texts=[content])[0]
segment.content = content
@@ -2273,7 +2276,7 @@ class SegmentService:
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
- .filter(DatasetProcessRule.id == document.dataset_process_rule_id)
+ .where(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
@@ -2292,7 +2295,7 @@ class SegmentService:
segment.status = "error"
segment.error = str(e)
db.session.commit()
- new_segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment.id).first()
+ new_segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment.id).first()
return new_segment
@classmethod
@@ -2318,7 +2321,7 @@ class SegmentService:
index_node_ids = (
db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id)
- .filter(
+ .where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
@@ -2329,7 +2332,7 @@ class SegmentService:
index_node_ids = [index_node_id[0] for index_node_id in index_node_ids]
delete_segment_from_index_task.delay(index_node_ids, dataset.id, document.id)
- db.session.query(DocumentSegment).filter(DocumentSegment.id.in_(segment_ids)).delete()
+ db.session.query(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)).delete()
db.session.commit()
@classmethod
@@ -2337,7 +2340,7 @@ class SegmentService:
if action == "enable":
segments = (
db.session.query(DocumentSegment)
- .filter(
+ .where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
@@ -2364,7 +2367,7 @@ class SegmentService:
elif action == "disable":
segments = (
db.session.query(DocumentSegment)
- .filter(
+ .where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.document_id == document.id,
@@ -2401,7 +2404,7 @@ class SegmentService:
index_node_hash = helper.generate_text_hash(content)
child_chunk_count = (
db.session.query(ChildChunk)
- .filter(
+ .where(
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
@@ -2411,7 +2414,7 @@ class SegmentService:
)
max_position = (
db.session.query(func.max(ChildChunk.position))
- .filter(
+ .where(
ChildChunk.tenant_id == current_user.current_tenant_id,
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
@@ -2454,7 +2457,7 @@ class SegmentService:
) -> list[ChildChunk]:
child_chunks = (
db.session.query(ChildChunk)
- .filter(
+ .where(
ChildChunk.dataset_id == dataset.id,
ChildChunk.document_id == document.id,
ChildChunk.segment_id == segment.id,
@@ -2575,7 +2578,7 @@ class SegmentService:
"""Get a child chunk by its ID."""
result = (
db.session.query(ChildChunk)
- .filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
+ .where(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, ChildChunk) else None
@@ -2591,15 +2594,15 @@ class SegmentService:
limit: int = 20,
):
"""Get segments for a document with optional filtering."""
- query = select(DocumentSegment).filter(
+ query = select(DocumentSegment).where(
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
)
if status_list:
- query = query.filter(DocumentSegment.status.in_(status_list))
+ query = query.where(DocumentSegment.status.in_(status_list))
if keyword:
- query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
+ query = query.where(DocumentSegment.content.ilike(f"%{keyword}%"))
query = query.order_by(DocumentSegment.position.asc())
paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@@ -2612,7 +2615,7 @@ class SegmentService:
) -> tuple[DocumentSegment, Document]:
"""Update a segment by its ID with validation and checks."""
# check dataset
- dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first()
if not dataset:
raise NotFound("Dataset not found.")
@@ -2644,7 +2647,7 @@ class SegmentService:
# check segment
segment = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
+ .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
.first()
)
if not segment:
@@ -2661,7 +2664,7 @@ class SegmentService:
"""Get a segment by its ID."""
result = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
+ .where(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, DocumentSegment) else None
@@ -2674,7 +2677,7 @@ class DatasetCollectionBindingService:
) -> DatasetCollectionBinding:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
- .filter(
+ .where(
DatasetCollectionBinding.provider_name == provider_name,
DatasetCollectionBinding.model_name == model_name,
DatasetCollectionBinding.type == collection_type,
@@ -2700,7 +2703,7 @@ class DatasetCollectionBindingService:
) -> DatasetCollectionBinding:
dataset_collection_binding = (
db.session.query(DatasetCollectionBinding)
- .filter(
+ .where(
DatasetCollectionBinding.id == collection_binding_id, DatasetCollectionBinding.type == collection_type
)
.order_by(DatasetCollectionBinding.created_at)
@@ -2719,7 +2722,7 @@ class DatasetPermissionService:
db.session.query(
DatasetPermission.account_id,
)
- .filter(DatasetPermission.dataset_id == dataset_id)
+ .where(DatasetPermission.dataset_id == dataset_id)
.all()
)
@@ -2732,7 +2735,7 @@ class DatasetPermissionService:
@classmethod
def update_partial_member_list(cls, tenant_id, dataset_id, user_list):
try:
- db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete()
+ db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete()
permissions = []
for user in user_list:
permission = DatasetPermission(
@@ -2768,7 +2771,7 @@ class DatasetPermissionService:
@classmethod
def clear_partial_member_list(cls, dataset_id):
try:
- db.session.query(DatasetPermission).filter(DatasetPermission.dataset_id == dataset_id).delete()
+ db.session.query(DatasetPermission).where(DatasetPermission.dataset_id == dataset_id).delete()
db.session.commit()
except Exception as e:
db.session.rollback()
diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py
index 603064ca07..344c67885e 100644
--- a/api/services/entities/knowledge_entities/knowledge_entities.py
+++ b/api/services/entities/knowledge_entities/knowledge_entities.py
@@ -4,13 +4,6 @@ from typing import Literal, Optional
from pydantic import BaseModel
-class SegmentUpdateEntity(BaseModel):
- content: str
- answer: Optional[str] = None
- keywords: Optional[list[str]] = None
- enabled: Optional[bool] = None
-
-
class ParentMode(StrEnum):
FULL_DOC = "full-doc"
PARAGRAPH = "paragraph"
@@ -95,7 +88,7 @@ class WeightKeywordSetting(BaseModel):
class WeightModel(BaseModel):
- weight_type: Optional[str] = None
+ weight_type: Optional[Literal["semantic_first", "keyword_first", "customized"]] = None
vector_setting: Optional[WeightVectorSetting] = None
keyword_setting: Optional[WeightKeywordSetting] = None
@@ -153,10 +146,6 @@ class MetadataUpdateArgs(BaseModel):
value: Optional[str | int | float] = None
-class MetadataValueUpdateArgs(BaseModel):
- fields: list[MetadataUpdateArgs]
-
-
class MetadataDetail(BaseModel):
id: str
name: str
diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py
index eb50d79494..b7af03e91f 100644
--- a/api/services/external_knowledge_service.py
+++ b/api/services/external_knowledge_service.py
@@ -1,6 +1,5 @@
import json
from copy import deepcopy
-from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
from urllib.parse import urlparse
@@ -11,6 +10,7 @@ from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
from core.rag.entities.metadata_entities import MetadataCondition
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models.dataset import (
Dataset,
ExternalKnowledgeApis,
@@ -30,11 +30,11 @@ class ExternalDatasetService:
) -> tuple[list[ExternalKnowledgeApis], int | None]:
query = (
select(ExternalKnowledgeApis)
- .filter(ExternalKnowledgeApis.tenant_id == tenant_id)
+ .where(ExternalKnowledgeApis.tenant_id == tenant_id)
.order_by(ExternalKnowledgeApis.created_at.desc())
)
if search:
- query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
+ query = query.where(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
external_knowledge_apis = db.paginate(
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
@@ -120,7 +120,7 @@ class ExternalDatasetService:
external_knowledge_api.description = args.get("description", "")
external_knowledge_api.settings = json.dumps(args.get("settings"), ensure_ascii=False)
external_knowledge_api.updated_by = user_id
- external_knowledge_api.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ external_knowledge_api.updated_at = naive_utc_now()
db.session.commit()
return external_knowledge_api
diff --git a/api/services/feature_service.py b/api/services/feature_service.py
index 188caf3505..1441e6ce16 100644
--- a/api/services/feature_service.py
+++ b/api/services/feature_service.py
@@ -123,7 +123,7 @@ class FeatureModel(BaseModel):
dataset_operator_enabled: bool = False
webapp_copyright_enabled: bool = False
workspace_members: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0)
-
+ is_allow_transfer_workspace: bool = True
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
@@ -149,6 +149,7 @@ class SystemFeatureModel(BaseModel):
branding: BrandingModel = BrandingModel()
webapp_auth: WebAppAuthModel = WebAppAuthModel()
plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel()
+ enable_change_email: bool = True
class FeatureService:
@@ -186,6 +187,7 @@ class FeatureService:
if dify_config.ENTERPRISE_ENABLED:
system_features.branding.enabled = True
system_features.webapp_auth.enabled = True
+ system_features.enable_change_email = False
cls._fulfill_params_from_enterprise(system_features)
if dify_config.MARKETPLACE_ENABLED:
@@ -228,6 +230,8 @@ class FeatureService:
if features.billing.subscription.plan != "sandbox":
features.webapp_copyright_enabled = True
+ else:
+ features.is_allow_transfer_workspace = False
if "members" in billing_info:
features.members.size = billing_info["members"]["size"]
diff --git a/api/services/file_service.py b/api/services/file_service.py
index 286535bd18..e234c2f325 100644
--- a/api/services/file_service.py
+++ b/api/services/file_service.py
@@ -144,7 +144,7 @@ class FileService:
@staticmethod
def get_file_preview(file_id: str):
- upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
+ upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found")
@@ -167,7 +167,7 @@ class FileService:
if not result:
raise NotFound("File not found or signature is invalid")
- upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
+ upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
@@ -187,7 +187,7 @@ class FileService:
if not result:
raise NotFound("File not found or signature is invalid")
- upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
+ upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
@@ -198,7 +198,7 @@ class FileService:
@staticmethod
def get_public_image_preview(file_id: str):
- upload_file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
+ upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
diff --git a/api/services/message_service.py b/api/services/message_service.py
index 51b070ece7..283b7b9b4b 100644
--- a/api/services/message_service.py
+++ b/api/services/message_service.py
@@ -50,7 +50,7 @@ class MessageService:
if first_id:
first_message = (
db.session.query(Message)
- .filter(Message.conversation_id == conversation.id, Message.id == first_id)
+ .where(Message.conversation_id == conversation.id, Message.id == first_id)
.first()
)
@@ -59,7 +59,7 @@ class MessageService:
history_messages = (
db.session.query(Message)
- .filter(
+ .where(
Message.conversation_id == conversation.id,
Message.created_at < first_message.created_at,
Message.id != first_message.id,
@@ -71,7 +71,7 @@ class MessageService:
else:
history_messages = (
db.session.query(Message)
- .filter(Message.conversation_id == conversation.id)
+ .where(Message.conversation_id == conversation.id)
.order_by(Message.created_at.desc())
.limit(fetch_limit)
.all()
@@ -109,19 +109,19 @@ class MessageService:
app_model=app_model, user=user, conversation_id=conversation_id
)
- base_query = base_query.filter(Message.conversation_id == conversation.id)
+ base_query = base_query.where(Message.conversation_id == conversation.id)
if include_ids is not None:
- base_query = base_query.filter(Message.id.in_(include_ids))
+ base_query = base_query.where(Message.id.in_(include_ids))
if last_id:
- last_message = base_query.filter(Message.id == last_id).first()
+ last_message = base_query.where(Message.id == last_id).first()
if not last_message:
raise LastMessageNotExistsError()
history_messages = (
- base_query.filter(Message.created_at < last_message.created_at, Message.id != last_message.id)
+ base_query.where(Message.created_at < last_message.created_at, Message.id != last_message.id)
.order_by(Message.created_at.desc())
.limit(fetch_limit)
.all()
@@ -183,7 +183,7 @@ class MessageService:
offset = (page - 1) * limit
feedbacks = (
db.session.query(MessageFeedback)
- .filter(MessageFeedback.app_id == app_model.id)
+ .where(MessageFeedback.app_id == app_model.id)
.order_by(MessageFeedback.created_at.desc(), MessageFeedback.id.desc())
.limit(limit)
.offset(offset)
@@ -196,7 +196,7 @@ class MessageService:
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
message = (
db.session.query(Message)
- .filter(
+ .where(
Message.id == message_id,
Message.app_id == app_model.id,
Message.from_source == ("api" if isinstance(user, EndUser) else "console"),
@@ -248,9 +248,7 @@ class MessageService:
if not conversation.override_model_configs:
app_model_config = (
db.session.query(AppModelConfig)
- .filter(
- AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id
- )
+ .where(AppModelConfig.id == conversation.app_model_config_id, AppModelConfig.app_id == app_model.id)
.first()
)
else:
diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py
index 26311a6377..a200cfa146 100644
--- a/api/services/model_load_balancing_service.py
+++ b/api/services/model_load_balancing_service.py
@@ -103,7 +103,7 @@ class ModelLoadBalancingService:
# Get load balancing configurations
load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
- .filter(
+ .where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
@@ -219,7 +219,7 @@ class ModelLoadBalancingService:
# Get load balancing configurations
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
- .filter(
+ .where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
@@ -307,7 +307,7 @@ class ModelLoadBalancingService:
current_load_balancing_configs = (
db.session.query(LoadBalancingModelConfig)
- .filter(
+ .where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider_configuration.provider.provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
@@ -457,7 +457,7 @@ class ModelLoadBalancingService:
# Get load balancing config
load_balancing_model_config = (
db.session.query(LoadBalancingModelConfig)
- .filter(
+ .where(
LoadBalancingModelConfig.tenant_id == tenant_id,
LoadBalancingModelConfig.provider_name == provider,
LoadBalancingModelConfig.model_type == model_type_enum.to_origin_model_type(),
diff --git a/api/services/ops_service.py b/api/services/ops_service.py
index dbeb4f1908..62f37c1588 100644
--- a/api/services/ops_service.py
+++ b/api/services/ops_service.py
@@ -17,7 +17,7 @@ class OpsService:
"""
trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig)
- .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
+ .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
@@ -25,7 +25,7 @@ class OpsService:
return None
# decrypt_token and obfuscated_token
- app = db.session.query(App).filter(App.id == app_id).first()
+ app = db.session.query(App).where(App.id == app_id).first()
if not app:
return None
tenant_id = app.tenant_id
@@ -148,7 +148,7 @@ class OpsService:
# check if trace config already exists
trace_config_data: Optional[TraceAppConfig] = (
db.session.query(TraceAppConfig)
- .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
+ .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
@@ -156,7 +156,7 @@ class OpsService:
return None
# get tenant id
- app = db.session.query(App).filter(App.id == app_id).first()
+ app = db.session.query(App).where(App.id == app_id).first()
if not app:
return None
tenant_id = app.tenant_id
@@ -190,7 +190,7 @@ class OpsService:
# check if trace config already exists
current_trace_config = (
db.session.query(TraceAppConfig)
- .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
+ .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
@@ -198,7 +198,7 @@ class OpsService:
return None
# get tenant id
- app = db.session.query(App).filter(App.id == app_id).first()
+ app = db.session.query(App).where(App.id == app_id).first()
if not app:
return None
tenant_id = app.tenant_id
@@ -227,7 +227,7 @@ class OpsService:
"""
trace_config = (
db.session.query(TraceAppConfig)
- .filter(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
+ .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider)
.first()
)
diff --git a/api/services/plugin/plugin_auto_upgrade_service.py b/api/services/plugin/plugin_auto_upgrade_service.py
new file mode 100644
index 0000000000..3774050445
--- /dev/null
+++ b/api/services/plugin/plugin_auto_upgrade_service.py
@@ -0,0 +1,87 @@
+from sqlalchemy.orm import Session
+
+from extensions.ext_database import db
+from models.account import TenantPluginAutoUpgradeStrategy
+
+
+class PluginAutoUpgradeService:
+ @staticmethod
+ def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
+ with Session(db.engine) as session:
+ return (
+ session.query(TenantPluginAutoUpgradeStrategy)
+ .filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
+ .first()
+ )
+
+ @staticmethod
+ def change_strategy(
+ tenant_id: str,
+ strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting,
+ upgrade_time_of_day: int,
+ upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode,
+ exclude_plugins: list[str],
+ include_plugins: list[str],
+ ) -> bool:
+ with Session(db.engine) as session:
+ exist_strategy = (
+ session.query(TenantPluginAutoUpgradeStrategy)
+ .filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
+ .first()
+ )
+ if not exist_strategy:
+ strategy = TenantPluginAutoUpgradeStrategy(
+ tenant_id=tenant_id,
+ strategy_setting=strategy_setting,
+ upgrade_time_of_day=upgrade_time_of_day,
+ upgrade_mode=upgrade_mode,
+ exclude_plugins=exclude_plugins,
+ include_plugins=include_plugins,
+ )
+ session.add(strategy)
+ else:
+ exist_strategy.strategy_setting = strategy_setting
+ exist_strategy.upgrade_time_of_day = upgrade_time_of_day
+ exist_strategy.upgrade_mode = upgrade_mode
+ exist_strategy.exclude_plugins = exclude_plugins
+ exist_strategy.include_plugins = include_plugins
+
+ session.commit()
+ return True
+
+ @staticmethod
+ def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
+ with Session(db.engine) as session:
+ exist_strategy = (
+ session.query(TenantPluginAutoUpgradeStrategy)
+ .filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
+ .first()
+ )
+ if not exist_strategy:
+ # create for this tenant
+ PluginAutoUpgradeService.change_strategy(
+ tenant_id,
+ TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
+ 0,
+ TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
+ [plugin_id],
+ [],
+ )
+ return True
+ else:
+ if exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE:
+ if plugin_id not in exist_strategy.exclude_plugins:
+ new_exclude_plugins = exist_strategy.exclude_plugins.copy()
+ new_exclude_plugins.append(plugin_id)
+ exist_strategy.exclude_plugins = new_exclude_plugins
+ elif exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL:
+ if plugin_id in exist_strategy.include_plugins:
+ new_include_plugins = exist_strategy.include_plugins.copy()
+ new_include_plugins.remove(plugin_id)
+ exist_strategy.include_plugins = new_include_plugins
+ elif exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL:
+ exist_strategy.upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
+ exist_strategy.exclude_plugins = [plugin_id]
+
+ session.commit()
+ return True
diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py
index dbaaa7160e..1806fbcfd6 100644
--- a/api/services/plugin/plugin_migration.py
+++ b/api/services/plugin/plugin_migration.py
@@ -101,7 +101,7 @@ class PluginMigration:
for test_interval in test_intervals:
tenant_count = (
session.query(Tenant.id)
- .filter(Tenant.created_at.between(current_time, current_time + test_interval))
+ .where(Tenant.created_at.between(current_time, current_time + test_interval))
.count()
)
if tenant_count <= 100:
@@ -126,7 +126,7 @@ class PluginMigration:
rs = (
session.query(Tenant.id)
- .filter(Tenant.created_at.between(current_time, batch_end))
+ .where(Tenant.created_at.between(current_time, batch_end))
.order_by(Tenant.created_at)
)
@@ -212,7 +212,7 @@ class PluginMigration:
Extract tool tables.
"""
with Session(db.engine) as session:
- rs = session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all()
+ rs = session.query(BuiltinToolProvider).where(BuiltinToolProvider.tenant_id == tenant_id).all()
result = []
for row in rs:
result.append(ToolProviderID(row.provider).plugin_id)
@@ -226,7 +226,7 @@ class PluginMigration:
"""
with Session(db.engine) as session:
- rs = session.query(Workflow).filter(Workflow.tenant_id == tenant_id).all()
+ rs = session.query(Workflow).where(Workflow.tenant_id == tenant_id).all()
result = []
for row in rs:
graph = row.graph_dict
@@ -249,7 +249,7 @@ class PluginMigration:
Extract app tables.
"""
with Session(db.engine) as session:
- apps = session.query(App).filter(App.tenant_id == tenant_id).all()
+ apps = session.query(App).where(App.tenant_id == tenant_id).all()
if not apps:
return []
@@ -257,7 +257,7 @@ class PluginMigration:
app.app_model_config_id for app in apps if app.is_agent or app.mode == AppMode.AGENT_CHAT.value
]
- rs = session.query(AppModelConfig).filter(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
+ rs = session.query(AppModelConfig).where(AppModelConfig.id.in_(agent_app_model_config_ids)).all()
result = []
for row in rs:
agent_config = row.agent_mode_dict
diff --git a/api/services/plugin/plugin_parameter_service.py b/api/services/plugin/plugin_parameter_service.py
index 393213c0e2..00b59dacb3 100644
--- a/api/services/plugin/plugin_parameter_service.py
+++ b/api/services/plugin/plugin_parameter_service.py
@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from core.plugin.entities.parameters import PluginParameterOption
from core.plugin.impl.dynamic_select import DynamicSelectClient
from core.tools.tool_manager import ToolManager
-from core.tools.utils.configuration import ProviderConfigEncrypter
+from core.tools.utils.encryption import create_tool_provider_encrypter
from extensions.ext_database import db
from models.tools import BuiltinToolProvider
@@ -38,11 +38,9 @@ class PluginParameterService:
case "tool":
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
# init tool configuration
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ controller=provider_controller,
)
# check if credentials are required
@@ -53,7 +51,7 @@ class PluginParameterService:
with Session(db.engine) as session:
db_record = (
session.query(BuiltinToolProvider)
- .filter(
+ .where(
BuiltinToolProvider.tenant_id == tenant_id,
BuiltinToolProvider.provider == provider,
)
@@ -63,7 +61,7 @@ class PluginParameterService:
if db_record is None:
raise ValueError(f"Builtin provider {provider} not found when fetching credentials")
- credentials = tool_configuration.decrypt(db_record.credentials)
+ credentials = encrypter.decrypt(db_record.credentials)
case _:
raise ValueError(f"Invalid provider type: {provider_type}")
diff --git a/api/services/plugin/plugin_permission_service.py b/api/services/plugin/plugin_permission_service.py
index 275e496037..60fa269640 100644
--- a/api/services/plugin/plugin_permission_service.py
+++ b/api/services/plugin/plugin_permission_service.py
@@ -8,7 +8,7 @@ class PluginPermissionService:
@staticmethod
def get_permission(tenant_id: str) -> TenantPluginPermission | None:
with Session(db.engine) as session:
- return session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first()
+ return session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
@staticmethod
def change_permission(
@@ -18,7 +18,7 @@ class PluginPermissionService:
):
with Session(db.engine) as session:
permission = (
- session.query(TenantPluginPermission).filter(TenantPluginPermission.tenant_id == tenant_id).first()
+ session.query(TenantPluginPermission).where(TenantPluginPermission.tenant_id == tenant_id).first()
)
if not permission:
permission = TenantPluginPermission(
diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py
index 0f22afd8dd..9005f0669b 100644
--- a/api/services/plugin/plugin_service.py
+++ b/api/services/plugin/plugin_service.py
@@ -38,6 +38,9 @@ class PluginService:
plugin_id: str
version: str
unique_identifier: str
+ status: str
+ deprecated_reason: str
+ alternative_plugin_id: str
REDIS_KEY_PREFIX = "plugin_service:latest_plugin:"
REDIS_TTL = 60 * 5 # 5 minutes
@@ -71,6 +74,9 @@ class PluginService:
plugin_id=plugin_id,
version=manifest.latest_version,
unique_identifier=manifest.latest_package_identifier,
+ status=manifest.status,
+ deprecated_reason=manifest.deprecated_reason,
+ alternative_plugin_id=manifest.alternative_plugin_id,
)
# Store in Redis
@@ -196,6 +202,17 @@ class PluginService:
manager = PluginInstaller()
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
+ @staticmethod
+ def is_plugin_verified(tenant_id: str, plugin_unique_identifier: str) -> bool:
+ """
+ Check if the plugin is verified
+ """
+ manager = PluginInstaller()
+ try:
+ return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier).verified
+ except Exception:
+ return False
+
@staticmethod
def fetch_install_tasks(tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]:
"""
diff --git a/api/services/recommend_app/database/database_retrieval.py b/api/services/recommend_app/database/database_retrieval.py
index 3295516cce..b97d13d012 100644
--- a/api/services/recommend_app/database/database_retrieval.py
+++ b/api/services/recommend_app/database/database_retrieval.py
@@ -33,14 +33,14 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
"""
recommended_apps = (
db.session.query(RecommendedApp)
- .filter(RecommendedApp.is_listed == True, RecommendedApp.language == language)
+ .where(RecommendedApp.is_listed == True, RecommendedApp.language == language)
.all()
)
if len(recommended_apps) == 0:
recommended_apps = (
db.session.query(RecommendedApp)
- .filter(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
+ .where(RecommendedApp.is_listed == True, RecommendedApp.language == languages[0])
.all()
)
@@ -83,7 +83,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
# is in public recommended list
recommended_app = (
db.session.query(RecommendedApp)
- .filter(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id)
+ .where(RecommendedApp.is_listed == True, RecommendedApp.app_id == app_id)
.first()
)
@@ -91,7 +91,7 @@ class DatabaseRecommendAppRetrieval(RecommendAppRetrievalBase):
return None
# get app detail
- app_model = db.session.query(App).filter(App.id == app_id).first()
+ app_model = db.session.query(App).where(App.id == app_id).first()
if not app_model or not app_model.is_public:
return None
diff --git a/api/services/saved_message_service.py b/api/services/saved_message_service.py
index 4cb8700117..641e03c3cf 100644
--- a/api/services/saved_message_service.py
+++ b/api/services/saved_message_service.py
@@ -17,7 +17,7 @@ class SavedMessageService:
raise ValueError("User is required")
saved_messages = (
db.session.query(SavedMessage)
- .filter(
+ .where(
SavedMessage.app_id == app_model.id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
SavedMessage.created_by == user.id,
@@ -37,7 +37,7 @@ class SavedMessageService:
return
saved_message = (
db.session.query(SavedMessage)
- .filter(
+ .where(
SavedMessage.app_id == app_model.id,
SavedMessage.message_id == message_id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
@@ -67,7 +67,7 @@ class SavedMessageService:
return
saved_message = (
db.session.query(SavedMessage)
- .filter(
+ .where(
SavedMessage.app_id == app_model.id,
SavedMessage.message_id == message_id,
SavedMessage.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
diff --git a/api/services/tag_service.py b/api/services/tag_service.py
index 74c6150b44..75fa52a75c 100644
--- a/api/services/tag_service.py
+++ b/api/services/tag_service.py
@@ -16,10 +16,10 @@ class TagService:
query = (
db.session.query(Tag.id, Tag.type, Tag.name, func.count(TagBinding.id).label("binding_count"))
.outerjoin(TagBinding, Tag.id == TagBinding.tag_id)
- .filter(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
+ .where(Tag.type == tag_type, Tag.tenant_id == current_tenant_id)
)
if keyword:
- query = query.filter(db.and_(Tag.name.ilike(f"%{keyword}%")))
+ query = query.where(db.and_(Tag.name.ilike(f"%{keyword}%")))
query = query.group_by(Tag.id, Tag.type, Tag.name, Tag.created_at)
results: list = query.order_by(Tag.created_at.desc()).all()
return results
@@ -28,7 +28,7 @@ class TagService:
def get_target_ids_by_tag_ids(tag_type: str, current_tenant_id: str, tag_ids: list) -> list:
tags = (
db.session.query(Tag)
- .filter(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
+ .where(Tag.id.in_(tag_ids), Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
)
if not tags:
@@ -36,7 +36,7 @@ class TagService:
tag_ids = [tag.id for tag in tags]
tag_bindings = (
db.session.query(TagBinding.target_id)
- .filter(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
+ .where(TagBinding.tag_id.in_(tag_ids), TagBinding.tenant_id == current_tenant_id)
.all()
)
if not tag_bindings:
@@ -50,7 +50,7 @@ class TagService:
return []
tags = (
db.session.query(Tag)
- .filter(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
+ .where(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type)
.all()
)
if not tags:
@@ -62,7 +62,7 @@ class TagService:
tags = (
db.session.query(Tag)
.join(TagBinding, Tag.id == TagBinding.tag_id)
- .filter(
+ .where(
TagBinding.target_id == target_id,
TagBinding.tenant_id == current_tenant_id,
Tag.tenant_id == current_tenant_id,
@@ -92,7 +92,7 @@ class TagService:
def update_tags(args: dict, tag_id: str) -> Tag:
if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")):
raise ValueError("Tag name already exists")
- tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
+ tag = db.session.query(Tag).where(Tag.id == tag_id).first()
if not tag:
raise NotFound("Tag not found")
tag.name = args["name"]
@@ -101,17 +101,17 @@ class TagService:
@staticmethod
def get_tag_binding_count(tag_id: str) -> int:
- count = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).count()
+ count = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).count()
return count
@staticmethod
def delete_tag(tag_id: str):
- tag = db.session.query(Tag).filter(Tag.id == tag_id).first()
+ tag = db.session.query(Tag).where(Tag.id == tag_id).first()
if not tag:
raise NotFound("Tag not found")
db.session.delete(tag)
# delete tag binding
- tag_bindings = db.session.query(TagBinding).filter(TagBinding.tag_id == tag_id).all()
+ tag_bindings = db.session.query(TagBinding).where(TagBinding.tag_id == tag_id).all()
if tag_bindings:
for tag_binding in tag_bindings:
db.session.delete(tag_binding)
@@ -125,7 +125,7 @@ class TagService:
for tag_id in args["tag_ids"]:
tag_binding = (
db.session.query(TagBinding)
- .filter(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
+ .where(TagBinding.tag_id == tag_id, TagBinding.target_id == args["target_id"])
.first()
)
if tag_binding:
@@ -146,7 +146,7 @@ class TagService:
# delete tag binding
tag_bindings = (
db.session.query(TagBinding)
- .filter(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
+ .where(TagBinding.target_id == args["target_id"], TagBinding.tag_id == (args["tag_id"]))
.first()
)
if tag_bindings:
@@ -158,7 +158,7 @@ class TagService:
if type == "knowledge":
dataset = (
db.session.query(Dataset)
- .filter(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
+ .where(Dataset.tenant_id == current_user.current_tenant_id, Dataset.id == target_id)
.first()
)
if not dataset:
@@ -166,7 +166,7 @@ class TagService:
elif type == "app":
app = (
db.session.query(App)
- .filter(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
+ .where(App.tenant_id == current_user.current_tenant_id, App.id == target_id)
.first()
)
if not app:
diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py
index 6f848d49c4..78e587abee 100644
--- a/api/services/tools/api_tools_manage_service.py
+++ b/api/services/tools/api_tools_manage_service.py
@@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import (
)
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
-from core.tools.utils.configuration import ProviderConfigEncrypter
+from core.tools.utils.encryption import create_tool_provider_encrypter
from core.tools.utils.parser import ApiBasedToolSchemaParser
from extensions.ext_database import db
from models.tools import ApiToolProvider
@@ -119,7 +119,7 @@ class ApiToolManageService:
# check if the provider exists
provider = (
db.session.query(ApiToolProvider)
- .filter(
+ .where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
@@ -164,15 +164,11 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# encrypt credentials
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
- config=list(provider_controller.get_credentials_schema()),
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ controller=provider_controller,
)
-
- encrypted_credentials = tool_configuration.encrypt(credentials)
- db_provider.credentials_str = json.dumps(encrypted_credentials)
+ db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials))
db.session.add(db_provider)
db.session.commit()
@@ -214,7 +210,7 @@ class ApiToolManageService:
"""
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
- .filter(
+ .where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
@@ -261,7 +257,7 @@ class ApiToolManageService:
# check if the provider exists
provider = (
db.session.query(ApiToolProvider)
- .filter(
+ .where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == original_provider,
)
@@ -297,28 +293,26 @@ class ApiToolManageService:
provider_controller.load_bundled_tools(tool_bundles)
# get original credentials if exists
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, cache = create_tool_provider_encrypter(
tenant_id=tenant_id,
- config=list(provider_controller.get_credentials_schema()),
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ controller=provider_controller,
)
- original_credentials = tool_configuration.decrypt(provider.credentials)
- masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
+ original_credentials = encrypter.decrypt(provider.credentials)
+ masked_credentials = encrypter.mask_tool_credentials(original_credentials)
# check if the credential has changed, save the original credential
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = original_credentials[name]
- credentials = tool_configuration.encrypt(credentials)
+ credentials = encrypter.encrypt(credentials)
provider.credentials_str = json.dumps(credentials)
db.session.add(provider)
db.session.commit()
# delete cache
- tool_configuration.delete_tool_credentials_cache()
+ cache.delete()
# update labels
ToolLabelManager.update_tool_labels(provider_controller, labels)
@@ -332,7 +326,7 @@ class ApiToolManageService:
"""
provider = (
db.session.query(ApiToolProvider)
- .filter(
+ .where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
@@ -382,7 +376,7 @@ class ApiToolManageService:
db_provider = (
db.session.query(ApiToolProvider)
- .filter(
+ .where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider_name,
)
@@ -416,15 +410,13 @@ class ApiToolManageService:
# decrypt credentials
if db_provider.id:
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, _ = create_tool_provider_encrypter(
tenant_id=tenant_id,
- config=list(provider_controller.get_credentials_schema()),
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ controller=provider_controller,
)
- decrypted_credentials = tool_configuration.decrypt(credentials)
+ decrypted_credentials = encrypter.decrypt(credentials)
# check if the credential has changed, save the original credential
- masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials)
+ masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials)
for name, value in credentials.items():
if name in masked_credentials and value == masked_credentials[name]:
credentials[name] = decrypted_credentials[name]
@@ -446,13 +438,13 @@ class ApiToolManageService:
return {"result": result or "empty response"}
@staticmethod
- def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]:
+ def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]:
"""
list api tools
"""
# get all api providers
db_providers: list[ApiToolProvider] = (
- db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or []
+ db.session.query(ApiToolProvider).where(ApiToolProvider.tenant_id == tenant_id).all() or []
)
result: list[ToolProviderApiEntity] = []
@@ -474,7 +466,7 @@ class ApiToolManageService:
for tool in tools or []:
user_provider.tools.append(
ToolTransformService.convert_tool_entity_to_api_entity(
- tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels
+ tenant_id=tenant_id, tool=tool, labels=labels
)
)
diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py
index 58a4b2f179..65f05d2986 100644
--- a/api/services/tools/builtin_tools_manage_service.py
+++ b/api/services/tools/builtin_tools_manage_service.py
@@ -1,28 +1,85 @@
import json
import logging
+import re
+from collections.abc import Mapping
from pathlib import Path
+from typing import Any, Optional
from sqlalchemy.orm import Session
from configs import dify_config
+from constants import HIDDEN_VALUE, UNKNOWN_VALUE
from core.helper.position_helper import is_filtered
-from core.model_runtime.utils.encoders import jsonable_encoder
+from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
-from core.plugin.impl.exc import PluginDaemonClientSideError
+from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
-from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
-from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
+from core.tools.entities.api_entities import (
+ ToolApiEntity,
+ ToolProviderApiEntity,
+ ToolProviderCredentialApiEntity,
+ ToolProviderCredentialInfoApiEntity,
+)
+from core.tools.entities.tool_entities import CredentialType
+from core.tools.errors import ToolProviderNotFoundError
+from core.tools.plugin_tool.provider import PluginToolProviderController
from core.tools.tool_label_manager import ToolLabelManager
from core.tools.tool_manager import ToolManager
-from core.tools.utils.configuration import ProviderConfigEncrypter
+from core.tools.utils.encryption import create_provider_encrypter
+from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params
from extensions.ext_database import db
-from models.tools import BuiltinToolProvider
+from extensions.ext_redis import redis_client
+from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient
+from services.plugin.plugin_service import PluginService
from services.tools.tools_transform_service import ToolTransformService
logger = logging.getLogger(__name__)
class BuiltinToolManageService:
+ __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
+ __DEFAULT_EXPIRES_AT__ = 2147483647
+
+ @staticmethod
+ def delete_custom_oauth_client_params(tenant_id: str, provider: str):
+ """
+ delete custom oauth client params
+ """
+ tool_provider = ToolProviderID(provider)
+ with Session(db.engine) as session:
+ session.query(ToolOAuthTenantClient).filter_by(
+ tenant_id=tenant_id,
+ provider=tool_provider.provider_name,
+ plugin_id=tool_provider.plugin_id,
+ ).delete()
+ session.commit()
+ return {"result": "success"}
+
+ @staticmethod
+ def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str):
+ """
+ get builtin tool provider oauth client schema
+ """
+ provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
+ verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
+ tenant_id, provider.plugin_unique_identifier
+ )
+
+ is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled(
+ tenant_id, provider_name
+ )
+ is_system_oauth_params_exists = verified and BuiltinToolManageService.is_oauth_system_client_exists(
+ provider_name
+ )
+ result = {
+ "schema": provider.get_oauth_client_schema(),
+ "is_oauth_custom_client_enabled": is_oauth_custom_client_enabled,
+ "is_system_oauth_params_exists": is_system_oauth_params_exists,
+ "client_params": BuiltinToolManageService.get_custom_oauth_client_params(tenant_id, provider_name),
+ "redirect_uri": f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback",
+ }
+ return result
+
@staticmethod
def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]:
"""
@@ -36,27 +93,11 @@ class BuiltinToolManageService:
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
tools = provider_controller.get_tools()
- tool_provider_configurations = ProviderConfigEncrypter(
- tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
- )
- # check if user has added the provider
- builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
-
- credentials = {}
- if builtin_provider is not None:
- # get credentials
- credentials = builtin_provider.credentials
- credentials = tool_provider_configurations.decrypt(credentials)
-
result: list[ToolApiEntity] = []
for tool in tools or []:
result.append(
ToolTransformService.convert_tool_entity_to_api_entity(
tool=tool,
- credentials=credentials,
tenant_id=tenant_id,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
@@ -65,25 +106,15 @@ class BuiltinToolManageService:
return result
@staticmethod
- def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str):
+ def get_builtin_tool_provider_info(tenant_id: str, provider: str):
"""
get builtin tool provider info
"""
provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
- tool_provider_configurations = ProviderConfigEncrypter(
- tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
- )
# check if user has added the provider
- builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id)
-
- credentials = {}
- if builtin_provider is not None:
- # get credentials
- credentials = builtin_provider.credentials
- credentials = tool_provider_configurations.decrypt(credentials)
+ builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id)
+ if builtin_provider is None:
+ raise ValueError(f"you have not added provider {provider}")
entity = ToolTransformService.builtin_provider_to_user_provider(
provider_controller=provider_controller,
@@ -92,127 +123,410 @@ class BuiltinToolManageService:
)
entity.original_credentials = {}
-
return entity
@staticmethod
- def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str):
+ def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str):
"""
list builtin provider credentials schema
+ :param credential_type: credential type
:param provider_name: the name of the provider
:param tenant_id: the id of the tenant
:return: the list of tool providers
"""
provider = ToolManager.get_builtin_provider(provider_name, tenant_id)
- return jsonable_encoder(provider.get_credentials_schema())
+ return provider.get_credentials_schema_by_type(credential_type)
@staticmethod
def update_builtin_tool_provider(
- session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict
+ user_id: str,
+ tenant_id: str,
+ provider: str,
+ credential_id: str,
+ credentials: dict | None = None,
+ name: str | None = None,
):
"""
update builtin tool provider
"""
- # get if the provider exists
- provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
+ with Session(db.engine) as session:
+ # get if the provider exists
+ db_provider = (
+ session.query(BuiltinToolProvider)
+ .where(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ BuiltinToolProvider.id == credential_id,
+ )
+ .first()
+ )
+ if db_provider is None:
+ raise ValueError(f"you have not added provider {provider}")
+
+ try:
+ if CredentialType.of(db_provider.credential_type).is_editable() and credentials:
+ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
+ if not provider_controller.need_credentials:
+ raise ValueError(f"provider {provider} does not need credentials")
+
+ encrypter, cache = BuiltinToolManageService.create_tool_encrypter(
+ tenant_id, db_provider, provider, provider_controller
+ )
+ original_credentials = encrypter.decrypt(db_provider.credentials)
+ new_credentials: dict = {
+ key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE)
+ for key, value in credentials.items()
+ }
+
+ if CredentialType.of(db_provider.credential_type).is_validate_allowed():
+ provider_controller.validate_credentials(user_id, new_credentials)
+
+ # encrypt credentials
+ db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials))
+
+ cache.delete()
+
+ # update name if provided
+ if name and name != db_provider.name:
+ # check if the name is already used
+ if (
+ session.query(BuiltinToolProvider)
+ .filter_by(tenant_id=tenant_id, provider=provider, name=name)
+ .count()
+ > 0
+ ):
+ raise ValueError(f"the credential name '{name}' is already used")
+
+ db_provider.name = name
+
+ session.commit()
+ except Exception as e:
+ session.rollback()
+ raise ValueError(str(e))
+ return {"result": "success"}
+
+ @staticmethod
+ def add_builtin_tool_provider(
+ user_id: str,
+ api_type: CredentialType,
+ tenant_id: str,
+ provider: str,
+ credentials: dict,
+ expires_at: int = -1,
+ name: str | None = None,
+ ):
+ """
+ add builtin tool provider
+ """
try:
- # get provider
- provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
- if not provider_controller.need_credentials:
- raise ValueError(f"provider {provider_name} does not need credentials")
- tool_configuration = ProviderConfigEncrypter(
- tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
- )
+ with Session(db.engine) as session:
+ lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}"
+ with redis_client.lock(lock, timeout=20):
+ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
+ if not provider_controller.need_credentials:
+ raise ValueError(f"provider {provider} does not need credentials")
+
+ provider_count = (
+ session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count()
+ )
+
+ # check if the provider count is reached the limit
+ if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__:
+ raise ValueError(f"you have reached the maximum number of providers for {provider}")
+
+ # validate credentials if allowed
+ if CredentialType.of(api_type).is_validate_allowed():
+ provider_controller.validate_credentials(user_id, credentials)
- # get original credentials if exists
- if provider is not None:
- original_credentials = tool_configuration.decrypt(provider.credentials)
- masked_credentials = tool_configuration.mask_tool_credentials(original_credentials)
- # check if the credential has changed, save the original credential
- for name, value in credentials.items():
- if name in masked_credentials and value == masked_credentials[name]:
- credentials[name] = original_credentials[name]
- # validate credentials
- provider_controller.validate_credentials(user_id, credentials)
- # encrypt credentials
- credentials = tool_configuration.encrypt(credentials)
- except (
- PluginDaemonClientSideError,
- ToolProviderNotFoundError,
- ToolNotFoundError,
- ToolProviderCredentialValidationError,
- ) as e:
+ # generate name if not provided
+ if name is None or name == "":
+ name = BuiltinToolManageService.generate_builtin_tool_provider_name(
+ session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type
+ )
+ else:
+ # check if the name is already used
+ if (
+ session.query(BuiltinToolProvider)
+ .filter_by(tenant_id=tenant_id, provider=provider, name=name)
+ .count()
+ > 0
+ ):
+ raise ValueError(f"the credential name '{name}' is already used")
+
+ # create encrypter
+ encrypter, _ = create_provider_encrypter(
+ tenant_id=tenant_id,
+ config=[
+ x.to_basic_provider_config()
+ for x in provider_controller.get_credentials_schema_by_type(api_type)
+ ],
+ cache=NoOpProviderCredentialCache(),
+ )
+
+ db_provider = BuiltinToolProvider(
+ tenant_id=tenant_id,
+ user_id=user_id,
+ provider=provider,
+ encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
+ credential_type=api_type.value,
+ name=name,
+ expires_at=expires_at
+ if expires_at is not None
+ else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__,
+ )
+
+ session.add(db_provider)
+ session.commit()
+ except Exception as e:
+ session.rollback()
raise ValueError(str(e))
+ return {"result": "success"}
- if provider is None:
- # create provider
- provider = BuiltinToolProvider(
- tenant_id=tenant_id,
- user_id=user_id,
- provider=provider_name,
- encrypted_credentials=json.dumps(credentials),
+ @staticmethod
+ def create_tool_encrypter(
+ tenant_id: str,
+ db_provider: BuiltinToolProvider,
+ provider: str,
+ provider_controller: BuiltinToolProviderController,
+ ):
+ encrypter, cache = create_provider_encrypter(
+ tenant_id=tenant_id,
+ config=[
+ x.to_basic_provider_config()
+ for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type)
+ ],
+ cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id),
+ )
+ return encrypter, cache
+
+ @staticmethod
+ def generate_builtin_tool_provider_name(
+ session: Session, tenant_id: str, provider: str, credential_type: CredentialType
+ ) -> str:
+ try:
+ db_providers = (
+ session.query(BuiltinToolProvider)
+ .filter_by(
+ tenant_id=tenant_id,
+ provider=provider,
+ credential_type=credential_type.value,
+ )
+ .order_by(BuiltinToolProvider.created_at.desc())
+ .all()
)
- db.session.add(provider)
- else:
- provider.encrypted_credentials = json.dumps(credentials)
+ # Get the default name pattern
+ default_pattern = f"{credential_type.get_name()}"
- # delete cache
- tool_configuration.delete_tool_credentials_cache()
+ # Find all names that match the default pattern: "{default_pattern} {number}"
+ pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$"
+ numbers = []
- db.session.commit()
- return {"result": "success"}
+ for db_provider in db_providers:
+ if db_provider.name:
+ match = re.match(pattern, db_provider.name.strip())
+ if match:
+ numbers.append(int(match.group(1)))
+
+ # If no default pattern names found, start with 1
+ if not numbers:
+ return f"{default_pattern} 1"
+
+ # Find the next number
+ max_number = max(numbers)
+ return f"{default_pattern} {max_number + 1}"
+ except Exception as e:
+ logger.warning(f"Error generating next provider name for {provider}: {str(e)}")
+ # fallback
+ return f"{credential_type.get_name()} 1"
@staticmethod
- def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str):
+ def get_builtin_tool_provider_credentials(
+ tenant_id: str, provider_name: str
+ ) -> list[ToolProviderCredentialApiEntity]:
"""
get builtin tool provider credentials
"""
- provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
+ with db.session.no_autoflush:
+ providers = (
+ db.session.query(BuiltinToolProvider)
+ .filter_by(tenant_id=tenant_id, provider=provider_name)
+ .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
+ .all()
+ )
- if provider_obj is None:
- return {}
+ if len(providers) == 0:
+ return []
+
+ default_provider = providers[0]
+ default_provider.is_default = True
+ provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id)
+
+ credentials: list[ToolProviderCredentialApiEntity] = []
+ encrypters = {}
+ for provider in providers:
+ credential_type = provider.credential_type
+ if credential_type not in encrypters:
+ encrypters[credential_type] = BuiltinToolManageService.create_tool_encrypter(
+ tenant_id, provider, provider.provider, provider_controller
+ )[0]
+ encrypter = encrypters[credential_type]
+ decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials))
+ credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity(
+ provider=provider,
+ credentials=decrypt_credential,
+ )
+ credentials.append(credential_entity)
+ return credentials
- provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id)
- tool_configuration = ProviderConfigEncrypter(
- tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ @staticmethod
+ def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity:
+ """
+ get builtin tool provider credential info
+ """
+ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
+ supported_credential_types = provider_controller.get_supported_credential_types()
+ credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider)
+ credential_info = ToolProviderCredentialInfoApiEntity(
+ supported_credential_types=supported_credential_types,
+ is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider),
+ credentials=credentials,
)
- credentials = tool_configuration.decrypt(provider_obj.credentials)
- credentials = tool_configuration.mask_tool_credentials(credentials)
- return credentials
+
+ return credential_info
@staticmethod
- def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str):
+ def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str):
"""
delete tool provider
"""
- provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id)
+ with Session(db.engine) as session:
+ db_provider = (
+ session.query(BuiltinToolProvider)
+ .where(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ BuiltinToolProvider.id == credential_id,
+ )
+ .first()
+ )
+
+ if db_provider is None:
+ raise ValueError(f"you have not added provider {provider}")
+
+ session.delete(db_provider)
+ session.commit()
+
+ # delete cache
+ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
+ _, cache = BuiltinToolManageService.create_tool_encrypter(
+ tenant_id, db_provider, provider, provider_controller
+ )
+ cache.delete()
+
+ return {"result": "success"}
+
+ @staticmethod
+ def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
+ """
+ set default provider
+ """
+ with Session(db.engine) as session:
+ # get provider
+ target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
+ if target_provider is None:
+ raise ValueError("provider not found")
+
+ # clear default provider
+ session.query(BuiltinToolProvider).filter_by(
+ tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
+ ).update({"is_default": False})
+
+ # set new default provider
+ target_provider.is_default = True
+ session.commit()
+ return {"result": "success"}
- if provider_obj is None:
- raise ValueError(f"you have not added provider {provider_name}")
+ @staticmethod
+ def is_oauth_system_client_exists(provider_name: str) -> bool:
+ """
+ check if oauth system client exists
+ """
+ tool_provider = ToolProviderID(provider_name)
+ with Session(db.engine).no_autoflush as session:
+ system_client: ToolOAuthSystemClient | None = (
+ session.query(ToolOAuthSystemClient)
+ .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
+ .first()
+ )
+ return system_client is not None
- db.session.delete(provider_obj)
- db.session.commit()
+ @staticmethod
+ def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool:
+ """
+ check if oauth custom client is enabled
+ """
+ tool_provider = ToolProviderID(provider)
+ with Session(db.engine).no_autoflush as session:
+ user_client: ToolOAuthTenantClient | None = (
+ session.query(ToolOAuthTenantClient)
+ .filter_by(
+ tenant_id=tenant_id,
+ provider=tool_provider.provider_name,
+ plugin_id=tool_provider.plugin_id,
+ enabled=True,
+ )
+ .first()
+ )
+ return user_client is not None and user_client.enabled
- # delete cache
- provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id)
- tool_configuration = ProviderConfigEncrypter(
+ @staticmethod
+ def get_oauth_client(tenant_id: str, provider: str) -> Mapping[str, Any] | None:
+ """
+ get builtin tool provider
+ """
+ tool_provider = ToolProviderID(provider)
+ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
+ encrypter, _ = create_provider_encrypter(
tenant_id=tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
+ cache=NoOpProviderCredentialCache(),
)
- tool_configuration.delete_tool_credentials_cache()
+ with Session(db.engine).no_autoflush as session:
+ user_client: ToolOAuthTenantClient | None = (
+ session.query(ToolOAuthTenantClient)
+ .filter_by(
+ tenant_id=tenant_id,
+ provider=tool_provider.provider_name,
+ plugin_id=tool_provider.plugin_id,
+ enabled=True,
+ )
+ .first()
+ )
+ oauth_params: Mapping[str, Any] | None = None
+ if user_client:
+ oauth_params = encrypter.decrypt(user_client.oauth_params)
+ return oauth_params
+
+ # only verified provider can use custom oauth client
+ is_verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified(
+ tenant_id, provider.plugin_unique_identifier
+ )
+ if not is_verified:
+ return oauth_params
- return {"result": "success"}
+ system_client: ToolOAuthSystemClient | None = (
+ session.query(ToolOAuthSystemClient)
+ .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name)
+ .first()
+ )
+ if system_client:
+ try:
+ oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params)
+ except Exception as e:
+ raise ValueError(f"Error decrypting system oauth params: {e}")
+
+ return oauth_params
@staticmethod
def get_builtin_tool_provider_icon(provider: str):
@@ -234,9 +548,7 @@ class BuiltinToolManageService:
with db.session.no_autoflush:
# get all user added providers
- db_providers: list[BuiltinToolProvider] = (
- db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or []
- )
+ db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id)
# rewrite db_providers
for db_provider in db_providers:
@@ -275,7 +587,6 @@ class BuiltinToolManageService:
ToolTransformService.convert_tool_entity_to_api_entity(
tenant_id=tenant_id,
tool=tool,
- credentials=user_builtin_provider.original_credentials,
labels=ToolLabelManager.get_tool_labels(provider_controller),
)
)
@@ -287,43 +598,153 @@ class BuiltinToolManageService:
return BuiltinToolProviderSort.sort(result)
@staticmethod
- def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None:
- try:
- full_provider_name = provider_name
- provider_id_entity = ToolProviderID(provider_name)
- provider_name = provider_id_entity.provider_name
- if provider_id_entity.organization != "langgenius":
- provider_obj = (
- db.session.query(BuiltinToolProvider)
- .filter(
- BuiltinToolProvider.tenant_id == tenant_id,
- BuiltinToolProvider.provider == full_provider_name,
+ def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]:
+ """
+ This method is used to fetch the builtin provider from the database
+ 1.if the default provider exists, return the default provider
+ 2.if the default provider does not exist, return the oldest provider
+ """
+ with Session(db.engine) as session:
+ try:
+ full_provider_name = provider_name
+ provider_id_entity = ToolProviderID(provider_name)
+ provider_name = provider_id_entity.provider_name
+
+ if provider_id_entity.organization != "langgenius":
+ provider = (
+ session.query(BuiltinToolProvider)
+ .where(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ BuiltinToolProvider.provider == full_provider_name,
+ )
+ .order_by(
+ BuiltinToolProvider.is_default.desc(), # default=True first
+ BuiltinToolProvider.created_at.asc(), # oldest first
+ )
+ .first()
)
- .first()
- )
- else:
- provider_obj = (
- db.session.query(BuiltinToolProvider)
- .filter(
- BuiltinToolProvider.tenant_id == tenant_id,
- (BuiltinToolProvider.provider == provider_name)
- | (BuiltinToolProvider.provider == full_provider_name),
+ else:
+ provider = (
+ session.query(BuiltinToolProvider)
+ .where(
+ BuiltinToolProvider.tenant_id == tenant_id,
+ (BuiltinToolProvider.provider == provider_name)
+ | (BuiltinToolProvider.provider == full_provider_name),
+ )
+ .order_by(
+ BuiltinToolProvider.is_default.desc(), # default=True first
+ BuiltinToolProvider.created_at.asc(), # oldest first
+ )
+ .first()
+ )
+
+ if provider is None:
+ return None
+
+ provider.provider = ToolProviderID(provider.provider).to_string()
+ return provider
+ except Exception:
+ # it's an old provider without organization
+ return (
+ session.query(BuiltinToolProvider)
+ .where(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name)
+ .order_by(
+ BuiltinToolProvider.is_default.desc(), # default=True first
+ BuiltinToolProvider.created_at.asc(), # oldest first
)
.first()
)
- if provider_obj is None:
- return None
+ @staticmethod
+ def save_custom_oauth_client_params(
+ tenant_id: str,
+ provider: str,
+ client_params: Optional[dict] = None,
+ enable_oauth_custom_client: Optional[bool] = None,
+ ):
+ """
+ setup oauth custom client
+ """
+ if client_params is None and enable_oauth_custom_client is None:
+ return {"result": "success"}
- provider_obj.provider = ToolProviderID(provider_obj.provider).to_string()
- return provider_obj
- except Exception:
- # it's an old provider without organization
- return (
- db.session.query(BuiltinToolProvider)
- .filter(
- BuiltinToolProvider.tenant_id == tenant_id,
- (BuiltinToolProvider.provider == provider_name),
+ tool_provider = ToolProviderID(provider)
+ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
+ if not provider_controller:
+ raise ToolProviderNotFoundError(f"Provider {provider} not found")
+
+ if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)):
+ raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
+
+ with Session(db.engine) as session:
+ custom_client_params = (
+ session.query(ToolOAuthTenantClient)
+ .filter_by(
+ tenant_id=tenant_id,
+ plugin_id=tool_provider.plugin_id,
+ provider=tool_provider.provider_name,
+ )
+ .first()
+ )
+
+ # if the record does not exist, create a basic record
+ if custom_client_params is None:
+ custom_client_params = ToolOAuthTenantClient(
+ tenant_id=tenant_id,
+ plugin_id=tool_provider.plugin_id,
+ provider=tool_provider.provider_name,
+ )
+ session.add(custom_client_params)
+
+ if client_params is not None:
+ encrypter, _ = create_provider_encrypter(
+ tenant_id=tenant_id,
+ config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
+ cache=NoOpProviderCredentialCache(),
+ )
+ original_params = encrypter.decrypt(custom_client_params.oauth_params)
+ new_params: dict = {
+ key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE)
+ for key, value in client_params.items()
+ }
+ custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params))
+
+ if enable_oauth_custom_client is not None:
+ custom_client_params.enabled = enable_oauth_custom_client
+
+ session.commit()
+ return {"result": "success"}
+
+ @staticmethod
+ def get_custom_oauth_client_params(tenant_id: str, provider: str):
+ """
+ get custom oauth client params
+ """
+ with Session(db.engine) as session:
+ tool_provider = ToolProviderID(provider)
+ custom_oauth_client_params: ToolOAuthTenantClient | None = (
+ session.query(ToolOAuthTenantClient)
+ .filter_by(
+ tenant_id=tenant_id,
+ plugin_id=tool_provider.plugin_id,
+ provider=tool_provider.provider_name,
)
.first()
)
+ if custom_oauth_client_params is None:
+ return {}
+
+ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id)
+ if not provider_controller:
+ raise ToolProviderNotFoundError(f"Provider {provider} not found")
+
+ if not isinstance(provider_controller, BuiltinToolProviderController):
+ raise ValueError(f"Provider {provider} is not a builtin or plugin provider")
+
+ encrypter, _ = create_provider_encrypter(
+ tenant_id=tenant_id,
+ config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()],
+ cache=NoOpProviderCredentialCache(),
+ )
+
+ return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params))
diff --git a/api/services/tools/mcp_tools_mange_service.py b/api/services/tools/mcp_tools_manage_service.py
similarity index 72%
rename from api/services/tools/mcp_tools_mange_service.py
rename to api/services/tools/mcp_tools_manage_service.py
index 7c23abda4b..23be449a5a 100644
--- a/api/services/tools/mcp_tools_mange_service.py
+++ b/api/services/tools/mcp_tools_manage_service.py
@@ -7,13 +7,14 @@ from sqlalchemy import or_
from sqlalchemy.exc import IntegrityError
from core.helper import encrypter
+from core.helper.provider_cache import NoOpProviderCredentialCache
from core.mcp.error import MCPAuthError, MCPError
from core.mcp.mcp_client import MCPClient
from core.tools.entities.api_entities import ToolProviderApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.mcp_tool.provider import MCPToolProviderController
-from core.tools.utils.configuration import ProviderConfigEncrypter
+from core.tools.utils.encryption import ProviderConfigEncrypter
from extensions.ext_database import db
from models.tools import MCPToolProvider
from services.tools.tools_transform_service import ToolTransformService
@@ -30,7 +31,7 @@ class MCPToolManageService:
def get_mcp_provider_by_provider_id(provider_id: str, tenant_id: str) -> MCPToolProvider:
res = (
db.session.query(MCPToolProvider)
- .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
+ .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.id == provider_id)
.first()
)
if not res:
@@ -41,7 +42,7 @@ class MCPToolManageService:
def get_mcp_provider_by_server_identifier(server_identifier: str, tenant_id: str) -> MCPToolProvider:
res = (
db.session.query(MCPToolProvider)
- .filter(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
+ .where(MCPToolProvider.tenant_id == tenant_id, MCPToolProvider.server_identifier == server_identifier)
.first()
)
if not res:
@@ -62,7 +63,7 @@ class MCPToolManageService:
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
existing_provider = (
db.session.query(MCPToolProvider)
- .filter(
+ .where(
MCPToolProvider.tenant_id == tenant_id,
or_(
MCPToolProvider.name == name,
@@ -75,9 +76,9 @@ class MCPToolManageService:
if existing_provider:
if existing_provider.name == name:
raise ValueError(f"MCP tool {name} already exists")
- elif existing_provider.server_url_hash == server_url_hash:
+ if existing_provider.server_url_hash == server_url_hash:
raise ValueError(f"MCP tool {server_url} already exists")
- elif existing_provider.server_identifier == server_identifier:
+ if existing_provider.server_identifier == server_identifier:
raise ValueError(f"MCP tool {server_identifier} already exists")
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
mcp_tool = MCPToolProvider(
@@ -99,7 +100,7 @@ class MCPToolManageService:
def retrieve_mcp_tools(tenant_id: str, for_list: bool = False) -> list[ToolProviderApiEntity]:
mcp_providers = (
db.session.query(MCPToolProvider)
- .filter(MCPToolProvider.tenant_id == tenant_id)
+ .where(MCPToolProvider.tenant_id == tenant_id)
.order_by(MCPToolProvider.name)
.all()
)
@@ -109,22 +110,29 @@ class MCPToolManageService:
]
@classmethod
- def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str):
+ def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
+ server_url = mcp_provider.decrypted_server_url
+ authed = mcp_provider.authed
try:
- with MCPClient(
- mcp_provider.decrypted_server_url, provider_id, tenant_id, authed=mcp_provider.authed, for_list=True
- ) as mcp_client:
+ with MCPClient(server_url, provider_id, tenant_id, authed=authed, for_list=True) as mcp_client:
tools = mcp_client.list_tools()
- except MCPAuthError as e:
+ except MCPAuthError:
raise ValueError("Please auth the tool first")
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
- mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
- mcp_provider.authed = True
- mcp_provider.updated_at = datetime.now()
- db.session.commit()
+
+ try:
+ mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
+ mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
+ mcp_provider.authed = True
+ mcp_provider.updated_at = datetime.now()
+ db.session.commit()
+ except Exception:
+ db.session.rollback()
+ raise
+
user = mcp_provider.load_user()
return ToolProviderApiEntity(
id=mcp_provider.id,
@@ -160,34 +168,49 @@ class MCPToolManageService:
server_identifier: str,
):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
- mcp_provider.updated_at = datetime.now()
- mcp_provider.name = name
- mcp_provider.icon = (
- json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
- )
- mcp_provider.server_identifier = server_identifier
+
+ reconnect_result = None
+ encrypted_server_url = None
+ server_url_hash = None
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
- mcp_provider.server_url = encrypted_server_url
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
if server_url_hash != mcp_provider.server_url_hash:
- cls._re_connect_mcp_provider(mcp_provider, provider_id, tenant_id)
- mcp_provider.server_url_hash = server_url_hash
+ reconnect_result = cls._re_connect_mcp_provider(server_url, provider_id, tenant_id)
+
try:
+ mcp_provider.updated_at = datetime.now()
+ mcp_provider.name = name
+ mcp_provider.icon = (
+ json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
+ )
+ mcp_provider.server_identifier = server_identifier
+
+ if encrypted_server_url is not None and server_url_hash is not None:
+ mcp_provider.server_url = encrypted_server_url
+ mcp_provider.server_url_hash = server_url_hash
+
+ if reconnect_result:
+ mcp_provider.authed = reconnect_result["authed"]
+ mcp_provider.tools = reconnect_result["tools"]
+ mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
+
db.session.commit()
except IntegrityError as e:
db.session.rollback()
error_msg = str(e.orig)
if "unique_mcp_provider_name" in error_msg:
raise ValueError(f"MCP tool {name} already exists")
- elif "unique_mcp_provider_server_url" in error_msg:
+ if "unique_mcp_provider_server_url" in error_msg:
raise ValueError(f"MCP tool {server_url} already exists")
- elif "unique_mcp_provider_server_identifier" in error_msg:
+ if "unique_mcp_provider_server_identifier" in error_msg:
raise ValueError(f"MCP tool {server_identifier} already exists")
- else:
- raise
+ raise
+ except Exception:
+ db.session.rollback()
+ raise
@classmethod
def update_mcp_provider_credentials(
@@ -197,8 +220,7 @@ class MCPToolManageService:
tool_configuration = ProviderConfigEncrypter(
tenant_id=mcp_provider.tenant_id,
config=list(provider_controller.get_credentials_schema()),
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.provider_id,
+ provider_config_cache=NoOpProviderCredentialCache(),
)
credentials = tool_configuration.encrypt(credentials)
mcp_provider.updated_at = datetime.now()
@@ -209,23 +231,22 @@ class MCPToolManageService:
db.session.commit()
@classmethod
- def _re_connect_mcp_provider(cls, mcp_provider: MCPToolProvider, provider_id: str, tenant_id: str):
- """re-connect mcp provider"""
+ def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str):
try:
with MCPClient(
- mcp_provider.decrypted_server_url,
+ server_url,
provider_id,
tenant_id,
authed=False,
for_list=True,
) as mcp_client:
tools = mcp_client.list_tools()
- mcp_provider.authed = True
- mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
+ return {
+ "authed": True,
+ "tools": json.dumps([tool.model_dump() for tool in tools]),
+ "encrypted_credentials": "{}",
+ }
except MCPAuthError:
- mcp_provider.authed = False
- mcp_provider.tools = "[]"
+ return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"}
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
- # reset credentials
- mcp_provider.encrypted_credentials = "{}"
diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py
index 3d0c35cd9b..2d192e6f7f 100644
--- a/api/services/tools/tools_transform_service.py
+++ b/api/services/tools/tools_transform_service.py
@@ -5,21 +5,23 @@ from typing import Any, Optional, Union, cast
from yarl import URL
from configs import dify_config
+from core.helper.provider_cache import ToolProviderCredentialsCache
from core.mcp.types import Tool as MCPTool
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.custom_tool.provider import ApiToolProviderController
-from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
+from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_bundle import ApiToolBundle
from core.tools.entities.tool_entities import (
ApiProviderAuthType,
+ CredentialType,
ToolParameter,
ToolProviderType,
)
from core.tools.plugin_tool.provider import PluginToolProviderController
-from core.tools.utils.configuration import ProviderConfigEncrypter
+from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter
from core.tools.workflow_as_tool.provider import WorkflowToolProviderController
from core.tools.workflow_as_tool.tool import WorkflowTool
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
@@ -119,7 +121,12 @@ class ToolTransformService:
result.plugin_unique_identifier = provider_controller.plugin_unique_identifier
# get credentials schema
- schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()}
+ schema = {
+ x.to_basic_provider_config().name: x
+ for x in provider_controller.get_credentials_schema_by_type(
+ CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY
+ )
+ }
for name, value in schema.items():
if result.masked_credentials:
@@ -136,15 +143,23 @@ class ToolTransformService:
credentials = db_provider.credentials
# init tool configuration
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, _ = create_provider_encrypter(
tenant_id=db_provider.tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ config=[
+ x.to_basic_provider_config()
+ for x in provider_controller.get_credentials_schema_by_type(
+ CredentialType.of(db_provider.credential_type)
+ )
+ ],
+ cache=ToolProviderCredentialsCache(
+ tenant_id=db_provider.tenant_id,
+ provider=db_provider.provider,
+ credential_id=db_provider.id,
+ ),
)
# decrypt the credentials and mask the credentials
- decrypted_credentials = tool_configuration.decrypt(data=credentials)
- masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
+ decrypted_credentials = encrypter.decrypt(data=credentials)
+ masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
result.original_credentials = decrypted_credentials
@@ -287,16 +302,14 @@ class ToolTransformService:
if decrypt_credentials:
# init tool configuration
- tool_configuration = ProviderConfigEncrypter(
+ encrypter, _ = create_tool_provider_encrypter(
tenant_id=db_provider.tenant_id,
- config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()],
- provider_type=provider_controller.provider_type.value,
- provider_identity=provider_controller.entity.identity.name,
+ controller=provider_controller,
)
# decrypt the credentials and mask the credentials
- decrypted_credentials = tool_configuration.decrypt(data=credentials)
- masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials)
+ decrypted_credentials = encrypter.decrypt(data=credentials)
+ masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials)
result.masked_credentials = masked_credentials
@@ -306,7 +319,6 @@ class ToolTransformService:
def convert_tool_entity_to_api_entity(
tool: Union[ApiToolBundle, WorkflowTool, Tool],
tenant_id: str,
- credentials: dict | None = None,
labels: list[str] | None = None,
) -> ToolApiEntity:
"""
@@ -316,27 +328,39 @@ class ToolTransformService:
# fork tool runtime
tool = tool.fork_tool_runtime(
runtime=ToolRuntime(
- credentials=credentials or {},
+ credentials={},
tenant_id=tenant_id,
)
)
# get tool parameters
- parameters = tool.entity.parameters or []
+ base_parameters = tool.entity.parameters or []
# get tool runtime parameters
runtime_parameters = tool.get_runtime_parameters()
- # override parameters
- current_parameters = parameters.copy()
- for runtime_parameter in runtime_parameters:
- found = False
- for index, parameter in enumerate(current_parameters):
- if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
- current_parameters[index] = runtime_parameter
- found = True
- break
- if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
- current_parameters.append(runtime_parameter)
+ # merge parameters using a functional approach to avoid type issues
+ merged_parameters: list[ToolParameter] = []
+
+ # create a mapping of runtime parameters for quick lookup
+ runtime_param_map = {(rp.name, rp.form): rp for rp in runtime_parameters}
+
+ # process base parameters, replacing with runtime versions if they exist
+ for base_param in base_parameters:
+ key = (base_param.name, base_param.form)
+ if key in runtime_param_map:
+ merged_parameters.append(runtime_param_map[key])
+ else:
+ merged_parameters.append(base_param)
+
+ # add any runtime parameters that weren't in base parameters
+ for runtime_parameter in runtime_parameters:
+ if runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
+ # check if this parameter is already in merged_parameters
+ already_exists = any(
+ p.name == runtime_parameter.name and p.form == runtime_parameter.form for p in merged_parameters
+ )
+ if not already_exists:
+ merged_parameters.append(runtime_parameter)
return ToolApiEntity(
author=tool.entity.identity.author,
@@ -344,10 +368,10 @@ class ToolTransformService:
label=tool.entity.identity.label,
description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""),
output_schema=tool.entity.output_schema,
- parameters=current_parameters,
+ parameters=merged_parameters,
labels=labels or [],
)
- if isinstance(tool, ApiToolBundle):
+ elif isinstance(tool, ApiToolBundle):
return ToolApiEntity(
author=tool.author,
name=tool.operation_id or "",
@@ -356,6 +380,22 @@ class ToolTransformService:
parameters=tool.parameters,
labels=labels or [],
)
+ else:
+ # Handle WorkflowTool case
+ raise ValueError(f"Unsupported tool type: {type(tool)}")
+
+ @staticmethod
+ def convert_builtin_provider_to_credential_entity(
+ provider: BuiltinToolProvider, credentials: dict
+ ) -> ToolProviderCredentialApiEntity:
+ return ToolProviderCredentialApiEntity(
+ id=provider.id,
+ name=provider.name,
+ provider=provider.provider,
+ credential_type=CredentialType.of(provider.credential_type),
+ is_default=provider.is_default,
+ credentials=credentials,
+ )
@staticmethod
def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]:
diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py
index c6b205557a..75da5e5eaa 100644
--- a/api/services/tools/workflow_tools_manage_service.py
+++ b/api/services/tools/workflow_tools_manage_service.py
@@ -43,7 +43,7 @@ class WorkflowToolManageService:
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
- .filter(
+ .where(
WorkflowToolProvider.tenant_id == tenant_id,
# name or app_id
or_(WorkflowToolProvider.name == name, WorkflowToolProvider.app_id == workflow_app_id),
@@ -54,7 +54,7 @@ class WorkflowToolManageService:
if existing_workflow_tool_provider is not None:
raise ValueError(f"Tool with name {name} or app_id {workflow_app_id} already exists")
- app: App | None = db.session.query(App).filter(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
+ app: App | None = db.session.query(App).where(App.id == workflow_app_id, App.tenant_id == tenant_id).first()
if app is None:
raise ValueError(f"App {workflow_app_id} not found")
@@ -123,7 +123,7 @@ class WorkflowToolManageService:
# check if the name is unique
existing_workflow_tool_provider = (
db.session.query(WorkflowToolProvider)
- .filter(
+ .where(
WorkflowToolProvider.tenant_id == tenant_id,
WorkflowToolProvider.name == name,
WorkflowToolProvider.id != workflow_tool_id,
@@ -136,7 +136,7 @@ class WorkflowToolManageService:
workflow_tool_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
- .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
+ .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
@@ -144,7 +144,7 @@ class WorkflowToolManageService:
raise ValueError(f"Tool {workflow_tool_id} not found")
app: App | None = (
- db.session.query(App).filter(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
+ db.session.query(App).where(App.id == workflow_tool_provider.app_id, App.tenant_id == tenant_id).first()
)
if app is None:
@@ -186,7 +186,7 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:return: the list of tools
"""
- db_tools = db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.tenant_id == tenant_id).all()
+ db_tools = db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.tenant_id == tenant_id).all()
tools: list[WorkflowToolProviderController] = []
for provider in db_tools:
@@ -224,7 +224,7 @@ class WorkflowToolManageService:
:param tenant_id: the tenant id
:param workflow_tool_id: the workflow tool id
"""
- db.session.query(WorkflowToolProvider).filter(
+ db.session.query(WorkflowToolProvider).where(
WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id
).delete()
@@ -243,7 +243,7 @@ class WorkflowToolManageService:
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
- .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
+ .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
return cls._get_workflow_tool(tenant_id, db_tool)
@@ -259,7 +259,7 @@ class WorkflowToolManageService:
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
- .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
+ .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.app_id == workflow_app_id)
.first()
)
return cls._get_workflow_tool(tenant_id, db_tool)
@@ -275,7 +275,7 @@ class WorkflowToolManageService:
raise ValueError("Tool not found")
workflow_app: App | None = (
- db.session.query(App).filter(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
+ db.session.query(App).where(App.id == db_tool.app_id, App.tenant_id == db_tool.tenant_id).first()
)
if workflow_app is None:
@@ -318,7 +318,7 @@ class WorkflowToolManageService:
"""
db_tool: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
- .filter(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
+ .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == workflow_tool_id)
.first()
)
diff --git a/api/services/vector_service.py b/api/services/vector_service.py
index 9165139193..f9ec054593 100644
--- a/api/services/vector_service.py
+++ b/api/services/vector_service.py
@@ -36,7 +36,7 @@ class VectorService:
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
- .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
+ .where(DatasetProcessRule.id == dataset_document.dataset_process_rule_id)
.first()
)
if not processing_rule:
diff --git a/api/services/web_conversation_service.py b/api/services/web_conversation_service.py
index f698ed3084..c48e24f244 100644
--- a/api/services/web_conversation_service.py
+++ b/api/services/web_conversation_service.py
@@ -65,7 +65,7 @@ class WebConversationService:
return
pinned_conversation = (
db.session.query(PinnedConversation)
- .filter(
+ .where(
PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id == conversation_id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
@@ -97,7 +97,7 @@ class WebConversationService:
return
pinned_conversation = (
db.session.query(PinnedConversation)
- .filter(
+ .where(
PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id == conversation_id,
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py
index 8f92b3f070..a9df8d0d73 100644
--- a/api/services/webapp_auth_service.py
+++ b/api/services/webapp_auth_service.py
@@ -52,7 +52,7 @@ class WebAppAuthService:
@classmethod
def get_user_through_email(cls, email: str):
- account = db.session.query(Account).filter(Account.email == email).first()
+ account = db.session.query(Account).where(Account.email == email).first()
if not account:
return None
@@ -91,10 +91,10 @@ class WebAppAuthService:
@classmethod
def create_end_user(cls, app_code, email) -> EndUser:
- site = db.session.query(Site).filter(Site.code == app_code).first()
+ site = db.session.query(Site).where(Site.code == app_code).first()
if not site:
raise NotFound("Site not found.")
- app_model = db.session.query(App).filter(App.id == site.app_id).first()
+ app_model = db.session.query(App).where(App.id == site.app_id).first()
if not app_model:
raise NotFound("App not found.")
end_user = EndUser(
diff --git a/api/services/website_service.py b/api/services/website_service.py
index 6720932a3a..991b669737 100644
--- a/api/services/website_service.py
+++ b/api/services/website_service.py
@@ -1,6 +1,7 @@
import datetime
import json
-from typing import Any
+from dataclasses import dataclass
+from typing import Any, Optional
import requests
from flask_login import current_user
@@ -13,241 +14,392 @@ from extensions.ext_storage import storage
from services.auth.api_key_auth_service import ApiKeyAuthService
+@dataclass
+class CrawlOptions:
+ """Options for crawling operations."""
+
+ limit: int = 1
+ crawl_sub_pages: bool = False
+ only_main_content: bool = False
+ includes: Optional[str] = None
+ excludes: Optional[str] = None
+ max_depth: Optional[int] = None
+ use_sitemap: bool = True
+
+ def get_include_paths(self) -> list[str]:
+ """Get list of include paths from comma-separated string."""
+ return self.includes.split(",") if self.includes else []
+
+ def get_exclude_paths(self) -> list[str]:
+ """Get list of exclude paths from comma-separated string."""
+ return self.excludes.split(",") if self.excludes else []
+
+
+@dataclass
+class CrawlRequest:
+ """Request container for crawling operations."""
+
+ url: str
+ provider: str
+ options: CrawlOptions
+
+
+@dataclass
+class ScrapeRequest:
+ """Request container for scraping operations."""
+
+ provider: str
+ url: str
+ tenant_id: str
+ only_main_content: bool
+
+
+@dataclass
+class WebsiteCrawlApiRequest:
+ """Request container for website crawl API arguments."""
+
+ provider: str
+ url: str
+ options: dict[str, Any]
+
+ def to_crawl_request(self) -> CrawlRequest:
+ """Convert API request to internal CrawlRequest."""
+ options = CrawlOptions(
+ limit=self.options.get("limit", 1),
+ crawl_sub_pages=self.options.get("crawl_sub_pages", False),
+ only_main_content=self.options.get("only_main_content", False),
+ includes=self.options.get("includes"),
+ excludes=self.options.get("excludes"),
+ max_depth=self.options.get("max_depth"),
+ use_sitemap=self.options.get("use_sitemap", True),
+ )
+ return CrawlRequest(url=self.url, provider=self.provider, options=options)
+
+ @classmethod
+ def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest":
+ """Create from Flask-RESTful parsed arguments."""
+ provider = args.get("provider")
+ url = args.get("url")
+ options = args.get("options", {})
+
+ if not provider:
+ raise ValueError("Provider is required")
+ if not url:
+ raise ValueError("URL is required")
+ if not options:
+ raise ValueError("Options are required")
+
+ return cls(provider=provider, url=url, options=options)
+
+
+@dataclass
+class WebsiteCrawlStatusApiRequest:
+ """Request container for website crawl status API arguments."""
+
+ provider: str
+ job_id: str
+
+ @classmethod
+ def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest":
+ """Create from Flask-RESTful parsed arguments."""
+ provider = args.get("provider")
+
+ if not provider:
+ raise ValueError("Provider is required")
+ if not job_id:
+ raise ValueError("Job ID is required")
+
+ return cls(provider=provider, job_id=job_id)
+
+
class WebsiteService:
+ """Service class for website crawling operations using different providers."""
+
@classmethod
- def document_create_args_validate(cls, args: dict):
- if "url" not in args or not args["url"]:
- raise ValueError("url is required")
- if "options" not in args or not args["options"]:
- raise ValueError("options is required")
- if "limit" not in args["options"] or not args["options"]["limit"]:
- raise ValueError("limit is required")
+ def _get_credentials_and_config(cls, tenant_id: str, provider: str) -> tuple[dict, dict]:
+ """Get and validate credentials for a provider."""
+ credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
+ if not credentials or "config" not in credentials:
+ raise ValueError("No valid credentials found for the provider")
+ return credentials, credentials["config"]
@classmethod
- def crawl_url(cls, args: dict) -> dict:
- provider = args.get("provider", "")
- url = args.get("url")
- options = args.get("options", "")
- credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
- if provider == "firecrawl":
- # decrypt api_key
- api_key = encrypter.decrypt_token(
- tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
- )
- firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
- crawl_sub_pages = options.get("crawl_sub_pages", False)
- only_main_content = options.get("only_main_content", False)
- if not crawl_sub_pages:
- params = {
- "includePaths": [],
- "excludePaths": [],
- "limit": 1,
- "scrapeOptions": {"onlyMainContent": only_main_content},
- }
- else:
- includes = options.get("includes").split(",") if options.get("includes") else []
- excludes = options.get("excludes").split(",") if options.get("excludes") else []
- params = {
- "includePaths": includes,
- "excludePaths": excludes,
- "limit": options.get("limit", 1),
- "scrapeOptions": {"onlyMainContent": only_main_content},
- }
- if options.get("max_depth"):
- params["maxDepth"] = options.get("max_depth")
- job_id = firecrawl_app.crawl_url(url, params)
- website_crawl_time_cache_key = f"website_crawl_{job_id}"
- time = str(datetime.datetime.now().timestamp())
- redis_client.setex(website_crawl_time_cache_key, 3600, time)
- return {"status": "active", "job_id": job_id}
- elif provider == "watercrawl":
- # decrypt api_key
- api_key = encrypter.decrypt_token(
- tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
- )
- return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).crawl_url(url, options)
+ def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str:
+ """Decrypt and return the API key from config."""
+ api_key = config.get("api_key")
+ if not api_key:
+ raise ValueError("API key not found in configuration")
+ return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key)
- elif provider == "jinareader":
- api_key = encrypter.decrypt_token(
- tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
- )
- crawl_sub_pages = options.get("crawl_sub_pages", False)
- if not crawl_sub_pages:
- response = requests.get(
- f"https://r.jina.ai/{url}",
- headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
- )
- if response.json().get("code") != 200:
- raise ValueError("Failed to crawl")
- return {"status": "active", "data": response.json().get("data")}
- else:
- response = requests.post(
- "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
- json={
- "url": url,
- "maxPages": options.get("limit", 1),
- "useSitemap": options.get("use_sitemap", True),
- },
- headers={
- "Content-Type": "application/json",
- "Authorization": f"Bearer {api_key}",
- },
- )
- if response.json().get("code") != 200:
- raise ValueError("Failed to crawl")
- return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")}
+ @classmethod
+ def document_create_args_validate(cls, args: dict) -> None:
+ """Validate arguments for document creation."""
+ try:
+ WebsiteCrawlApiRequest.from_args(args)
+ except ValueError as e:
+ raise ValueError(f"Invalid arguments: {e}")
+
+ @classmethod
+ def crawl_url(cls, api_request: WebsiteCrawlApiRequest) -> dict[str, Any]:
+ """Crawl a URL using the specified provider with typed request."""
+ request = api_request.to_crawl_request()
+
+ _, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider)
+ api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
+
+ if request.provider == "firecrawl":
+ return cls._crawl_with_firecrawl(request=request, api_key=api_key, config=config)
+ elif request.provider == "watercrawl":
+ return cls._crawl_with_watercrawl(request=request, api_key=api_key, config=config)
+ elif request.provider == "jinareader":
+ return cls._crawl_with_jinareader(request=request, api_key=api_key)
else:
raise ValueError("Invalid provider")
@classmethod
- def get_crawl_status(cls, job_id: str, provider: str) -> dict:
- credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider)
- if provider == "firecrawl":
- # decrypt api_key
- api_key = encrypter.decrypt_token(
- tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
- )
- firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
- result = firecrawl_app.check_crawl_status(job_id)
- crawl_status_data = {
- "status": result.get("status", "active"),
- "job_id": job_id,
- "total": result.get("total", 0),
- "current": result.get("current", 0),
- "data": result.get("data", []),
+ def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
+ firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
+
+ if not request.options.crawl_sub_pages:
+ params = {
+ "includePaths": [],
+ "excludePaths": [],
+ "limit": 1,
+ "scrapeOptions": {"onlyMainContent": request.options.only_main_content},
}
- if crawl_status_data["status"] == "completed":
- website_crawl_time_cache_key = f"website_crawl_{job_id}"
- start_time = redis_client.get(website_crawl_time_cache_key)
- if start_time:
- end_time = datetime.datetime.now().timestamp()
- time_consuming = abs(end_time - float(start_time))
- crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
- redis_client.delete(website_crawl_time_cache_key)
- elif provider == "watercrawl":
- # decrypt api_key
- api_key = encrypter.decrypt_token(
- tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
+ else:
+ params = {
+ "includePaths": request.options.get_include_paths(),
+ "excludePaths": request.options.get_exclude_paths(),
+ "limit": request.options.limit,
+ "scrapeOptions": {"onlyMainContent": request.options.only_main_content},
+ }
+ if request.options.max_depth:
+ params["maxDepth"] = request.options.max_depth
+
+ job_id = firecrawl_app.crawl_url(request.url, params)
+ website_crawl_time_cache_key = f"website_crawl_{job_id}"
+ time = str(datetime.datetime.now().timestamp())
+ redis_client.setex(website_crawl_time_cache_key, 3600, time)
+ return {"status": "active", "job_id": job_id}
+
+ @classmethod
+ def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]:
+ # Convert CrawlOptions back to dict format for WaterCrawlProvider
+ options = {
+ "limit": request.options.limit,
+ "crawl_sub_pages": request.options.crawl_sub_pages,
+ "only_main_content": request.options.only_main_content,
+ "includes": request.options.includes,
+ "excludes": request.options.excludes,
+ "max_depth": request.options.max_depth,
+ "use_sitemap": request.options.use_sitemap,
+ }
+ return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url(
+ url=request.url, options=options
+ )
+
+ @classmethod
+ def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]:
+ if not request.options.crawl_sub_pages:
+ response = requests.get(
+ f"https://r.jina.ai/{request.url}",
+ headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
)
- crawl_status_data = WaterCrawlProvider(
- api_key, credentials.get("config").get("base_url", None)
- ).get_crawl_status(job_id)
- elif provider == "jinareader":
- api_key = encrypter.decrypt_token(
- tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key")
+ if response.json().get("code") != 200:
+ raise ValueError("Failed to crawl")
+ return {"status": "active", "data": response.json().get("data")}
+ else:
+ response = requests.post(
+ "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app",
+ json={
+ "url": request.url,
+ "maxPages": request.options.limit,
+ "useSitemap": request.options.use_sitemap,
+ },
+ headers={
+ "Content-Type": "application/json",
+ "Authorization": f"Bearer {api_key}",
+ },
)
+ if response.json().get("code") != 200:
+ raise ValueError("Failed to crawl")
+ return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")}
+
+ @classmethod
+ def get_crawl_status(cls, job_id: str, provider: str) -> dict[str, Any]:
+ """Get crawl status using string parameters."""
+ api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id)
+ return cls.get_crawl_status_typed(api_request)
+
+ @classmethod
+ def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> dict[str, Any]:
+ """Get crawl status using typed request."""
+ _, config = cls._get_credentials_and_config(current_user.current_tenant_id, api_request.provider)
+ api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config)
+
+ if api_request.provider == "firecrawl":
+ return cls._get_firecrawl_status(api_request.job_id, api_key, config)
+ elif api_request.provider == "watercrawl":
+ return cls._get_watercrawl_status(api_request.job_id, api_key, config)
+ elif api_request.provider == "jinareader":
+ return cls._get_jinareader_status(api_request.job_id, api_key)
+ else:
+ raise ValueError("Invalid provider")
+
+ @classmethod
+ def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
+ firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
+ result = firecrawl_app.check_crawl_status(job_id)
+ crawl_status_data = {
+ "status": result.get("status", "active"),
+ "job_id": job_id,
+ "total": result.get("total", 0),
+ "current": result.get("current", 0),
+ "data": result.get("data", []),
+ }
+ if crawl_status_data["status"] == "completed":
+ website_crawl_time_cache_key = f"website_crawl_{job_id}"
+ start_time = redis_client.get(website_crawl_time_cache_key)
+ if start_time:
+ end_time = datetime.datetime.now().timestamp()
+ time_consuming = abs(end_time - float(start_time))
+ crawl_status_data["time_consuming"] = f"{time_consuming:.2f}"
+ redis_client.delete(website_crawl_time_cache_key)
+ return crawl_status_data
+
+ @classmethod
+ def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]:
+ return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id)
+
+ @classmethod
+ def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]:
+ response = requests.post(
+ "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
+ headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
+ json={"taskId": job_id},
+ )
+ data = response.json().get("data", {})
+ crawl_status_data = {
+ "status": data.get("status", "active"),
+ "job_id": job_id,
+ "total": len(data.get("urls", [])),
+ "current": len(data.get("processed", [])) + len(data.get("failed", [])),
+ "data": [],
+ "time_consuming": data.get("duration", 0) / 1000,
+ }
+
+ if crawl_status_data["status"] == "completed":
response = requests.post(
"https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
- json={"taskId": job_id},
+ json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
)
data = response.json().get("data", {})
- crawl_status_data = {
- "status": data.get("status", "active"),
- "job_id": job_id,
- "total": len(data.get("urls", [])),
- "current": len(data.get("processed", [])) + len(data.get("failed", [])),
- "data": [],
- "time_consuming": data.get("duration", 0) / 1000,
- }
-
- if crawl_status_data["status"] == "completed":
- response = requests.post(
- "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
- headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
- json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())},
- )
- data = response.json().get("data", {})
- formatted_data = [
- {
- "title": item.get("data", {}).get("title"),
- "source_url": item.get("data", {}).get("url"),
- "description": item.get("data", {}).get("description"),
- "markdown": item.get("data", {}).get("content"),
- }
- for item in data.get("processed", {}).values()
- ]
- crawl_status_data["data"] = formatted_data
- else:
- raise ValueError("Invalid provider")
+ formatted_data = [
+ {
+ "title": item.get("data", {}).get("title"),
+ "source_url": item.get("data", {}).get("url"),
+ "description": item.get("data", {}).get("description"),
+ "markdown": item.get("data", {}).get("content"),
+ }
+ for item in data.get("processed", {}).values()
+ ]
+ crawl_status_data["data"] = formatted_data
return crawl_status_data
@classmethod
def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[str, Any] | None:
- credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
- # decrypt api_key
- api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
+ _, config = cls._get_credentials_and_config(tenant_id, provider)
+ api_key = cls._get_decrypted_api_key(tenant_id, config)
if provider == "firecrawl":
- crawl_data: list[dict[str, Any]] | None = None
- file_key = "website_files/" + job_id + ".txt"
- if storage.exists(file_key):
- stored_data = storage.load_once(file_key)
- if stored_data:
- crawl_data = json.loads(stored_data.decode("utf-8"))
- else:
- firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
- result = firecrawl_app.check_crawl_status(job_id)
- if result.get("status") != "completed":
- raise ValueError("Crawl job is not completed")
- crawl_data = result.get("data")
-
- if crawl_data:
- for item in crawl_data:
- if item.get("source_url") == url:
- return dict(item)
- return None
+ return cls._get_firecrawl_url_data(job_id, url, api_key, config)
elif provider == "watercrawl":
- api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
- return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).get_crawl_url_data(
- job_id, url
- )
+ return cls._get_watercrawl_url_data(job_id, url, api_key, config)
elif provider == "jinareader":
- if not job_id:
- response = requests.get(
- f"https://r.jina.ai/{url}",
- headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
- )
- if response.json().get("code") != 200:
- raise ValueError("Failed to crawl")
- return dict(response.json().get("data", {}))
- else:
- # Get crawl status first
- status_response = requests.post(
- "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
- headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
- json={"taskId": job_id},
- )
- status_data = status_response.json().get("data", {})
- if status_data.get("status") != "completed":
- raise ValueError("Crawl job is not completed")
-
- # Get processed data
- data_response = requests.post(
- "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
- headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
- json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
- )
- processed_data = data_response.json().get("data", {})
- for item in processed_data.get("processed", {}).values():
- if item.get("data", {}).get("url") == url:
- return dict(item.get("data", {}))
- return None
+ return cls._get_jinareader_url_data(job_id, url, api_key)
else:
raise ValueError("Invalid provider")
@classmethod
- def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict:
- credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider)
- if provider == "firecrawl":
- # decrypt api_key
- api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
- firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None))
- params = {"onlyMainContent": only_main_content}
- result = firecrawl_app.scrape_url(url, params)
- return result
- elif provider == "watercrawl":
- api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key"))
- return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).scrape_url(url)
+ def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
+ crawl_data: list[dict[str, Any]] | None = None
+ file_key = "website_files/" + job_id + ".txt"
+ if storage.exists(file_key):
+ stored_data = storage.load_once(file_key)
+ if stored_data:
+ crawl_data = json.loads(stored_data.decode("utf-8"))
+ else:
+ firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
+ result = firecrawl_app.check_crawl_status(job_id)
+ if result.get("status") != "completed":
+ raise ValueError("Crawl job is not completed")
+ crawl_data = result.get("data")
+
+ if crawl_data:
+ for item in crawl_data:
+ if item.get("source_url") == url:
+ return dict(item)
+ return None
+
+ @classmethod
+ def _get_watercrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None:
+ return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url)
+
+ @classmethod
+ def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None:
+ if not job_id:
+ response = requests.get(
+ f"https://r.jina.ai/{url}",
+ headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"},
+ )
+ if response.json().get("code") != 200:
+ raise ValueError("Failed to crawl")
+ return dict(response.json().get("data", {}))
+ else:
+ # Get crawl status first
+ status_response = requests.post(
+ "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
+ headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
+ json={"taskId": job_id},
+ )
+ status_data = status_response.json().get("data", {})
+ if status_data.get("status") != "completed":
+ raise ValueError("Crawl job is not completed")
+
+ # Get processed data
+ data_response = requests.post(
+ "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app",
+ headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"},
+ json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())},
+ )
+ processed_data = data_response.json().get("data", {})
+ for item in processed_data.get("processed", {}).values():
+ if item.get("data", {}).get("url") == url:
+ return dict(item.get("data", {}))
+ return None
+
+ @classmethod
+ def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict[str, Any]:
+ request = ScrapeRequest(provider=provider, url=url, tenant_id=tenant_id, only_main_content=only_main_content)
+
+ _, config = cls._get_credentials_and_config(tenant_id=request.tenant_id, provider=request.provider)
+ api_key = cls._get_decrypted_api_key(tenant_id=request.tenant_id, config=config)
+
+ if request.provider == "firecrawl":
+ return cls._scrape_with_firecrawl(request=request, api_key=api_key, config=config)
+ elif request.provider == "watercrawl":
+ return cls._scrape_with_watercrawl(request=request, api_key=api_key, config=config)
else:
raise ValueError("Invalid provider")
+
+ @classmethod
+ def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
+ firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url"))
+ params = {"onlyMainContent": request.only_main_content}
+ return firecrawl_app.scrape_url(url=request.url, params=params)
+
+ @classmethod
+ def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]:
+ return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url)
diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py
index 2b0d57bdfd..abf6824d73 100644
--- a/api/services/workflow/workflow_converter.py
+++ b/api/services/workflow/workflow_converter.py
@@ -620,7 +620,7 @@ class WorkflowConverter:
"""
api_based_extension = (
db.session.query(APIBasedExtension)
- .filter(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
+ .where(APIBasedExtension.tenant_id == tenant_id, APIBasedExtension.id == api_based_extension_id)
.first()
)
diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py
index 44fd72b5e4..3164e010b4 100644
--- a/api/services/workflow_draft_variable_service.py
+++ b/api/services/workflow_draft_variable_service.py
@@ -5,9 +5,9 @@ from collections.abc import Mapping, Sequence
from enum import StrEnum
from typing import Any, ClassVar
-from sqlalchemy import Engine, orm, select
+from sqlalchemy import Engine, orm
from sqlalchemy.dialects.postgresql import insert
-from sqlalchemy.orm import Session
+from sqlalchemy.orm import Session, sessionmaker
from sqlalchemy.sql.expression import and_, or_
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -25,7 +25,8 @@ from factories.file_factory import StorageKeyLoader
from factories.variable_factory import build_segment, segment_to_variable
from models import App, Conversation
from models.enums import DraftVariableType
-from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
+from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable
+from repositories.factory import DifyAPIRepositoryFactory
_logger = logging.getLogger(__name__)
@@ -117,10 +118,27 @@ class WorkflowDraftVariableService:
_session: Session
def __init__(self, session: Session) -> None:
+ """
+ Initialize the WorkflowDraftVariableService with a SQLAlchemy session.
+
+ Args:
+ session (Session): The SQLAlchemy session used to execute database queries.
+ The provided session must be bound to an `Engine` object, not a specific `Connection`.
+
+ Raises:
+ AssertionError: If the provided session is not bound to an `Engine` object.
+ """
self._session = session
+ engine = session.get_bind()
+ # Ensure the session is bound to a engine.
+ assert isinstance(engine, Engine)
+ session_maker = sessionmaker(bind=engine, expire_on_commit=False)
+ self._api_node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker
+ )
def get_variable(self, variable_id: str) -> WorkflowDraftVariable | None:
- return self._session.query(WorkflowDraftVariable).filter(WorkflowDraftVariable.id == variable_id).first()
+ return self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == variable_id).first()
def get_draft_variables_by_selectors(
self,
@@ -148,7 +166,7 @@ class WorkflowDraftVariableService:
def list_variables_without_values(self, app_id: str, page: int, limit: int) -> WorkflowDraftVariableList:
criteria = WorkflowDraftVariable.app_id == app_id
total = None
- query = self._session.query(WorkflowDraftVariable).filter(criteria)
+ query = self._session.query(WorkflowDraftVariable).where(criteria)
if page == 1:
total = query.count()
variables = (
@@ -167,7 +185,7 @@ class WorkflowDraftVariableService:
WorkflowDraftVariable.app_id == app_id,
WorkflowDraftVariable.node_id == node_id,
)
- query = self._session.query(WorkflowDraftVariable).filter(*criteria)
+ query = self._session.query(WorkflowDraftVariable).where(*criteria)
variables = query.order_by(WorkflowDraftVariable.created_at.desc()).all()
return WorkflowDraftVariableList(variables=variables)
@@ -248,8 +266,7 @@ class WorkflowDraftVariableService:
_logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name)
return None
- query = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == variable.node_execution_id)
- node_exec = self._session.scalars(query).first()
+ node_exec = self._api_node_execution_repo.get_execution_by_id(variable.node_execution_id)
if node_exec is None:
_logger.warning(
"Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s",
@@ -298,6 +315,8 @@ class WorkflowDraftVariableService:
def reset_variable(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None:
variable_type = variable.get_variable_type()
+ if variable_type == DraftVariableType.SYS and not is_system_variable_editable(variable.name):
+ raise VariableResetError(f"cannot reset system variable, variable_id={variable.id}")
if variable_type == DraftVariableType.CONVERSATION:
return self._reset_conv_var(workflow, variable)
else:
@@ -309,7 +328,7 @@ class WorkflowDraftVariableService:
def delete_workflow_variables(self, app_id: str):
(
self._session.query(WorkflowDraftVariable)
- .filter(WorkflowDraftVariable.app_id == app_id)
+ .where(WorkflowDraftVariable.app_id == app_id)
.delete(synchronize_session=False)
)
@@ -360,7 +379,7 @@ class WorkflowDraftVariableService:
if conv_id is not None:
conversation = (
self._session.query(Conversation)
- .filter(
+ .where(
Conversation.id == conv_id,
Conversation.app_id == workflow.app_id,
)
diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py
index 483c0d3086..e43999a8c9 100644
--- a/api/services/workflow_run_service.py
+++ b/api/services/workflow_run_service.py
@@ -2,9 +2,9 @@ import threading
from collections.abc import Sequence
from typing import Optional
+from sqlalchemy.orm import sessionmaker
+
import contexts
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
-from core.workflow.repositories.workflow_node_execution_repository import OrderConfig
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import (
@@ -15,10 +15,18 @@ from models import (
WorkflowRun,
WorkflowRunTriggeredFrom,
)
-from models.workflow import WorkflowNodeExecutionTriggeredFrom
+from repositories.factory import DifyAPIRepositoryFactory
class WorkflowRunService:
+ def __init__(self):
+ """Initialize WorkflowRunService with repository dependencies."""
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker
+ )
+ self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+
def get_paginate_advanced_chat_workflow_runs(self, app_model: App, args: dict) -> InfiniteScrollPagination:
"""
Get advanced chat app workflow run list
@@ -62,45 +70,16 @@ class WorkflowRunService:
:param args: request args
"""
limit = int(args.get("limit", 20))
+ last_id = args.get("last_id")
- base_query = db.session.query(WorkflowRun).filter(
- WorkflowRun.tenant_id == app_model.tenant_id,
- WorkflowRun.app_id == app_model.id,
- WorkflowRun.triggered_from == WorkflowRunTriggeredFrom.DEBUGGING.value,
+ return self._workflow_run_repo.get_paginated_workflow_runs(
+ tenant_id=app_model.tenant_id,
+ app_id=app_model.id,
+ triggered_from=WorkflowRunTriggeredFrom.DEBUGGING.value,
+ limit=limit,
+ last_id=last_id,
)
- if args.get("last_id"):
- last_workflow_run = base_query.filter(
- WorkflowRun.id == args.get("last_id"),
- ).first()
-
- if not last_workflow_run:
- raise ValueError("Last workflow run not exists")
-
- workflow_runs = (
- base_query.filter(
- WorkflowRun.created_at < last_workflow_run.created_at, WorkflowRun.id != last_workflow_run.id
- )
- .order_by(WorkflowRun.created_at.desc())
- .limit(limit)
- .all()
- )
- else:
- workflow_runs = base_query.order_by(WorkflowRun.created_at.desc()).limit(limit).all()
-
- has_more = False
- if len(workflow_runs) == limit:
- current_page_first_workflow_run = workflow_runs[-1]
- rest_count = base_query.filter(
- WorkflowRun.created_at < current_page_first_workflow_run.created_at,
- WorkflowRun.id != current_page_first_workflow_run.id,
- ).count()
-
- if rest_count > 0:
- has_more = True
-
- return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
-
def get_workflow_run(self, app_model: App, run_id: str) -> Optional[WorkflowRun]:
"""
Get workflow run detail
@@ -108,18 +87,12 @@ class WorkflowRunService:
:param app_model: app model
:param run_id: workflow run id
"""
- workflow_run = (
- db.session.query(WorkflowRun)
- .filter(
- WorkflowRun.tenant_id == app_model.tenant_id,
- WorkflowRun.app_id == app_model.id,
- WorkflowRun.id == run_id,
- )
- .first()
+ return self._workflow_run_repo.get_workflow_run_by_id(
+ tenant_id=app_model.tenant_id,
+ app_id=app_model.id,
+ run_id=run_id,
)
- return workflow_run
-
def get_workflow_run_node_executions(
self,
app_model: App,
@@ -137,17 +110,13 @@ class WorkflowRunService:
if not workflow_run:
return []
- repository = SQLAlchemyWorkflowNodeExecutionRepository(
- session_factory=db.engine,
- user=user,
- app_id=app_model.id,
- triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
- )
+ # Get tenant_id from user
+ tenant_id = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
+ if tenant_id is None:
+ raise ValueError("User tenant_id cannot be None")
- # Use the repository to get the database models directly
- order_config = OrderConfig(order_by=["index"], order_direction="desc")
- workflow_node_executions = repository.get_db_models_by_workflow_run(
- workflow_run_id=run_id, order_config=order_config
+ return self._node_execution_service_repo.get_executions_by_workflow_run(
+ tenant_id=tenant_id,
+ app_id=app_model.id,
+ workflow_run_id=run_id,
)
-
- return workflow_node_executions
diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py
index 2be57fd51c..e9f21fc5f1 100644
--- a/api/services/workflow_service.py
+++ b/api/services/workflow_service.py
@@ -2,23 +2,22 @@ import json
import time
import uuid
from collections.abc import Callable, Generator, Mapping, Sequence
-from datetime import UTC, datetime
-from typing import Any, Optional
+from typing import Any, Optional, cast
from uuid import uuid4
from sqlalchemy import select
-from sqlalchemy.orm import Session
+from sqlalchemy.orm import Session, sessionmaker
from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
-from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
+from core.variables.variables import VariableUnion
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
-from core.workflow.enums import SystemVariableKey
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes import NodeType
@@ -28,10 +27,12 @@ from core.workflow.nodes.event import RunCompletedEvent
from core.workflow.nodes.event.types import NodeEvent
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.start.entities import StartNodeData
+from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
from factories.file_factory import build_from_mapping, build_from_mappings
+from libs.datetime_utils import naive_utc_now
from models.account import Account
from models.model import App, AppMode
from models.tools import WorkflowToolProvider
@@ -41,6 +42,7 @@ from models.workflow import (
WorkflowNodeExecutionTriggeredFrom,
WorkflowType,
)
+from repositories.factory import DifyAPIRepositoryFactory
from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError
from services.workflow.workflow_converter import WorkflowConverter
@@ -57,26 +59,37 @@ class WorkflowService:
Workflow Service
"""
- def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None:
- # TODO(QuantumGhost): This query is not fully covered by index.
- criteria = (
- WorkflowNodeExecutionModel.tenant_id == app_model.tenant_id,
- WorkflowNodeExecutionModel.app_id == app_model.id,
- WorkflowNodeExecutionModel.workflow_id == workflow.id,
- WorkflowNodeExecutionModel.node_id == node_id,
+ def __init__(self, session_maker: sessionmaker | None = None):
+ """Initialize WorkflowService with repository dependencies."""
+ if session_maker is None:
+ session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
+ self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
+ session_maker
)
- node_exec = (
- db.session.query(WorkflowNodeExecutionModel)
- .filter(*criteria)
- .order_by(WorkflowNodeExecutionModel.created_at.desc())
- .first()
+
+ def get_node_last_run(self, app_model: App, workflow: Workflow, node_id: str) -> WorkflowNodeExecutionModel | None:
+ """
+ Get the most recent execution for a specific node.
+
+ Args:
+ app_model: The application model
+ workflow: The workflow model
+ node_id: The node identifier
+
+ Returns:
+ The most recent WorkflowNodeExecutionModel for the node, or None if not found
+ """
+ return self._node_execution_service_repo.get_node_last_execution(
+ tenant_id=app_model.tenant_id,
+ app_id=app_model.id,
+ workflow_id=workflow.id,
+ node_id=node_id,
)
- return node_exec
def is_workflow_exist(self, app_model: App) -> bool:
return (
db.session.query(Workflow)
- .filter(
+ .where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.version == Workflow.VERSION_DRAFT,
@@ -91,7 +104,7 @@ class WorkflowService:
# fetch draft workflow by app_model
workflow = (
db.session.query(Workflow)
- .filter(
+ .where(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.version == "draft"
)
.first()
@@ -104,7 +117,7 @@ class WorkflowService:
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
- .filter(
+ .where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == workflow_id,
@@ -128,7 +141,7 @@ class WorkflowService:
# fetch published workflow by workflow_id
workflow = (
db.session.query(Workflow)
- .filter(
+ .where(
Workflow.tenant_id == app_model.tenant_id,
Workflow.app_id == app_model.id,
Workflow.id == app_model.workflow_id,
@@ -219,7 +232,7 @@ class WorkflowService:
workflow.graph = json.dumps(graph)
workflow.features = json.dumps(features)
workflow.updated_by = account.id
- workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ workflow.updated_at = naive_utc_now()
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
@@ -255,7 +268,7 @@ class WorkflowService:
tenant_id=app_model.tenant_id,
app_id=app_model.id,
type=draft_workflow.type,
- version=Workflow.version_from_datetime(datetime.now(UTC).replace(tzinfo=None)),
+ version=Workflow.version_from_datetime(naive_utc_now()),
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
@@ -357,7 +370,7 @@ class WorkflowService:
else:
variable_pool = VariablePool(
- system_variables={},
+ system_variables=SystemVariable.empty(),
user_inputs=user_inputs,
environment_variables=draft_workflow.environment_variables,
conversation_variables=[],
@@ -396,7 +409,7 @@ class WorkflowService:
node_execution.workflow_id = draft_workflow.id
# Create repository and save the node execution
- repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=db.engine,
user=account,
app_id=app_model.id,
@@ -404,8 +417,9 @@ class WorkflowService:
)
repository.save(node_execution)
- # Convert node_execution to WorkflowNodeExecution after save
- workflow_node_execution = repository.to_db_model(node_execution)
+ workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id)
+ if workflow_node_execution is None:
+ raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
with Session(bind=db.engine) as session, session.begin():
draft_var_saver = DraftVariableSaver(
@@ -418,6 +432,7 @@ class WorkflowService:
)
draft_var_saver.save(process_data=node_execution.process_data, outputs=node_execution.outputs)
session.commit()
+
return workflow_node_execution
def run_free_workflow_node(
@@ -429,7 +444,7 @@ class WorkflowService:
# run draft workflow node
start_at = time.perf_counter()
- workflow_node_execution = self._handle_node_run_result(
+ node_execution = self._handle_node_run_result(
invoke_node_fn=lambda: WorkflowEntry.run_free_node(
node_id=node_id,
node_data=node_data,
@@ -441,7 +456,7 @@ class WorkflowService:
node_id=node_id,
)
- return workflow_node_execution
+ return node_execution
def _handle_node_run_result(
self,
@@ -450,10 +465,10 @@ class WorkflowService:
node_id: str,
) -> WorkflowNodeExecution:
try:
- node_instance, generator = invoke_node_fn()
+ node, node_events = invoke_node_fn()
node_run_result: NodeRunResult | None = None
- for event in generator:
+ for event in node_events:
if isinstance(event, RunCompletedEvent):
node_run_result = event.run_result
@@ -464,18 +479,18 @@ class WorkflowService:
if not node_run_result:
raise ValueError("Node run failed with no run result")
# single step debug mode error handling return
- if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error:
+ if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.continue_on_error:
node_error_args: dict[str, Any] = {
"status": WorkflowNodeExecutionStatus.EXCEPTION,
"error": node_run_result.error,
"inputs": node_run_result.inputs,
- "metadata": {"error_strategy": node_instance.node_data.error_strategy},
+ "metadata": {"error_strategy": node.error_strategy},
}
- if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE:
+ if node.error_strategy is ErrorStrategy.DEFAULT_VALUE:
node_run_result = NodeRunResult(
**node_error_args,
outputs={
- **node_instance.node_data.default_value_dict,
+ **node.default_value_dict,
"error_message": node_run_result.error,
"error_type": node_run_result.error_type,
},
@@ -494,10 +509,10 @@ class WorkflowService:
)
error = node_run_result.error if not run_succeeded else None
except WorkflowNodeRunFailedError as e:
- node_instance = e.node_instance
+ node = e._node
run_succeeded = False
node_run_result = None
- error = e.error
+ error = e._error
# Create a NodeExecution domain model
node_execution = WorkflowNodeExecution(
@@ -505,11 +520,11 @@ class WorkflowService:
workflow_id="", # This is a single-step execution, so no workflow ID
index=1,
node_id=node_id,
- node_type=node_instance.node_type,
- title=node_instance.node_data.title,
+ node_type=node.type_,
+ title=node.title,
elapsed_time=time.perf_counter() - start_at,
- created_at=datetime.now(UTC).replace(tzinfo=None),
- finished_at=datetime.now(UTC).replace(tzinfo=None),
+ created_at=naive_utc_now(),
+ finished_at=naive_utc_now(),
)
if run_succeeded and node_run_result:
@@ -606,7 +621,7 @@ class WorkflowService:
setattr(workflow, field, value)
workflow.updated_by = account_id
- workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
+ workflow.updated_at = naive_utc_now()
return workflow
@@ -643,7 +658,7 @@ class WorkflowService:
# Check if there's a tool provider using this specific workflow version
tool_provider = (
session.query(WorkflowToolProvider)
- .filter(
+ .where(
WorkflowToolProvider.tenant_id == workflow.tenant_id,
WorkflowToolProvider.app_id == workflow.app_id,
WorkflowToolProvider.version == workflow.version,
@@ -671,36 +686,30 @@ def _setup_variable_pool(
):
# Only inject system variables for START node type.
if node_type == NodeType.START:
- # Create a variable pool.
- system_inputs: dict[SystemVariableKey, Any] = {
- # From inputs:
- SystemVariableKey.FILES: files,
- SystemVariableKey.USER_ID: user_id,
- # From workflow model
- SystemVariableKey.APP_ID: workflow.app_id,
- SystemVariableKey.WORKFLOW_ID: workflow.id,
- # Randomly generated.
- SystemVariableKey.WORKFLOW_EXECUTION_ID: str(uuid.uuid4()),
- }
+ system_variable = SystemVariable(
+ user_id=user_id,
+ app_id=workflow.app_id,
+ workflow_id=workflow.id,
+ files=files or [],
+ workflow_execution_id=str(uuid.uuid4()),
+ )
# Only add chatflow-specific variables for non-workflow types
if workflow.type != WorkflowType.WORKFLOW.value:
- system_inputs.update(
- {
- SystemVariableKey.QUERY: query,
- SystemVariableKey.CONVERSATION_ID: conversation_id,
- SystemVariableKey.DIALOGUE_COUNT: 0,
- }
- )
+ system_variable.query = query
+ system_variable.conversation_id = conversation_id
+ system_variable.dialogue_count = 0
else:
- system_inputs = {}
+ system_variable = SystemVariable.empty()
# init variable pool
variable_pool = VariablePool(
- system_variables=system_inputs,
+ system_variables=system_variable,
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
- conversation_variables=conversation_variables,
+ # Based on the definition of `VariableUnion`,
+ # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
+ conversation_variables=cast(list[VariableUnion], conversation_variables), #
)
return variable_pool
diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py
index 125e0c1b1e..d4fc68a084 100644
--- a/api/services/workspace_service.py
+++ b/api/services/workspace_service.py
@@ -25,13 +25,13 @@ class WorkspaceService:
# Get role of user
tenant_account_join = (
db.session.query(TenantAccountJoin)
- .filter(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id)
+ .where(TenantAccountJoin.tenant_id == tenant.id, TenantAccountJoin.account_id == current_user.id)
.first()
)
assert tenant_account_join is not None, "TenantAccountJoin not found"
tenant_info["role"] = tenant_account_join.role
- can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo
+ can_replace_logo = FeatureService.get_features(tenant.id).can_replace_logo
if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]):
base_url = dify_config.FILES_URL
diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py
index 75d648e1b7..204c1a4f5b 100644
--- a/api/tasks/add_document_to_index_task.py
+++ b/api/tasks/add_document_to_index_task.py
@@ -25,7 +25,7 @@ def add_document_to_index_task(dataset_document_id: str):
logging.info(click.style("Start add document to index: {}".format(dataset_document_id), fg="green"))
start_at = time.perf_counter()
- dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document_id).first()
+ dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document_id).first()
if not dataset_document:
logging.info(click.style("Document not found: {}".format(dataset_document_id), fg="red"))
db.session.close()
@@ -43,7 +43,7 @@ def add_document_to_index_task(dataset_document_id: str):
segments = (
db.session.query(DocumentSegment)
- .filter(
+ .where(
DocumentSegment.document_id == dataset_document.id,
DocumentSegment.enabled == False,
DocumentSegment.status == "completed",
@@ -86,12 +86,10 @@ def add_document_to_index_task(dataset_document_id: str):
index_processor.load(dataset, documents)
# delete auto disable log
- db.session.query(DatasetAutoDisableLog).filter(
- DatasetAutoDisableLog.document_id == dataset_document.id
- ).delete()
+ db.session.query(DatasetAutoDisableLog).where(DatasetAutoDisableLog.document_id == dataset_document.id).delete()
# update segment to enable
- db.session.query(DocumentSegment).filter(DocumentSegment.document_id == dataset_document.id).update(
+ db.session.query(DocumentSegment).where(DocumentSegment.document_id == dataset_document.id).update(
{
DocumentSegment.enabled: True,
DocumentSegment.disabled_at: None,
diff --git a/api/tasks/annotation/batch_import_annotations_task.py b/api/tasks/annotation/batch_import_annotations_task.py
index 6144a4fe3e..6d48f5df89 100644
--- a/api/tasks/annotation/batch_import_annotations_task.py
+++ b/api/tasks/annotation/batch_import_annotations_task.py
@@ -29,7 +29,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
start_at = time.perf_counter()
indexing_cache_key = "app_annotation_batch_import_{}".format(str(job_id))
# get app info
- app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
+ app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
if app:
try:
@@ -48,7 +48,7 @@ def batch_import_annotations_task(job_id: str, content_list: list[dict], app_id:
documents.append(document)
# if annotation reply is enabled , batch add annotations' index
app_annotation_setting = (
- db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
+ db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
)
if app_annotation_setting:
diff --git a/api/tasks/annotation/disable_annotation_reply_task.py b/api/tasks/annotation/disable_annotation_reply_task.py
index 747fce5784..5d5d1d3ad8 100644
--- a/api/tasks/annotation/disable_annotation_reply_task.py
+++ b/api/tasks/annotation/disable_annotation_reply_task.py
@@ -19,16 +19,14 @@ def disable_annotation_reply_task(job_id: str, app_id: str, tenant_id: str):
logging.info(click.style("Start delete app annotations index: {}".format(app_id), fg="green"))
start_at = time.perf_counter()
# get app info
- app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
- annotations_count = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).count()
+ app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
+ annotations_count = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).count()
if not app:
logging.info(click.style("App not found: {}".format(app_id), fg="red"))
db.session.close()
return
- app_annotation_setting = (
- db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
- )
+ app_annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if not app_annotation_setting:
logging.info(click.style("App annotation setting not found: {}".format(app_id), fg="red"))
diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py
index c04f1be845..12d10df442 100644
--- a/api/tasks/annotation/enable_annotation_reply_task.py
+++ b/api/tasks/annotation/enable_annotation_reply_task.py
@@ -30,14 +30,14 @@ def enable_annotation_reply_task(
logging.info(click.style("Start add app annotation to index: {}".format(app_id), fg="green"))
start_at = time.perf_counter()
# get app info
- app = db.session.query(App).filter(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
+ app = db.session.query(App).where(App.id == app_id, App.tenant_id == tenant_id, App.status == "normal").first()
if not app:
logging.info(click.style("App not found: {}".format(app_id), fg="red"))
db.session.close()
return
- annotations = db.session.query(MessageAnnotation).filter(MessageAnnotation.app_id == app_id).all()
+ annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).all()
enable_app_annotation_key = "enable_app_annotation_{}".format(str(app_id))
enable_app_annotation_job_key = "enable_app_annotation_job_{}".format(str(job_id))
@@ -46,9 +46,7 @@ def enable_annotation_reply_task(
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_provider_name, embedding_model_name, "annotation"
)
- annotation_setting = (
- db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.app_id == app_id).first()
- )
+ annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
if annotation_setting:
if dataset_collection_binding.id != annotation_setting.collection_binding_id:
old_dataset_collection_binding = (
diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py
index 97efc47b33..49bff72a96 100644
--- a/api/tasks/batch_clean_document_task.py
+++ b/api/tasks/batch_clean_document_task.py
@@ -27,12 +27,12 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
start_at = time.perf_counter()
try:
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
- segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all()
+ segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
@@ -42,7 +42,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
- image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
+ image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
try:
if image_file and image_file.key:
storage.delete(image_file.key)
@@ -56,7 +56,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
db.session.commit()
if file_ids:
- files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all()
+ files = db.session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
for file in files:
try:
storage.delete(file.key)
diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py
index 51b6343fdc..64df3175e1 100644
--- a/api/tasks/batch_create_segment_to_index_task.py
+++ b/api/tasks/batch_create_segment_to_index_task.py
@@ -81,7 +81,7 @@ def batch_create_segment_to_index_task(
segment_hash = helper.generate_text_hash(content) # type: ignore
max_position = (
db.session.query(func.max(DocumentSegment.position))
- .filter(DocumentSegment.document_id == dataset_document.id)
+ .where(DocumentSegment.document_id == dataset_document.id)
.scalar()
)
segment_document = DocumentSegment(
diff --git a/api/tasks/clean_dataset_task.py b/api/tasks/clean_dataset_task.py
index 6bac718395..fad090141a 100644
--- a/api/tasks/clean_dataset_task.py
+++ b/api/tasks/clean_dataset_task.py
@@ -53,8 +53,8 @@ def clean_dataset_task(
index_struct=index_struct,
collection_binding_id=collection_binding_id,
)
- documents = db.session.query(Document).filter(Document.dataset_id == dataset_id).all()
- segments = db.session.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset_id).all()
+ documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all()
+ segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all()
if documents is None or len(documents) == 0:
logging.info(click.style("No documents found for dataset: {}".format(dataset_id), fg="green"))
@@ -72,7 +72,7 @@ def clean_dataset_task(
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
- image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
+ image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
try:
@@ -85,12 +85,12 @@ def clean_dataset_task(
db.session.delete(image_file)
db.session.delete(segment)
- db.session.query(DatasetProcessRule).filter(DatasetProcessRule.dataset_id == dataset_id).delete()
- db.session.query(DatasetQuery).filter(DatasetQuery.dataset_id == dataset_id).delete()
- db.session.query(AppDatasetJoin).filter(AppDatasetJoin.dataset_id == dataset_id).delete()
+ db.session.query(DatasetProcessRule).where(DatasetProcessRule.dataset_id == dataset_id).delete()
+ db.session.query(DatasetQuery).where(DatasetQuery.dataset_id == dataset_id).delete()
+ db.session.query(AppDatasetJoin).where(AppDatasetJoin.dataset_id == dataset_id).delete()
# delete dataset metadata
- db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id == dataset_id).delete()
- db.session.query(DatasetMetadataBinding).filter(DatasetMetadataBinding.dataset_id == dataset_id).delete()
+ db.session.query(DatasetMetadata).where(DatasetMetadata.dataset_id == dataset_id).delete()
+ db.session.query(DatasetMetadataBinding).where(DatasetMetadataBinding.dataset_id == dataset_id).delete()
# delete files
if documents:
for document in documents:
@@ -102,7 +102,7 @@ def clean_dataset_task(
file_id = data_source_info["upload_file_id"]
file = (
db.session.query(UploadFile)
- .filter(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
+ .where(UploadFile.tenant_id == document.tenant_id, UploadFile.id == file_id)
.first()
)
if not file:
diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py
index c72a3319c1..dd7a544ff5 100644
--- a/api/tasks/clean_document_task.py
+++ b/api/tasks/clean_document_task.py
@@ -28,12 +28,12 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
start_at = time.perf_counter()
try:
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
- segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
+ segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
@@ -43,7 +43,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
- image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
+ image_file = db.session.query(UploadFile).where(UploadFile.id == upload_file_id).first()
if image_file is None:
continue
try:
@@ -58,7 +58,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
db.session.commit()
if file_id:
- file = db.session.query(UploadFile).filter(UploadFile.id == file_id).first()
+ file = db.session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(file.key)
@@ -68,7 +68,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
db.session.commit()
# delete dataset metadata binding
- db.session.query(DatasetMetadataBinding).filter(
+ db.session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
diff --git a/api/tasks/clean_notion_document_task.py b/api/tasks/clean_notion_document_task.py
index 1087a37761..0f72f87f15 100644
--- a/api/tasks/clean_notion_document_task.py
+++ b/api/tasks/clean_notion_document_task.py
@@ -24,17 +24,17 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
start_at = time.perf_counter()
try:
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
for document_id in document_ids:
- document = db.session.query(Document).filter(Document.id == document_id).first()
+ document = db.session.query(Document).where(Document.id == document_id).first()
db.session.delete(document)
- segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
+ segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py
index a3f811faa1..5eda24674a 100644
--- a/api/tasks/create_segment_to_index_task.py
+++ b/api/tasks/create_segment_to_index_task.py
@@ -24,7 +24,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
logging.info(click.style("Start create segment to index: {}".format(segment_id), fg="green"))
start_at = time.perf_counter()
- segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
+ segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logging.info(click.style("Segment not found: {}".format(segment_id), fg="red"))
db.session.close()
@@ -37,11 +37,12 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
try:
# update segment status to indexing
- update_params = {
- DocumentSegment.status: "indexing",
- DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
- }
- db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params)
+ db.session.query(DocumentSegment).filter_by(id=segment.id).update(
+ {
+ DocumentSegment.status: "indexing",
+ DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ }
+ )
db.session.commit()
document = Document(
page_content=segment.content,
@@ -74,11 +75,12 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
index_processor.load(dataset, [document])
# update segment to completed
- update_params = {
- DocumentSegment.status: "completed",
- DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
- }
- db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params)
+ db.session.query(DocumentSegment).filter_by(id=segment.id).update(
+ {
+ DocumentSegment.status: "completed",
+ DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
+ }
+ )
db.session.commit()
end_at = time.perf_counter()
diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py
index a27207f2f1..7478bf5a90 100644
--- a/api/tasks/deal_dataset_vector_index_task.py
+++ b/api/tasks/deal_dataset_vector_index_task.py
@@ -35,7 +35,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
elif action == "add":
dataset_documents = (
db.session.query(DatasetDocument)
- .filter(
+ .where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@@ -46,7 +46,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
- db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
+ db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
@@ -56,7 +56,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
# add from vector index
segments = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
+ .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
@@ -76,19 +76,19 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
- db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
+ db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
- db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
+ db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
elif action == "update":
dataset_documents = (
db.session.query(DatasetDocument)
- .filter(
+ .where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
@@ -100,7 +100,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
- db.session.query(DatasetDocument).filter(DatasetDocument.id.in_(dataset_documents_ids)).update(
+ db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
@@ -113,7 +113,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
try:
segments = (
db.session.query(DocumentSegment)
- .filter(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
+ .where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
@@ -148,12 +148,12 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
- db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
+ db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
- db.session.query(DatasetDocument).filter(DatasetDocument.id == dataset_document.id).update(
+ db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
diff --git a/api/tasks/delete_account_task.py b/api/tasks/delete_account_task.py
index 52c884ca29..d3b33e3052 100644
--- a/api/tasks/delete_account_task.py
+++ b/api/tasks/delete_account_task.py
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_account_task(account_id):
- account = db.session.query(Account).filter(Account.id == account_id).first()
+ account = db.session.query(Account).where(Account.id == account_id).first()
try:
BillingService.delete_account(account_id)
except Exception as e:
diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py
index a93babc310..66ff0f9a0a 100644
--- a/api/tasks/delete_segment_from_index_task.py
+++ b/api/tasks/delete_segment_from_index_task.py
@@ -22,11 +22,11 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume
logging.info(click.style("Start delete segment from index", fg="green"))
start_at = time.perf_counter()
try:
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
return
- dataset_document = db.session.query(Document).filter(Document.id == document_id).first()
+ dataset_document = db.session.query(Document).where(Document.id == document_id).first()
if not dataset_document:
return
diff --git a/api/tasks/disable_segment_from_index_task.py b/api/tasks/disable_segment_from_index_task.py
index 327eed4721..e67ba5c76e 100644
--- a/api/tasks/disable_segment_from_index_task.py
+++ b/api/tasks/disable_segment_from_index_task.py
@@ -21,7 +21,7 @@ def disable_segment_from_index_task(segment_id: str):
logging.info(click.style("Start disable segment from index: {}".format(segment_id), fg="green"))
start_at = time.perf_counter()
- segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
+ segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logging.info(click.style("Segment not found: {}".format(segment_id), fg="red"))
db.session.close()
diff --git a/api/tasks/disable_segments_from_index_task.py b/api/tasks/disable_segments_from_index_task.py
index 8b77b290c8..0c8b1aabc7 100644
--- a/api/tasks/disable_segments_from_index_task.py
+++ b/api/tasks/disable_segments_from_index_task.py
@@ -23,13 +23,13 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
"""
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan"))
db.session.close()
return
- dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
+ dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
if not dataset_document:
logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
@@ -44,7 +44,7 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
segments = (
db.session.query(DocumentSegment)
- .filter(
+ .where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
@@ -64,7 +64,7 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green"))
except Exception:
# update segment error msg
- db.session.query(DocumentSegment).filter(
+ db.session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py
index b4848be192..dcc748ef18 100644
--- a/api/tasks/document_indexing_sync_task.py
+++ b/api/tasks/document_indexing_sync_task.py
@@ -25,7 +25,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
logging.info(click.style("Start sync document: {}".format(document_id), fg="green"))
start_at = time.perf_counter()
- document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logging.info(click.style("Document not found: {}".format(document_id), fg="red"))
@@ -46,7 +46,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
page_edited_time = data_source_info["last_edited_time"]
data_source_binding = (
db.session.query(DataSourceOauthBinding)
- .filter(
+ .where(
db.and_(
DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceOauthBinding.provider == "notion",
@@ -77,13 +77,13 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
# delete all document segment and index
try:
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
- segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
+ segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py
index 55cac6a9af..ec6d10d93b 100644
--- a/api/tasks/document_indexing_task.py
+++ b/api/tasks/document_indexing_task.py
@@ -1,4 +1,3 @@
-import datetime
import logging
import time
@@ -8,6 +7,7 @@ from celery import shared_task # type: ignore
from configs import dify_config
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from extensions.ext_database import db
+from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document
from services.feature_service import FeatureService
@@ -24,7 +24,7 @@ def document_indexing_task(dataset_id: str, document_ids: list):
documents = []
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset is not found: {}".format(dataset_id), fg="yellow"))
db.session.close()
@@ -48,12 +48,12 @@ def document_indexing_task(dataset_id: str, document_ids: list):
except Exception as e:
for document_id in document_ids:
document = (
- db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
- document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
db.session.close()
@@ -63,12 +63,12 @@ def document_indexing_task(dataset_id: str, document_ids: list):
logging.info(click.style("Start process document: {}".format(document_id), fg="green"))
document = (
- db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "parsing"
- document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
+ document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py
index 167b928f5d..e53c38ddc3 100644
--- a/api/tasks/document_indexing_update_task.py
+++ b/api/tasks/document_indexing_update_task.py
@@ -23,7 +23,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
logging.info(click.style("Start update document: {}".format(document_id), fg="green"))
start_at = time.perf_counter()
- document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logging.info(click.style("Document not found: {}".format(document_id), fg="red"))
@@ -36,14 +36,14 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
# delete all document segment and index
try:
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
- segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
+ segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py
index a6c93e110e..b3ddface59 100644
--- a/api/tasks/duplicate_document_indexing_task.py
+++ b/api/tasks/duplicate_document_indexing_task.py
@@ -25,7 +25,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
documents = []
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
logging.info(click.style("Dataset not found: {}".format(dataset_id), fg="red"))
db.session.close()
@@ -50,7 +50,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
except Exception as e:
for document_id in document_ids:
document = (
- db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
@@ -66,7 +66,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
logging.info(click.style("Start process document: {}".format(document_id), fg="green"))
document = (
- db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
@@ -74,7 +74,7 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
- segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
+ segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py
index 21f08f40a7..13822f078e 100644
--- a/api/tasks/enable_segment_to_index_task.py
+++ b/api/tasks/enable_segment_to_index_task.py
@@ -24,7 +24,7 @@ def enable_segment_to_index_task(segment_id: str):
logging.info(click.style("Start enable segment to index: {}".format(segment_id), fg="green"))
start_at = time.perf_counter()
- segment = db.session.query(DocumentSegment).filter(DocumentSegment.id == segment_id).first()
+ segment = db.session.query(DocumentSegment).where(DocumentSegment.id == segment_id).first()
if not segment:
logging.info(click.style("Segment not found: {}".format(segment_id), fg="red"))
db.session.close()
diff --git a/api/tasks/enable_segments_to_index_task.py b/api/tasks/enable_segments_to_index_task.py
index 625a3b582e..e3fdf04d8c 100644
--- a/api/tasks/enable_segments_to_index_task.py
+++ b/api/tasks/enable_segments_to_index_task.py
@@ -25,12 +25,12 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
Usage: enable_segments_to_index_task.delay(segment_ids, dataset_id, document_id)
"""
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan"))
return
- dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
+ dataset_document = db.session.query(DatasetDocument).where(DatasetDocument.id == document_id).first()
if not dataset_document:
logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
@@ -45,7 +45,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
segments = (
db.session.query(DocumentSegment)
- .filter(
+ .where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
@@ -95,7 +95,7 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
except Exception as e:
logging.exception("enable segments to index failed")
# update segment error msg
- db.session.query(DocumentSegment).filter(
+ db.session.query(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
diff --git a/api/tasks/mail_account_deletion_task.py b/api/tasks/mail_account_deletion_task.py
index 0c60ae53d5..a6f8ce2f0b 100644
--- a/api/tasks/mail_account_deletion_task.py
+++ b/api/tasks/mail_account_deletion_task.py
@@ -3,14 +3,20 @@ import time
import click
from celery import shared_task # type: ignore
-from flask import render_template
from extensions.ext_mail import mail
+from libs.email_i18n import EmailType, get_email_i18n_service
@shared_task(queue="mail")
-def send_deletion_success_task(to):
- """Send email to user regarding account deletion."""
+def send_deletion_success_task(to: str, language: str = "en-US") -> None:
+ """
+ Send account deletion success email with internationalization support.
+
+ Args:
+ to: Recipient email address
+ language: Language code for email localization
+ """
if not mail.is_inited():
return
@@ -18,12 +24,16 @@ def send_deletion_success_task(to):
start_at = time.perf_counter()
try:
- html_content = render_template(
- "delete_account_success_template_en-US.html",
+ email_service = get_email_i18n_service()
+ email_service.send_email(
+ email_type=EmailType.ACCOUNT_DELETION_SUCCESS,
+ language_code=language,
to=to,
- email=to,
+ template_context={
+ "to": to,
+ "email": to,
+ },
)
- mail.send(to=to, subject="Your Dify.AI Account Has Been Successfully Deleted", html=html_content)
end_at = time.perf_counter()
logging.info(
@@ -36,12 +46,14 @@ def send_deletion_success_task(to):
@shared_task(queue="mail")
-def send_account_deletion_verification_code(to, code):
- """Send email to user regarding account deletion verification code.
+def send_account_deletion_verification_code(to: str, code: str, language: str = "en-US") -> None:
+ """
+ Send account deletion verification code email with internationalization support.
Args:
- to (str): Recipient email address
- code (str): Verification code
+ to: Recipient email address
+ code: Verification code
+ language: Language code for email localization
"""
if not mail.is_inited():
return
@@ -50,8 +62,16 @@ def send_account_deletion_verification_code(to, code):
start_at = time.perf_counter()
try:
- html_content = render_template("delete_account_code_email_template_en-US.html", to=to, code=code)
- mail.send(to=to, subject="Dify.AI Account Deletion and Verification", html=html_content)
+ email_service = get_email_i18n_service()
+ email_service.send_email(
+ email_type=EmailType.ACCOUNT_DELETION_VERIFICATION,
+ language_code=language,
+ to=to,
+ template_context={
+ "to": to,
+ "code": code,
+ },
+ )
end_at = time.perf_counter()
logging.info(
diff --git a/api/tasks/mail_change_mail_task.py b/api/tasks/mail_change_mail_task.py
new file mode 100644
index 0000000000..6334fb22de
--- /dev/null
+++ b/api/tasks/mail_change_mail_task.py
@@ -0,0 +1,80 @@
+import logging
+import time
+
+import click
+from celery import shared_task # type: ignore
+
+from extensions.ext_mail import mail
+from libs.email_i18n import EmailType, get_email_i18n_service
+
+
+@shared_task(queue="mail")
+def send_change_mail_task(language: str, to: str, code: str, phase: str) -> None:
+ """
+ Send change email notification with internationalization support.
+
+ Args:
+ language: Language code for email localization
+ to: Recipient email address
+ code: Email verification code
+ phase: Change email phase ('old_email' or 'new_email')
+ """
+ if not mail.is_inited():
+ return
+
+ logging.info(click.style("Start change email mail to {}".format(to), fg="green"))
+ start_at = time.perf_counter()
+
+ try:
+ email_service = get_email_i18n_service()
+ email_service.send_change_email(
+ language_code=language,
+ to=to,
+ code=code,
+ phase=phase,
+ )
+
+ end_at = time.perf_counter()
+ logging.info(
+ click.style("Send change email mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green")
+ )
+ except Exception:
+ logging.exception("Send change email mail to {} failed".format(to))
+
+
+@shared_task(queue="mail")
+def send_change_mail_completed_notification_task(language: str, to: str) -> None:
+ """
+ Send change email completed notification with internationalization support.
+
+ Args:
+ language: Language code for email localization
+ to: Recipient email address
+ """
+ if not mail.is_inited():
+ return
+
+ logging.info(click.style("Start change email completed notify mail to {}".format(to), fg="green"))
+ start_at = time.perf_counter()
+
+ try:
+ email_service = get_email_i18n_service()
+ email_service.send_email(
+ email_type=EmailType.CHANGE_EMAIL_COMPLETED,
+ language_code=language,
+ to=to,
+ template_context={
+ "to": to,
+ "email": to,
+ },
+ )
+
+ end_at = time.perf_counter()
+ logging.info(
+ click.style(
+ "Send change email completed mail to {} succeeded: latency: {}".format(to, end_at - start_at),
+ fg="green",
+ )
+ )
+ except Exception:
+ logging.exception("Send change email completed mail to {} failed".format(to))
diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py
index ddad331725..34220784e9 100644
--- a/api/tasks/mail_email_code_login.py
+++ b/api/tasks/mail_email_code_login.py
@@ -3,19 +3,20 @@ import time
import click
from celery import shared_task # type: ignore
-from flask import render_template
from extensions.ext_mail import mail
-from services.feature_service import FeatureService
+from libs.email_i18n import EmailType, get_email_i18n_service
@shared_task(queue="mail")
-def send_email_code_login_mail_task(language: str, to: str, code: str):
+def send_email_code_login_mail_task(language: str, to: str, code: str) -> None:
"""
- Async Send email code login mail
- :param language: Language in which the email should be sent (e.g., 'en', 'zh')
- :param to: Recipient email address
- :param code: Email code to be included in the email
+ Send email code login email with internationalization support.
+
+ Args:
+ language: Language code for email localization
+ to: Recipient email address
+ code: Email verification code
"""
if not mail.is_inited():
return
@@ -23,28 +24,17 @@ def send_email_code_login_mail_task(language: str, to: str, code: str):
logging.info(click.style("Start email code login mail to {}".format(to), fg="green"))
start_at = time.perf_counter()
- # send email code login mail using different languages
try:
- if language == "zh-Hans":
- template = "email_code_login_mail_template_zh-CN.html"
- system_features = FeatureService.get_system_features()
- if system_features.branding.enabled:
- application_title = system_features.branding.application_title
- template = "without-brand/email_code_login_mail_template_zh-CN.html"
- html_content = render_template(template, to=to, code=code, application_title=application_title)
- else:
- html_content = render_template(template, to=to, code=code)
- mail.send(to=to, subject="邮箱验证码", html=html_content)
- else:
- template = "email_code_login_mail_template_en-US.html"
- system_features = FeatureService.get_system_features()
- if system_features.branding.enabled:
- application_title = system_features.branding.application_title
- template = "without-brand/email_code_login_mail_template_en-US.html"
- html_content = render_template(template, to=to, code=code, application_title=application_title)
- else:
- html_content = render_template(template, to=to, code=code)
- mail.send(to=to, subject="Email Code", html=html_content)
+ email_service = get_email_i18n_service()
+ email_service.send_email(
+ email_type=EmailType.EMAIL_CODE_LOGIN,
+ language_code=language,
+ to=to,
+ template_context={
+ "to": to,
+ "code": code,
+ },
+ )
end_at = time.perf_counter()
logging.info(
diff --git a/api/tasks/mail_enterprise_task.py b/api/tasks/mail_enterprise_task.py
index b9d8fd55df..a1c2908624 100644
--- a/api/tasks/mail_enterprise_task.py
+++ b/api/tasks/mail_enterprise_task.py
@@ -1,15 +1,17 @@
import logging
import time
+from collections.abc import Mapping
import click
from celery import shared_task # type: ignore
from flask import render_template_string
from extensions.ext_mail import mail
+from libs.email_i18n import get_email_i18n_service
@shared_task(queue="mail")
-def send_enterprise_email_task(to, subject, body, substitutions):
+def send_enterprise_email_task(to: list[str], subject: str, body: str, substitutions: Mapping[str, str]):
if not mail.is_inited():
return
@@ -19,11 +21,8 @@ def send_enterprise_email_task(to, subject, body, substitutions):
try:
html_content = render_template_string(body, **substitutions)
- if isinstance(to, list):
- for t in to:
- mail.send(to=t, subject=subject, html=html_content)
- else:
- mail.send(to=to, subject=subject, html=html_content)
+ email_service = get_email_i18n_service()
+ email_service.send_raw_email(to=to, subject=subject, html_content=html_content)
end_at = time.perf_counter()
logging.info(
diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py
index 7ca85c7f2d..8c73de0111 100644
--- a/api/tasks/mail_invite_member_task.py
+++ b/api/tasks/mail_invite_member_task.py
@@ -3,24 +3,23 @@ import time
import click
from celery import shared_task # type: ignore
-from flask import render_template
from configs import dify_config
from extensions.ext_mail import mail
-from services.feature_service import FeatureService
+from libs.email_i18n import EmailType, get_email_i18n_service
@shared_task(queue="mail")
-def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str):
+def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str) -> None:
"""
- Async Send invite member mail
- :param language
- :param to
- :param token
- :param inviter_name
- :param workspace_name
-
- Usage: send_invite_member_mail_task.delay(language, to, token, inviter_name, workspace_name)
+ Send invite member email with internationalization support.
+
+ Args:
+ language: Language code for email localization
+ to: Recipient email address
+ token: Invitation token
+ inviter_name: Name of the person sending the invitation
+ workspace_name: Name of the workspace
"""
if not mail.is_inited():
return
@@ -30,49 +29,20 @@ def send_invite_member_mail_task(language: str, to: str, token: str, inviter_nam
)
start_at = time.perf_counter()
- # send invite member mail using different languages
try:
url = f"{dify_config.CONSOLE_WEB_URL}/activate?token={token}"
- if language == "zh-Hans":
- template = "invite_member_mail_template_zh-CN.html"
- system_features = FeatureService.get_system_features()
- if system_features.branding.enabled:
- application_title = system_features.branding.application_title
- template = "without-brand/invite_member_mail_template_zh-CN.html"
- html_content = render_template(
- template,
- to=to,
- inviter_name=inviter_name,
- workspace_name=workspace_name,
- url=url,
- application_title=application_title,
- )
- mail.send(to=to, subject=f"立即加入 {application_title} 工作空间", html=html_content)
- else:
- html_content = render_template(
- template, to=to, inviter_name=inviter_name, workspace_name=workspace_name, url=url
- )
- mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content)
- else:
- template = "invite_member_mail_template_en-US.html"
- system_features = FeatureService.get_system_features()
- if system_features.branding.enabled:
- application_title = system_features.branding.application_title
- template = "without-brand/invite_member_mail_template_en-US.html"
- html_content = render_template(
- template,
- to=to,
- inviter_name=inviter_name,
- workspace_name=workspace_name,
- url=url,
- application_title=application_title,
- )
- mail.send(to=to, subject=f"Join {application_title} Workspace Now", html=html_content)
- else:
- html_content = render_template(
- template, to=to, inviter_name=inviter_name, workspace_name=workspace_name, url=url
- )
- mail.send(to=to, subject="Join Dify Workspace Now", html=html_content)
+ email_service = get_email_i18n_service()
+ email_service.send_email(
+ email_type=EmailType.INVITE_MEMBER,
+ language_code=language,
+ to=to,
+ template_context={
+ "to": to,
+ "inviter_name": inviter_name,
+ "workspace_name": workspace_name,
+ "url": url,
+ },
+ )
end_at = time.perf_counter()
logging.info(
diff --git a/api/tasks/mail_owner_transfer_task.py b/api/tasks/mail_owner_transfer_task.py
new file mode 100644
index 0000000000..e566a6bc56
--- /dev/null
+++ b/api/tasks/mail_owner_transfer_task.py
@@ -0,0 +1,129 @@
+import logging
+import time
+
+import click
+from celery import shared_task # type: ignore
+
+from extensions.ext_mail import mail
+from libs.email_i18n import EmailType, get_email_i18n_service
+
+
+@shared_task(queue="mail")
+def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspace: str) -> None:
+ """
+ Send owner transfer confirmation email with internationalization support.
+
+ Args:
+ language: Language code for email localization
+ to: Recipient email address
+ code: Verification code
+ workspace: Workspace name
+ """
+ if not mail.is_inited():
+ return
+
+ logging.info(click.style("Start owner transfer confirm mail to {}".format(to), fg="green"))
+ start_at = time.perf_counter()
+
+ try:
+ email_service = get_email_i18n_service()
+ email_service.send_email(
+ email_type=EmailType.OWNER_TRANSFER_CONFIRM,
+ language_code=language,
+ to=to,
+ template_context={
+ "to": to,
+ "code": code,
+ "WorkspaceName": workspace,
+ },
+ )
+
+ end_at = time.perf_counter()
+ logging.info(
+ click.style(
+ "Send owner transfer confirm mail to {} succeeded: latency: {}".format(to, end_at - start_at),
+ fg="green",
+ )
+ )
+ except Exception:
+ logging.exception("owner transfer confirm email mail to {} failed".format(to))
+
+
+@shared_task(queue="mail")
+def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: str, new_owner_email: str) -> None:
+ """
+ Send old owner transfer notification email with internationalization support.
+
+ Args:
+ language: Language code for email localization
+ to: Recipient email address
+ workspace: Workspace name
+ new_owner_email: New owner email address
+ """
+ if not mail.is_inited():
+ return
+
+ logging.info(click.style("Start old owner transfer notify mail to {}".format(to), fg="green"))
+ start_at = time.perf_counter()
+
+ try:
+ email_service = get_email_i18n_service()
+ email_service.send_email(
+ email_type=EmailType.OWNER_TRANSFER_OLD_NOTIFY,
+ language_code=language,
+ to=to,
+ template_context={
+ "to": to,
+ "WorkspaceName": workspace,
+ "NewOwnerEmail": new_owner_email,
+ },
+ )
+
+ end_at = time.perf_counter()
+ logging.info(
+ click.style(
+ "Send old owner transfer notify mail to {} succeeded: latency: {}".format(to, end_at - start_at),
+ fg="green",
+ )
+ )
+ except Exception:
+ logging.exception("old owner transfer notify email mail to {} failed".format(to))
+
+
+@shared_task(queue="mail")
+def send_new_owner_transfer_notify_email_task(language: str, to: str, workspace: str) -> None:
+ """
+ Send new owner transfer notification email with internationalization support.
+
+ Args:
+ language: Language code for email localization
+ to: Recipient email address
+ workspace: Workspace name
+ """
+ if not mail.is_inited():
+ return
+
+ logging.info(click.style("Start new owner transfer notify mail to {}".format(to), fg="green"))
+ start_at = time.perf_counter()
+
+ try:
+ email_service = get_email_i18n_service()
+ email_service.send_email(
+ email_type=EmailType.OWNER_TRANSFER_NEW_NOTIFY,
+ language_code=language,
+ to=to,
+ template_context={
+ "to": to,
+ "WorkspaceName": workspace,
+ },
+ )
+
+ end_at = time.perf_counter()
+ logging.info(
+ click.style(
+ "Send new owner transfer notify mail to {} succeeded: latency: {}".format(to, end_at - start_at),
+ fg="green",
+ )
+ )
+ except Exception:
+ logging.exception("new owner transfer notify email mail to {} failed".format(to))
diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py
index d4f4482a48..e2482f2101 100644
--- a/api/tasks/mail_reset_password_task.py
+++ b/api/tasks/mail_reset_password_task.py
@@ -3,19 +3,20 @@ import time
import click
from celery import shared_task # type: ignore
-from flask import render_template
from extensions.ext_mail import mail
-from services.feature_service import FeatureService
+from libs.email_i18n import EmailType, get_email_i18n_service
@shared_task(queue="mail")
-def send_reset_password_mail_task(language: str, to: str, code: str):
+def send_reset_password_mail_task(language: str, to: str, code: str) -> None:
"""
- Async Send reset password mail
- :param language: Language in which the email should be sent (e.g., 'en', 'zh')
- :param to: Recipient email address
- :param code: Reset password code
+ Send reset password email with internationalization support.
+
+ Args:
+ language: Language code for email localization
+ to: Recipient email address
+ code: Reset password code
"""
if not mail.is_inited():
return
@@ -23,30 +24,17 @@ def send_reset_password_mail_task(language: str, to: str, code: str):
logging.info(click.style("Start password reset mail to {}".format(to), fg="green"))
start_at = time.perf_counter()
- # send reset password mail using different languages
try:
- if language == "zh-Hans":
- template = "reset_password_mail_template_zh-CN.html"
- system_features = FeatureService.get_system_features()
- if system_features.branding.enabled:
- application_title = system_features.branding.application_title
- template = "without-brand/reset_password_mail_template_zh-CN.html"
- html_content = render_template(template, to=to, code=code, application_title=application_title)
- mail.send(to=to, subject=f"设置您的 {application_title} 密码", html=html_content)
- else:
- html_content = render_template(template, to=to, code=code)
- mail.send(to=to, subject="设置您的 Dify 密码", html=html_content)
- else:
- template = "reset_password_mail_template_en-US.html"
- system_features = FeatureService.get_system_features()
- if system_features.branding.enabled:
- application_title = system_features.branding.application_title
- template = "without-brand/reset_password_mail_template_en-US.html"
- html_content = render_template(template, to=to, code=code, application_title=application_title)
- mail.send(to=to, subject=f"Set Your {application_title} Password", html=html_content)
- else:
- html_content = render_template(template, to=to, code=code)
- mail.send(to=to, subject="Set Your Dify Password", html=html_content)
+ email_service = get_email_i18n_service()
+ email_service.send_email(
+ email_type=EmailType.RESET_PASSWORD,
+ language_code=language,
+ to=to,
+ template_context={
+ "to": to,
+ "code": code,
+ },
+ )
end_at = time.perf_counter()
logging.info(
diff --git a/api/tasks/process_tenant_plugin_autoupgrade_check_task.py b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py
new file mode 100644
index 0000000000..6fcdad0525
--- /dev/null
+++ b/api/tasks/process_tenant_plugin_autoupgrade_check_task.py
@@ -0,0 +1,166 @@
+import traceback
+import typing
+
+import click
+from celery import shared_task # type: ignore
+
+from core.helper import marketplace
+from core.helper.marketplace import MarketplacePluginDeclaration
+from core.plugin.entities.plugin import PluginInstallationSource
+from core.plugin.impl.plugin import PluginInstaller
+from models.account import TenantPluginAutoUpgradeStrategy
+
+RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3
+
+
+cached_plugin_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {}
+
+
+def marketplace_batch_fetch_plugin_manifests(
+ plugin_ids_plain_list: list[str],
+) -> list[MarketplacePluginDeclaration]:
+ global cached_plugin_manifests
+ # return marketplace.batch_fetch_plugin_manifests(plugin_ids_plain_list)
+ not_included_plugin_ids = [
+ plugin_id for plugin_id in plugin_ids_plain_list if plugin_id not in cached_plugin_manifests
+ ]
+ if not_included_plugin_ids:
+ manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_included_plugin_ids)
+ for manifest in manifests:
+ cached_plugin_manifests[manifest.plugin_id] = manifest
+
+ if (
+ len(manifests) == 0
+ ): # this indicates that the plugin not found in marketplace, should set None in cache to prevent future check
+ for plugin_id in not_included_plugin_ids:
+ cached_plugin_manifests[plugin_id] = None
+
+ result: list[MarketplacePluginDeclaration] = []
+ for plugin_id in plugin_ids_plain_list:
+ final_manifest = cached_plugin_manifests.get(plugin_id)
+ if final_manifest is not None:
+ result.append(final_manifest)
+
+ return result
+
+
+@shared_task(queue="plugin")
+def process_tenant_plugin_autoupgrade_check_task(
+ tenant_id: str,
+ strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting,
+ upgrade_time_of_day: int,
+ upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode,
+ exclude_plugins: list[str],
+ include_plugins: list[str],
+):
+ try:
+ manager = PluginInstaller()
+
+ click.echo(
+ click.style(
+ "Checking upgradable plugin for tenant: {}".format(tenant_id),
+ fg="green",
+ )
+ )
+
+ if strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED:
+ return
+
+ # get plugin_ids to check
+ plugin_ids: list[tuple[str, str, str]] = [] # plugin_id, version, unique_identifier
+ click.echo(click.style("Upgrade mode: {}".format(upgrade_mode), fg="green"))
+
+ if upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL and include_plugins:
+ all_plugins = manager.list_plugins(tenant_id)
+
+ for plugin in all_plugins:
+ if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id in include_plugins:
+ plugin_ids.append(
+ (
+ plugin.plugin_id,
+ plugin.version,
+ plugin.plugin_unique_identifier,
+ )
+ )
+
+ elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE:
+ # get all plugins and remove excluded plugins
+ all_plugins = manager.list_plugins(tenant_id)
+ plugin_ids = [
+ (plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
+ for plugin in all_plugins
+ if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id not in exclude_plugins
+ ]
+ elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL:
+ all_plugins = manager.list_plugins(tenant_id)
+ plugin_ids = [
+ (plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
+ for plugin in all_plugins
+ if plugin.source == PluginInstallationSource.Marketplace
+ ]
+
+ if not plugin_ids:
+ return
+
+ plugin_ids_plain_list = [plugin_id for plugin_id, _, _ in plugin_ids]
+
+ manifests = marketplace_batch_fetch_plugin_manifests(plugin_ids_plain_list)
+
+ if not manifests:
+ return
+
+ for manifest in manifests:
+ for plugin_id, version, original_unique_identifier in plugin_ids:
+ if manifest.plugin_id != plugin_id:
+ continue
+
+ try:
+ current_version = version
+ latest_version = manifest.latest_version
+
+ def fix_only_checker(latest_version, current_version):
+ latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
+ current_version_tuple = tuple(int(val) for val in current_version.split("."))
+
+ if (
+ latest_version_tuple[0] == current_version_tuple[0]
+ and latest_version_tuple[1] == current_version_tuple[1]
+ ):
+ return latest_version_tuple[2] != current_version_tuple[2]
+ return False
+
+ version_checker = {
+ TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version,
+ current_version: latest_version != current_version,
+ TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
+ }
+
+ if version_checker[strategy_setting](latest_version, current_version):
+ # execute upgrade
+ new_unique_identifier = manifest.latest_package_identifier
+
+ marketplace.record_install_plugin_event(new_unique_identifier)
+ click.echo(
+ click.style(
+ "Upgrade plugin: {} -> {}".format(original_unique_identifier, new_unique_identifier),
+ fg="green",
+ )
+ )
+ task_start_resp = manager.upgrade_plugin(
+ tenant_id,
+ original_unique_identifier,
+ new_unique_identifier,
+ PluginInstallationSource.Marketplace,
+ {
+ "plugin_unique_identifier": new_unique_identifier,
+ },
+ )
+ except Exception as e:
+ click.echo(click.style("Error when upgrading plugin: {}".format(e), fg="red"))
+ traceback.print_exc()
+ break
+
+ except Exception as e:
+ click.echo(click.style("Error when checking upgradable plugin: {}".format(e), fg="red"))
+ traceback.print_exc()
+ return
diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py
index e7d49c78dc..dfb2389579 100644
--- a/api/tasks/recover_document_indexing_task.py
+++ b/api/tasks/recover_document_indexing_task.py
@@ -21,7 +21,7 @@ def recover_document_indexing_task(dataset_id: str, document_id: str):
logging.info(click.style("Recover document: {}".format(document_id), fg="green"))
start_at = time.perf_counter()
- document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logging.info(click.style("Document not found: {}".format(document_id), fg="red"))
diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py
index 4a62cb74b4..1619f8c546 100644
--- a/api/tasks/remove_app_and_related_data_task.py
+++ b/api/tasks/remove_app_and_related_data_task.py
@@ -6,6 +6,7 @@ import click
from celery import shared_task # type: ignore
from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.orm import sessionmaker
from extensions.ext_database import db
from models import (
@@ -31,7 +32,8 @@ from models import (
)
from models.tools import WorkflowToolProvider
from models.web import PinnedConversation, SavedMessage
-from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun
+from models.workflow import ConversationVariable, Workflow, WorkflowAppLog
+from repositories.factory import DifyAPIRepositoryFactory
@shared_task(queue="app_deletion", bind=True, max_retries=3)
@@ -74,7 +76,7 @@ def remove_app_and_related_data_task(self, tenant_id: str, app_id: str):
def _delete_app_model_configs(tenant_id: str, app_id: str):
def del_model_config(model_config_id: str):
- db.session.query(AppModelConfig).filter(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
+ db.session.query(AppModelConfig).where(AppModelConfig.id == model_config_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_model_configs where app_id=:app_id limit 1000""",
@@ -86,14 +88,14 @@ def _delete_app_model_configs(tenant_id: str, app_id: str):
def _delete_app_site(tenant_id: str, app_id: str):
def del_site(site_id: str):
- db.session.query(Site).filter(Site.id == site_id).delete(synchronize_session=False)
+ db.session.query(Site).where(Site.id == site_id).delete(synchronize_session=False)
_delete_records("""select id from sites where app_id=:app_id limit 1000""", {"app_id": app_id}, del_site, "site")
def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def del_mcp_server(mcp_server_id: str):
- db.session.query(AppMCPServer).filter(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
+ db.session.query(AppMCPServer).where(AppMCPServer.id == mcp_server_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_mcp_servers where app_id=:app_id limit 1000""",
@@ -105,7 +107,7 @@ def _delete_app_mcp_servers(tenant_id: str, app_id: str):
def _delete_app_api_tokens(tenant_id: str, app_id: str):
def del_api_token(api_token_id: str):
- db.session.query(ApiToken).filter(ApiToken.id == api_token_id).delete(synchronize_session=False)
+ db.session.query(ApiToken).where(ApiToken.id == api_token_id).delete(synchronize_session=False)
_delete_records(
"""select id from api_tokens where app_id=:app_id limit 1000""", {"app_id": app_id}, del_api_token, "api token"
@@ -114,7 +116,7 @@ def _delete_app_api_tokens(tenant_id: str, app_id: str):
def _delete_installed_apps(tenant_id: str, app_id: str):
def del_installed_app(installed_app_id: str):
- db.session.query(InstalledApp).filter(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
+ db.session.query(InstalledApp).where(InstalledApp.id == installed_app_id).delete(synchronize_session=False)
_delete_records(
"""select id from installed_apps where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -126,7 +128,7 @@ def _delete_installed_apps(tenant_id: str, app_id: str):
def _delete_recommended_apps(tenant_id: str, app_id: str):
def del_recommended_app(recommended_app_id: str):
- db.session.query(RecommendedApp).filter(RecommendedApp.id == recommended_app_id).delete(
+ db.session.query(RecommendedApp).where(RecommendedApp.id == recommended_app_id).delete(
synchronize_session=False
)
@@ -140,9 +142,9 @@ def _delete_recommended_apps(tenant_id: str, app_id: str):
def _delete_app_annotation_data(tenant_id: str, app_id: str):
def del_annotation_hit_history(annotation_hit_history_id: str):
- db.session.query(AppAnnotationHitHistory).filter(
- AppAnnotationHitHistory.id == annotation_hit_history_id
- ).delete(synchronize_session=False)
+ db.session.query(AppAnnotationHitHistory).where(AppAnnotationHitHistory.id == annotation_hit_history_id).delete(
+ synchronize_session=False
+ )
_delete_records(
"""select id from app_annotation_hit_histories where app_id=:app_id limit 1000""",
@@ -152,7 +154,7 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
)
def del_annotation_setting(annotation_setting_id: str):
- db.session.query(AppAnnotationSetting).filter(AppAnnotationSetting.id == annotation_setting_id).delete(
+ db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.id == annotation_setting_id).delete(
synchronize_session=False
)
@@ -166,7 +168,7 @@ def _delete_app_annotation_data(tenant_id: str, app_id: str):
def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def del_dataset_join(dataset_join_id: str):
- db.session.query(AppDatasetJoin).filter(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
+ db.session.query(AppDatasetJoin).where(AppDatasetJoin.id == dataset_join_id).delete(synchronize_session=False)
_delete_records(
"""select id from app_dataset_joins where app_id=:app_id limit 1000""",
@@ -178,7 +180,7 @@ def _delete_app_dataset_joins(tenant_id: str, app_id: str):
def _delete_app_workflows(tenant_id: str, app_id: str):
def del_workflow(workflow_id: str):
- db.session.query(Workflow).filter(Workflow.id == workflow_id).delete(synchronize_session=False)
+ db.session.query(Workflow).where(Workflow.id == workflow_id).delete(synchronize_session=False)
_delete_records(
"""select id from workflows where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -189,34 +191,36 @@ def _delete_app_workflows(tenant_id: str, app_id: str):
def _delete_app_workflow_runs(tenant_id: str, app_id: str):
- def del_workflow_run(workflow_run_id: str):
- db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).delete(synchronize_session=False)
-
- _delete_records(
- """select id from workflow_runs where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
- {"tenant_id": tenant_id, "app_id": app_id},
- del_workflow_run,
- "workflow run",
+ """Delete all workflow runs for an app using the service repository."""
+ session_maker = sessionmaker(bind=db.engine)
+ workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
+
+ deleted_count = workflow_run_repo.delete_runs_by_app(
+ tenant_id=tenant_id,
+ app_id=app_id,
+ batch_size=1000,
)
+ logging.info(f"Deleted {deleted_count} workflow runs for app {app_id}")
-def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
- def del_workflow_node_execution(workflow_node_execution_id: str):
- db.session.query(WorkflowNodeExecutionModel).filter(
- WorkflowNodeExecutionModel.id == workflow_node_execution_id
- ).delete(synchronize_session=False)
- _delete_records(
- """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
- {"tenant_id": tenant_id, "app_id": app_id},
- del_workflow_node_execution,
- "workflow node execution",
+def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
+ """Delete all workflow node executions for an app using the service repository."""
+ session_maker = sessionmaker(bind=db.engine)
+ node_execution_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
+
+ deleted_count = node_execution_repo.delete_executions_by_app(
+ tenant_id=tenant_id,
+ app_id=app_id,
+ batch_size=1000,
)
+ logging.info(f"Deleted {deleted_count} workflow node executions for app {app_id}")
+
def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def del_workflow_app_log(workflow_app_log_id: str):
- db.session.query(WorkflowAppLog).filter(WorkflowAppLog.id == workflow_app_log_id).delete(
+ db.session.query(WorkflowAppLog).where(WorkflowAppLog.id == workflow_app_log_id).delete(
synchronize_session=False
)
@@ -230,10 +234,10 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def _delete_app_conversations(tenant_id: str, app_id: str):
def del_conversation(conversation_id: str):
- db.session.query(PinnedConversation).filter(PinnedConversation.conversation_id == conversation_id).delete(
+ db.session.query(PinnedConversation).where(PinnedConversation.conversation_id == conversation_id).delete(
synchronize_session=False
)
- db.session.query(Conversation).filter(Conversation.id == conversation_id).delete(synchronize_session=False)
+ db.session.query(Conversation).where(Conversation.id == conversation_id).delete(synchronize_session=False)
_delete_records(
"""select id from conversations where app_id=:app_id limit 1000""",
@@ -253,19 +257,19 @@ def _delete_conversation_variables(*, app_id: str):
def _delete_app_messages(tenant_id: str, app_id: str):
def del_message(message_id: str):
- db.session.query(MessageFeedback).filter(MessageFeedback.message_id == message_id).delete(
+ db.session.query(MessageFeedback).where(MessageFeedback.message_id == message_id).delete(
synchronize_session=False
)
- db.session.query(MessageAnnotation).filter(MessageAnnotation.message_id == message_id).delete(
+ db.session.query(MessageAnnotation).where(MessageAnnotation.message_id == message_id).delete(
synchronize_session=False
)
- db.session.query(MessageChain).filter(MessageChain.message_id == message_id).delete(synchronize_session=False)
- db.session.query(MessageAgentThought).filter(MessageAgentThought.message_id == message_id).delete(
+ db.session.query(MessageChain).where(MessageChain.message_id == message_id).delete(synchronize_session=False)
+ db.session.query(MessageAgentThought).where(MessageAgentThought.message_id == message_id).delete(
synchronize_session=False
)
- db.session.query(MessageFile).filter(MessageFile.message_id == message_id).delete(synchronize_session=False)
- db.session.query(SavedMessage).filter(SavedMessage.message_id == message_id).delete(synchronize_session=False)
- db.session.query(Message).filter(Message.id == message_id).delete()
+ db.session.query(MessageFile).where(MessageFile.message_id == message_id).delete(synchronize_session=False)
+ db.session.query(SavedMessage).where(SavedMessage.message_id == message_id).delete(synchronize_session=False)
+ db.session.query(Message).where(Message.id == message_id).delete()
_delete_records(
"""select id from messages where app_id=:app_id limit 1000""", {"app_id": app_id}, del_message, "message"
@@ -274,7 +278,7 @@ def _delete_app_messages(tenant_id: str, app_id: str):
def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def del_tool_provider(tool_provider_id: str):
- db.session.query(WorkflowToolProvider).filter(WorkflowToolProvider.id == tool_provider_id).delete(
+ db.session.query(WorkflowToolProvider).where(WorkflowToolProvider.id == tool_provider_id).delete(
synchronize_session=False
)
@@ -288,7 +292,7 @@ def _delete_workflow_tool_providers(tenant_id: str, app_id: str):
def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def del_tag_binding(tag_binding_id: str):
- db.session.query(TagBinding).filter(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
+ db.session.query(TagBinding).where(TagBinding.id == tag_binding_id).delete(synchronize_session=False)
_delete_records(
"""select id from tag_bindings where tenant_id=:tenant_id and target_id=:app_id limit 1000""",
@@ -300,7 +304,7 @@ def _delete_app_tag_bindings(tenant_id: str, app_id: str):
def _delete_end_users(tenant_id: str, app_id: str):
def del_end_user(end_user_id: str):
- db.session.query(EndUser).filter(EndUser.id == end_user_id).delete(synchronize_session=False)
+ db.session.query(EndUser).where(EndUser.id == end_user_id).delete(synchronize_session=False)
_delete_records(
"""select id from end_users where tenant_id=:tenant_id and app_id=:app_id limit 1000""",
@@ -312,7 +316,7 @@ def _delete_end_users(tenant_id: str, app_id: str):
def _delete_trace_app_configs(tenant_id: str, app_id: str):
def del_trace_app_config(trace_app_config_id: str):
- db.session.query(TraceAppConfig).filter(TraceAppConfig.id == trace_app_config_id).delete(
+ db.session.query(TraceAppConfig).where(TraceAppConfig.id == trace_app_config_id).delete(
synchronize_session=False
)
diff --git a/api/tasks/remove_document_from_index_task.py b/api/tasks/remove_document_from_index_task.py
index 0e2960788d..3f73cc7b40 100644
--- a/api/tasks/remove_document_from_index_task.py
+++ b/api/tasks/remove_document_from_index_task.py
@@ -22,7 +22,7 @@ def remove_document_from_index_task(document_id: str):
logging.info(click.style("Start remove document segments from index: {}".format(document_id), fg="green"))
start_at = time.perf_counter()
- document = db.session.query(Document).filter(Document.id == document_id).first()
+ document = db.session.query(Document).where(Document.id == document_id).first()
if not document:
logging.info(click.style("Document not found: {}".format(document_id), fg="red"))
db.session.close()
@@ -43,7 +43,7 @@ def remove_document_from_index_task(document_id: str):
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
- segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).all()
+ segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).all()
index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids:
try:
@@ -51,7 +51,7 @@ def remove_document_from_index_task(document_id: str):
except Exception:
logging.exception(f"clean dataset {dataset.id} from index failed")
# update segment to disable
- db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).update(
+ db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).update(
{
DocumentSegment.enabled: False,
DocumentSegment.disabled_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py
index 8f8c3f9d81..58f0156afb 100644
--- a/api/tasks/retry_document_indexing_task.py
+++ b/api/tasks/retry_document_indexing_task.py
@@ -25,7 +25,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
documents: list[Document] = []
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset not found: {}".format(dataset_id), fg="red"))
db.session.close()
@@ -45,7 +45,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
)
except Exception as e:
document = (
- db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
@@ -59,7 +59,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
logging.info(click.style("Start retry document: {}".format(document_id), fg="green"))
document = (
- db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if not document:
logging.info(click.style("Document not found: {}".format(document_id), fg="yellow"))
@@ -69,7 +69,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
- segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
+ segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py
index dba0a39c2d..539c2db80f 100644
--- a/api/tasks/sync_website_document_indexing_task.py
+++ b/api/tasks/sync_website_document_indexing_task.py
@@ -24,7 +24,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
"""
start_at = time.perf_counter()
- dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
+ dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
raise ValueError("Dataset not found")
@@ -41,7 +41,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
)
except Exception as e:
document = (
- db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
@@ -53,7 +53,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
return
logging.info(click.style("Start sync website document: {}".format(document_id), fg="green"))
- document = db.session.query(Document).filter(Document.id == document_id, Document.dataset_id == dataset_id).first()
+ document = db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logging.info(click.style("Document not found: {}".format(document_id), fg="yellow"))
return
@@ -61,7 +61,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
- segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
+ segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
diff --git a/api/templates/change_mail_completed_template_en-US.html b/api/templates/change_mail_completed_template_en-US.html
new file mode 100644
index 0000000000..ecaf35868d
--- /dev/null
+++ b/api/templates/change_mail_completed_template_en-US.html
@@ -0,0 +1,135 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Your login email has been changed
+
+
You can now log into Dify with your new email address:
Some Documents in Your Knowledge Base Have Been Disabled
-
Some Documents in Your Knowledge Base Have Been Disabled
-
Dear {{userName}},
-
+
Dear {{userName}},
+
We're sorry for the inconvenience. To ensure optimal performance, documents
that haven’t been updated or accessed in the past 30 days have been disabled in
your knowledge bases:
{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.
-
Click the button below to log in to Dify and join the workspace.
{{ inviter_name }} is pleased to invite you to join our workspace on Dify, a platform specifically designed for LLM application development. On Dify, you can explore, create, and collaborate to build and operate AI applications.
+
Click the button below to log in to Dify and join the workspace.
{{ inviter_name }} is pleased to invite you to join our workspace on {{application_title}}, a platform specifically designed for LLM application development. On {{application_title}}, you can explore, create, and collaborate to build and operate AI applications.
-
Click the button below to log in to {{application_title}} and join the workspace.
{{ inviter_name }} is pleased to invite you to join our workspace on {{application_title}}, a platform specifically designed for LLM application development. On {{application_title}}, you can explore, create, and collaborate to build and operate AI applications.
+
Click the button below to log in to {{application_title}} and join the workspace.