diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index e7a8c98d26..44c1ddf739 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,5 +1,4 @@ FROM mcr.microsoft.com/devcontainers/python:3.12 -# [Optional] Uncomment this section to install additional OS packages. -# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ -# && apt-get -y install --no-install-recommends +RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ + && apt-get -y install libgmp-dev libmpfr-dev libmpc-dev diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index cc8eb552b0..93ecac48f2 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -1,12 +1,13 @@ #!/bin/bash -npm add -g pnpm@10.8.0 +npm add -g pnpm@10.11.1 cd web && pnpm install pipx install uv echo 'alias start-api="cd /workspaces/dify/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug"' >> ~/.bashrc echo 'alias start-worker="cd /workspaces/dify/api && uv run python -m celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion"' >> ~/.bashrc echo 'alias start-web="cd /workspaces/dify/web && pnpm dev"' >> ~/.bashrc +echo 'alias start-web-prod="cd /workspaces/dify/web && pnpm build && pnpm start"' >> ~/.bashrc echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d"' >> ~/.bashrc echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc diff --git a/.github/actions/setup-uv/action.yml b/.github/actions/setup-uv/action.yml index a596be63f7..0499b44dba 100644 --- a/.github/actions/setup-uv/action.yml +++ b/.github/actions/setup-uv/action.yml @@ -8,7 +8,7 @@ inputs: uv-version: description: UV version to set up required: true - default: '0.6.14' + default: '~=0.7.11' uv-lockfile: description: Path to the UV lockfile to restore cache from required: true diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index b4a6eb9adb..f4a5f754e0 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,25 +1,23 @@ -# Summary +> [!IMPORTANT] +> +> 1. Make sure you have read our [contribution guidelines](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) +> 2. Ensure there is an associated issue and you have been assigned to it +> 3. Use the correct syntax to link this PR: `Fixes #`. -Please include a summary of the change and which issue is fixed. Please also include relevant motivation and context. List any dependencies that are required for this change. +## Summary -> [!Tip] -> Close issue syntax: `Fixes #` or `Resolves #`, see [documentation](https://docs.github.com/en/issues/tracking-your-work-with-issues/linking-a-pull-request-to-an-issue#linking-a-pull-request-to-an-issue-using-a-keyword) for more details. + - -# Screenshots +## Screenshots | Before | After | |--------|-------| | ... | ... | -# Checklist - -> [!IMPORTANT] -> Please review the checklist below before submitting your pull request. +## Checklist - [ ] This change requires a documentation update, included: [Dify Document](https://github.com/langgenius/dify-docs) - [x] I understand that this PR may be closed in case there was no previous discussion or issues. (This doesn't apply to typos!) - [x] I've added a test for each change that was introduced, and I tried as much as possible to make a single atomic change. - [x] I've updated the documentation accordingly. - [x] I ran `dev/reformat`(backend) and `cd web && npx lint-staged`(frontend) to appease the lint gods - diff --git a/.github/workflows/deploy-rag-dev.yml b/.github/workflows/deploy-rag-dev.yml new file mode 100644 index 0000000000..86265aad6d --- /dev/null +++ b/.github/workflows/deploy-rag-dev.yml @@ -0,0 +1,28 @@ +name: Deploy RAG Dev + +permissions: + contents: read + +on: + workflow_run: + workflows: ["Build and Push API & Web"] + branches: + - "deploy/rag-dev" + types: + - completed + +jobs: + deploy: + runs-on: ubuntu-latest + if: | + github.event.workflow_run.conclusion == 'success' && + github.event.workflow_run.head_branch == 'deploy/rag-dev' + steps: + - name: Deploy to server + uses: appleboy/ssh-action@v0.1.8 + with: + host: ${{ secrets.RAG_SSH_HOST }} + username: ${{ secrets.SSH_USER }} + key: ${{ secrets.SSH_PRIVATE_KEY }} + script: | + ${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }} diff --git a/.github/workflows/expose_service_ports.sh b/.github/workflows/expose_service_ports.sh index 10d95cb736..01772ccf9f 100755 --- a/.github/workflows/expose_service_ports.sh +++ b/.github/workflows/expose_service_ports.sh @@ -10,6 +10,7 @@ yq eval '.services["elasticsearch"].ports += ["9200:9200"]' -i docker/docker-com yq eval '.services.couchbase-server.ports += ["8091-8096:8091-8096"]' -i docker/docker-compose.yaml yq eval '.services.couchbase-server.ports += ["11210:11210"]' -i docker/docker-compose.yaml yq eval '.services.tidb.ports += ["4000:4000"]' -i docker/tidb/docker-compose.yaml +yq eval '.services.oceanbase.ports += ["2881:2881"]' -i docker/docker-compose.yaml yq eval '.services.opengauss.ports += ["6600:6600"]' -i docker/docker-compose.yaml echo "Ports exposed for sandbox, weaviate, tidb, qdrant, chroma, milvus, pgvector, pgvecto-rs, elasticsearch, couchbase, opengauss" diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 30c0ff000d..b06ab9653e 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -139,6 +139,7 @@ jobs: - name: Checkout code uses: actions/checkout@v4 with: + fetch-depth: 0 persist-credentials: false - name: Check changed files diff --git a/.github/workflows/translate-i18n-base-on-english.yml b/.github/workflows/translate-i18n-base-on-english.yml index 3f8082eb69..c79d58563f 100644 --- a/.github/workflows/translate-i18n-base-on-english.yml +++ b/.github/workflows/translate-i18n-base-on-english.yml @@ -31,11 +31,19 @@ jobs: echo "FILES_CHANGED=false" >> $GITHUB_ENV fi + - name: Install pnpm + uses: pnpm/action-setup@v4 + with: + version: 10 + run_install: false + - name: Set up Node.js if: env.FILES_CHANGED == 'true' uses: actions/setup-node@v4 with: node-version: 'lts/*' + cache: pnpm + cache-dependency-path: ./web/package.json - name: Install dependencies if: env.FILES_CHANGED == 'true' diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index c784817e72..7d0a873ebd 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -31,6 +31,13 @@ jobs: with: persist-credentials: false + - name: Free Disk Space + uses: endersonmenezes/free-disk-space@v2 + with: + remove_dotnet: true + remove_haskell: true + remove_tool_cache: true + - name: Setup UV and Python uses: ./.github/actions/setup-uv with: @@ -59,7 +66,7 @@ jobs: tidb tiflash - - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase) + - name: Set up Vector Stores (Weaviate, Qdrant, PGVector, Milvus, PgVecto-RS, Chroma, MyScale, ElasticSearch, Couchbase, OceanBase) uses: hoverkraft-tech/compose-action@v2.0.2 with: compose-file: | @@ -75,8 +82,9 @@ jobs: pgvector chroma elasticsearch + oceanbase - - name: Check TiDB Ready + - name: Check VDB Ready (TiDB) run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py - name: Test Vector Stores diff --git a/.gitignore b/.gitignore index 8818ab6f65..8f82bea00d 100644 --- a/.gitignore +++ b/.gitignore @@ -179,6 +179,7 @@ docker/volumes/pgvecto_rs/data/* docker/volumes/couchbase/* docker/volumes/oceanbase/* docker/volumes/plugin_daemon/* +docker/volumes/matrixone/* !docker/volumes/oceanbase/init.d docker/nginx/conf.d/default.conf @@ -192,12 +193,12 @@ sdks/python-client/dist sdks/python-client/dify_client.egg-info .vscode/* -!.vscode/launch.json +!.vscode/launch.json.template +!.vscode/README.md pyrightconfig.json api/.vscode .idea/ -.vscode # pnpm /.pnpm-store @@ -207,3 +208,9 @@ plugins.jsonl # mise mise.toml + +# Next.js build output +.next/ + +# AI Assistant +.roo/ diff --git a/.vscode/README.md b/.vscode/README.md new file mode 100644 index 0000000000..26516f0540 --- /dev/null +++ b/.vscode/README.md @@ -0,0 +1,14 @@ +# Debugging with VS Code + +This `launch.json.template` file provides various debug configurations for the Dify project within VS Code / Cursor. To use these configurations, you should copy the contents of this file into a new file named `launch.json` in the same `.vscode` directory. + +## How to Use + +1. **Create `launch.json`**: If you don't have one, create a file named `launch.json` inside the `.vscode` directory. +2. **Copy Content**: Copy the entire content from `launch.json.template` into your newly created `launch.json` file. +3. **Select Debug Configuration**: Go to the Run and Debug view in VS Code / Cursor (Ctrl+Shift+D or Cmd+Shift+D). +4. **Start Debugging**: Select the desired configuration from the dropdown menu and click the green play button. + +## Tips + +- If you need to debug with Edge browser instead of Chrome, modify the `serverReadyAction` configuration in the "Next.js: debug full stack" section, change `"debugWithChrome"` to `"debugWithEdge"` to use Microsoft Edge for debugging. diff --git a/.vscode/launch.json.template b/.vscode/launch.json.template new file mode 100644 index 0000000000..f5a7f0893b --- /dev/null +++ b/.vscode/launch.json.template @@ -0,0 +1,68 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Flask API", + "type": "debugpy", + "request": "launch", + "module": "flask", + "env": { + "FLASK_APP": "app.py", + "FLASK_ENV": "development", + "GEVENT_SUPPORT": "True" + }, + "args": [ + "run", + "--host=0.0.0.0", + "--port=5001", + "--no-debugger", + "--no-reload" + ], + "jinja": true, + "justMyCode": true, + "cwd": "${workspaceFolder}/api", + "python": "${workspaceFolder}/api/.venv/bin/python" + }, + { + "name": "Python: Celery Worker (Solo)", + "type": "debugpy", + "request": "launch", + "module": "celery", + "env": { + "GEVENT_SUPPORT": "True" + }, + "args": [ + "-A", + "app.celery", + "worker", + "-P", + "solo", + "-c", + "1", + "-Q", + "dataset,generation,mail,ops_trace", + "--loglevel", + "INFO" + ], + "justMyCode": false, + "cwd": "${workspaceFolder}/api", + "python": "${workspaceFolder}/api/.venv/bin/python" + }, + { + "name": "Next.js: debug full stack", + "type": "node", + "request": "launch", + "program": "${workspaceFolder}/web/node_modules/next/dist/bin/next", + "runtimeArgs": ["--inspect"], + "skipFiles": ["/**"], + "serverReadyAction": { + "action": "debugWithChrome", + "killOnServerStop": true, + "pattern": "- Local:.+(https?://.+)", + "uriFormat": "%s", + "webRoot": "${workspaceFolder}/web" + }, + "cwd": "${workspaceFolder}/web" + } + ] +} diff --git a/README.md b/README.md index efb37d6083..ec399e49ee 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

📌 Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast @@ -87,8 +87,6 @@ Please refer to our [FAQ](https://docs.dify.ai/getting-started/install-self-host **1. Workflow**: Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond. -https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - **2. Comprehensive model support**: Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers). @@ -228,6 +226,11 @@ Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/) - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Using Alibaba Cloud Computing Nest + +Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Contributing For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). @@ -237,7 +240,7 @@ At the same time, please consider supporting Dify by sharing it on social media ## Community & contact -- [Github Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions. +- [GitHub Discussion](https://github.com/langgenius/dify/discussions). Best for: sharing feedback and asking questions. - [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). - [Discord](https://discord.gg/FngNHpbcY7). Best for: sharing your applications and hanging out with the community. - [X(Twitter)](https://twitter.com/dify_ai). Best for: sharing your applications and hanging out with the community. diff --git a/README_AR.md b/README_AR.md index 4f93802fda..5214da4894 100644 --- a/README_AR.md +++ b/README_AR.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify Cloud · @@ -54,8 +54,6 @@ **1. سير العمل**: قم ببناء واختبار سير عمل الذكاء الاصطناعي القوي على قماش بصري، مستفيدًا من جميع الميزات التالية وأكثر. - - **2. الدعم الشامل للنماذج**: تكامل سلس مع مئات من LLMs الخاصة / مفتوحة المصدر من عشرات من موفري التحليل والحلول المستضافة ذاتيًا، مما يغطي GPT و Mistral و Llama3 وأي نماذج متوافقة مع واجهة OpenAI API. يمكن العثور على قائمة كاملة بمزودي النموذج المدعومين [هنا](https://docs.dify.ai/getting-started/readme/model-providers). ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) @@ -211,6 +209,9 @@ docker compose up -d - [AWS CDK بواسطة @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### استخدام Alibaba Cloud للنشر + [بسرعة نشر Dify إلى سحابة علي بابا مع عش الحوسبة السحابية علي بابا](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + ## المساهمة لأولئك الذين يرغبون في المساهمة، انظر إلى [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) لدينا. @@ -225,7 +226,7 @@ docker compose up -d ## المجتمع والاتصال -- [مناقشة Github](https://github.com/langgenius/dify/discussions). الأفضل لـ: مشاركة التعليقات وطرح الأسئلة. +- [مناقشة GitHub](https://github.com/langgenius/dify/discussions). الأفضل لـ: مشاركة التعليقات وطرح الأسئلة. - [المشكلات على GitHub](https://github.com/langgenius/dify/issues). الأفضل لـ: الأخطاء التي تواجهها في استخدام Dify.AI، واقتراحات الميزات. انظر [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). - [Discord](https://discord.gg/FngNHpbcY7). الأفضل لـ: مشاركة تطبيقاتك والترفيه مع المجتمع. - [تويتر](https://twitter.com/dify_ai). الأفضل لـ: مشاركة تطبيقاتك والترفيه مع المجتمع. diff --git a/README_BN.md b/README_BN.md index 7599fae9ff..1911f186d7 100644 --- a/README_BN.md +++ b/README_BN.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

📌 ডিফাই ওয়ার্কফ্লো ফাইল আপলোড পরিচিতি: গুগল নোটবুক-এলএম পডকাস্ট পুনর্নির্মাণ @@ -84,8 +84,6 @@ docker compose up -d **১. ওয়ার্কফ্লো**: ভিজ্যুয়াল ক্যানভাসে AI ওয়ার্কফ্লো তৈরি এবং পরীক্ষা করুন, নিম্নলিখিত সব ফিচার এবং তার বাইরেও আরও অনেক কিছু ব্যবহার করে। - - **২. মডেল সাপোর্ট**: GPT, Mistral, Llama3, এবং যেকোনো OpenAI API-সামঞ্জস্যপূর্ণ মডেলসহ, কয়েক ডজন ইনফারেন্স প্রদানকারী এবং সেল্ফ-হোস্টেড সমাধান থেকে শুরু করে প্রোপ্রাইটরি/ওপেন-সোর্স LLM-এর সাথে সহজে ইন্টিগ্রেশন। সমর্থিত মডেল প্রদানকারীদের একটি সম্পূর্ণ তালিকা পাওয়া যাবে [এখানে](https://docs.dify.ai/getting-started/readme/model-providers)। @@ -227,6 +225,11 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud ব্যবহার করে ডিপ্লয় + + [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Contributing যারা কোড অবদান রাখতে চান, তাদের জন্য আমাদের [অবদান নির্দেশিকা] দেখুন (https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)। @@ -236,7 +239,7 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন ## কমিউনিটি এবং যোগাযোগ -- [Github Discussion](https://github.com/langgenius/dify/discussions) ফিডব্যাক এবং প্রতিক্রিয়া জানানোর মাধ্যম। +- [GitHub Discussion](https://github.com/langgenius/dify/discussions) ফিডব্যাক এবং প্রতিক্রিয়া জানানোর মাধ্যম। - [GitHub Issues](https://github.com/langgenius/dify/issues). Dify.AI ব্যবহার করে আপনি যেসব বাগের সম্মুখীন হন এবং ফিচার প্রস্তাবনা। আমাদের [অবদান নির্দেশিকা](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) দেখুন। - [Discord](https://discord.gg/FngNHpbcY7) আপনার এপ্লিকেশন শেয়ার এবং কমিউনিটি আড্ডার মাধ্যম। - [X(Twitter)](https://twitter.com/dify_ai) আপনার এপ্লিকেশন শেয়ার এবং কমিউনিটি আড্ডার মাধ্যম। diff --git a/README_CN.md b/README_CN.md index 973629f459..a194b01937 100644 --- a/README_CN.md +++ b/README_CN.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify 云服务 · @@ -61,11 +61,6 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI **1. 工作流**: 在画布上构建和测试功能强大的 AI 工作流程,利用以下所有功能以及更多功能。 - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. 全面的模型支持**: 与数百种专有/开源 LLMs 以及数十种推理提供商和自托管解决方案无缝集成,涵盖 GPT、Mistral、Llama3 以及任何与 OpenAI API 兼容的模型。完整的支持模型提供商列表可在[此处](https://docs.dify.ai/getting-started/readme/model-providers)找到。 @@ -226,6 +221,11 @@ docker compose up -d ##### AWS - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### 使用 阿里云计算巢 部署 + +使用 [阿里云计算巢](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) 将 Dify 一键部署到 阿里云 + + ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) @@ -248,7 +248,7 @@ docker compose up -d 我们欢迎您为 Dify 做出贡献,以帮助改善 Dify。包括:提交代码、问题、新想法,或分享您基于 Dify 创建的有趣且有用的 AI 应用程序。同时,我们也欢迎您在不同的活动、会议和社交媒体上分享 Dify。 -- [Github Discussion](https://github.com/langgenius/dify/discussions). 👉:分享您的应用程序并与社区交流。 +- [GitHub Discussion](https://github.com/langgenius/dify/discussions). 👉:分享您的应用程序并与社区交流。 - [GitHub Issues](https://github.com/langgenius/dify/issues)。👉:使用 Dify.AI 时遇到的错误和问题,请参阅[贡献指南](CONTRIBUTING.md)。 - [电子邮件支持](mailto:hello@dify.ai?subject=[GitHub]Questions%20About%20Dify)。👉:关于使用 Dify.AI 的问题。 - [Discord](https://discord.gg/FngNHpbcY7)。👉:分享您的应用程序并与社区交流。 diff --git a/README_DE.md b/README_DE.md index 738c0e3b67..fd550a5b96 100644 --- a/README_DE.md +++ b/README_DE.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

📌 Einführung in Dify Workflow File Upload: Google NotebookLM Podcast nachbilden @@ -83,11 +83,6 @@ Bitte beachten Sie unsere [FAQ](https://docs.dify.ai/getting-started/install-sel **1. Workflow**: Erstellen und testen Sie leistungsstarke KI-Workflows auf einer visuellen Oberfläche, wobei Sie alle der folgenden Funktionen und darüber hinaus nutzen können. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. Umfassende Modellunterstützung**: Nahtlose Integration mit Hunderten von proprietären und Open-Source-LLMs von Dutzenden Inferenzanbietern und selbstgehosteten Lösungen, die GPT, Mistral, Llama3 und alle mit der OpenAI API kompatiblen Modelle abdecken. Eine vollständige Liste der unterstützten Modellanbieter finden Sie [hier](https://docs.dify.ai/getting-started/readme/model-providers). @@ -226,6 +221,11 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Contributing Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren. @@ -235,7 +235,7 @@ Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide]( ## Gemeinschaft & Kontakt -* [Github Discussion](https://github.com/langgenius/dify/discussions). Am besten geeignet für: den Austausch von Feedback und das Stellen von Fragen. +* [GitHub Discussion](https://github.com/langgenius/dify/discussions). Am besten geeignet für: den Austausch von Feedback und das Stellen von Fragen. * [GitHub Issues](https://github.com/langgenius/dify/issues). Am besten für: Fehler, auf die Sie bei der Verwendung von Dify.AI stoßen, und Funktionsvorschläge. Siehe unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). * [Discord](https://discord.gg/FngNHpbcY7). Am besten geeignet für: den Austausch von Bewerbungen und den Austausch mit der Community. * [X(Twitter)](https://twitter.com/dify_ai). Am besten geeignet für: den Austausch von Bewerbungen und den Austausch mit der Community. diff --git a/README_ES.md b/README_ES.md index 212268b73d..38dea09be1 100644 --- a/README_ES.md +++ b/README_ES.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify Cloud · @@ -59,11 +59,6 @@ Dify es una plataforma de desarrollo de aplicaciones de LLM de código abierto. **1. Flujo de trabajo**: Construye y prueba potentes flujos de trabajo de IA en un lienzo visual, aprovechando todas las siguientes características y más. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. Soporte de modelos completo**: Integración perfecta con cientos de LLMs propietarios / de código abierto de docenas de proveedores de inferencia y soluciones auto-alojadas, que cubren GPT, Mistral, Llama3 y cualquier modelo compatible con la API de OpenAI. Se puede encontrar una lista completa de proveedores de modelos admitidos [aquí](https://docs.dify.ai/getting-started/readme/model-providers). @@ -226,6 +221,10 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + ## Contribuir Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_FR.md b/README_FR.md index 89eea7d058..925918e47e 100644 --- a/README_FR.md +++ b/README_FR.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify Cloud · @@ -59,11 +59,6 @@ Dify est une plateforme de développement d'applications LLM open source. Son in **1. Flux de travail** : Construisez et testez des flux de travail d'IA puissants sur un canevas visuel, en utilisant toutes les fonctionnalités suivantes et plus encore. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. Prise en charge complète des modèles** : Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers). @@ -224,6 +219,11 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK par @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Contribuer Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_JA.md b/README_JA.md index adca219753..3f8a5b859d 100644 --- a/README_JA.md +++ b/README_JA.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify Cloud · @@ -60,11 +60,6 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ **1. ワークフロー**: 強力なAIワークフローをビジュアルキャンバス上で構築し、テストできます。すべての機能、および以下の機能を使用できます。 - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. 総合的なモデルサポート**: 数百ものプロプライエタリ/オープンソースのLLMと、数十もの推論プロバイダーおよびセルフホスティングソリューションとのシームレスな統合を提供します。GPT、Mistral、Llama3、OpenAI APIと互換性のあるすべてのモデルを統合されています。サポートされているモデルプロバイダーの完全なリストは[こちら](https://docs.dify.ai/getting-started/readme/model-providers)をご覧ください。 @@ -225,6 +220,10 @@ docker compose up -d ##### AWS - [@KevinZhaoによるAWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## 貢献 コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。 @@ -241,7 +240,7 @@ docker compose up -d ## コミュニティ & お問い合わせ -* [Github Discussion](https://github.com/langgenius/dify/discussions). 主に: フィードバックの共有や質問。 +* [GitHub Discussion](https://github.com/langgenius/dify/discussions). 主に: フィードバックの共有や質問。 * [GitHub Issues](https://github.com/langgenius/dify/issues). 主に: Dify.AIを使用する際に発生するエラーや問題については、[貢献ガイド](CONTRIBUTING_JA.md)を参照してください * [Discord](https://discord.gg/FngNHpbcY7). 主に: アプリケーションの共有やコミュニティとの交流。 * [X(Twitter)](https://twitter.com/dify_ai). 主に: アプリケーションの共有やコミュニティとの交流。 diff --git a/README_KL.md b/README_KL.md index 17e6c9d509..9e562a4d73 100644 --- a/README_KL.md +++ b/README_KL.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify Cloud · @@ -59,11 +59,6 @@ Dify is an open-source LLM app development platform. Its intuitive interface com **1. Workflow**: Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. Comprehensive model support**: Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers). @@ -224,6 +219,11 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo ##### AWS - [AWS CDK qachlot @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Contributing For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). @@ -240,7 +240,7 @@ At the same time, please consider supporting Dify by sharing it on social media ## Community & Contact -* [Github Discussion](https://github.com/langgenius/dify/discussions +* [GitHub Discussion](https://github.com/langgenius/dify/discussions ). Best for: sharing feedback and asking questions. * [GitHub Issues](https://github.com/langgenius/dify/issues). Best for: bugs you encounter using Dify.AI, and feature proposals. See our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_KR.md b/README_KR.md index d44723f9b6..683b3a86f4 100644 --- a/README_KR.md +++ b/README_KR.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify 클라우드 · @@ -54,11 +54,6 @@ **1. 워크플로우**: 다음 기능들을 비롯한 다양한 기능을 활용하여 시각적 캔버스에서 강력한 AI 워크플로우를 구축하고 테스트하세요. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. 포괄적인 모델 지원:**: 수십 개의 추론 제공업체와 자체 호스팅 솔루션에서 제공하는 수백 개의 독점 및 오픈 소스 LLM과 원활하게 통합되며, GPT, Mistral, Llama3 및 모든 OpenAI API 호환 모델을 포함합니다. 지원되는 모델 제공업체의 전체 목록은 [여기](https://docs.dify.ai/getting-started/readme/model-providers)에서 확인할 수 있습니다. @@ -218,6 +213,11 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 ##### AWS - [KevinZhao의 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## 기여 코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요. @@ -234,7 +234,7 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 ## 커뮤니티 & 연락처 -* [Github 토론](https://github.com/langgenius/dify/discussions). 피드백 공유 및 질문하기에 적합합니다. +* [GitHub 토론](https://github.com/langgenius/dify/discussions). 피드백 공유 및 질문하기에 적합합니다. * [GitHub 이슈](https://github.com/langgenius/dify/issues). Dify.AI 사용 중 발견한 버그와 기능 제안에 적합합니다. [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요. * [디스코드](https://discord.gg/FngNHpbcY7). 애플리케이션 공유 및 커뮤니티와 소통하기에 적합합니다. * [트위터](https://twitter.com/dify_ai). 애플리케이션 공유 및 커뮤니티와 소통하기에 적합합니다. diff --git a/README_PT.md b/README_PT.md index 9dc2207279..b81127b70b 100644 --- a/README_PT.md +++ b/README_PT.md @@ -1,5 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) - +![cover-v5-optimized](./images/GitHub_README_if.png)

📌 Introduzindo o Dify Workflow com Upload de Arquivo: Recrie o Podcast Google NotebookLM

@@ -59,11 +58,6 @@ Dify é uma plataforma de desenvolvimento de aplicativos LLM de código aberto. **1. Workflow**: Construa e teste workflows poderosos de IA em uma interface visual, aproveitando todos os recursos a seguir e muito mais. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. Suporte abrangente a modelos**: Integração perfeita com centenas de LLMs proprietários e de código aberto de diversas provedoras e soluções auto-hospedadas, abrangendo GPT, Mistral, Llama3 e qualquer modelo compatível com a API da OpenAI. A lista completa de provedores suportados pode ser encontrada [aqui](https://docs.dify.ai/getting-started/readme/model-providers). @@ -224,6 +218,11 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Contribuindo Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_SI.md b/README_SI.md index 9a38b558b4..7034233233 100644 --- a/README_SI.md +++ b/README_SI.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

📌 Predstavljamo nalaganje datotek Dify Workflow: znova ustvarite Google NotebookLM Podcast @@ -81,11 +81,6 @@ Prosimo, glejte naša pogosta vprašanja [FAQ](https://docs.dify.ai/getting-star **1. Potek dela**: Zgradite in preizkusite zmogljive poteke dela AI na vizualnem platnu, pri čemer izkoristite vse naslednje funkcije in več. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. Celovita podpora za modele**: Brezhibna integracija s stotinami lastniških/odprtokodnih LLM-jev ducatov ponudnikov sklepanja in samostojnih rešitev, ki pokrivajo GPT, Mistral, Llama3 in vse modele, združljive z API-jem OpenAI. Celoten seznam podprtih ponudnikov modelov najdete [tukaj](https://docs.dify.ai/getting-started/readme/model-providers). @@ -224,6 +219,11 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Prispevam Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkrati vas prosimo, da podprete Dify tako, da ga delite na družbenih medijih ter na dogodkih in konferencah. @@ -234,7 +234,7 @@ Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkra ## Skupnost in stik -* [Github Discussion](https://github.com/langgenius/dify/discussions). Najboljše za: izmenjavo povratnih informacij in postavljanje vprašanj. +* [GitHub Discussion](https://github.com/langgenius/dify/discussions). Najboljše za: izmenjavo povratnih informacij in postavljanje vprašanj. * [GitHub Issues](https://github.com/langgenius/dify/issues). Najboljše za: hrošče, na katere naletite pri uporabi Dify.AI, in predloge funkcij. Oglejte si naš [vodnik za prispevke](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). * [Discord](https://discord.gg/FngNHpbcY7). Najboljše za: deljenje vaših aplikacij in druženje s skupnostjo. * [X(Twitter)](https://twitter.com/dify_ai). Najboljše za: deljenje vaših aplikacij in druženje s skupnostjo. diff --git a/README_TR.md b/README_TR.md index ab2853a019..51156933d4 100644 --- a/README_TR.md +++ b/README_TR.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify Bulut · @@ -55,11 +55,6 @@ Dify, açık kaynaklı bir LLM uygulama geliştirme platformudur. Sezgisel aray **1. Workflow**: Görsel bir arayüz üzerinde güçlü AI iş akışları oluşturun ve test edin, aşağıdaki tüm özellikleri ve daha fazlasını kullanarak. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. Kapsamlı model desteği**: Çok sayıda çıkarım sağlayıcısı ve kendi kendine barındırılan çözümlerden yüzlerce özel / açık kaynaklı LLM ile sorunsuz entegrasyon sağlar. GPT, Mistral, Llama3 ve OpenAI API uyumlu tüm modelleri kapsar. Desteklenen model sağlayıcılarının tam listesine [buradan](https://docs.dify.ai/getting-started/readme/model-providers) ulaşabilirsiniz. @@ -217,6 +212,11 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter ##### AWS - [AWS CDK tarafından @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Katkıda Bulunma Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz. @@ -232,7 +232,7 @@ Aynı zamanda, lütfen Dify'ı sosyal medyada, etkinliklerde ve konferanslarda p ## Topluluk & iletişim -* [Github Tartışmaları](https://github.com/langgenius/dify/discussions). En uygun: geri bildirim paylaşmak ve soru sormak için. +* [GitHub Tartışmaları](https://github.com/langgenius/dify/discussions). En uygun: geri bildirim paylaşmak ve soru sormak için. * [GitHub Sorunları](https://github.com/langgenius/dify/issues). En uygun: Dify.AI kullanırken karşılaştığınız hatalar ve özellik önerileri için. [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakın. * [Discord](https://discord.gg/FngNHpbcY7). En uygun: uygulamalarınızı paylaşmak ve toplulukla vakit geçirmek için. * [X(Twitter)](https://twitter.com/dify_ai). En uygun: uygulamalarınızı paylaşmak ve toplulukla vakit geçirmek için. diff --git a/README_TW.md b/README_TW.md index 8263a22b64..291da28825 100644 --- a/README_TW.md +++ b/README_TW.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

📌 介紹 Dify 工作流程檔案上傳功能:重現 Google NotebookLM Podcast @@ -86,8 +86,6 @@ docker compose up -d **1. 工作流程**: 在視覺化畫布上建立和測試強大的 AI 工作流程,利用以下所有功能及更多。 -https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - **2. 全面的模型支援**: 無縫整合來自數十個推理提供商和自託管解決方案的數百個專有/開源 LLM,涵蓋 GPT、Mistral、Llama3 和任何與 OpenAI API 兼容的模型。您可以在[此處](https://docs.dify.ai/getting-started/readme/model-providers)找到支援的模型提供商完整列表。 @@ -226,6 +224,11 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify - [由 @KevinZhao 提供的 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) +#### 使用 阿里云计算巢進行部署 + +[阿里云](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## 貢獻 對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 @@ -235,7 +238,7 @@ Dify 的所有功能都提供相應的 API,因此您可以輕鬆地將 Dify ## 社群與聯絡方式 -- [Github Discussion](https://github.com/langgenius/dify/discussions):最適合分享反饋和提問。 +- [GitHub Discussion](https://github.com/langgenius/dify/discussions):最適合分享反饋和提問。 - [GitHub Issues](https://github.com/langgenius/dify/issues):最適合報告使用 Dify.AI 時遇到的問題和提出功能建議。請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 - [Discord](https://discord.gg/FngNHpbcY7):最適合分享您的應用程式並與社群互動。 - [X(Twitter)](https://twitter.com/dify_ai):最適合分享您的應用程式並與社群互動。 diff --git a/README_VI.md b/README_VI.md index 852ed7aaa0..51a2e9e9e6 100644 --- a/README_VI.md +++ b/README_VI.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify Cloud · @@ -55,11 +55,6 @@ Dify là một nền tảng phát triển ứng dụng LLM mã nguồn mở. Gia **1. Quy trình làm việc**: Xây dựng và kiểm tra các quy trình làm việc AI mạnh mẽ trên một canvas trực quan, tận dụng tất cả các tính năng sau đây và hơn thế nữa. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. Hỗ trợ mô hình toàn diện**: Tích hợp liền mạch với hàng trăm mô hình LLM độc quyền / mã nguồn mở từ hàng chục nhà cung cấp suy luận và giải pháp tự lưu trữ, bao gồm GPT, Mistral, Llama3, và bất kỳ mô hình tương thích API OpenAI nào. Danh sách đầy đủ các nhà cung cấp mô hình được hỗ trợ có thể được tìm thấy [tại đây](https://docs.dify.ai/getting-started/readme/model-providers). @@ -219,6 +214,12 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/) ##### AWS - [AWS CDK bởi @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) + +#### Alibaba Cloud + +[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) + + ## Đóng góp Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi. diff --git a/api/.env.example b/api/.env.example index 2cc6410cdd..baa9c382c8 100644 --- a/api/.env.example +++ b/api/.env.example @@ -137,7 +137,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* # Vector database configuration -# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore +# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone VECTOR_STORE=weaviate # Weaviate configuration @@ -152,6 +152,7 @@ QDRANT_API_KEY=difyai123456 QDRANT_CLIENT_TIMEOUT=20 QDRANT_GRPC_ENABLED=false QDRANT_GRPC_PORT=6334 +QDRANT_REPLICATION_FACTOR=1 #Couchbase configuration COUCHBASE_CONNECTION_STRING=127.0.0.1 @@ -269,6 +270,7 @@ OPENSEARCH_PORT=9200 OPENSEARCH_USER=admin OPENSEARCH_PASSWORD=admin OPENSEARCH_SECURE=true +OPENSEARCH_VERIFY_CERTS=true # Baidu configuration BAIDU_VECTOR_DB_ENDPOINT=http://127.0.0.1:5287 @@ -292,6 +294,13 @@ VIKINGDB_SCHEMA=http VIKINGDB_CONNECTION_TIMEOUT=30 VIKINGDB_SOCKET_TIMEOUT=30 +# Matrixone configration +MATRIXONE_HOST=127.0.0.1 +MATRIXONE_PORT=6001 +MATRIXONE_USER=dump +MATRIXONE_PASSWORD=111 +MATRIXONE_DATABASE=dify + # Lindorm configuration LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070 LINDORM_USERNAME=admin @@ -330,9 +339,11 @@ PROMPT_GENERATION_MAX_TOKENS=512 CODE_GENERATION_MAX_TOKENS=1024 PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false -# Mail configuration, support: resend, smtp +# Mail configuration, support: resend, smtp, sendgrid MAIL_TYPE= +# If using SendGrid, use the 'from' field for authentication if necessary. MAIL_DEFAULT_SEND_FROM=no-reply +# resend configuration RESEND_API_KEY= RESEND_API_URL=https://api.resend.com # smtp configuration @@ -342,12 +353,14 @@ SMTP_USERNAME=123 SMTP_PASSWORD=abc SMTP_USE_TLS=true SMTP_OPPORTUNISTIC_TLS=false - +# Sendgid configuration +SENDGRID_API_KEY= # Sentry configuration SENTRY_DSN= # DEBUG DEBUG=false +ENABLE_REQUEST_LOGGING=False SQLALCHEMY_ECHO=false # Notion import configuration, support public and internal @@ -476,6 +489,7 @@ LOGIN_LOCKOUT_DURATION=86400 ENABLE_OTEL=false OTLP_BASE_ENDPOINT=http://localhost:4318 OTLP_API_KEY= +OTEL_EXPORTER_OTLP_PROTOCOL= OTEL_EXPORTER_TYPE=otlp OTEL_SAMPLING_RATE=0.1 OTEL_BATCH_EXPORT_SCHEDULE_DELAY=5000 @@ -487,3 +501,10 @@ OTEL_METRIC_EXPORT_TIMEOUT=30000 # Prevent Clickjacking ALLOW_EMBED=false + +# Dataset queue monitor configuration +QUEUE_MONITOR_THRESHOLD=200 +# You can configure multiple ones, separated by commas. eg: test1@dify.ai,test2@dify.ai +QUEUE_MONITOR_ALERT_EMAILS= +# Monitor interval in minutes, default is 30 minutes +QUEUE_MONITOR_INTERVAL=30 diff --git a/api/.ruff.toml b/api/.ruff.toml index 41a24abad9..facb0d5419 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -43,6 +43,7 @@ select = [ "S307", # suspicious-eval-usage, disallow use of `eval` and `ast.literal_eval` "S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers. "S302", # suspicious-marshal-usage, disallow use of `marshal` module + "S311", # suspicious-non-cryptographic-random-usage ] ignore = [ diff --git a/api/Dockerfile b/api/Dockerfile index cff696ff56..7e4997507f 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -4,7 +4,7 @@ FROM python:3.12-slim-bookworm AS base WORKDIR /app/api # Install uv -ENV UV_VERSION=0.6.14 +ENV UV_VERSION=0.7.11 RUN pip install --no-cache-dir uv==${UV_VERSION} diff --git a/api/app_factory.py b/api/app_factory.py index 1c886ac5c7..3a258be28f 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -54,6 +54,7 @@ def initialize_extensions(app: DifyApp): ext_otel, ext_proxy_fix, ext_redis, + ext_request_logging, ext_sentry, ext_set_secretkey, ext_storage, @@ -83,6 +84,7 @@ def initialize_extensions(app: DifyApp): ext_blueprints, ext_commands, ext_otel, + ext_request_logging, ] for ext in extensions: short_name = ext.__name__.split(".")[-1] diff --git a/api/commands.py b/api/commands.py index dc31dc0d80..86769847c1 100644 --- a/api/commands.py +++ b/api/commands.py @@ -6,6 +6,7 @@ from typing import Optional import click from flask import current_app +from sqlalchemy import select from werkzeug.exceptions import NotFound from configs import dify_config @@ -26,7 +27,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, D from models.dataset import Document as DatasetDocument from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation from models.provider import Provider, ProviderModel -from services.account_service import RegisterService, TenantService +from services.account_service import AccountService, RegisterService, TenantService from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs from services.plugin.data_migration import PluginDataMigration from services.plugin.plugin_migration import PluginMigration @@ -67,6 +68,7 @@ def reset_password(email, new_password, password_confirm): account.password = base64_password_hashed account.password_salt = base64_salt db.session.commit() + AccountService.reset_login_error_rate_limit(email) click.echo(click.style("Password reset successfully.", fg="green")) @@ -279,6 +281,7 @@ def migrate_knowledge_vector_database(): VectorType.ELASTICSEARCH, VectorType.OPENGAUSS, VectorType.TABLESTORE, + VectorType.MATRIXONE, } lower_collection_vector_types = { VectorType.ANALYTICDB, @@ -297,11 +300,11 @@ def migrate_knowledge_vector_database(): page = 1 while True: try: - datasets = ( - Dataset.query.filter(Dataset.indexing_technique == "high_quality") - .order_by(Dataset.created_at.desc()) - .paginate(page=page, per_page=50) + stmt = ( + select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc()) ) + + datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) except NotFound: break @@ -551,11 +554,12 @@ def old_metadata_migration(): page = 1 while True: try: - documents = ( - DatasetDocument.query.filter(DatasetDocument.doc_metadata is not None) + stmt = ( + select(DatasetDocument) + .filter(DatasetDocument.doc_metadata.is_not(None)) .order_by(DatasetDocument.created_at.desc()) - .paginate(page=page, per_page=50) ) + documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False) except NotFound: break if not documents: @@ -592,11 +596,15 @@ def old_metadata_migration(): ) db.session.add(dataset_metadata_binding) else: - dataset_metadata_binding = DatasetMetadataBinding.query.filter( - DatasetMetadataBinding.dataset_id == document.dataset_id, - DatasetMetadataBinding.document_id == document.id, - DatasetMetadataBinding.metadata_id == dataset_metadata.id, - ).first() + dataset_metadata_binding = ( + db.session.query(DatasetMetadataBinding) # type: ignore + .filter( + DatasetMetadataBinding.dataset_id == document.dataset_id, + DatasetMetadataBinding.document_id == document.id, + DatasetMetadataBinding.metadata_id == dataset_metadata.id, + ) + .first() + ) if not dataset_metadata_binding: dataset_metadata_binding = DatasetMetadataBinding( tenant_id=document.tenant_id, @@ -840,6 +848,9 @@ def clear_orphaned_file_records(force: bool): {"type": "text", "table": "workflow_node_executions", "column": "outputs"}, {"type": "text", "table": "conversations", "column": "introduction"}, {"type": "text", "table": "conversations", "column": "system_instruction"}, + {"type": "text", "table": "accounts", "column": "avatar"}, + {"type": "text", "table": "apps", "column": "icon"}, + {"type": "text", "table": "sites", "column": "icon"}, {"type": "json", "table": "messages", "column": "inputs"}, {"type": "json", "table": "messages", "column": "message"}, ] diff --git a/api/configs/deploy/__init__.py b/api/configs/deploy/__init__.py index 950936d3c6..63f4dfba63 100644 --- a/api/configs/deploy/__init__.py +++ b/api/configs/deploy/__init__.py @@ -17,6 +17,12 @@ class DeploymentConfig(BaseSettings): default=False, ) + # Request logging configuration + ENABLE_REQUEST_LOGGING: bool = Field( + description="Enable request and response body logging", + default=False, + ) + EDITION: str = Field( description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')", default="SELF_HOSTED", diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 4890b5f746..df15b92c35 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -74,7 +74,7 @@ class CodeExecutionSandboxConfig(BaseSettings): CODE_EXECUTION_ENDPOINT: HttpUrl = Field( description="URL endpoint for the code execution service", - default="http://sandbox:8194", + default=HttpUrl("http://sandbox:8194"), ) CODE_EXECUTION_API_KEY: str = Field( @@ -145,7 +145,7 @@ class PluginConfig(BaseSettings): PLUGIN_DAEMON_URL: HttpUrl = Field( description="Plugin API URL", - default="http://localhost:5002", + default=HttpUrl("http://localhost:5002"), ) PLUGIN_DAEMON_KEY: str = Field( @@ -188,7 +188,7 @@ class MarketplaceConfig(BaseSettings): MARKETPLACE_API_URL: HttpUrl = Field( description="Marketplace API URL", - default="https://marketplace.dify.ai", + default=HttpUrl("https://marketplace.dify.ai"), ) @@ -609,7 +609,7 @@ class MailConfig(BaseSettings): """ MAIL_TYPE: Optional[str] = Field( - description="Email service provider type ('smtp' or 'resend'), default to None.", + description="Email service provider type ('smtp' or 'resend' or 'sendGrid), default to None.", default=None, ) @@ -663,6 +663,11 @@ class MailConfig(BaseSettings): default=50, ) + SENDGRID_API_KEY: Optional[str] = Field( + description="API key for SendGrid service", + default=None, + ) + class RagEtlConfig(BaseSettings): """ diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index d285515998..60ba272ec9 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -1,8 +1,8 @@ import os from typing import Any, Literal, Optional -from urllib.parse import quote_plus +from urllib.parse import parse_qsl, quote_plus -from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field +from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, PositiveInt, computed_field from pydantic_settings import BaseSettings from .cache.redis_config import RedisConfig @@ -24,6 +24,7 @@ from .vdb.couchbase_config import CouchbaseConfig from .vdb.elasticsearch_config import ElasticsearchConfig from .vdb.huawei_cloud_config import HuaweiCloudConfig from .vdb.lindorm_config import LindormConfig +from .vdb.matrixone_config import MatrixoneConfig from .vdb.milvus_config import MilvusConfig from .vdb.myscale_config import MyScaleConfig from .vdb.oceanbase_config import OceanBaseVectorConfig @@ -173,17 +174,31 @@ class DatabaseConfig(BaseSettings): RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field( description="Number of processes for the retrieval service, default to CPU cores.", - default=os.cpu_count(), + default=os.cpu_count() or 1, ) - @computed_field + @computed_field # type: ignore[misc] + @property def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: + # Parse DB_EXTRAS for 'options' + db_extras_dict = dict(parse_qsl(self.DB_EXTRAS)) + options = db_extras_dict.get("options", "") + # Always include timezone + timezone_opt = "-c timezone=UTC" + if options: + # Merge user options and timezone + merged_options = f"{options} {timezone_opt}" + else: + merged_options = timezone_opt + + connect_args = {"options": merged_options} + return { "pool_size": self.SQLALCHEMY_POOL_SIZE, "max_overflow": self.SQLALCHEMY_MAX_OVERFLOW, "pool_recycle": self.SQLALCHEMY_POOL_RECYCLE, "pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING, - "connect_args": {"options": "-c timezone=UTC"}, + "connect_args": connect_args, } @@ -242,6 +257,25 @@ class InternalTestConfig(BaseSettings): ) +class DatasetQueueMonitorConfig(BaseSettings): + """ + Configuration settings for Dataset Queue Monitor + """ + + QUEUE_MONITOR_THRESHOLD: Optional[NonNegativeInt] = Field( + description="Threshold for dataset queue monitor", + default=200, + ) + QUEUE_MONITOR_ALERT_EMAILS: Optional[str] = Field( + description="Emails for dataset queue monitor alert, separated by commas", + default=None, + ) + QUEUE_MONITOR_INTERVAL: Optional[NonNegativeFloat] = Field( + description="Interval for dataset queue monitor in minutes", + default=30, + ) + + class MiddlewareConfig( # place the configs in alphabet order CeleryConfig, @@ -289,5 +323,7 @@ class MiddlewareConfig( BaiduVectorDBConfig, OpenGaussConfig, TableStoreConfig, + DatasetQueueMonitorConfig, + MatrixoneConfig, ): pass diff --git a/api/configs/middleware/cache/redis_config.py b/api/configs/middleware/cache/redis_config.py index 2e98c31ec3..916f52e165 100644 --- a/api/configs/middleware/cache/redis_config.py +++ b/api/configs/middleware/cache/redis_config.py @@ -83,3 +83,13 @@ class RedisConfig(BaseSettings): description="Password for Redis Clusters authentication (if required)", default=None, ) + + REDIS_SERIALIZATION_PROTOCOL: int = Field( + description="Redis serialization protocol (RESP) version", + default=3, + ) + + REDIS_ENABLE_CLIENT_SIDE_CACHE: bool = Field( + description="Enable client side cache in redis", + default=False, + ) diff --git a/api/configs/middleware/storage/amazon_s3_storage_config.py b/api/configs/middleware/storage/amazon_s3_storage_config.py index f2d94b12ff..e14c210718 100644 --- a/api/configs/middleware/storage/amazon_s3_storage_config.py +++ b/api/configs/middleware/storage/amazon_s3_storage_config.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional from pydantic import Field from pydantic_settings import BaseSettings @@ -34,7 +34,7 @@ class S3StorageConfig(BaseSettings): default=None, ) - S3_ADDRESS_STYLE: str = Field( + S3_ADDRESS_STYLE: Literal["auto", "virtual", "path"] = Field( description="S3 addressing style: 'auto', 'path', or 'virtual'", default="auto", ) diff --git a/api/configs/middleware/vdb/matrixone_config.py b/api/configs/middleware/vdb/matrixone_config.py new file mode 100644 index 0000000000..9400612d8e --- /dev/null +++ b/api/configs/middleware/vdb/matrixone_config.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel, Field + + +class MatrixoneConfig(BaseModel): + """Matrixone vector database configuration.""" + + MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server") + MATRIXONE_PORT: int = Field(default=6001, description="Port number of the Matrixone server") + MATRIXONE_USER: str = Field(default="dump", description="Username for authenticating with Matrixone") + MATRIXONE_PASSWORD: str = Field(default="111", description="Password for authenticating with Matrixone") + MATRIXONE_DATABASE: str = Field(default="dify", description="Name of the Matrixone database to connect to") + MATRIXONE_METRIC: str = Field( + default="l2", description="Distance metric type for vector similarity search (cosine or l2)" + ) diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py index 96f478e9a6..9fd9b60194 100644 --- a/api/configs/middleware/vdb/opensearch_config.py +++ b/api/configs/middleware/vdb/opensearch_config.py @@ -33,6 +33,11 @@ class OpenSearchConfig(BaseSettings): default=False, ) + OPENSEARCH_VERIFY_CERTS: bool = Field( + description="Whether to verify SSL certificates for HTTPS connections (recommended to set True in production)", + default=True, + ) + OPENSEARCH_AUTH_METHOD: AuthMethod = Field( description="Authentication method for OpenSearch connection (default is 'basic')", default=AuthMethod.BASIC, diff --git a/api/configs/middleware/vdb/qdrant_config.py b/api/configs/middleware/vdb/qdrant_config.py index b70f624652..0a753eddec 100644 --- a/api/configs/middleware/vdb/qdrant_config.py +++ b/api/configs/middleware/vdb/qdrant_config.py @@ -33,3 +33,8 @@ class QdrantConfig(BaseSettings): description="Port number for gRPC connection to Qdrant server (default is 6334)", default=6334, ) + + QDRANT_REPLICATION_FACTOR: PositiveInt = Field( + description="Replication factor for Qdrant collections (default is 1)", + default=1, + ) diff --git a/api/configs/observability/otel/otel_config.py b/api/configs/observability/otel/otel_config.py index 568a800d10..1b88ddcfe6 100644 --- a/api/configs/observability/otel/otel_config.py +++ b/api/configs/observability/otel/otel_config.py @@ -27,6 +27,11 @@ class OTelConfig(BaseSettings): default="otlp", ) + OTEL_EXPORTER_OTLP_PROTOCOL: str = Field( + description="OTLP exporter protocol ('grpc' or 'http')", + default="http", + ) + OTEL_SAMPLING_RATE: float = Field(default=0.1, description="Sampling rate for traces (0.0 to 1.0)") OTEL_BATCH_EXPORT_SCHEDULE_DELAY: int = Field( diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index c7960e1356..0107df22c5 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="1.3.1", + default="1.4.3", ) COMMIT_SHA: str = Field( diff --git a/api/configs/remote_settings_sources/nacos/http_request.py b/api/configs/remote_settings_sources/nacos/http_request.py index 2785bd955b..9b3359c6ad 100644 --- a/api/configs/remote_settings_sources/nacos/http_request.py +++ b/api/configs/remote_settings_sources/nacos/http_request.py @@ -60,8 +60,7 @@ class NacosHttpClient: sign_str = tenant + "+" if group: sign_str = sign_str + group + "+" - if sign_str: - sign_str += ts + sign_str += ts # Directly concatenate ts without conditional checks, because the nacos auth header forced it. return sign_str def get_access_token(self, force_refresh=False): diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index 127b8fe76d..ae41a2c03a 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -11,10 +11,6 @@ if TYPE_CHECKING: from core.workflow.entities.variable_pool import VariablePool -tenant_id: ContextVar[str] = ContextVar("tenant_id") - -workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool") - """ To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with """ diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index 79869916ed..3466eea1f6 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -1,5 +1,7 @@ from flask_restful import fields +from libs.helper import AppIconUrlField + parameters__system_parameters = { "image_file_size_limit": fields.Integer, "video_file_size_limit": fields.Integer, @@ -22,3 +24,20 @@ parameters_fields = { "file_upload": fields.Raw, "system_parameters": fields.Nested(parameters__system_parameters), } + +site_fields = { + "title": fields.String, + "chat_color_theme": fields.String, + "chat_color_theme_inverted": fields.Boolean, + "icon_type": fields.String, + "icon": fields.String, + "icon_background": fields.String, + "icon_url": AppIconUrlField, + "description": fields.String, + "copyright": fields.String, + "privacy_policy": fields.String, + "custom_disclaimer": fields.String, + "default_language": fields.String, + "show_workflow_steps": fields.Boolean, + "use_icon_as_answer_icon": fields.Boolean, +} diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py index 8cb7ad9f5b..f5257fae79 100644 --- a/api/controllers/console/admin.py +++ b/api/controllers/console/admin.py @@ -56,8 +56,7 @@ class InsertExploreAppListApi(Resource): parser.add_argument("position", type=int, required=True, nullable=False, location="json") args = parser.parse_args() - with Session(db.engine) as session: - app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none() + app = db.session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none() if not app: raise NotFound(f"App '{args['app_id']}' is not found") @@ -78,38 +77,38 @@ class InsertExploreAppListApi(Resource): select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]) ).scalar_one_or_none() - if not recommended_app: - recommended_app = RecommendedApp( - app_id=app.id, - description=desc, - copyright=copy_right, - privacy_policy=privacy_policy, - custom_disclaimer=custom_disclaimer, - language=args["language"], - category=args["category"], - position=args["position"], - ) - - db.session.add(recommended_app) - - app.is_public = True - db.session.commit() - - return {"result": "success"}, 201 - else: - recommended_app.description = desc - recommended_app.copyright = copy_right - recommended_app.privacy_policy = privacy_policy - recommended_app.custom_disclaimer = custom_disclaimer - recommended_app.language = args["language"] - recommended_app.category = args["category"] - recommended_app.position = args["position"] + if not recommended_app: + recommended_app = RecommendedApp( + app_id=app.id, + description=desc, + copyright=copy_right, + privacy_policy=privacy_policy, + custom_disclaimer=custom_disclaimer, + language=args["language"], + category=args["category"], + position=args["position"], + ) + + db.session.add(recommended_app) + + app.is_public = True + db.session.commit() + + return {"result": "success"}, 201 + else: + recommended_app.description = desc + recommended_app.copyright = copy_right + recommended_app.privacy_policy = privacy_policy + recommended_app.custom_disclaimer = custom_disclaimer + recommended_app.language = args["language"] + recommended_app.category = args["category"] + recommended_app.position = args["position"] - app.is_public = True + app.is_public = True - db.session.commit() + db.session.commit() - return {"result": "success"}, 200 + return {"result": "success"}, 200 class InsertExploreAppApi(Resource): diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index 91058767eb..2b48afd550 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -208,7 +208,7 @@ class AnnotationBatchImportApi(Resource): if len(request.files) > 1: raise TooManyFilesError() # check file type - if not file.filename.endswith(".csv"): + if not file.filename or not file.filename.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") return AppAnnotationService.batch_import_app_annotations(app_id, file) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index f97209c369..860166a61a 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -17,15 +17,13 @@ from controllers.console.wraps import ( ) from core.ops.ops_trace_manager import OpsTraceManager from extensions.ext_database import db -from fields.app_fields import ( - app_detail_fields, - app_detail_fields_with_site, - app_pagination_fields, -) +from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields from libs.login import login_required from models import Account, App from services.app_dsl_service import AppDslService, ImportMode from services.app_service import AppService +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] @@ -75,7 +73,17 @@ class AppListApi(Resource): if not app_pagination: return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} - return marshal(app_pagination, app_pagination_fields) + if FeatureService.get_system_features().webapp_auth.enabled: + app_ids = [str(app.id) for app in app_pagination.items] + res = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids=app_ids) + if len(res) != len(app_ids): + raise BadRequest("Invalid app id in webapp auth") + + for app in app_pagination.items: + if str(app.id) in res: + app.access_mode = res[str(app.id)].access_mode + + return marshal(app_pagination, app_pagination_fields), 200 @setup_required @login_required @@ -119,6 +127,10 @@ class AppApi(Resource): app_model = app_service.get_app(app_model) + if FeatureService.get_system_features().webapp_auth.enabled: + app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id)) + app_model.access_mode = app_setting.access_mode + return app_model @setup_required diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py index 5dc6515ce0..9ffb94e9f9 100644 --- a/api/controllers/console/app/app_import.py +++ b/api/controllers/console/app/app_import.py @@ -17,6 +17,8 @@ from libs.login import login_required from models import Account from models.model import App from services.app_dsl_service import AppDslService, ImportStatus +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService class AppImportApi(Resource): @@ -60,7 +62,9 @@ class AppImportApi(Resource): app_id=args.get("app_id"), ) session.commit() - + if result.app_id and FeatureService.get_system_features().webapp_auth.enabled: + # update web app setting as private + EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private") # Return appropriate status code based on result status = result.status if status == ImportStatus.FAILED.value: diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 0c13adce9b..cbbdd324ba 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -81,8 +81,7 @@ class DraftWorkflowApi(Resource): parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") parser.add_argument("features", type=dict, required=True, nullable=False, location="json") parser.add_argument("hash", type=str, required=False, location="json") - # TODO: set this to required=True after frontend is updated - parser.add_argument("environment_variables", type=list, required=False, location="json") + parser.add_argument("environment_variables", type=list, required=True, location="json") parser.add_argument("conversation_variables", type=list, required=False, location="json") args = parser.parse_args() elif "text/plain" in content_type: diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index c475aea9fc..310146a5e7 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -6,12 +6,12 @@ from sqlalchemy.orm import Session from controllers.console import api from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required +from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from extensions.ext_database import db from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from libs.login import login_required from models import App from models.model import AppMode -from models.workflow import WorkflowRunStatus from services.workflow_app_service import WorkflowAppService @@ -34,11 +34,25 @@ class WorkflowAppLogApi(Resource): parser.add_argument( "created_at__after", type=str, location="args", help="Filter logs created after this timestamp" ) + parser.add_argument( + "created_by_end_user_session_id", + type=str, + location="args", + required=False, + default=None, + ) + parser.add_argument( + "created_by_account", + type=str, + location="args", + required=False, + default=None, + ) parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() - args.status = WorkflowRunStatus(args.status) if args.status else None + args.status = WorkflowExecutionStatus(args.status) if args.status else None if args.created_at__before: args.created_at__before = isoparse(args.created_at__before) @@ -57,6 +71,8 @@ class WorkflowAppLogApi(Resource): created_at_after=args.created_at__after, page=args.page, limit=args.limit, + created_by_end_user_session_id=args.created_by_end_user_session_id, + created_by_account=args.created_by_account, ) return workflow_app_log_pagination diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 08ab61bbb9..9099700213 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,3 +1,6 @@ +from typing import cast + +from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse from flask_restful.inputs import int_range @@ -12,8 +15,7 @@ from fields.workflow_run_fields import ( ) from libs.helper import uuid_value from libs.login import login_required -from models import App -from models.model import AppMode +from models import Account, App, AppMode, EndUser from services.workflow_run_service import WorkflowRunService @@ -90,7 +92,12 @@ class WorkflowRunNodeExecutionListApi(Resource): run_id = str(run_id) workflow_run_service = WorkflowRunService() - node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id) + user = cast("Account | EndUser", current_user) + node_executions = workflow_run_service.get_workflow_run_node_executions( + app_model=app_model, + run_id=run_id, + user=user, + ) return {"data": node_executions} diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index d73d8ce701..3bbe3177fc 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -24,7 +24,7 @@ from libs.password import hash_password, valid_password from models.account import Account from services.account_service import AccountService, TenantService from services.errors.account import AccountRegisterError -from services.errors.workspace import WorkSpaceNotAllowedCreateError +from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService @@ -168,6 +168,8 @@ class ForgotPasswordResetApi(Resource): ) except WorkSpaceNotAllowedCreateError: pass + except WorkspacesLimitExceededError: + pass except AccountRegisterError: raise AccountInFreezeError() diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 27864bab3d..5f2a24322d 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -21,6 +21,7 @@ from controllers.console.error import ( AccountNotFound, EmailSendIpLimitError, NotAllowedCreateWorkspace, + WorkspacesLimitExceeded, ) from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created @@ -30,7 +31,7 @@ from models.account import Account from services.account_service import AccountService, RegisterService, TenantService from services.billing_service import BillingService from services.errors.account import AccountRegisterError -from services.errors.workspace import WorkSpaceNotAllowedCreateError +from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService @@ -88,10 +89,15 @@ class LoginApi(Resource): # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) if len(tenants) == 0: - return { - "result": "fail", - "data": "workspace not found, please contact system admin to invite you to join in a workspace", - } + system_features = FeatureService.get_system_features() + + if system_features.is_allow_create_workspace and not system_features.license.workspaces.is_available(): + raise WorkspacesLimitExceeded() + else: + return { + "result": "fail", + "data": "workspace not found, please contact system admin to invite you to join in a workspace", + } token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(args["email"]) @@ -196,15 +202,18 @@ class EmailCodeLoginApi(Resource): except AccountRegisterError as are: raise AccountInFreezeError() if account: - tenant = TenantService.get_join_tenants(account) - if not tenant: + tenants = TenantService.get_join_tenants(account) + if not tenants: + workspaces = FeatureService.get_system_features().license.workspaces + if not workspaces.is_available(): + raise WorkspacesLimitExceeded() if not FeatureService.get_system_features().is_allow_create_workspace: raise NotAllowedCreateWorkspace() else: - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role="owner") - account.current_tenant = tenant - tenant_was_created.send(tenant) + new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + TenantService.create_tenant_member(new_tenant, account, role="owner") + account.current_tenant = new_tenant + tenant_was_created.send(new_tenant) if account is None: try: @@ -215,6 +224,8 @@ class EmailCodeLoginApi(Resource): return NotAllowedCreateWorkspace() except AccountRegisterError as are: raise AccountInFreezeError() + except WorkspacesLimitExceededError: + raise WorkspacesLimitExceeded() token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(args["email"]) return {"result": "success", "data": token_pair.model_dump()} diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index f5284cc43b..395367c9e2 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -148,15 +148,15 @@ def _generate_account(provider: str, user_info: OAuthUserInfo): account = _get_account_by_openid_or_email(provider, user_info) if account: - tenant = TenantService.get_join_tenants(account) - if not tenant: + tenants = TenantService.get_join_tenants(account) + if not tenants: if not FeatureService.get_system_features().is_allow_create_workspace: raise WorkSpaceNotAllowedCreateError() else: - tenant = TenantService.create_tenant(f"{account.name}'s Workspace") - TenantService.create_tenant_member(tenant, account, role="owner") - account.current_tenant = tenant - tenant_was_created.send(tenant) + new_tenant = TenantService.create_tenant(f"{account.name}'s Workspace") + TenantService.create_tenant_member(new_tenant, account, role="owner") + account.current_tenant = new_tenant + tenant_was_created.send(new_tenant) if not account: if not FeatureService.get_system_features().is_allow_register: diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 571a395780..1611214cb3 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -526,17 +526,36 @@ class DatasetIndexingStatusApi(Resource): ) documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", - ).count() - total_segments = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" - ).count() - document.completed_segments = completed_segments - document.total_segments = total_segments - documents_status.append(marshal(document, document_status_fields)) + completed_segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ) + .count() + ) + total_segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .count() + ) + # Create a dictionary with document attributes and additional fields + document_dict = { + "id": document.id, + "indexing_status": document.indexing_status, + "processing_started_at": document.processing_started_at, + "parsing_completed_at": document.parsing_completed_at, + "cleaning_completed_at": document.cleaning_completed_at, + "splitting_completed_at": document.splitting_completed_at, + "completed_at": document.completed_at, + "paused_at": document.paused_at, + "error": document.error, + "stopped_at": document.stopped_at, + "completed_segments": completed_segments, + "total_segments": total_segments, + } + documents_status.append(marshal(document_dict, document_status_fields)) data = {"data": documents_status} return data @@ -667,6 +686,7 @@ class DatasetRetrievalSettingApi(Resource): | VectorType.TABLESTORE | VectorType.HUAWEI_CLOUD | VectorType.TENCENT + | VectorType.MATRIXONE ): return { "retrieval_method": [ @@ -714,6 +734,7 @@ class DatasetRetrievalSettingMockApi(Resource): | VectorType.TABLESTORE | VectorType.TENCENT | VectorType.HUAWEI_CLOUD + | VectorType.MATRIXONE ): return { "retrieval_method": [ diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 68601adfed..7ac60a0dc2 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -6,7 +6,7 @@ from typing import cast from flask import request from flask_login import current_user from flask_restful import Resource, fields, marshal, marshal_with, reqparse -from sqlalchemy import asc, desc +from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -43,7 +43,6 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.extract_setting import ExtractSetting from extensions.ext_database import db -from extensions.ext_redis import redis_client from fields.document_fields import ( dataset_and_document_fields, document_fields, @@ -54,8 +53,6 @@ from libs.login import login_required from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig -from tasks.add_document_to_index_task import add_document_to_index_task -from tasks.remove_document_from_index_task import remove_document_from_index_task class DocumentResource(Resource): @@ -112,7 +109,7 @@ class GetProcessRuleApi(Resource): limits = DocumentService.DEFAULT_RULES["limits"] if document_id: # get the latest process rule - document = Document.query.get_or_404(document_id) + document = db.get_or_404(Document, document_id) dataset = DatasetService.get_dataset(document.dataset_id) @@ -175,7 +172,7 @@ class DatasetDocumentListApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) - query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) + query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) if search: search = f"%{search}%" @@ -209,18 +206,24 @@ class DatasetDocumentListApi(Resource): desc(Document.position), ) - paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items if fetch: for document in documents: - completed_segments = DocumentSegment.query.filter( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", - ).count() - total_segments = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" - ).count() + completed_segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ) + .count() + ) + total_segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .count() + ) document.completed_segments = completed_segments document.total_segments = total_segments data = marshal(documents, document_with_segments_fields) @@ -563,19 +566,36 @@ class DocumentBatchIndexingStatusApi(DocumentResource): documents = self.get_batch_documents(dataset_id, batch) documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", - ).count() - total_segments = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" - ).count() - document.completed_segments = completed_segments - document.total_segments = total_segments - if document.is_paused: - document.indexing_status = "paused" - documents_status.append(marshal(document, document_status_fields)) + completed_segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ) + .count() + ) + total_segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .count() + ) + # Create a dictionary with document attributes and additional fields + document_dict = { + "id": document.id, + "indexing_status": "paused" if document.is_paused else document.indexing_status, + "processing_started_at": document.processing_started_at, + "parsing_completed_at": document.parsing_completed_at, + "cleaning_completed_at": document.cleaning_completed_at, + "splitting_completed_at": document.splitting_completed_at, + "completed_at": document.completed_at, + "paused_at": document.paused_at, + "error": document.error, + "stopped_at": document.stopped_at, + "completed_segments": completed_segments, + "total_segments": total_segments, + } + documents_status.append(marshal(document_dict, document_status_fields)) data = {"data": documents_status} return data @@ -589,20 +609,37 @@ class DocumentIndexingStatusApi(DocumentResource): document_id = str(document_id) document = self.get_document(dataset_id, document_id) - completed_segments = DocumentSegment.query.filter( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document_id), - DocumentSegment.status != "re_segment", - ).count() - total_segments = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment" - ).count() + completed_segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document_id), + DocumentSegment.status != "re_segment", + ) + .count() + ) + total_segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment") + .count() + ) - document.completed_segments = completed_segments - document.total_segments = total_segments - if document.is_paused: - document.indexing_status = "paused" - return marshal(document, document_status_fields) + # Create a dictionary with document attributes and additional fields + document_dict = { + "id": document.id, + "indexing_status": "paused" if document.is_paused else document.indexing_status, + "processing_started_at": document.processing_started_at, + "parsing_completed_at": document.parsing_completed_at, + "cleaning_completed_at": document.cleaning_completed_at, + "splitting_completed_at": document.splitting_completed_at, + "completed_at": document.completed_at, + "paused_at": document.paused_at, + "error": document.error, + "stopped_at": document.stopped_at, + "completed_segments": completed_segments, + "total_segments": total_segments, + } + return marshal(document_dict, document_status_fields) class DocumentDetailApi(DocumentResource): @@ -822,77 +859,16 @@ class DocumentStatusApi(DocumentResource): DatasetService.check_dataset_permission(dataset, current_user) document_ids = request.args.getlist("document_id") - for document_id in document_ids: - document = self.get_document(dataset_id, document_id) - - indexing_cache_key = "document_{}_indexing".format(document.id) - cache_result = redis_client.get(indexing_cache_key) - if cache_result is not None: - raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later") - - if action == "enable": - if document.enabled: - continue - document.enabled = True - document.disabled_at = None - document.disabled_by = None - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - add_document_to_index_task.delay(document_id) - - elif action == "disable": - if not document.completed_at or document.indexing_status != "completed": - raise InvalidActionError(f"Document: {document.name} is not completed.") - if not document.enabled: - continue - - document.enabled = False - document.disabled_at = datetime.now(UTC).replace(tzinfo=None) - document.disabled_by = current_user.id - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - remove_document_from_index_task.delay(document_id) - - elif action == "archive": - if document.archived: - continue - - document.archived = True - document.archived_at = datetime.now(UTC).replace(tzinfo=None) - document.archived_by = current_user.id - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - if document.enabled: - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - remove_document_from_index_task.delay(document_id) - - elif action == "un_archive": - if not document.archived: - continue - document.archived = False - document.archived_at = None - document.archived_by = None - document.updated_at = datetime.now(UTC).replace(tzinfo=None) - db.session.commit() - - # Set cache to prevent indexing the same document multiple times - redis_client.setex(indexing_cache_key, 600, 1) - - add_document_to_index_task.delay(document_id) - else: - raise InvalidActionError() + try: + DocumentService.batch_update_document_status(dataset, document_ids, action, current_user) + except services.errors.document.DocumentIndexingError as e: + raise InvalidActionError(str(e)) + except ValueError as e: + raise InvalidActionError(str(e)) + except NotFound as e: + raise NotFound(str(e)) + return {"result": "success"}, 200 diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index a145038672..48142dbe73 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -4,6 +4,7 @@ import pandas as pd from flask import request from flask_login import current_user from flask_restful import Resource, marshal, reqparse +from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound import services @@ -26,6 +27,7 @@ from controllers.console.wraps import ( from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType +from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields from libs.login import login_required @@ -74,9 +76,14 @@ class DatasetDocumentSegmentListApi(Resource): hit_count_gte = args["hit_count_gte"] keyword = args["keyword"] - query = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).order_by(DocumentSegment.position.asc()) + query = ( + select(DocumentSegment) + .filter( + DocumentSegment.document_id == str(document_id), + DocumentSegment.tenant_id == current_user.current_tenant_id, + ) + .order_by(DocumentSegment.position.asc()) + ) if status_list: query = query.filter(DocumentSegment.status.in_(status_list)) @@ -93,7 +100,7 @@ class DatasetDocumentSegmentListApi(Resource): elif args["enabled"].lower() == "false": query = query.filter(DocumentSegment.enabled == False) - segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) response = { "data": marshal(segments.items, segment_fields), @@ -276,9 +283,11 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise ProviderNotInitializeError(ex.description) # check segment segment_id = str(segment_id) - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .first() + ) if not segment: raise NotFound("Segment not found.") # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor @@ -320,9 +329,11 @@ class DatasetDocumentSegmentUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .first() + ) if not segment: raise NotFound("Segment not found.") # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor @@ -363,7 +374,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource): if len(request.files) > 1: raise TooManyFilesError() # check file type - if not file.filename.endswith(".csv"): + if not file.filename or not file.filename.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") try: @@ -423,9 +434,11 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .first() + ) if not segment: raise NotFound("Segment not found.") if not current_user.is_dataset_editor: @@ -478,9 +491,11 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .first() + ) if not segment: raise NotFound("Segment not found.") parser = reqparse.RequestParser() @@ -523,9 +538,11 @@ class ChildChunkAddApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .first() + ) if not segment: raise NotFound("Segment not found.") # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor @@ -567,16 +584,20 @@ class ChildChunkUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .first() + ) if not segment: raise NotFound("Segment not found.") # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ChildChunk.query.filter( - ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id - ).first() + child_chunk = ( + db.session.query(ChildChunk) + .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) + .first() + ) if not child_chunk: raise NotFound("Child chunk not found.") # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor @@ -612,16 +633,20 @@ class ChildChunkUpdateApi(Resource): raise NotFound("Document not found.") # check segment segment_id = str(segment_id) - segment = DocumentSegment.query.filter( - DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id) + .first() + ) if not segment: raise NotFound("Segment not found.") # check child chunk child_chunk_id = str(child_chunk_id) - child_chunk = ChildChunk.query.filter( - ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id - ).first() + child_chunk = ( + db.session.query(ChildChunk) + .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id) + .first() + ) if not child_chunk: raise NotFound("Child chunk not found.") # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index b8fd1f0358..6944c56bf8 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -46,6 +46,18 @@ class NotAllowedCreateWorkspace(BaseHTTPException): code = 400 +class WorkspaceMembersLimitExceeded(BaseHTTPException): + error_code = "limit_exceeded" + description = "Unable to add member because the maximum workspace's member limit was exceeded" + code = 400 + + +class WorkspacesLimitExceeded(BaseHTTPException): + error_code = "limit_exceeded" + description = "Unable to create workspace because the maximum workspace limit was exceeded" + code = 400 + + class AccountBannedError(BaseHTTPException): error_code = "account_banned" description = "Account is banned." diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py index 18221b7797..1e05ff4206 100644 --- a/api/controllers/console/explore/error.py +++ b/api/controllers/console/explore/error.py @@ -23,3 +23,9 @@ class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): error_code = "app_suggested_questions_after_answer_disabled" description = "Function Suggested questions after answer disabled." code = 403 + + +class AppAccessDeniedError(BaseHTTPException): + error_code = "access_denied" + description = "App access denied." + code = 403 diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 9336c35a0d..9d0c08564e 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,3 +1,4 @@ +import logging from datetime import UTC, datetime from typing import Any @@ -15,6 +16,11 @@ from fields.installed_app_fields import installed_app_list_fields from libs.login import login_required from models import App, InstalledApp, RecommendedApp from services.account_service import TenantService +from services.app_service import AppService +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService + +logger = logging.getLogger(__name__) class InstalledAppsListApi(Resource): @@ -48,6 +54,28 @@ class InstalledAppsListApi(Resource): for installed_app in installed_apps if installed_app.app is not None ] + + # filter out apps that user doesn't have access to + if FeatureService.get_system_features().webapp_auth.enabled: + user_id = current_user.id + res = [] + app_ids = [installed_app["app"].id for installed_app in installed_app_list] + webapp_settings = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids) + for installed_app in installed_app_list: + webapp_setting = webapp_settings.get(installed_app["app"].id) + if not webapp_setting: + continue + if webapp_setting.access_mode == "sso_verified": + continue + app_code = AppService.get_app_code_by_id(str(installed_app["app"].id)) + if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( + user_id=user_id, + app_code=app_code, + ): + res.append(installed_app) + installed_app_list = res + logger.debug(f"installed_app_list: {installed_app_list}, user_id: {user_id}") + installed_app_list.sort( key=lambda app: ( -app["is_pinned"], @@ -66,7 +94,7 @@ class InstalledAppsListApi(Resource): parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") args = parser.parse_args() - recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() + recommended_app = db.session.query(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]).first() if recommended_app is None: raise NotFound("App not found") @@ -79,9 +107,11 @@ class InstalledAppsListApi(Resource): if not app.is_public: raise Forbidden("You can't install a non-public app") - installed_app = InstalledApp.query.filter( - and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id) - ).first() + installed_app = ( + db.session.query(InstalledApp) + .filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id)) + .first() + ) if installed_app is None: # todo: position diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 49ea81a8a0..afbd78bd5b 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -4,10 +4,14 @@ from flask_login import current_user from flask_restful import Resource from werkzeug.exceptions import NotFound +from controllers.console.explore.error import AppAccessDeniedError from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from libs.login import login_required from models import InstalledApp +from services.app_service import AppService +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService def installed_app_required(view=None): @@ -48,6 +52,36 @@ def installed_app_required(view=None): return decorator +def user_allowed_to_access_app(view=None): + def decorator(view): + @wraps(view) + def decorated(installed_app: InstalledApp, *args, **kwargs): + feature = FeatureService.get_system_features() + if feature.webapp_auth.enabled: + app_id = installed_app.app_id + app_code = AppService.get_app_code_by_id(app_id) + res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( + user_id=str(current_user.id), + app_code=app_code, + ) + if not res: + raise AppAccessDeniedError() + + return view(installed_app, *args, **kwargs) + + return decorated + + if view: + return decorator(view) + return decorator + + class InstalledAppResource(Resource): # must be reversed if there are multiple decorators - method_decorators = [installed_app_required, account_initialization_required, login_required] + + method_decorators = [ + user_allowed_to_access_app, + installed_app_required, + account_initialization_required, + login_required, + ] diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index ba74e2c074..b4eb5e246b 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -15,7 +15,7 @@ class LoadBalancingCredentialsValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str): - if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role): + if not TenantAccountRole.is_privileged_role(current_user.current_role): raise Forbidden() tenant_id = current_user.current_tenant_id @@ -64,7 +64,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource): @login_required @account_initialization_required def post(self, provider: str, config_id: str): - if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role): + if not TenantAccountRole.is_privileged_role(current_user.current_role): raise Forbidden() tenant_id = current_user.current_tenant_id diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index b9918b0d32..db49da7840 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -6,6 +6,7 @@ from flask_restful import Resource, abort, marshal_with, reqparse import services from configs import dify_config from controllers.console import api +from controllers.console.error import WorkspaceMembersLimitExceeded from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, @@ -17,6 +18,7 @@ from libs.login import login_required from models.account import Account, TenantAccountRole from services.account_service import RegisterService, TenantService from services.errors.account import AccountAlreadyInTenantError +from services.feature_service import FeatureService class MemberListApi(Resource): @@ -54,6 +56,12 @@ class MemberInviteEmailApi(Resource): inviter = current_user invitation_results = [] console_web_url = dify_config.CONSOLE_WEB_URL + + workspace_members = FeatureService.get_features(tenant_id=inviter.current_tenant.id).workspace_members + + if not workspace_members.is_available(len(invitee_emails)): + raise WorkspaceMembersLimitExceeded() + for invitee_email in invitee_emails: try: token = RegisterService.invite_new_member( @@ -71,7 +79,6 @@ class MemberInviteEmailApi(Resource): invitation_results.append( {"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"} ) - break except Exception as e: invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)}) diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index fda5a7d3bb..9bddbb4b4b 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -41,12 +41,16 @@ class PluginListApi(Resource): @account_initialization_required def get(self): tenant_id = current_user.current_tenant_id + parser = reqparse.RequestParser() + parser.add_argument("page", type=int, required=False, location="args", default=1) + parser.add_argument("page_size", type=int, required=False, location="args", default=256) + args = parser.parse_args() try: - plugins = PluginService.list(tenant_id) + plugins_with_total = PluginService.list_with_total(tenant_id, args["page"], args["page_size"]) except PluginDaemonClientSideError as e: raise ValueError(e) - return jsonable_encoder({"plugins": plugins}) + return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total}) class PluginListLatestVersionsApi(Resource): diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py index 71e6f9178f..19999e7361 100644 --- a/api/controllers/console/workspace/workspace.py +++ b/api/controllers/console/workspace/workspace.py @@ -3,6 +3,7 @@ import logging from flask import request from flask_login import current_user from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse +from sqlalchemy import select from werkzeug.exceptions import Unauthorized import services @@ -67,16 +68,24 @@ class TenantListApi(Resource): @account_initialization_required def get(self): tenants = TenantService.get_join_tenants(current_user) + tenant_dicts = [] for tenant in tenants: features = FeatureService.get_features(tenant.id) - if features.billing.enabled: - tenant.plan = features.billing.subscription.plan - else: - tenant.plan = "sandbox" - if tenant.id == current_user.current_tenant_id: - tenant.current = True # Set current=True for current tenant - return {"workspaces": marshal(tenants, tenants_fields)}, 200 + + # Create a dictionary with tenant attributes + tenant_dict = { + "id": tenant.id, + "name": tenant.name, + "status": tenant.status, + "created_at": tenant.created_at, + "plan": features.billing.subscription.plan if features.billing.enabled else "sandbox", + "current": tenant.id == current_user.current_tenant_id, + } + + tenant_dicts.append(tenant_dict) + + return {"workspaces": marshal(tenant_dicts, tenants_fields)}, 200 class WorkspaceListApi(Resource): @@ -88,9 +97,8 @@ class WorkspaceListApi(Resource): parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") args = parser.parse_args() - tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate( - page=args["page"], per_page=args["limit"], error_out=False - ) + stmt = select(Tenant).order_by(Tenant.created_at.desc()) + tenants = db.paginate(select=stmt, page=args["page"], per_page=args["limit"], error_out=False) has_more = False if tenants.has_next: @@ -162,7 +170,7 @@ class CustomConfigWorkspaceApi(Resource): parser.add_argument("replace_webapp_logo", type=str, location="json") args = parser.parse_args() - tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404() + tenant = db.get_or_404(Tenant, current_user.current_tenant_id) custom_config_dict = { "remove_webapp_brand": args["remove_webapp_brand"], @@ -226,7 +234,7 @@ class WorkspaceInfoApi(Resource): parser.add_argument("name", type=str, required=True, location="json") args = parser.parse_args() - tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404() + tenant = db.get_or_404(Tenant, current_user.current_tenant_id) tenant.name = args["name"] db.session.commit() diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index 360cbd9246..ca122772de 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -44,6 +44,17 @@ def only_edition_cloud(view): return decorated +def only_edition_enterprise(view): + @wraps(view) + def decorated(*args, **kwargs): + if not dify_config.ENTERPRISE_ENABLED: + abort(404) + + return view(*args, **kwargs) + + return decorated + + def only_edition_self_hosted(view): @wraps(view) def decorated(*args, **kwargs): diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 6641632169..f1a15793c7 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -64,9 +64,24 @@ class PluginUploadFileApi(Resource): extension = guess_extension(tool_file.mimetype) or ".bin" preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension) - tool_file.mime_type = mimetype - tool_file.extension = extension - tool_file.preview_url = preview_url + + # Create a dictionary with all the necessary attributes + result = { + "id": tool_file.id, + "user_id": tool_file.user_id, + "tenant_id": tool_file.tenant_id, + "conversation_id": tool_file.conversation_id, + "file_key": tool_file.file_key, + "mimetype": tool_file.mimetype, + "original_url": tool_file.original_url, + "name": tool_file.name, + "size": tool_file.size, + "mime_type": mimetype, + "extension": extension, + "preview_url": preview_url, + } + + return result, 201 except services.errors.file.FileTooLargeError as file_too_large_error: raise FileTooLargeError(file_too_large_error.description) except services.errors.file.UnsupportedFileTypeError: diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index f147a3453f..d51db4322a 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -5,5 +5,6 @@ from libs.external_api import ExternalApi bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") api = ExternalApi(bp) +from . import mail from .plugin import plugin from .workspace import workspace diff --git a/api/controllers/inner_api/mail.py b/api/controllers/inner_api/mail.py new file mode 100644 index 0000000000..ce3373d65c --- /dev/null +++ b/api/controllers/inner_api/mail.py @@ -0,0 +1,27 @@ +from flask_restful import ( + Resource, # type: ignore + reqparse, +) + +from controllers.console.wraps import setup_required +from controllers.inner_api import api +from controllers.inner_api.wraps import enterprise_inner_api_only +from services.enterprise.mail_service import DifyMail, EnterpriseMailService + + +class EnterpriseMail(Resource): + @setup_required + @enterprise_inner_api_only + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("to", type=str, action="append", required=True) + parser.add_argument("subject", type=str, required=True) + parser.add_argument("body", type=str, required=True) + parser.add_argument("substitutions", type=dict, required=False) + args = parser.parse_args() + + EnterpriseMailService.send_mail(DifyMail(**args)) + return {"message": "success"}, 200 + + +api.add_resource(EnterpriseMail, "/enterprise/mail") diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index f3a1bd8fa5..41063b35a5 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -29,7 +29,7 @@ from core.plugin.entities.request import ( RequestRequestUploadFile, ) from core.tools.entities.tool_entities import ToolProviderType -from libs.helper import compact_generate_response +from libs.helper import length_prefixed_response from models.account import Account, Tenant from models.model import EndUser @@ -44,7 +44,7 @@ class PluginInvokeLLMApi(Resource): response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload) return PluginModelBackwardsInvocation.convert_to_event_stream(response) - return compact_generate_response(generator()) + return length_prefixed_response(0xF, generator()) class PluginInvokeTextEmbeddingApi(Resource): @@ -101,7 +101,7 @@ class PluginInvokeTTSApi(Resource): ) return PluginModelBackwardsInvocation.convert_to_event_stream(response) - return compact_generate_response(generator()) + return length_prefixed_response(0xF, generator()) class PluginInvokeSpeech2TextApi(Resource): @@ -162,7 +162,7 @@ class PluginInvokeToolApi(Resource): ), ) - return compact_generate_response(generator()) + return length_prefixed_response(0xF, generator()) class PluginInvokeParameterExtractorNodeApi(Resource): @@ -228,7 +228,7 @@ class PluginInvokeAppApi(Resource): files=payload.files, ) - return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response)) + return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response)) class PluginInvokeEncryptApi(Resource): diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py index 709bba3f30..50408e0929 100644 --- a/api/controllers/inner_api/plugin/wraps.py +++ b/api/controllers/inner_api/plugin/wraps.py @@ -2,12 +2,14 @@ from collections.abc import Callable from functools import wraps from typing import Optional -from flask import request +from flask import current_app, request +from flask_login import user_logged_in from flask_restful import reqparse from pydantic import BaseModel from sqlalchemy.orm import Session from extensions.ext_database import db +from libs.login import _get_user from models.account import Account, Tenant from models.model import EndUser from services.account_service import AccountService @@ -30,6 +32,7 @@ def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser: ) session.add(user_model) session.commit() + session.refresh(user_model) else: user_model = AccountService.load_user(user_id) if not user_model: @@ -80,7 +83,12 @@ def get_user_tenant(view: Optional[Callable] = None): raise ValueError("tenant not found") kwargs["tenant_model"] = tenant_model - kwargs["user_model"] = get_user(tenant_id, user_id) + + user = get_user(tenant_id, user_id) + kwargs["user_model"] = user + + current_app.login_manager._update_request_context_with_user(user) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore return view_func(*args, **kwargs) diff --git a/api/controllers/service_api/__init__.py b/api/controllers/service_api/__init__.py index d97074e8b9..d964e27819 100644 --- a/api/controllers/service_api/__init__.py +++ b/api/controllers/service_api/__init__.py @@ -6,6 +6,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1") api = ExternalApi(bp) from . import index -from .app import annotation, app, audio, completion, conversation, file, message, workflow +from .app import annotation, app, audio, completion, conversation, file, message, site, workflow from .dataset import dataset, document, hit_testing, metadata, segment, upload_file from .workspace import models diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index bd1a23b723..595ae118ef 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -3,19 +3,19 @@ from flask_restful import Resource, marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden from controllers.service_api import api -from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token +from controllers.service_api.wraps import validate_app_token from extensions.ext_redis import redis_client from fields.annotation_fields import ( annotation_fields, ) from libs.login import current_user -from models.model import App, EndUser +from models.model import App from services.annotation_service import AppAnnotationService class AnnotationReplyActionApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) - def post(self, app_model: App, end_user: EndUser, action): + @validate_app_token + def post(self, app_model: App, action): parser = reqparse.RequestParser() parser.add_argument("score_threshold", required=True, type=float, location="json") parser.add_argument("embedding_provider_name", required=True, type=str, location="json") @@ -31,8 +31,8 @@ class AnnotationReplyActionApi(Resource): class AnnotationReplyActionStatusApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - def get(self, app_model: App, end_user: EndUser, job_id, action): + @validate_app_token + def get(self, app_model: App, job_id, action): job_id = str(job_id) app_annotation_job_key = "{}_app_annotation_job_{}".format(action, str(job_id)) cache_result = redis_client.get(app_annotation_job_key) @@ -49,8 +49,8 @@ class AnnotationReplyActionStatusApi(Resource): class AnnotationListApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - def get(self, app_model: App, end_user: EndUser): + @validate_app_token + def get(self, app_model: App): page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) keyword = request.args.get("keyword", default="", type=str) @@ -65,9 +65,9 @@ class AnnotationListApi(Resource): } return response, 200 - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) + @validate_app_token @marshal_with(annotation_fields) - def post(self, app_model: App, end_user: EndUser): + def post(self, app_model: App): parser = reqparse.RequestParser() parser.add_argument("question", required=True, type=str, location="json") parser.add_argument("answer", required=True, type=str, location="json") @@ -77,9 +77,9 @@ class AnnotationListApi(Resource): class AnnotationUpdateDeleteApi(Resource): - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON)) + @validate_app_token @marshal_with(annotation_fields) - def put(self, app_model: App, end_user: EndUser, annotation_id): + def put(self, app_model: App, annotation_id): if not current_user.is_editor: raise Forbidden() @@ -91,8 +91,8 @@ class AnnotationUpdateDeleteApi(Resource): annotation = AppAnnotationService.update_app_annotation_directly(args, app_model.id, annotation_id) return annotation - @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY)) - def delete(self, app_model: App, end_user: EndUser, annotation_id): + @validate_app_token + def delete(self, app_model: App, annotation_id): if not current_user.is_editor: raise Forbidden() diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 2c03aba33d..89222d5e83 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -47,7 +47,13 @@ class AppInfoApi(Resource): def get(self, app_model: App): """Get app information""" tags = [tag.name for tag in app_model.tags] - return {"name": app_model.name, "description": app_model.description, "tags": tags, "mode": app_model.mode} + return { + "name": app_model.name, + "description": app_model.description, + "tags": tags, + "mode": app_model.mode, + "author_name": app_model.author_name, + } api.add_resource(AppParameterApi, "/parameters") diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py index 1b148a9756..d90fa2081f 100644 --- a/api/controllers/service_api/app/message.py +++ b/api/controllers/service_api/app/message.py @@ -93,6 +93,18 @@ class MessageFeedbackApi(Resource): return {"result": "success"} +class AppGetFeedbacksApi(Resource): + @validate_app_token + def get(self, app_model: App): + """Get All Feedbacks of an app""" + parser = reqparse.RequestParser() + parser.add_argument("page", type=int, default=1, location="args") + parser.add_argument("limit", type=int_range(1, 101), required=False, default=20, location="args") + args = parser.parse_args() + feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"]) + return {"data": feedbacks} + + class MessageSuggestedApi(Resource): @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True)) def get(self, app_model: App, end_user: EndUser, message_id): @@ -119,3 +131,4 @@ class MessageSuggestedApi(Resource): api.add_resource(MessageListApi, "/messages") api.add_resource(MessageFeedbackApi, "/messages//feedbacks") api.add_resource(MessageSuggestedApi, "/messages//suggested") +api.add_resource(AppGetFeedbacksApi, "/app/feedbacks") diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py new file mode 100644 index 0000000000..e752dfee30 --- /dev/null +++ b/api/controllers/service_api/app/site.py @@ -0,0 +1,30 @@ +from flask_restful import Resource, marshal_with +from werkzeug.exceptions import Forbidden + +from controllers.common import fields +from controllers.service_api import api +from controllers.service_api.wraps import validate_app_token +from extensions.ext_database import db +from models.account import TenantStatus +from models.model import App, Site + + +class AppSiteApi(Resource): + """Resource for app sites.""" + + @validate_app_token + @marshal_with(fields.site_fields) + def get(self, app_model: App): + """Retrieve app site info.""" + site = db.session.query(Site).filter(Site.app_id == app_model.id).first() + + if not site: + raise Forbidden() + + if app_model.tenant.status == TenantStatus.ARCHIVE: + raise Forbidden() + + return site + + +api.add_resource(AppSiteApi, "/site") diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index e9bb2b046a..efb4acc5fb 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -24,12 +24,13 @@ from core.errors.error import ( QuotaExceededError, ) from core.model_runtime.errors.invoke import InvokeError +from core.workflow.entities.workflow_execution import WorkflowExecutionStatus from extensions.ext_database import db from fields.workflow_app_log_fields import workflow_app_log_pagination_fields from libs import helper from libs.helper import TimestampField from models.model import App, AppMode, EndUser -from models.workflow import WorkflowRun, WorkflowRunStatus +from models.workflow import WorkflowRun from services.app_generate_service import AppGenerateService from services.errors.llm import InvokeRateLimitError from services.workflow_app_service import WorkflowAppService @@ -134,11 +135,25 @@ class WorkflowAppLogApi(Resource): parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") parser.add_argument("created_at__before", type=str, location="args") parser.add_argument("created_at__after", type=str, location="args") + parser.add_argument( + "created_by_end_user_session_id", + type=str, + location="args", + required=False, + default=None, + ) + parser.add_argument( + "created_by_account", + type=str, + location="args", + required=False, + default=None, + ) parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") args = parser.parse_args() - args.status = WorkflowRunStatus(args.status) if args.status else None + args.status = WorkflowExecutionStatus(args.status) if args.status else None if args.created_at__before: args.created_at__before = isoparse(args.created_at__before) @@ -157,6 +172,8 @@ class WorkflowAppLogApi(Resource): created_at_after=args.created_at__after, page=args.page, limit=args.limit, + created_by_end_user_session_id=args.created_by_end_user_session_id, + created_by_account=args.created_by_account, ) return workflow_app_log_pagination diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index ee190245d5..839afdb9fd 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -1,19 +1,25 @@ from flask import request -from flask_restful import marshal, reqparse +from flask_restful import marshal, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound import services.dataset_service from controllers.service_api import api -from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError -from controllers.service_api.wraps import DatasetApiResource +from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError +from controllers.service_api.wraps import ( + DatasetApiResource, + cloud_edition_billing_rate_limit_check, + validate_dataset_token, +) from core.model_runtime.entities.model_entities import ModelType from core.plugin.entities.plugin import ModelProviderID from core.provider_manager import ProviderManager from fields.dataset_fields import dataset_detail_fields +from fields.tag_fields import tag_fields from libs.login import current_user from models.dataset import Dataset, DatasetPermissionEnum -from services.dataset_service import DatasetPermissionService, DatasetService +from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import RetrievalModel +from services.tag_service import TagService def _validate_name(name): @@ -68,6 +74,7 @@ class DatasetListApi(DatasetApiResource): response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} return response, 200 + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id): """Resource for creating datasets.""" parser = reqparse.RequestParser() @@ -191,6 +198,7 @@ class DatasetApi(DatasetApiResource): return data, 200 + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, _, dataset_id): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -291,6 +299,7 @@ class DatasetApi(DatasetApiResource): return result_data, 200 + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, _, dataset_id): """ Deletes a dataset given its ID. @@ -313,12 +322,193 @@ class DatasetApi(DatasetApiResource): try: if DatasetService.delete_dataset(dataset_id_str, current_user): DatasetPermissionService.clear_partial_member_list(dataset_id_str) - return {"result": "success"}, 204 + return 204 else: raise NotFound("Dataset not found.") except services.errors.dataset.DatasetInUseError: raise DatasetInUseError() +class DocumentStatusApi(DatasetApiResource): + """Resource for batch document status operations.""" + + def patch(self, tenant_id, dataset_id, action): + """ + Batch update document status. + + Args: + tenant_id: tenant id + dataset_id: dataset id + action: action to perform (enable, disable, archive, un_archive) + + Returns: + dict: A dictionary with a key 'result' and a value 'success' + int: HTTP status code 200 indicating that the operation was successful. + + Raises: + NotFound: If the dataset with the given ID does not exist. + Forbidden: If the user does not have permission. + InvalidActionError: If the action is invalid or cannot be performed. + """ + dataset_id_str = str(dataset_id) + dataset = DatasetService.get_dataset(dataset_id_str) + + if dataset is None: + raise NotFound("Dataset not found.") + + # Check user's permission + try: + DatasetService.check_dataset_permission(dataset, current_user) + except services.errors.account.NoPermissionError as e: + raise Forbidden(str(e)) + + # Check dataset model setting + DatasetService.check_dataset_model_setting(dataset) + + # Get document IDs from request body + data = request.get_json() + document_ids = data.get("document_ids", []) + + try: + DocumentService.batch_update_document_status(dataset, document_ids, action, current_user) + except services.errors.document.DocumentIndexingError as e: + raise InvalidActionError(str(e)) + except ValueError as e: + raise InvalidActionError(str(e)) + + return {"result": "success"}, 200 + + +class DatasetTagsApi(DatasetApiResource): + @validate_dataset_token + @marshal_with(tag_fields) + def get(self, _, dataset_id): + """Get all knowledge type tags.""" + tags = TagService.get_tags("knowledge", current_user.current_tenant_id) + + return tags, 200 + + @validate_dataset_token + def post(self, _, dataset_id): + """Add a knowledge type tag.""" + if not (current_user.is_editor or current_user.is_dataset_editor): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 50 characters.", + type=DatasetTagsApi._validate_tag_name, + ) + + args = parser.parse_args() + args["type"] = "knowledge" + tag = TagService.save_tags(args) + + response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": 0} + + return response, 200 + + @validate_dataset_token + def patch(self, _, dataset_id): + if not (current_user.is_editor or current_user.is_dataset_editor): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument( + "name", + nullable=False, + required=True, + help="Name must be between 1 to 50 characters.", + type=DatasetTagsApi._validate_tag_name, + ) + parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) + args = parser.parse_args() + args["type"] = "knowledge" + tag = TagService.update_tags(args, args.get("tag_id")) + + binding_count = TagService.get_tag_binding_count(args.get("tag_id")) + + response = {"id": tag.id, "name": tag.name, "type": tag.type, "binding_count": binding_count} + + return response, 200 + + @validate_dataset_token + def delete(self, _, dataset_id): + """Delete a knowledge type tag.""" + if not current_user.is_editor: + raise Forbidden() + parser = reqparse.RequestParser() + parser.add_argument("tag_id", nullable=False, required=True, help="Id of a tag.", type=str) + args = parser.parse_args() + TagService.delete_tag(args.get("tag_id")) + + return 204 + + @staticmethod + def _validate_tag_name(name): + if not name or len(name) < 1 or len(name) > 50: + raise ValueError("Name must be between 1 to 50 characters.") + return name + + +class DatasetTagBindingApi(DatasetApiResource): + @validate_dataset_token + def post(self, _, dataset_id): + # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator + if not (current_user.is_editor or current_user.is_dataset_editor): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument( + "tag_ids", type=list, nullable=False, required=True, location="json", help="Tag IDs is required." + ) + parser.add_argument( + "target_id", type=str, nullable=False, required=True, location="json", help="Target Dataset ID is required." + ) + + args = parser.parse_args() + args["type"] = "knowledge" + TagService.save_tag_binding(args) + + return 204 + + +class DatasetTagUnbindingApi(DatasetApiResource): + @validate_dataset_token + def post(self, _, dataset_id): + # The role of the current user in the ta table must be admin, owner, editor, or dataset_operator + if not (current_user.is_editor or current_user.is_dataset_editor): + raise Forbidden() + + parser = reqparse.RequestParser() + parser.add_argument("tag_id", type=str, nullable=False, required=True, help="Tag ID is required.") + parser.add_argument("target_id", type=str, nullable=False, required=True, help="Target ID is required.") + + args = parser.parse_args() + args["type"] = "knowledge" + TagService.delete_tag_binding(args) + + return 204 + + +class DatasetTagsBindingStatusApi(DatasetApiResource): + @validate_dataset_token + def get(self, _, *args, **kwargs): + """Get all knowledge type tags.""" + dataset_id = kwargs.get("dataset_id") + tags = TagService.get_tags_by_target_id("knowledge", current_user.current_tenant_id, str(dataset_id)) + tags_list = [{"id": tag.id, "name": tag.name} for tag in tags] + response = {"data": tags_list, "total": len(tags)} + return response, 200 + + api.add_resource(DatasetListApi, "/datasets") api.add_resource(DatasetApi, "/datasets/") +api.add_resource(DocumentStatusApi, "/datasets//documents/status/") +api.add_resource(DatasetTagsApi, "/datasets/tags") +api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding") +api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding") +api.add_resource(DatasetTagsBindingStatusApi, "/datasets//tags") diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index 33eda37014..e4779f3bdf 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -2,10 +2,10 @@ import json from flask import request from flask_restful import marshal, reqparse -from sqlalchemy import desc +from sqlalchemy import desc, select from werkzeug.exceptions import NotFound -import services.dataset_service +import services from controllers.common.errors import FilenameNotExistsError from controllers.service_api import api from controllers.service_api.app.error import ( @@ -19,7 +19,11 @@ from controllers.service_api.dataset.error import ( ArchivedDocumentImmutableError, DocumentIndexingError, ) -from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check +from controllers.service_api.wraps import ( + DatasetApiResource, + cloud_edition_billing_rate_limit_check, + cloud_edition_billing_resource_check, +) from core.errors.error import ProviderTokenNotInitError from extensions.ext_database import db from fields.document_fields import document_fields, document_status_fields @@ -35,6 +39,7 @@ class DocumentAddByTextApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("documents", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by text.""" parser = reqparse.RequestParser() @@ -99,6 +104,7 @@ class DocumentUpdateByTextApi(DatasetApiResource): """Resource for update documents.""" @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by text.""" parser = reqparse.RequestParser() @@ -158,6 +164,7 @@ class DocumentAddByFileApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("documents", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by upload file.""" args = {} @@ -175,8 +182,11 @@ class DocumentAddByFileApi(DatasetApiResource): if not dataset: raise ValueError("Dataset does not exist.") - if not dataset.indexing_technique and not args.get("indexing_technique"): + + indexing_technique = args.get("indexing_technique") or dataset.indexing_technique + if not indexing_technique: raise ValueError("indexing_technique is required.") + args["indexing_technique"] = indexing_technique # save file info file = request.files["file"] @@ -206,12 +216,16 @@ class DocumentAddByFileApi(DatasetApiResource): knowledge_config = KnowledgeConfig(**args) DocumentService.document_create_args_validate(knowledge_config) + dataset_process_rule = dataset.latest_process_rule if "process_rule" not in args else None + if not knowledge_config.original_document_id and not dataset_process_rule and not knowledge_config.process_rule: + raise ValueError("process_rule is required.") + try: documents, batch = DocumentService.save_document_with_dataset_id( dataset=dataset, knowledge_config=knowledge_config, account=dataset.created_by_account, - dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, + dataset_process_rule=dataset_process_rule, created_from="api", ) except ProviderTokenNotInitError as ex: @@ -225,6 +239,7 @@ class DocumentUpdateByFileApi(DatasetApiResource): """Resource for update documents.""" @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" args = {} @@ -295,6 +310,7 @@ class DocumentUpdateByFileApi(DatasetApiResource): class DocumentDeleteApi(DatasetApiResource): + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id, dataset_id, document_id): """Delete document.""" document_id = str(document_id) @@ -323,7 +339,7 @@ class DocumentDeleteApi(DatasetApiResource): except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return {"result": "success"}, 204 + return 204 class DocumentListApi(DatasetApiResource): @@ -337,7 +353,7 @@ class DocumentListApi(DatasetApiResource): if not dataset: raise NotFound("Dataset not found.") - query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) + query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) if search: search = f"%{search}%" @@ -345,7 +361,7 @@ class DocumentListApi(DatasetApiResource): query = query.order_by(desc(Document.created_at), desc(Document.position)) - paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) documents = paginated_documents.items response = { @@ -374,19 +390,36 @@ class DocumentIndexingStatusApi(DatasetApiResource): raise NotFound("Documents not found.") documents_status = [] for document in documents: - completed_segments = DocumentSegment.query.filter( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != "re_segment", - ).count() - total_segments = DocumentSegment.query.filter( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" - ).count() - document.completed_segments = completed_segments - document.total_segments = total_segments - if document.is_paused: - document.indexing_status = "paused" - documents_status.append(marshal(document, document_status_fields)) + completed_segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != "re_segment", + ) + .count() + ) + total_segments = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment") + .count() + ) + # Create a dictionary with document attributes and additional fields + document_dict = { + "id": document.id, + "indexing_status": "paused" if document.is_paused else document.indexing_status, + "processing_started_at": document.processing_started_at, + "parsing_completed_at": document.parsing_completed_at, + "cleaning_completed_at": document.cleaning_completed_at, + "splitting_completed_at": document.splitting_completed_at, + "completed_at": document.completed_at, + "paused_at": document.paused_at, + "error": document.error, + "stopped_at": document.stopped_at, + "completed_segments": completed_segments, + "total_segments": total_segments, + } + documents_status.append(marshal(document_dict, document_status_fields)) data = {"data": documents_status} return data diff --git a/api/controllers/service_api/dataset/hit_testing.py b/api/controllers/service_api/dataset/hit_testing.py index 465f71bf03..52e9bca5da 100644 --- a/api/controllers/service_api/dataset/hit_testing.py +++ b/api/controllers/service_api/dataset/hit_testing.py @@ -1,9 +1,10 @@ from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from controllers.service_api import api -from controllers.service_api.wraps import DatasetApiResource +from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): dataset_id_str = str(dataset_id) diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 35582feea0..1968696ee5 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -3,7 +3,7 @@ from flask_restful import marshal, reqparse from werkzeug.exceptions import NotFound from controllers.service_api import api -from controllers.service_api.wraps import DatasetApiResource +from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check from fields.dataset_fields import dataset_metadata_fields from services.dataset_service import DatasetService from services.entities.knowledge_entities.knowledge_entities import ( @@ -14,6 +14,7 @@ from services.metadata_service import MetadataService class DatasetMetadataCreateServiceApi(DatasetApiResource): + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): parser = reqparse.RequestParser() parser.add_argument("type", type=str, required=True, nullable=True, location="json") @@ -39,6 +40,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource): class DatasetMetadataServiceApi(DatasetApiResource): + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, tenant_id, dataset_id, metadata_id): parser = reqparse.RequestParser() parser.add_argument("name", type=str, required=True, nullable=True, location="json") @@ -54,6 +56,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) return marshal(metadata, dataset_metadata_fields), 200 + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id, dataset_id, metadata_id): dataset_id_str = str(dataset_id) metadata_id_str = str(metadata_id) @@ -73,6 +76,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource): class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, action): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) @@ -88,6 +92,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): class DocumentMetadataEditServiceApi(DatasetApiResource): + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): dataset_id_str = str(dataset_id) dataset = DatasetService.get_dataset(dataset_id_str) diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index fb3ca1e15f..403b7f0a0c 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -8,6 +8,7 @@ from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.wraps import ( DatasetApiResource, cloud_edition_billing_knowledge_limit_check, + cloud_edition_billing_rate_limit_check, cloud_edition_billing_resource_check, ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError @@ -35,6 +36,7 @@ class SegmentApi(DatasetApiResource): @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Create single segment.""" # check dataset @@ -139,6 +141,7 @@ class SegmentApi(DatasetApiResource): class DatasetSegmentApi(DatasetApiResource): + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) @@ -159,9 +162,10 @@ class DatasetSegmentApi(DatasetApiResource): if not segment: raise NotFound("Segment not found.") SegmentService.delete_segment(segment, document, dataset) - return {"result": "success"}, 204 + return 204 @cloud_edition_billing_resource_check("vector_space", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id, segment_id): # check dataset dataset_id = str(dataset_id) @@ -208,12 +212,35 @@ class DatasetSegmentApi(DatasetApiResource): ) return {"data": marshal(updated_segment, segment_fields), "doc_form": document.doc_form}, 200 + def get(self, tenant_id, dataset_id, document_id, segment_id): + # check dataset + dataset_id = str(dataset_id) + tenant_id = str(tenant_id) + dataset = db.session.query(Dataset).filter(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + if not dataset: + raise NotFound("Dataset not found.") + # check user's model setting + DatasetService.check_dataset_model_setting(dataset) + # check document + document_id = str(document_id) + document = DocumentService.get_document(dataset_id, document_id) + if not document: + raise NotFound("Document not found.") + # check segment + segment_id = str(segment_id) + segment = SegmentService.get_segment_by_id(segment_id=segment_id, tenant_id=current_user.current_tenant_id) + if not segment: + raise NotFound("Segment not found.") + + return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 + class ChildChunkApi(DatasetApiResource): """Resource for child chunks.""" @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id, segment_id): """Create child chunk.""" # check dataset @@ -310,6 +337,7 @@ class DatasetChildChunkApi(DatasetApiResource): """Resource for updating child chunks.""" @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def delete(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id): """Delete child chunk.""" # check dataset @@ -344,10 +372,11 @@ class DatasetChildChunkApi(DatasetApiResource): except ChildChunkDeleteIndexServiceError as e: raise ChildChunkDeleteIndexError(str(e)) - return {"result": "success"}, 204 + return 204 @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") + @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def patch(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id): """Update child chunk.""" # check dataset diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index cd35ceac1d..d3316a5159 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -99,7 +99,12 @@ def validate_app_token(view: Optional[Callable] = None, *, fetch_user_arg: Optio if user_id: user_id = str(user_id) - kwargs["end_user"] = create_or_update_end_user_for_user_id(app_model, user_id) + end_user = create_or_update_end_user_for_user_id(app_model, user_id) + kwargs["end_user"] = end_user + + # Set EndUser as current logged-in user for flask_login.current_user + current_app.login_manager._update_request_context_with_user(end_user) # type: ignore + user_logged_in.send(current_app._get_current_object(), user=end_user) # type: ignore return view_func(*args, **kwargs) diff --git a/api/controllers/web/__init__.py b/api/controllers/web/__init__.py index 50a04a6254..56749a0e25 100644 --- a/api/controllers/web/__init__.py +++ b/api/controllers/web/__init__.py @@ -15,4 +15,17 @@ api.add_resource(FileApi, "/files/upload") api.add_resource(RemoteFileInfoApi, "/remote-files/") api.add_resource(RemoteFileUploadApi, "/remote-files/upload") -from . import app, audio, completion, conversation, feature, message, passport, saved_message, site, workflow +from . import ( + app, + audio, + completion, + conversation, + feature, + forgot_password, + login, + message, + passport, + saved_message, + site, + workflow, +) diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index c9a37af5ed..94a525a75d 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,12 +1,17 @@ -from flask_restful import marshal_with +from flask import request +from flask_restful import Resource, marshal_with, reqparse from controllers.common import fields from controllers.web import api from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict +from libs.passport import PassportService from models.model import App, AppMode from services.app_service import AppService +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService +from services.webapp_auth_service import WebAppAuthService class AppParameterApi(WebApiResource): @@ -40,5 +45,69 @@ class AppMeta(WebApiResource): return AppService().get_app_meta(app_model) +class AppAccessMode(Resource): + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("appId", type=str, required=False, location="args") + parser.add_argument("appCode", type=str, required=False, location="args") + args = parser.parse_args() + + features = FeatureService.get_system_features() + if not features.webapp_auth.enabled: + return {"accessMode": "public"} + + app_id = args.get("appId") + if args.get("appCode"): + app_code = args["appCode"] + app_id = AppService.get_app_id_by_code(app_code) + + if not app_id: + raise ValueError("appId or appCode must be provided") + + res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) + + return {"accessMode": res.access_mode} + + +class AppWebAuthPermission(Resource): + def get(self): + user_id = "visitor" + try: + auth_header = request.headers.get("Authorization") + if auth_header is None: + raise + if " " not in auth_header: + raise + + auth_scheme, tk = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + if auth_scheme != "bearer": + raise + + decoded = PassportService().verify(tk) + user_id = decoded.get("user_id", "visitor") + except Exception as e: + pass + + features = FeatureService.get_system_features() + if not features.webapp_auth.enabled: + return {"result": True} + + parser = reqparse.RequestParser() + parser.add_argument("appId", type=str, required=True, location="args") + args = parser.parse_args() + + app_id = args["appId"] + app_code = AppService.get_app_code_by_id(app_id) + + res = True + if WebAppAuthService.is_app_require_permission_check(app_id=app_id): + res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code) + return {"result": res} + + api.add_resource(AppParameterApi, "/parameters") api.add_resource(AppMeta, "/meta") +# webapp auth apis +api.add_resource(AppAccessMode, "/webapp/access-mode") +api.add_resource(AppWebAuthPermission, "/webapp/permission") diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 9fe5d08d54..4371e679db 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -121,9 +121,15 @@ class UnsupportedFileTypeError(BaseHTTPException): code = 415 -class WebSSOAuthRequiredError(BaseHTTPException): +class WebAppAuthRequiredError(BaseHTTPException): error_code = "web_sso_auth_required" - description = "Web SSO authentication required." + description = "Web app authentication required." + code = 401 + + +class WebAppAuthAccessDeniedError(BaseHTTPException): + error_code = "web_app_access_denied" + description = "You do not have permission to access this web app." code = 401 diff --git a/api/controllers/web/forgot_password.py b/api/controllers/web/forgot_password.py new file mode 100644 index 0000000000..0da8d65efc --- /dev/null +++ b/api/controllers/web/forgot_password.py @@ -0,0 +1,147 @@ +import base64 +import secrets + +from flask import request +from flask_restful import Resource, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session + +from controllers.console.auth.error import ( + EmailCodeError, + EmailPasswordResetLimitError, + InvalidEmailError, + InvalidTokenError, + PasswordMismatchError, +) +from controllers.console.error import AccountNotFound, EmailSendIpLimitError +from controllers.console.wraps import email_password_login_enabled, only_edition_enterprise, setup_required +from controllers.web import api +from extensions.ext_database import db +from libs.helper import email, extract_remote_ip +from libs.password import hash_password, valid_password +from models.account import Account +from services.account_service import AccountService + + +class ForgotPasswordSendEmailApi(Resource): + @only_edition_enterprise + @setup_required + @email_password_login_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() + + ip_address = extract_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + token = None + if account is None: + raise AccountNotFound() + else: + token = AccountService.send_reset_password_email(account=account, email=args["email"], language=language) + + return {"result": "success", "data": token} + + +class ForgotPasswordCheckApi(Resource): + @only_edition_enterprise + @setup_required + @email_password_login_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + user_email = args["email"] + + is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args["email"]) + if is_forgot_password_error_rate_limit: + raise EmailPasswordResetLimitError() + + token_data = AccountService.get_reset_password_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if user_email != token_data.get("email"): + raise InvalidEmailError() + + if args["code"] != token_data.get("code"): + AccountService.add_forgot_password_error_rate_limit(args["email"]) + raise EmailCodeError() + + # Verified, revoke the first token + AccountService.revoke_reset_password_token(args["token"]) + + # Refresh token data by generating a new token + _, new_token = AccountService.generate_reset_password_token( + user_email, code=args["code"], additional_data={"phase": "reset"} + ) + + AccountService.reset_forgot_password_error_rate_limit(args["email"]) + return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + + +class ForgotPasswordResetApi(Resource): + @only_edition_enterprise + @setup_required + @email_password_login_enabled + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + parser.add_argument("new_password", type=valid_password, required=True, nullable=False, location="json") + parser.add_argument("password_confirm", type=valid_password, required=True, nullable=False, location="json") + args = parser.parse_args() + + # Validate passwords match + if args["new_password"] != args["password_confirm"]: + raise PasswordMismatchError() + + # Validate token and get reset data + reset_data = AccountService.get_reset_password_data(args["token"]) + if not reset_data: + raise InvalidTokenError() + # Must use token in reset phase + if reset_data.get("phase", "") != "reset": + raise InvalidTokenError() + + # Revoke token to prevent reuse + AccountService.revoke_reset_password_token(args["token"]) + + # Generate secure salt and hash password + salt = secrets.token_bytes(16) + password_hashed = hash_password(args["new_password"], salt) + + email = reset_data.get("email", "") + + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none() + + if account: + self._update_existing_account(account, password_hashed, salt, session) + else: + raise AccountNotFound() + + return {"result": "success"} + + def _update_existing_account(self, account, password_hashed, salt, session): + # Update existing account credentials + account.password = base64.b64encode(password_hashed).decode() + account.password_salt = base64.b64encode(salt).decode() + session.commit() + + +api.add_resource(ForgotPasswordSendEmailApi, "/forgot-password") +api.add_resource(ForgotPasswordCheckApi, "/forgot-password/validity") +api.add_resource(ForgotPasswordResetApi, "/forgot-password/resets") diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py new file mode 100644 index 0000000000..01c4f4a262 --- /dev/null +++ b/api/controllers/web/login.py @@ -0,0 +1,108 @@ +from flask_restful import Resource, reqparse +from jwt import InvalidTokenError # type: ignore + +import services +from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError +from controllers.console.error import AccountBannedError, AccountNotFound +from controllers.console.wraps import only_edition_enterprise, setup_required +from controllers.web import api +from libs.helper import email +from libs.password import valid_password +from services.account_service import AccountService +from services.webapp_auth_service import WebAppAuthService + + +class LoginApi(Resource): + """Resource for web app email/password login.""" + + @setup_required + @only_edition_enterprise + def post(self): + """Authenticate user and login.""" + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("password", type=valid_password, required=True, location="json") + args = parser.parse_args() + + try: + account = WebAppAuthService.authenticate(args["email"], args["password"]) + except services.errors.account.AccountLoginError: + raise AccountBannedError() + except services.errors.account.AccountPasswordError: + raise EmailOrPasswordMismatchError() + except services.errors.account.AccountNotFoundError: + raise AccountNotFound() + + token = WebAppAuthService.login(account=account) + return {"result": "success", "data": {"access_token": token}} + + +# class LogoutApi(Resource): +# @setup_required +# def get(self): +# account = cast(Account, flask_login.current_user) +# if isinstance(account, flask_login.AnonymousUserMixin): +# return {"result": "success"} +# flask_login.logout_user() +# return {"result": "success"} + + +class EmailCodeLoginSendEmailApi(Resource): + @setup_required + @only_edition_enterprise + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + + account = WebAppAuthService.get_user_through_email(args["email"]) + if account is None: + raise AccountNotFound() + else: + token = WebAppAuthService.send_email_code_login_email(account=account, language=language) + + return {"result": "success", "data": token} + + +class EmailCodeLoginApi(Resource): + @setup_required + @only_edition_enterprise + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, location="json") + args = parser.parse_args() + + user_email = args["email"] + + token_data = WebAppAuthService.get_email_code_login_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if token_data["email"] != args["email"]: + raise InvalidEmailError() + + if token_data["code"] != args["code"]: + raise EmailCodeError() + + WebAppAuthService.revoke_email_code_login_token(args["token"]) + account = WebAppAuthService.get_user_through_email(user_email) + if not account: + raise AccountNotFound() + + token = WebAppAuthService.login(account=account) + AccountService.reset_login_error_rate_limit(args["email"]) + return {"result": "success", "data": {"access_token": token}} + + +api.add_resource(LoginApi, "/login") +# api.add_resource(LogoutApi, "/logout") +api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") +api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 267dac223d..10c3cdcf0e 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -1,16 +1,19 @@ import uuid +from datetime import UTC, datetime, timedelta from flask import request from flask_restful import Resource from werkzeug.exceptions import NotFound, Unauthorized +from configs import dify_config from controllers.web import api -from controllers.web.error import WebSSOAuthRequiredError +from controllers.web.error import WebAppAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site from services.enterprise.enterprise_service import EnterpriseService from services.feature_service import FeatureService +from services.webapp_auth_service import WebAppAuthService, WebAppAuthType class PassportResource(Resource): @@ -20,14 +23,23 @@ class PassportResource(Resource): system_features = FeatureService.get_system_features() app_code = request.headers.get("X-App-Code") user_id = request.args.get("user_id") + web_app_access_token = request.args.get("web_app_access_token") if app_code is None: raise Unauthorized("X-App-Code header is missing.") - if system_features.sso_enforced_for_web: - app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) - if app_web_sso_enabled: - raise WebSSOAuthRequiredError() + # exchange token for enterprise logined web user + enterprise_user_decoded = decode_enterprise_webapp_user_id(web_app_access_token) + if enterprise_user_decoded: + # a web user has already logged in, exchange a token for this app without redirecting to the login page + return exchange_token_for_existing_web_user( + app_code=app_code, enterprise_user_decoded=enterprise_user_decoded + ) + + if system_features.webapp_auth.enabled: + app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) + if not app_settings or not app_settings.access_mode == "public": + raise WebAppAuthRequiredError() # get site from db and check if it is normal site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() @@ -84,6 +96,128 @@ class PassportResource(Resource): api.add_resource(PassportResource, "/passport") +def decode_enterprise_webapp_user_id(jwt_token: str | None): + """ + Decode the enterprise user session from the Authorization header. + """ + if not jwt_token: + return None + + decoded = PassportService().verify(jwt_token) + source = decoded.get("token_source") + if not source or source != "webapp_login_token": + raise Unauthorized("Invalid token source. Expected 'webapp_login_token'.") + return decoded + + +def exchange_token_for_existing_web_user(app_code: str, enterprise_user_decoded: dict): + """ + Exchange a token for an existing web user session. + """ + user_id = enterprise_user_decoded.get("user_id") + end_user_id = enterprise_user_decoded.get("end_user_id") + session_id = enterprise_user_decoded.get("session_id") + user_auth_type = enterprise_user_decoded.get("auth_type") + if not user_auth_type: + raise Unauthorized("Missing auth_type in the token.") + + site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() + if not site: + raise NotFound() + + app_model = db.session.query(App).filter(App.id == site.app_id).first() + if not app_model or app_model.status != "normal" or not app_model.enable_site: + raise NotFound() + + app_auth_type = WebAppAuthService.get_app_auth_type(app_code=app_code) + + if app_auth_type == WebAppAuthType.PUBLIC: + return _exchange_for_public_app_token(app_model, site, enterprise_user_decoded) + elif app_auth_type == WebAppAuthType.EXTERNAL and user_auth_type != "external": + raise WebAppAuthRequiredError("Please login as external user.") + elif app_auth_type == WebAppAuthType.INTERNAL and user_auth_type != "internal": + raise WebAppAuthRequiredError("Please login as internal user.") + + end_user = None + if end_user_id: + end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() + if session_id: + end_user = ( + db.session.query(EndUser) + .filter( + EndUser.session_id == session_id, + EndUser.tenant_id == app_model.tenant_id, + EndUser.app_id == app_model.id, + ) + .first() + ) + if not end_user: + if not session_id: + raise NotFound("Missing session_id for existing web user.") + end_user = EndUser( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type="browser", + is_anonymous=True, + session_id=session_id, + ) + db.session.add(end_user) + db.session.commit() + exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) + exp = int(exp_dt.timestamp()) + payload = { + "iss": site.id, + "sub": "Web API Passport", + "app_id": site.app_id, + "app_code": site.code, + "user_id": user_id, + "end_user_id": end_user.id, + "auth_type": user_auth_type, + "granted_at": int(datetime.now(UTC).timestamp()), + "token_source": "webapp", + "exp": exp, + } + token: str = PassportService().issue(payload) + return { + "access_token": token, + } + + +def _exchange_for_public_app_token(app_model, site, token_decoded): + user_id = token_decoded.get("user_id") + end_user = None + if user_id: + end_user = ( + db.session.query(EndUser).filter(EndUser.app_id == app_model.id, EndUser.session_id == user_id).first() + ) + + if not end_user: + end_user = EndUser( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type="browser", + is_anonymous=True, + session_id=generate_session_id(), + ) + + db.session.add(end_user) + db.session.commit() + + payload = { + "iss": site.app_id, + "sub": "Web API Passport", + "app_id": site.app_id, + "app_code": site.code, + "end_user_id": end_user.id, + } + + tk = PassportService().issue(payload) + + return { + "access_token": tk, + } + + def generate_session_id(): """ Generate a unique session ID. diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index c327c3df18..154bddfc5c 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -1,15 +1,17 @@ +from datetime import UTC, datetime from functools import wraps from flask import request from flask_restful import Resource from werkzeug.exceptions import BadRequest, NotFound, Unauthorized -from controllers.web.error import WebSSOAuthRequiredError +from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site -from services.enterprise.enterprise_service import EnterpriseService +from services.enterprise.enterprise_service import EnterpriseService, WebAppSettings from services.feature_service import FeatureService +from services.webapp_auth_service import WebAppAuthService def validate_jwt_token(view=None): @@ -29,7 +31,7 @@ def validate_jwt_token(view=None): def decode_jwt_token(): system_features = FeatureService.get_system_features() - app_code = request.headers.get("X-App-Code") + app_code = str(request.headers.get("X-App-Code")) try: auth_header = request.headers.get("Authorization") if auth_header is None: @@ -45,7 +47,8 @@ def decode_jwt_token(): raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") decoded = PassportService().verify(tk) app_code = decoded.get("app_code") - app_model = db.session.query(App).filter(App.id == decoded["app_id"]).first() + app_id = decoded.get("app_id") + app_model = db.session.query(App).filter(App.id == app_id).first() site = db.session.query(Site).filter(Site.code == app_code).first() if not app_model: raise NotFound() @@ -53,39 +56,90 @@ def decode_jwt_token(): raise BadRequest("Site URL is no longer valid.") if app_model.enable_site is False: raise BadRequest("Site is disabled.") - end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() + end_user_id = decoded.get("end_user_id") + end_user = db.session.query(EndUser).filter(EndUser.id == end_user_id).first() if not end_user: raise NotFound() - _validate_web_sso_token(decoded, system_features, app_code) + # for enterprise webapp auth + app_web_auth_enabled = False + webapp_settings = None + if system_features.webapp_auth.enabled: + webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) + if not webapp_settings: + raise NotFound("Web app settings not found.") + app_web_auth_enabled = webapp_settings.access_mode != "public" + + _validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled) + _validate_user_accessibility( + decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled, webapp_settings + ) return app_model, end_user except Unauthorized as e: - if system_features.sso_enforced_for_web: - app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) - if app_web_sso_enabled: - raise WebSSOAuthRequiredError() + if system_features.webapp_auth.enabled: + if not app_code: + raise Unauthorized("Please re-login to access the web app.") + app_web_auth_enabled = ( + EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=str(app_code)).access_mode != "public" + ) + if app_web_auth_enabled: + raise WebAppAuthRequiredError() raise Unauthorized(e.description) -def _validate_web_sso_token(decoded, system_features, app_code): - app_web_sso_enabled = False - - # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login - if system_features.sso_enforced_for_web: - app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) - if app_web_sso_enabled: - source = decoded.get("token_source") - if not source or source != "sso": - raise WebSSOAuthRequiredError() +def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool): + # Check if authentication is enforced for web app, and if the token source is not webapp, + # raise an error and redirect to login + if system_webapp_auth_enabled and app_web_auth_enabled: + source = decoded.get("token_source") + if not source or source != "webapp": + raise WebAppAuthRequiredError() - # Check if SSO is not enforced for web, and if the token source is SSO, + # Check if authentication is not enforced for web, and if the token source is webapp, # raise an error and redirect to normal passport login - if not system_features.sso_enforced_for_web or not app_web_sso_enabled: + if not system_webapp_auth_enabled or not app_web_auth_enabled: source = decoded.get("token_source") - if source and source == "sso": - raise Unauthorized("sso token expired.") + if source and source == "webapp": + raise Unauthorized("webapp token expired.") + + +def _validate_user_accessibility( + decoded, + app_code, + app_web_auth_enabled: bool, + system_webapp_auth_enabled: bool, + webapp_settings: WebAppSettings | None, +): + if system_webapp_auth_enabled and app_web_auth_enabled: + # Check if the user is allowed to access the web app + user_id = decoded.get("user_id") + if not user_id: + raise WebAppAuthRequiredError() + + if not webapp_settings: + raise WebAppAuthRequiredError("Web app settings not found.") + + if WebAppAuthService.is_app_require_permission_check(access_mode=webapp_settings.access_mode): + if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code): + raise WebAppAuthAccessDeniedError() + + auth_type = decoded.get("auth_type") + granted_at = decoded.get("granted_at") + if not auth_type: + raise WebAppAuthAccessDeniedError("Missing auth_type in the token.") + if not granted_at: + raise WebAppAuthAccessDeniedError("Missing granted_at in the token.") + # check if sso has been updated + if auth_type == "external": + last_update_time = EnterpriseService.get_app_sso_settings_last_update_time() + if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time: + raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.") + elif auth_type == "internal": + last_update_time = EnterpriseService.get_workspace_sso_settings_last_update_time() + if granted_at and datetime.fromtimestamp(granted_at, tz=UTC) < last_update_time: + raise WebAppAuthAccessDeniedError("SSO settings have been updated. Please re-login.") class WebApiResource(Resource): diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 5212d797d8..4979f63432 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -63,7 +63,7 @@ class CotAgentRunner(BaseAgentRunner, ABC): self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs) iteration_step = 1 - max_iteration_steps = min(app_config.agent.max_iteration if app_config.agent else 5, 5) + 1 + max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1 # convert tools into ModelRuntime Tool format tool_instances, prompt_messages_tools = self._init_prompt_tools() diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index e68b4f2356..143a3a51aa 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -82,7 +82,7 @@ class AgentEntity(BaseModel): strategy: Strategy prompt: Optional[AgentPromptEntity] = None tools: Optional[list[AgentToolEntity]] = None - max_iteration: int = 5 + max_iteration: int = 10 class AgentInvokeMessage(ToolInvokeMessage): diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 611a55b30a..5491689ece 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -48,7 +48,7 @@ class FunctionCallAgentRunner(BaseAgentRunner): assert app_config.agent iteration_step = 1 - max_iteration_steps = min(app_config.agent.max_iteration, 5) + 1 + max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1 # continue to run until there is not any tool call function_call_state = True diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index f503543d7b..590b944c0d 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -75,7 +75,7 @@ class AgentConfigManager: strategy=strategy, prompt=agent_prompt_entity, tools=agent_tools, - max_iteration=agent_dict.get("max_iteration", 5), + max_iteration=agent_dict.get("max_iteration", 10), ) return None diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index 20189053f4..a5492d70bd 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -138,14 +138,11 @@ class DatasetConfigManager: if not config.get("dataset_configs"): config["dataset_configs"] = {"retrieval_model": "single"} - if not config["dataset_configs"].get("datasets"): - config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []} - if not isinstance(config["dataset_configs"], dict): raise ValueError("dataset_configs must be of object type") - if not isinstance(config["dataset_configs"], dict): - raise ValueError("dataset_configs must be of object type") + if not config["dataset_configs"].get("datasets"): + config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []} need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get( "datasets", {} diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 5beb09c2aa..5b5eefe315 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -70,7 +70,7 @@ class ModelConfigConverter: if not model_mode: model_mode = LLMMode.CHAT.value if model_schema and model_schema.model_properties.get(ModelPropertyKey.MODE): - model_mode = LLMMode.value_of(model_schema.model_properties[ModelPropertyKey.MODE]).value + model_mode = LLMMode(model_schema.model_properties[ModelPropertyKey.MODE]).value if not model_schema: raise ValueError(f"Model {model_name} not exist.") diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 8ae52131f2..3f31b1c3d5 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -109,6 +109,7 @@ class VariableEntity(BaseModel): description: str = "" type: VariableEntityType required: bool = False + hide: bool = False max_length: Optional[int] = None options: Sequence[str] = Field(default_factory=list) allowed_file_types: Sequence[FileType] = Field(default_factory=list) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 4b0e64130b..a8848b9534 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -26,12 +26,14 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory -from models.account import Account -from models.model import App, Conversation, EndUser, Message -from models.workflow import Workflow +from libs.flask_utils import preserve_flask_contexts +from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom +from models.enums import WorkflowRunTriggeredFrom from services.conversation_service import ConversationService from services.errors.message import MessageNotExistsError @@ -157,16 +159,30 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): trace_manager=trace_manager, workflow_run_id=workflow_run_id, ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - # Create workflow node execution repository + # Create repositories + # + # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + if invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) + # Create workflow node execution repository workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) return self._generate( @@ -174,6 +190,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user=user, invoke_from=invoke_from, application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, conversation=conversation, stream=streaming, @@ -223,16 +240,26 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): node_id=node_id, inputs=args["inputs"] ), ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - # Create workflow node execution repository + # Create repositories + # + # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) + # Create workflow node execution repository workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) return self._generate( @@ -240,6 +267,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, conversation=None, stream=streaming, @@ -287,16 +315,26 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): extras={"auto_generate_conversation_name": False}, single_loop_run=AdvancedChatAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - # Create workflow node execution repository + # Create repositories + # + # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) + # Create workflow node execution repository workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) return self._generate( @@ -304,6 +342,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, conversation=None, stream=streaming, @@ -316,6 +355,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): user: Union[Account, EndUser], invoke_from: InvokeFrom, application_generate_entity: AdvancedChatAppGenerateEntity, + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, conversation: Optional[Conversation] = None, stream: bool = True, @@ -327,6 +367,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param user: account or end user :param invoke_from: invoke from source :param application_generate_entity: application generate entity + :param workflow_execution_repository: repository for workflow execution :param workflow_node_execution_repository: repository for workflow node execution :param conversation: conversation :param stream: is stream @@ -357,7 +398,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): message_id=message.id, ) - # new thread + # new thread with request context and contextvars + context = contextvars.copy_context() + worker_thread = threading.Thread( target=self._generate_worker, kwargs={ @@ -366,7 +409,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): "queue_manager": queue_manager, "conversation_id": conversation.id, "message_id": message.id, - "context": contextvars.copy_context(), + "context": context, }, ) @@ -380,6 +423,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, user=user, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, stream=stream, ) @@ -404,9 +448,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param message_id: message ID :return: """ - for var, val in context.items(): - var.set(val) - with flask_app.app_context(): + + with preserve_flask_contexts(flask_app, context_vars=context): try: # get conversation and message conversation = self._get_conversation(conversation_id) @@ -452,6 +495,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation: Conversation, message: Message, user: Union[Account, EndUser], + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, stream: bool = False, ) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: @@ -475,9 +519,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, user=user, - stream=stream, dialogue_count=self._dialogue_count, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, + stream=stream, ) try: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index c83e06bf15..d9b3833862 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -140,7 +140,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count, SystemVariableKey.APP_ID: app_config.app_id, SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, - SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, + SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id, } # init variable pool diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index f71c49d112..8c5645bbb7 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -1,4 +1,3 @@ -import json import logging import time from collections.abc import Generator, Mapping @@ -10,6 +9,7 @@ from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, @@ -56,25 +56,23 @@ from core.app.entities.task_entities import ( WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline -from core.app.task_pipeline.message_cycle_manage import MessageCycleManage +from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.ops_trace_manager import TraceQueueManager +from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes import NodeType -from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.workflow_cycle_manager import WorkflowCycleManager +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from events.message_event import message_was_created from extensions.ext_database import db from models import Conversation, EndUser, Message, MessageFile from models.account import Account -from models.enums import CreatedByRole -from models.workflow import ( - Workflow, - WorkflowRunStatus, -) +from models.enums import CreatorUserRole +from models.workflow import Workflow logger = logging.getLogger(__name__) @@ -94,6 +92,7 @@ class AdvancedChatAppGenerateTaskPipeline: user: Union[Account, EndUser], stream: bool, dialogue_count: int, + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, ) -> None: self._base_task_pipeline = BasedGenerateTaskPipeline( @@ -105,11 +104,11 @@ class AdvancedChatAppGenerateTaskPipeline: if isinstance(user, EndUser): self._user_id = user.id user_session_id = user.session_id - self._created_by_role = CreatedByRole.END_USER + self._created_by_role = CreatorUserRole.END_USER elif isinstance(user, Account): self._user_id = user.id user_session_id = user.id - self._created_by_role = CreatedByRole.ACCOUNT + self._created_by_role = CreatorUserRole.ACCOUNT else: raise NotImplementedError(f"User type not supported: {type(user)}") @@ -123,13 +122,24 @@ class AdvancedChatAppGenerateTaskPipeline: SystemVariableKey.DIALOGUE_COUNT: dialogue_count, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, + SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id, }, + workflow_info=CycleManagerWorkflowInfo( + workflow_id=workflow.id, + workflow_type=WorkflowType(workflow.type), + version=workflow.version, + graph_data=workflow.graph_dict, + ), + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, ) + self._workflow_response_converter = WorkflowResponseConverter( + application_generate_entity=application_generate_entity, + ) + self._task_state = WorkflowTaskState() - self._message_cycle_manager = MessageCycleManage( + self._message_cycle_manager = MessageCycleManager( application_generate_entity=application_generate_entity, task_state=self._task_state ) @@ -150,7 +160,7 @@ class AdvancedChatAppGenerateTaskPipeline: :return: """ # start generate conversation name thread - self._conversation_name_generate_thread = self._message_cycle_manager._generate_conversation_name( + self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name( conversation_id=self._conversation_id, query=self._application_generate_entity.query ) @@ -294,19 +304,15 @@ class AdvancedChatAppGenerateTaskPipeline: with Session(db.engine, expire_on_commit=False) as session: # init workflow run - workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( - session=session, - workflow_id=self._workflow_id, - user_id=self._user_id, - created_by_role=self._created_by_role, - ) - self._workflow_run_id = workflow_run.id + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start() + self._workflow_run_id = workflow_execution.id_ message = self._get_message(session=session) if not message: raise ValueError(f"Message not found: {self._message_id}") - message.workflow_run_id = workflow_run.id - workflow_start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( - session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + message.workflow_run_id = workflow_execution.id_ + workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, ) session.commit() @@ -319,13 +325,10 @@ class AdvancedChatAppGenerateTaskPipeline: raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( + workflow_execution_id=self._workflow_run_id, event=event ) - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( - workflow_run=workflow_run, event=event - ) - node_retry_resp = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( + node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -338,20 +341,15 @@ class AdvancedChatAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( - workflow_run=workflow_run, event=event - ) + workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( + workflow_execution_id=self._workflow_run_id, event=event + ) - node_start_resp = self._workflow_cycle_manager._workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() + node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) if node_start_resp: yield node_start_resp @@ -359,15 +357,15 @@ class AdvancedChatAppGenerateTaskPipeline: # Record files if it's an answer node or end node if event.node_type in [NodeType.ANSWER, NodeType.END]: self._recorded_files.extend( - self._workflow_cycle_manager._fetch_files_from_node_outputs(event.outputs or {}) + self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {}) ) with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success( event=event ) - node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -383,11 +381,11 @@ class AdvancedChatAppGenerateTaskPipeline: | QueueNodeInLoopFailedEvent | QueueNodeExceptionEvent, ): - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed( event=event ) - node_finish_resp = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -399,132 +397,92 @@ class AdvancedChatAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - parallel_start_resp = ( - self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + parallel_start_resp = ( + self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, ) + ) yield parallel_start_resp elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - parallel_finish_resp = ( - self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + parallel_finish_resp = ( + self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, ) + ) yield parallel_finish_resp elif isinstance(event, QueueIterationStartEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield iter_start_resp elif isinstance(event, QueueIterationNextEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield iter_next_resp elif isinstance(event, QueueIterationCompletedEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield iter_finish_resp elif isinstance(event, QueueLoopStartEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - loop_start_resp = self._workflow_cycle_manager._workflow_loop_start_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield loop_start_resp elif isinstance(event, QueueLoopNextEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - loop_next_resp = self._workflow_cycle_manager._workflow_loop_next_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield loop_next_resp elif isinstance(event, QueueLoopCompletedEvent): if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - loop_finish_resp = self._workflow_cycle_manager._workflow_loop_completed_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield loop_finish_resp elif isinstance(event, QueueWorkflowSucceededEvent): @@ -535,10 +493,8 @@ class AdvancedChatAppGenerateTaskPipeline: raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( - session=session, + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( workflow_run_id=self._workflow_run_id, - start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, outputs=event.outputs, @@ -546,10 +502,11 @@ class AdvancedChatAppGenerateTaskPipeline: trace_manager=trace_manager, ) - workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( - session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, ) - session.commit() yield workflow_finish_resp self._base_task_pipeline._queue_manager.publish( @@ -562,10 +519,8 @@ class AdvancedChatAppGenerateTaskPipeline: raise ValueError("graph runtime state not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( - session=session, + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( workflow_run_id=self._workflow_run_id, - start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, outputs=event.outputs, @@ -573,10 +528,11 @@ class AdvancedChatAppGenerateTaskPipeline: conversation_id=None, trace_manager=trace_manager, ) - workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( - session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, ) - session.commit() yield workflow_finish_resp self._base_task_pipeline._queue_manager.publish( @@ -589,26 +545,25 @@ class AdvancedChatAppGenerateTaskPipeline: raise ValueError("graph runtime state not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( - session=session, + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( workflow_run_id=self._workflow_run_id, - start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED, - error=event.error, + status=WorkflowExecutionStatus.FAILED, + error_message=event.error, conversation_id=self._conversation_id, trace_manager=trace_manager, exceptions_count=event.exceptions_count, ) - workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( - session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, ) - err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}")) + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}")) err = self._base_task_pipeline._handle_error( event=err_event, session=session, message_id=self._message_id ) - session.commit() yield workflow_finish_resp yield self._base_task_pipeline._error_to_stream_response(err) @@ -616,21 +571,19 @@ class AdvancedChatAppGenerateTaskPipeline: elif isinstance(event, QueueStopEvent): if self._workflow_run_id and graph_runtime_state: with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( - session=session, + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( workflow_run_id=self._workflow_run_id, - start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.STOPPED, - error=event.get_stop_reason(), + status=WorkflowExecutionStatus.STOPPED, + error_message=event.get_stop_reason(), conversation_id=self._conversation_id, trace_manager=trace_manager, ) - workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, + workflow_execution=workflow_execution, ) # Save message self._save_message(session=session, graph_runtime_state=graph_runtime_state) @@ -650,22 +603,18 @@ class AdvancedChatAppGenerateTaskPipeline: yield self._message_end_to_stream_response() break elif isinstance(event, QueueRetrieverResourcesEvent): - self._message_cycle_manager._handle_retriever_resources(event) + self._message_cycle_manager.handle_retriever_resources(event) with Session(db.engine, expire_on_commit=False) as session: message = self._get_message(session=session) - message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) + message.message_metadata = self._task_state.metadata.model_dump_json() session.commit() elif isinstance(event, QueueAnnotationReplyEvent): - self._message_cycle_manager._handle_annotation_reply(event) + self._message_cycle_manager.handle_annotation_reply(event) with Session(db.engine, expire_on_commit=False) as session: message = self._get_message(session=session) - message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) + message.message_metadata = self._task_state.metadata.model_dump_json() session.commit() elif isinstance(event, QueueTextChunkEvent): delta_text = event.text @@ -682,12 +631,12 @@ class AdvancedChatAppGenerateTaskPipeline: tts_publisher.publish(queue_message) self._task_state.answer += delta_text - yield self._message_cycle_manager._message_to_stream_response( + yield self._message_cycle_manager.message_to_stream_response( answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector ) elif isinstance(event, QueueMessageReplaceEvent): # published by moderation - yield self._message_cycle_manager._message_replace_to_stream_response( + yield self._message_cycle_manager.message_replace_to_stream_response( answer=event.text, reason=event.reason ) elif isinstance(event, QueueAdvancedChatMessageEndEvent): @@ -699,7 +648,7 @@ class AdvancedChatAppGenerateTaskPipeline: ) if output_moderation_answer: self._task_state.answer = output_moderation_answer - yield self._message_cycle_manager._message_replace_to_stream_response( + yield self._message_cycle_manager.message_replace_to_stream_response( answer=output_moderation_answer, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, ) @@ -711,7 +660,7 @@ class AdvancedChatAppGenerateTaskPipeline: yield self._message_end_to_stream_response() elif isinstance(event, QueueAgentLogEvent): - yield self._workflow_cycle_manager._handle_agent_log( + yield self._workflow_response_converter.handle_agent_log( task_id=self._application_generate_entity.task_id, event=event ) else: @@ -728,9 +677,7 @@ class AdvancedChatAppGenerateTaskPipeline: message = self._get_message(session=session) message.answer = self._task_state.answer message.provider_response_latency = time.perf_counter() - self._base_task_pipeline._start_at - message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) + message.message_metadata = self._task_state.metadata.model_dump_json() message_files = [ MessageFile( message_id=message.id, @@ -739,9 +686,9 @@ class AdvancedChatAppGenerateTaskPipeline: url=file["remote_url"], belongs_to="assistant", upload_file_id=file["related_id"], - created_by_role=CreatedByRole.ACCOUNT + created_by_role=CreatorUserRole.ACCOUNT if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else CreatedByRole.END_USER, + else CreatorUserRole.END_USER, created_by=message.from_account_id or message.from_end_user_id or "", ) for file in self._recorded_files @@ -758,9 +705,9 @@ class AdvancedChatAppGenerateTaskPipeline: message.answer_price_unit = usage.completion_price_unit message.total_price = usage.total_price message.currency = usage.currency - self._task_state.metadata["usage"] = jsonable_encoder(usage) + self._task_state.metadata.usage = usage else: - self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) + self._task_state.metadata.usage = LLMUsage.empty_usage() message_was_created.send( message, application_generate_entity=self._application_generate_entity, @@ -771,18 +718,16 @@ class AdvancedChatAppGenerateTaskPipeline: Message end to stream response. :return: """ - extras = {} - if self._task_state.metadata: - extras["metadata"] = self._task_state.metadata.copy() + extras = self._task_state.metadata.model_dump() - if "annotation_reply" in extras["metadata"]: - del extras["metadata"]["annotation_reply"] + if self._task_state.metadata.annotation_reply: + del extras["annotation_reply"] return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, id=self._message_id, files=self._recorded_files, - metadata=extras.get("metadata", {}), + metadata=extras, ) def _handle_output_moderation_chunk(self, text: str) -> bool: diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 3ed436c07a..a448bf8a94 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -23,6 +23,7 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from extensions.ext_database import db from factories import file_factory +from libs.flask_utils import preserve_flask_contexts from models import Account, App, EndUser from services.conversation_service import ConversationService from services.errors.message import MessageNotExistsError @@ -179,12 +180,14 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): message_id=message.id, ) - # new thread + # new thread with request context and contextvars + context = contextvars.copy_context() + worker_thread = threading.Thread( target=self._generate_worker, kwargs={ "flask_app": current_app._get_current_object(), # type: ignore - "context": contextvars.copy_context(), + "context": context, "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, "conversation_id": conversation.id, @@ -224,10 +227,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): :param message_id: message ID :return: """ - for var, val in context.items(): - var.set(val) - with flask_app.app_context(): + with preserve_flask_contexts(flask_app, context_vars=context): try: # get conversation and message conversation = self._get_conversation(conversation_id) diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index c813dbb9d1..a3f0cf7f9f 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -1,3 +1,4 @@ +import logging import time from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional, Union @@ -33,6 +34,8 @@ from models.model import App, AppMode, Message, MessageAnnotation if TYPE_CHECKING: from core.file.models import File +_logger = logging.getLogger(__name__) + class AppRunner: def get_pre_calculate_rest_tokens( @@ -298,7 +301,7 @@ class AppRunner: ) def _handle_invoke_result_stream( - self, invoke_result: Generator, queue_manager: AppQueueManager, agent: bool + self, invoke_result: Generator[LLMResultChunk, None, None], queue_manager: AppQueueManager, agent: bool ) -> None: """ Handle invoke result @@ -317,18 +320,28 @@ class AppRunner: else: queue_manager.publish(QueueAgentMessageEvent(chunk=result), PublishFrom.APPLICATION_MANAGER) - text += result.delta.message.content + message = result.delta.message + if isinstance(message.content, str): + text += message.content + elif isinstance(message.content, list): + for content in message.content: + if not isinstance(content, str): + # TODO(QuantumGhost): Add multimodal output support for easy ui. + _logger.warning("received multimodal output, type=%s", type(content)) + text += content.data + else: + text += content # failback to str if not model: model = result.model if not prompt_messages: - prompt_messages = result.prompt_messages + prompt_messages = list(result.prompt_messages) if result.delta.usage: usage = result.delta.usage - if not usage: + if usage is None: usage = LLMUsage.empty_usage() llm_result = LLMResult( diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 2d865795d8..a1329cb938 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -4,7 +4,7 @@ import uuid from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload -from flask import Flask, current_app +from flask import Flask, copy_current_request_context, current_app from pydantic import ValidationError from configs import dify_config @@ -170,17 +170,18 @@ class ChatAppGenerator(MessageBasedAppGenerator): message_id=message.id, ) - # new thread - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "conversation_id": conversation.id, - "message_id": message.id, - }, - ) + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return self._generate_worker( + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id=conversation.id, + message_id=message.id, + ) + + worker_thread = threading.Thread(target=worker_with_context) worker_thread.start() diff --git a/api/core/app/apps/common/__init__.py b/api/core/app/apps/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py new file mode 100644 index 0000000000..6f524a5872 --- /dev/null +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -0,0 +1,561 @@ +import time +from collections.abc import Mapping, Sequence +from datetime import UTC, datetime +from typing import Any, Optional, Union, cast + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.app.entities.queue_entities import ( + QueueAgentLogEvent, + QueueIterationCompletedEvent, + QueueIterationNextEvent, + QueueIterationStartEvent, + QueueLoopCompletedEvent, + QueueLoopNextEvent, + QueueLoopStartEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, + QueueNodeInLoopFailedEvent, + QueueNodeRetryEvent, + QueueNodeStartedEvent, + QueueNodeSucceededEvent, + QueueParallelBranchRunFailedEvent, + QueueParallelBranchRunStartedEvent, + QueueParallelBranchRunSucceededEvent, +) +from core.app.entities.task_entities import ( + AgentLogStreamResponse, + IterationNodeCompletedStreamResponse, + IterationNodeNextStreamResponse, + IterationNodeStartStreamResponse, + LoopNodeCompletedStreamResponse, + LoopNodeNextStreamResponse, + LoopNodeStartStreamResponse, + NodeFinishStreamResponse, + NodeRetryStreamResponse, + NodeStartStreamResponse, + ParallelBranchFinishedStreamResponse, + ParallelBranchStartStreamResponse, + WorkflowFinishStreamResponse, + WorkflowStartStreamResponse, +) +from core.file import FILE_MODEL_IDENTITY, File +from core.tools.tool_manager import ToolManager +from core.workflow.entities.workflow_execution import WorkflowExecution +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus +from core.workflow.nodes import NodeType +from core.workflow.nodes.tool.entities import ToolNodeData +from models import ( + Account, + CreatorUserRole, + EndUser, + WorkflowRun, +) + + +class WorkflowResponseConverter: + def __init__( + self, + *, + application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], + ) -> None: + self._application_generate_entity = application_generate_entity + + def workflow_start_to_stream_response( + self, + *, + task_id: str, + workflow_execution: WorkflowExecution, + ) -> WorkflowStartStreamResponse: + return WorkflowStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_execution.id_, + data=WorkflowStartStreamResponse.Data( + id=workflow_execution.id_, + workflow_id=workflow_execution.workflow_id, + inputs=workflow_execution.inputs, + created_at=int(workflow_execution.started_at.timestamp()), + ), + ) + + def workflow_finish_to_stream_response( + self, + *, + session: Session, + task_id: str, + workflow_execution: WorkflowExecution, + ) -> WorkflowFinishStreamResponse: + created_by = None + workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_)) + assert workflow_run is not None + if workflow_run.created_by_role == CreatorUserRole.ACCOUNT: + stmt = select(Account).where(Account.id == workflow_run.created_by) + account = session.scalar(stmt) + if account: + created_by = { + "id": account.id, + "name": account.name, + "email": account.email, + } + elif workflow_run.created_by_role == CreatorUserRole.END_USER: + stmt = select(EndUser).where(EndUser.id == workflow_run.created_by) + end_user = session.scalar(stmt) + if end_user: + created_by = { + "id": end_user.id, + "user": end_user.session_id, + } + else: + raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}") + + # Handle the case where finished_at is None by using current time as default + finished_at_timestamp = ( + int(workflow_execution.finished_at.timestamp()) + if workflow_execution.finished_at + else int(datetime.now(UTC).timestamp()) + ) + + return WorkflowFinishStreamResponse( + task_id=task_id, + workflow_run_id=workflow_execution.id_, + data=WorkflowFinishStreamResponse.Data( + id=workflow_execution.id_, + workflow_id=workflow_execution.workflow_id, + status=workflow_execution.status, + outputs=workflow_execution.outputs, + error=workflow_execution.error_message, + elapsed_time=workflow_execution.elapsed_time, + total_tokens=workflow_execution.total_tokens, + total_steps=workflow_execution.total_steps, + created_by=created_by, + created_at=int(workflow_execution.started_at.timestamp()), + finished_at=finished_at_timestamp, + files=self.fetch_files_from_node_outputs(workflow_execution.outputs), + exceptions_count=workflow_execution.exceptions_count, + ), + ) + + def workflow_node_start_to_stream_response( + self, + *, + event: QueueNodeStartedEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, + ) -> Optional[NodeStartStreamResponse]: + if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: + return None + if not workflow_node_execution.workflow_execution_id: + return None + + response = NodeStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_node_execution.workflow_execution_id, + data=NodeStartStreamResponse.Data( + id=workflow_node_execution.id, + node_id=workflow_node_execution.node_id, + node_type=workflow_node_execution.node_type, + title=workflow_node_execution.title, + index=workflow_node_execution.index, + predecessor_node_id=workflow_node_execution.predecessor_node_id, + inputs=workflow_node_execution.inputs, + created_at=int(workflow_node_execution.created_at.timestamp()), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + loop_id=event.in_loop_id, + parallel_run_id=event.parallel_mode_run_id, + agent_strategy=event.agent_strategy, + ), + ) + + # extras logic + if event.node_type == NodeType.TOOL: + node_data = cast(ToolNodeData, event.node_data) + response.data.extras["icon"] = ToolManager.get_tool_icon( + tenant_id=self._application_generate_entity.app_config.tenant_id, + provider_type=node_data.provider_type, + provider_id=node_data.provider_id, + ) + + return response + + def workflow_node_finish_to_stream_response( + self, + *, + event: QueueNodeSucceededEvent + | QueueNodeFailedEvent + | QueueNodeInIterationFailedEvent + | QueueNodeInLoopFailedEvent + | QueueNodeExceptionEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, + ) -> Optional[NodeFinishStreamResponse]: + if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: + return None + if not workflow_node_execution.workflow_execution_id: + return None + if not workflow_node_execution.finished_at: + return None + + return NodeFinishStreamResponse( + task_id=task_id, + workflow_run_id=workflow_node_execution.workflow_execution_id, + data=NodeFinishStreamResponse.Data( + id=workflow_node_execution.id, + node_id=workflow_node_execution.node_id, + node_type=workflow_node_execution.node_type, + index=workflow_node_execution.index, + title=workflow_node_execution.title, + predecessor_node_id=workflow_node_execution.predecessor_node_id, + inputs=workflow_node_execution.inputs, + process_data=workflow_node_execution.process_data, + outputs=workflow_node_execution.outputs, + status=workflow_node_execution.status, + error=workflow_node_execution.error, + elapsed_time=workflow_node_execution.elapsed_time, + execution_metadata=workflow_node_execution.metadata, + created_at=int(workflow_node_execution.created_at.timestamp()), + finished_at=int(workflow_node_execution.finished_at.timestamp()), + files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + loop_id=event.in_loop_id, + ), + ) + + def workflow_node_retry_to_stream_response( + self, + *, + event: QueueNodeRetryEvent, + task_id: str, + workflow_node_execution: WorkflowNodeExecution, + ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: + if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: + return None + if not workflow_node_execution.workflow_execution_id: + return None + if not workflow_node_execution.finished_at: + return None + + return NodeRetryStreamResponse( + task_id=task_id, + workflow_run_id=workflow_node_execution.workflow_execution_id, + data=NodeRetryStreamResponse.Data( + id=workflow_node_execution.id, + node_id=workflow_node_execution.node_id, + node_type=workflow_node_execution.node_type, + index=workflow_node_execution.index, + title=workflow_node_execution.title, + predecessor_node_id=workflow_node_execution.predecessor_node_id, + inputs=workflow_node_execution.inputs, + process_data=workflow_node_execution.process_data, + outputs=workflow_node_execution.outputs, + status=workflow_node_execution.status, + error=workflow_node_execution.error, + elapsed_time=workflow_node_execution.elapsed_time, + execution_metadata=workflow_node_execution.metadata, + created_at=int(workflow_node_execution.created_at.timestamp()), + finished_at=int(workflow_node_execution.finished_at.timestamp()), + files=self.fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + loop_id=event.in_loop_id, + retry_index=event.retry_index, + ), + ) + + def workflow_parallel_branch_start_to_stream_response( + self, + *, + task_id: str, + workflow_execution_id: str, + event: QueueParallelBranchRunStartedEvent, + ) -> ParallelBranchStartStreamResponse: + return ParallelBranchStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_execution_id, + data=ParallelBranchStartStreamResponse.Data( + parallel_id=event.parallel_id, + parallel_branch_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + loop_id=event.in_loop_id, + created_at=int(time.time()), + ), + ) + + def workflow_parallel_branch_finished_to_stream_response( + self, + *, + task_id: str, + workflow_execution_id: str, + event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, + ) -> ParallelBranchFinishedStreamResponse: + return ParallelBranchFinishedStreamResponse( + task_id=task_id, + workflow_run_id=workflow_execution_id, + data=ParallelBranchFinishedStreamResponse.Data( + parallel_id=event.parallel_id, + parallel_branch_id=event.parallel_start_node_id, + parent_parallel_id=event.parent_parallel_id, + parent_parallel_start_node_id=event.parent_parallel_start_node_id, + iteration_id=event.in_iteration_id, + loop_id=event.in_loop_id, + status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed", + error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None, + created_at=int(time.time()), + ), + ) + + def workflow_iteration_start_to_stream_response( + self, + *, + task_id: str, + workflow_execution_id: str, + event: QueueIterationStartEvent, + ) -> IterationNodeStartStreamResponse: + return IterationNodeStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_execution_id, + data=IterationNodeStartStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + created_at=int(time.time()), + extras={}, + inputs=event.inputs or {}, + metadata=event.metadata or {}, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ), + ) + + def workflow_iteration_next_to_stream_response( + self, + *, + task_id: str, + workflow_execution_id: str, + event: QueueIterationNextEvent, + ) -> IterationNodeNextStreamResponse: + return IterationNodeNextStreamResponse( + task_id=task_id, + workflow_run_id=workflow_execution_id, + data=IterationNodeNextStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + index=event.index, + pre_iteration_output=event.output, + created_at=int(time.time()), + extras={}, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parallel_mode_run_id=event.parallel_mode_run_id, + duration=event.duration, + ), + ) + + def workflow_iteration_completed_to_stream_response( + self, + *, + task_id: str, + workflow_execution_id: str, + event: QueueIterationCompletedEvent, + ) -> IterationNodeCompletedStreamResponse: + return IterationNodeCompletedStreamResponse( + task_id=task_id, + workflow_run_id=workflow_execution_id, + data=IterationNodeCompletedStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + outputs=event.outputs, + created_at=int(time.time()), + extras={}, + inputs=event.inputs or {}, + status=WorkflowNodeExecutionStatus.SUCCEEDED + if event.error is None + else WorkflowNodeExecutionStatus.FAILED, + error=None, + elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(), + total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, + execution_metadata=event.metadata, + finished_at=int(time.time()), + steps=event.steps, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ), + ) + + def workflow_loop_start_to_stream_response( + self, *, task_id: str, workflow_execution_id: str, event: QueueLoopStartEvent + ) -> LoopNodeStartStreamResponse: + return LoopNodeStartStreamResponse( + task_id=task_id, + workflow_run_id=workflow_execution_id, + data=LoopNodeStartStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + created_at=int(time.time()), + extras={}, + inputs=event.inputs or {}, + metadata=event.metadata or {}, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ), + ) + + def workflow_loop_next_to_stream_response( + self, + *, + task_id: str, + workflow_execution_id: str, + event: QueueLoopNextEvent, + ) -> LoopNodeNextStreamResponse: + return LoopNodeNextStreamResponse( + task_id=task_id, + workflow_run_id=workflow_execution_id, + data=LoopNodeNextStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + index=event.index, + pre_loop_output=event.output, + created_at=int(time.time()), + extras={}, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + parallel_mode_run_id=event.parallel_mode_run_id, + duration=event.duration, + ), + ) + + def workflow_loop_completed_to_stream_response( + self, + *, + task_id: str, + workflow_execution_id: str, + event: QueueLoopCompletedEvent, + ) -> LoopNodeCompletedStreamResponse: + return LoopNodeCompletedStreamResponse( + task_id=task_id, + workflow_run_id=workflow_execution_id, + data=LoopNodeCompletedStreamResponse.Data( + id=event.node_id, + node_id=event.node_id, + node_type=event.node_type.value, + title=event.node_data.title, + outputs=event.outputs, + created_at=int(time.time()), + extras={}, + inputs=event.inputs or {}, + status=WorkflowNodeExecutionStatus.SUCCEEDED + if event.error is None + else WorkflowNodeExecutionStatus.FAILED, + error=None, + elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(), + total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, + execution_metadata=event.metadata, + finished_at=int(time.time()), + steps=event.steps, + parallel_id=event.parallel_id, + parallel_start_node_id=event.parallel_start_node_id, + ), + ) + + def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]: + """ + Fetch files from node outputs + :param outputs_dict: node outputs dict + :return: + """ + if not outputs_dict: + return [] + + files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()] + # Remove None + files = [file for file in files if file] + # Flatten list + # Flatten the list of sequences into a single list of mappings + flattened_files = [file for sublist in files if sublist for file in sublist] + + # Convert to tuple to match Sequence type + return tuple(flattened_files) + + def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]: + """ + Fetch files from variable value + :param value: variable value + :return: + """ + if not value: + return [] + + files = [] + if isinstance(value, list): + for item in value: + file = self._get_file_var_from_value(item) + if file: + files.append(file) + elif isinstance(value, dict): + file = self._get_file_var_from_value(value) + if file: + files.append(file) + + return files + + def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None: + """ + Get file var from value + :param value: variable value + :return: + """ + if not value: + return None + + if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: + return value + elif isinstance(value, File): + return value.to_dict() + + return None + + def handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: + """ + Handle agent log + :param task_id: task id + :param event: agent log event + :return: + """ + return AgentLogStreamResponse( + task_id=task_id, + data=AgentLogStreamResponse.Data( + node_execution_id=event.node_execution_id, + id=event.id, + parent_id=event.parent_id, + label=event.label, + error=event.error, + status=event.status, + data=event.data, + metadata=event.metadata, + node_id=event.node_id, + ), + ) diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index b1bc412616..adcbaad3ec 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -4,7 +4,7 @@ import uuid from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload -from flask import Flask, current_app +from flask import Flask, copy_current_request_context, current_app from pydantic import ValidationError from configs import dify_config @@ -151,16 +151,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator): message_id=message.id, ) - # new thread - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "message_id": message.id, - }, - ) + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return self._generate_worker( + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message_id=message.id, + ) + + worker_thread = threading.Thread(target=worker_with_context) worker_thread.start() @@ -313,16 +314,17 @@ class CompletionAppGenerator(MessageBasedAppGenerator): message_id=message.id, ) - # new thread - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "message_id": message.id, - }, - ) + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return self._generate_worker( + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message_id=message.id, + ) + + worker_thread = threading.Thread(target=worker_with_context) worker_thread.start() diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 995082b79d..58b94f4d43 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -25,7 +25,7 @@ from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBa from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from models import Account -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationNotExistsError @@ -223,7 +223,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): belongs_to="user", url=file.remote_url, upload_file_id=file.related_id, - created_by_role=(CreatedByRole.ACCOUNT if account_id else CreatedByRole.END_USER), + created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER), created_by=account_id or end_user_id or "", ) db.session.add(message_file) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 1d67671974..fd15bd9f50 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -18,16 +18,20 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter +from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline +from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory -from models import Account, App, EndUser, Workflow +from libs.flask_utils import preserve_flask_contexts +from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom +from models.enums import WorkflowRunTriggeredFrom logger = logging.getLogger(__name__) @@ -129,19 +133,33 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from=invoke_from, call_depth=call_depth, trace_manager=trace_manager, - workflow_run_id=workflow_run_id, + workflow_execution_id=workflow_run_id, ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) - # Create workflow node execution repository + # Create repositories + # + # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + if invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) + # Create workflow node execution repository workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) return self._generate( @@ -150,6 +168,7 @@ class WorkflowAppGenerator(BaseAppGenerator): user=user, application_generate_entity=application_generate_entity, invoke_from=invoke_from, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, workflow_thread_pool_id=workflow_thread_pool_id, @@ -163,6 +182,7 @@ class WorkflowAppGenerator(BaseAppGenerator): user: Union[Account, EndUser], application_generate_entity: WorkflowAppGenerateEntity, invoke_from: InvokeFrom, + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, streaming: bool = True, workflow_thread_pool_id: Optional[str] = None, @@ -175,6 +195,7 @@ class WorkflowAppGenerator(BaseAppGenerator): :param user: account or end user :param application_generate_entity: application generate entity :param invoke_from: invoke from source + :param workflow_execution_repository: repository for workflow execution :param workflow_node_execution_repository: repository for workflow node execution :param streaming: is stream :param workflow_thread_pool_id: workflow thread pool id @@ -187,14 +208,16 @@ class WorkflowAppGenerator(BaseAppGenerator): app_mode=app_model.mode, ) - # new thread + # new thread with request context and contextvars + context = contextvars.copy_context() + worker_thread = threading.Thread( target=self._generate_worker, kwargs={ "flask_app": current_app._get_current_object(), # type: ignore "application_generate_entity": application_generate_entity, "queue_manager": queue_manager, - "context": contextvars.copy_context(), + "context": context, "workflow_thread_pool_id": workflow_thread_pool_id, }, ) @@ -207,6 +230,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow=workflow, queue_manager=queue_manager, user=user, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, stream=streaming, ) @@ -254,18 +278,30 @@ class WorkflowAppGenerator(BaseAppGenerator): single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( node_id=node_id, inputs=args["inputs"] ), - workflow_run_id=str(uuid.uuid4()), + workflow_execution_id=str(uuid.uuid4()), ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) + # Create repositories + # + # Create session factory + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) # Create workflow node execution repository session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) return self._generate( @@ -274,6 +310,7 @@ class WorkflowAppGenerator(BaseAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, ) @@ -317,18 +354,30 @@ class WorkflowAppGenerator(BaseAppGenerator): invoke_from=InvokeFrom.DEBUGGER, extras={"auto_generate_conversation_name": False}, single_loop_run=WorkflowAppGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]), - workflow_run_id=str(uuid.uuid4()), + workflow_execution_id=str(uuid.uuid4()), ) - contexts.tenant_id.set(application_generate_entity.app_config.tenant_id) contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(threading.Lock()) + # Create repositories + # + # Create session factory + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, + ) # Create workflow node execution repository session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) return self._generate( @@ -337,6 +386,7 @@ class WorkflowAppGenerator(BaseAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, ) @@ -357,9 +407,8 @@ class WorkflowAppGenerator(BaseAppGenerator): :param workflow_thread_pool_id: workflow thread pool id :return: """ - for var, val in context.items(): - var.set(val) - with flask_app.app_context(): + + with preserve_flask_contexts(flask_app, context_vars=context): try: # workflow app runner = WorkflowAppRunner( @@ -394,6 +443,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow: Workflow, queue_manager: AppQueueManager, user: Union[Account, EndUser], + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, stream: bool = False, ) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: @@ -413,8 +463,9 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow=workflow, queue_manager=queue_manager, user=user, - stream=stream, + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, + stream=stream, ) try: diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index b38ee18ac4..b59e34e222 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -95,7 +95,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): SystemVariableKey.USER_ID: user_id, SystemVariableKey.APP_ID: app_config.app_id, SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, - SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, + SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id, } variable_pool = VariablePool( diff --git a/api/core/workflow/workflow_app_generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py similarity index 69% rename from api/core/workflow/workflow_app_generate_task_pipeline.py rename to api/core/app/apps/workflow/generate_task_pipeline.py index 10a2d8b38b..1734dbb598 100644 --- a/api/core/workflow/workflow_app_generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -3,10 +3,12 @@ import time from collections.abc import Generator from typing import Optional, Union +from sqlalchemy import select from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import ( InvokeFrom, WorkflowAppGenerateEntity, @@ -48,24 +50,24 @@ from core.app.entities.task_entities import ( WorkflowAppStreamResponse, WorkflowFinishStreamResponse, WorkflowStartStreamResponse, - WorkflowTaskState, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager +from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType from core.workflow.enums import SystemVariableKey -from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.workflow_cycle_manager import WorkflowCycleManager +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from extensions.ext_database import db from models.account import Account -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import EndUser from models.workflow import ( Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowRun, - WorkflowRunStatus, ) logger = logging.getLogger(__name__) @@ -83,6 +85,7 @@ class WorkflowAppGenerateTaskPipeline: queue_manager: AppQueueManager, user: Union[Account, EndUser], stream: bool, + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, ) -> None: self._base_task_pipeline = BasedGenerateTaskPipeline( @@ -94,11 +97,11 @@ class WorkflowAppGenerateTaskPipeline: if isinstance(user, EndUser): self._user_id = user.id user_session_id = user.session_id - self._created_by_role = CreatedByRole.END_USER + self._created_by_role = CreatorUserRole.END_USER elif isinstance(user, Account): self._user_id = user.id user_session_id = user.id - self._created_by_role = CreatedByRole.ACCOUNT + self._created_by_role = CreatorUserRole.ACCOUNT else: raise ValueError(f"Invalid user type: {type(user)}") @@ -109,15 +112,24 @@ class WorkflowAppGenerateTaskPipeline: SystemVariableKey.USER_ID: user_session_id, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, + SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id, }, + workflow_info=CycleManagerWorkflowInfo( + workflow_id=workflow.id, + workflow_type=WorkflowType(workflow.type), + version=workflow.version, + graph_data=workflow.graph_dict, + ), + workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, ) + self._workflow_response_converter = WorkflowResponseConverter( + application_generate_entity=application_generate_entity, + ) + self._application_generate_entity = application_generate_entity - self._workflow_id = workflow.id self._workflow_features_dict = workflow.features_dict - self._task_state = WorkflowTaskState() self._workflow_run_id = "" def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: @@ -256,19 +268,13 @@ class WorkflowAppGenerateTaskPipeline: # override graph runtime state graph_runtime_state = event.graph_runtime_state - with Session(db.engine, expire_on_commit=False) as session: - # init workflow run - workflow_run = self._workflow_cycle_manager._handle_workflow_run_start( - session=session, - workflow_id=self._workflow_id, - user_id=self._user_id, - created_by_role=self._created_by_role, - ) - self._workflow_run_id = workflow_run.id - start_resp = self._workflow_cycle_manager._workflow_start_to_stream_response( - session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run - ) - session.commit() + # init workflow run + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start() + self._workflow_run_id = workflow_execution.id_ + start_resp = self._workflow_response_converter.workflow_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) yield start_resp elif isinstance( @@ -278,13 +284,11 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_retried( - workflow_run=workflow_run, event=event + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( + workflow_execution_id=self._workflow_run_id, + event=event, ) - response = self._workflow_cycle_manager._workflow_node_retry_to_stream_response( + response = self._workflow_response_converter.workflow_node_retry_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -297,27 +301,22 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - workflow_node_execution = self._workflow_cycle_manager._handle_node_execution_start( - workflow_run=workflow_run, event=event - ) - node_start_response = self._workflow_cycle_manager._workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() + workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( + workflow_execution_id=self._workflow_run_id, event=event + ) + node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) if node_start_response: yield node_start_response elif isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_success( + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success( event=event ) - node_success_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -332,10 +331,10 @@ class WorkflowAppGenerateTaskPipeline: | QueueNodeInLoopFailedEvent | QueueNodeExceptionEvent, ): - workflow_node_execution = self._workflow_cycle_manager._handle_workflow_node_execution_failed( + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed( event=event, ) - node_failed_response = self._workflow_cycle_manager._workflow_node_finish_to_stream_response( + node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response( event=event, task_id=self._application_generate_entity.task_id, workflow_node_execution=workflow_node_execution, @@ -348,18 +347,13 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - parallel_start_resp = ( - self._workflow_cycle_manager._workflow_parallel_branch_start_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + parallel_start_resp = ( + self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, ) + ) yield parallel_start_resp @@ -367,18 +361,13 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - parallel_finish_resp = ( - self._workflow_cycle_manager._workflow_parallel_branch_finished_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + parallel_finish_resp = ( + self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, ) + ) yield parallel_finish_resp @@ -386,16 +375,11 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - iter_start_resp = self._workflow_cycle_manager._workflow_iteration_start_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield iter_start_resp @@ -403,16 +387,11 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - iter_next_resp = self._workflow_cycle_manager._workflow_iteration_next_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield iter_next_resp @@ -420,16 +399,11 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - iter_finish_resp = self._workflow_cycle_manager._workflow_iteration_completed_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield iter_finish_resp @@ -437,16 +411,11 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - loop_start_resp = self._workflow_cycle_manager._workflow_loop_start_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield loop_start_resp @@ -454,16 +423,11 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - loop_next_resp = self._workflow_cycle_manager._workflow_loop_next_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield loop_next_resp @@ -471,16 +435,11 @@ class WorkflowAppGenerateTaskPipeline: if not self._workflow_run_id: raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._get_workflow_run( - session=session, workflow_run_id=self._workflow_run_id - ) - loop_finish_resp = self._workflow_cycle_manager._workflow_loop_completed_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, - event=event, - ) + loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) yield loop_finish_resp @@ -491,10 +450,8 @@ class WorkflowAppGenerateTaskPipeline: raise ValueError("graph runtime state not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._handle_workflow_run_success( - session=session, + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( workflow_run_id=self._workflow_run_id, - start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, outputs=event.outputs, @@ -503,12 +460,12 @@ class WorkflowAppGenerateTaskPipeline: ) # save workflow app log - self._save_workflow_app_log(session=session, workflow_run=workflow_run) + self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) - workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, - workflow_run=workflow_run, + workflow_execution=workflow_execution, ) session.commit() @@ -520,10 +477,8 @@ class WorkflowAppGenerateTaskPipeline: raise ValueError("graph runtime state not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._handle_workflow_run_partial_success( - session=session, + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( workflow_run_id=self._workflow_run_id, - start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, outputs=event.outputs, @@ -533,10 +488,12 @@ class WorkflowAppGenerateTaskPipeline: ) # save workflow app log - self._save_workflow_app_log(session=session, workflow_run=workflow_run) + self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) - workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( - session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, ) session.commit() @@ -548,26 +505,28 @@ class WorkflowAppGenerateTaskPipeline: raise ValueError("graph runtime state not initialized.") with Session(db.engine, expire_on_commit=False) as session: - workflow_run = self._workflow_cycle_manager._handle_workflow_run_failed( - session=session, + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( workflow_run_id=self._workflow_run_id, - start_at=graph_runtime_state.start_at, total_tokens=graph_runtime_state.total_tokens, total_steps=graph_runtime_state.node_run_steps, - status=WorkflowRunStatus.FAILED + status=WorkflowExecutionStatus.FAILED if isinstance(event, QueueWorkflowFailedEvent) - else WorkflowRunStatus.STOPPED, - error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + else WorkflowExecutionStatus.STOPPED, + error_message=event.error + if isinstance(event, QueueWorkflowFailedEvent) + else event.get_stop_reason(), conversation_id=None, trace_manager=trace_manager, exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, ) # save workflow app log - self._save_workflow_app_log(session=session, workflow_run=workflow_run) + self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) - workflow_finish_resp = self._workflow_cycle_manager._workflow_finish_to_stream_response( - session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, ) session.commit() @@ -581,12 +540,11 @@ class WorkflowAppGenerateTaskPipeline: if tts_publisher: tts_publisher.publish(queue_message) - self._task_state.answer += delta_text yield self._text_chunk_to_stream_response( delta_text, from_variable_selector=event.from_variable_selector ) elif isinstance(event, QueueAgentLogEvent): - yield self._workflow_cycle_manager._handle_agent_log( + yield self._workflow_response_converter.handle_agent_log( task_id=self._application_generate_entity.task_id, event=event ) else: @@ -595,11 +553,9 @@ class WorkflowAppGenerateTaskPipeline: if tts_publisher: tts_publisher.publish(None) - def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None: - """ - Save workflow app log. - :return: - """ + def _save_workflow_app_log(self, *, session: Session, workflow_execution: WorkflowExecution) -> None: + workflow_run = session.scalar(select(WorkflowRun).where(WorkflowRun.id == workflow_execution.id_)) + assert workflow_run is not None invoke_from = self._application_generate_entity.invoke_from if invoke_from == InvokeFrom.SERVICE_API: created_from = WorkflowAppLogCreatedFrom.SERVICE_API diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 0884fac4a9..facc24b4ca 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -29,8 +29,8 @@ from core.app.entities.queue_entities import ( QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) -from core.workflow.entities.node_entities import NodeRunMetadataKey from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.graph_engine.entities.event import ( AgentLogEvent, GraphEngineEvent, @@ -295,7 +295,7 @@ class WorkflowBasedAppRunner(AppRunner): inputs: Mapping[str, Any] | None = {} process_data: Mapping[str, Any] | None = {} outputs: Mapping[str, Any] | None = {} - execution_metadata: Mapping[NodeRunMetadataKey, Any] | None = {} + execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = {} if node_run_result: inputs = node_run_result.inputs process_data = node_run_result.process_data diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index 56e6b46a60..c0d99693b0 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -76,6 +76,8 @@ class AppGenerateEntity(BaseModel): App Generate Entity. """ + model_config = ConfigDict(arbitrary_types_allowed=True) + task_id: str # app config @@ -99,9 +101,6 @@ class AppGenerateEntity(BaseModel): # tracing instance trace_manager: Optional[TraceQueueManager] = None - class Config: - arbitrary_types_allowed = True - class EasyUIBasedAppGenerateEntity(AppGenerateEntity): """ @@ -205,7 +204,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity): # app config app_config: WorkflowUIBasedAppConfig - workflow_run_id: str + workflow_execution_id: str class SingleIterationRunEntity(BaseModel): """ diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 7228020e9b..42e6a1519c 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from datetime import datetime from enum import Enum, StrEnum from typing import Any, Optional @@ -6,7 +6,9 @@ from typing import Any, Optional from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey +from core.rag.entities.citation_metadata import RetrievalSourceMetadata +from core.workflow.entities.node_entities import AgentNodeStrategyInit +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes import NodeType from core.workflow.nodes.base import BaseNodeData @@ -282,7 +284,7 @@ class QueueRetrieverResourcesEvent(AppQueueEvent): """ event: QueueEvent = QueueEvent.RETRIEVER_RESOURCES - retriever_resources: list[dict] + retriever_resources: Sequence[RetrievalSourceMetadata] in_iteration_id: Optional[str] = None """iteration id if node is in iteration""" in_loop_id: Optional[str] = None @@ -412,7 +414,7 @@ class QueueNodeSucceededEvent(AppQueueEvent): inputs: Optional[Mapping[str, Any]] = None process_data: Optional[Mapping[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None + execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None error: Optional[str] = None """single iteration duration map""" @@ -446,7 +448,7 @@ class QueueNodeRetryEvent(QueueNodeStartedEvent): inputs: Optional[Mapping[str, Any]] = None process_data: Optional[Mapping[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None + execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None error: str retry_index: int # retry index @@ -480,7 +482,7 @@ class QueueNodeInIterationFailedEvent(AppQueueEvent): inputs: Optional[Mapping[str, Any]] = None process_data: Optional[Mapping[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None + execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None error: str @@ -513,7 +515,7 @@ class QueueNodeInLoopFailedEvent(AppQueueEvent): inputs: Optional[Mapping[str, Any]] = None process_data: Optional[Mapping[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None + execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None error: str @@ -546,7 +548,7 @@ class QueueNodeExceptionEvent(AppQueueEvent): inputs: Optional[Mapping[str, Any]] = None process_data: Optional[Mapping[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None + execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None error: str @@ -579,7 +581,7 @@ class QueueNodeFailedEvent(AppQueueEvent): inputs: Optional[Mapping[str, Any]] = None process_data: Optional[Mapping[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None - execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None + execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None error: str diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 817699bd20..25c889e922 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -2,12 +2,29 @@ from collections.abc import Mapping, Sequence from enum import Enum from typing import Any, Optional -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field -from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.utils.encoders import jsonable_encoder +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.node_entities import AgentNodeStrategyInit -from models.workflow import WorkflowNodeExecutionStatus +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus + + +class AnnotationReplyAccount(BaseModel): + id: str + name: str + + +class AnnotationReply(BaseModel): + id: str + account: AnnotationReplyAccount + + +class TaskStateMetadata(BaseModel): + annotation_reply: AnnotationReply | None = None + retriever_resources: Sequence[RetrievalSourceMetadata] = Field(default_factory=list) + usage: LLMUsage | None = None class TaskState(BaseModel): @@ -15,7 +32,7 @@ class TaskState(BaseModel): TaskState entity """ - metadata: dict = {} + metadata: TaskStateMetadata = Field(default_factory=TaskStateMetadata) class EasyUITaskState(TaskState): @@ -189,8 +206,7 @@ class WorkflowStartStreamResponse(StreamResponse): id: str workflow_id: str - sequence_number: int - inputs: dict + inputs: Mapping[str, Any] created_at: int event: StreamEvent = StreamEvent.WORKFLOW_STARTED @@ -210,9 +226,8 @@ class WorkflowFinishStreamResponse(StreamResponse): id: str workflow_id: str - sequence_number: int status: str - outputs: Optional[dict] = None + outputs: Optional[Mapping[str, Any]] = None error: Optional[str] = None elapsed_time: float total_tokens: int @@ -244,7 +259,7 @@ class NodeStartStreamResponse(StreamResponse): title: str index: int predecessor_node_id: Optional[str] = None - inputs: Optional[dict] = None + inputs: Optional[Mapping[str, Any]] = None created_at: int extras: dict = {} parallel_id: Optional[str] = None @@ -301,13 +316,13 @@ class NodeFinishStreamResponse(StreamResponse): title: str index: int predecessor_node_id: Optional[str] = None - inputs: Optional[dict] = None - process_data: Optional[dict] = None - outputs: Optional[dict] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None status: str error: Optional[str] = None elapsed_time: float - execution_metadata: Optional[dict] = None + execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None created_at: int finished_at: int files: Optional[Sequence[Mapping[str, Any]]] = [] @@ -370,13 +385,13 @@ class NodeRetryStreamResponse(StreamResponse): title: str index: int predecessor_node_id: Optional[str] = None - inputs: Optional[dict] = None - process_data: Optional[dict] = None - outputs: Optional[dict] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None status: str error: Optional[str] = None elapsed_time: float - execution_metadata: Optional[dict] = None + execution_metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None created_at: int finished_at: int files: Optional[Sequence[Mapping[str, Any]]] = [] @@ -788,7 +803,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse): id: str workflow_id: str status: str - outputs: Optional[dict] = None + outputs: Optional[Mapping[str, Any]] = None error: Optional[str] = None elapsed_time: float total_tokens: int diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index a98a42f5df..d535e1f835 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -1,4 +1,3 @@ -import json import logging import time from collections.abc import Generator @@ -43,15 +42,15 @@ from core.app.entities.task_entities import ( StreamResponse, ) from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline -from core.app.task_pipeline.message_cycle_manage import MessageCycleManage +from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_manager import ModelInstance from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + TextPromptMessageContent, ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.prompt.utils.prompt_message_util import PromptMessageUtil @@ -63,7 +62,7 @@ from models.model import AppMode, Conversation, Message, MessageAgentThought logger = logging.getLogger(__name__) -class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleManage): +class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline): """ EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application. """ @@ -104,6 +103,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan ) ) + self._message_cycle_manager = MessageCycleManager( + application_generate_entity=application_generate_entity, + task_state=self._task_state, + ) + self._conversation_name_generate_thread: Optional[Thread] = None def process( @@ -115,7 +119,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan ]: if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: # start generate conversation name thread - self._conversation_name_generate_thread = self._generate_conversation_name( + self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name( conversation_id=self._conversation_id, query=self._application_generate_entity.query or "" ) @@ -136,9 +140,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if isinstance(stream_response, ErrorStreamResponse): raise stream_response.err elif isinstance(stream_response, MessageEndStreamResponse): - extras = {"usage": jsonable_encoder(self._task_state.llm_result.usage)} + extras = {"usage": self._task_state.llm_result.usage.model_dump()} if self._task_state.metadata: - extras["metadata"] = self._task_state.metadata + extras["metadata"] = self._task_state.metadata.model_dump() response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] if self._conversation_mode == AppMode.COMPLETION.value: response = CompletionAppBlockingResponse( @@ -277,7 +281,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan ) if output_moderation_answer: self._task_state.llm_result.message.content = output_moderation_answer - yield self._message_replace_to_stream_response(answer=output_moderation_answer) + yield self._message_cycle_manager.message_replace_to_stream_response( + answer=output_moderation_answer + ) with Session(db.engine) as session: # Save message @@ -286,9 +292,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan message_end_resp = self._message_end_to_stream_response() yield message_end_resp elif isinstance(event, QueueRetrieverResourcesEvent): - self._handle_retriever_resources(event) + self._message_cycle_manager.handle_retriever_resources(event) elif isinstance(event, QueueAnnotationReplyEvent): - annotation = self._handle_annotation_reply(event) + annotation = self._message_cycle_manager.handle_annotation_reply(event) if annotation: self._task_state.llm_result.message.content = annotation.content elif isinstance(event, QueueAgentThoughtEvent): @@ -296,7 +302,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan if agent_thought_response is not None: yield agent_thought_response elif isinstance(event, QueueMessageFileEvent): - response = self._message_file_to_stream_response(event) + response = self._message_cycle_manager.message_file_to_stream_response(event) if response: yield response elif isinstance(event, QueueLLMChunkEvent | QueueAgentMessageEvent): @@ -304,6 +310,23 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan delta_text = chunk.delta.message.content if delta_text is None: continue + if isinstance(chunk.delta.message.content, list): + delta_text = "" + for content in chunk.delta.message.content: + logger.debug( + "The content type %s in LLM chunk delta message content.: %r", type(content), content + ) + if isinstance(content, TextPromptMessageContent): + delta_text += content.data + elif isinstance(content, str): + delta_text += content # failback to str + else: + logger.warning( + "Unsupported content type %s in LLM chunk delta message content.: %r", + type(content), + content, + ) + continue if not self._task_state.llm_result.prompt_messages: self._task_state.llm_result.prompt_messages = chunk.prompt_messages @@ -318,7 +341,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan self._task_state.llm_result.message.content = current_content if isinstance(event, QueueLLMChunkEvent): - yield self._message_to_stream_response( + yield self._message_cycle_manager.message_to_stream_response( answer=cast(str, delta_text), message_id=self._message_id, ) @@ -328,7 +351,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan message_id=self._message_id, ) elif isinstance(event, QueueMessageReplaceEvent): - yield self._message_replace_to_stream_response(answer=event.text) + yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text) elif isinstance(event, QueuePingEvent): yield self._ping_stream_response() else: @@ -372,9 +395,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan message.provider_response_latency = time.perf_counter() - self._start_at message.total_price = usage.total_price message.currency = usage.currency - message.message_metadata = ( - json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None - ) + message.message_metadata = self._task_state.metadata.model_dump_json() if trace_manager: trace_manager.add_trace_task( @@ -423,16 +444,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan Message end to stream response. :return: """ - self._task_state.metadata["usage"] = jsonable_encoder(self._task_state.llm_result.usage) - - extras = {} - if self._task_state.metadata: - extras["metadata"] = self._task_state.metadata - + self._task_state.metadata.usage = self._task_state.llm_result.usage + metadata_dict = self._task_state.metadata.model_dump() return MessageEndStreamResponse( task_id=self._application_generate_entity.task_id, id=self._message_id, - metadata=extras.get("metadata", {}), + metadata=metadata_dict, ) def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse: @@ -455,8 +472,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan agent_thought: Optional[MessageAgentThought] = ( db.session.query(MessageAgentThought).filter(MessageAgentThought.id == event.agent_thought_id).first() ) - db.session.refresh(agent_thought) - db.session.close() if agent_thought: return AgentThoughtStreamResponse( diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manager.py similarity index 85% rename from api/core/app/task_pipeline/message_cycle_manage.py rename to api/core/app/task_pipeline/message_cycle_manager.py index a6d826f08b..2343081eaf 100644 --- a/api/core/app/task_pipeline/message_cycle_manage.py +++ b/api/core/app/task_pipeline/message_cycle_manager.py @@ -17,6 +17,8 @@ from core.app.entities.queue_entities import ( QueueRetrieverResourcesEvent, ) from core.app.entities.task_entities import ( + AnnotationReply, + AnnotationReplyAccount, EasyUITaskState, MessageFileStreamResponse, MessageReplaceStreamResponse, @@ -30,7 +32,7 @@ from models.model import AppMode, Conversation, MessageAnnotation, MessageFile from services.annotation_service import AppAnnotationService -class MessageCycleManage: +class MessageCycleManager: def __init__( self, *, @@ -45,7 +47,7 @@ class MessageCycleManage: self._application_generate_entity = application_generate_entity self._task_state = task_state - def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: + def generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]: """ Generate conversation name. :param conversation_id: conversation id @@ -102,7 +104,7 @@ class MessageCycleManage: db.session.commit() db.session.close() - def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: + def handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]: """ Handle annotation reply. :param event: event @@ -111,25 +113,28 @@ class MessageCycleManage: annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id) if annotation: account = annotation.account - self._task_state.metadata["annotation_reply"] = { - "id": annotation.id, - "account": {"id": annotation.account_id, "name": account.name if account else "Dify user"}, - } + self._task_state.metadata.annotation_reply = AnnotationReply( + id=annotation.id, + account=AnnotationReplyAccount( + id=annotation.account_id, + name=account.name if account else "Dify user", + ), + ) return annotation return None - def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: + def handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None: """ Handle retriever resources. :param event: event :return: """ if self._application_generate_entity.app_config.additional_features.show_retrieve_source: - self._task_state.metadata["retriever_resources"] = event.retriever_resources + self._task_state.metadata.retriever_resources = event.retriever_resources - def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: + def message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]: """ Message file to stream response. :param event: event @@ -166,7 +171,7 @@ class MessageCycleManage: return None - def _message_to_stream_response( + def message_to_stream_response( self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None ) -> MessageStreamResponse: """ @@ -182,7 +187,7 @@ class MessageCycleManage: from_variable_selector=from_variable_selector, ) - def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: + def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse: """ Message replace to stream response. :param answer: answer diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 56859df7f4..a3a7b4b812 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,12 +1,18 @@ +import logging +from collections.abc import Sequence + from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueRetrieverResourcesEvent +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.index_processor.constant.index_type import IndexType from core.rag.models.document import Document from extensions.ext_database import db from models.dataset import ChildChunk, DatasetQuery, DocumentSegment from models.dataset import Document as DatasetDocument +_logger = logging.getLogger(__name__) + class DatasetIndexToolCallbackHandler: """Callback handler for dataset tool.""" @@ -42,18 +48,31 @@ class DatasetIndexToolCallbackHandler: """Handle tool end.""" for document in documents: if document.metadata is not None: - dataset_document = DatasetDocument.query.filter( - DatasetDocument.id == document.metadata["document_id"] - ).first() + document_id = document.metadata["document_id"] + dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first() + if not dataset_document: + _logger.warning( + "Expected DatasetDocument record to exist, but none was found, document_id=%s", + document_id, + ) + continue if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - child_chunk = ChildChunk.query.filter( - ChildChunk.index_node_id == document.metadata["doc_id"], - ChildChunk.dataset_id == dataset_document.dataset_id, - ChildChunk.document_id == dataset_document.id, - ).first() + child_chunk = ( + db.session.query(ChildChunk) + .filter( + ChildChunk.index_node_id == document.metadata["doc_id"], + ChildChunk.dataset_id == dataset_document.dataset_id, + ChildChunk.document_id == dataset_document.id, + ) + .first() + ) if child_chunk: - segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == child_chunk.segment_id) + .update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False + ) ) else: query = db.session.query(DocumentSegment).filter( @@ -68,7 +87,8 @@ class DatasetIndexToolCallbackHandler: db.session.commit() - def return_retriever_resource_info(self, resource: list): + # TODO(-LAN-): Improve type check + def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]): """Handle return_retriever_resource_info.""" self._queue_manager.publish( QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 5017835565..e1c021a44a 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -55,6 +55,25 @@ class ProviderModelWithStatusEntity(ProviderModel): status: ModelStatus load_balancing_enabled: bool = False + def raise_for_status(self) -> None: + """ + Check model status and raise ValueError if not active. + + :raises ValueError: When model status is not active, with a descriptive message + """ + if self.status == ModelStatus.ACTIVE: + return + + error_messages = { + ModelStatus.NO_CONFIGURE: "Model is not configured", + ModelStatus.QUOTA_EXCEEDED: "Model quota has been exceeded", + ModelStatus.NO_PERMISSION: "No permission to use this model", + ModelStatus.DISABLED: "Model is disabled", + } + + if self.status in error_messages: + raise ValueError(error_messages[self.status]) + class ModelWithProviderEntity(ProviderModelWithStatusEntity): """ diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 86887c9b4a..66d8d0f414 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -754,7 +754,7 @@ class ProviderConfiguration(BaseModel): :param only_active: return active model only :return: """ - provider_models = self.get_provider_models(model_type, only_active) + provider_models = self.get_provider_models(model_type, only_active, model) for provider_model in provider_models: if provider_model.model == model: @@ -763,12 +763,13 @@ class ProviderConfiguration(BaseModel): return None def get_provider_models( - self, model_type: Optional[ModelType] = None, only_active: bool = False + self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None ) -> list[ModelWithProviderEntity]: """ Get provider models. :param model_type: model type :param only_active: only active models + :param model: model name :return: """ model_provider_factory = ModelProviderFactory(self.tenant_id) @@ -791,7 +792,10 @@ class ProviderConfiguration(BaseModel): ) else: provider_models = self._get_custom_provider_models( - model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map + model_types=model_types, + provider_schema=provider_schema, + model_setting_map=model_setting_map, + model=model, ) if only_active: @@ -897,37 +901,36 @@ class ProviderConfiguration(BaseModel): ) except Exception as ex: logger.warning(f"get custom model schema failed, {ex}") - - if not custom_model_schema: - continue - - if custom_model_schema.model_type not in model_types: - continue - - status = ModelStatus.ACTIVE - if ( - custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] - ): - model_setting = model_setting_map[custom_model_schema.model_type][ - custom_model_schema.model - ] - if model_setting.enabled is False: - status = ModelStatus.DISABLED - - provider_models.append( - ModelWithProviderEntity( - model=custom_model_schema.model, - label=custom_model_schema.label, - model_type=custom_model_schema.model_type, - features=custom_model_schema.features, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties=custom_model_schema.model_properties, - deprecated=custom_model_schema.deprecated, - provider=SimpleModelProviderEntity(self.provider), - status=status, - ) + continue + + if not custom_model_schema: + continue + + if custom_model_schema.model_type not in model_types: + continue + + status = ModelStatus.ACTIVE + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): + model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + + provider_models.append( + ModelWithProviderEntity( + model=custom_model_schema.model, + label=custom_model_schema.label, + model_type=custom_model_schema.model_type, + features=custom_model_schema.features, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties=custom_model_schema.model_properties, + deprecated=custom_model_schema.deprecated, + provider=SimpleModelProviderEntity(self.provider), + status=status, ) + ) # if llm name not in restricted llm list, remove it restrict_model_names = [rm.model for rm in restrict_models] @@ -944,6 +947,7 @@ class ProviderConfiguration(BaseModel): model_types: Sequence[ModelType], provider_schema: ProviderEntity, model_setting_map: dict[ModelType, dict[str, ModelSettings]], + model: Optional[str] = None, ) -> list[ModelWithProviderEntity]: """ Get custom provider models. @@ -996,7 +1000,8 @@ class ProviderConfiguration(BaseModel): for model_configuration in self.custom_configuration.models: if model_configuration.model_type not in model_types: continue - + if model and model != model_configuration.model: + continue try: custom_model_schema = self.get_model_schema( model_type=model_configuration.model_type, diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index 231743bf2a..06fdb089d4 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -41,45 +41,53 @@ class Extensible: extensions = [] position_map: dict[str, int] = {} - # get the path of the current class - current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + ".py") - current_dir_path = os.path.dirname(current_path) - - # traverse subdirectories - for subdir_name in os.listdir(current_dir_path): - if subdir_name.startswith("__"): - continue - - subdir_path = os.path.join(current_dir_path, subdir_name) - extension_name = subdir_name - if os.path.isdir(subdir_path): + # Get the package name from the module path + package_name = ".".join(cls.__module__.split(".")[:-1]) + + try: + # Get package directory path + package_spec = importlib.util.find_spec(package_name) + if not package_spec or not package_spec.origin: + raise ImportError(f"Could not find package {package_name}") + + package_dir = os.path.dirname(package_spec.origin) + + # Traverse subdirectories + for subdir_name in os.listdir(package_dir): + if subdir_name.startswith("__"): + continue + + subdir_path = os.path.join(package_dir, subdir_name) + if not os.path.isdir(subdir_path): + continue + + extension_name = subdir_name file_names = os.listdir(subdir_path) - # is builtin extension, builtin extension - # in the front-end page and business logic, there are special treatments. + # Check for extension module file + if (extension_name + ".py") not in file_names: + logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") + continue + + # Check for builtin flag and position builtin = False - # default position is 0 can not be None for sort_to_dict_by_position_map position = 0 if "__builtin__" in file_names: builtin = True - builtin_file_path = os.path.join(subdir_path, "__builtin__") if os.path.exists(builtin_file_path): position = int(Path(builtin_file_path).read_text(encoding="utf-8").strip()) position_map[extension_name] = position - if (extension_name + ".py") not in file_names: - logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") - continue - - # Dynamic loading {subdir_name}.py file and find the subclass of Extensible - py_path = os.path.join(subdir_path, extension_name + ".py") - spec = importlib.util.spec_from_file_location(extension_name, py_path) + # Import the extension module + module_name = f"{package_name}.{extension_name}.{extension_name}" + spec = importlib.util.find_spec(module_name) if not spec or not spec.loader: - raise Exception(f"Failed to load module {extension_name} from {py_path}") + raise ImportError(f"Failed to load module {module_name}") mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) + # Find extension class extension_class = None for name, obj in vars(mod).items(): if isinstance(obj, type) and issubclass(obj, cls) and obj != cls: @@ -87,21 +95,21 @@ class Extensible: break if not extension_class: - logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.") + logging.warning(f"Missing subclass of {cls.__name__} in {module_name}, Skip.") continue + # Load schema if not builtin json_data: dict[str, Any] = {} if not builtin: - if "schema.json" not in file_names: + json_path = os.path.join(subdir_path, "schema.json") + if not os.path.exists(json_path): logging.warning(f"Missing schema.json file in {subdir_path}, Skip.") continue - json_path = os.path.join(subdir_path, "schema.json") - json_data = {} - if os.path.exists(json_path): - with open(json_path, encoding="utf-8") as f: - json_data = json.load(f) + with open(json_path, encoding="utf-8") as f: + json_data = json.load(f) + # Create extension extensions.append( ModuleExtension( extension_class=extension_class, @@ -113,6 +121,11 @@ class Extensible: ) ) + except Exception as e: + logging.exception("Error scanning extensions") + raise + + # Sort extensions by position sorted_extensions = sort_to_dict_by_position_map( position_map=position_map, data=extensions, name_func=lambda x: x.name ) diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 5bb045cce9..2b580cb373 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -15,6 +15,7 @@ from core.helper.code_executor.python3.python3_transformer import Python3Templat from core.helper.code_executor.template_transformer import TemplateTransformer logger = logging.getLogger(__name__) +code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) class CodeExecutionError(Exception): @@ -64,7 +65,7 @@ class CodeExecutor: :param code: code :return: """ - url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) / "v1" / "sandbox" / "run" + url = code_execution_endpoint_url / "v1" / "sandbox" / "run" headers = {"X-Api-Key": dify_config.CODE_EXECUTION_API_KEY} diff --git a/api/core/helper/marketplace.py b/api/core/helper/marketplace.py index f4129b88ed..65bf4fc1db 100644 --- a/api/core/helper/marketplace.py +++ b/api/core/helper/marketplace.py @@ -7,29 +7,28 @@ from configs import dify_config from core.helper.download import download_with_size_limit from core.plugin.entities.marketplace import MarketplacePluginDeclaration +marketplace_api_url = URL(str(dify_config.MARKETPLACE_API_URL)) -def get_plugin_pkg_url(plugin_unique_identifier: str): - return (URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/download").with_query( - unique_identifier=plugin_unique_identifier - ) + +def get_plugin_pkg_url(plugin_unique_identifier: str) -> str: + return str((marketplace_api_url / "api/v1/plugins/download").with_query(unique_identifier=plugin_unique_identifier)) def download_plugin_pkg(plugin_unique_identifier: str): - url = str(get_plugin_pkg_url(plugin_unique_identifier)) - return download_with_size_limit(url, dify_config.PLUGIN_MAX_PACKAGE_SIZE) + return download_with_size_limit(get_plugin_pkg_url(plugin_unique_identifier), dify_config.PLUGIN_MAX_PACKAGE_SIZE) def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplacePluginDeclaration]: if len(plugin_ids) == 0: return [] - url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/plugins/batch") + url = str(marketplace_api_url / "api/v1/plugins/batch") response = requests.post(url, json={"plugin_ids": plugin_ids}) response.raise_for_status() return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]] def record_install_plugin_event(plugin_unique_identifier: str): - url = str(URL(str(dify_config.MARKETPLACE_API_URL)) / "api/v1/stats/plugins/install_count") + url = str(marketplace_api_url / "api/v1/stats/plugins/install_count") response = requests.post(url, json={"unique_identifier": plugin_unique_identifier}) response.raise_for_status() diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 6a5982eca4..a324ac2767 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -1,5 +1,5 @@ import logging -import random +import secrets from typing import cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -38,7 +38,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt if len(text_chunks) == 0: return True - text_chunk = random.choice(text_chunks) + text_chunk = secrets.choice(text_chunks) try: model_provider_factory = ModelProviderFactory(tenant_id) diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 81bf59b2b6..848d897779 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -51,7 +51,7 @@ class IndexingRunner: for dataset_document in dataset_documents: try: # get dataset - dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") @@ -103,15 +103,17 @@ class IndexingRunner: """Run the indexing process when the index_status is splitting.""" try: # get dataset - dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete - document_segments = DocumentSegment.query.filter_by( - dataset_id=dataset.id, document_id=dataset_document.id - ).all() + document_segments = ( + db.session.query(DocumentSegment) + .filter_by(dataset_id=dataset.id, document_id=dataset_document.id) + .all() + ) for document_segment in document_segments: db.session.delete(document_segment) @@ -162,15 +164,17 @@ class IndexingRunner: """Run the indexing process when the index_status is indexing.""" try: # get dataset - dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first() if not dataset: raise ValueError("no dataset found") # get exist document_segment list and delete - document_segments = DocumentSegment.query.filter_by( - dataset_id=dataset.id, document_id=dataset_document.id - ).all() + document_segments = ( + db.session.query(DocumentSegment) + .filter_by(dataset_id=dataset.id, document_id=dataset_document.id) + .all() + ) documents = [] if document_segments: @@ -254,7 +258,7 @@ class IndexingRunner: embedding_model_instance = None if dataset_id: - dataset = Dataset.query.filter_by(id=dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() if not dataset: raise ValueError("Dataset not found.") if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": @@ -587,7 +591,7 @@ class IndexingRunner: @staticmethod def _process_keyword_index(flask_app, dataset_id, document_id, documents): with flask_app.app_context(): - dataset = Dataset.query.filter_by(id=dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() if not dataset: raise ValueError("no dataset found") keyword = Keyword(dataset) @@ -656,10 +660,10 @@ class IndexingRunner: """ Update the document indexing status. """ - count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count() + count = db.session.query(DatasetDocument).filter_by(id=document_id, is_paused=True).count() if count > 0: raise DocumentIsPausedError() - document = DatasetDocument.query.filter_by(id=document_id).first() + document = db.session.query(DatasetDocument).filter_by(id=document_id).first() if not document: raise DocumentIsDeletedPausedError() @@ -668,7 +672,7 @@ class IndexingRunner: if extra_update_params: update_params.update(extra_update_params) - DatasetDocument.query.filter_by(id=document_id).update(update_params) + db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params) db.session.commit() @staticmethod @@ -676,7 +680,7 @@ class IndexingRunner: """ Update the document segment by document id. """ - DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) + db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params) db.session.commit() def _transform( diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index e5dbc30689..e01896a491 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -51,15 +51,19 @@ class LLMGenerator: response = cast( LLMResult, model_instance.invoke_llm( - prompt_messages=list(prompts), model_parameters={"max_tokens": 100, "temperature": 1}, stream=False + prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False ), ) answer = cast(str, response.message.content) cleaned_answer = re.sub(r"^.*(\{.*\}).*$", r"\1", answer, flags=re.DOTALL) if cleaned_answer is None: return "" - result_dict = json.loads(cleaned_answer) - answer = result_dict["Your Output"] + try: + result_dict = json.loads(cleaned_answer) + answer = result_dict["Your Output"] + except json.JSONDecodeError as e: + logging.exception("Failed to generate name after answer, use query instead") + answer = query name = answer.strip() if len(name) > 75: diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py index 34ea3aec26..ddfa1e7a66 100644 --- a/api/core/llm_generator/prompts.py +++ b/api/core/llm_generator/prompts.py @@ -1,61 +1,20 @@ -# Written by YORKI MINAKO🤡, Edited by Xiaoyi -CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is. -Notice: the language type user uses could be diverse, which can be English, Chinese, Italian, Español, Arabic, Japanese, French, and etc. -ENSURE your output is in the SAME language as the user's input! -Your output is restricted only to: (Input language) Intention + Subject(short as possible) -Your output MUST be a valid JSON. +# Written by YORKI MINAKO🤡, Edited by Xiaoyi, Edited by yasu-oh +CONVERSATION_TITLE_PROMPT = """You are asked to generate a concise chat title by decomposing the user’s input into two parts: “Intention” and “Subject”. -Tip: When the user's question is directed at you (the language model), you can add an emoji to make it more fun. +1. Detect Input Language +Automatically identify the language of the user’s input (e.g. English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.). +2. Generate Title +- Combine Intention + Subject into a single, as-short-as-possible phrase. +- The title must be natural, friendly, and in the same language as the input. +- If the input is a direct question to the model, you may add an emoji at the end. -example 1: -User Input: hi, yesterday i had some burgers. +3. Output Format +Return **only** a valid JSON object with these exact keys and no additional text: { - "Language Type": "The user's input is pure English", - "Your Reasoning": "The language of my output must be pure English.", - "Your Output": "sharing yesterday's food" -} - -example 2: -User Input: hello -{ - "Language Type": "The user's input is pure English", - "Your Reasoning": "The language of my output must be pure English.", - "Your Output": "Greeting myself☺️" -} - - -example 3: -User Input: why mmap file: oom -{ - "Language Type": "The user's input is written in pure English", - "Your Reasoning": "The language of my output must be pure English.", - "Your Output": "Asking about the reason for mmap file: oom" -} - - -example 4: -User Input: www.convinceme.yesterday-you-ate-seafood.tv讲了什么? -{ - "Language Type": "The user's input English-Chinese mixed", - "Your Reasoning": "The English-part is an URL, the main intention is still written in Chinese, so the language of my output must be using Chinese.", - "Your Output": "询问网站www.convinceme.yesterday-you-ate-seafood.tv" -} - -example 5: -User Input: why小红的年龄is老than小明? -{ - "Language Type": "The user's input is English-Chinese mixed", - "Your Reasoning": "The English parts are filler words, the main intention is written in Chinese, besides, Chinese occupies a greater \"actual meaning\" than English, so the language of my output must be using Chinese.", - "Your Output": "询问小红和小明的年龄" -} - -example 6: -User Input: yo, 你今天咋样? -{ - "Language Type": "The user's input is English-Chinese mixed", - "Your Reasoning": "The English-part is a subjective particle, the main intention is written in Chinese, so the language of my output must be using Chinese.", - "Your Output": "查询今日我的状态☺️" + "Language Type": "", + "Your Reasoning": "", + "Your Output": "" } User Input: diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 995a30d44c..4886ffe244 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -542,8 +542,6 @@ class LBModelManager: return config - return None - def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None: """ Cooldown model load balancing config diff --git a/api/core/model_runtime/entities/llm_entities.py b/api/core/model_runtime/entities/llm_entities.py index 9bb118622b..de5a748d4f 100644 --- a/api/core/model_runtime/entities/llm_entities.py +++ b/api/core/model_runtime/entities/llm_entities.py @@ -17,19 +17,6 @@ class LLMMode(StrEnum): COMPLETION = "completion" CHAT = "chat" - @classmethod - def value_of(cls, value: str) -> "LLMMode": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid mode value {value}") - class LLMUsage(ModelUsage): """ diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 373ef2bbe2..568149cc37 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -160,6 +160,10 @@ class ProviderModel(BaseModel): deprecated: bool = False model_config = ConfigDict(protected_namespaces=()) + @property + def support_structure_output(self) -> bool: + return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features + class ParameterRule(BaseModel): """ diff --git a/api/core/model_runtime/utils/encoders.py b/api/core/model_runtime/utils/encoders.py index 03e3506271..a5c11aeeba 100644 --- a/api/core/model_runtime/utils/encoders.py +++ b/api/core/model_runtime/utils/encoders.py @@ -129,17 +129,18 @@ def jsonable_encoder( sqlalchemy_safe=sqlalchemy_safe, ) if dataclasses.is_dataclass(obj): - # FIXME: mypy error, try to fix it instead of using type: ignore - obj_dict = dataclasses.asdict(obj) # type: ignore - return jsonable_encoder( - obj_dict, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) + # Ensure obj is a dataclass instance, not a dataclass type + if not isinstance(obj, type): + obj_dict = dataclasses.asdict(obj) + return jsonable_encoder( + obj_dict, + by_alias=by_alias, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + exclude_none=exclude_none, + custom_encoder=custom_encoder, + sqlalchemy_safe=sqlalchemy_safe, + ) if isinstance(obj, Enum): return obj.value if isinstance(obj, PurePath): diff --git a/api/core/ops/base_trace_instance.py b/api/core/ops/base_trace_instance.py index f7b882fc71..8593198bc2 100644 --- a/api/core/ops/base_trace_instance.py +++ b/api/core/ops/base_trace_instance.py @@ -1,7 +1,11 @@ from abc import ABC, abstractmethod +from sqlalchemy.orm import Session + from core.ops.entities.config_entity import BaseTracingConfig from core.ops.entities.trace_entity import BaseTraceInfo +from extensions.ext_database import db +from models import Account, App, TenantAccountJoin class BaseTraceInstance(ABC): @@ -24,3 +28,38 @@ class BaseTraceInstance(ABC): Subclasses must implement specific tracing logic for activities. """ ... + + def get_service_account_with_tenant(self, app_id: str) -> Account: + """ + Get service account for an app and set up its tenant. + + Args: + app_id: The ID of the app + + Returns: + Account: The service account with tenant set up + + Raises: + ValueError: If app, creator account or tenant cannot be found + """ + with Session(db.engine, expire_on_commit=False) as session: + # Get the app to find its creator + app = session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError(f"App with id {app_id} not found") + + if not app.created_by: + raise ValueError(f"App with id {app_id} has no creator (created_by is None)") + + service_account = session.query(Account).filter(Account.id == app.created_by).first() + if not service_account: + raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") + + current_tenant = ( + session.query(TenantAccountJoin).filter_by(account_id=service_account.id, current=True).first() + ) + if not current_tenant: + raise ValueError(f"Current tenant not found for account {service_account.id}") + service_account.set_tenant_id(current_tenant.tenant_id) + + return service_account diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py index 874b2800b2..c988bf48d1 100644 --- a/api/core/ops/entities/config_entity.py +++ b/api/core/ops/entities/config_entity.py @@ -1,9 +1,9 @@ -from enum import Enum +from enum import StrEnum from pydantic import BaseModel, ValidationInfo, field_validator -class TracingProviderEnum(Enum): +class TracingProviderEnum(StrEnum): LANGFUSE = "langfuse" LANGSMITH = "langsmith" OPIK = "opik" @@ -98,6 +98,7 @@ class WeaveConfig(BaseTracingConfig): entity: str | None = None project: str endpoint: str = "https://trace.wandb.ai" + host: str | None = None @field_validator("endpoint") @classmethod @@ -109,6 +110,14 @@ class WeaveConfig(BaseTracingConfig): return v + @field_validator("host") + @classmethod + def validate_host(cls, v, info: ValidationInfo): + if v is not None and v != "": + if not v.startswith(("https://", "http://")): + raise ValueError("host must start with https:// or http://") + return v + OPS_FILE_PATH = "ops_trace/" OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE" diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index f0e34c0cd7..151fa2aaf4 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -3,7 +3,7 @@ from datetime import datetime from enum import StrEnum from typing import Any, Optional, Union -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import BaseModel, ConfigDict, field_serializer, field_validator class BaseTraceInfo(BaseModel): @@ -24,10 +24,13 @@ class BaseTraceInfo(BaseModel): return v return "" - class Config: - json_encoders = { - datetime: lambda v: v.isoformat(), - } + model_config = ConfigDict(protected_namespaces=()) + + @field_serializer("start_time", "end_time") + def serialize_datetime(self, dt: datetime | None) -> str | None: + if dt is None: + return None + return dt.isoformat() class WorkflowTraceInfo(BaseTraceInfo): diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index f486da3a6d..46ba1c45b9 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from datetime import datetime from enum import StrEnum from typing import Any, Optional, Union @@ -155,10 +156,10 @@ class LangfuseSpan(BaseModel): description="The status message of the span. Additional field for context of the event. E.g. the error " "message of an error event.", ) - input: Optional[Union[str, dict[str, Any], list, None]] = Field( + input: Optional[Union[str, Mapping[str, Any], list, None]] = Field( default=None, description="The input of the span. Can be any JSON object." ) - output: Optional[Union[str, dict[str, Any], list, None]] = Field( + output: Optional[Union[str, Mapping[str, Any], list, None]] = Field( default=None, description="The output of the span. Can be any JSON object." ) version: Optional[str] = Field( diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index c74617e558..0ea74e9ef0 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -1,4 +1,3 @@ -import json import logging import os from datetime import datetime, timedelta @@ -30,8 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( ) from core.ops.utils import filter_none_values from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.model import EndUser +from models import EndUser, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -113,8 +113,18 @@ class LangFuseDataTrace(BaseTraceInstance): # through workflow_run_id get all_nodes_execution using repository session_factory = sessionmaker(bind=db.engine) + # Find the app's creator account + app_id = trace_info.metadata.get("app_id") + if not app_id: + raise ValueError("No app_id found in trace_info metadata") + + service_account = self.get_service_account_with_tenant(app_id) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=session_factory, tenant_id=trace_info.tenant_id + session_factory=session_factory, + user=service_account, + app_id=trace_info.metadata.get("app_id"), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) # Get all executions for this workflow run @@ -124,23 +134,22 @@ class LangFuseDataTrace(BaseTraceInstance): for node_execution in workflow_node_executions: node_execution_id = node_execution.id - tenant_id = node_execution.tenant_id - app_id = node_execution.app_id + tenant_id = trace_info.tenant_id # Use from trace_info instead + app_id = trace_info.metadata.get("app_id") # Use from trace_info instead node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == "llm": - inputs = ( - json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} - ) + if node_type == NodeType.LLM: + inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} + inputs = node_execution.inputs if node_execution.inputs else {} + outputs = node_execution.outputs if node_execution.outputs else {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} + execution_metadata = node_execution.metadata if node_execution.metadata else {} + metadata = {str(k): v for k, v in execution_metadata.items()} metadata.update( { "workflow_run_id": trace_info.workflow_run_id, @@ -152,7 +161,7 @@ class LangFuseDataTrace(BaseTraceInstance): "status": status, } ) - process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + process_data = node_execution.process_data if node_execution.process_data else {} model_provider = process_data.get("model_provider", None) model_name = process_data.get("model_name", None) if model_provider is not None and model_name is not None: diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 348b7ba501..4fd01136ba 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from datetime import datetime from enum import StrEnum from typing import Any, Optional, Union @@ -30,8 +31,8 @@ class LangSmithMultiModel(BaseModel): class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): name: Optional[str] = Field(..., description="Name of the run") - inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the run") - outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the run") + inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the run") + outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the run") run_type: LangSmithRunType = Field(..., description="Type of the run") start_time: Optional[datetime | str] = Field(None, description="Start time of the run") end_time: Optional[datetime | str] = Field(None, description="End time of the run") diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index d1e16d3152..8a392940db 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -1,4 +1,3 @@ -import json import logging import os import uuid @@ -29,8 +28,10 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( ) from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.model import EndUser, MessageFile +from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -137,8 +138,18 @@ class LangSmithDataTrace(BaseTraceInstance): # through workflow_run_id get all_nodes_execution using repository session_factory = sessionmaker(bind=db.engine) + # Find the app's creator account + app_id = trace_info.metadata.get("app_id") + if not app_id: + raise ValueError("No app_id found in trace_info metadata") + + service_account = self.get_service_account_with_tenant(app_id) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id") + session_factory=session_factory, + user=service_account, + app_id=trace_info.metadata.get("app_id"), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) # Get all executions for this workflow run @@ -148,27 +159,23 @@ class LangSmithDataTrace(BaseTraceInstance): for node_execution in workflow_node_executions: node_execution_id = node_execution.id - tenant_id = node_execution.tenant_id - app_id = node_execution.app_id + tenant_id = trace_info.tenant_id # Use from trace_info instead + app_id = trace_info.metadata.get("app_id") # Use from trace_info instead node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == "llm": - inputs = ( - json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} - ) + if node_type == NodeType.LLM: + inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} + inputs = node_execution.inputs if node_execution.inputs else {} + outputs = node_execution.outputs if node_execution.outputs else {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = ( - json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} - ) - node_total_tokens = execution_metadata.get("total_tokens", 0) - metadata = execution_metadata.copy() + execution_metadata = node_execution.metadata if node_execution.metadata else {} + node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 + metadata = {str(key): value for key, value in execution_metadata.items()} metadata.update( { "workflow_run_id": trace_info.workflow_run_id, @@ -181,7 +188,7 @@ class LangSmithDataTrace(BaseTraceInstance): } ) - process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + process_data = node_execution.process_data if node_execution.process_data else {} if process_data and process_data.get("model_mode") == "chat": run_type = LangSmithRunType.llm @@ -191,7 +198,7 @@ class LangSmithDataTrace(BaseTraceInstance): "ls_model_name": process_data.get("model_name", ""), } ) - elif node_type == "knowledge-retrieval": + elif node_type == NodeType.KNOWLEDGE_RETRIEVAL: run_type = LangSmithRunType.retriever else: run_type = LangSmithRunType.tool diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 1484041447..f4d2760ba5 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -1,4 +1,3 @@ -import json import logging import os import uuid @@ -23,8 +22,10 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.model import EndUser, MessageFile +from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -114,6 +115,7 @@ class OpikDataTrace(BaseTraceInstance): "metadata": workflow_metadata, "input": wrap_dict("input", trace_info.workflow_run_inputs), "output": wrap_dict("output", trace_info.workflow_run_outputs), + "thread_id": trace_info.conversation_id, "tags": ["message", "workflow"], "project_name": self.project, } @@ -143,6 +145,7 @@ class OpikDataTrace(BaseTraceInstance): "metadata": workflow_metadata, "input": wrap_dict("input", trace_info.workflow_run_inputs), "output": wrap_dict("output", trace_info.workflow_run_outputs), + "thread_id": trace_info.conversation_id, "tags": ["workflow"], "project_name": self.project, } @@ -150,8 +153,18 @@ class OpikDataTrace(BaseTraceInstance): # through workflow_run_id get all_nodes_execution using repository session_factory = sessionmaker(bind=db.engine) + # Find the app's creator account + app_id = trace_info.metadata.get("app_id") + if not app_id: + raise ValueError("No app_id found in trace_info metadata") + + service_account = self.get_service_account_with_tenant(app_id) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id") + session_factory=session_factory, + user=service_account, + app_id=trace_info.metadata.get("app_id"), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) # Get all executions for this workflow run @@ -161,26 +174,22 @@ class OpikDataTrace(BaseTraceInstance): for node_execution in workflow_node_executions: node_execution_id = node_execution.id - tenant_id = node_execution.tenant_id - app_id = node_execution.app_id + tenant_id = trace_info.tenant_id # Use from trace_info instead + app_id = trace_info.metadata.get("app_id") # Use from trace_info instead node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == "llm": - inputs = ( - json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} - ) + if node_type == NodeType.LLM: + inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} + inputs = node_execution.inputs if node_execution.inputs else {} + outputs = node_execution.outputs if node_execution.outputs else {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = ( - json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} - ) - metadata = execution_metadata.copy() + execution_metadata = node_execution.metadata if node_execution.metadata else {} + metadata = {str(k): v for k, v in execution_metadata.items()} metadata.update( { "workflow_run_id": trace_info.workflow_run_id, @@ -193,7 +202,7 @@ class OpikDataTrace(BaseTraceInstance): } ) - process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + process_data = node_execution.process_data if node_execution.process_data else {} provider = None model = None @@ -226,7 +235,7 @@ class OpikDataTrace(BaseTraceInstance): parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id if not total_tokens: - total_tokens = execution_metadata.get("total_tokens", 0) + total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 span_data = { "trace_id": opik_trace_id, @@ -288,6 +297,7 @@ class OpikDataTrace(BaseTraceInstance): "metadata": wrap_metadata(metadata), "input": trace_info.inputs, "output": message_data.answer, + "thread_id": message_data.conversation_id, "tags": ["message", str(trace_info.conversation_mode)], "project_name": self.project, } @@ -402,6 +412,7 @@ class OpikDataTrace(BaseTraceInstance): "metadata": wrap_metadata(trace_info.metadata), "input": trace_info.inputs, "output": trace_info.outputs, + "thread_id": trace_info.conversation_id, "tags": ["generate_name"], "project_name": self.project, } diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 2c68055f87..a98904102c 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -16,11 +16,7 @@ from sqlalchemy.orm import Session from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token from core.ops.entities.config_entity import ( OPS_FILE_PATH, - LangfuseConfig, - LangSmithConfig, - OpikConfig, TracingProviderEnum, - WeaveConfig, ) from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, @@ -33,11 +29,8 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace -from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace -from core.ops.opik_trace.opik_trace import OpikDataTrace from core.ops.utils import get_message_data -from core.ops.weave_trace.weave_trace import WeaveDataTrace +from core.workflow.entities.workflow_execution import WorkflowExecution from extensions.ext_database import db from extensions.ext_storage import storage from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig @@ -45,36 +38,58 @@ from models.workflow import WorkflowAppLog, WorkflowRun from tasks.ops_trace_task import process_trace_tasks -def build_opik_trace_instance(config: OpikConfig): - return OpikDataTrace(config) - - -provider_config_map: dict[str, dict[str, Any]] = { - TracingProviderEnum.LANGFUSE.value: { - "config_class": LangfuseConfig, - "secret_keys": ["public_key", "secret_key"], - "other_keys": ["host", "project_key"], - "trace_instance": LangFuseDataTrace, - }, - TracingProviderEnum.LANGSMITH.value: { - "config_class": LangSmithConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "endpoint"], - "trace_instance": LangSmithDataTrace, - }, - TracingProviderEnum.OPIK.value: { - "config_class": OpikConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "url", "workspace"], - "trace_instance": lambda config: build_opik_trace_instance(config), - }, - TracingProviderEnum.WEAVE.value: { - "config_class": WeaveConfig, - "secret_keys": ["api_key"], - "other_keys": ["project", "entity", "endpoint"], - "trace_instance": WeaveDataTrace, - }, -} +class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]): + def __getitem__(self, provider: str) -> dict[str, Any]: + match provider: + case TracingProviderEnum.LANGFUSE: + from core.ops.entities.config_entity import LangfuseConfig + from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace + + return { + "config_class": LangfuseConfig, + "secret_keys": ["public_key", "secret_key"], + "other_keys": ["host", "project_key"], + "trace_instance": LangFuseDataTrace, + } + + case TracingProviderEnum.LANGSMITH: + from core.ops.entities.config_entity import LangSmithConfig + from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace + + return { + "config_class": LangSmithConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "endpoint"], + "trace_instance": LangSmithDataTrace, + } + + case TracingProviderEnum.OPIK: + from core.ops.entities.config_entity import OpikConfig + from core.ops.opik_trace.opik_trace import OpikDataTrace + + return { + "config_class": OpikConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "url", "workspace"], + "trace_instance": OpikDataTrace, + } + + case TracingProviderEnum.WEAVE: + from core.ops.entities.config_entity import WeaveConfig + from core.ops.weave_trace.weave_trace import WeaveDataTrace + + return { + "config_class": WeaveConfig, + "secret_keys": ["api_key"], + "other_keys": ["project", "entity", "endpoint", "host"], + "trace_instance": WeaveDataTrace, + } + + case _: + raise KeyError(f"Unsupported tracing provider: {provider}") + + +provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap() class OpsTraceManager: @@ -220,7 +235,11 @@ class OpsTraceManager: return None tracing_provider = app_ops_trace_config.get("tracing_provider") - if tracing_provider is None or tracing_provider not in provider_config_map: + if tracing_provider is None: + return None + try: + provider_config_map[tracing_provider] + except KeyError: return None # decrypt_token @@ -232,7 +251,7 @@ class OpsTraceManager: provider_config_map[tracing_provider]["trace_instance"], provider_config_map[tracing_provider]["config_class"], ) - decrypt_trace_config_key = str(decrypt_trace_config) + decrypt_trace_config_key = json.dumps(decrypt_trace_config, sort_keys=True) tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key) if tracing_instance is None: # create new tracing_instance and update the cache if it absent @@ -273,8 +292,14 @@ class OpsTraceManager: :return: """ # auth check - if tracing_provider not in provider_config_map and tracing_provider is not None: - raise ValueError(f"Invalid tracing provider: {tracing_provider}") + if enabled == True: + try: + provider_config_map[tracing_provider] + except KeyError: + raise ValueError(f"Invalid tracing provider: {tracing_provider}") + else: + if tracing_provider is not None: + raise ValueError(f"Invalid tracing provider: {tracing_provider}") app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first() if not app_config: @@ -353,7 +378,7 @@ class TraceTask: self, trace_type: Any, message_id: Optional[str] = None, - workflow_run: Optional[WorkflowRun] = None, + workflow_execution: Optional[WorkflowExecution] = None, conversation_id: Optional[str] = None, user_id: Optional[str] = None, timer: Optional[Any] = None, @@ -361,7 +386,7 @@ class TraceTask: ): self.trace_type = trace_type self.message_id = message_id - self.workflow_run_id = workflow_run.id if workflow_run else None + self.workflow_run_id = workflow_execution.id_ if workflow_execution else None self.conversation_id = conversation_id self.user_id = user_id self.timer = timer @@ -462,6 +487,7 @@ class TraceTask: "file_list": file_list, "triggered_from": workflow_run.triggered_from, "user_id": user_id, + "app_id": workflow_run.app_id, } workflow_trace_info = WorkflowTraceInfo( diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/core/ops/weave_trace/entities/weave_trace_entity.py index e423f5ccbb..7f489f37ac 100644 --- a/api/core/ops/weave_trace/entities/weave_trace_entity.py +++ b/api/core/ops/weave_trace/entities/weave_trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from typing import Any, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -19,8 +20,8 @@ class WeaveMultiModel(BaseModel): class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): id: str = Field(..., description="ID of the trace") op: str = Field(..., description="Name of the operation") - inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace") - outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace") + inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace") + outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace") attributes: Optional[Union[str, dict[str, Any], list, None]] = Field( None, description="Metadata and attributes associated with trace" ) diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 49594cb0f1..3917348a91 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -1,4 +1,3 @@ -import json import logging import os import uuid @@ -7,6 +6,7 @@ from typing import Any, Optional, cast import wandb import weave +from sqlalchemy.orm import sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import WeaveConfig @@ -22,9 +22,11 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.model import EndUser, MessageFile -from models.workflow import WorkflowNodeExecution +from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -38,9 +40,14 @@ class WeaveDataTrace(BaseTraceInstance): self.weave_api_key = weave_config.api_key self.project_name = weave_config.project self.entity = weave_config.entity + self.host = weave_config.host + + # Login with API key first, including host if provided + if self.host: + login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host) + else: + login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True) - # Login with API key first - login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True) if not login_status: logger.error("Failed to login to Weights & Biases with the provided API key") raise ValueError("Weave login failed") @@ -128,58 +135,46 @@ class WeaveDataTrace(BaseTraceInstance): self.start_call(workflow_run, parent_run_id=trace_info.message_id) - # through workflow_run_id get all_nodes_execution - workflow_nodes_execution_id_records = ( - db.session.query(WorkflowNodeExecution.id) - .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) - .all() - ) + # through workflow_run_id get all_nodes_execution using repository + session_factory = sessionmaker(bind=db.engine) + # Find the app's creator account + app_id = trace_info.metadata.get("app_id") + if not app_id: + raise ValueError("No app_id found in trace_info metadata") - for node_execution_id_record in workflow_nodes_execution_id_records: - node_execution = ( - db.session.query( - WorkflowNodeExecution.id, - WorkflowNodeExecution.tenant_id, - WorkflowNodeExecution.app_id, - WorkflowNodeExecution.title, - WorkflowNodeExecution.node_type, - WorkflowNodeExecution.status, - WorkflowNodeExecution.inputs, - WorkflowNodeExecution.outputs, - WorkflowNodeExecution.created_at, - WorkflowNodeExecution.elapsed_time, - WorkflowNodeExecution.process_data, - WorkflowNodeExecution.execution_metadata, - ) - .filter(WorkflowNodeExecution.id == node_execution_id_record.id) - .first() - ) + service_account = self.get_service_account_with_tenant(app_id) + + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=service_account, + app_id=trace_info.metadata.get("app_id"), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) - if not node_execution: - continue + # Get all executions for this workflow run + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( + workflow_run_id=trace_info.workflow_run_id + ) + for node_execution in workflow_node_executions: node_execution_id = node_execution.id - tenant_id = node_execution.tenant_id - app_id = node_execution.app_id + tenant_id = trace_info.tenant_id # Use from trace_info instead + app_id = trace_info.metadata.get("app_id") # Use from trace_info instead node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == "llm": - inputs = ( - json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} - ) + if node_type == NodeType.LLM: + inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} + inputs = node_execution.inputs if node_execution.inputs else {} + outputs = node_execution.outputs if node_execution.outputs else {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = ( - json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} - ) - node_total_tokens = execution_metadata.get("total_tokens", 0) - attributes = execution_metadata.copy() + execution_metadata = node_execution.metadata if node_execution.metadata else {} + node_total_tokens = execution_metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) or 0 + attributes = {str(k): v for k, v in execution_metadata.items()} attributes.update( { "workflow_run_id": trace_info.workflow_run_id, @@ -192,7 +187,7 @@ class WeaveDataTrace(BaseTraceInstance): } ) - process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + process_data = node_execution.process_data if node_execution.process_data else {} if process_data and process_data.get("model_mode") == "chat": attributes.update( { @@ -396,7 +391,11 @@ class WeaveDataTrace(BaseTraceInstance): def api_check(self): try: - login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True) + if self.host: + login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True, host=self.host) + else: + login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True) + if not login_status: raise ValueError("Weave login failed") else: diff --git a/api/core/plugin/backwards_invocation/base.py b/api/core/plugin/backwards_invocation/base.py index 3214e07469..2a5f857576 100644 --- a/api/core/plugin/backwards_invocation/base.py +++ b/api/core/plugin/backwards_invocation/base.py @@ -11,14 +11,12 @@ class BaseBackwardsInvocation: try: for chunk in response: if isinstance(chunk, BaseModel | dict): - yield BaseBackwardsInvocationResponse(data=chunk).model_dump_json().encode() + b"\n\n" - elif isinstance(chunk, str): - yield f"event: {chunk}\n\n".encode() + yield BaseBackwardsInvocationResponse(data=chunk).model_dump_json().encode() except Exception as e: error_message = BaseBackwardsInvocationResponse(error=str(e)).model_dump_json() - yield f"{error_message}\n\n".encode() + yield error_message.encode() else: - yield BaseBackwardsInvocationResponse(data=response).model_dump_json().encode() + b"\n\n" + yield BaseBackwardsInvocationResponse(data=response).model_dump_json().encode() T = TypeVar("T", bound=dict | Mapping | str | bool | int | BaseModel) diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 5ec9620f22..072644e53b 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -21,7 +21,7 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from core.workflow.nodes.llm.node import LLMNode +from core.workflow.nodes.llm import llm_utils from models.account import Tenant @@ -55,20 +55,21 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): def handle() -> Generator[LLMResultChunk, None, None]: for chunk in response: if chunk.delta.usage: - LLMNode.deduct_llm_quota( + llm_utils.deduct_llm_quota( tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage ) + chunk.prompt_messages = [] yield chunk return handle() else: if response.usage: - LLMNode.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) + llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage) def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]: yield LLMResultChunk( model=response.model, - prompt_messages=response.prompt_messages, + prompt_messages=[], system_fingerprint=response.system_fingerprint, delta=LLMResultChunkDelta( index=0, diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index db07e52f3f..7898795ce2 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -64,9 +64,9 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): ) return { - "inputs": execution.inputs_dict, - "outputs": execution.outputs_dict, - "process_data": execution.process_data_dict, + "inputs": execution.inputs, + "outputs": execution.outputs, + "process_data": execution.process_data, } @classmethod @@ -113,7 +113,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): ) return { - "inputs": execution.inputs_dict, - "outputs": execution.outputs_dict, - "process_data": execution.process_data_dict, + "inputs": execution.inputs, + "outputs": execution.outputs, + "process_data": execution.process_data, } diff --git a/api/core/plugin/endpoint/exc.py b/api/core/plugin/endpoint/exc.py new file mode 100644 index 0000000000..aa29f1e9a1 --- /dev/null +++ b/api/core/plugin/endpoint/exc.py @@ -0,0 +1,6 @@ +class EndpointSetupFailedError(ValueError): + """ + Endpoint setup failed error + """ + + pass diff --git a/api/core/plugin/entities/endpoint.py b/api/core/plugin/entities/endpoint.py index 6c6c8bf9bc..d7ba75bb4f 100644 --- a/api/core/plugin/entities/endpoint.py +++ b/api/core/plugin/entities/endpoint.py @@ -24,7 +24,7 @@ class EndpointProviderDeclaration(BaseModel): """ settings: list[ProviderConfig] = Field(default_factory=list) - endpoints: Optional[list[EndpointDeclaration]] = Field(default_factory=list) + endpoints: Optional[list[EndpointDeclaration]] = Field(default_factory=list[EndpointDeclaration]) class EndpointEntity(BasePluginEntity): diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index c15f98c6ea..bdf7d5ce1f 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -52,7 +52,7 @@ class PluginResourceRequirements(BaseModel): model: Optional[Model] = Field(default=None) node: Optional[Node] = Field(default=None) endpoint: Optional[Endpoint] = Field(default=None) - storage: Storage = Field(default=None) + storage: Optional[Storage] = Field(default=None) permission: Optional[Permission] = Field(default=None) @@ -66,9 +66,9 @@ class PluginCategory(enum.StrEnum): class PluginDeclaration(BaseModel): class Plugins(BaseModel): - tools: Optional[list[str]] = Field(default_factory=list) - models: Optional[list[str]] = Field(default_factory=list) - endpoints: Optional[list[str]] = Field(default_factory=list) + tools: Optional[list[str]] = Field(default_factory=list[str]) + models: Optional[list[str]] = Field(default_factory=list[str]) + endpoints: Optional[list[str]] = Field(default_factory=list[str]) class Meta(BaseModel): minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$") diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 2bea07bea0..e0d2857e97 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -9,7 +9,7 @@ from core.agent.plugin_entities import AgentProviderEntityWithPlugin from core.model_runtime.entities.model_entities import AIModelEntity from core.model_runtime.entities.provider_entities import ProviderEntity from core.plugin.entities.base import BasePluginEntity -from core.plugin.entities.plugin import PluginDeclaration +from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin @@ -156,9 +156,23 @@ class PluginInstallTaskStartResponse(BaseModel): task_id: str = Field(description="The ID of the install task.") -class PluginUploadResponse(BaseModel): +class PluginVerification(BaseModel): + """ + Verification of the plugin. + """ + + class AuthorizedCategory(StrEnum): + Langgenius = "langgenius" + Partner = "partner" + Community = "community" + + authorized_category: AuthorizedCategory = Field(description="The authorized category of the plugin.") + + +class PluginDecodeResponse(BaseModel): unique_identifier: str = Field(description="The unique identifier of the plugin.") manifest: PluginDeclaration + verification: Optional[PluginVerification] = Field(default=None, description="Basic verification information") class PluginOAuthAuthorizationUrlResponse(BaseModel): @@ -167,3 +181,8 @@ class PluginOAuthAuthorizationUrlResponse(BaseModel): class PluginOAuthCredentialsResponse(BaseModel): credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.") + + +class PluginListResponse(BaseModel): + list: list[PluginEntity] + total: int diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 6c0c7f2868..1692020ec8 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -55,8 +55,8 @@ class RequestInvokeLLM(BaseRequestInvokeModel): mode: str completion_params: dict[str, Any] = Field(default_factory=dict) prompt_messages: list[PromptMessage] = Field(default_factory=list) - tools: Optional[list[PromptMessageTool]] = Field(default_factory=list) - stop: Optional[list[str]] = Field(default_factory=list) + tools: Optional[list[PromptMessageTool]] = Field(default_factory=list[PromptMessageTool]) + stop: Optional[list[str]] = Field(default_factory=list[str]) stream: Optional[bool] = False model_config = ConfigDict(protected_namespaces=()) diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 4f1d808a3e..7375726fa9 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -6,6 +6,7 @@ from typing import TypeVar import requests from pydantic import BaseModel +from requests.exceptions import HTTPError from yarl import URL from configs import dify_config @@ -17,6 +18,7 @@ from core.model_runtime.errors.invoke import ( InvokeServerUnavailableError, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.plugin.endpoint.exc import EndpointSetupFailedError from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError from core.plugin.impl.exc import ( PluginDaemonBadRequestError, @@ -29,8 +31,7 @@ from core.plugin.impl.exc import ( PluginUniqueIdentifierError, ) -plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_DAEMON_URL -plugin_daemon_inner_api_key = dify_config.PLUGIN_DAEMON_KEY +plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) @@ -51,9 +52,9 @@ class BasePluginClient: """ Make a request to the plugin daemon inner API. """ - url = URL(str(plugin_daemon_inner_api_baseurl)) / path + url = plugin_daemon_inner_api_baseurl / path headers = headers or {} - headers["X-Api-Key"] = plugin_daemon_inner_api_key + headers["X-Api-Key"] = dify_config.PLUGIN_DAEMON_KEY headers["Accept-Encoding"] = "gzip, deflate, br" if headers.get("Content-Type") == "application/json" and isinstance(data, dict): @@ -135,12 +136,31 @@ class BasePluginClient: """ Make a request to the plugin daemon inner API and return the response as a model. """ - response = self._request(method, path, headers, data, params, files) - json_response = response.json() - if transformer: - json_response = transformer(json_response) + try: + response = self._request(method, path, headers, data, params, files) + response.raise_for_status() + except HTTPError as e: + msg = f"Failed to request plugin daemon, status: {e.response.status_code}, url: {path}" + logging.exception(msg) + raise e + except Exception as e: + msg = f"Failed to request plugin daemon, url: {path}" + logging.exception(msg) + raise ValueError(msg) from e + + try: + json_response = response.json() + if transformer: + json_response = transformer(json_response) + rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore + except Exception: + msg = ( + f"Failed to parse response from plugin daemon to PluginDaemonBasicResponse [{str(type.__name__)}]," + f" url: {path}" + ) + logging.exception(msg) + raise ValueError(msg) - rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore if rep.code != 0: try: error = PluginDaemonError(**json.loads(rep.message)) @@ -219,6 +239,8 @@ class BasePluginClient: raise InvokeServerUnavailableError(description=args.get("description")) case CredentialsValidateFailedError.__name__: raise CredentialsValidateFailedError(error_object.get("message")) + case EndpointSetupFailedError.__name__: + raise EndpointSetupFailedError(error_object.get("message")) case _: raise PluginInvokeError(description=message) case PluginDaemonInternalServerError.__name__: diff --git a/api/core/plugin/impl/plugin.py b/api/core/plugin/impl/plugin.py index 3349463ce5..b7f7b31655 100644 --- a/api/core/plugin/impl/plugin.py +++ b/api/core/plugin/impl/plugin.py @@ -9,7 +9,12 @@ from core.plugin.entities.plugin import ( PluginInstallation, PluginInstallationSource, ) -from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginInstallTaskStartResponse, PluginUploadResponse +from core.plugin.entities.plugin_daemon import ( + PluginDecodeResponse, + PluginInstallTask, + PluginInstallTaskStartResponse, + PluginListResponse, +) from core.plugin.impl.base import BasePluginClient @@ -27,19 +32,28 @@ class PluginInstaller(BasePluginClient): ) def list_plugins(self, tenant_id: str) -> list[PluginEntity]: - return self._request_with_plugin_daemon_response( + result = self._request_with_plugin_daemon_response( "GET", f"plugin/{tenant_id}/management/list", - list[PluginEntity], + PluginListResponse, params={"page": 1, "page_size": 256}, ) + return result.list + + def list_plugins_with_total(self, tenant_id: str, page: int, page_size: int) -> PluginListResponse: + return self._request_with_plugin_daemon_response( + "GET", + f"plugin/{tenant_id}/management/list", + PluginListResponse, + params={"page": page, "page_size": page_size}, + ) def upload_pkg( self, tenant_id: str, pkg: bytes, verify_signature: bool = False, - ) -> PluginUploadResponse: + ) -> PluginDecodeResponse: """ Upload a plugin package and return the plugin unique identifier. """ @@ -54,7 +68,7 @@ class PluginInstaller(BasePluginClient): return self._request_with_plugin_daemon_response( "POST", f"plugin/{tenant_id}/management/install/upload/package", - PluginUploadResponse, + PluginDecodeResponse, files=body, data=data, ) @@ -162,6 +176,18 @@ class PluginInstaller(BasePluginClient): params={"plugin_unique_identifier": plugin_unique_identifier}, ) + def decode_plugin_from_identifier(self, tenant_id: str, plugin_unique_identifier: str) -> PluginDecodeResponse: + """ + Decode a plugin from an identifier. + """ + return self._request_with_plugin_daemon_response( + "GET", + f"plugin/{tenant_id}/management/decode/from_identifier", + PluginDecodeResponse, + data={"plugin_unique_identifier": plugin_unique_identifier}, + headers={"Content-Type": "application/json"}, + ) + def fetch_plugin_installation_by_ids( self, tenant_id: str, plugin_ids: Sequence[str] ) -> Sequence[PluginInstallation]: diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 7570200175..488a394679 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -3,7 +3,9 @@ from collections import defaultdict from json import JSONDecodeError from typing import Any, Optional, cast +from sqlalchemy import select from sqlalchemy.exc import IntegrityError +from sqlalchemy.orm import Session from configs import dify_config from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity @@ -393,19 +395,13 @@ class ProviderManager: @staticmethod def _get_all_providers(tenant_id: str) -> dict[str, list[Provider]]: - """ - Get all provider records of the workspace. - - :param tenant_id: workspace id - :return: - """ - providers = db.session.query(Provider).filter(Provider.tenant_id == tenant_id, Provider.is_valid == True).all() - provider_name_to_provider_records_dict = defaultdict(list) - for provider in providers: - # TODO: Use provider name with prefix after the data migration - provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) - + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Provider).where(Provider.tenant_id == tenant_id, Provider.is_valid == True) + providers = session.scalars(stmt) + for provider in providers: + # Use provider name with prefix after the data migration + provider_name_to_provider_records_dict[str(ModelProviderID(provider.provider_name))].append(provider) return provider_name_to_provider_records_dict @staticmethod @@ -416,17 +412,12 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - # Get all provider model records of the workspace - provider_models = ( - db.session.query(ProviderModel) - .filter(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) - .all() - ) - provider_name_to_provider_model_records_dict = defaultdict(list) - for provider_model in provider_models: - provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) - + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ProviderModel).where(ProviderModel.tenant_id == tenant_id, ProviderModel.is_valid == True) + provider_models = session.scalars(stmt) + for provider_model in provider_models: + provider_name_to_provider_model_records_dict[provider_model.provider_name].append(provider_model) return provider_name_to_provider_model_records_dict @staticmethod @@ -437,17 +428,14 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - preferred_provider_types = ( - db.session.query(TenantPreferredModelProvider) - .filter(TenantPreferredModelProvider.tenant_id == tenant_id) - .all() - ) - - provider_name_to_preferred_provider_type_records_dict = { - preferred_provider_type.provider_name: preferred_provider_type - for preferred_provider_type in preferred_provider_types - } - + provider_name_to_preferred_provider_type_records_dict = {} + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(TenantPreferredModelProvider).where(TenantPreferredModelProvider.tenant_id == tenant_id) + preferred_provider_types = session.scalars(stmt) + provider_name_to_preferred_provider_type_records_dict = { + preferred_provider_type.provider_name: preferred_provider_type + for preferred_provider_type in preferred_provider_types + } return provider_name_to_preferred_provider_type_records_dict @staticmethod @@ -458,18 +446,14 @@ class ProviderManager: :param tenant_id: workspace id :return: """ - provider_model_settings = ( - db.session.query(ProviderModelSetting).filter(ProviderModelSetting.tenant_id == tenant_id).all() - ) - provider_name_to_provider_model_settings_dict = defaultdict(list) - for provider_model_setting in provider_model_settings: - ( + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(ProviderModelSetting).where(ProviderModelSetting.tenant_id == tenant_id) + provider_model_settings = session.scalars(stmt) + for provider_model_setting in provider_model_settings: provider_name_to_provider_model_settings_dict[provider_model_setting.provider_name].append( provider_model_setting ) - ) - return provider_name_to_provider_model_settings_dict @staticmethod @@ -492,15 +476,14 @@ class ProviderManager: if not model_load_balancing_enabled: return {} - provider_load_balancing_configs = ( - db.session.query(LoadBalancingModelConfig).filter(LoadBalancingModelConfig.tenant_id == tenant_id).all() - ) - provider_name_to_provider_load_balancing_model_configs_dict = defaultdict(list) - for provider_load_balancing_config in provider_load_balancing_configs: - provider_name_to_provider_load_balancing_model_configs_dict[ - provider_load_balancing_config.provider_name - ].append(provider_load_balancing_config) + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(LoadBalancingModelConfig).where(LoadBalancingModelConfig.tenant_id == tenant_id) + provider_load_balancing_configs = session.scalars(stmt) + for provider_load_balancing_config in provider_load_balancing_configs: + provider_name_to_provider_load_balancing_model_configs_dict[ + provider_load_balancing_config.provider_name + ].append(provider_load_balancing_config) return provider_name_to_provider_load_balancing_model_configs_dict @@ -626,10 +609,9 @@ class ProviderManager: if not cached_provider_credentials: try: # fix origin data - if ( - custom_provider_record.encrypted_config - and not custom_provider_record.encrypted_config.startswith("{") - ): + if custom_provider_record.encrypted_config is None: + raise ValueError("No credentials found") + if not custom_provider_record.encrypted_config.startswith("{"): provider_credentials = {"openai_api_key": custom_provider_record.encrypted_config} else: provider_credentials = json.loads(custom_provider_record.encrypted_config) @@ -733,7 +715,7 @@ class ProviderManager: return SystemConfiguration(enabled=False) # Convert provider_records to dict - quota_type_to_provider_records_dict = {} + quota_type_to_provider_records_dict: dict[ProviderQuotaType, Provider] = {} for provider_record in provider_records: if provider_record.provider_type != ProviderType.SYSTEM.value: continue @@ -758,6 +740,11 @@ class ProviderManager: else: provider_record = quota_type_to_provider_records_dict[provider_quota.quota_type] + if provider_record.quota_used is None: + raise ValueError("quota_used is None") + if provider_record.quota_limit is None: + raise ValueError("quota_limit is None") + quota_configuration = QuotaConfiguration( quota_type=provider_quota.quota_type, quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS, @@ -791,10 +778,9 @@ class ProviderManager: cached_provider_credentials = provider_credentials_cache.get() if not cached_provider_credentials: - try: - provider_credentials: dict[str, Any] = json.loads(provider_record.encrypted_config) - except JSONDecodeError: - provider_credentials = {} + provider_credentials: dict[str, Any] = {} + if provider_records and provider_records[0].encrypted_config: + provider_credentials = json.loads(provider_records[0].encrypted_config) # Get provider credential secret variables provider_credential_secret_variables = self._extract_secret_variables( diff --git a/api/core/rag/datasource/keyword/jieba/stopwords.py b/api/core/rag/datasource/keyword/jieba/stopwords.py index 9abe78d6ef..54b65d9a2d 100644 --- a/api/core/rag/datasource/keyword/jieba/stopwords.py +++ b/api/core/rag/datasource/keyword/jieba/stopwords.py @@ -720,7 +720,7 @@ STOPWORDS = { "〉", "〈", "…", - " ", + " ", "0", "1", "2", @@ -731,16 +731,6 @@ STOPWORDS = { "7", "8", "9", - "0", - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "9", "二", "三", "四", diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 01f74b4a22..2c5178241c 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -405,7 +405,29 @@ class RetrievalService: record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks") # type: ignore record["score"] = segment_child_map[record["segment"].id]["max_score"] - return [RetrievalSegments(**record) for record in records] + result = [] + for record in records: + # Extract segment + segment = record["segment"] + + # Extract child_chunks, ensuring it's a list or None + child_chunks = record.get("child_chunks") + if not isinstance(child_chunks, list): + child_chunks = None + + # Extract score, ensuring it's a float or None + score_value = record.get("score") + score = ( + float(score_value) + if score_value is not None and isinstance(score_value, int | float | str) + else None + ) + + # Create RetrievalSegments object + retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score) + result.append(retrieval_segment) + + return result except Exception as e: db.session.rollback() raise e diff --git a/api/core/rag/datasource/vdb/baidu/baidu_vector.py b/api/core/rag/datasource/vdb/baidu/baidu_vector.py index 86f1f5bfe4..db7ffc9c4f 100644 --- a/api/core/rag/datasource/vdb/baidu/baidu_vector.py +++ b/api/core/rag/datasource/vdb/baidu/baidu_vector.py @@ -85,7 +85,6 @@ class BaiduVector(BaseVector): end = min(start + batch_size, total_count) rows = [] assert len(metadatas) == total_count, "metadatas length should be equal to total_count" - # FIXME do you need this assert? for i in range(start, end, 1): row = Row( id=metadatas[i].get("doc_id", str(uuid.uuid4())), diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 033d05a077..44cc5d3e98 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -142,7 +142,7 @@ class ElasticSearchVector(BaseVector): if score > score_threshold: if doc.metadata is not None: doc.metadata["score"] = score - docs.append(doc) + docs.append(doc) return docs diff --git a/api/core/rag/datasource/vdb/matrixone/__init__.py b/api/core/rag/datasource/vdb/matrixone/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py new file mode 100644 index 0000000000..4894957382 --- /dev/null +++ b/api/core/rag/datasource/vdb/matrixone/matrixone_vector.py @@ -0,0 +1,233 @@ +import json +import logging +import uuid +from functools import wraps +from typing import Any, Optional + +from mo_vector.client import MoVectorClient # type: ignore +from pydantic import BaseModel, model_validator + +from configs import dify_config +from core.rag.datasource.vdb.vector_base import BaseVector +from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory +from core.rag.datasource.vdb.vector_type import VectorType +from core.rag.embedding.embedding_base import Embeddings +from core.rag.models.document import Document +from extensions.ext_redis import redis_client +from models.dataset import Dataset + +logger = logging.getLogger(__name__) + + +class MatrixoneConfig(BaseModel): + host: str = "localhost" + port: int = 6001 + user: str = "dump" + password: str = "111" + database: str = "dify" + metric: str = "l2" + + @model_validator(mode="before") + @classmethod + def validate_config(cls, values: dict) -> dict: + if not values["host"]: + raise ValueError("config host is required") + if not values["port"]: + raise ValueError("config port is required") + if not values["user"]: + raise ValueError("config user is required") + if not values["password"]: + raise ValueError("config password is required") + if not values["database"]: + raise ValueError("config database is required") + return values + + +def ensure_client(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + if self.client is None: + self.client = self._get_client(None, False) + return func(self, *args, **kwargs) + + return wrapper + + +class MatrixoneVector(BaseVector): + """ + Matrixone vector storage implementation. + """ + + def __init__(self, collection_name: str, config: MatrixoneConfig): + super().__init__(collection_name) + self.config = config + self.collection_name = collection_name.lower() + self.client = None + + @property + def collection_name(self): + return self._collection_name + + @collection_name.setter + def collection_name(self, value): + self._collection_name = value + + def get_type(self) -> str: + return VectorType.MATRIXONE + + def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs): + if self.client is None: + self.client = self._get_client(len(embeddings[0]), True) + return self.add_texts(texts, embeddings) + + def _get_client(self, dimension: Optional[int] = None, create_table: bool = False) -> MoVectorClient: + """ + Create a new client for the collection. + + The collection will be created if it doesn't exist. + """ + lock_name = f"vector_indexing_lock_{self._collection_name}" + with redis_client.lock(lock_name, timeout=20): + client = MoVectorClient( + connection_string=f"mysql+pymysql://{self.config.user}:{self.config.password}@{self.config.host}:{self.config.port}/{self.config.database}", + table_name=self.collection_name, + vector_dimension=dimension, + create_table=create_table, + ) + collection_exist_cache_key = f"vector_indexing_{self._collection_name}" + if redis_client.get(collection_exist_cache_key): + return client + try: + client.create_full_text_index() + except Exception as e: + logger.exception("Failed to create full text index") + redis_client.set(collection_exist_cache_key, 1, ex=3600) + return client + + def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs): + if self.client is None: + self.client = self._get_client(len(embeddings[0]), True) + assert self.client is not None + ids = [] + for _, doc in enumerate(documents): + if doc.metadata is not None: + doc_id = doc.metadata.get("doc_id", str(uuid.uuid4())) + ids.append(doc_id) + self.client.insert( + texts=[doc.page_content for doc in documents], + embeddings=embeddings, + metadatas=[doc.metadata for doc in documents], + ids=ids, + ) + return ids + + @ensure_client + def text_exists(self, id: str) -> bool: + assert self.client is not None + result = self.client.get(ids=[id]) + return len(result) > 0 + + @ensure_client + def delete_by_ids(self, ids: list[str]) -> None: + assert self.client is not None + if not ids: + return + self.client.delete(ids=ids) + + @ensure_client + def get_ids_by_metadata_field(self, key: str, value: str): + assert self.client is not None + results = self.client.query_by_metadata(filter={key: value}) + return [result.id for result in results] + + @ensure_client + def delete_by_metadata_field(self, key: str, value: str) -> None: + assert self.client is not None + self.client.delete(filter={key: value}) + + @ensure_client + def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: + assert self.client is not None + top_k = kwargs.get("top_k", 5) + document_ids_filter = kwargs.get("document_ids_filter") + filter = None + if document_ids_filter: + filter = {"document_id": {"$in": document_ids_filter}} + + results = self.client.query( + query_vector=query_vector, + k=top_k, + filter=filter, + ) + + docs = [] + # TODO: add the score threshold to the query + for result in results: + metadata = result.metadata + docs.append( + Document( + page_content=result.document, + metadata=metadata, + ) + ) + return docs + + @ensure_client + def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: + assert self.client is not None + top_k = kwargs.get("top_k", 5) + document_ids_filter = kwargs.get("document_ids_filter") + filter = None + if document_ids_filter: + filter = {"document_id": {"$in": document_ids_filter}} + score_threshold = float(kwargs.get("score_threshold", 0.0)) + + results = self.client.full_text_query( + keywords=[query], + k=top_k, + filter=filter, + ) + + docs = [] + for result in results: + metadata = result.metadata + if isinstance(metadata, str): + import json + + metadata = json.loads(metadata) + score = 1 - result.distance + if score >= score_threshold: + metadata["score"] = score + docs.append( + Document( + page_content=result.document, + metadata=metadata, + ) + ) + return docs + + @ensure_client + def delete(self) -> None: + assert self.client is not None + self.client.delete() + + +class MatrixoneVectorFactory(AbstractVectorFactory): + def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector: + if dataset.index_struct_dict: + class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"] + collection_name = class_prefix + else: + dataset_id = dataset.id + collection_name = Dataset.gen_collection_name_by_id(dataset_id) + dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MATRIXONE, collection_name)) + + config = MatrixoneConfig( + host=dify_config.MATRIXONE_HOST or "localhost", + port=dify_config.MATRIXONE_PORT or 6001, + user=dify_config.MATRIXONE_USER or "dump", + password=dify_config.MATRIXONE_PASSWORD or "111", + database=dify_config.MATRIXONE_DATABASE or "dify", + metric=dify_config.MATRIXONE_METRIC or "l2", + ) + return MatrixoneVector(collection_name=collection_name, config=config) diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py index 7b3f826082..63de6a0603 100644 --- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py +++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py @@ -97,6 +97,10 @@ class MilvusVector(BaseVector): try: milvus_version = self._client.get_server_version() + # Check if it's Zilliz Cloud - it supports full-text search with Milvus 2.5 compatibility + if "Zilliz Cloud" in milvus_version: + return True + # For standard Milvus installations, check version number return version.parse(milvus_version).base_version >= version.parse("2.5.0").base_version except Exception as e: logger.warning(f"Failed to check Milvus version: {str(e)}. Disabling hybrid search.") diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index 2b47d179d2..dd196e1f09 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -80,6 +80,23 @@ class OceanBaseVector(BaseVector): self.delete() + vals = [] + params = self._client.perform_raw_text_sql("SHOW PARAMETERS LIKE '%ob_vector_memory_limit_percentage%'") + for row in params: + val = int(row[6]) + vals.append(val) + if len(vals) == 0: + raise ValueError("ob_vector_memory_limit_percentage not found in parameters.") + if any(val == 0 for val in vals): + try: + self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30") + except Exception as e: + raise Exception( + "Failed to set ob_vector_memory_limit_percentage. " + + "Maybe the database user has insufficient privilege.", + e, + ) + cols = [ Column("id", String(36), primary_key=True, autoincrement=False), Column("vector", VECTOR(self._vec_dim)), @@ -110,22 +127,6 @@ class OceanBaseVector(BaseVector): + "to support fulltext index and vector index in the same table", e, ) - vals = [] - params = self._client.perform_raw_text_sql("SHOW PARAMETERS LIKE '%ob_vector_memory_limit_percentage%'") - for row in params: - val = int(row[6]) - vals.append(val) - if len(vals) == 0: - raise ValueError("ob_vector_memory_limit_percentage not found in parameters.") - if any(val == 0 for val in vals): - try: - self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30") - except Exception as e: - raise Exception( - "Failed to set ob_vector_memory_limit_percentage. " - + "Maybe the database user has insufficient privilege.", - e, - ) redis_client.set(collection_exist_cache_key, 1, ex=3600) def _check_hybrid_search_support(self) -> bool: diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py index e23b8d197f..0abb3c0077 100644 --- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py +++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py @@ -23,7 +23,8 @@ logger = logging.getLogger(__name__) class OpenSearchConfig(BaseModel): host: str port: int - secure: bool = False + secure: bool = False # use_ssl + verify_certs: bool = True auth_method: Literal["basic", "aws_managed_iam"] = "basic" user: Optional[str] = None password: Optional[str] = None @@ -42,6 +43,8 @@ class OpenSearchConfig(BaseModel): raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method") if not values.get("aws_service"): raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method") + if not values.get("OPENSEARCH_SECURE") and values.get("OPENSEARCH_VERIFY_CERTS"): + raise ValueError("verify_certs=True requires secure (HTTPS) connection") return values def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth: @@ -57,7 +60,7 @@ class OpenSearchConfig(BaseModel): params = { "hosts": [{"host": self.host, "port": self.port}], "use_ssl": self.secure, - "verify_certs": self.secure, + "verify_certs": self.verify_certs, "connection_class": Urllib3HttpConnection, "pool_maxsize": 20, } @@ -181,7 +184,16 @@ class OpenSearchVector(BaseVector): } document_ids_filter = kwargs.get("document_ids_filter") if document_ids_filter: - query["query"] = {"terms": {"metadata.document_id": document_ids_filter}} + query["query"] = { + "script_score": { + "query": {"bool": {"filter": [{"terms": {Field.DOCUMENT_ID.value: document_ids_filter}}]}}, + "script": { + "source": "knn_score", + "lang": "knn", + "params": {"field": Field.VECTOR.value, "query_value": query_vector, "space_type": "l2"}, + }, + } + } try: response = self._client.search(index=self._collection_name.lower(), body=query) @@ -206,10 +218,10 @@ class OpenSearchVector(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - full_text_query = {"query": {"match": {Field.CONTENT_KEY.value: query}}} + full_text_query = {"query": {"bool": {"must": [{"match": {Field.CONTENT_KEY.value: query}}]}}} document_ids_filter = kwargs.get("document_ids_filter") if document_ids_filter: - full_text_query["query"]["terms"] = {"metadata.document_id": document_ids_filter} + full_text_query["query"]["bool"]["filter"] = [{"terms": {"metadata.document_id": document_ids_filter}}] response = self._client.search(index=self._collection_name.lower(), body=full_text_query) @@ -252,7 +264,8 @@ class OpenSearchVector(BaseVector): Field.METADATA_KEY.value: { "type": "object", "properties": { - "doc_id": {"type": "keyword"} # Map doc_id to keyword type + "doc_id": {"type": "keyword"}, # Map doc_id to keyword type + "document_id": {"type": "keyword"}, }, }, } @@ -279,6 +292,7 @@ class OpenSearchVectorFactory(AbstractVectorFactory): host=dify_config.OPENSEARCH_HOST or "localhost", port=dify_config.OPENSEARCH_PORT, secure=dify_config.OPENSEARCH_SECURE, + verify_certs=dify_config.OPENSEARCH_VERIFY_CERTS, auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value, user=dify_config.OPENSEARCH_USER, password=dify_config.OPENSEARCH_PASSWORD, diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 0a3738ac93..d1c8142b3d 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -261,7 +261,7 @@ class OracleVector(BaseVector): words = pseg.cut(query) current_entity = "" for word, pos in words: - if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名 + if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名 current_entity += word else: if current_entity: @@ -303,7 +303,6 @@ class OracleVector(BaseVector): return docs else: return [Document(page_content="", metadata={})] - return [] def delete(self) -> None: with self._get_connection() as conn: diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 366a21c381..04e9cf801e 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -1,3 +1,4 @@ +import hashlib import json import logging import uuid @@ -61,12 +62,12 @@ CREATE TABLE IF NOT EXISTS {table_name} ( """ SQL_CREATE_INDEX = """ -CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name} +CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx_{index_hash} ON {table_name} USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64); """ SQL_CREATE_INDEX_PG_BIGM = """ -CREATE INDEX IF NOT EXISTS bigm_idx ON {table_name} +CREATE INDEX IF NOT EXISTS bigm_idx_{index_hash} ON {table_name} USING gin (text gin_bigm_ops); """ @@ -76,6 +77,7 @@ class PGVector(BaseVector): super().__init__(collection_name) self.pool = self._create_connection_pool(config) self.table_name = f"embedding_{collection_name}" + self.index_hash = hashlib.md5(self.table_name.encode()).hexdigest()[:8] self.pg_bigm = config.pg_bigm def get_type(self) -> str: @@ -256,10 +258,9 @@ class PGVector(BaseVector): # PG hnsw index only support 2000 dimension or less # ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing if dimension <= 2000: - cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name)) + cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name, index_hash=self.index_hash)) if self.pg_bigm: - cur.execute("CREATE EXTENSION IF NOT EXISTS pg_bigm") - cur.execute(SQL_CREATE_INDEX_PG_BIGM.format(table_name=self.table_name)) + cur.execute(SQL_CREATE_INDEX_PG_BIGM.format(table_name=self.table_name, index_hash=self.index_hash)) redis_client.set(collection_exist_cache_key, 1, ex=3600) diff --git a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py index 1e040f415e..8ce194c683 100644 --- a/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py +++ b/api/core/rag/datasource/vdb/qdrant/qdrant_vector.py @@ -46,6 +46,7 @@ class QdrantConfig(BaseModel): root_path: Optional[str] = None grpc_port: int = 6334 prefer_grpc: bool = False + replication_factor: int = 1 def to_qdrant_params(self): if self.endpoint and self.endpoint.startswith("path:"): @@ -119,11 +120,13 @@ class QdrantVector(BaseVector): max_indexing_threads=0, on_disk=False, ) + self._client.create_collection( collection_name=collection_name, vectors_config=vectors_config, hnsw_config=hnsw_config, timeout=int(self._client_config.timeout), + replication_factor=self._client_config.replication_factor, ) # create group_id payload index @@ -466,5 +469,6 @@ class QdrantVectorFactory(AbstractVectorFactory): timeout=dify_config.QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.QDRANT_GRPC_PORT, prefer_grpc=dify_config.QDRANT_GRPC_ENABLED, + replication_factor=dify_config.QDRANT_REPLICATION_FACTOR, ), ) diff --git a/api/core/rag/datasource/vdb/tencent/tencent_vector.py b/api/core/rag/datasource/vdb/tencent/tencent_vector.py index e266659075..d2bf3eb92a 100644 --- a/api/core/rag/datasource/vdb/tencent/tencent_vector.py +++ b/api/core/rag/datasource/vdb/tencent/tencent_vector.py @@ -271,12 +271,15 @@ class TencentVector(BaseVector): for result in res[0]: meta = result.get(self.field_metadata) + if isinstance(meta, str): + # Compatible with version 1.1.3 and below. + meta = json.loads(meta) + score = 1 - result.get("score", 0.0) score = result.get("score", 0.0) if score > score_threshold: meta["score"] = score doc = Document(page_content=result.get(self.field_text), metadata=meta) docs.append(doc) - return docs def delete(self) -> None: diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index 6a61fe9496..6f895b12af 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -49,6 +49,7 @@ class TidbOnQdrantConfig(BaseModel): root_path: Optional[str] = None grpc_port: int = 6334 prefer_grpc: bool = False + replication_factor: int = 1 def to_qdrant_params(self): if self.endpoint and self.endpoint.startswith("path:"): @@ -134,6 +135,7 @@ class TidbOnQdrantVector(BaseVector): vectors_config=vectors_config, hnsw_config=hnsw_config, timeout=int(self._client_config.timeout), + replication_factor=self._client_config.replication_factor, ) # create group_id payload index @@ -484,6 +486,7 @@ class TidbOnQdrantVectorFactory(AbstractVectorFactory): timeout=dify_config.TIDB_ON_QDRANT_CLIENT_TIMEOUT, grpc_port=dify_config.TIDB_ON_QDRANT_GRPC_PORT, prefer_grpc=dify_config.TIDB_ON_QDRANT_GRPC_ENABLED, + replication_factor=dify_config.QDRANT_REPLICATION_FACTOR, ), ) diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py index 3958280bd5..184b5f2142 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py @@ -245,4 +245,4 @@ class TidbService: return cluster_infos else: response.raise_for_status() - return [] # FIXME for mypy, This line will not be reached as raise_for_status() will raise an exception + return [] diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index 66e002312a..67a4a515b1 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -164,6 +164,10 @@ class Vector: from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory return HuaweiCloudVectorFactory + case VectorType.MATRIXONE: + from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory + + return MatrixoneVectorFactory case _: raise ValueError(f"Vector store {vector_type} is not supported.") diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 7a81565e37..0d70947b72 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -29,3 +29,4 @@ class VectorType(StrEnum): OPENGAUSS = "opengauss" TABLESTORE = "tablestore" HUAWEI_CLOUD = "huawei_cloud" + MATRIXONE = "matrixone" diff --git a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py index 8fe6199517..7a8efb4068 100644 --- a/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py +++ b/api/core/rag/datasource/vdb/weaviate/weaviate_vector.py @@ -41,6 +41,13 @@ class WeaviateVector(BaseVector): weaviate.connect.connection.has_grpc = False + # Fix to minimize the performance impact of the deprecation check in weaviate-client 3.24.0, + # by changing the connection timeout to pypi.org from 1 second to 0.001 seconds. + # TODO: This can be removed once weaviate-client is updated to 3.26.7 or higher, + # which does not contain the deprecation check. + if hasattr(weaviate.connect.connection, "PYPI_TIMEOUT"): + weaviate.connect.connection.PYPI_TIMEOUT = 0.001 + try: client = weaviate.Client( url=config.endpoint, auth_client_secret=auth_config, timeout_config=(5, 60), startup_period=None diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 42fad111ce..f50f9f6b60 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -139,4 +139,4 @@ class CacheEmbedding(Embeddings): logging.exception(f"Failed to add embedding to redis for the text '{text[:10]}...({len(text)} chars)'") raise ex - return embedding_results + return embedding_results # type: ignore diff --git a/api/core/rag/entities/citation_metadata.py b/api/core/rag/entities/citation_metadata.py new file mode 100644 index 0000000000..00120425c9 --- /dev/null +++ b/api/core/rag/entities/citation_metadata.py @@ -0,0 +1,23 @@ +from typing import Any, Optional + +from pydantic import BaseModel + + +class RetrievalSourceMetadata(BaseModel): + position: Optional[int] = None + dataset_id: Optional[str] = None + dataset_name: Optional[str] = None + document_id: Optional[str] = None + document_name: Optional[str] = None + data_source_type: Optional[str] = None + segment_id: Optional[str] = None + retriever_from: Optional[str] = None + score: Optional[float] = None + hit_count: Optional[int] = None + word_count: Optional[int] = None + segment_position: Optional[int] = None + index_node_hash: Optional[str] = None + content: Optional[str] = None + page: Optional[int] = None + doc_metadata: Optional[dict[str, Any]] = None + title: Optional[str] = None diff --git a/api/core/rag/extractor/entity/extract_setting.py b/api/core/rag/extractor/entity/extract_setting.py index 7c00c668dd..1593ad1475 100644 --- a/api/core/rag/extractor/entity/extract_setting.py +++ b/api/core/rag/extractor/entity/extract_setting.py @@ -27,6 +27,8 @@ class WebsiteInfo(BaseModel): website import info. """ + model_config = ConfigDict(arbitrary_types_allowed=True) + provider: str job_id: str url: str @@ -34,12 +36,6 @@ class WebsiteInfo(BaseModel): tenant_id: str only_main_content: bool = False - class Config: - arbitrary_types_allowed = True - - def __init__(self, **data) -> None: - super().__init__(**data) - class ExtractSetting(BaseModel): """ diff --git a/api/core/rag/extractor/firecrawl/firecrawl_app.py b/api/core/rag/extractor/firecrawl/firecrawl_app.py index 836a1398bf..83a4ac651f 100644 --- a/api/core/rag/extractor/firecrawl/firecrawl_app.py +++ b/api/core/rag/extractor/firecrawl/firecrawl_app.py @@ -22,6 +22,7 @@ class FirecrawlApp: "formats": ["markdown"], "onlyMainContent": True, "timeout": 30000, + "integration": "dify", } if params: json_data.update(params) @@ -39,7 +40,7 @@ class FirecrawlApp: def crawl_url(self, url, params=None) -> str: # Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post headers = self._prepare_headers() - json_data = {"url": url} + json_data = {"url": url, "integration": "dify"} if params: json_data.update(params) response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers) @@ -49,7 +50,6 @@ class FirecrawlApp: return cast(str, job_id) else: self._handle_error(response, "start crawl job") - # FIXME: unreachable code for mypy return "" # unreachable def check_crawl_status(self, job_id) -> dict[str, Any]: @@ -82,7 +82,6 @@ class FirecrawlApp: ) else: self._handle_error(response, "check crawl status") - # FIXME: unreachable code for mypy return {} # unreachable def _format_crawl_status_response( @@ -126,4 +125,31 @@ class FirecrawlApp: def _handle_error(self, response, action) -> None: error_message = response.json().get("error", "Unknown error occurred") - raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") + raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return] + + def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]: + # Documentation: https://docs.firecrawl.dev/api-reference/endpoint/search + headers = self._prepare_headers() + json_data = { + "query": query, + "limit": 5, + "lang": "en", + "country": "us", + "timeout": 60000, + "ignoreInvalidURLs": False, + "scrapeOptions": {}, + "integration": "dify", + } + if params: + json_data.update(params) + response = self._post_request(f"{self.base_url}/v1/search", json_data, headers) + if response.status_code == 200: + response_data = response.json() + if not response_data.get("success"): + raise Exception(f"Search failed. Error: {response_data.get('warning', 'Unknown error')}") + return cast(dict[str, Any], response_data) + elif response.status_code in {402, 409, 500, 429, 408}: + self._handle_error(response, "perform search") + return {} # Avoid additional exception after handling error + else: + raise Exception(f"Failed to perform search. Status code: {response.status_code}") diff --git a/api/core/rag/extractor/markdown_extractor.py b/api/core/rag/extractor/markdown_extractor.py index 849852ac23..c97765b1dc 100644 --- a/api/core/rag/extractor/markdown_extractor.py +++ b/api/core/rag/extractor/markdown_extractor.py @@ -68,22 +68,17 @@ class MarkdownExtractor(BaseExtractor): continue header_match = re.match(r"^#+\s", line) if header_match: - if current_header is not None: - markdown_tups.append((current_header, current_text)) - + markdown_tups.append((current_header, current_text)) current_header = line current_text = "" else: current_text += line + "\n" markdown_tups.append((current_header, current_text)) - if current_header is not None: - # pass linting, assert keys are defined - markdown_tups = [ - (re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) for key, value in markdown_tups - ] - else: - markdown_tups = [(key, re.sub("\n", "", value)) for key, value in markdown_tups] + markdown_tups = [ + (re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value)) + for key, value in markdown_tups + ] return markdown_tups diff --git a/api/core/rag/extractor/notion_extractor.py b/api/core/rag/extractor/notion_extractor.py index 7ab248199a..eca955ddd1 100644 --- a/api/core/rag/extractor/notion_extractor.py +++ b/api/core/rag/extractor/notion_extractor.py @@ -79,55 +79,71 @@ class NotionExtractor(BaseExtractor): def _get_notion_database_data(self, database_id: str, query_dict: dict[str, Any] = {}) -> list[Document]: """Get all the pages from a Notion database.""" assert self._notion_access_token is not None, "Notion access token is required" - res = requests.post( - DATABASE_URL_TMPL.format(database_id=database_id), - headers={ - "Authorization": "Bearer " + self._notion_access_token, - "Content-Type": "application/json", - "Notion-Version": "2022-06-28", - }, - json=query_dict, - ) - - data = res.json() database_content = [] - if "results" not in data or data["results"] is None: - return [] - for result in data["results"]: - properties = result["properties"] - data = {} - value: Any - for property_name, property_value in properties.items(): - type = property_value["type"] - if type == "multi_select": - value = [] - multi_select_list = property_value[type] - for multi_select in multi_select_list: - value.append(multi_select["name"]) - elif type in {"rich_text", "title"}: - if len(property_value[type]) > 0: - value = property_value[type][0]["plain_text"] + next_cursor = None + has_more = True + + while has_more: + current_query = query_dict.copy() + if next_cursor: + current_query["start_cursor"] = next_cursor + + res = requests.post( + DATABASE_URL_TMPL.format(database_id=database_id), + headers={ + "Authorization": "Bearer " + self._notion_access_token, + "Content-Type": "application/json", + "Notion-Version": "2022-06-28", + }, + json=current_query, + ) + + response_data = res.json() + + if "results" not in response_data or response_data["results"] is None: + break + + for result in response_data["results"]: + properties = result["properties"] + data = {} + value: Any + for property_name, property_value in properties.items(): + type = property_value["type"] + if type == "multi_select": + value = [] + multi_select_list = property_value[type] + for multi_select in multi_select_list: + value.append(multi_select["name"]) + elif type in {"rich_text", "title"}: + if len(property_value[type]) > 0: + value = property_value[type][0]["plain_text"] + else: + value = "" + elif type in {"select", "status"}: + if property_value[type]: + value = property_value[type]["name"] + else: + value = "" else: - value = "" - elif type in {"select", "status"}: - if property_value[type]: - value = property_value[type]["name"] + value = property_value[type] + data[property_name] = value + row_dict = {k: v for k, v in data.items() if v} + row_content = "" + for key, value in row_dict.items(): + if isinstance(value, dict): + value_dict = {k: v for k, v in value.items() if v} + value_content = "".join(f"{k}:{v} " for k, v in value_dict.items()) + row_content = row_content + f"{key}:{value_content}\n" else: - value = "" - else: - value = property_value[type] - data[property_name] = value - row_dict = {k: v for k, v in data.items() if v} - row_content = "" - for key, value in row_dict.items(): - if isinstance(value, dict): - value_dict = {k: v for k, v in value.items() if v} - value_content = "".join(f"{k}:{v} " for k, v in value_dict.items()) - row_content = row_content + f"{key}:{value_content}\n" - else: - row_content = row_content + f"{key}:{value}\n" - database_content.append(row_content) + row_content = row_content + f"{key}:{value}\n" + database_content.append(row_content) + + has_more = response_data.get("has_more", False) + next_cursor = response_data.get("next_cursor") + + if not database_content: + return [] return [Document(page_content="\n".join(database_content))] @@ -317,7 +333,7 @@ class NotionExtractor(BaseExtractor): data_source_info["last_edited_time"] = last_edited_time update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)} - DocumentModel.query.filter_by(id=document_model.id).update(update_params) + db.session.query(DocumentModel).filter_by(id=document_model.id).update(update_params) db.session.commit() def get_notion_last_edited_time(self) -> str: @@ -347,14 +363,18 @@ class NotionExtractor(BaseExtractor): @classmethod def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: - data_source_binding = DataSourceOauthBinding.query.filter( - db.and_( - DataSourceOauthBinding.tenant_id == tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', + data_source_binding = ( + db.session.query(DataSourceOauthBinding) + .filter( + db.and_( + DataSourceOauthBinding.tenant_id == tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', + ) ) - ).first() + .first() + ) if not data_source_binding: raise Exception( diff --git a/api/core/rag/extractor/watercrawl/client.py b/api/core/rag/extractor/watercrawl/client.py index 6eaede7dbc..6d596e07d8 100644 --- a/api/core/rag/extractor/watercrawl/client.py +++ b/api/core/rag/extractor/watercrawl/client.py @@ -6,6 +6,12 @@ from urllib.parse import urljoin import requests from requests import Response +from core.rag.extractor.watercrawl.exceptions import ( + WaterCrawlAuthenticationError, + WaterCrawlBadRequestError, + WaterCrawlPermissionError, +) + class BaseAPIClient: def __init__(self, api_key, base_url): @@ -53,6 +59,15 @@ class WaterCrawlAPIClient(BaseAPIClient): yield data def process_response(self, response: Response) -> dict | bytes | list | None | Generator: + if response.status_code == 401: + raise WaterCrawlAuthenticationError(response) + + if response.status_code == 403: + raise WaterCrawlPermissionError(response) + + if 400 <= response.status_code < 500: + raise WaterCrawlBadRequestError(response) + response.raise_for_status() if response.status_code == 204: return None diff --git a/api/core/rag/extractor/watercrawl/exceptions.py b/api/core/rag/extractor/watercrawl/exceptions.py new file mode 100644 index 0000000000..e407a594e0 --- /dev/null +++ b/api/core/rag/extractor/watercrawl/exceptions.py @@ -0,0 +1,32 @@ +import json + + +class WaterCrawlError(Exception): + pass + + +class WaterCrawlBadRequestError(WaterCrawlError): + def __init__(self, response): + self.status_code = response.status_code + self.response = response + data = response.json() + self.message = data.get("message", "Unknown error occurred") + self.errors = data.get("errors", {}) + super().__init__(self.message) + + @property + def flat_errors(self): + return json.dumps(self.errors) + + def __str__(self): + return f"WaterCrawlBadRequestError: {self.message} \n {self.flat_errors}" + + +class WaterCrawlPermissionError(WaterCrawlBadRequestError): + def __str__(self): + return f"You are exceeding your WaterCrawl API limits. {self.message}" + + +class WaterCrawlAuthenticationError(WaterCrawlBadRequestError): + def __str__(self): + return "WaterCrawl API key is invalid or expired. Please check your API key and try again." diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index edaa8c92fa..bff0acc48f 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -19,7 +19,7 @@ from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_storage import storage -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import UploadFile logger = logging.getLogger(__name__) @@ -76,8 +76,7 @@ class WordExtractor(BaseExtractor): parsed = urlparse(url) return bool(parsed.netloc) and bool(parsed.scheme) - def _extract_images_from_docx(self, doc, image_folder): - os.makedirs(image_folder, exist_ok=True) + def _extract_images_from_docx(self, doc): image_count = 0 image_map = {} @@ -117,7 +116,7 @@ class WordExtractor(BaseExtractor): extension=str(image_ext), mime_type=mime_type or "", created_by=self.user_id, - created_by_role=CreatedByRole.ACCOUNT, + created_by_role=CreatorUserRole.ACCOUNT, created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=True, used_by=self.user_id, @@ -210,7 +209,7 @@ class WordExtractor(BaseExtractor): content = [] - image_map = self._extract_images_from_docx(doc, image_folder) + image_map = self._extract_images_from_docx(doc) hyperlinks_url = None url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+") @@ -225,7 +224,7 @@ class WordExtractor(BaseExtractor): xml = ElementTree.XML(run.element.xml) x_child = [c for c in xml.iter() if c is not None] for x in x_child: - if x_child is None: + if x is None: continue if x.tag.endswith("instrText"): if x.text is None: diff --git a/api/core/rag/index_processor/processor/qa_index_processor.py b/api/core/rag/index_processor/processor/qa_index_processor.py index 0055625e13..75f3153697 100644 --- a/api/core/rag/index_processor/processor/qa_index_processor.py +++ b/api/core/rag/index_processor/processor/qa_index_processor.py @@ -104,7 +104,7 @@ class QAIndexProcessor(BaseIndexProcessor): def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: # check file type - if not file.filename.endswith(".csv"): + if not file.filename or not file.filename.lower().endswith(".csv"): raise ValueError("Invalid file type. Only CSV files are allowed") try: diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index 421cdc05df..04a3428ad8 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -45,13 +45,12 @@ class BaseDocumentTransformer(ABC): .. code-block:: python class EmbeddingsRedundantFilter(BaseDocumentTransformer, BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + embeddings: Embeddings similarity_fn: Callable = cosine_similarity similarity_threshold: float = 0.95 - class Config: - arbitrary_types_allowed = True - def transform_documents( self, documents: Sequence[Document], **kwargs: Any ) -> Sequence[Document]: diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 9216b31b8e..38c0b540d5 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -35,6 +35,7 @@ from core.prompt.simple_prompt_transform import ModelMode from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.index_processor.constant.index_type import IndexType @@ -190,7 +191,7 @@ class DatasetRetrieval: retrieve_config.rerank_mode or "reranking_model", retrieve_config.reranking_model, retrieve_config.weights, - retrieve_config.reranking_enabled or True, + True if retrieve_config.reranking_enabled is None else retrieve_config.reranking_enabled, message_id, metadata_filter_document_ids, metadata_condition, @@ -198,21 +199,21 @@ class DatasetRetrieval: dify_documents = [item for item in all_documents if item.provider == "dify"] external_documents = [item for item in all_documents if item.provider == "external"] - document_context_list = [] - retrieval_resource_list = [] + document_context_list: list[DocumentContext] = [] + retrieval_resource_list: list[RetrievalSourceMetadata] = [] # deal with external documents for item in external_documents: document_context_list.append(DocumentContext(content=item.page_content, score=item.metadata.get("score"))) - source = { - "dataset_id": item.metadata.get("dataset_id"), - "dataset_name": item.metadata.get("dataset_name"), - "document_id": item.metadata.get("document_id") or item.metadata.get("title"), - "document_name": item.metadata.get("title"), - "data_source_type": "external", - "retriever_from": invoke_from.to_source(), - "score": item.metadata.get("score"), - "content": item.page_content, - } + source = RetrievalSourceMetadata( + dataset_id=item.metadata.get("dataset_id"), + dataset_name=item.metadata.get("dataset_name"), + document_id=item.metadata.get("document_id") or item.metadata.get("title"), + document_name=item.metadata.get("title"), + data_source_type="external", + retriever_from=invoke_from.to_source(), + score=item.metadata.get("score"), + content=item.page_content, + ) retrieval_resource_list.append(source) # deal with dify documents if dify_documents: @@ -237,39 +238,43 @@ class DatasetRetrieval: if show_retrieve_source: for record in records: segment = record.segment - dataset = Dataset.query.filter_by(id=segment.dataset_id).first() - document = DatasetDocument.query.filter( - DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).first() + dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() + document = ( + db.session.query(DatasetDocument) + .filter( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .first() + ) if dataset and document: - source = { - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, - "segment_id": segment.id, - "retriever_from": invoke_from.to_source(), - "score": record.score or 0.0, - "doc_metadata": document.doc_metadata, - } + source = RetrievalSourceMetadata( + dataset_id=dataset.id, + dataset_name=dataset.name, + document_id=document.id, + document_name=document.name, + data_source_type=document.data_source_type, + segment_id=segment.id, + retriever_from=invoke_from.to_source(), + score=record.score or 0.0, + doc_metadata=document.doc_metadata, + ) if invoke_from.to_source() == "dev": - source["hit_count"] = segment.hit_count - source["word_count"] = segment.word_count - source["segment_position"] = segment.position - source["index_node_hash"] = segment.index_node_hash + source.hit_count = segment.hit_count + source.word_count = segment.word_count + source.segment_position = segment.position + source.index_node_hash = segment.index_node_hash if segment.answer: - source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: - source["content"] = segment.content + source.content = segment.content retrieval_resource_list.append(source) if hit_callback and retrieval_resource_list: - retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score") or 0.0, reverse=True) + retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.score or 0.0, reverse=True) for position, item in enumerate(retrieval_resource_list, start=1): - item["position"] = position + item.position = position hit_callback.return_retriever_resource_info(retrieval_resource_list) if document_context_list: document_context_list = sorted(document_context_list, key=lambda x: x.score or 0.0, reverse=True) @@ -491,6 +496,8 @@ class DatasetRetrieval: all_documents = self.calculate_keyword_score(query, all_documents, top_k) elif index_type == "high_quality": all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold) + else: + all_documents = all_documents[:top_k] if top_k else all_documents self._on_query(query, dataset_ids, app_id, user_from, user_id) @@ -506,19 +513,30 @@ class DatasetRetrieval: dify_documents = [document for document in documents if document.provider == "dify"] for document in dify_documents: if document.metadata is not None: - dataset_document = DatasetDocument.query.filter( - DatasetDocument.id == document.metadata["document_id"] - ).first() + dataset_document = ( + db.session.query(DatasetDocument) + .filter(DatasetDocument.id == document.metadata["document_id"]) + .first() + ) if dataset_document: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: - child_chunk = ChildChunk.query.filter( - ChildChunk.index_node_id == document.metadata["doc_id"], - ChildChunk.dataset_id == dataset_document.dataset_id, - ChildChunk.document_id == dataset_document.id, - ).first() + child_chunk = ( + db.session.query(ChildChunk) + .filter( + ChildChunk.index_node_id == document.metadata["doc_id"], + ChildChunk.dataset_id == dataset_document.dataset_id, + ChildChunk.document_id == dataset_document.id, + ) + .first() + ) if child_chunk: - segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == child_chunk.segment_id) + .update( + {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, + synchronize_session=False, + ) ) db.session.commit() else: @@ -921,6 +939,9 @@ class DatasetRetrieval: return metadata_filter_document_ids, metadata_condition def _replace_metadata_filter_value(self, text: str, inputs: dict) -> str: + if not inputs: + return text + def replacer(match): key = match.group(1) return str(inputs.get(key, f"{{{{{key}}}}}")) diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index f0426ace1f..33a283771d 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -9,7 +9,7 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser -from core.workflow.nodes.llm import LLMNode +from core.workflow.nodes.llm import llm_utils PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" @@ -165,7 +165,7 @@ class ReactMultiDatasetRouter: text, usage = self._handle_invoke_result(invoke_result=invoke_result) # deduct quota - LLMNode.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) + llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) return text, usage diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py new file mode 100644 index 0000000000..e30538742a --- /dev/null +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -0,0 +1,243 @@ +""" +SQLAlchemy implementation of the WorkflowExecutionRepository. +""" + +import json +import logging +from typing import Optional, Union + +from sqlalchemy import select +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.workflow.entities.workflow_execution import ( + WorkflowExecution, + WorkflowExecutionStatus, + WorkflowType, +) +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from models import ( + Account, + CreatorUserRole, + EndUser, + WorkflowRun, +) +from models.enums import WorkflowRunTriggeredFrom + +logger = logging.getLogger(__name__) + + +class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): + """ + SQLAlchemy implementation of the WorkflowExecutionRepository interface. + + This implementation supports multi-tenancy by filtering operations based on tenant_id. + Each method creates its own session, handles the transaction, and commits changes + to the database. This prevents long-running connections in the workflow core. + + This implementation also includes an in-memory cache for workflow executions to improve + performance by reducing database queries. + """ + + def __init__( + self, + session_factory: sessionmaker | Engine, + user: Union[Account, EndUser], + app_id: Optional[str], + triggered_from: Optional[WorkflowRunTriggeredFrom], + ): + """ + Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. + + Args: + session_factory: SQLAlchemy sessionmaker or engine for creating sessions + user: Account or EndUser object containing tenant_id, user ID, and role information + app_id: App ID for filtering by application (can be None) + triggered_from: Source of the execution trigger (DEBUGGING or APP_RUN) + """ + # If an engine is provided, create a sessionmaker from it + if isinstance(session_factory, Engine): + self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False) + elif isinstance(session_factory, sessionmaker): + self._session_factory = session_factory + else: + raise ValueError( + f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" + ) + + # Extract tenant_id from user + tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + if not tenant_id: + raise ValueError("User must have a tenant_id or current_tenant_id") + self._tenant_id = tenant_id + + # Store app context + self._app_id = app_id + + # Extract user context + self._triggered_from = triggered_from + self._creator_user_id = user.id + + # Determine user role based on user type + self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER + + # Initialize in-memory cache for workflow executions + # Key: execution_id, Value: WorkflowRun (DB model) + self._execution_cache: dict[str, WorkflowRun] = {} + + def _to_domain_model(self, db_model: WorkflowRun) -> WorkflowExecution: + """ + Convert a database model to a domain model. + + Args: + db_model: The database model to convert + + Returns: + The domain model + """ + # Parse JSON fields + inputs = db_model.inputs_dict + outputs = db_model.outputs_dict + graph = db_model.graph_dict + + # Convert status to domain enum + status = WorkflowExecutionStatus(db_model.status) + + return WorkflowExecution( + id_=db_model.id, + workflow_id=db_model.workflow_id, + workflow_type=WorkflowType(db_model.type), + workflow_version=db_model.version, + graph=graph, + inputs=inputs, + outputs=outputs, + status=status, + error_message=db_model.error or "", + total_tokens=db_model.total_tokens, + total_steps=db_model.total_steps, + exceptions_count=db_model.exceptions_count, + started_at=db_model.created_at, + finished_at=db_model.finished_at, + ) + + def _to_db_model(self, domain_model: WorkflowExecution) -> WorkflowRun: + """ + Convert a domain model to a database model. + + Args: + domain_model: The domain model to convert + + Returns: + The database model + """ + # Use values from constructor if provided + if not self._triggered_from: + raise ValueError("triggered_from is required in repository constructor") + if not self._creator_user_id: + raise ValueError("created_by is required in repository constructor") + if not self._creator_user_role: + raise ValueError("created_by_role is required in repository constructor") + + db_model = WorkflowRun() + db_model.id = domain_model.id_ + db_model.tenant_id = self._tenant_id + if self._app_id is not None: + db_model.app_id = self._app_id + db_model.workflow_id = domain_model.workflow_id + db_model.triggered_from = self._triggered_from + + # No sequence number generation needed anymore + + db_model.type = domain_model.workflow_type + db_model.version = domain_model.workflow_version + db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None + db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None + db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None + db_model.status = domain_model.status + db_model.error = domain_model.error_message if domain_model.error_message else None + db_model.total_tokens = domain_model.total_tokens + db_model.total_steps = domain_model.total_steps + db_model.exceptions_count = domain_model.exceptions_count + db_model.created_by_role = self._creator_user_role + db_model.created_by = self._creator_user_id + db_model.created_at = domain_model.started_at + db_model.finished_at = domain_model.finished_at + + # Calculate elapsed time if finished_at is available + if domain_model.finished_at: + db_model.elapsed_time = (domain_model.finished_at - domain_model.started_at).total_seconds() + else: + db_model.elapsed_time = 0 + + return db_model + + def save(self, execution: WorkflowExecution) -> None: + """ + Save or update a WorkflowExecution domain entity to the database. + + This method serves as a domain-to-database adapter that: + 1. Converts the domain entity to its database representation + 2. Persists the database model using SQLAlchemy's merge operation + 3. Maintains proper multi-tenancy by including tenant context during conversion + 4. Updates the in-memory cache for faster subsequent lookups + + The method handles both creating new records and updating existing ones through + SQLAlchemy's merge operation. + + Args: + execution: The WorkflowExecution domain entity to persist + """ + # Convert domain model to database model using tenant context and other attributes + db_model = self._to_db_model(execution) + + # Create a new database session + with self._session_factory() as session: + # SQLAlchemy merge intelligently handles both insert and update operations + # based on the presence of the primary key + session.merge(db_model) + session.commit() + + # Update the in-memory cache for faster subsequent lookups + logger.debug(f"Updating cache for execution_id: {db_model.id}") + self._execution_cache[db_model.id] = db_model + + def get(self, execution_id: str) -> Optional[WorkflowExecution]: + """ + Retrieve a WorkflowExecution by its ID. + + First checks the in-memory cache, and if not found, queries the database. + If found in the database, adds it to the cache for future lookups. + + Args: + execution_id: The workflow execution ID + + Returns: + The WorkflowExecution instance if found, None otherwise + """ + # First check the cache + if execution_id in self._execution_cache: + logger.debug(f"Cache hit for execution_id: {execution_id}") + # Convert cached DB model to domain model + cached_db_model = self._execution_cache[execution_id] + return self._to_domain_model(cached_db_model) + + # If not in cache, query the database + logger.debug(f"Cache miss for execution_id: {execution_id}, querying database") + with self._session_factory() as session: + stmt = select(WorkflowRun).where( + WorkflowRun.id == execution_id, + WorkflowRun.tenant_id == self._tenant_id, + ) + + if self._app_id: + stmt = stmt.where(WorkflowRun.app_id == self._app_id) + + db_model = session.scalar(stmt) + if db_model: + # Add DB model to cache + self._execution_cache[execution_id] = db_model + + # Convert to domain model and return + return self._to_domain_model(db_model) + + return None diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 8bf2ab8761..2f27442616 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -2,16 +2,30 @@ SQLAlchemy implementation of the WorkflowNodeExecutionRepository. """ +import json import logging from collections.abc import Sequence -from typing import Optional +from typing import Optional, Union from sqlalchemy import UnaryExpression, asc, delete, desc, select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom +from core.model_runtime.utils.encoders import jsonable_encoder +from core.workflow.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from core.workflow.nodes.enums import NodeType +from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository +from models import ( + Account, + CreatorUserRole, + EndUser, + WorkflowNodeExecutionModel, + WorkflowNodeExecutionTriggeredFrom, +) logger = logging.getLogger(__name__) @@ -23,16 +37,26 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) This implementation supports multi-tenancy by filtering operations based on tenant_id. Each method creates its own session, handles the transaction, and commits changes to the database. This prevents long-running connections in the workflow core. + + This implementation also includes an in-memory cache for node executions to improve + performance by reducing database queries. """ - def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None): + def __init__( + self, + session_factory: sessionmaker | Engine, + user: Union[Account, EndUser], + app_id: Optional[str], + triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom], + ): """ - Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context. + Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. Args: session_factory: SQLAlchemy sessionmaker or engine for creating sessions - tenant_id: Tenant ID for multi-tenancy - app_id: Optional app ID for filtering by application + user: Account or EndUser object containing tenant_id, user ID, and role information + app_id: App ID for filtering by application (can be None) + triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN) """ # If an engine is provided, create a sessionmaker from it if isinstance(session_factory, Engine): @@ -44,56 +68,197 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" ) + # Extract tenant_id from user + tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + if not tenant_id: + raise ValueError("User must have a tenant_id or current_tenant_id") self._tenant_id = tenant_id + + # Store app context self._app_id = app_id - def save(self, execution: WorkflowNodeExecution) -> None: + # Extract user context + self._triggered_from = triggered_from + self._creator_user_id = user.id + + # Determine user role based on user type + self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER + + # Initialize in-memory cache for node executions + # Key: node_execution_id, Value: WorkflowNodeExecution (DB model) + self._node_execution_cache: dict[str, WorkflowNodeExecutionModel] = {} + + def _to_domain_model(self, db_model: WorkflowNodeExecutionModel) -> WorkflowNodeExecution: """ - Save a WorkflowNodeExecution instance and commit changes to the database. + Convert a database model to a domain model. Args: - execution: The WorkflowNodeExecution instance to save + db_model: The database model to convert + + Returns: + The domain model """ - with self._session_factory() as session: - # Ensure tenant_id is set - if not execution.tenant_id: - execution.tenant_id = self._tenant_id + # Parse JSON fields + inputs = db_model.inputs_dict + process_data = db_model.process_data_dict + outputs = db_model.outputs_dict + metadata = {WorkflowNodeExecutionMetadataKey(k): v for k, v in db_model.execution_metadata_dict.items()} + + # Convert status to domain enum + status = WorkflowNodeExecutionStatus(db_model.status) + + return WorkflowNodeExecution( + id=db_model.id, + node_execution_id=db_model.node_execution_id, + workflow_id=db_model.workflow_id, + workflow_execution_id=db_model.workflow_run_id, + index=db_model.index, + predecessor_node_id=db_model.predecessor_node_id, + node_id=db_model.node_id, + node_type=NodeType(db_model.node_type), + title=db_model.title, + inputs=inputs, + process_data=process_data, + outputs=outputs, + status=status, + error=db_model.error, + elapsed_time=db_model.elapsed_time, + metadata=metadata, + created_at=db_model.created_at, + finished_at=db_model.finished_at, + ) + + def to_db_model(self, domain_model: WorkflowNodeExecution) -> WorkflowNodeExecutionModel: + """ + Convert a domain model to a database model. - # Set app_id if provided and not already set - if self._app_id and not execution.app_id: - execution.app_id = self._app_id + Args: + domain_model: The domain model to convert + + Returns: + The database model + """ + # Use values from constructor if provided + if not self._triggered_from: + raise ValueError("triggered_from is required in repository constructor") + if not self._creator_user_id: + raise ValueError("created_by is required in repository constructor") + if not self._creator_user_role: + raise ValueError("created_by_role is required in repository constructor") + + db_model = WorkflowNodeExecutionModel() + db_model.id = domain_model.id + db_model.tenant_id = self._tenant_id + if self._app_id is not None: + db_model.app_id = self._app_id + db_model.workflow_id = domain_model.workflow_id + db_model.triggered_from = self._triggered_from + db_model.workflow_run_id = domain_model.workflow_execution_id + db_model.index = domain_model.index + db_model.predecessor_node_id = domain_model.predecessor_node_id + db_model.node_execution_id = domain_model.node_execution_id + db_model.node_id = domain_model.node_id + db_model.node_type = domain_model.node_type + db_model.title = domain_model.title + db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None + db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None + db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None + db_model.status = domain_model.status + db_model.error = domain_model.error + db_model.elapsed_time = domain_model.elapsed_time + db_model.execution_metadata = ( + json.dumps(jsonable_encoder(domain_model.metadata)) if domain_model.metadata else None + ) + db_model.created_at = domain_model.created_at + db_model.created_by_role = self._creator_user_role + db_model.created_by = self._creator_user_id + db_model.finished_at = domain_model.finished_at + return db_model - session.add(execution) + def save(self, execution: WorkflowNodeExecution) -> None: + """ + Save or update a NodeExecution domain entity to the database. + + This method serves as a domain-to-database adapter that: + 1. Converts the domain entity to its database representation + 2. Persists the database model using SQLAlchemy's merge operation + 3. Maintains proper multi-tenancy by including tenant context during conversion + 4. Updates the in-memory cache for faster subsequent lookups + + The method handles both creating new records and updating existing ones through + SQLAlchemy's merge operation. + + Args: + execution: The NodeExecution domain entity to persist + """ + # Convert domain model to database model using tenant context and other attributes + db_model = self.to_db_model(execution) + + # Create a new database session + with self._session_factory() as session: + # SQLAlchemy merge intelligently handles both insert and update operations + # based on the presence of the primary key + session.merge(db_model) session.commit() + # Update the in-memory cache for faster subsequent lookups + # Only cache if we have a node_execution_id to use as the cache key + if db_model.node_execution_id: + logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}") + self._node_execution_cache[db_model.node_execution_id] = db_model + def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: """ - Retrieve a WorkflowNodeExecution by its node_execution_id. + Retrieve a NodeExecution by its node_execution_id. + + First checks the in-memory cache, and if not found, queries the database. + If found in the database, adds it to the cache for future lookups. Args: node_execution_id: The node execution ID Returns: - The WorkflowNodeExecution instance if found, None otherwise + The NodeExecution instance if found, None otherwise """ + # First check the cache + if node_execution_id in self._node_execution_cache: + logger.debug(f"Cache hit for node_execution_id: {node_execution_id}") + # Convert cached DB model to domain model + cached_db_model = self._node_execution_cache[node_execution_id] + return self._to_domain_model(cached_db_model) + + # If not in cache, query the database + logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database") with self._session_factory() as session: - stmt = select(WorkflowNodeExecution).where( - WorkflowNodeExecution.node_execution_id == node_execution_id, - WorkflowNodeExecution.tenant_id == self._tenant_id, + stmt = select(WorkflowNodeExecutionModel).where( + WorkflowNodeExecutionModel.node_execution_id == node_execution_id, + WorkflowNodeExecutionModel.tenant_id == self._tenant_id, ) if self._app_id: - stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) - return session.scalar(stmt) + db_model = session.scalar(stmt) + if db_model: + # Add DB model to cache + self._node_execution_cache[node_execution_id] = db_model - def get_by_workflow_run( + # Convert to domain model and return + return self._to_domain_model(db_model) + + return None + + def get_db_models_by_workflow_run( self, workflow_run_id: str, order_config: Optional[OrderConfig] = None, - ) -> Sequence[WorkflowNodeExecution]: + ) -> Sequence[WorkflowNodeExecutionModel]: """ - Retrieve all WorkflowNodeExecution instances for a specific workflow run. + Retrieve all WorkflowNodeExecution database models for a specific workflow run. + + This method directly returns database models without converting to domain models, + which is useful when you need to access database-specific fields like triggered_from. + It also updates the in-memory cache with the retrieved models. Args: workflow_run_id: The workflow run ID @@ -102,23 +267,23 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of WorkflowNodeExecution instances + A list of WorkflowNodeExecution database models """ with self._session_factory() as session: - stmt = select(WorkflowNodeExecution).where( - WorkflowNodeExecution.workflow_run_id == workflow_run_id, - WorkflowNodeExecution.tenant_id == self._tenant_id, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + stmt = select(WorkflowNodeExecutionModel).where( + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + WorkflowNodeExecutionModel.tenant_id == self._tenant_id, + WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) if self._app_id: - stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) # Apply ordering if provided if order_config and order_config.order_by: order_columns: list[UnaryExpression] = [] for field in order_config.order_by: - column = getattr(WorkflowNodeExecution, field, None) + column = getattr(WorkflowNodeExecutionModel, field, None) if not column: continue if order_config.order_direction == "desc": @@ -129,49 +294,83 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) if order_columns: stmt = stmt.order_by(*order_columns) - return session.scalars(stmt).all() + db_models = session.scalars(stmt).all() - def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: + # Update the cache with the retrieved DB models + for model in db_models: + if model.node_execution_id: + self._node_execution_cache[model.node_execution_id] = model + + return db_models + + def get_by_workflow_run( + self, + workflow_run_id: str, + order_config: Optional[OrderConfig] = None, + ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all running WorkflowNodeExecution instances for a specific workflow run. + Retrieve all NodeExecution instances for a specific workflow run. + + This method always queries the database to ensure complete and ordered results, + but updates the cache with any retrieved executions. Args: workflow_run_id: The workflow run ID + order_config: Optional configuration for ordering results + order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) + order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of running WorkflowNodeExecution instances + A list of NodeExecution instances """ - with self._session_factory() as session: - stmt = select(WorkflowNodeExecution).where( - WorkflowNodeExecution.workflow_run_id == workflow_run_id, - WorkflowNodeExecution.tenant_id == self._tenant_id, - WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING, - WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) + # Get the database models using the new method + db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config) - if self._app_id: - stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + # Convert database models to domain models + domain_models = [] + for model in db_models: + domain_model = self._to_domain_model(model) + domain_models.append(domain_model) - return session.scalars(stmt).all() + return domain_models - def update(self, execution: WorkflowNodeExecution) -> None: + def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: """ - Update an existing WorkflowNodeExecution instance and commit changes to the database. + Retrieve all running NodeExecution instances for a specific workflow run. + + This method queries the database directly and updates the cache with any + retrieved executions that have a node_execution_id. Args: - execution: The WorkflowNodeExecution instance to update + workflow_run_id: The workflow run ID + + Returns: + A list of running NodeExecution instances """ with self._session_factory() as session: - # Ensure tenant_id is set - if not execution.tenant_id: - execution.tenant_id = self._tenant_id + stmt = select(WorkflowNodeExecutionModel).where( + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + WorkflowNodeExecutionModel.tenant_id == self._tenant_id, + WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING, + WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + if self._app_id: + stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) - # Set app_id if provided and not already set - if self._app_id and not execution.app_id: - execution.app_id = self._app_id + db_models = session.scalars(stmt).all() + domain_models = [] - session.merge(execution) - session.commit() + for model in db_models: + # Update cache if node_execution_id is present + if model.node_execution_id: + self._node_execution_cache[model.node_execution_id] = model + + # Convert to domain model + domain_model = self._to_domain_model(model) + domain_models.append(domain_model) + + return domain_models def clear(self) -> None: """ @@ -179,12 +378,13 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) This method deletes all WorkflowNodeExecution records that match the tenant_id and app_id (if provided) associated with this repository instance. + It also clears the in-memory cache. """ with self._session_factory() as session: - stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id) + stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id) if self._app_id: - stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) + stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) result = session.execute(stmt) session.commit() @@ -194,3 +394,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}" + (f" and app {self._app_id}" if self._app_id else "") ) + + # Clear the in-memory cache + self._node_execution_cache.clear() + logger.info("Cleared in-memory node execution cache") diff --git a/api/core/tools/builtin_tool/_position.yaml b/api/core/tools/builtin_tool/_position.yaml index b5875e2075..0e811de311 100644 --- a/api/core/tools/builtin_tool/_position.yaml +++ b/api/core/tools/builtin_tool/_position.yaml @@ -1,3 +1,4 @@ +- audio - code - time -- qrcode +- webscraper diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index 2f2f1ebbdd..2f5cc6d4c0 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -168,7 +168,7 @@ class ApiTool(Tool): cookies[parameter["name"]] = value elif parameter["in"] == "header": - headers[parameter["name"]] = value + headers[parameter["name"]] = str(value) # check if there is a request body and handle it if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None: diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index 37375f4a71..03047c0545 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -279,7 +279,6 @@ class ToolParameter(PluginParameter): :param options: the options of the parameter """ # convert options to ToolParameterOption - # FIXME fix the type error if options: option_objs = [ PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 3dce1ca293..178f2b9689 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -32,7 +32,7 @@ from core.tools.errors import ( from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import Message, MessageFile @@ -339,9 +339,9 @@ class ToolEngine: url=message.url, upload_file_id=tool_file_id, created_by_role=( - CreatedByRole.ACCOUNT + CreatorUserRole.ACCOUNT if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else CreatedByRole.END_USER + else CreatorUserRole.END_USER ), created_by=user_id, ) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index aa2661fe63..0bfe6329b1 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -528,7 +528,7 @@ class ToolManager: yield provider except Exception: - logger.exception(f"load builtin provider {provider}") + logger.exception(f"load builtin provider {provider_path}") continue # set builtin providers loaded cls._builtin_providers_loaded = True @@ -644,10 +644,10 @@ class ToolManager: ) workflow_provider_controllers: list[WorkflowToolProviderController] = [] - for provider in workflow_providers: + for workflow_provider in workflow_providers: try: workflow_provider_controllers.append( - ToolTransformService.workflow_provider_to_controller(db_provider=provider) + ToolTransformService.workflow_provider_to_controller(db_provider=workflow_provider) ) except Exception: # app has been deleted diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 032274b87e..2cbc4b9821 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -8,6 +8,7 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.models.document import Document as RagDocument from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod @@ -84,13 +85,17 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): document_context_list = [] index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] - segments = DocumentSegment.query.filter( - DocumentSegment.dataset_id.in_(self.dataset_ids), - DocumentSegment.completed_at.isnot(None), - DocumentSegment.status == "completed", - DocumentSegment.enabled == True, - DocumentSegment.index_node_id.in_(index_node_ids), - ).all() + segments = ( + db.session.query(DocumentSegment) + .filter( + DocumentSegment.dataset_id.in_(self.dataset_ids), + DocumentSegment.completed_at.isnot(None), + DocumentSegment.status == "completed", + DocumentSegment.enabled == True, + DocumentSegment.index_node_id.in_(index_node_ids), + ) + .all() + ) if segments: index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} @@ -103,38 +108,42 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): else: document_context_list.append(segment.get_sign_content()) if self.return_resource: - context_list = [] + context_list: list[RetrievalSourceMetadata] = [] resource_number = 1 for segment in sorted_segments: - dataset = Dataset.query.filter_by(id=segment.dataset_id).first() - document = Document.query.filter( - Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() + dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() + document = ( + db.session.query(Document) + .filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ) + .first() + ) if dataset and document: - source = { - "position": resource_number, - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, - "segment_id": segment.id, - "retriever_from": self.retriever_from, - "score": document_score_list.get(segment.index_node_id, None), - "doc_metadata": document.doc_metadata, - } + source = RetrievalSourceMetadata( + position=resource_number, + dataset_id=dataset.id, + dataset_name=dataset.name, + document_id=document.id, + document_name=document.name, + data_source_type=document.data_source_type, + segment_id=segment.id, + retriever_from=self.retriever_from, + score=document_score_list.get(segment.index_node_id, None), + doc_metadata=document.doc_metadata, + ) if self.retriever_from == "dev": - source["hit_count"] = segment.hit_count - source["word_count"] = segment.word_count - source["segment_position"] = segment.position - source["index_node_hash"] = segment.index_node_hash + source.hit_count = segment.hit_count + source.word_count = segment.word_count + source.segment_position = segment.position + source.index_node_hash = segment.index_node_hash if segment.answer: - source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: - source["content"] = segment.content + source.content = segment.content context_list.append(source) resource_number += 1 @@ -144,8 +153,6 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): return str("\n".join(document_context_list)) return "" - raise RuntimeError("not segments found") - def _retriever( self, flask_app: Flask, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index ed97b44f95..ff1d9021ce 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig from core.rag.datasource.retrieval_service import RetrievalService +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.rag.entities.context_entities import DocumentContext from core.rag.models.document import Document as RetrievalDocument from core.rag.retrieval.dataset_retrieval import DatasetRetrieval @@ -14,7 +15,7 @@ from models.dataset import Dataset from models.dataset import Document as DatasetDocument from services.external_knowledge_service import ExternalDatasetService -default_retrieval_model = { +default_retrieval_model: dict[str, Any] = { "search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "reranking_enable": False, "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, @@ -79,7 +80,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): else: document_ids_filter = None if dataset.provider == "external": - results = [] + results: list[RetrievalDocument] = [] external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( tenant_id=dataset.tenant_id, dataset_id=dataset.id, @@ -100,21 +101,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): document.metadata["dataset_name"] = dataset.name results.append(document) # deal with external documents - context_list = [] + context_list: list[RetrievalSourceMetadata] = [] for position, item in enumerate(results, start=1): if item.metadata is not None: - source = { - "position": position, - "dataset_id": item.metadata.get("dataset_id"), - "dataset_name": item.metadata.get("dataset_name"), - "document_id": item.metadata.get("document_id") or item.metadata.get("title"), - "document_name": item.metadata.get("title"), - "data_source_type": "external", - "retriever_from": self.retriever_from, - "score": item.metadata.get("score"), - "title": item.metadata.get("title"), - "content": item.page_content, - } + source = RetrievalSourceMetadata( + position=position, + dataset_id=item.metadata.get("dataset_id"), + dataset_name=item.metadata.get("dataset_name"), + document_id=item.metadata.get("document_id") or item.metadata.get("title"), + document_name=item.metadata.get("title"), + data_source_type="external", + retriever_from=self.retriever_from, + score=item.metadata.get("score"), + title=item.metadata.get("title"), + content=item.page_content, + ) context_list.append(source) for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(context_list) @@ -125,6 +126,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): return "" # get retrieval model , if the model is not setting , using default retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model + retrieval_resource_list: list[RetrievalSourceMetadata] = [] if dataset.indexing_technique == "economy": # use keyword table query documents = RetrievalService.retrieve( @@ -162,7 +164,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): for item in documents: if item.metadata is not None and item.metadata.get("score"): document_score_list[item.metadata["doc_id"]] = item.metadata["score"] - document_context_list = [] + document_context_list: list[DocumentContext] = [] records = RetrievalService.format_retrieval_documents(documents) if records: for record in records: @@ -181,48 +183,52 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): score=record.score, ) ) - retrieval_resource_list = [] + if self.return_resource: for record in records: segment = record.segment - dataset = Dataset.query.filter_by(id=segment.dataset_id).first() - document = DatasetDocument.query.filter( - DatasetDocument.id == segment.document_id, - DatasetDocument.enabled == True, - DatasetDocument.archived == False, - ).first() + dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() + document = ( + db.session.query(DatasetDocument) # type: ignore + .filter( + DatasetDocument.id == segment.document_id, + DatasetDocument.enabled == True, + DatasetDocument.archived == False, + ) + .first() + ) if dataset and document: - source = { - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, # type: ignore - "document_name": document.name, # type: ignore - "data_source_type": document.data_source_type, # type: ignore - "segment_id": segment.id, - "retriever_from": self.retriever_from, - "score": record.score or 0.0, - "doc_metadata": document.doc_metadata, # type: ignore - } + source = RetrievalSourceMetadata( + dataset_id=dataset.id, + dataset_name=dataset.name, + document_id=document.id, # type: ignore + document_name=document.name, # type: ignore + data_source_type=document.data_source_type, # type: ignore + segment_id=segment.id, + retriever_from=self.retriever_from, + score=record.score or 0.0, + doc_metadata=document.doc_metadata, # type: ignore + ) if self.retriever_from == "dev": - source["hit_count"] = segment.hit_count - source["word_count"] = segment.word_count - source["segment_position"] = segment.position - source["index_node_hash"] = segment.index_node_hash + source.hit_count = segment.hit_count + source.word_count = segment.word_count + source.segment_position = segment.position + source.index_node_hash = segment.index_node_hash if segment.answer: - source["content"] = f"question:{segment.content} \nanswer:{segment.answer}" + source.content = f"question:{segment.content} \nanswer:{segment.answer}" else: - source["content"] = segment.content + source.content = segment.content retrieval_resource_list.append(source) if self.return_resource and retrieval_resource_list: retrieval_resource_list = sorted( retrieval_resource_list, - key=lambda x: x.get("score") or 0.0, + key=lambda x: x.score or 0.0, reverse=True, ) for position, item in enumerate(retrieval_resource_list, start=1): # type: ignore - item["position"] = position # type: ignore + item.position = position # type: ignore for hit_callback in self.hit_callbacks: hit_callback.return_retriever_resource_info(retrieval_resource_list) if document_context_list: diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 257d96133e..9998de0465 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -32,14 +32,14 @@ class ToolFileMessageTransformer: try: assert isinstance(message.message, ToolInvokeMessage.TextMessage) tool_file_manager = ToolFileManager() - file = tool_file_manager.create_file_by_url( + tool_file = tool_file_manager.create_file_by_url( user_id=user_id, tenant_id=tenant_id, file_url=message.message.text, conversation_id=conversation_id, ) - url = f"/files/tools/{file.id}{guess_extension(file.mimetype) or '.png'}" + url = f"/files/tools/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}" yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, @@ -66,10 +66,9 @@ class ToolFileMessageTransformer: if not isinstance(message.message, ToolInvokeMessage.BlobMessage): raise ValueError("unexpected message type") - # FIXME: should do a type check here. assert isinstance(message.message.blob, bytes) tool_file_manager = ToolFileManager() - file = tool_file_manager.create_file_by_raw( + tool_file = tool_file_manager.create_file_by_raw( user_id=user_id, tenant_id=tenant_id, conversation_id=conversation_id, @@ -78,7 +77,7 @@ class ToolFileMessageTransformer: filename=filename, ) - url = cls.get_tool_file_url(tool_file_id=file.id, extension=guess_extension(file.mimetype)) + url = cls.get_tool_file_url(tool_file_id=tool_file.id, extension=guess_extension(tool_file.mimetype)) # check if file is image if "image" in mimetype: diff --git a/api/core/tools/utils/parser.py b/api/core/tools/utils/parser.py index f72291783a..3f844e8234 100644 --- a/api/core/tools/utils/parser.py +++ b/api/core/tools/utils/parser.py @@ -55,6 +55,13 @@ class ApiBasedToolSchemaParser: # convert parameters parameters = [] if "parameters" in interface["operation"]: + for i, parameter in enumerate(interface["operation"]["parameters"]): + if "$ref" in parameter: + root = openapi + reference = parameter["$ref"].split("/")[1:] + for ref in reference: + root = root[ref] + interface["operation"]["parameters"][i] = root for parameter in interface["operation"]["parameters"]: tool_parameter = ToolParameter( name=parameter["name"], diff --git a/api/core/tools/utils/web_reader_tool.py b/api/core/tools/utils/web_reader_tool.py index d42fd99fce..cbd06fc186 100644 --- a/api/core/tools/utils/web_reader_tool.py +++ b/api/core/tools/utils/web_reader_tool.py @@ -1,21 +1,13 @@ -import hashlib -import json import mimetypes -import os import re -import site -import subprocess -import tempfile -import unicodedata -from contextlib import contextmanager -from pathlib import Path -from typing import Any, Literal, Optional, cast +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Any, Optional, cast from urllib.parse import unquote import chardet import cloudscraper # type: ignore -from bs4 import BeautifulSoup, CData, Comment, NavigableString # type: ignore -from regex import regex # type: ignore +from readabilipy import simple_json_from_html_string # type: ignore from core.helper import ssrf_proxy from core.rag.extractor import extract_processor @@ -23,9 +15,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor FULL_TEMPLATE = """ TITLE: {title} -AUTHORS: {authors} -PUBLISH DATE: {publish_date} -TOP_IMAGE_URL: {top_image} +AUTHOR: {author} TEXT: {text} @@ -73,8 +63,8 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str: response = ssrf_proxy.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) elif response.status_code == 403: scraper = cloudscraper.create_scraper() - scraper.perform_request = ssrf_proxy.make_request - response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) + scraper.perform_request = ssrf_proxy.make_request # type: ignore + response = scraper.get(url, headers=headers, follow_redirects=True, timeout=(120, 300)) # type: ignore if response.status_code != 200: return "URL returned status code {}.".format(response.status_code) @@ -90,273 +80,36 @@ def get_url(url: str, user_agent: Optional[str] = None) -> str: else: content = response.text - a = extract_using_readabilipy(content) + article = extract_using_readabilipy(content) - if not a["plain_text"] or not a["plain_text"].strip(): + if not article.text: return "" res = FULL_TEMPLATE.format( - title=a["title"], - authors=a["byline"], - publish_date=a["date"], - top_image="", - text=a["plain_text"] or "", + title=article.title, + author=article.auther, + text=article.text, ) return res -def extract_using_readabilipy(html): - with tempfile.NamedTemporaryFile(delete=False, mode="w+") as f_html: - f_html.write(html) - f_html.close() - html_path = f_html.name +@dataclass +class Article: + title: str + auther: str + text: Sequence[dict] - # Call Mozilla's Readability.js Readability.parse() function via node, writing output to a temporary file - article_json_path = html_path + ".json" - jsdir = os.path.join(find_module_path("readabilipy"), "javascript") - with chdir(jsdir): - subprocess.check_call(["node", "ExtractArticle.js", "-i", html_path, "-o", article_json_path]) - # Read output of call to Readability.parse() from JSON file and return as Python dictionary - input_json = json.loads(Path(article_json_path).read_text(encoding="utf-8")) - - # Deleting files after processing - os.unlink(article_json_path) - os.unlink(html_path) - - article_json: dict[str, Any] = { - "title": None, - "byline": None, - "date": None, - "content": None, - "plain_content": None, - "plain_text": None, - } - # Populate article fields from readability fields where present - if input_json: - if input_json.get("title"): - article_json["title"] = input_json["title"] - if input_json.get("byline"): - article_json["byline"] = input_json["byline"] - if input_json.get("date"): - article_json["date"] = input_json["date"] - if input_json.get("content"): - article_json["content"] = input_json["content"] - article_json["plain_content"] = plain_content(article_json["content"], False, False) - article_json["plain_text"] = extract_text_blocks_as_plain_text(article_json["plain_content"]) - if input_json.get("textContent"): - article_json["plain_text"] = input_json["textContent"] - article_json["plain_text"] = re.sub(r"\n\s*\n", "\n", article_json["plain_text"]) - - return article_json - - -def find_module_path(module_name): - for package_path in site.getsitepackages(): - potential_path = os.path.join(package_path, module_name) - if os.path.exists(potential_path): - return potential_path - - return None - - -@contextmanager -def chdir(path): - """Change directory in context and return to original on exit""" - # From https://stackoverflow.com/a/37996581, couldn't find a built-in - original_path = os.getcwd() - os.chdir(path) - try: - yield - finally: - os.chdir(original_path) - - -def extract_text_blocks_as_plain_text(paragraph_html): - # Load article as DOM - soup = BeautifulSoup(paragraph_html, "html.parser") - # Select all lists - list_elements = soup.find_all(["ul", "ol"]) - # Prefix text in all list items with "* " and make lists paragraphs - for list_element in list_elements: - plain_items = "".join( - list(filter(None, [plain_text_leaf_node(li)["text"] for li in list_element.find_all("li")])) - ) - list_element.string = plain_items - list_element.name = "p" - # Select all text blocks - text_blocks = [s.parent for s in soup.find_all(string=True)] - text_blocks = [plain_text_leaf_node(block) for block in text_blocks] - # Drop empty paragraphs - text_blocks = list(filter(lambda p: p["text"] is not None, text_blocks)) - return text_blocks - - -def plain_text_leaf_node(element): - # Extract all text, stripped of any child HTML elements and normalize it - plain_text = normalize_text(element.get_text()) - if plain_text != "" and element.name == "li": - plain_text = "* {}, ".format(plain_text) - if plain_text == "": - plain_text = None - if "data-node-index" in element.attrs: - plain = {"node_index": element["data-node-index"], "text": plain_text} - else: - plain = {"text": plain_text} - return plain - - -def plain_content(readability_content, content_digests, node_indexes): - # Load article as DOM - soup = BeautifulSoup(readability_content, "html.parser") - # Make all elements plain - elements = plain_elements(soup.contents, content_digests, node_indexes) - if node_indexes: - # Add node index attributes to nodes - elements = [add_node_indexes(element) for element in elements] - # Replace article contents with plain elements - soup.contents = elements - return str(soup) - - -def plain_elements(elements, content_digests, node_indexes): - # Get plain content versions of all elements - elements = [plain_element(element, content_digests, node_indexes) for element in elements] - if content_digests: - # Add content digest attribute to nodes - elements = [add_content_digest(element) for element in elements] - return elements - - -def plain_element(element, content_digests, node_indexes): - # For lists, we make each item plain text - if is_leaf(element): - # For leaf node elements, extract the text content, discarding any HTML tags - # 1. Get element contents as text - plain_text = element.get_text() - # 2. Normalize the extracted text string to a canonical representation - plain_text = normalize_text(plain_text) - # 3. Update element content to be plain text - element.string = plain_text - elif is_text(element): - if is_non_printing(element): - # The simplified HTML may have come from Readability.js so might - # have non-printing text (e.g. Comment or CData). In this case, we - # keep the structure, but ensure that the string is empty. - element = type(element)("") - else: - plain_text = element.string - plain_text = normalize_text(plain_text) - element = type(element)(plain_text) - else: - # If not a leaf node or leaf type call recursively on child nodes, replacing - element.contents = plain_elements(element.contents, content_digests, node_indexes) - return element - - -def add_node_indexes(element, node_index="0"): - # Can't add attributes to string types - if is_text(element): - return element - # Add index to current element - element["data-node-index"] = node_index - # Add index to child elements - for local_idx, child in enumerate([c for c in element.contents if not is_text(c)], start=1): - # Can't add attributes to leaf string types - child_index = "{stem}.{local}".format(stem=node_index, local=local_idx) - add_node_indexes(child, node_index=child_index) - return element - - -def normalize_text(text): - """Normalize unicode and whitespace.""" - # Normalize unicode first to try and standardize whitespace characters as much as possible before normalizing them - text = strip_control_characters(text) - text = normalize_unicode(text) - text = normalize_whitespace(text) - return text - - -def strip_control_characters(text): - """Strip out unicode control characters which might break the parsing.""" - # Unicode control characters - # [Cc]: Other, Control [includes new lines] - # [Cf]: Other, Format - # [Cn]: Other, Not Assigned - # [Co]: Other, Private Use - # [Cs]: Other, Surrogate - control_chars = {"Cc", "Cf", "Cn", "Co", "Cs"} - retained_chars = ["\t", "\n", "\r", "\f"] - - # Remove non-printing control characters - return "".join( - [ - "" if (unicodedata.category(char) in control_chars) and (char not in retained_chars) else char - for char in text - ] +def extract_using_readabilipy(html: str): + json_article: dict[str, Any] = simple_json_from_html_string(html, use_readability=True) + article = Article( + title=json_article.get("title") or "", + auther=json_article.get("byline") or "", + text=json_article.get("plain_text") or [], ) - -def normalize_unicode(text): - """Normalize unicode such that things that are visually equivalent map to the same unicode string where possible.""" - normal_form: Literal["NFC", "NFD", "NFKC", "NFKD"] = "NFKC" - text = unicodedata.normalize(normal_form, text) - return text - - -def normalize_whitespace(text): - """Replace runs of whitespace characters with a single space as this is what happens when HTML text is displayed.""" - text = regex.sub(r"\s+", " ", text) - # Remove leading and trailing whitespace - text = text.strip() - return text - - -def is_leaf(element): - return element.name in {"p", "li"} - - -def is_text(element): - return isinstance(element, NavigableString) - - -def is_non_printing(element): - return any(isinstance(element, _e) for _e in [Comment, CData]) - - -def add_content_digest(element): - if not is_text(element): - element["data-content-digest"] = content_digest(element) - return element - - -def content_digest(element): - digest: Any - if is_text(element): - # Hash - trimmed_string = element.string.strip() - if trimmed_string == "": - digest = "" - else: - digest = hashlib.sha256(trimmed_string.encode("utf-8")).hexdigest() - else: - contents = element.contents - num_contents = len(contents) - if num_contents == 0: - # No hash when no child elements exist - digest = "" - elif num_contents == 1: - # If single child, use digest of child - digest = content_digest(contents[0]) - else: - # Build content digest from the "non-empty" digests of child nodes - digest = hashlib.sha256() - child_digests = list(filter(lambda x: x != "", [content_digest(content) for content in contents])) - for child in child_digests: - digest.update(child.encode("utf-8")) - digest = digest.hexdigest() - return digest + return article def get_image_upload_file_ids(content): diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 241b4a94de..57c93d1d45 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -1,7 +1,9 @@ import json import logging from collections.abc import Generator -from typing import Any, Optional, Union, cast +from typing import Any, Optional, cast + +from flask_login import current_user from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod from core.tools.__base.tool import Tool @@ -87,7 +89,7 @@ class WorkflowTool(Tool): result = generator.generate( app_model=app, workflow=workflow, - user=self._get_user(user_id), + user=cast("Account | EndUser", current_user), args={"inputs": tool_parameters, "files": files}, invoke_from=self.runtime.invoke_from, streaming=False, @@ -111,20 +113,6 @@ class WorkflowTool(Tool): yield self.create_text_message(json.dumps(outputs, ensure_ascii=False)) yield self.create_json_message(outputs) - def _get_user(self, user_id: str) -> Union[EndUser, Account]: - """ - get the user by user id - """ - - user = db.session.query(EndUser).filter(EndUser.id == user_id).first() - if not user: - user = db.session.query(Account).filter(Account.id == user_id).first() - - if not user: - raise ValueError("user not found") - - return user - def fork_tool_runtime(self, runtime: ToolRuntime) -> "WorkflowTool": """ fork a new tool with metadata diff --git a/api/core/variables/consts.py b/api/core/variables/consts.py new file mode 100644 index 0000000000..03b277d619 --- /dev/null +++ b/api/core/variables/consts.py @@ -0,0 +1,7 @@ +# The minimal selector length for valid variables. +# +# The first element of the selector is the node id, and the second element is the variable name. +# +# If the selector length is more than 2, the remaining parts are the keys / indexes paths used +# to extract part of the variable value. +MIN_SELECTORS_LENGTH = 2 diff --git a/api/core/variables/utils.py b/api/core/variables/utils.py new file mode 100644 index 0000000000..e5d222af7d --- /dev/null +++ b/api/core/variables/utils.py @@ -0,0 +1,8 @@ +from collections.abc import Iterable, Sequence + + +def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]: + selectors = [node_id, name] + if paths: + selectors.extend(paths) + return selectors diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index c32815b24d..b650b1682e 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -30,7 +30,7 @@ class Variable(Segment): """ id: str = Field( - default=lambda _: str(uuid4()), + default_factory=lambda: str(uuid4()), description="Unique identity for variable.", ) name: str diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index 82fd6cdc30..687ec8e47c 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,36 +1,10 @@ from collections.abc import Mapping -from enum import StrEnum from typing import Any, Optional from pydantic import BaseModel from core.model_runtime.entities.llm_entities import LLMUsage -from models.workflow import WorkflowNodeExecutionStatus - - -class NodeRunMetadataKey(StrEnum): - """ - Node Run Metadata Key. - """ - - TOTAL_TOKENS = "total_tokens" - TOTAL_PRICE = "total_price" - CURRENCY = "currency" - TOOL_INFO = "tool_info" - AGENT_LOG = "agent_log" - ITERATION_ID = "iteration_id" - ITERATION_INDEX = "iteration_index" - LOOP_ID = "loop_id" - LOOP_INDEX = "loop_index" - PARALLEL_ID = "parallel_id" - PARALLEL_START_NODE_ID = "parallel_start_node_id" - PARENT_PARALLEL_ID = "parent_parallel_id" - PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" - PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" - ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs - LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs - ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field - LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class NodeRunResult(BaseModel): @@ -43,7 +17,7 @@ class NodeRunResult(BaseModel): inputs: Optional[Mapping[str, Any]] = None # node inputs process_data: Optional[Mapping[str, Any]] = None # process data outputs: Optional[Mapping[str, Any]] = None # node outputs - metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # node metadata + metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # node metadata llm_usage: Optional[LLMUsage] = None # llm usage edge_source_handle: Optional[str] = None # source handle id of node with multiple branches diff --git a/api/core/workflow/entities/workflow_execution.py b/api/core/workflow/entities/workflow_execution.py new file mode 100644 index 0000000000..781be4b3c6 --- /dev/null +++ b/api/core/workflow/entities/workflow_execution.py @@ -0,0 +1,87 @@ +""" +Domain entities for workflow execution. + +Models are independent of the storage mechanism and don't contain +implementation details like tenant_id, app_id, etc. +""" + +from collections.abc import Mapping +from datetime import UTC, datetime +from enum import StrEnum +from typing import Any, Optional + +from pydantic import BaseModel, Field + + +class WorkflowType(StrEnum): + """ + Workflow Type Enum for domain layer + """ + + WORKFLOW = "workflow" + CHAT = "chat" + + +class WorkflowExecutionStatus(StrEnum): + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + STOPPED = "stopped" + PARTIAL_SUCCEEDED = "partial-succeeded" + + +class WorkflowExecution(BaseModel): + """ + Domain model for workflow execution based on WorkflowRun but without + user, tenant, and app attributes. + """ + + id_: str = Field(...) + workflow_id: str = Field(...) + workflow_version: str = Field(...) + workflow_type: WorkflowType = Field(...) + graph: Mapping[str, Any] = Field(...) + + inputs: Mapping[str, Any] = Field(...) + outputs: Optional[Mapping[str, Any]] = None + + status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING + error_message: str = Field(default="") + total_tokens: int = Field(default=0) + total_steps: int = Field(default=0) + exceptions_count: int = Field(default=0) + + started_at: datetime = Field(...) + finished_at: Optional[datetime] = None + + @property + def elapsed_time(self) -> float: + """ + Calculate elapsed time in seconds. + If workflow is not finished, use current time. + """ + end_time = self.finished_at or datetime.now(UTC).replace(tzinfo=None) + return (end_time - self.started_at).total_seconds() + + @classmethod + def new( + cls, + *, + id_: str, + workflow_id: str, + workflow_type: WorkflowType, + workflow_version: str, + graph: Mapping[str, Any], + inputs: Mapping[str, Any], + started_at: datetime, + ) -> "WorkflowExecution": + return WorkflowExecution( + id_=id_, + workflow_id=workflow_id, + workflow_type=workflow_type, + workflow_version=workflow_version, + graph=graph, + inputs=inputs, + status=WorkflowExecutionStatus.RUNNING, + started_at=started_at, + ) diff --git a/api/core/workflow/entities/workflow_node_execution.py b/api/core/workflow/entities/workflow_node_execution.py new file mode 100644 index 0000000000..773f5b777b --- /dev/null +++ b/api/core/workflow/entities/workflow_node_execution.py @@ -0,0 +1,122 @@ +""" +Domain entities for workflow node execution. + +This module contains the domain model for workflow node execution, which is used +by the core workflow module. These models are independent of the storage mechanism +and don't contain implementation details like tenant_id, app_id, etc. +""" + +from collections.abc import Mapping +from datetime import datetime +from enum import StrEnum +from typing import Any, Optional + +from pydantic import BaseModel, Field + +from core.workflow.nodes.enums import NodeType + + +class WorkflowNodeExecutionMetadataKey(StrEnum): + """ + Node Run Metadata Key. + """ + + TOTAL_TOKENS = "total_tokens" + TOTAL_PRICE = "total_price" + CURRENCY = "currency" + TOOL_INFO = "tool_info" + AGENT_LOG = "agent_log" + ITERATION_ID = "iteration_id" + ITERATION_INDEX = "iteration_index" + LOOP_ID = "loop_id" + LOOP_INDEX = "loop_index" + PARALLEL_ID = "parallel_id" + PARALLEL_START_NODE_ID = "parallel_start_node_id" + PARENT_PARALLEL_ID = "parent_parallel_id" + PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" + PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" + ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs + LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs + ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field + LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output + + +class WorkflowNodeExecutionStatus(StrEnum): + """ + Node Execution Status Enum. + """ + + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + EXCEPTION = "exception" + RETRY = "retry" + + +class WorkflowNodeExecution(BaseModel): + """ + Domain model for workflow node execution. + + This model represents the core business entity of a node execution, + without implementation details like tenant_id, app_id, etc. + + Note: User/context-specific fields (triggered_from, created_by, created_by_role) + have been moved to the repository implementation to keep the domain model clean. + These fields are still accepted in the constructor for backward compatibility, + but they are not stored in the model. + """ + + # Core identification fields + id: str # Unique identifier for this execution record + node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing + workflow_id: str # ID of the workflow this node belongs to + workflow_execution_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging) + + # Execution positioning and flow + index: int # Sequence number for ordering in trace visualization + predecessor_node_id: Optional[str] = None # ID of the node that executed before this one + node_id: str # ID of the node being executed + node_type: NodeType # Type of node (e.g., start, llm, knowledge) + title: str # Display title of the node + + # Execution data + inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node + process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data + outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node + + # Execution state + status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status + error: Optional[str] = None # Error message if execution failed + elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds + + # Additional metadata + metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.) + + # Timing information + created_at: datetime # When execution started + finished_at: Optional[datetime] = None # When execution completed + + def update_from_mapping( + self, + inputs: Optional[Mapping[str, Any]] = None, + process_data: Optional[Mapping[str, Any]] = None, + outputs: Optional[Mapping[str, Any]] = None, + metadata: Optional[Mapping[WorkflowNodeExecutionMetadataKey, Any]] = None, + ) -> None: + """ + Update the model from mappings. + + Args: + inputs: The inputs to update + process_data: The process data to update + outputs: The outputs to update + metadata: The metadata to update + """ + if inputs is not None: + self.inputs = dict(inputs) + if process_data is not None: + self.process_data = dict(process_data) + if outputs is not None: + self.outputs = dict(outputs) + if metadata is not None: + self.metadata = dict(metadata) diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index 9642efa1a5..b52a2b0e6e 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -13,4 +13,4 @@ class SystemVariableKey(StrEnum): DIALOGUE_COUNT = "dialogue_count" APP_ID = "app_id" WORKFLOW_ID = "workflow_id" - WORKFLOW_RUN_ID = "workflow_run_id" + WORKFLOW_EXECUTION_ID = "workflow_run_id" diff --git a/api/core/workflow/graph_engine/entities/event.py b/api/core/workflow/graph_engine/entities/event.py index 689a07c4f6..9a4939502e 100644 --- a/api/core/workflow/graph_engine/entities/event.py +++ b/api/core/workflow/graph_engine/entities/event.py @@ -1,9 +1,10 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from datetime import datetime from typing import Any, Optional from pydantic import BaseModel, Field +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.node_entities import AgentNodeStrategyInit from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState from core.workflow.nodes import NodeType @@ -82,7 +83,7 @@ class NodeRunStreamChunkEvent(BaseNodeEvent): class NodeRunRetrieverResourceEvent(BaseNodeEvent): - retriever_resources: list[dict] = Field(..., description="retriever resources") + retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") context: str = Field(..., description="context") diff --git a/api/core/workflow/graph_engine/entities/graph.py b/api/core/workflow/graph_engine/entities/graph.py index 5c672c985b..8e5b1e7142 100644 --- a/api/core/workflow/graph_engine/entities/graph.py +++ b/api/core/workflow/graph_engine/entities/graph.py @@ -36,7 +36,7 @@ class Graph(BaseModel): root_node_id: str = Field(..., description="root node id of the graph") node_ids: list[str] = Field(default_factory=list, description="graph node ids") node_id_config_mapping: dict[str, dict] = Field( - default_factory=list, description="node configs mapping (node id: node config)" + default_factory=dict, description="node configs mapping (node id: node config)" ) edge_mapping: dict[str, list[GraphEdge]] = Field( default_factory=dict, description="graph edge mapping (source node id: edges)" diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index 7683dcc9dc..f2d9c98936 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -6,7 +6,7 @@ from typing import Optional from pydantic import BaseModel, Field from core.workflow.entities.node_entities import NodeRunResult -from models.workflow import WorkflowNodeExecutionStatus +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus class RouteNodeState(BaseModel): diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 36273d8ec1..ee2164f22f 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -14,8 +14,9 @@ from flask import Flask, current_app from configs import dify_config from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult from core.workflow.entities.variable_pool import VariablePool, VariableValue +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager from core.workflow.graph_engine.entities.event import ( BaseAgentEvent, @@ -52,9 +53,9 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING -from extensions.ext_database import db +from libs.flask_utils import preserve_flask_contexts from models.enums import UserFrom -from models.workflow import WorkflowNodeExecutionStatus, WorkflowType +from models.workflow import WorkflowType logger = logging.getLogger(__name__) @@ -537,10 +538,8 @@ class GraphEngine: """ Run parallel nodes """ - for var, val in context.items(): - var.set(val) - with flask_app.app_context(): + with preserve_flask_contexts(flask_app, context_vars=context): try: q.put( ParallelBranchRunStartedEvent( @@ -593,8 +592,6 @@ class GraphEngine: error=str(e), ) ) - finally: - db.session.remove() def _run_node( self, @@ -632,7 +629,6 @@ class GraphEngine: agent_strategy=agent_strategy, ) - db.session.close() max_retries = node_instance.node_data.retry_config.max_retries retry_interval = node_instance.node_data.retry_config.retry_interval_seconds retries = 0 @@ -643,26 +639,19 @@ class GraphEngine: retry_start_at = datetime.now(UTC).replace(tzinfo=None) # yield control to other threads time.sleep(0.001) - generator = node_instance.run() - for item in generator: - if isinstance(item, GraphEngineEvent): - if isinstance(item, BaseIterationEvent): - # add parallel info to iteration event - item.parallel_id = parallel_id - item.parallel_start_node_id = parallel_start_node_id - item.parent_parallel_id = parent_parallel_id - item.parent_parallel_start_node_id = parent_parallel_start_node_id - elif isinstance(item, BaseLoopEvent): - # add parallel info to loop event - item.parallel_id = parallel_id - item.parallel_start_node_id = parallel_start_node_id - item.parent_parallel_id = parent_parallel_id - item.parent_parallel_start_node_id = parent_parallel_start_node_id - - yield item + event_stream = node_instance.run() + for event in event_stream: + if isinstance(event, GraphEngineEvent): + # add parallel info to iteration event + if isinstance(event, BaseIterationEvent | BaseLoopEvent): + event.parallel_id = parallel_id + event.parallel_start_node_id = parallel_start_node_id + event.parent_parallel_id = parent_parallel_id + event.parent_parallel_start_node_id = parent_parallel_start_node_id + yield event else: - if isinstance(item, RunCompletedEvent): - run_result = item.run_result + if isinstance(event, RunCompletedEvent): + run_result = event.run_result if run_result.status == WorkflowNodeExecutionStatus.FAILED: if ( retries == max_retries @@ -698,7 +687,7 @@ class GraphEngine: # if run failed, handle error run_result = self._handle_continue_on_error( node_instance, - item.run_result, + event.run_result, self.graph_runtime_state.variable_pool, handle_exceptions=handle_exceptions, ) @@ -746,10 +735,12 @@ class GraphEngine: and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH ): run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS - if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS): + if run_result.metadata and run_result.metadata.get( + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS + ): # plus state total_tokens self.graph_runtime_state.total_tokens += int( - run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] + run_result.metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type] ) if run_result.llm_usage: @@ -772,13 +763,17 @@ class GraphEngine: if parallel_id and parallel_start_node_id: metadata_dict = dict(run_result.metadata) - metadata_dict[NodeRunMetadataKey.PARALLEL_ID] = parallel_id - metadata_dict[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id + metadata_dict[WorkflowNodeExecutionMetadataKey.PARALLEL_ID] = parallel_id + metadata_dict[WorkflowNodeExecutionMetadataKey.PARALLEL_START_NODE_ID] = ( + parallel_start_node_id + ) if parent_parallel_id and parent_parallel_start_node_id: - metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id - metadata_dict[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = ( - parent_parallel_start_node_id + metadata_dict[WorkflowNodeExecutionMetadataKey.PARENT_PARALLEL_ID] = ( + parent_parallel_id ) + metadata_dict[ + WorkflowNodeExecutionMetadataKey.PARENT_PARALLEL_START_NODE_ID + ] = parent_parallel_start_node_id run_result.metadata = metadata_dict yield NodeRunSucceededEvent( @@ -795,28 +790,28 @@ class GraphEngine: should_continue_retry = False break - elif isinstance(item, RunStreamChunkEvent): + elif isinstance(event, RunStreamChunkEvent): yield NodeRunStreamChunkEvent( id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, node_data=node_instance.node_data, - chunk_content=item.chunk_content, - from_variable_selector=item.from_variable_selector, + chunk_content=event.chunk_content, + from_variable_selector=event.from_variable_selector, route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, ) - elif isinstance(item, RunRetrieverResourceEvent): + elif isinstance(event, RunRetrieverResourceEvent): yield NodeRunRetrieverResourceEvent( id=node_instance.id, node_id=node_instance.node_id, node_type=node_instance.node_type, node_data=node_instance.node_data, - retriever_resources=item.retriever_resources, - context=item.context, + retriever_resources=event.retriever_resources, + context=event.context, route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, @@ -843,8 +838,6 @@ class GraphEngine: except Exception as e: logger.exception(f"Node {node_instance.node_data.title} run failed") raise e - finally: - db.session.close() def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): """ @@ -910,7 +903,7 @@ class GraphEngine: "error": error_result.error, "inputs": error_result.inputs, "metadata": { - NodeRunMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy, + WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy, }, } diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 771e0ca7a5..22c564c1fc 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -2,6 +2,9 @@ import json from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast +from sqlalchemy import select +from sqlalchemy.orm import Session + from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter from core.memory.token_buffer_memory import TokenBufferMemory @@ -15,6 +18,7 @@ from core.tools.tool_manager import ToolManager from core.variables.segments import StringSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated from core.workflow.nodes.base.entities import BaseNodeData @@ -25,7 +29,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from factories.agent_factory import get_plugin_agent_strategy from models.model import Conversation -from models.workflow import WorkflowNodeExecutionStatus class AgentNode(ToolNode): @@ -211,7 +214,7 @@ class AgentNode(ToolNode): ) if tool_runtime.entity.description: tool_runtime.entity.description.llm = ( - extra.get("descrption", "") or tool_runtime.entity.description.llm + extra.get("description", "") or tool_runtime.entity.description.llm ) for tool_runtime_params in tool_runtime.entity.parameters: tool_runtime_params.form = ( @@ -320,15 +323,12 @@ class AgentNode(ToolNode): return None conversation_id = conversation_id_variable.value - # get conversation - conversation = ( - db.session.query(Conversation) - .filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id) - .first() - ) + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id) + conversation = session.scalar(stmt) - if not conversation: - return None + if not conversation: + return None memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) @@ -356,7 +356,9 @@ class AgentNode(ToolNode): def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity: if model_schema.features: - for feature in model_schema.features: - if feature.value not in AgentOldVersionModelFeatures: + for feature in model_schema.features[:]: # Create a copy to safely modify during iteration + try: + AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value + except ValueError: model_schema.features.remove(feature) return model_schema diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 77e94375bf..075a41fb2f 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Literal, Union from pydantic import BaseModel @@ -26,7 +26,7 @@ class ParamsAutoGenerated(Enum): OPEN = 1 -class AgentOldVersionModelFeatures(Enum): +class AgentOldVersionModelFeatures(StrEnum): """ Enum class for old SDK version llm feature. """ diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 520cbdbb60..aa030870e2 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -3,6 +3,7 @@ from typing import Any, cast from core.variables import ArrayFileSegment, FileSegment from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter from core.workflow.nodes.answer.entities import ( AnswerNodeData, @@ -13,7 +14,6 @@ from core.workflow.nodes.answer.entities import ( from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser -from models.workflow import WorkflowNodeExecutionStatus class AnswerNode(BaseNode[AnswerNodeData]): diff --git a/api/core/workflow/nodes/answer/base_stream_processor.py b/api/core/workflow/nodes/answer/base_stream_processor.py index e4f2478890..09d5464d7a 100644 --- a/api/core/workflow/nodes/answer/base_stream_processor.py +++ b/api/core/workflow/nodes/answer/base_stream_processor.py @@ -57,7 +57,6 @@ class StreamProcessor(ABC): # The branch_identify parameter is added to ensure that # only nodes in the correct logical branch are included. - reachable_node_ids.append(edge.target_node_id) ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle) reachable_node_ids.extend(ids) else: @@ -74,6 +73,8 @@ class StreamProcessor(ABC): self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]: + if node_id not in self.rest_node_ids: + self.rest_node_ids.append(node_id) node_ids = [] for edge in self.graph.edge_mapping.get(node_id, []): if edge.target_node_id == self.graph.root_node_id: @@ -95,7 +96,12 @@ class StreamProcessor(ABC): if node_id not in self.rest_node_ids: return + if node_id in reachable_node_ids: + return + self.rest_node_ids.remove(node_id) + self.rest_node_ids.extend(set(reachable_node_ids) - set(self.rest_node_ids)) + for edge in self.graph.edge_mapping.get(node_id, []): if edge.target_node_id in reachable_node_ids: continue diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index e566770870..7da0c19740 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -4,9 +4,9 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent -from models.workflow import WorkflowNodeExecutionStatus from .entities import BaseNodeData diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 3c34c5b4e7..61c08a7d71 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -8,10 +8,10 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.variables.segments import ArrayFileSegment from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.code.entities import CodeNodeData from core.workflow.nodes.enums import NodeType -from models.workflow import WorkflowNodeExecutionStatus from .exc import ( CodeNodeError, @@ -127,7 +127,7 @@ class CodeNode(BaseNode[CodeNodeData]): depth: int = 1, ): if depth > dify_config.CODE_MAX_DEPTH: - raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") + raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.") transformed_result: dict[str, Any] = {} if output_schema is None: @@ -167,8 +167,11 @@ class CodeNode(BaseNode[CodeNodeData]): value=value, variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", ) - elif isinstance(first_element, dict) and all( - value is None or isinstance(value, dict) for value in output_value + elif ( + isinstance(first_element, dict) + and all(value is None or isinstance(value, dict) for value in output_value) + or isinstance(first_element, list) + and all(value is None or isinstance(value, list) for value in output_value) ): for i, value in enumerate(output_value): if value is not None: diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 8fb1baec89..429fed2d04 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -7,6 +7,7 @@ import tempfile from collections.abc import Mapping, Sequence from typing import Any, cast +import chardet import docx import pandas as pd import pypandoc # type: ignore @@ -25,9 +26,9 @@ from core.helper import ssrf_proxy from core.variables import ArrayFileSegment from core.variables.segments import FileSegment from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType -from models.workflow import WorkflowNodeExecutionStatus from .entities import DocumentExtractorNodeData from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError @@ -180,26 +181,64 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) def _extract_text_from_plain_text(file_content: bytes) -> str: try: - return file_content.decode("utf-8", "ignore") - except UnicodeDecodeError as e: - raise TextExtractionError("Failed to decode plain text file") from e + # Detect encoding using chardet + result = chardet.detect(file_content) + encoding = result["encoding"] + + # Fallback to utf-8 if detection fails + if not encoding: + encoding = "utf-8" + + return file_content.decode(encoding, errors="ignore") + except (UnicodeDecodeError, LookupError) as e: + # If decoding fails, try with utf-8 as last resort + try: + return file_content.decode("utf-8", errors="ignore") + except UnicodeDecodeError: + raise TextExtractionError(f"Failed to decode plain text file: {e}") from e def _extract_text_from_json(file_content: bytes) -> str: try: - json_data = json.loads(file_content.decode("utf-8", "ignore")) + # Detect encoding using chardet + result = chardet.detect(file_content) + encoding = result["encoding"] + + # Fallback to utf-8 if detection fails + if not encoding: + encoding = "utf-8" + + json_data = json.loads(file_content.decode(encoding, errors="ignore")) return json.dumps(json_data, indent=2, ensure_ascii=False) - except (UnicodeDecodeError, json.JSONDecodeError) as e: - raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e + except (UnicodeDecodeError, LookupError, json.JSONDecodeError) as e: + # If decoding fails, try with utf-8 as last resort + try: + json_data = json.loads(file_content.decode("utf-8", errors="ignore")) + return json.dumps(json_data, indent=2, ensure_ascii=False) + except (UnicodeDecodeError, json.JSONDecodeError): + raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e def _extract_text_from_yaml(file_content: bytes) -> str: """Extract the content from yaml file""" try: - yaml_data = yaml.safe_load_all(file_content.decode("utf-8", "ignore")) + # Detect encoding using chardet + result = chardet.detect(file_content) + encoding = result["encoding"] + + # Fallback to utf-8 if detection fails + if not encoding: + encoding = "utf-8" + + yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore")) return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)) - except (UnicodeDecodeError, yaml.YAMLError) as e: - raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e + except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e: + # If decoding fails, try with utf-8 as last resort + try: + yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore")) + return cast(str, yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False)) + except (UnicodeDecodeError, yaml.YAMLError): + raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e def _extract_text_from_pdf(file_content: bytes) -> str: @@ -338,26 +377,64 @@ def _extract_text_from_file(file: File): def _extract_text_from_csv(file_content: bytes) -> str: try: - csv_file = io.StringIO(file_content.decode("utf-8", "ignore")) + # Detect encoding using chardet + result = chardet.detect(file_content) + encoding = result["encoding"] + + # Fallback to utf-8 if detection fails + if not encoding: + encoding = "utf-8" + + try: + csv_file = io.StringIO(file_content.decode(encoding, errors="ignore")) + except (UnicodeDecodeError, LookupError): + # If decoding fails, try with utf-8 as last resort + csv_file = io.StringIO(file_content.decode("utf-8", errors="ignore")) + csv_reader = csv.reader(csv_file) rows = list(csv_reader) if not rows: return "" + # Combine multi-line text in the header row + header_row = [cell.replace("\n", " ").replace("\r", "") for cell in rows[0]] + # Create Markdown table - markdown_table = "| " + " | ".join(rows[0]) + " |\n" - markdown_table += "| " + " | ".join(["---"] * len(rows[0])) + " |\n" + markdown_table = "| " + " | ".join(header_row) + " |\n" + markdown_table += "| " + " | ".join(["-" * len(col) for col in rows[0]]) + " |\n" + + # Process each data row and combine multi-line text in each cell for row in rows[1:]: - markdown_table += "| " + " | ".join(row) + " |\n" + processed_row = [cell.replace("\n", " ").replace("\r", "") for cell in row] + markdown_table += "| " + " | ".join(processed_row) + " |\n" - return markdown_table.strip() + return markdown_table except Exception as e: raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e def _extract_text_from_excel(file_content: bytes) -> str: """Extract text from an Excel file using pandas.""" + + def _construct_markdown_table(df: pd.DataFrame) -> str: + """Manually construct a Markdown table from a DataFrame.""" + # Construct the header row + header_row = "| " + " | ".join(df.columns) + " |" + + # Construct the separator row + separator_row = "| " + " | ".join(["-" * len(col) for col in df.columns]) + " |" + + # Construct the data rows + data_rows = [] + for _, row in df.iterrows(): + data_row = "| " + " | ".join(map(str, row)) + " |" + data_rows.append(data_row) + + # Combine all rows into a single string + markdown_table = "\n".join([header_row, separator_row] + data_rows) + return markdown_table + try: excel_file = pd.ExcelFile(io.BytesIO(file_content)) markdown_table = "" @@ -365,8 +442,15 @@ def _extract_text_from_excel(file_content: bytes) -> str: try: df = excel_file.parse(sheet_name=sheet_name) df.dropna(how="all", inplace=True) - # Create Markdown table two times to separate tables with a newline - markdown_table += df.to_markdown(index=False) + "\n\n" + + # Combine multi-line text in each cell into a single line + df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore + + # Combine multi-line text in column names into a single line + df.columns = pd.Index([" ".join(col.splitlines()) for col in df.columns]) + + # Manually construct the Markdown table + markdown_table += _construct_markdown_table(df) + "\n\n" except Exception as e: continue return markdown_table diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 6acc915ab5..0e9756b243 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,8 +1,8 @@ from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.end.entities import EndNodeData from core.workflow.nodes.enums import NodeType -from models.workflow import WorkflowNodeExecutionStatus class EndNode(BaseNode[EndNodeData]): diff --git a/api/core/workflow/nodes/event/event.py b/api/core/workflow/nodes/event/event.py index 9fea3fbda3..3ebe80f245 100644 --- a/api/core/workflow/nodes/event/event.py +++ b/api/core/workflow/nodes/event/event.py @@ -1,10 +1,11 @@ +from collections.abc import Sequence from datetime import datetime from pydantic import BaseModel, Field from core.model_runtime.entities.llm_entities import LLMUsage +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.entities.node_entities import NodeRunResult -from models.workflow import WorkflowNodeExecutionStatus class RunCompletedEvent(BaseModel): @@ -17,7 +18,7 @@ class RunStreamChunkEvent(BaseModel): class RunRetrieverResourceEvent(BaseModel): - retriever_resources: list[dict] = Field(..., description="retriever resources") + retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") context: str = Field(..., description="context") @@ -37,11 +38,3 @@ class RunRetryEvent(BaseModel): error: str = Field(..., description="error") retry_index: int = Field(..., description="Retry attempt number") start_at: datetime = Field(..., description="Retry start time") - - -class SingleStepRetryEvent(NodeRunResult): - """Single step retry event""" - - status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RETRY - - elapsed_time: float = Field(..., description="elapsed time") diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index 2c42f5a1be..2c83b00d4a 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -1,8 +1,9 @@ import base64 import json +import secrets +import string from collections.abc import Mapping from copy import deepcopy -from random import randint from typing import Any, Literal from urllib.parse import urlencode, urlparse @@ -235,6 +236,10 @@ class Executor: files[key].append(file_tuple) # convert files to list for httpx request + # If there are no actual files, we still need to force httpx to use `multipart/form-data`. + # This is achieved by inserting a harmless placeholder file that will be ignored by the server. + if not files: + self.files = [("__multipart_placeholder__", ("", b"", "application/octet-stream"))] if files: self.files = [] for key, file_tuples in files.items(): @@ -373,7 +378,10 @@ class Executor: raw += f"{k}: {v}\r\n" body_string = "" - if self.files: + # Only log actual files if present. + # '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file. + # This prevents logging meaningless placeholder entries. + if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files): for key, (filename, content, mime_type) in self.files: body_string += f"--{boundary}\r\n" body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' @@ -427,4 +435,4 @@ def _generate_random_string(n: int) -> str: >>> _generate_random_string(5) 'abcde' """ - return "".join([chr(randint(97, 122)) for _ in range(n)]) + return "".join(secrets.choice(string.ascii_lowercase) for _ in range(n)) diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 1c82637974..6b1ac57c06 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -8,12 +8,12 @@ from core.file import File, FileTransferMethod from core.tools.tool_file_manager import ToolFileManager from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_entities import VariableSelector +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.http_request.executor import Executor from core.workflow.utils import variable_template_parser from factories import file_factory -from models.workflow import WorkflowNodeExecutionStatus from .entities import ( HttpRequestNodeData, diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index cb51b1ddd5..976922f75d 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -4,12 +4,12 @@ from typing_extensions import deprecated from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.if_else.entities import IfElseNodeData from core.workflow.utils.condition.entities import Condition from core.workflow.utils.condition.processor import ConditionProcessor -from models.workflow import WorkflowNodeExecutionStatus class IfElseNode(BaseNode[IfElseNodeData]): diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index a7d0aefc6d..42b6795fb0 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -12,10 +12,10 @@ from flask import Flask, current_app from configs import dify_config from core.variables import ArrayVariable, IntegerVariable, NoneVariable from core.workflow.entities.node_entities import ( - NodeRunMetadataKey, NodeRunResult, ) from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.graph_engine.entities.event import ( BaseGraphEvent, BaseNodeEvent, @@ -37,7 +37,7 @@ from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from models.workflow import WorkflowNodeExecutionStatus +from libs.flask_utils import preserve_flask_contexts from .exc import ( InvalidIteratorValueError, @@ -249,8 +249,8 @@ class IterationNode(BaseNode[IterationNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"output": outputs}, metadata={ - NodeRunMetadataKey.ITERATION_DURATION_MAP: iter_run_map, - NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, }, ) ) @@ -353,27 +353,26 @@ class IterationNode(BaseNode[IterationNodeData]): ) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent: """ add iteration metadata to event. + ensures iteration context (ID, index/parallel_run_id) is added to metadata, """ if not isinstance(event, BaseNodeEvent): return event if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent): event.parallel_mode_run_id = parallel_mode_run_id - return event + + iter_metadata = { + WorkflowNodeExecutionMetadataKey.ITERATION_ID: self.node_id, + WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index, + } + if parallel_mode_run_id: + # for parallel, the specific branch ID is more important than the sequential index + iter_metadata[WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id + if event.route_node_state.node_run_result: - metadata = event.route_node_state.node_run_result.metadata - if not metadata: - metadata = {} - if NodeRunMetadataKey.ITERATION_ID not in metadata: - metadata = { - **metadata, - NodeRunMetadataKey.ITERATION_ID: self.node_id, - NodeRunMetadataKey.PARALLEL_MODE_RUN_ID - if self.node_data.is_parallel - else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id - if self.node_data.is_parallel - else iter_run_index, - } - event.route_node_state.node_run_result.metadata = metadata + current_metadata = event.route_node_state.node_run_result.metadata or {} + if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata: + event.route_node_state.node_run_result.metadata = {**current_metadata, **iter_metadata} + return event def _run_single_iter( @@ -585,9 +584,8 @@ class IterationNode(BaseNode[IterationNodeData]): """ run single iteration in parallel mode """ - for var, val in context.items(): - var.set(val) - with flask_app.app_context(): + + with preserve_flask_contexts(flask_app, context_vars=context): parallel_mode_run_id = uuid.uuid4().hex graph_engine_copy = graph_engine.create_copy() variable_pool_copy = graph_engine_copy.graph_runtime_state.variable_pool diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index fe955e47d1..bee481ebdb 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -1,8 +1,8 @@ from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.iteration.entities import IterationStartNodeData -from models.workflow import WorkflowNodeExecutionStatus class IterationStartNode(BaseNode[IterationStartNodeData]): diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index d2e5a15545..19bdee4fe2 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -132,3 +132,12 @@ class KnowledgeRetrievalNodeData(BaseNodeData): metadata_model_config: Optional[ModelConfig] = None metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None vision: VisionConfig = Field(default_factory=VisionConfig) + + @property + def structured_output_enabled(self) -> bool: + # NOTE(QuantumGhost): Temporary workaround for issue #20725 + # (https://github.com/langgenius/dify/issues/20725). + # + # The proper fix would be to make `KnowledgeRetrievalNode` inherit + # from `BaseNode` instead of `LLMNode`. + return False diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 5c4cac9719..5cf5848d54 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -8,6 +8,7 @@ from typing import Any, Optional, cast from sqlalchemy import Float, and_, func, or_, text from sqlalchemy import cast as sqlalchemy_cast +from sqlalchemy.orm import Session from core.app.app_config.entities import DatasetRetrieveConfigEntity from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity @@ -24,6 +25,7 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import StringSegment from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event.event import ModelInvokeCompletedEvent from core.workflow.nodes.knowledge_retrieval.template_prompts import ( @@ -41,7 +43,6 @@ from extensions.ext_database import db from extensions.ext_redis import redis_client from libs.json_in_md_parser import parse_and_check_json_markdown from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog -from models.workflow import WorkflowNodeExecutionStatus from services.feature_service import FeatureService from .entities import KnowledgeRetrievalNodeData, ModelConfig @@ -85,30 +86,31 @@ class KnowledgeRetrievalNode(LLMNode): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Query is required." ) + # TODO(-LAN-): Move this check outside. # check rate limit - if self.tenant_id: - knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) - if knowledge_rate_limit.enabled: - current_time = int(time.time() * 1000) - key = f"rate_limit_{self.tenant_id}" - redis_client.zadd(key, {current_time: current_time}) - redis_client.zremrangebyscore(key, 0, current_time - 60000) - request_count = redis_client.zcard(key) - if request_count > knowledge_rate_limit.limit: + knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) + if knowledge_rate_limit.enabled: + current_time = int(time.time() * 1000) + key = f"rate_limit_{self.tenant_id}" + redis_client.zadd(key, {current_time: current_time}) + redis_client.zremrangebyscore(key, 0, current_time - 60000) + request_count = redis_client.zcard(key) + if request_count > knowledge_rate_limit.limit: + with Session(db.engine) as session: # add ratelimit record rate_limit_log = RateLimitLog( tenant_id=self.tenant_id, subscription_plan=knowledge_rate_limit.subscription_plan, operation="knowledge", ) - db.session.add(rate_limit_log) - db.session.commit() - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error="Sorry, you have reached the knowledge base request rate limit of your subscription.", - error_type="RateLimitExceeded", - ) + session.add(rate_limit_log) + session.commit() + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error="Sorry, you have reached the knowledge base request rate limit of your subscription.", + error_type="RateLimitExceeded", + ) # retrieve knowledge try: @@ -173,7 +175,9 @@ class KnowledgeRetrievalNode(LLMNode): dataset_retrieval = DatasetRetrieval() if node_data.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value: # fetch model config - model_instance, model_config = self._fetch_model_config(node_data.single_retrieval_config.model) # type: ignore + if node_data.single_retrieval_config is None: + raise ValueError("single_retrieval_config is required") + model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model) # check model is support tool calling model_type_instance = model_config.provider_model_bundle.model_type_instance model_type_instance = cast(LargeLanguageModel, model_type_instance) @@ -264,6 +268,7 @@ class KnowledgeRetrievalNode(LLMNode): "data_source_type": "external", "retriever_from": "workflow", "score": item.metadata.get("score"), + "doc_metadata": item.metadata, }, "title": item.metadata.get("title"), "content": item.page_content, @@ -275,12 +280,16 @@ class KnowledgeRetrievalNode(LLMNode): if records: for record in records: segment = record.segment - dataset = Dataset.query.filter_by(id=segment.dataset_id).first() - document = Document.query.filter( - Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ).first() + dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore + document = ( + db.session.query(Document) + .filter( + Document.id == segment.document_id, + Document.enabled == True, + Document.archived == False, + ) + .first() + ) if dataset and document: source = { "metadata": { @@ -289,7 +298,7 @@ class KnowledgeRetrievalNode(LLMNode): "dataset_name": dataset.name, "document_id": document.id, "document_name": document.name, - "document_data_source_type": document.data_source_type, + "data_source_type": document.data_source_type, "segment_id": segment.id, "retriever_from": "workflow", "score": record.score or 0.0, @@ -419,7 +428,7 @@ class KnowledgeRetrievalNode(LLMNode): raise ValueError("metadata_model_config is required") # get metadata model instance # fetch model config - model_instance, model_config = self._fetch_model_config(node_data.metadata_model_config) # type: ignore + model_instance, model_config = self.get_model_config(metadata_model_config) # fetch prompt messages prompt_template = self._get_prompt_template( node_data=node_data, @@ -545,14 +554,7 @@ class KnowledgeRetrievalNode(LLMNode): variable_mapping[node_id + ".query"] = node_data.query_variable_selector return variable_mapping - def _fetch_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: # type: ignore - """ - Fetch model config - :param model: model - :return: - """ - if model is None: - raise ValueError("model is required") + def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: model_name = model.name provider_name = model.provider diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 04ccfc5405..e698d3f5d8 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -4,9 +4,9 @@ from typing import Any, Literal, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType -from models.workflow import WorkflowNodeExecutionStatus from .entities import ListOperatorNodeData from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 486b4b01af..36d0688807 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -66,7 +66,8 @@ class LLMNodeData(BaseNodeData): context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) structured_output: dict | None = None - structured_output_enabled: bool = False + # We used 'structured_output_enabled' in the past, but it's not a good name. + structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") @field_validator("prompt_config", mode="before") @classmethod @@ -74,3 +75,7 @@ class LLMNodeData(BaseNodeData): if v is None: return PromptConfig() return v + + @property + def structured_output_enabled(self) -> bool: + return self.structured_output_switch_on and self.structured_output is not None diff --git a/api/core/workflow/nodes/llm/llm_utils.py b/api/core/workflow/nodes/llm/llm_utils.py new file mode 100644 index 0000000000..0966c87a1d --- /dev/null +++ b/api/core/workflow/nodes/llm/llm_utils.py @@ -0,0 +1,156 @@ +from collections.abc import Sequence +from datetime import UTC, datetime +from typing import Optional, cast + +from sqlalchemy import select, update +from sqlalchemy.orm import Session + +from configs import dify_config +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.entities.provider_entities import QuotaUnit +from core.file.models import File +from core.memory.token_buffer_memory import TokenBufferMemory +from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.llm_entities import LLMUsage +from core.model_runtime.entities.model_entities import ModelType +from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.plugin.entities.plugin import ModelProviderID +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.enums import SystemVariableKey +from core.workflow.nodes.llm.entities import ModelConfig +from models import db +from models.model import Conversation +from models.provider import Provider, ProviderType + +from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError + + +def fetch_model_config( + tenant_id: str, node_data_model: ModelConfig +) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: + if not node_data_model.mode: + raise LLMModeRequiredError("LLM mode is required.") + + model = ModelManager().get_model_instance( + tenant_id=tenant_id, + model_type=ModelType.LLM, + provider=node_data_model.provider, + model=node_data_model.name, + ) + + model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance) + + # check model + provider_model = model.provider_model_bundle.configuration.get_provider_model( + model=node_data_model.name, model_type=ModelType.LLM + ) + + if provider_model is None: + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + provider_model.raise_for_status() + + # model config + stop: list[str] = [] + if "stop" in node_data_model.completion_params: + stop = node_data_model.completion_params.pop("stop") + + model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials) + if not model_schema: + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + + return model, ModelConfigWithCredentialsEntity( + provider=node_data_model.provider, + model=node_data_model.name, + model_schema=model_schema, + mode=node_data_model.mode, + provider_model_bundle=model.provider_model_bundle, + credentials=model.credentials, + parameters=node_data_model.completion_params, + stop=stop, + ) + + +def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]: + variable = variable_pool.get(selector) + if variable is None: + return [] + elif isinstance(variable, FileSegment): + return [variable.value] + elif isinstance(variable, ArrayFileSegment): + return variable.value + elif isinstance(variable, NoneSegment | ArrayAnySegment): + return [] + raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") + + +def fetch_memory( + variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance +) -> Optional[TokenBufferMemory]: + if not node_data_memory: + return None + + # get conversation id + conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value]) + if not isinstance(conversation_id_variable, StringSegment): + return None + conversation_id = conversation_id_variable.value + + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) + conversation = session.scalar(stmt) + if not conversation: + return None + + memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) + return memory + + +def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: + provider_model_bundle = model_instance.provider_model_bundle + provider_configuration = provider_model_bundle.configuration + + if provider_configuration.using_provider_type != ProviderType.SYSTEM: + return + + system_configuration = provider_configuration.system_configuration + + quota_unit = None + for quota_configuration in system_configuration.quota_configurations: + if quota_configuration.quota_type == system_configuration.current_quota_type: + quota_unit = quota_configuration.quota_unit + + if quota_configuration.quota_limit == -1: + return + + break + + used_quota = None + if quota_unit: + if quota_unit == QuotaUnit.TOKENS: + used_quota = usage.total_tokens + elif quota_unit == QuotaUnit.CREDITS: + used_quota = dify_config.get_model_credits(model_instance.model) + else: + used_quota = 1 + + if used_quota is not None and system_configuration.current_quota_type is not None: + with Session(db.engine) as session: + stmt = ( + update(Provider) + .where( + Provider.tenant_id == tenant_id, + # TODO: Use provider name with prefix after the data migration. + Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, + Provider.provider_type == ProviderType.SYSTEM.value, + Provider.quota_type == system_configuration.current_quota_type.value, + Provider.quota_limit > Provider.quota_used, + ) + .values( + quota_used=Provider.quota_used + used_quota, + last_used=datetime.now(tz=UTC).replace(tzinfo=None), + ) + ) + session.execute(stmt) + session.commit() diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index f42bc6784d..d27124d62c 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -3,16 +3,11 @@ import io import json import logging from collections.abc import Generator, Mapping, Sequence -from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, Optional, cast import json_repair -from configs import dify_config from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.model_entities import ModelStatus -from core.entities.provider_entities import QuotaUnit -from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError from core.file import FileType, file_manager from core.helper.code_executor import CodeExecutor, CodeLanguage from core.memory.token_buffer_memory import TokenBufferMemory @@ -40,11 +35,10 @@ from core.model_runtime.entities.model_entities import ( ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.plugin.entities.plugin import ModelProviderID from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.variables import ( - ArrayAnySegment, ArrayFileSegment, ArraySegment, FileSegment, @@ -53,9 +47,10 @@ from core.variables import ( StringSegment, ) from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base import BaseNode @@ -70,15 +65,11 @@ from core.workflow.nodes.event import ( from core.workflow.utils.structured_output.entities import ( ResponseFormat, SpecialModelType, - SupportStructuredOutputStatus, ) from core.workflow.utils.structured_output.prompt import STRUCTURED_OUTPUT_PROMPT from core.workflow.utils.variable_template_parser import VariableTemplateParser -from extensions.ext_database import db -from models.model import Conversation -from models.provider import Provider, ProviderType -from models.workflow import WorkflowNodeExecutionStatus +from . import llm_utils from .entities import ( LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, @@ -88,7 +79,6 @@ from .entities import ( from .exc import ( InvalidContextStructureError, InvalidVariableTypeError, - LLMModeRequiredError, LLMNodeError, MemoryRolePrefixRequiredError, ModelNotExistError, @@ -160,6 +150,7 @@ class LLMNode(BaseNode[LLMNodeData]): result_text = "" usage = LLMUsage.empty_usage() finish_reason = None + variable_pool = self.graph_runtime_state.variable_pool try: # init messages template @@ -178,7 +169,10 @@ class LLMNode(BaseNode[LLMNodeData]): # fetch files files = ( - self._fetch_files(selector=self.node_data.vision.configs.variable_selector) + llm_utils.fetch_files( + variable_pool=variable_pool, + selector=self.node_data.vision.configs.variable_selector, + ) if self.node_data.vision.enabled else [] ) @@ -200,15 +194,18 @@ class LLMNode(BaseNode[LLMNodeData]): model_instance, model_config = self._fetch_model_config(self.node_data.model) # fetch memory - memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance) + memory = llm_utils.fetch_memory( + variable_pool=variable_pool, + app_id=self.app_id, + node_data_memory=self.node_data.memory, + model_instance=model_instance, + ) query = None if self.node_data.memory: query = self.node_data.memory.query_prompt_template if not query and ( - query_variable := self.graph_runtime_state.variable_pool.get( - (SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY) - ) + query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) ): query = query_variable.text @@ -222,7 +219,7 @@ class LLMNode(BaseNode[LLMNodeData]): memory_config=self.node_data.memory, vision_enabled=self.node_data.vision.enabled, vision_detail=self.node_data.vision.configs.detail, - variable_pool=self.graph_runtime_state.variable_pool, + variable_pool=variable_pool, jinja2_variables=self.node_data.prompt_config.jinja2_variables, ) @@ -251,7 +248,7 @@ class LLMNode(BaseNode[LLMNodeData]): usage = event.usage finish_reason = event.finish_reason # deduct quota - self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) + llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) break outputs = {"text": result_text, "usage": jsonable_encoder(usage), "finish_reason": finish_reason} structured_output = process_structured_output(result_text) @@ -267,14 +264,14 @@ class LLMNode(BaseNode[LLMNodeData]): process_data=process_data, outputs=outputs, metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, - NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, }, llm_usage=usage, ) ) - except LLMNodeError as e: + except ValueError as e: yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -302,8 +299,6 @@ class LLMNode(BaseNode[LLMNodeData]): prompt_messages: Sequence[PromptMessage], stop: Optional[Sequence[str]] = None, ) -> Generator[NodeEvent, None, None]: - db.session.close() - invoke_result = model_instance.invoke_llm( prompt_messages=list(prompt_messages), model_parameters=node_data_model.completion_params, @@ -449,18 +444,6 @@ class LLMNode(BaseNode[LLMNodeData]): return inputs - def _fetch_files(self, *, selector: Sequence[str]) -> Sequence["File"]: - variable = self.graph_runtime_state.variable_pool.get(selector) - if variable is None: - return [] - elif isinstance(variable, FileSegment): - return [variable.value] - elif isinstance(variable, ArrayFileSegment): - return variable.value - elif isinstance(variable, NoneSegment | ArrayAnySegment): - return [] - raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") - def _fetch_context(self, node_data: LLMNodeData): if not node_data.context.enabled: return @@ -474,7 +457,7 @@ class LLMNode(BaseNode[LLMNodeData]): yield RunRetrieverResourceEvent(retriever_resources=[], context=context_value_variable.value) elif isinstance(context_value_variable, ArraySegment): context_str = "" - original_retriever_resource = [] + original_retriever_resource: list[RetrievalSourceMetadata] = [] for item in context_value_variable.value: if isinstance(item, str): context_str += item + "\n" @@ -492,7 +475,7 @@ class LLMNode(BaseNode[LLMNodeData]): retriever_resources=original_retriever_resource, context=context_str.strip() ) - def _convert_to_original_retriever_resource(self, context_dict: dict) -> Optional[dict]: + def _convert_to_original_retriever_resource(self, context_dict: dict): if ( "metadata" in context_dict and "_source" in context_dict["metadata"] @@ -500,24 +483,24 @@ class LLMNode(BaseNode[LLMNodeData]): ): metadata = context_dict.get("metadata", {}) - source = { - "position": metadata.get("position"), - "dataset_id": metadata.get("dataset_id"), - "dataset_name": metadata.get("dataset_name"), - "document_id": metadata.get("document_id"), - "document_name": metadata.get("document_name"), - "data_source_type": metadata.get("document_data_source_type"), - "segment_id": metadata.get("segment_id"), - "retriever_from": metadata.get("retriever_from"), - "score": metadata.get("score"), - "hit_count": metadata.get("segment_hit_count"), - "word_count": metadata.get("segment_word_count"), - "segment_position": metadata.get("segment_position"), - "index_node_hash": metadata.get("segment_index_node_hash"), - "content": context_dict.get("content"), - "page": metadata.get("page"), - "doc_metadata": metadata.get("doc_metadata"), - } + source = RetrievalSourceMetadata( + position=metadata.get("position"), + dataset_id=metadata.get("dataset_id"), + dataset_name=metadata.get("dataset_name"), + document_id=metadata.get("document_id"), + document_name=metadata.get("document_name"), + data_source_type=metadata.get("data_source_type"), + segment_id=metadata.get("segment_id"), + retriever_from=metadata.get("retriever_from"), + score=metadata.get("score"), + hit_count=metadata.get("segment_hit_count"), + word_count=metadata.get("segment_word_count"), + segment_position=metadata.get("segment_position"), + index_node_hash=metadata.get("segment_index_node_hash"), + content=context_dict.get("content"), + page=metadata.get("page"), + doc_metadata=metadata.get("doc_metadata"), + ) return source @@ -526,95 +509,25 @@ class LLMNode(BaseNode[LLMNodeData]): def _fetch_model_config( self, node_data_model: ModelConfig ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - model_name = node_data_model.name - provider_name = node_data_model.provider - - model_manager = ModelManager() - model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name + model, model_config_with_cred = llm_utils.fetch_model_config( + tenant_id=self.tenant_id, node_data_model=node_data_model ) + completion_params = model_config_with_cred.parameters - provider_model_bundle = model_instance.provider_model_bundle - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - model_credentials = model_instance.credentials - - # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, model_type=ModelType.LLM - ) - - if provider_model is None: - raise ModelNotExistError(f"Model {model_name} not exist.") - - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelCurrentlyNotSupportError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise QuotaExceededError(f"Model provider {provider_name} quota exceeded.") - - # model config - completion_params = node_data_model.completion_params - stop = [] - if "stop" in completion_params: - stop = completion_params["stop"] - del completion_params["stop"] - - # get model mode - model_mode = node_data_model.mode - if not model_mode: - raise LLMModeRequiredError("LLM mode is required.") - - model_schema = model_type_instance.get_model_schema(model_name, model_credentials) - + model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials) if not model_schema: - raise ModelNotExistError(f"Model {model_name} not exist.") - support_structured_output = self._check_model_structured_output_support() - if support_structured_output == SupportStructuredOutputStatus.SUPPORTED: - completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) - elif support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: - # Set appropriate response format based on model capabilities - self._set_response_format(completion_params, model_schema.parameter_rules) - return model_instance, ModelConfigWithCredentialsEntity( - provider=provider_name, - model=model_name, - model_schema=model_schema, - mode=model_mode, - provider_model_bundle=provider_model_bundle, - credentials=model_credentials, - parameters=completion_params, - stop=stop, - ) - - def _fetch_memory( - self, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance - ) -> Optional[TokenBufferMemory]: - if not node_data_memory: - return None - - # get conversation id - conversation_id_variable = self.graph_runtime_state.variable_pool.get( - ["sys", SystemVariableKey.CONVERSATION_ID.value] - ) - if not isinstance(conversation_id_variable, StringSegment): - return None - conversation_id = conversation_id_variable.value - - # get conversation - conversation = ( - db.session.query(Conversation) - .filter(Conversation.app_id == self.app_id, Conversation.id == conversation_id) - .first() - ) - - if not conversation: - return None + raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance) - - return memory + if self.node_data.structured_output_enabled: + if model_schema.support_structure_output: + completion_params = self._handle_native_json_schema(completion_params, model_schema.parameter_rules) + else: + # Set appropriate response format based on model capabilities + self._set_response_format(completion_params, model_schema.parameter_rules) + model_config_with_cred.parameters = completion_params + # NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`. + node_data_model.completion_params = completion_params + return model, model_config_with_cred def _fetch_prompt_messages( self, @@ -789,13 +702,25 @@ class LLMNode(BaseNode[LLMNodeData]): "No prompt found in the LLM configuration. " "Please ensure a prompt is properly configured before proceeding." ) - support_structured_output = self._check_model_structured_output_support() - if support_structured_output == SupportStructuredOutputStatus.UNSUPPORTED: - filtered_prompt_messages = self._handle_prompt_based_schema( - prompt_messages=filtered_prompt_messages, - ) - stop = model_config.stop - return filtered_prompt_messages, stop + + model = ModelManager().get_model_instance( + tenant_id=self.tenant_id, + model_type=ModelType.LLM, + provider=model_config.provider, + model=model_config.model, + ) + model_schema = model.model_type_instance.get_model_schema( + model=model_config.model, + credentials=model.credentials, + ) + if not model_schema: + raise ModelNotExistError(f"Model {model_config.model} not exist.") + if self.node_data.structured_output_enabled: + if not model_schema.support_structure_output: + filtered_prompt_messages = self._handle_prompt_based_schema( + prompt_messages=filtered_prompt_messages, + ) + return filtered_prompt_messages, model_config.stop def _parse_structured_output(self, result_text: str) -> dict[str, Any]: structured_output: dict[str, Any] = {} @@ -816,51 +741,6 @@ class LLMNode(BaseNode[LLMNodeData]): structured_output = parsed return structured_output - @classmethod - def deduct_llm_quota(cls, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None: - provider_model_bundle = model_instance.provider_model_bundle - provider_configuration = provider_model_bundle.configuration - - if provider_configuration.using_provider_type != ProviderType.SYSTEM: - return - - system_configuration = provider_configuration.system_configuration - - quota_unit = None - for quota_configuration in system_configuration.quota_configurations: - if quota_configuration.quota_type == system_configuration.current_quota_type: - quota_unit = quota_configuration.quota_unit - - if quota_configuration.quota_limit == -1: - return - - break - - used_quota = None - if quota_unit: - if quota_unit == QuotaUnit.TOKENS: - used_quota = usage.total_tokens - elif quota_unit == QuotaUnit.CREDITS: - used_quota = dify_config.get_model_credits(model_instance.model) - else: - used_quota = 1 - - if used_quota is not None and system_configuration.current_quota_type is not None: - db.session.query(Provider).filter( - Provider.tenant_id == tenant_id, - # TODO: Use provider name with prefix after the data migration. - Provider.provider_name == ModelProviderID(model_instance.provider).provider_name, - Provider.provider_type == ProviderType.SYSTEM.value, - Provider.quota_type == system_configuration.current_quota_type.value, - Provider.quota_limit > Provider.quota_used, - ).update( - { - "quota_used": Provider.quota_used + used_quota, - "last_used": datetime.now(tz=UTC).replace(tzinfo=None), - } - ) - db.session.commit() - @classmethod def _extract_variable_selector_to_variable_mapping( cls, @@ -902,7 +782,7 @@ class LLMNode(BaseNode[LLMNodeData]): variable_mapping["#context#"] = node_data.context.variable_selector if node_data.vision.enabled: - variable_mapping["#files#"] = ["sys", SystemVariableKey.FILES.value] + variable_mapping["#files#"] = node_data.vision.configs.variable_selector if node_data.memory: variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value] @@ -1184,32 +1064,6 @@ class LLMNode(BaseNode[LLMNodeData]): except json.JSONDecodeError: raise LLMNodeError("structured_output_schema is not valid JSON format") - def _check_model_structured_output_support(self) -> SupportStructuredOutputStatus: - """ - Check if the current model supports structured output. - - Returns: - SupportStructuredOutput: The support status of structured output - """ - # Early return if structured output is disabled - if ( - not isinstance(self.node_data, LLMNodeData) - or not self.node_data.structured_output_enabled - or not self.node_data.structured_output - ): - return SupportStructuredOutputStatus.DISABLED - # Get model schema and check if it exists - model_schema = self._fetch_model_schema(self.node_data.model.provider) - if not model_schema: - return SupportStructuredOutputStatus.DISABLED - - # Check if model supports structured output feature - return ( - SupportStructuredOutputStatus.SUPPORTED - if bool(model_schema.features and ModelFeature.STRUCTURED_OUTPUT in model_schema.features) - else SupportStructuredOutputStatus.UNSUPPORTED - ) - def _save_multimodal_output_and_convert_result_to_markdown( self, contents: str | list[PromptMessageContentUnionTypes] | None, diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 16802311dc..3f4a5edab9 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -26,7 +26,7 @@ class LoopNodeData(BaseLoopNodeData): loop_count: int # Maximum number of loops break_conditions: list[Condition] # Conditions to break the loop logical_operator: Literal["and", "or"] - loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list) + loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData]) outputs: Optional[Mapping[str, Any]] = None diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index 5d4ce0ccbe..327b9e234b 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -1,8 +1,8 @@ from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.loop.entities import LoopEndNodeData -from models.workflow import WorkflowNodeExecutionStatus class LoopEndNode(BaseNode[LoopEndNodeData]): diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index eae33c0a92..fafa205386 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -15,7 +15,8 @@ from core.variables import ( SegmentType, StringSegment, ) -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.graph_engine.entities.event import ( BaseGraphEvent, BaseNodeEvent, @@ -37,7 +38,6 @@ from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.loop.entities import LoopNodeData from core.workflow.utils.condition.processor import ConditionProcessor -from models.workflow import WorkflowNodeExecutionStatus if TYPE_CHECKING: from core.workflow.entities.variable_pool import VariablePool @@ -187,10 +187,10 @@ class LoopNode(BaseNode[LoopNodeData]): outputs=self.node_data.outputs, steps=loop_count, metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, "completed_reason": "loop_break" if check_break_result else "loop_completed", - NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, + WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, + WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, }, ) @@ -198,9 +198,9 @@ class LoopNode(BaseNode[LoopNodeData]): run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, - NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, + WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, }, outputs=self.node_data.outputs, inputs=inputs, @@ -221,8 +221,8 @@ class LoopNode(BaseNode[LoopNodeData]): metadata={ "total_tokens": graph_engine.graph_runtime_state.total_tokens, "completed_reason": "error", - NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, + WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, + WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, }, error=str(e), ) @@ -232,9 +232,9 @@ class LoopNode(BaseNode[LoopNodeData]): status=WorkflowNodeExecutionStatus.FAILED, error=str(e), metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, - NodeRunMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - NodeRunMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, + WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, }, ) ) @@ -322,7 +322,9 @@ class LoopNode(BaseNode[LoopNodeData]): inputs=inputs, steps=current_index, metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: ( + graph_engine.graph_runtime_state.total_tokens + ), "completed_reason": "error", }, error=event.error, @@ -331,13 +333,17 @@ class LoopNode(BaseNode[LoopNodeData]): run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=event.error, - metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens}, + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: ( + graph_engine.graph_runtime_state.total_tokens + ) + }, ) ) return {"check_break_result": True} elif isinstance(event, NodeRunFailedEvent): # Loop run failed - yield event + yield self._handle_event_metadata(event=event, iter_run_index=current_index) yield LoopRunFailedEvent( loop_id=self.id, loop_node_id=self.node_id, @@ -347,7 +353,7 @@ class LoopNode(BaseNode[LoopNodeData]): inputs=inputs, steps=current_index, metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, "completed_reason": "error", }, error=event.error, @@ -356,7 +362,9 @@ class LoopNode(BaseNode[LoopNodeData]): run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=event.error, - metadata={NodeRunMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens}, + metadata={ + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens + }, ) ) return {"check_break_result": True} @@ -411,11 +419,11 @@ class LoopNode(BaseNode[LoopNodeData]): metadata = event.route_node_state.node_run_result.metadata if not metadata: metadata = {} - if NodeRunMetadataKey.LOOP_ID not in metadata: + if WorkflowNodeExecutionMetadataKey.LOOP_ID not in metadata: metadata = { **metadata, - NodeRunMetadataKey.LOOP_ID: self.node_id, - NodeRunMetadataKey.LOOP_INDEX: iter_run_index, + WorkflowNodeExecutionMetadataKey.LOOP_ID: self.node_id, + WorkflowNodeExecutionMetadataKey.LOOP_INDEX: iter_run_index, } event.route_node_state.node_run_result.metadata = metadata return event diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index 7cf145e4e5..5a15f36044 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -1,8 +1,8 @@ from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.loop.entities import LoopStartNodeData -from models.workflow import WorkflowNodeExecutionStatus class LoopStartNode(BaseNode[LoopStartNodeData]): diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 8db1e432fc..2552784762 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -25,13 +25,13 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.nodes.base.node import BaseNode from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.llm import LLMNode, ModelConfig +from core.workflow.nodes.llm import ModelConfig, llm_utils from core.workflow.utils import variable_template_parser -from extensions.ext_database import db -from models.workflow import WorkflowNodeExecutionStatus from .entities import ParameterExtractorNodeData from .exc import ( @@ -84,7 +84,7 @@ def extract_json(text): return None -class ParameterExtractorNode(LLMNode): +class ParameterExtractorNode(BaseNode): """ Parameter Extractor Node. """ @@ -117,8 +117,11 @@ class ParameterExtractorNode(LLMNode): variable = self.graph_runtime_state.variable_pool.get(node_data.query) query = variable.text if variable else "" + variable_pool = self.graph_runtime_state.variable_pool + files = ( - self._fetch_files( + llm_utils.fetch_files( + variable_pool=variable_pool, selector=node_data.vision.configs.variable_selector, ) if node_data.vision.enabled @@ -138,7 +141,9 @@ class ParameterExtractorNode(LLMNode): raise ModelSchemaNotFoundError("Model schema not found") # fetch memory - memory = self._fetch_memory( + memory = llm_utils.fetch_memory( + variable_pool=variable_pool, + app_id=self.app_id, node_data_memory=node_data.memory, model_instance=model_instance, ) @@ -244,9 +249,9 @@ class ParameterExtractorNode(LLMNode): process_data=process_data, outputs={"__is_success": 1 if not error else 0, "__reason": error, **result}, metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, - NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, }, llm_usage=usage, ) @@ -259,8 +264,6 @@ class ParameterExtractorNode(LLMNode): tools: list[PromptMessageTool], stop: list[str], ) -> tuple[str, LLMUsage, Optional[AssistantPromptMessage.ToolCall]]: - db.session.close() - invoke_result = model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=node_data_model.completion_params, @@ -282,7 +285,7 @@ class ParameterExtractorNode(LLMNode): tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None # deduct quota - self.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) + llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage) if text is None: text = "" @@ -797,7 +800,9 @@ class ParameterExtractorNode(LLMNode): Fetch model config. """ if not self._model_instance or not self._model_config: - self._model_instance, self._model_config = super()._fetch_model_config(node_data_model) + self._model_instance, self._model_config = llm_utils.fetch_model_config( + tenant_id=self.tenant_id, node_data_model=node_data_model + ) return self._model_instance, self._model_config @@ -816,7 +821,6 @@ class ParameterExtractorNode(LLMNode): :param node_data: node data :return: """ - # FIXME: fix the type error later variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} if node_data.instruction: diff --git a/api/core/workflow/nodes/question_classifier/entities.py b/api/core/workflow/nodes/question_classifier/entities.py index 5219f11d26..6248df0edf 100644 --- a/api/core/workflow/nodes/question_classifier/entities.py +++ b/api/core/workflow/nodes/question_classifier/entities.py @@ -19,3 +19,12 @@ class QuestionClassifierNodeData(BaseNodeData): instruction: Optional[str] = None memory: Optional[MemoryConfig] = None vision: VisionConfig = Field(default_factory=VisionConfig) + + @property + def structured_output_enabled(self) -> bool: + # NOTE(QuantumGhost): Temporary workaround for issue #20725 + # (https://github.com/langgenius/dify/issues/20725). + # + # The proper fix would be to make `QuestionClassifierNode` inherit + # from `BaseNode` instead of `LLMNode`. + return False diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index b4f34a3bef..1f50700c7e 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -10,17 +10,18 @@ from core.model_runtime.utils.encoders import jsonable_encoder from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.nodes.enums import NodeType from core.workflow.nodes.event import ModelInvokeCompletedEvent from core.workflow.nodes.llm import ( LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, + llm_utils, ) from core.workflow.utils.variable_template_parser import VariableTemplateParser from libs.json_in_md_parser import parse_and_check_json_markdown -from models.workflow import WorkflowNodeExecutionStatus from .entities import QuestionClassifierNodeData from .exc import InvalidModelTypeError @@ -50,7 +51,9 @@ class QuestionClassifierNode(LLMNode): # fetch model config model_instance, model_config = self._fetch_model_config(node_data.model) # fetch memory - memory = self._fetch_memory( + memory = llm_utils.fetch_memory( + variable_pool=variable_pool, + app_id=self.app_id, node_data_memory=node_data.memory, model_instance=model_instance, ) @@ -59,7 +62,8 @@ class QuestionClassifierNode(LLMNode): node_data.instruction = variable_pool.convert_template(node_data.instruction).text files = ( - self._fetch_files( + llm_utils.fetch_files( + variable_pool=variable_pool, selector=node_data.vision.configs.variable_selector, ) if node_data.vision.enabled @@ -79,9 +83,13 @@ class QuestionClassifierNode(LLMNode): memory=memory, max_token_limit=rest_token, ) + # Some models (e.g. Gemma, Mistral) force roles alternation (user/assistant/user/assistant...). + # If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt, + # two consecutive user prompts will be generated, causing model's error. + # To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end. prompt_messages, stop = self._fetch_prompt_messages( prompt_template=prompt_template, - sys_query=query, + sys_query="", memory=memory, model_config=model_config, sys_files=files, @@ -142,9 +150,9 @@ class QuestionClassifierNode(LLMNode): outputs=outputs, edge_source_handle=category_id, metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, - NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, }, llm_usage=usage, ) @@ -154,9 +162,9 @@ class QuestionClassifierNode(LLMNode): inputs=variables, error=str(e), metadata={ - NodeRunMetadataKey.TOTAL_TOKENS: usage.total_tokens, - NodeRunMetadataKey.TOTAL_PRICE: usage.total_price, - NodeRunMetadataKey.CURRENCY: usage.currency, + WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, + WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, + WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, }, llm_usage=usage, ) diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 1b47b81517..8839aec9d6 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,9 +1,9 @@ from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.start.entities import StartNodeData -from models.workflow import WorkflowNodeExecutionStatus class StartNode(BaseNode[StartNodeData]): diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index 22a1b21888..476cf7eee4 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -4,10 +4,10 @@ from typing import Any, Optional from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor, CodeLanguage from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData -from models.workflow import WorkflowNodeExecutionStatus MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index c72ae5b69b..aaecc7b989 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -14,8 +14,9 @@ from core.tools.tool_engine import ToolEngine from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.variables.segments import ArrayAnySegment from core.variables.variables import ArrayAnyVariable -from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult +from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import AgentLogEvent from core.workflow.nodes.base import BaseNode @@ -25,7 +26,6 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db from factories import file_factory from models import ToolFile -from models.workflow import WorkflowNodeExecutionStatus from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .entities import ToolNodeData @@ -70,7 +70,7 @@ class ToolNode(BaseNode[ToolNodeData]): run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs={}, - metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, error=f"Failed to get tool runtime: {str(e)}", error_type=type(e).__name__, ) @@ -110,7 +110,7 @@ class ToolNode(BaseNode[ToolNodeData]): run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, error=f"Failed to invoke tool: {str(e)}", error_type=type(e).__name__, ) @@ -125,7 +125,7 @@ class ToolNode(BaseNode[ToolNodeData]): run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - metadata={NodeRunMetadataKey.TOOL_INFO: tool_info}, + metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, error=f"Failed to transform tool message: {str(e)}", error_type=type(e).__name__, ) @@ -201,7 +201,7 @@ class ToolNode(BaseNode[ToolNodeData]): json: list[dict] = [] agent_logs: list[AgentLogEvent] = [] - agent_execution_metadata: Mapping[NodeRunMetadataKey, Any] = {} + agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} variables: dict[str, Any] = {} @@ -274,7 +274,7 @@ class ToolNode(BaseNode[ToolNodeData]): agent_execution_metadata = { key: value for key, value in msg_metadata.items() - if key in NodeRunMetadataKey.__members__.values() + if key in WorkflowNodeExecutionMetadataKey.__members__.values() } json.append(message.message.json_object) elif message.type == ToolInvokeMessage.MessageType.LINK: @@ -366,8 +366,8 @@ class ToolNode(BaseNode[ToolNodeData]): outputs={"text": text, "files": files, "json": json, **variables}, metadata={ **agent_execution_metadata, - NodeRunMetadataKey.TOOL_INFO: tool_info, - NodeRunMetadataKey.AGENT_LOG: agent_logs, + WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, + WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, }, inputs=parameters_for_log, ) diff --git a/api/core/workflow/nodes/variable_aggregator/entities.py b/api/core/workflow/nodes/variable_aggregator/entities.py index 9e58f5e944..f4577d7573 100644 --- a/api/core/workflow/nodes/variable_aggregator/entities.py +++ b/api/core/workflow/nodes/variable_aggregator/entities.py @@ -1,7 +1,8 @@ -from typing import Literal, Optional +from typing import Optional from pydantic import BaseModel +from core.variables.types import SegmentType from core.workflow.nodes.base import BaseNodeData @@ -17,7 +18,7 @@ class AdvancedSettings(BaseModel): Group. """ - output_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] + output_type: SegmentType variables: list[list[str]] group_name: str diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 372496a8fa..db3e25b015 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,8 +1,8 @@ from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData -from models.workflow import WorkflowNodeExecutionStatus class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index 7c7f14c0b8..835e1d77b5 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,11 +1,11 @@ from core.variables import SegmentType, Variable from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from factories import variable_factory -from models.workflow import WorkflowNodeExecutionStatus from .node_data import VariableAssignerData, WriteMode diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 6a7ad86b51..8759a55b34 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -6,11 +6,11 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode from core.workflow.nodes.enums import NodeType from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from models.workflow import WorkflowNodeExecutionStatus from . import helpers from .constants import EMPTY_VALUE_MAPPING diff --git a/api/core/workflow/repository/__init__.py b/api/core/workflow/repositories/__init__.py similarity index 69% rename from api/core/workflow/repository/__init__.py rename to api/core/workflow/repositories/__init__.py index 672abb6583..a778151baa 100644 --- a/api/core/workflow/repository/__init__.py +++ b/api/core/workflow/repositories/__init__.py @@ -6,7 +6,7 @@ for accessing and manipulating data, regardless of the underlying storage mechanism. """ -from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository __all__ = [ "OrderConfig", diff --git a/api/core/workflow/repositories/workflow_execution_repository.py b/api/core/workflow/repositories/workflow_execution_repository.py new file mode 100644 index 0000000000..5917310c8b --- /dev/null +++ b/api/core/workflow/repositories/workflow_execution_repository.py @@ -0,0 +1,42 @@ +from typing import Optional, Protocol + +from core.workflow.entities.workflow_execution import WorkflowExecution + + +class WorkflowExecutionRepository(Protocol): + """ + Repository interface for WorkflowExecution. + + This interface defines the contract for accessing and manipulating + WorkflowExecution data, regardless of the underlying storage mechanism. + + Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), + and other implementation details should be handled at the implementation level, not in + the core interface. This keeps the core domain model clean and independent of specific + application domains or deployment scenarios. + """ + + def save(self, execution: WorkflowExecution) -> None: + """ + Save or update a WorkflowExecution instance. + + This method handles both creating new records and updating existing ones. + The implementation should determine whether to create or update based on + the execution's ID or other identifying fields. + + Args: + execution: The WorkflowExecution instance to save or update + """ + ... + + def get(self, execution_id: str) -> Optional[WorkflowExecution]: + """ + Retrieve a WorkflowExecution by its ID. + + Args: + execution_id: The workflow execution ID + + Returns: + The WorkflowExecution instance if found, None otherwise + """ + ... diff --git a/api/core/workflow/repository/workflow_node_execution_repository.py b/api/core/workflow/repositories/workflow_node_execution_repository.py similarity index 65% rename from api/core/workflow/repository/workflow_node_execution_repository.py rename to api/core/workflow/repositories/workflow_node_execution_repository.py index 9bb790cb0f..1908a6b190 100644 --- a/api/core/workflow/repository/workflow_node_execution_repository.py +++ b/api/core/workflow/repositories/workflow_node_execution_repository.py @@ -2,12 +2,12 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Literal, Optional, Protocol -from models.workflow import WorkflowNodeExecution +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution @dataclass class OrderConfig: - """Configuration for ordering WorkflowNodeExecution instances.""" + """Configuration for ordering NodeExecution instances.""" order_by: list[str] order_direction: Optional[Literal["asc", "desc"]] = None @@ -15,10 +15,10 @@ class OrderConfig: class WorkflowNodeExecutionRepository(Protocol): """ - Repository interface for WorkflowNodeExecution. + Repository interface for NodeExecution. This interface defines the contract for accessing and manipulating - WorkflowNodeExecution data, regardless of the underlying storage mechanism. + NodeExecution data, regardless of the underlying storage mechanism. Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), and trigger sources (triggered_from) should be handled at the implementation level, not in @@ -28,22 +28,26 @@ class WorkflowNodeExecutionRepository(Protocol): def save(self, execution: WorkflowNodeExecution) -> None: """ - Save a WorkflowNodeExecution instance. + Save or update a NodeExecution instance. + + This method handles both creating new records and updating existing ones. + The implementation should determine whether to create or update based on + the execution's ID or other identifying fields. Args: - execution: The WorkflowNodeExecution instance to save + execution: The NodeExecution instance to save or update """ ... def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: """ - Retrieve a WorkflowNodeExecution by its node_execution_id. + Retrieve a NodeExecution by its node_execution_id. Args: node_execution_id: The node execution ID Returns: - The WorkflowNodeExecution instance if found, None otherwise + The NodeExecution instance if found, None otherwise """ ... @@ -53,7 +57,7 @@ class WorkflowNodeExecutionRepository(Protocol): order_config: Optional[OrderConfig] = None, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all WorkflowNodeExecution instances for a specific workflow run. + Retrieve all NodeExecution instances for a specific workflow run. Args: workflow_run_id: The workflow run ID @@ -62,34 +66,25 @@ class WorkflowNodeExecutionRepository(Protocol): order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of WorkflowNodeExecution instances + A list of NodeExecution instances """ ... def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all running WorkflowNodeExecution instances for a specific workflow run. + Retrieve all running NodeExecution instances for a specific workflow run. Args: workflow_run_id: The workflow run ID Returns: - A list of running WorkflowNodeExecution instances - """ - ... - - def update(self, execution: WorkflowNodeExecution) -> None: - """ - Update an existing WorkflowNodeExecution instance. - - Args: - execution: The WorkflowNodeExecution instance to update + A list of running NodeExecution instances """ ... def clear(self) -> None: """ - Clear all WorkflowNodeExecution records based on implementation-specific criteria. + Clear all NodeExecution records based on implementation-specific criteria. This method is intended to be used for bulk deletion operations, such as removing all records associated with a specific app_id and tenant_id in multi-tenant implementations. diff --git a/api/core/workflow/utils/condition/entities.py b/api/core/workflow/utils/condition/entities.py index 799c735f54..56871a15d8 100644 --- a/api/core/workflow/utils/condition/entities.py +++ b/api/core/workflow/utils/condition/entities.py @@ -39,7 +39,7 @@ class SubCondition(BaseModel): class SubVariableCondition(BaseModel): logical_operator: Literal["and", "or"] - conditions: list[SubCondition] = Field(default=list) + conditions: list[SubCondition] = Field(default_factory=list) class Condition(BaseModel): diff --git a/api/core/workflow/utils/structured_output/entities.py b/api/core/workflow/utils/structured_output/entities.py index 7954acbaee..6491042bfe 100644 --- a/api/core/workflow/utils/structured_output/entities.py +++ b/api/core/workflow/utils/structured_output/entities.py @@ -14,11 +14,3 @@ class SpecialModelType(StrEnum): GEMINI = "gemini" OLLAMA = "ollama" - - -class SupportStructuredOutputStatus(StrEnum): - """Constants for structured output support status""" - - SUPPORTED = "supported" - UNSUPPORTED = "unsupported" - DISABLED = "disabled" diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 01d5db4303..b88f9edd03 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -1,22 +1,11 @@ -import json -import time -from collections.abc import Mapping, Sequence +from collections.abc import Mapping +from dataclasses import dataclass from datetime import UTC, datetime -from typing import Any, Optional, Union, cast +from typing import Any, Optional, Union from uuid import uuid4 -from sqlalchemy import func, select -from sqlalchemy.orm import Session - -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( - QueueAgentLogEvent, - QueueIterationCompletedEvent, - QueueIterationNextEvent, - QueueIterationStartEvent, - QueueLoopCompletedEvent, - QueueLoopNextEvent, - QueueLoopStartEvent, QueueNodeExceptionEvent, QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, @@ -24,49 +13,28 @@ from core.app.entities.queue_entities import ( QueueNodeRetryEvent, QueueNodeStartedEvent, QueueNodeSucceededEvent, - QueueParallelBranchRunFailedEvent, - QueueParallelBranchRunStartedEvent, - QueueParallelBranchRunSucceededEvent, -) -from core.app.entities.task_entities import ( - AgentLogStreamResponse, - IterationNodeCompletedStreamResponse, - IterationNodeNextStreamResponse, - IterationNodeStartStreamResponse, - LoopNodeCompletedStreamResponse, - LoopNodeNextStreamResponse, - LoopNodeStartStreamResponse, - NodeFinishStreamResponse, - NodeRetryStreamResponse, - NodeStartStreamResponse, - ParallelBranchFinishedStreamResponse, - ParallelBranchStartStreamResponse, - WorkflowFinishStreamResponse, - WorkflowStartStreamResponse, ) from core.app.task_pipeline.exc import WorkflowRunNotFoundError -from core.file import FILE_MODEL_IDENTITY, File -from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from core.tools.tool_manager import ToolManager -from core.workflow.entities.node_entities import NodeRunMetadataKey -from core.workflow.enums import SystemVariableKey -from core.workflow.nodes import NodeType -from core.workflow.nodes.tool.entities import ToolNodeData -from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from core.workflow.workflow_entry import WorkflowEntry -from models.account import Account -from models.enums import CreatedByRole, WorkflowRunTriggeredFrom -from models.model import EndUser -from models.workflow import ( - Workflow, +from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType +from core.workflow.entities.workflow_node_execution import ( WorkflowNodeExecution, + WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, - WorkflowNodeExecutionTriggeredFrom, - WorkflowRun, - WorkflowRunStatus, ) +from core.workflow.enums import SystemVariableKey +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.workflow_entry import WorkflowEntry + + +@dataclass +class CycleManagerWorkflowInfo: + workflow_id: str + workflow_type: WorkflowType + version: str + graph_data: Mapping[str, Any] class WorkflowCycleManager: @@ -75,276 +43,232 @@ class WorkflowCycleManager: *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], workflow_system_variables: dict[SystemVariableKey, Any], + workflow_info: CycleManagerWorkflowInfo, + workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, ) -> None: - self._workflow_run: WorkflowRun | None = None - self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {} self._application_generate_entity = application_generate_entity self._workflow_system_variables = workflow_system_variables + self._workflow_info = workflow_info + self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository - def _handle_workflow_run_start( - self, - *, - session: Session, - workflow_id: str, - user_id: str, - created_by_role: CreatedByRole, - ) -> WorkflowRun: - workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) - workflow = session.scalar(workflow_stmt) - if not workflow: - raise ValueError(f"Workflow not found: {workflow_id}") - - max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where( - WorkflowRun.tenant_id == workflow.tenant_id, - WorkflowRun.app_id == workflow.app_id, - ) - max_sequence = session.scalar(max_sequence_stmt) or 0 - new_sequence_number = max_sequence + 1 - + def handle_workflow_run_start(self) -> WorkflowExecution: inputs = {**self._application_generate_entity.inputs} for key, value in (self._workflow_system_variables or {}).items(): if key.value == "conversation": continue inputs[f"sys.{key.value}"] = value - triggered_from = ( - WorkflowRunTriggeredFrom.DEBUGGING - if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER - else WorkflowRunTriggeredFrom.APP_RUN - ) - # handle special values inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) # init workflow run # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this - workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4()) - - workflow_run = WorkflowRun() - workflow_run.id = workflow_run_id - workflow_run.tenant_id = workflow.tenant_id - workflow_run.app_id = workflow.app_id - workflow_run.sequence_number = new_sequence_number - workflow_run.workflow_id = workflow.id - workflow_run.type = workflow.type - workflow_run.triggered_from = triggered_from.value - workflow_run.version = workflow.version - workflow_run.graph = workflow.graph - workflow_run.inputs = json.dumps(inputs) - workflow_run.status = WorkflowRunStatus.RUNNING - workflow_run.created_by_role = created_by_role - workflow_run.created_by = user_id - workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) - - session.add(workflow_run) - - return workflow_run - - def _handle_workflow_run_success( + execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_EXECUTION_ID) or uuid4()) + execution = WorkflowExecution.new( + id_=execution_id, + workflow_id=self._workflow_info.workflow_id, + workflow_type=self._workflow_info.workflow_type, + workflow_version=self._workflow_info.version, + graph=self._workflow_info.graph_data, + inputs=inputs, + started_at=datetime.now(UTC).replace(tzinfo=None), + ) + + self._workflow_execution_repository.save(execution) + + return execution + + def handle_workflow_run_success( self, *, - session: Session, workflow_run_id: str, - start_at: float, total_tokens: int, total_steps: int, outputs: Mapping[str, Any] | None = None, conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, - ) -> WorkflowRun: - """ - Workflow run success - :param workflow_run_id: workflow run id - :param start_at: start time - :param total_tokens: total tokens - :param total_steps: total steps - :param outputs: outputs - :param conversation_id: conversation id - :return: - """ - workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) + ) -> WorkflowExecution: + workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) outputs = WorkflowEntry.handle_special_values(outputs) - workflow_run.status = WorkflowRunStatus.SUCCEEDED - workflow_run.outputs = json.dumps(outputs or {}) - workflow_run.elapsed_time = time.perf_counter() - start_at - workflow_run.total_tokens = total_tokens - workflow_run.total_steps = total_steps - workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) + workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED + workflow_execution.outputs = outputs or {} + workflow_execution.total_tokens = total_tokens + workflow_execution.total_steps = total_steps + workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) if trace_manager: trace_manager.add_trace_task( TraceTask( TraceTaskName.WORKFLOW_TRACE, - workflow_run=workflow_run, + workflow_execution=workflow_execution, conversation_id=conversation_id, user_id=trace_manager.user_id, ) ) - return workflow_run + self._workflow_execution_repository.save(workflow_execution) + return workflow_execution - def _handle_workflow_run_partial_success( + def handle_workflow_run_partial_success( self, *, - session: Session, workflow_run_id: str, - start_at: float, total_tokens: int, total_steps: int, outputs: Mapping[str, Any] | None = None, exceptions_count: int = 0, conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, - ) -> WorkflowRun: - workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) + ) -> WorkflowExecution: + execution = self._get_workflow_execution_or_raise_error(workflow_run_id) outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) - workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCEEDED.value - workflow_run.outputs = json.dumps(outputs or {}) - workflow_run.elapsed_time = time.perf_counter() - start_at - workflow_run.total_tokens = total_tokens - workflow_run.total_steps = total_steps - workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) - workflow_run.exceptions_count = exceptions_count + execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED + execution.outputs = outputs or {} + execution.total_tokens = total_tokens + execution.total_steps = total_steps + execution.finished_at = datetime.now(UTC).replace(tzinfo=None) + execution.exceptions_count = exceptions_count if trace_manager: trace_manager.add_trace_task( TraceTask( TraceTaskName.WORKFLOW_TRACE, - workflow_run=workflow_run, + workflow_execution=execution, conversation_id=conversation_id, user_id=trace_manager.user_id, ) ) - return workflow_run + self._workflow_execution_repository.save(execution) + return execution - def _handle_workflow_run_failed( + def handle_workflow_run_failed( self, *, - session: Session, workflow_run_id: str, - start_at: float, total_tokens: int, total_steps: int, - status: WorkflowRunStatus, - error: str, + status: WorkflowExecutionStatus, + error_message: str, conversation_id: Optional[str] = None, trace_manager: Optional[TraceQueueManager] = None, exceptions_count: int = 0, - ) -> WorkflowRun: - """ - Workflow run failed - :param workflow_run_id: workflow run id - :param start_at: start time - :param total_tokens: total tokens - :param total_steps: total steps - :param status: status - :param error: error message - :return: - """ - workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id) + ) -> WorkflowExecution: + workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - workflow_run.status = status.value - workflow_run.error = error - workflow_run.elapsed_time = time.perf_counter() - start_at - workflow_run.total_tokens = total_tokens - workflow_run.total_steps = total_steps - workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) - workflow_run.exceptions_count = exceptions_count + workflow_execution.status = WorkflowExecutionStatus(status.value) + workflow_execution.error_message = error_message + workflow_execution.total_tokens = total_tokens + workflow_execution.total_steps = total_steps + workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) + workflow_execution.exceptions_count = exceptions_count # Use the instance repository to find running executions for a workflow run - running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions( - workflow_run_id=workflow_run.id + running_node_executions = self._workflow_node_execution_repository.get_running_executions( + workflow_run_id=workflow_execution.id_ ) - # Update the cache with the retrieved executions - for execution in running_workflow_node_executions: - if execution.node_execution_id: - self._workflow_node_executions[execution.node_execution_id] = execution + # Update the domain models + now = datetime.now(UTC).replace(tzinfo=None) + for node_execution in running_node_executions: + if node_execution.node_execution_id: + # Update the domain model + node_execution.status = WorkflowNodeExecutionStatus.FAILED + node_execution.error = error_message + node_execution.finished_at = now + node_execution.elapsed_time = (now - node_execution.created_at).total_seconds() - for workflow_node_execution in running_workflow_node_executions: - now = datetime.now(UTC).replace(tzinfo=None) - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.finished_at = now - workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds() + # Update the repository with the domain model + self._workflow_node_execution_repository.save(node_execution) if trace_manager: trace_manager.add_trace_task( TraceTask( TraceTaskName.WORKFLOW_TRACE, - workflow_run=workflow_run, + workflow_execution=workflow_execution, conversation_id=conversation_id, user_id=trace_manager.user_id, ) ) - return workflow_run + self._workflow_execution_repository.save(workflow_execution) + return workflow_execution - def _handle_node_execution_start( - self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent + def handle_node_execution_start( + self, + *, + workflow_execution_id: str, + event: QueueNodeStartedEvent, ) -> WorkflowNodeExecution: - workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.id = str(uuid4()) - workflow_node_execution.tenant_id = workflow_run.tenant_id - workflow_node_execution.app_id = workflow_run.app_id - workflow_node_execution.workflow_id = workflow_run.workflow_id - workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value - workflow_node_execution.workflow_run_id = workflow_run.id - workflow_node_execution.predecessor_node_id = event.predecessor_node_id - workflow_node_execution.index = event.node_run_index - workflow_node_execution.node_execution_id = event.node_execution_id - workflow_node_execution.node_id = event.node_id - workflow_node_execution.node_type = event.node_type.value - workflow_node_execution.title = event.node_data.title - workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value - workflow_node_execution.created_by_role = workflow_run.created_by_role - workflow_node_execution.created_by = workflow_run.created_by - workflow_node_execution.execution_metadata = json.dumps( - { - NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, - NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, - NodeRunMetadataKey.LOOP_ID: event.in_loop_id, - } + workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) + + # Create a domain model + created_at = datetime.now(UTC).replace(tzinfo=None) + metadata = { + WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id, + WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id, + } + + domain_execution = WorkflowNodeExecution( + id=str(uuid4()), + workflow_id=workflow_execution.workflow_id, + workflow_execution_id=workflow_execution.id_, + predecessor_node_id=event.predecessor_node_id, + index=event.node_run_index, + node_execution_id=event.node_execution_id, + node_id=event.node_id, + node_type=event.node_type, + title=event.node_data.title, + status=WorkflowNodeExecutionStatus.RUNNING, + metadata=metadata, + created_at=created_at, ) - workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - # Use the instance repository to save the workflow node execution - self._workflow_node_execution_repository.save(workflow_node_execution) + # Use the instance repository to save the domain model + self._workflow_node_execution_repository.save(domain_execution) - self._workflow_node_executions[event.node_execution_id] = workflow_node_execution - return workflow_node_execution + return domain_execution - def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: - workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) + def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: + # Get the domain model from repository + domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) + if not domain_execution: + raise ValueError(f"Domain node execution not found: {event.node_execution_id}") + + # Process data inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) - execution_metadata_dict = dict(event.execution_metadata or {}) - execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None + + # Convert metadata keys to strings + execution_metadata_dict = {} + if event.execution_metadata: + for key, value in event.execution_metadata.items(): + execution_metadata_dict[key] = value + finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - process_data = WorkflowEntry.handle_special_values(event.process_data) + # Update domain model + domain_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED + domain_execution.update_from_mapping( + inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict + ) + domain_execution.finished_at = finished_at + domain_execution.elapsed_time = elapsed_time - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None - workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = execution_metadata - workflow_node_execution.finished_at = finished_at - workflow_node_execution.elapsed_time = elapsed_time + # Update the repository with the domain model + self._workflow_node_execution_repository.save(domain_execution) - # Use the instance repository to update the workflow node execution - self._workflow_node_execution_repository.update(workflow_node_execution) - return workflow_node_execution + return domain_execution - def _handle_workflow_node_execution_failed( + def handle_workflow_node_execution_failed( self, *, event: QueueNodeFailedEvent @@ -357,592 +281,96 @@ class WorkflowCycleManager: :param event: queue node failed event :return: """ - workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) + # Get the domain model from repository + domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) + if not domain_execution: + raise ValueError(f"Domain node execution not found: {event.node_execution_id}") + # Process data inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) + + # Convert metadata keys to strings + execution_metadata_dict = {} + if event.execution_metadata: + for key, value in event.execution_metadata.items(): + execution_metadata_dict[key] = value + finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - execution_metadata = ( - json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None - ) - process_data = WorkflowEntry.handle_special_values(event.process_data) - workflow_node_execution.status = ( - WorkflowNodeExecutionStatus.FAILED.value + + # Update domain model + domain_execution.status = ( + WorkflowNodeExecutionStatus.FAILED if not isinstance(event, QueueNodeExceptionEvent) - else WorkflowNodeExecutionStatus.EXCEPTION.value + else WorkflowNodeExecutionStatus.EXCEPTION ) - workflow_node_execution.error = event.error - workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None - workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.finished_at = finished_at - workflow_node_execution.elapsed_time = elapsed_time - workflow_node_execution.execution_metadata = execution_metadata + domain_execution.error = event.error + domain_execution.update_from_mapping( + inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict + ) + domain_execution.finished_at = finished_at + domain_execution.elapsed_time = elapsed_time - self._workflow_node_execution_repository.update(workflow_node_execution) + # Update the repository with the domain model + self._workflow_node_execution_repository.save(domain_execution) - return workflow_node_execution + return domain_execution - def _handle_workflow_node_execution_retried( - self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent + def handle_workflow_node_execution_retried( + self, *, workflow_execution_id: str, event: QueueNodeRetryEvent ) -> WorkflowNodeExecution: - """ - Workflow node execution failed - :param workflow_run: workflow run - :param event: queue node failed event - :return: - """ + workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) created_at = event.start_at finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - created_at).total_seconds() inputs = WorkflowEntry.handle_special_values(event.inputs) outputs = WorkflowEntry.handle_special_values(event.outputs) + + # Convert metadata keys to strings origin_metadata = { - NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, - NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, - NodeRunMetadataKey.LOOP_ID: event.in_loop_id, + WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id, + WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id, } - merged_metadata = ( - {**jsonable_encoder(event.execution_metadata), **origin_metadata} - if event.execution_metadata is not None - else origin_metadata - ) - execution_metadata = json.dumps(merged_metadata) - - workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.id = str(uuid4()) - workflow_node_execution.tenant_id = workflow_run.tenant_id - workflow_node_execution.app_id = workflow_run.app_id - workflow_node_execution.workflow_id = workflow_run.workflow_id - workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value - workflow_node_execution.workflow_run_id = workflow_run.id - workflow_node_execution.predecessor_node_id = event.predecessor_node_id - workflow_node_execution.node_execution_id = event.node_execution_id - workflow_node_execution.node_id = event.node_id - workflow_node_execution.node_type = event.node_type.value - workflow_node_execution.title = event.node_data.title - workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value - workflow_node_execution.created_by_role = workflow_run.created_by_role - workflow_node_execution.created_by = workflow_run.created_by - workflow_node_execution.created_at = created_at - workflow_node_execution.finished_at = finished_at - workflow_node_execution.elapsed_time = elapsed_time - workflow_node_execution.error = event.error - workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = execution_metadata - workflow_node_execution.index = event.node_run_index - - # Use the instance repository to save the workflow node execution - self._workflow_node_execution_repository.save(workflow_node_execution) - - self._workflow_node_executions[event.node_execution_id] = workflow_node_execution - return workflow_node_execution - - def _workflow_start_to_stream_response( - self, - *, - session: Session, - task_id: str, - workflow_run: WorkflowRun, - ) -> WorkflowStartStreamResponse: - _ = session - return WorkflowStartStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run.id, - data=WorkflowStartStreamResponse.Data( - id=workflow_run.id, - workflow_id=workflow_run.workflow_id, - sequence_number=workflow_run.sequence_number, - inputs=dict(workflow_run.inputs_dict or {}), - created_at=int(workflow_run.created_at.timestamp()), - ), - ) - def _workflow_finish_to_stream_response( - self, - *, - session: Session, - task_id: str, - workflow_run: WorkflowRun, - ) -> WorkflowFinishStreamResponse: - created_by = None - if workflow_run.created_by_role == CreatedByRole.ACCOUNT: - stmt = select(Account).where(Account.id == workflow_run.created_by) - account = session.scalar(stmt) - if account: - created_by = { - "id": account.id, - "name": account.name, - "email": account.email, - } - elif workflow_run.created_by_role == CreatedByRole.END_USER: - stmt = select(EndUser).where(EndUser.id == workflow_run.created_by) - end_user = session.scalar(stmt) - if end_user: - created_by = { - "id": end_user.id, - "user": end_user.session_id, - } - else: - raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}") - - return WorkflowFinishStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run.id, - data=WorkflowFinishStreamResponse.Data( - id=workflow_run.id, - workflow_id=workflow_run.workflow_id, - sequence_number=workflow_run.sequence_number, - status=workflow_run.status, - outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None, - error=workflow_run.error, - elapsed_time=workflow_run.elapsed_time, - total_tokens=workflow_run.total_tokens, - total_steps=workflow_run.total_steps, - created_by=created_by, - created_at=int(workflow_run.created_at.timestamp()), - finished_at=int(workflow_run.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)), - exceptions_count=workflow_run.exceptions_count, - ), + # Convert execution metadata keys to strings + execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {} + if event.execution_metadata: + for key, value in event.execution_metadata.items(): + execution_metadata_dict[key] = value + + merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata + + # Create a domain model + domain_execution = WorkflowNodeExecution( + id=str(uuid4()), + workflow_id=workflow_execution.workflow_id, + workflow_execution_id=workflow_execution.id_, + predecessor_node_id=event.predecessor_node_id, + node_execution_id=event.node_execution_id, + node_id=event.node_id, + node_type=event.node_type, + title=event.node_data.title, + status=WorkflowNodeExecutionStatus.RETRY, + created_at=created_at, + finished_at=finished_at, + elapsed_time=elapsed_time, + error=event.error, + index=event.node_run_index, ) - def _workflow_node_start_to_stream_response( - self, - *, - event: QueueNodeStartedEvent, - task_id: str, - workflow_node_execution: WorkflowNodeExecution, - ) -> Optional[NodeStartStreamResponse]: - if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: - return None - if not workflow_node_execution.workflow_run_id: - return None - - response = NodeStartStreamResponse( - task_id=task_id, - workflow_run_id=workflow_node_execution.workflow_run_id, - data=NodeStartStreamResponse.Data( - id=workflow_node_execution.id, - node_id=workflow_node_execution.node_id, - node_type=workflow_node_execution.node_type, - title=workflow_node_execution.title, - index=workflow_node_execution.index, - predecessor_node_id=workflow_node_execution.predecessor_node_id, - inputs=workflow_node_execution.inputs_dict, - created_at=int(workflow_node_execution.created_at.timestamp()), - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - parent_parallel_id=event.parent_parallel_id, - parent_parallel_start_node_id=event.parent_parallel_start_node_id, - iteration_id=event.in_iteration_id, - loop_id=event.in_loop_id, - parallel_run_id=event.parallel_mode_run_id, - agent_strategy=event.agent_strategy, - ), - ) + # Update with mappings + domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata) - # extras logic - if event.node_type == NodeType.TOOL: - node_data = cast(ToolNodeData, event.node_data) - response.data.extras["icon"] = ToolManager.get_tool_icon( - tenant_id=self._application_generate_entity.app_config.tenant_id, - provider_type=node_data.provider_type, - provider_id=node_data.provider_id, - ) - - return response + # Use the instance repository to save the domain model + self._workflow_node_execution_repository.save(domain_execution) - def _workflow_node_finish_to_stream_response( - self, - *, - event: QueueNodeSucceededEvent - | QueueNodeFailedEvent - | QueueNodeInIterationFailedEvent - | QueueNodeInLoopFailedEvent - | QueueNodeExceptionEvent, - task_id: str, - workflow_node_execution: WorkflowNodeExecution, - ) -> Optional[NodeFinishStreamResponse]: - if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: - return None - if not workflow_node_execution.workflow_run_id: - return None - if not workflow_node_execution.finished_at: - return None - - return NodeFinishStreamResponse( - task_id=task_id, - workflow_run_id=workflow_node_execution.workflow_run_id, - data=NodeFinishStreamResponse.Data( - id=workflow_node_execution.id, - node_id=workflow_node_execution.node_id, - node_type=workflow_node_execution.node_type, - index=workflow_node_execution.index, - title=workflow_node_execution.title, - predecessor_node_id=workflow_node_execution.predecessor_node_id, - inputs=workflow_node_execution.inputs_dict, - process_data=workflow_node_execution.process_data_dict, - outputs=workflow_node_execution.outputs_dict, - status=workflow_node_execution.status, - error=workflow_node_execution.error, - elapsed_time=workflow_node_execution.elapsed_time, - execution_metadata=workflow_node_execution.execution_metadata_dict, - created_at=int(workflow_node_execution.created_at.timestamp()), - finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - parent_parallel_id=event.parent_parallel_id, - parent_parallel_start_node_id=event.parent_parallel_start_node_id, - iteration_id=event.in_iteration_id, - loop_id=event.in_loop_id, - ), - ) - - def _workflow_node_retry_to_stream_response( - self, - *, - event: QueueNodeRetryEvent, - task_id: str, - workflow_node_execution: WorkflowNodeExecution, - ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: - if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: - return None - if not workflow_node_execution.workflow_run_id: - return None - if not workflow_node_execution.finished_at: - return None - - return NodeRetryStreamResponse( - task_id=task_id, - workflow_run_id=workflow_node_execution.workflow_run_id, - data=NodeRetryStreamResponse.Data( - id=workflow_node_execution.id, - node_id=workflow_node_execution.node_id, - node_type=workflow_node_execution.node_type, - index=workflow_node_execution.index, - title=workflow_node_execution.title, - predecessor_node_id=workflow_node_execution.predecessor_node_id, - inputs=workflow_node_execution.inputs_dict, - process_data=workflow_node_execution.process_data_dict, - outputs=workflow_node_execution.outputs_dict, - status=workflow_node_execution.status, - error=workflow_node_execution.error, - elapsed_time=workflow_node_execution.elapsed_time, - execution_metadata=workflow_node_execution.execution_metadata_dict, - created_at=int(workflow_node_execution.created_at.timestamp()), - finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - parent_parallel_id=event.parent_parallel_id, - parent_parallel_start_node_id=event.parent_parallel_start_node_id, - iteration_id=event.in_iteration_id, - loop_id=event.in_loop_id, - retry_index=event.retry_index, - ), - ) - - def _workflow_parallel_branch_start_to_stream_response( - self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent - ) -> ParallelBranchStartStreamResponse: - _ = session - return ParallelBranchStartStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run.id, - data=ParallelBranchStartStreamResponse.Data( - parallel_id=event.parallel_id, - parallel_branch_id=event.parallel_start_node_id, - parent_parallel_id=event.parent_parallel_id, - parent_parallel_start_node_id=event.parent_parallel_start_node_id, - iteration_id=event.in_iteration_id, - loop_id=event.in_loop_id, - created_at=int(time.time()), - ), - ) - - def _workflow_parallel_branch_finished_to_stream_response( - self, - *, - session: Session, - task_id: str, - workflow_run: WorkflowRun, - event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, - ) -> ParallelBranchFinishedStreamResponse: - _ = session - return ParallelBranchFinishedStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run.id, - data=ParallelBranchFinishedStreamResponse.Data( - parallel_id=event.parallel_id, - parallel_branch_id=event.parallel_start_node_id, - parent_parallel_id=event.parent_parallel_id, - parent_parallel_start_node_id=event.parent_parallel_start_node_id, - iteration_id=event.in_iteration_id, - loop_id=event.in_loop_id, - status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed", - error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None, - created_at=int(time.time()), - ), - ) - - def _workflow_iteration_start_to_stream_response( - self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent - ) -> IterationNodeStartStreamResponse: - _ = session - return IterationNodeStartStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run.id, - data=IterationNodeStartStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=event.node_data.title, - created_at=int(time.time()), - extras={}, - inputs=event.inputs or {}, - metadata=event.metadata or {}, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - ), - ) - - def _workflow_iteration_next_to_stream_response( - self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent - ) -> IterationNodeNextStreamResponse: - _ = session - return IterationNodeNextStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run.id, - data=IterationNodeNextStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=event.node_data.title, - index=event.index, - pre_iteration_output=event.output, - created_at=int(time.time()), - extras={}, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - parallel_mode_run_id=event.parallel_mode_run_id, - duration=event.duration, - ), - ) - - def _workflow_iteration_completed_to_stream_response( - self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent - ) -> IterationNodeCompletedStreamResponse: - _ = session - return IterationNodeCompletedStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run.id, - data=IterationNodeCompletedStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=event.node_data.title, - outputs=event.outputs, - created_at=int(time.time()), - extras={}, - inputs=event.inputs or {}, - status=WorkflowNodeExecutionStatus.SUCCEEDED - if event.error is None - else WorkflowNodeExecutionStatus.FAILED, - error=None, - elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(), - total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, - execution_metadata=event.metadata, - finished_at=int(time.time()), - steps=event.steps, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - ), - ) - - def _workflow_loop_start_to_stream_response( - self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent - ) -> LoopNodeStartStreamResponse: - _ = session - return LoopNodeStartStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run.id, - data=LoopNodeStartStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=event.node_data.title, - created_at=int(time.time()), - extras={}, - inputs=event.inputs or {}, - metadata=event.metadata or {}, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - ), - ) - - def _workflow_loop_next_to_stream_response( - self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent - ) -> LoopNodeNextStreamResponse: - _ = session - return LoopNodeNextStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run.id, - data=LoopNodeNextStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=event.node_data.title, - index=event.index, - pre_loop_output=event.output, - created_at=int(time.time()), - extras={}, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - parallel_mode_run_id=event.parallel_mode_run_id, - duration=event.duration, - ), - ) - - def _workflow_loop_completed_to_stream_response( - self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent - ) -> LoopNodeCompletedStreamResponse: - _ = session - return LoopNodeCompletedStreamResponse( - task_id=task_id, - workflow_run_id=workflow_run.id, - data=LoopNodeCompletedStreamResponse.Data( - id=event.node_id, - node_id=event.node_id, - node_type=event.node_type.value, - title=event.node_data.title, - outputs=event.outputs, - created_at=int(time.time()), - extras={}, - inputs=event.inputs or {}, - status=WorkflowNodeExecutionStatus.SUCCEEDED - if event.error is None - else WorkflowNodeExecutionStatus.FAILED, - error=None, - elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(), - total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, - execution_metadata=event.metadata, - finished_at=int(time.time()), - steps=event.steps, - parallel_id=event.parallel_id, - parallel_start_node_id=event.parallel_start_node_id, - ), - ) - - def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]: - """ - Fetch files from node outputs - :param outputs_dict: node outputs dict - :return: - """ - if not outputs_dict: - return [] - - files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()] - # Remove None - files = [file for file in files if file] - # Flatten list - # Flatten the list of sequences into a single list of mappings - flattened_files = [file for sublist in files if sublist for file in sublist] - - # Convert to tuple to match Sequence type - return tuple(flattened_files) - - def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]: - """ - Fetch files from variable value - :param value: variable value - :return: - """ - if not value: - return [] - - files = [] - if isinstance(value, list): - for item in value: - file = self._get_file_var_from_value(item) - if file: - files.append(file) - elif isinstance(value, dict): - file = self._get_file_var_from_value(value) - if file: - files.append(file) - - return files - - def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None: - """ - Get file var from value - :param value: variable value - :return: - """ - if not value: - return None - - if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: - return value - elif isinstance(value, File): - return value.to_dict() - - return None - - def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun: - if self._workflow_run and self._workflow_run.id == workflow_run_id: - cached_workflow_run = self._workflow_run - cached_workflow_run = session.merge(cached_workflow_run) - return cached_workflow_run - stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id) - workflow_run = session.scalar(stmt) - if not workflow_run: - raise WorkflowRunNotFoundError(workflow_run_id) - self._workflow_run = workflow_run - - return workflow_run - - def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: - # First check the cache for performance - if node_execution_id in self._workflow_node_executions: - cached_execution = self._workflow_node_executions[node_execution_id] - # No need to merge with session since expire_on_commit=False - return cached_execution - - # If not in cache, use the instance repository to get by node_execution_id - execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id) + return domain_execution + def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution: + execution = self._workflow_execution_repository.get(id) if not execution: - raise ValueError(f"Workflow node execution not found: {node_execution_id}") - - # Update cache - self._workflow_node_executions[node_execution_id] = execution + raise WorkflowRunNotFoundError(id) return execution - - def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: - """ - Handle agent log - :param task_id: task id - :param event: agent log event - :return: - """ - return AgentLogStreamResponse( - task_id=task_id, - data=AgentLogStreamResponse.Data( - node_execution_id=event.node_execution_id, - id=event.id, - parent_id=event.parent_id, - label=event.label, - error=event.error, - status=event.status, - data=event.data, - metadata=event.metadata, - node_id=event.node_id, - ), - ) diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 26bd6b3577..a837552007 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -70,6 +70,7 @@ def init_app(app: DifyApp) -> Celery: "schedule.update_tidb_serverless_status_task", "schedule.clean_messages", "schedule.mail_clean_document_notify_task", + "schedule.queue_monitor_task", ] day = dify_config.CELERY_BEAT_SCHEDULER_TIME beat_schedule = { @@ -98,6 +99,12 @@ def init_app(app: DifyApp) -> Celery: "task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task", "schedule": crontab(minute="0", hour="10", day_of_week="1"), }, + "datasets-queue-monitor": { + "task": "schedule.queue_monitor_task.queue_monitor_task", + "schedule": timedelta( + minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30 + ), + }, } celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index aa55862b7c..79d49aba5e 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -39,6 +39,10 @@ def init_app(app: DifyApp): handlers=log_handlers, force=True, ) + + # Apply RequestIdFormatter to all handlers + apply_request_id_formatter() + # Disable propagation for noisy loggers to avoid duplicate logs logging.getLogger("sqlalchemy.engine").propagate = False log_tz = dify_config.LOG_TZ @@ -74,3 +78,16 @@ class RequestIdFilter(logging.Filter): def filter(self, record): record.req_id = get_request_id() if flask.has_request_context() else "" return True + + +class RequestIdFormatter(logging.Formatter): + def format(self, record): + if not hasattr(record, "req_id"): + record.req_id = "" + return super().format(record) + + +def apply_request_id_formatter(): + for handler in logging.root.handlers: + if handler.formatter: + handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT) diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index 10fb89eb73..3b4d787d01 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -3,11 +3,14 @@ import json import flask_login # type: ignore from flask import Response, request from flask_login import user_loaded_from_request, user_logged_in -from werkzeug.exceptions import Unauthorized +from werkzeug.exceptions import NotFound, Unauthorized -import contexts +from configs import dify_config from dify_app import DifyApp +from extensions.ext_database import db from libs.passport import PassportService +from models.account import Account, Tenant, TenantAccountJoin +from models.model import EndUser from services.account_service import AccountService login_manager = flask_login.LoginManager() @@ -17,35 +20,72 @@ login_manager = flask_login.LoginManager() @login_manager.request_loader def load_user_from_request(request_from_flask_login): """Load user based on the request.""" - if request.blueprint not in {"console", "inner_api"}: - return None - # Check if the user_id contains a dot, indicating the old format auth_header = request.headers.get("Authorization", "") - if not auth_header: - auth_token = request.args.get("_token") - if not auth_token: - raise Unauthorized("Invalid Authorization token.") - else: + auth_token: str | None = None + if auth_header: if " " not in auth_header: raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - auth_scheme, auth_token = auth_header.split(None, 1) + auth_scheme, auth_token = auth_header.split(maxsplit=1) auth_scheme = auth_scheme.lower() if auth_scheme != "bearer": raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + else: + auth_token = request.args.get("_token") - decoded = PassportService().verify(auth_token) - user_id = decoded.get("user_id") + # Check for admin API key authentication first + if dify_config.ADMIN_API_KEY_ENABLE and auth_header: + admin_api_key = dify_config.ADMIN_API_KEY + if admin_api_key and admin_api_key == auth_token: + workspace_id = request.headers.get("X-WORKSPACE-ID") + if workspace_id: + tenant_account_join = ( + db.session.query(Tenant, TenantAccountJoin) + .filter(Tenant.id == workspace_id) + .filter(TenantAccountJoin.tenant_id == Tenant.id) + .filter(TenantAccountJoin.role == "owner") + .one_or_none() + ) + if tenant_account_join: + tenant, ta = tenant_account_join + account = db.session.query(Account).filter_by(id=ta.account_id).first() + if account: + account.current_tenant = tenant + return account - logged_in_account = AccountService.load_logged_in_account(account_id=user_id) - return logged_in_account + if request.blueprint in {"console", "inner_api"}: + if not auth_token: + raise Unauthorized("Invalid Authorization token.") + decoded = PassportService().verify(auth_token) + user_id = decoded.get("user_id") + source = decoded.get("token_source") + if source: + raise Unauthorized("Invalid Authorization token.") + if not user_id: + raise Unauthorized("Invalid Authorization token.") + + logged_in_account = AccountService.load_logged_in_account(account_id=user_id) + return logged_in_account + elif request.blueprint == "web": + decoded = PassportService().verify(auth_token) + end_user_id = decoded.get("end_user_id") + if not end_user_id: + raise Unauthorized("Invalid Authorization token.") + end_user = db.session.query(EndUser).filter(EndUser.id == decoded["end_user_id"]).first() + if not end_user: + raise NotFound("End user not found.") + return end_user @user_logged_in.connect @user_loaded_from_request.connect def on_user_logged_in(_sender, user): - """Called when a user logged in.""" - if user: - contexts.tenant_id.set(user.current_tenant_id) + """Called when a user logged in. + + Note: AccountService.load_logged_in_account will populate user.current_tenant_id + through the load_user method, which calls account.set_tenant_id(). + """ + # tenant_id context variable removed - using current_user.current_tenant_id directly + pass @login_manager.unauthorized_handler diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index 9240ebe7fc..df5d8a9c11 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -26,7 +26,7 @@ class Mail: match mail_type: case "resend": - import resend # type: ignore + import resend api_key = dify_config.RESEND_API_KEY if not api_key: @@ -54,6 +54,15 @@ class Mail: use_tls=dify_config.SMTP_USE_TLS, opportunistic_tls=dify_config.SMTP_OPPORTUNISTIC_TLS, ) + case "sendgrid": + from libs.sendgrid import SendGridClient + + if not dify_config.SENDGRID_API_KEY: + raise ValueError("SENDGRID_API_KEY is required for SendGrid mail type") + + self._client = SendGridClient( + sendgrid_api_key=dify_config.SENDGRID_API_KEY, _from=dify_config.MAIL_DEFAULT_SEND_FROM or "" + ) case _: raise ValueError("Unsupported mail type {}".format(mail_type)) diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index 3cbdc8560b..6dcfa7bec6 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -12,19 +12,30 @@ from flask_login import user_loaded_from_request, user_logged_in # type: ignore from configs import dify_config from dify_app import DifyApp +from models import Account, EndUser @user_logged_in.connect @user_loaded_from_request.connect -def on_user_loaded(_sender, user): +def on_user_loaded(_sender, user: Union["Account", "EndUser"]): if dify_config.ENABLE_OTEL: from opentelemetry.trace import get_current_span if user: - current_span = get_current_span() - if current_span: - current_span.set_attribute("service.tenant.id", user.current_tenant_id) - current_span.set_attribute("service.user.id", user.id) + try: + current_span = get_current_span() + if isinstance(user, Account) and user.current_tenant_id: + tenant_id = user.current_tenant_id + elif isinstance(user, EndUser): + tenant_id = user.tenant_id + else: + return + if current_span: + current_span.set_attribute("service.tenant.id", tenant_id) + current_span.set_attribute("service.user.id", user.id) + except Exception: + logging.exception("Error setting tenant and user attributes") + pass def init_app(app: DifyApp): @@ -47,21 +58,25 @@ def init_app(app: DifyApp): def response_hook(span: Span, status: str, response_headers: list): if span and span.is_recording(): - if status.startswith("2"): - span.set_status(StatusCode.OK) - else: - span.set_status(StatusCode.ERROR, status) - - status = status.split(" ")[0] - status_code = int(status) - status_class = f"{status_code // 100}xx" - attributes: dict[str, str | int] = {"status_code": status_code, "status_class": status_class} - request = flask.request - if request and request.url_rule: - attributes[SpanAttributes.HTTP_TARGET] = str(request.url_rule.rule) - if request and request.method: - attributes[SpanAttributes.HTTP_METHOD] = str(request.method) - _http_response_counter.add(1, attributes) + try: + if status.startswith("2"): + span.set_status(StatusCode.OK) + else: + span.set_status(StatusCode.ERROR, status) + + status = status.split(" ")[0] + status_code = int(status) + status_class = f"{status_code // 100}xx" + attributes: dict[str, str | int] = {"status_code": status_code, "status_class": status_class} + request = flask.request + if request and request.url_rule: + attributes[SpanAttributes.HTTP_TARGET] = str(request.url_rule.rule) + if request and request.method: + attributes[SpanAttributes.HTTP_METHOD] = str(request.method) + _http_response_counter.add(1, attributes) + except Exception: + logging.exception("Error setting status and attributes") + pass instrumentor = FlaskInstrumentor() if dify_config.DEBUG: @@ -92,7 +107,7 @@ def init_app(app: DifyApp): class ExceptionLoggingHandler(logging.Handler): """Custom logging handler that creates spans for logging.exception() calls""" - def emit(self, record): + def emit(self, record: logging.LogRecord): try: if record.exc_info: tracer = get_tracer_provider().get_tracer("dify.exception.logging") @@ -107,15 +122,20 @@ def init_app(app: DifyApp): }, ) as span: span.set_status(StatusCode.ERROR) - span.record_exception(record.exc_info[1]) - span.set_attribute("exception.type", record.exc_info[0].__name__) - span.set_attribute("exception.message", str(record.exc_info[1])) + if record.exc_info[1]: + span.record_exception(record.exc_info[1]) + span.set_attribute("exception.message", str(record.exc_info[1])) + if record.exc_info[0]: + span.set_attribute("exception.type", record.exc_info[0].__name__) + except Exception: pass from opentelemetry import trace - from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter - from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter + from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter from opentelemetry.instrumentation.celery import CeleryInstrumentor from opentelemetry.instrumentation.flask import FlaskInstrumentor from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor @@ -158,19 +178,32 @@ def init_app(app: DifyApp): sampler = ParentBasedTraceIdRatio(dify_config.OTEL_SAMPLING_RATE) provider = TracerProvider(resource=resource, sampler=sampler) set_tracer_provider(provider) - exporter: Union[OTLPSpanExporter, ConsoleSpanExporter] - metric_exporter: Union[OTLPMetricExporter, ConsoleMetricExporter] + exporter: Union[GRPCSpanExporter, HTTPSpanExporter, ConsoleSpanExporter] + metric_exporter: Union[GRPCMetricExporter, HTTPMetricExporter, ConsoleMetricExporter] + protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower() if dify_config.OTEL_EXPORTER_TYPE == "otlp": - exporter = OTLPSpanExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/traces", - headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"}, - ) - metric_exporter = OTLPMetricExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics", - headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"}, - ) + if protocol == "grpc": + exporter = GRPCSpanExporter( + endpoint=dify_config.OTLP_BASE_ENDPOINT, + # Header field names must consist of lowercase letters, check RFC7540 + headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),), + insecure=True, + ) + metric_exporter = GRPCMetricExporter( + endpoint=dify_config.OTLP_BASE_ENDPOINT, + headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),), + insecure=True, + ) + else: + exporter = HTTPSpanExporter( + endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/traces", + headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"}, + ) + metric_exporter = HTTPMetricExporter( + endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics", + headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"}, + ) else: - # Fallback to console exporter exporter = ConsoleSpanExporter() metric_exporter = ConsoleMetricExporter() diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index f8679f7e4b..c283b1b7ca 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -1,6 +1,7 @@ from typing import Any, Union import redis +from redis.cache import CacheConfig from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection from redis.sentinel import Sentinel @@ -51,6 +52,14 @@ def init_app(app: DifyApp): connection_class: type[Union[Connection, SSLConnection]] = Connection if dify_config.REDIS_USE_SSL: connection_class = SSLConnection + resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL + if dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE: + if resp_protocol >= 3: + clientside_cache_config = CacheConfig() + else: + raise ValueError("Client side cache is only supported in RESP3") + else: + clientside_cache_config = None redis_params: dict[str, Any] = { "username": dify_config.REDIS_USERNAME, @@ -59,6 +68,8 @@ def init_app(app: DifyApp): "encoding": "utf-8", "encoding_errors": "strict", "decode_responses": False, + "protocol": resp_protocol, + "cache_config": clientside_cache_config, } if dify_config.REDIS_USE_SENTINEL: @@ -82,14 +93,22 @@ def init_app(app: DifyApp): ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1])) for node in dify_config.REDIS_CLUSTERS.split(",") ] - # FIXME: mypy error here, try to figure out how to fix it - redis_client.initialize(RedisCluster(startup_nodes=nodes, password=dify_config.REDIS_CLUSTERS_PASSWORD)) # type: ignore + redis_client.initialize( + RedisCluster( + startup_nodes=nodes, + password=dify_config.REDIS_CLUSTERS_PASSWORD, + protocol=resp_protocol, + cache_config=clientside_cache_config, + ) + ) else: redis_params.update( { "host": dify_config.REDIS_HOST, "port": dify_config.REDIS_PORT, "connection_class": connection_class, + "protocol": resp_protocol, + "cache_config": clientside_cache_config, } ) pool = redis.ConnectionPool(**redis_params) diff --git a/api/extensions/ext_request_logging.py b/api/extensions/ext_request_logging.py new file mode 100644 index 0000000000..7c69483e0f --- /dev/null +++ b/api/extensions/ext_request_logging.py @@ -0,0 +1,73 @@ +import json +import logging + +import flask +import werkzeug.http +from flask import Flask +from flask.signals import request_finished, request_started + +from configs import dify_config + +_logger = logging.getLogger(__name__) + + +def _is_content_type_json(content_type: str) -> bool: + if not content_type: + return False + content_type_no_option, _ = werkzeug.http.parse_options_header(content_type) + return content_type_no_option.lower() == "application/json" + + +def _log_request_started(_sender, **_extra): + """Log the start of a request.""" + if not _logger.isEnabledFor(logging.DEBUG): + return + + request = flask.request + if not (_is_content_type_json(request.content_type) and request.data): + _logger.debug("Received Request %s -> %s", request.method, request.path) + return + try: + json_data = json.loads(request.data) + except (TypeError, ValueError): + _logger.exception("Failed to parse JSON request") + return + formatted_json = json.dumps(json_data, ensure_ascii=False, indent=2) + _logger.debug( + "Received Request %s -> %s, Request Body:\n%s", + request.method, + request.path, + formatted_json, + ) + + +def _log_request_finished(_sender, response, **_extra): + """Log the end of a request.""" + if not _logger.isEnabledFor(logging.DEBUG) or response is None: + return + + if not _is_content_type_json(response.content_type): + _logger.debug("Response %s %s", response.status, response.content_type) + return + + response_data = response.get_data(as_text=True) + try: + json_data = json.loads(response_data) + except (TypeError, ValueError): + _logger.exception("Failed to parse JSON response") + return + formatted_json = json.dumps(json_data, ensure_ascii=False, indent=2) + _logger.debug( + "Response %s %s, Response Body:\n%s", + response.status, + response.content_type, + formatted_json, + ) + + +def init_app(app: Flask): + """Initialize the request logging extension.""" + if not dify_config.ENABLE_REQUEST_LOGGING: + return + request_started.connect(_log_request_started, app) + request_finished.connect(_log_request_finished, app) diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index bbca8448ec..a41ef4ae4e 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -39,10 +39,6 @@ from core.variables.variables import ( from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID -class InvalidSelectorError(ValueError): - pass - - class UnsupportedSegmentTypeError(Exception): pass @@ -84,8 +80,8 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen raise VariableError("missing value type") if (value := mapping.get("value")) is None: raise VariableError("missing value") - # FIXME: using Any here, fix it later - result: Any + + result: Variable match value_type: case SegmentType.STRING: result = StringVariable.model_validate(mapping) diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 0b0e2a2f54..500ca47c7e 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -63,6 +63,7 @@ app_detail_fields = { "created_at": TimestampField, "updated_by": fields.String, "updated_at": TimestampField, + "access_mode": fields.String, } prompt_config_fields = { @@ -98,6 +99,9 @@ app_partial_fields = { "updated_by": fields.String, "updated_at": TimestampField, "tags": fields.List(fields.Nested(tag_fields)), + "access_mode": fields.String, + "create_user_name": fields.String, + "author_name": fields.String, } @@ -176,6 +180,7 @@ app_detail_fields_with_site = { "updated_by": fields.String, "updated_at": TimestampField, "deleted_tools": fields.List(fields.Nested(deleted_tool_fields)), + "access_mode": fields.String, } diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py index 74fdf8bd97..a106728e9c 100644 --- a/api/fields/workflow_run_fields.py +++ b/api/fields/workflow_run_fields.py @@ -19,7 +19,6 @@ workflow_run_for_log_fields = { workflow_run_for_list_fields = { "id": fields.String, - "sequence_number": fields.Integer, "version": fields.String, "status": fields.String, "elapsed_time": fields.Float, @@ -36,7 +35,6 @@ advanced_chat_workflow_run_for_list_fields = { "id": fields.String, "conversation_id": fields.String, "message_id": fields.String, - "sequence_number": fields.Integer, "version": fields.String, "status": fields.String, "elapsed_time": fields.Float, @@ -63,7 +61,6 @@ workflow_run_pagination_fields = { workflow_run_detail_fields = { "id": fields.String, - "sequence_number": fields.Integer, "version": fields.String, "graph": fields.Raw(attribute="graph_dict"), "inputs": fields.Raw(attribute="inputs_dict"), diff --git a/api/libs/flask_utils.py b/api/libs/flask_utils.py new file mode 100644 index 0000000000..4ea2779584 --- /dev/null +++ b/api/libs/flask_utils.py @@ -0,0 +1,65 @@ +import contextvars +from collections.abc import Iterator +from contextlib import contextmanager +from typing import TypeVar + +from flask import Flask, g, has_request_context + +T = TypeVar("T") + + +@contextmanager +def preserve_flask_contexts( + flask_app: Flask, + context_vars: contextvars.Context, +) -> Iterator[None]: + """ + A context manager that handles: + 1. flask-login's UserProxy copy + 2. ContextVars copy + 3. flask_app.app_context() + + This context manager ensures that the Flask application context is properly set up, + the current user is preserved across context boundaries, and any provided context variables + are set within the new context. + + Note: + This manager aims to allow use current_user cross thread and app context, + but it's not the recommend use, it's better to pass user directly in parameters. + + Args: + flask_app: The Flask application instance + context_vars: contextvars.Context object containing context variables to be set in the new context + + Yields: + None + + Example: + ```python + with preserve_flask_contexts(flask_app, context_vars=context_vars): + # Code that needs Flask app context and context variables + # Current user will be preserved if available + ``` + """ + # Set context variables if provided + if context_vars: + for var, val in context_vars.items(): + var.set(val) + + # Save current user before entering new app context + saved_user = None + if has_request_context() and hasattr(g, "_login_user"): + saved_user = g._login_user + + # Enter Flask app context + with flask_app.app_context(): + try: + # Restore user in new app context if it was saved + if saved_user is not None: + g._login_user = saved_user + + # Yield control back to the caller + yield + finally: + # Any cleanup can be added here if needed + pass diff --git a/api/libs/helper.py b/api/libs/helper.py index afc8f31681..3f2a630956 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -1,8 +1,9 @@ import json import logging -import random import re +import secrets import string +import struct import subprocess import time import uuid @@ -14,10 +15,12 @@ from zoneinfo import available_timezones from flask import Response, stream_with_context from flask_restful import fields +from pydantic import BaseModel from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.file import helpers as file_helpers +from core.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_redis import redis_client if TYPE_CHECKING: @@ -175,14 +178,14 @@ def generate_string(n): letters_digits = string.ascii_letters + string.digits result = "" for i in range(n): - result += random.choice(letters_digits) + result += secrets.choice(letters_digits) return result def extract_remote_ip(request) -> str: if request.headers.get("CF-Connecting-IP"): - return cast(str, request.headers.get("Cf-Connecting-Ip")) + return cast(str, request.headers.get("CF-Connecting-IP")) elif request.headers.getlist("X-Forwarded-For"): return cast(str, request.headers.getlist("X-Forwarded-For")[0]) else: @@ -196,7 +199,7 @@ def generate_text_hash(text: str) -> str: def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response: if isinstance(response, dict): - return Response(response=json.dumps(response), status=200, mimetype="application/json") + return Response(response=json.dumps(jsonable_encoder(response)), status=200, mimetype="application/json") else: def generate() -> Generator: @@ -205,6 +208,60 @@ def compact_generate_response(response: Union[Mapping, Generator, RateLimitGener return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") +def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response: + """ + This function is used to return a response with a length prefix. + Magic number is a one byte number that indicates the type of the response. + + For a compatibility with latest plugin daemon https://github.com/langgenius/dify-plugin-daemon/pull/341 + Avoid using line-based response, it leads a memory issue. + + We uses following format: + | Field | Size | Description | + |---------------|----------|---------------------------------| + | Magic Number | 1 byte | Magic number identifier | + | Reserved | 1 byte | Reserved field | + | Header Length | 2 bytes | Header length (usually 0xa) | + | Data Length | 4 bytes | Length of the data | + | Reserved | 6 bytes | Reserved fields | + | Data | Variable | Actual data content | + + | Reserved Fields | Header | Data | + |-----------------|----------|----------| + | 4 bytes total | Variable | Variable | + + all data is in little endian + """ + + def pack_response_with_length_prefix(response: bytes) -> bytes: + header_length = 0xA + data_length = len(response) + # | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data + return struct.pack(" Generator: + for chunk in response: + if isinstance(chunk, str): + yield pack_response_with_length_prefix(chunk.encode("utf-8")) + else: + yield pack_response_with_length_prefix(chunk) + + return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream") + + class TokenManager: @classmethod def generate_token( diff --git a/api/libs/login.py b/api/libs/login.py index be9478e850..e3a7fe2948 100644 --- a/api/libs/login.py +++ b/api/libs/login.py @@ -2,14 +2,11 @@ from functools import wraps from typing import Any from flask import current_app, g, has_request_context, request -from flask_login import user_logged_in # type: ignore from flask_login.config import EXEMPT_METHODS # type: ignore -from werkzeug.exceptions import Unauthorized from werkzeug.local import LocalProxy from configs import dify_config -from extensions.ext_database import db -from models.account import Account, Tenant, TenantAccountJoin +from models.account import Account from models.model import EndUser #: A proxy for the current user. If no user is logged in, this will be an @@ -53,36 +50,6 @@ def login_required(func): @wraps(func) def decorated_view(*args, **kwargs): - auth_header = request.headers.get("Authorization") - if dify_config.ADMIN_API_KEY_ENABLE: - if auth_header: - if " " not in auth_header: - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - auth_scheme, auth_token = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - - admin_api_key = dify_config.ADMIN_API_KEY - if admin_api_key: - if admin_api_key == auth_token: - workspace_id = request.headers.get("X-WORKSPACE-ID") - if workspace_id: - tenant_account_join = ( - db.session.query(Tenant, TenantAccountJoin) - .filter(Tenant.id == workspace_id) - .filter(TenantAccountJoin.tenant_id == Tenant.id) - .filter(TenantAccountJoin.role == "owner") - .one_or_none() - ) - if tenant_account_join: - tenant, ta = tenant_account_join - account = db.session.query(Account).filter_by(id=ta.account_id).first() - # Login admin - if account: - account.current_tenant = tenant - current_app.login_manager._update_request_context_with_user(account) # type: ignore - user_logged_in.send(current_app._get_current_object(), user=_get_user()) # type: ignore if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED: pass elif not current_user.is_authenticated: diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 1c151633f0..218109522d 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -61,13 +61,17 @@ class NotionOAuth(OAuthDataSource): "total": len(pages), } # save data source binding - data_source_binding = DataSourceOauthBinding.query.filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, + data_source_binding = ( + db.session.query(DataSourceOauthBinding) + .filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, + ) ) - ).first() + .first() + ) if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False @@ -97,13 +101,17 @@ class NotionOAuth(OAuthDataSource): "total": len(pages), } # save data source binding - data_source_binding = DataSourceOauthBinding.query.filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.access_token == access_token, + data_source_binding = ( + db.session.query(DataSourceOauthBinding) + .filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.access_token == access_token, + ) ) - ).first() + .first() + ) if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False @@ -121,14 +129,18 @@ class NotionOAuth(OAuthDataSource): def sync_data_source(self, binding_id: str): # save data source binding - data_source_binding = DataSourceOauthBinding.query.filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.id == binding_id, - DataSourceOauthBinding.disabled == False, + data_source_binding = ( + db.session.query(DataSourceOauthBinding) + .filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.id == binding_id, + DataSourceOauthBinding.disabled == False, + ) ) - ).first() + .first() + ) if data_source_binding: # get all authorized pages pages = self.get_authorized_pages(data_source_binding.access_token) diff --git a/api/libs/sendgrid.py b/api/libs/sendgrid.py new file mode 100644 index 0000000000..5409e3eeeb --- /dev/null +++ b/api/libs/sendgrid.py @@ -0,0 +1,45 @@ +import logging + +import sendgrid # type: ignore +from python_http_client.exceptions import ForbiddenError, UnauthorizedError +from sendgrid.helpers.mail import Content, Email, Mail, To # type: ignore + + +class SendGridClient: + def __init__(self, sendgrid_api_key: str, _from: str): + self.sendgrid_api_key = sendgrid_api_key + self._from = _from + + def send(self, mail: dict): + logging.debug("Sending email with SendGrid") + + try: + _to = mail["to"] + + if not _to: + raise ValueError("SendGridClient: Cannot send email: recipient address is missing.") + + sg = sendgrid.SendGridAPIClient(api_key=self.sendgrid_api_key) + from_email = Email(self._from) + to_email = To(_to) + subject = mail["subject"] + content = Content("text/html", mail["html"]) + mail = Mail(from_email, to_email, subject, content) + mail_json = mail.get() # type: ignore + response = sg.client.mail.send.post(request_body=mail_json) + logging.debug(response.status_code) + logging.debug(response.body) + logging.debug(response.headers) + + except TimeoutError as e: + logging.exception("SendGridClient Timeout occurred while sending email") + raise + except (UnauthorizedError, ForbiddenError) as e: + logging.exception( + "SendGridClient Authentication failed. " + "Verify that your credentials and the 'from' email address are correct" + ) + raise + except Exception as e: + logging.exception(f"SendGridClient Unexpected error occurred while sending email to {_to}") + raise diff --git a/api/libs/smtp.py b/api/libs/smtp.py index 2325d69a41..35561f071c 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -28,7 +28,8 @@ class SMTPClient: else: smtp = smtplib.SMTP(self.server, self.port, timeout=10) - if self.username and self.password: + # Only authenticate if both username and password are non-empty + if self.username and self.password and self.username.strip() and self.password.strip(): smtp.login(self.username, self.password) msg = MIMEMultipart() diff --git a/api/migrations/versions/2025_05_14_1403-d28f2004b072_add_index_for_workflow_conversation_.py b/api/migrations/versions/2025_05_14_1403-d28f2004b072_add_index_for_workflow_conversation_.py new file mode 100644 index 0000000000..19f6c01655 --- /dev/null +++ b/api/migrations/versions/2025_05_14_1403-d28f2004b072_add_index_for_workflow_conversation_.py @@ -0,0 +1,33 @@ +"""add index for workflow_conversation_variables.conversation_id + +Revision ID: d28f2004b072 +Revises: 6a9f914f656c +Create Date: 2025-05-14 14:03:36.713828 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'd28f2004b072' +down_revision = '6a9f914f656c' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op: + batch_op.create_index(batch_op.f('workflow_conversation_variables_conversation_id_idx'), ['conversation_id'], unique=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('workflow_conversation_variables_conversation_id_idx')) + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_05_15_1531-2adcbe1f5dfb_add_workflowdraftvariable_model.py b/api/migrations/versions/2025_05_15_1531-2adcbe1f5dfb_add_workflowdraftvariable_model.py new file mode 100644 index 0000000000..5bf394b21c --- /dev/null +++ b/api/migrations/versions/2025_05_15_1531-2adcbe1f5dfb_add_workflowdraftvariable_model.py @@ -0,0 +1,51 @@ +"""add WorkflowDraftVariable model + +Revision ID: 2adcbe1f5dfb +Revises: d28f2004b072 +Create Date: 2025-05-15 15:31:03.128680 + +""" + +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = "2adcbe1f5dfb" +down_revision = "d28f2004b072" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "workflow_draft_variables", + sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("app_id", models.types.StringUUID(), nullable=False), + sa.Column("last_edited_at", sa.DateTime(), nullable=True), + sa.Column("node_id", sa.String(length=255), nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("description", sa.String(length=255), nullable=False), + sa.Column("selector", sa.String(length=255), nullable=False), + sa.Column("value_type", sa.String(length=20), nullable=False), + sa.Column("value", sa.Text(), nullable=False), + sa.Column("visible", sa.Boolean(), nullable=False), + sa.Column("editable", sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("workflow_draft_variables_pkey")), + sa.UniqueConstraint("app_id", "node_id", "name", name=op.f("workflow_draft_variables_app_id_key")), + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + # Dropping `workflow_draft_variables` also drops any index associated with it. + op.drop_table("workflow_draft_variables") + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_06_1424-4474872b0ee6_workflow_draft_varaibles_add_node_execution_id.py b/api/migrations/versions/2025_06_06_1424-4474872b0ee6_workflow_draft_varaibles_add_node_execution_id.py new file mode 100644 index 0000000000..d7a5d116c9 --- /dev/null +++ b/api/migrations/versions/2025_06_06_1424-4474872b0ee6_workflow_draft_varaibles_add_node_execution_id.py @@ -0,0 +1,60 @@ +"""`workflow_draft_varaibles` add `node_execution_id` column, add an index for `workflow_node_executions`. + +Revision ID: 4474872b0ee6 +Revises: 2adcbe1f5dfb +Create Date: 2025-06-06 14:24:44.213018 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '4474872b0ee6' +down_revision = '2adcbe1f5dfb' +branch_labels = None +depends_on = None + + +def upgrade(): + # `CREATE INDEX CONCURRENTLY` cannot run within a transaction, so use the `autocommit_block` + # context manager to wrap the index creation statement. + # Reference: + # + # - https://www.postgresql.org/docs/current/sql-createindex.html#:~:text=Another%20difference%20is,CREATE%20INDEX%20CONCURRENTLY%20cannot. + # - https://alembic.sqlalchemy.org/en/latest/api/runtime.html#alembic.runtime.migration.MigrationContext.autocommit_block + with op.get_context().autocommit_block(): + op.create_index( + op.f('workflow_node_executions_tenant_id_idx'), + "workflow_node_executions", + ['tenant_id', 'workflow_id', 'node_id', sa.literal_column('created_at DESC')], + unique=False, + postgresql_concurrently=True, + ) + + with op.batch_alter_table('workflow_draft_variables', schema=None) as batch_op: + batch_op.add_column(sa.Column('node_execution_id', models.types.StringUUID(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + # `DROP INDEX CONCURRENTLY` cannot run within a transaction, so use the `autocommit_block` + # context manager to wrap the index creation statement. + # Reference: + # + # - https://www.postgresql.org/docs/current/sql-createindex.html#:~:text=Another%20difference%20is,CREATE%20INDEX%20CONCURRENTLY%20cannot. + # - https://alembic.sqlalchemy.org/en/latest/api/runtime.html#alembic.runtime.migration.MigrationContext.autocommit_block + # `DROP INDEX CONCURRENTLY` cannot run within a transaction, so commit existing transactions first. + # Reference: + # + # https://www.postgresql.org/docs/current/sql-createindex.html#:~:text=Another%20difference%20is,CREATE%20INDEX%20CONCURRENTLY%20cannot. + with op.get_context().autocommit_block(): + op.drop_index(op.f('workflow_node_executions_tenant_id_idx'), postgresql_concurrently=True) + + with op.batch_alter_table('workflow_draft_variables', schema=None) as batch_op: + batch_op.drop_column('node_execution_id') + + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_06_19_1633-0ab65e1cc7fa_remove_sequence_number_from_workflow_.py b/api/migrations/versions/2025_06_19_1633-0ab65e1cc7fa_remove_sequence_number_from_workflow_.py new file mode 100644 index 0000000000..29fef77798 --- /dev/null +++ b/api/migrations/versions/2025_06_19_1633-0ab65e1cc7fa_remove_sequence_number_from_workflow_.py @@ -0,0 +1,66 @@ +"""remove sequence_number from workflow_runs + +Revision ID: 0ab65e1cc7fa +Revises: 4474872b0ee6 +Create Date: 2025-06-19 16:33:13.377215 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '0ab65e1cc7fa' +down_revision = '4474872b0ee6' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('workflow_run_tenant_app_sequence_idx')) + batch_op.drop_column('sequence_number') + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + # WARNING: This downgrade CANNOT recover the original sequence_number values! + # The original sequence numbers are permanently lost after the upgrade. + # This downgrade will regenerate sequence numbers based on created_at order, + # which may result in different values than the original sequence numbers. + # + # If you need to preserve original sequence numbers, use the alternative + # migration approach that creates a backup table before removal. + + # Step 1: Add sequence_number column as nullable first + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.add_column(sa.Column('sequence_number', sa.INTEGER(), autoincrement=False, nullable=True)) + + # Step 2: Populate sequence_number values based on created_at order within each app + # NOTE: This recreates sequence numbering logic but values will be different + # from the original sequence numbers that were removed in the upgrade + connection = op.get_bind() + connection.execute(sa.text(""" + UPDATE workflow_runs + SET sequence_number = subquery.row_num + FROM ( + SELECT id, ROW_NUMBER() OVER ( + PARTITION BY tenant_id, app_id + ORDER BY created_at, id + ) as row_num + FROM workflow_runs + ) subquery + WHERE workflow_runs.id = subquery.id + """)) + + # Step 3: Make the column NOT NULL and add the index + with op.batch_alter_table('workflow_runs', schema=None) as batch_op: + batch_op.alter_column('sequence_number', nullable=False) + batch_op.create_index(batch_op.f('workflow_run_tenant_app_sequence_idx'), ['tenant_id', 'app_id', 'sequence_number'], unique=False) + + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 2066481a61..83b50eb099 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -27,7 +27,7 @@ from .dataset import ( Whitelist, ) from .engine import db -from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom +from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom from .model import ( ApiRequest, ApiToken, @@ -84,11 +84,9 @@ from .workflow import ( Workflow, WorkflowAppLog, WorkflowAppLogCreatedFrom, - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, + WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowRun, - WorkflowRunStatus, WorkflowType, ) @@ -100,19 +98,19 @@ __all__ = [ "AccountStatus", "ApiRequest", "ApiToken", - "ApiToolProvider", # Added + "ApiToolProvider", "App", "AppAnnotationHitHistory", "AppAnnotationSetting", "AppDatasetJoin", "AppMode", "AppModelConfig", - "BuiltinToolProvider", # Added + "BuiltinToolProvider", "CeleryTask", "CeleryTaskSet", "Conversation", "ConversationVariable", - "CreatedByRole", + "CreatorUserRole", "DataSourceApiKeyAuthBinding", "DataSourceOauthBinding", "Dataset", @@ -171,11 +169,9 @@ __all__ = [ "Workflow", "WorkflowAppLog", "WorkflowAppLogCreatedFrom", - "WorkflowNodeExecution", - "WorkflowNodeExecutionStatus", + "WorkflowNodeExecutionModel", "WorkflowNodeExecutionTriggeredFrom", "WorkflowRun", - "WorkflowRunStatus", "WorkflowRunTriggeredFrom", "WorkflowToolProvider", "WorkflowType", diff --git a/api/models/account.py b/api/models/account.py index a0b8957fe1..7ffeefa980 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -1,9 +1,10 @@ import enum import json +from typing import Optional, cast from flask_login import UserMixin # type: ignore from sqlalchemy import func -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy.orm import Mapped, mapped_column, reconstructor from models.base import Base @@ -11,6 +12,66 @@ from .engine import db from .types import StringUUID +class TenantAccountRole(enum.StrEnum): + OWNER = "owner" + ADMIN = "admin" + EDITOR = "editor" + NORMAL = "normal" + DATASET_OPERATOR = "dataset_operator" + + @staticmethod + def is_valid_role(role: str) -> bool: + if not role: + return False + return role in { + TenantAccountRole.OWNER, + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.NORMAL, + TenantAccountRole.DATASET_OPERATOR, + } + + @staticmethod + def is_privileged_role(role: Optional["TenantAccountRole"]) -> bool: + if not role: + return False + return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} + + @staticmethod + def is_admin_role(role: Optional["TenantAccountRole"]) -> bool: + if not role: + return False + return role == TenantAccountRole.ADMIN + + @staticmethod + def is_non_owner_role(role: Optional["TenantAccountRole"]) -> bool: + if not role: + return False + return role in { + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.NORMAL, + TenantAccountRole.DATASET_OPERATOR, + } + + @staticmethod + def is_editing_role(role: Optional["TenantAccountRole"]) -> bool: + if not role: + return False + return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} + + @staticmethod + def is_dataset_edit_role(role: Optional["TenantAccountRole"]) -> bool: + if not role: + return False + return role in { + TenantAccountRole.OWNER, + TenantAccountRole.ADMIN, + TenantAccountRole.EDITOR, + TenantAccountRole.DATASET_OPERATOR, + } + + class AccountStatus(enum.StrEnum): PENDING = "pending" UNINITIALIZED = "uninitialized" @@ -40,54 +101,54 @@ class Account(UserMixin, Base): created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + @reconstructor + def init_on_load(self): + self.role: Optional[TenantAccountRole] = None + self._current_tenant: Optional[Tenant] = None + @property def is_password_set(self): return self.password is not None @property def current_tenant(self): - # FIXME: fix the type error later, because the type is important maybe cause some bugs - return self._current_tenant # type: ignore + return self._current_tenant @current_tenant.setter - def current_tenant(self, value: "Tenant"): - tenant = value - ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=self.id).first() + def current_tenant(self, tenant: "Tenant"): + ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first() if ta: - tenant.current_role = ta.role - else: - tenant = None # type: ignore - - self._current_tenant = tenant + self.role = TenantAccountRole(ta.role) + self._current_tenant = tenant + return + self._current_tenant = None @property def current_tenant_id(self) -> str | None: return self._current_tenant.id if self._current_tenant else None - @current_tenant_id.setter - def current_tenant_id(self, value: str): - try: - tenant_account_join = ( + def set_tenant_id(self, tenant_id: str): + tenant_account_join = cast( + tuple[Tenant, TenantAccountJoin], + ( db.session.query(Tenant, TenantAccountJoin) - .filter(Tenant.id == value) + .filter(Tenant.id == tenant_id) .filter(TenantAccountJoin.tenant_id == Tenant.id) .filter(TenantAccountJoin.account_id == self.id) .one_or_none() - ) + ), + ) - if tenant_account_join: - tenant, ta = tenant_account_join - tenant.current_role = ta.role - else: - tenant = None - except Exception: - tenant = None + if not tenant_account_join: + return + tenant, join = tenant_account_join + self.role = join.role self._current_tenant = tenant @property def current_role(self): - return self._current_tenant.current_role + return self.role def get_status(self) -> AccountStatus: status_str = self.status @@ -107,23 +168,23 @@ class Account(UserMixin, Base): # check current_user.current_tenant.current_role in ['admin', 'owner'] @property def is_admin_or_owner(self): - return TenantAccountRole.is_privileged_role(self._current_tenant.current_role) + return TenantAccountRole.is_privileged_role(self.role) @property def is_admin(self): - return TenantAccountRole.is_admin_role(self._current_tenant.current_role) + return TenantAccountRole.is_admin_role(self.role) @property def is_editor(self): - return TenantAccountRole.is_editing_role(self._current_tenant.current_role) + return TenantAccountRole.is_editing_role(self.role) @property def is_dataset_editor(self): - return TenantAccountRole.is_dataset_edit_role(self._current_tenant.current_role) + return TenantAccountRole.is_dataset_edit_role(self.role) @property def is_dataset_operator(self): - return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR + return self.role == TenantAccountRole.DATASET_OPERATOR class TenantStatus(enum.StrEnum): @@ -131,67 +192,7 @@ class TenantStatus(enum.StrEnum): ARCHIVE = "archive" -class TenantAccountRole(enum.StrEnum): - OWNER = "owner" - ADMIN = "admin" - EDITOR = "editor" - NORMAL = "normal" - DATASET_OPERATOR = "dataset_operator" - - @staticmethod - def is_valid_role(role: str) -> bool: - if not role: - return False - return role in { - TenantAccountRole.OWNER, - TenantAccountRole.ADMIN, - TenantAccountRole.EDITOR, - TenantAccountRole.NORMAL, - TenantAccountRole.DATASET_OPERATOR, - } - - @staticmethod - def is_privileged_role(role: str) -> bool: - if not role: - return False - return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} - - @staticmethod - def is_admin_role(role: str) -> bool: - if not role: - return False - return role == TenantAccountRole.ADMIN - - @staticmethod - def is_non_owner_role(role: str) -> bool: - if not role: - return False - return role in { - TenantAccountRole.ADMIN, - TenantAccountRole.EDITOR, - TenantAccountRole.NORMAL, - TenantAccountRole.DATASET_OPERATOR, - } - - @staticmethod - def is_editing_role(role: str) -> bool: - if not role: - return False - return role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN, TenantAccountRole.EDITOR} - - @staticmethod - def is_dataset_edit_role(role: str) -> bool: - if not role: - return False - return role in { - TenantAccountRole.OWNER, - TenantAccountRole.ADMIN, - TenantAccountRole.EDITOR, - TenantAccountRole.DATASET_OPERATOR, - } - - -class Tenant(db.Model): # type: ignore[name-defined] +class Tenant(Base): __tablename__ = "tenants" __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) @@ -220,7 +221,7 @@ class Tenant(db.Model): # type: ignore[name-defined] self.custom_config = json.dumps(value) -class TenantAccountJoin(db.Model): # type: ignore[name-defined] +class TenantAccountJoin(Base): __tablename__ = "tenant_account_joins" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), @@ -239,7 +240,7 @@ class TenantAccountJoin(db.Model): # type: ignore[name-defined] updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class AccountIntegrate(db.Model): # type: ignore[name-defined] +class AccountIntegrate(Base): __tablename__ = "account_integrates" __table_args__ = ( db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), @@ -256,7 +257,7 @@ class AccountIntegrate(db.Model): # type: ignore[name-defined] updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class InvitationCode(db.Model): # type: ignore[name-defined] +class InvitationCode(Base): __tablename__ = "invitation_codes" __table_args__ = ( db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), diff --git a/api/models/api_based_extension.py b/api/models/api_based_extension.py index 6b6d808710..5a70e18622 100644 --- a/api/models/api_based_extension.py +++ b/api/models/api_based_extension.py @@ -2,6 +2,7 @@ import enum from sqlalchemy import func +from .base import Base from .engine import db from .types import StringUUID @@ -13,7 +14,7 @@ class APIBasedExtensionPoint(enum.Enum): APP_MODERATION_OUTPUT = "app.moderation.output" -class APIBasedExtension(db.Model): # type: ignore[name-defined] +class APIBasedExtension(Base): __tablename__ = "api_based_extensions" __table_args__ = ( db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), diff --git a/api/models/base.py b/api/models/base.py index da9509301a..bd120f5487 100644 --- a/api/models/base.py +++ b/api/models/base.py @@ -1,5 +1,7 @@ -from sqlalchemy.orm import declarative_base +from sqlalchemy.orm import DeclarativeBase from models.engine import metadata -Base = declarative_base(metadata=metadata) + +class Base(DeclarativeBase): + metadata = metadata diff --git a/api/models/dataset.py b/api/models/dataset.py index d6708ac88b..ad43d6f371 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -22,6 +22,7 @@ from extensions.ext_storage import storage from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule from .account import Account +from .base import Base from .engine import db from .model import App, Tag, TagBinding, UploadFile from .types import StringUUID @@ -33,7 +34,7 @@ class DatasetPermissionEnum(enum.StrEnum): PARTIAL_TEAM = "partial_members" -class Dataset(db.Model): # type: ignore[name-defined] +class Dataset(Base): __tablename__ = "datasets" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_pkey"), @@ -92,7 +93,8 @@ class Dataset(db.Model): # type: ignore[name-defined] @property def latest_process_rule(self): return ( - DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) + db.session.query(DatasetProcessRule) + .filter(DatasetProcessRule.dataset_id == self.id) .order_by(DatasetProcessRule.created_at.desc()) .first() ) @@ -137,7 +139,8 @@ class Dataset(db.Model): # type: ignore[name-defined] @property def word_count(self): return ( - Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) + db.session.query(Document) + .with_entities(func.coalesce(func.sum(Document.word_count))) .filter(Document.dataset_id == self.id) .scalar() ) @@ -255,7 +258,7 @@ class Dataset(db.Model): # type: ignore[name-defined] return f"Vector_index_{normalized_dataset_id}_Node" -class DatasetProcessRule(db.Model): # type: ignore[name-defined] +class DatasetProcessRule(Base): __tablename__ = "dataset_process_rules" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), @@ -295,7 +298,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined] return None -class Document(db.Model): # type: ignore[name-defined] +class Document(Base): __tablename__ = "documents" __table_args__ = ( db.PrimaryKeyConstraint("id", name="document_pkey"), @@ -439,12 +442,13 @@ class Document(db.Model): # type: ignore[name-defined] @property def segment_count(self): - return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count() + return db.session.query(DocumentSegment).filter(DocumentSegment.document_id == self.id).count() @property def hit_count(self): return ( - DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) + db.session.query(DocumentSegment) + .with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) .filter(DocumentSegment.document_id == self.id) .scalar() ) @@ -635,7 +639,7 @@ class Document(db.Model): # type: ignore[name-defined] ) -class DocumentSegment(db.Model): # type: ignore[name-defined] +class DocumentSegment(Base): __tablename__ = "document_segments" __table_args__ = ( db.PrimaryKeyConstraint("id", name="document_segment_pkey"), @@ -786,7 +790,7 @@ class DocumentSegment(db.Model): # type: ignore[name-defined] return text -class ChildChunk(db.Model): # type: ignore[name-defined] +class ChildChunk(Base): __tablename__ = "child_chunks" __table_args__ = ( db.PrimaryKeyConstraint("id", name="child_chunk_pkey"), @@ -829,7 +833,7 @@ class ChildChunk(db.Model): # type: ignore[name-defined] return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first() -class AppDatasetJoin(db.Model): # type: ignore[name-defined] +class AppDatasetJoin(Base): __tablename__ = "app_dataset_joins" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), @@ -846,7 +850,7 @@ class AppDatasetJoin(db.Model): # type: ignore[name-defined] return db.session.get(App, self.app_id) -class DatasetQuery(db.Model): # type: ignore[name-defined] +class DatasetQuery(Base): __tablename__ = "dataset_queries" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), @@ -863,7 +867,7 @@ class DatasetQuery(db.Model): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) -class DatasetKeywordTable(db.Model): # type: ignore[name-defined] +class DatasetKeywordTable(Base): __tablename__ = "dataset_keyword_tables" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), @@ -891,7 +895,7 @@ class DatasetKeywordTable(db.Model): # type: ignore[name-defined] return dct # get dataset - dataset = Dataset.query.filter_by(id=self.dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first() if not dataset: return None if self.data_source_type == "database": @@ -908,7 +912,7 @@ class DatasetKeywordTable(db.Model): # type: ignore[name-defined] return None -class Embedding(db.Model): # type: ignore[name-defined] +class Embedding(Base): __tablename__ = "embeddings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="embedding_pkey"), @@ -932,7 +936,7 @@ class Embedding(db.Model): # type: ignore[name-defined] return cast(list[float], pickle.loads(self.embedding)) # noqa: S301 -class DatasetCollectionBinding(db.Model): # type: ignore[name-defined] +class DatasetCollectionBinding(Base): __tablename__ = "dataset_collection_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), @@ -947,7 +951,7 @@ class DatasetCollectionBinding(db.Model): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class TidbAuthBinding(db.Model): # type: ignore[name-defined] +class TidbAuthBinding(Base): __tablename__ = "tidb_auth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), @@ -967,7 +971,7 @@ class TidbAuthBinding(db.Model): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class Whitelist(db.Model): # type: ignore[name-defined] +class Whitelist(Base): __tablename__ = "whitelists" __table_args__ = ( db.PrimaryKeyConstraint("id", name="whitelists_pkey"), @@ -979,7 +983,7 @@ class Whitelist(db.Model): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class DatasetPermission(db.Model): # type: ignore[name-defined] +class DatasetPermission(Base): __tablename__ = "dataset_permissions" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), @@ -996,7 +1000,7 @@ class DatasetPermission(db.Model): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined] +class ExternalKnowledgeApis(Base): __tablename__ = "external_knowledge_apis" __table_args__ = ( db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), @@ -1049,7 +1053,7 @@ class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined] return dataset_bindings -class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] +class ExternalKnowledgeBindings(Base): __tablename__ = "external_knowledge_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), @@ -1070,7 +1074,7 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined] +class DatasetAutoDisableLog(Base): __tablename__ = "dataset_auto_disable_logs" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), @@ -1087,7 +1091,7 @@ class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class RateLimitLog(db.Model): # type: ignore[name-defined] +class RateLimitLog(Base): __tablename__ = "rate_limit_logs" __table_args__ = ( db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"), @@ -1102,7 +1106,7 @@ class RateLimitLog(db.Model): # type: ignore[name-defined] created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class DatasetMetadata(db.Model): # type: ignore[name-defined] +class DatasetMetadata(Base): __tablename__ = "dataset_metadatas" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"), @@ -1121,7 +1125,7 @@ class DatasetMetadata(db.Model): # type: ignore[name-defined] updated_by = db.Column(StringUUID, nullable=True) -class DatasetMetadataBinding(db.Model): # type: ignore[name-defined] +class DatasetMetadataBinding(Base): __tablename__ = "dataset_metadata_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"), diff --git a/api/models/enums.py b/api/models/enums.py index 7b9500ebe4..4434c3fec8 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -1,7 +1,7 @@ from enum import StrEnum -class CreatedByRole(StrEnum): +class CreatorUserRole(StrEnum): ACCOUNT = "account" END_USER = "end_user" @@ -14,3 +14,10 @@ class UserFrom(StrEnum): class WorkflowRunTriggeredFrom(StrEnum): DEBUGGING = "debugging" APP_RUN = "app-run" + + +class DraftVariableType(StrEnum): + # node means that the correspond variable + NODE = "node" + SYS = "sys" + CONVERSATION = "conversation" diff --git a/api/models/model.py b/api/models/model.py index fd05d67e9a..fa83baa9cf 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -9,14 +9,14 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast from core.plugin.entities.plugin import GenericProviderID from core.tools.entities.tool_entities import ToolProviderType from core.tools.signature import sign_tool_file -from services.plugin.plugin_service import PluginService +from core.workflow.entities.workflow_execution import WorkflowExecutionStatus if TYPE_CHECKING: from models.workflow import Workflow import sqlalchemy as sa from flask import request -from flask_login import UserMixin # type: ignore +from flask_login import UserMixin from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text from sqlalchemy.orm import Mapped, Session, mapped_column @@ -25,12 +25,11 @@ from constants import DEFAULT_FILE_NUMBER_LIMITS from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from core.file import helpers as file_helpers from libs.helper import generate_string -from models.base import Base -from models.enums import CreatedByRole -from models.workflow import WorkflowRunStatus from .account import Account, Tenant +from .base import Base from .engine import db +from .enums import CreatorUserRole from .types import StringUUID if TYPE_CHECKING: @@ -169,6 +168,7 @@ class App(Base): @property def deleted_tools(self) -> list: from core.tools.tool_manager import ToolManager + from services.plugin.plugin_service import PluginService # get agent mode tools app_model_config = self.app_model_config @@ -294,6 +294,15 @@ class App(Base): return tags or [] + @property + def author_name(self): + if self.created_by: + account = db.session.query(Account).filter(Account.id == self.created_by).first() + if account: + return account.name + + return None + class AppModelConfig(Base): __tablename__ = "app_model_configs" @@ -602,7 +611,7 @@ class InstalledApp(Base): return tenant -class Conversation(db.Model): # type: ignore[name-defined] +class Conversation(Base): __tablename__ = "conversations" __table_args__ = ( db.PrimaryKeyConstraint("id", name="conversation_pkey"), @@ -785,22 +794,22 @@ class Conversation(db.Model): # type: ignore[name-defined] def status_count(self): messages = db.session.query(Message).filter(Message.conversation_id == self.id).all() status_counts = { - WorkflowRunStatus.RUNNING: 0, - WorkflowRunStatus.SUCCEEDED: 0, - WorkflowRunStatus.FAILED: 0, - WorkflowRunStatus.STOPPED: 0, - WorkflowRunStatus.PARTIAL_SUCCEEDED: 0, + WorkflowExecutionStatus.RUNNING: 0, + WorkflowExecutionStatus.SUCCEEDED: 0, + WorkflowExecutionStatus.FAILED: 0, + WorkflowExecutionStatus.STOPPED: 0, + WorkflowExecutionStatus.PARTIAL_SUCCEEDED: 0, } for message in messages: if message.workflow_run: - status_counts[message.workflow_run.status] += 1 + status_counts[WorkflowExecutionStatus(message.workflow_run.status)] += 1 return ( { - "success": status_counts[WorkflowRunStatus.SUCCEEDED], - "failed": status_counts[WorkflowRunStatus.FAILED], - "partial_success": status_counts[WorkflowRunStatus.PARTIAL_SUCCEEDED], + "success": status_counts[WorkflowExecutionStatus.SUCCEEDED], + "failed": status_counts[WorkflowExecutionStatus.FAILED], + "partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED], } if messages else None @@ -864,7 +873,7 @@ class Conversation(db.Model): # type: ignore[name-defined] } -class Message(db.Model): # type: ignore[name-defined] +class Message(Base): __tablename__ = "messages" __table_args__ = ( PrimaryKeyConstraint("id", name="message_pkey"), @@ -1211,7 +1220,7 @@ class Message(db.Model): # type: ignore[name-defined] ) -class MessageFeedback(db.Model): # type: ignore[name-defined] +class MessageFeedback(Base): __tablename__ = "message_feedbacks" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), @@ -1237,8 +1246,23 @@ class MessageFeedback(db.Model): # type: ignore[name-defined] account = db.session.query(Account).filter(Account.id == self.from_account_id).first() return account + def to_dict(self): + return { + "id": str(self.id), + "app_id": str(self.app_id), + "conversation_id": str(self.conversation_id), + "message_id": str(self.message_id), + "rating": self.rating, + "content": self.content, + "from_source": self.from_source, + "from_end_user_id": str(self.from_end_user_id) if self.from_end_user_id else None, + "from_account_id": str(self.from_account_id) if self.from_account_id else None, + "created_at": self.created_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + } -class MessageFile(db.Model): # type: ignore[name-defined] + +class MessageFile(Base): __tablename__ = "message_files" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_file_pkey"), @@ -1255,7 +1279,7 @@ class MessageFile(db.Model): # type: ignore[name-defined] url: str | None = None, belongs_to: Literal["user", "assistant"] | None = None, upload_file_id: str | None = None, - created_by_role: CreatedByRole, + created_by_role: CreatorUserRole, created_by: str, ): self.message_id = message_id @@ -1279,7 +1303,7 @@ class MessageFile(db.Model): # type: ignore[name-defined] created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) -class MessageAnnotation(db.Model): # type: ignore[name-defined] +class MessageAnnotation(Base): __tablename__ = "message_annotations" __table_args__ = ( db.PrimaryKeyConstraint("id", name="message_annotation_pkey"), @@ -1310,7 +1334,7 @@ class MessageAnnotation(db.Model): # type: ignore[name-defined] return account -class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined] +class AppAnnotationHitHistory(Base): __tablename__ = "app_annotation_hit_histories" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), @@ -1322,7 +1346,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined] id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) app_id = db.Column(StringUUID, nullable=False) - annotation_id = db.Column(StringUUID, nullable=False) + annotation_id: Mapped[str] = db.Column(StringUUID, nullable=False) source = db.Column(db.Text, nullable=False) question = db.Column(db.Text, nullable=False) account_id = db.Column(StringUUID, nullable=False) @@ -1348,7 +1372,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined] return account -class AppAnnotationSetting(db.Model): # type: ignore[name-defined] +class AppAnnotationSetting(Base): __tablename__ = "app_annotation_settings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), @@ -1364,26 +1388,6 @@ class AppAnnotationSetting(db.Model): # type: ignore[name-defined] updated_user_id = db.Column(StringUUID, nullable=False) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - @property - def created_account(self): - account = ( - db.session.query(Account) - .join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id) - .filter(AppAnnotationSetting.id == self.annotation_id) - .first() - ) - return account - - @property - def updated_account(self): - account = ( - db.session.query(Account) - .join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id) - .filter(AppAnnotationSetting.id == self.annotation_id) - .first() - ) - return account - @property def collection_binding_detail(self): from .dataset import DatasetCollectionBinding @@ -1422,7 +1426,7 @@ class EndUser(Base, UserMixin): ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(255), nullable=False) external_user_id = db.Column(db.String(255), nullable=True) @@ -1552,7 +1556,7 @@ class UploadFile(Base): size: int, extension: str, mime_type: str, - created_by_role: CreatedByRole, + created_by_role: CreatorUserRole, created_by: str, created_at: datetime, used: bool, diff --git a/api/models/provider.py b/api/models/provider.py index 567400702d..1e25f0c90f 100644 --- a/api/models/provider.py +++ b/api/models/provider.py @@ -1,9 +1,11 @@ +from datetime import datetime from enum import Enum +from typing import Optional -from sqlalchemy import func - -from models.base import Base +from sqlalchemy import func, text +from sqlalchemy.orm import Mapped, mapped_column +from .base import Base from .engine import db from .types import StringUUID @@ -52,20 +54,24 @@ class Provider(Base): ), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - provider_type = db.Column(db.String(40), nullable=False, server_default=db.text("'custom'::character varying")) - encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - last_used = db.Column(db.DateTime, nullable=True) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + provider_type: Mapped[str] = mapped_column( + db.String(40), nullable=False, server_default=text("'custom'::character varying") + ) + encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + last_used: Mapped[Optional[datetime]] = mapped_column(db.DateTime, nullable=True) - quota_type = db.Column(db.String(40), nullable=True, server_default=db.text("''::character varying")) - quota_limit = db.Column(db.BigInteger, nullable=True) - quota_used = db.Column(db.BigInteger, default=0) + quota_type: Mapped[Optional[str]] = mapped_column( + db.String(40), nullable=True, server_default=text("''::character varying") + ) + quota_limit: Mapped[Optional[int]] = mapped_column(db.BigInteger, nullable=True) + quota_used: Mapped[Optional[int]] = mapped_column(db.BigInteger, default=0) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) def __repr__(self): return ( @@ -105,15 +111,15 @@ class ProviderModel(Base): ), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - encrypted_config = db.Column(db.Text, nullable=True) - is_valid = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + is_valid: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantDefaultModel(Base): @@ -123,13 +129,13 @@ class TenantDefaultModel(Base): db.Index("tenant_default_model_tenant_id_provider_type_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class TenantPreferredModelProvider(Base): @@ -139,12 +145,12 @@ class TenantPreferredModelProvider(Base): db.Index("tenant_preferred_model_provider_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - preferred_provider_type = db.Column(db.String(40), nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + preferred_provider_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderOrder(Base): @@ -154,22 +160,24 @@ class ProviderOrder(Base): db.Index("provider_order_tenant_provider_idx", "tenant_id", "provider_name"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - account_id = db.Column(StringUUID, nullable=False) - payment_product_id = db.Column(db.String(191), nullable=False) - payment_id = db.Column(db.String(191)) - transaction_id = db.Column(db.String(191)) - quantity = db.Column(db.Integer, nullable=False, server_default=db.text("1")) - currency = db.Column(db.String(40)) - total_amount = db.Column(db.Integer) - payment_status = db.Column(db.String(40), nullable=False, server_default=db.text("'wait_pay'::character varying")) - paid_at = db.Column(db.DateTime) - pay_failed_at = db.Column(db.DateTime) - refunded_at = db.Column(db.DateTime) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + account_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + payment_product_id: Mapped[str] = mapped_column(db.String(191), nullable=False) + payment_id: Mapped[Optional[str]] = mapped_column(db.String(191)) + transaction_id: Mapped[Optional[str]] = mapped_column(db.String(191)) + quantity: Mapped[int] = mapped_column(db.Integer, nullable=False, server_default=text("1")) + currency: Mapped[Optional[str]] = mapped_column(db.String(40)) + total_amount: Mapped[Optional[int]] = mapped_column(db.Integer) + payment_status: Mapped[str] = mapped_column( + db.String(40), nullable=False, server_default=text("'wait_pay'::character varying") + ) + paid_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + pay_failed_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + refunded_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class ProviderModelSetting(Base): @@ -183,15 +191,15 @@ class ProviderModelSetting(Base): db.Index("provider_model_setting_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - load_balancing_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) + load_balancing_enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("false")) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) class LoadBalancingModelConfig(Base): @@ -205,13 +213,13 @@ class LoadBalancingModelConfig(Base): db.Index("load_balancing_model_config_tenant_provider_model_idx", "tenant_id", "provider_name", "model_type"), ) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) - provider_name = db.Column(db.String(255), nullable=False) - model_name = db.Column(db.String(255), nullable=False) - model_type = db.Column(db.String(40), nullable=False) - name = db.Column(db.String(255), nullable=False) - encrypted_config = db.Column(db.Text, nullable=True) - enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + id: Mapped[str] = mapped_column(StringUUID, server_default=text("uuid_generate_v4()")) + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + provider_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_name: Mapped[str] = mapped_column(db.String(255), nullable=False) + model_type: Mapped[str] = mapped_column(db.String(40), nullable=False) + name: Mapped[str] = mapped_column(db.String(255), nullable=False) + encrypted_config: Mapped[Optional[str]] = mapped_column(db.Text, nullable=True) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=text("true")) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + updated_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) diff --git a/api/models/source.py b/api/models/source.py index b9d7d91346..f6e0900ae6 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -9,7 +9,7 @@ from .engine import db from .types import StringUUID -class DataSourceOauthBinding(db.Model): # type: ignore[name-defined] +class DataSourceOauthBinding(Base): __tablename__ = "data_source_oauth_bindings" __table_args__ = ( db.PrimaryKeyConstraint("id", name="source_binding_pkey"), diff --git a/api/models/tools.py b/api/models/tools.py index e027475e38..03fbc3acb1 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -172,10 +172,6 @@ class WorkflowToolProvider(Base): db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) - @property - def schema_type(self) -> ApiProviderSchemaType: - return ApiProviderSchemaType.value_of(self.schema_type_str) - @property def user(self) -> Account | None: return db.session.query(Account).filter(Account.id == self.user_id).first() diff --git a/api/models/types.py b/api/models/types.py index cb6773e70c..e5581c3ab0 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -1,4 +1,7 @@ -from sqlalchemy import CHAR, TypeDecorator +import enum +from typing import Generic, TypeVar + +from sqlalchemy import CHAR, VARCHAR, TypeDecorator from sqlalchemy.dialects.postgresql import UUID @@ -24,3 +27,51 @@ class StringUUID(TypeDecorator): if value is None: return value return str(value) + + +_E = TypeVar("_E", bound=enum.StrEnum) + + +class EnumText(TypeDecorator, Generic[_E]): + impl = VARCHAR + cache_ok = True + + _length: int + _enum_class: type[_E] + + def __init__(self, enum_class: type[_E], length: int | None = None): + self._enum_class = enum_class + max_enum_value_len = max(len(e.value) for e in enum_class) + if length is not None: + if length < max_enum_value_len: + raise ValueError("length should be greater than enum value length.") + self._length = length + else: + # leave some rooms for future longer enum values. + self._length = max(max_enum_value_len, 20) + + def process_bind_param(self, value: _E | str | None, dialect): + if value is None: + return value + if isinstance(value, self._enum_class): + return value.value + elif isinstance(value, str): + self._enum_class(value) + return value + else: + raise TypeError(f"expected str or {self._enum_class}, got {type(value)}") + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(VARCHAR(self._length)) + + def process_result_value(self, value, dialect) -> _E | None: + if value is None: + return value + if not isinstance(value, str): + raise TypeError(f"expected str, got {type(value)}") + return self._enum_class(value) + + def compare_values(self, x, y): + if x is None or y is None: + return x is y + return x == y diff --git a/api/models/workflow.py b/api/models/workflow.py index da60617de5..1733dec0fc 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,29 +1,37 @@ import json +import logging from collections.abc import Mapping, Sequence from datetime import UTC, datetime from enum import Enum, StrEnum -from typing import TYPE_CHECKING, Any, Optional, Self, Union +from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 +from flask_login import current_user + +from core.variables import utils as variable_utils +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from factories.variable_factory import build_segment + if TYPE_CHECKING: from models.model import AppMode import sqlalchemy as sa -from sqlalchemy import Index, PrimaryKeyConstraint, func -from sqlalchemy.orm import Mapped, mapped_column +from sqlalchemy import Index, PrimaryKeyConstraint, UniqueConstraint, func +from sqlalchemy.orm import Mapped, declared_attr, mapped_column -import contexts from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter -from core.variables import SecretVariable, Variable +from core.variables import SecretVariable, Segment, SegmentType, Variable from factories import variable_factory from libs import helper -from models.base import Base -from models.enums import CreatedByRole from .account import Account +from .base import Base from .engine import db -from .types import StringUUID +from .enums import CreatorUserRole, DraftVariableType +from .types import EnumText, StringUUID + +_logger = logging.getLogger(__name__) if TYPE_CHECKING: from models.model import AppMode @@ -143,7 +151,7 @@ class Workflow(Base): conversation_variables: Sequence[Variable], marked_name: str = "", marked_comment: str = "", - ) -> Self: + ) -> "Workflow": workflow = Workflow() workflow.id = str(uuid4()) workflow.tenant_id = tenant_id @@ -192,7 +200,9 @@ class Workflow(Base): features["file_upload"]["number_limits"] = image_number_limits features["file_upload"]["allowed_file_upload_methods"] = image_transfer_methods features["file_upload"]["allowed_file_types"] = features["file_upload"].get("allowed_file_types", ["image"]) - features["file_upload"]["allowed_file_extensions"] = [] + features["file_upload"]["allowed_file_extensions"] = features["file_upload"].get( + "allowed_file_extensions", [] + ) del features["file_upload"]["image"] self._features = json.dumps(features) return self._features @@ -265,7 +275,16 @@ class Workflow(Base): if self._environment_variables is None: self._environment_variables = "{}" - tenant_id = contexts.tenant_id.get() + # Get tenant_id from current_user (Account or EndUser) + if isinstance(current_user, Account): + # Account user + tenant_id = current_user.current_tenant_id + else: + # EndUser + tenant_id = current_user.tenant_id + + if not tenant_id: + return [] environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables) results = [ @@ -288,7 +307,17 @@ class Workflow(Base): self._environment_variables = "{}" return - tenant_id = contexts.tenant_id.get() + # Get tenant_id from current_user (Account or EndUser) + if isinstance(current_user, Account): + # Account user + tenant_id = current_user.current_tenant_id + else: + # EndUser + tenant_id = current_user.tenant_id + + if not tenant_id: + self._environment_variables = "{}" + return value = list(value) if any(var for var in value if not var.id): @@ -348,18 +377,6 @@ class Workflow(Base): ) -class WorkflowRunStatus(StrEnum): - """ - Workflow Run Status Enum - """ - - RUNNING = "running" - SUCCEEDED = "succeeded" - FAILED = "failed" - STOPPED = "stopped" - PARTIAL_SUCCEEDED = "partial-succeeded" - - class WorkflowRun(Base): """ Workflow Run @@ -369,7 +386,7 @@ class WorkflowRun(Base): - id (uuid) Run ID - tenant_id (uuid) Workspace ID - app_id (uuid) App ID - - sequence_number (int) Auto-increment sequence number, incremented within the App, starting from 1 + - workflow_id (uuid) Workflow ID - type (string) Workflow type - triggered_from (string) Trigger source @@ -402,13 +419,12 @@ class WorkflowRun(Base): __table_args__ = ( db.PrimaryKeyConstraint("id", name="workflow_run_pkey"), db.Index("workflow_run_triggerd_from_idx", "tenant_id", "app_id", "triggered_from"), - db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) - sequence_number: Mapped[int] = mapped_column() + workflow_id: Mapped[str] = mapped_column(StringUUID) type: Mapped[str] = mapped_column(db.String(255)) triggered_from: Mapped[str] = mapped_column(db.String(255)) @@ -418,29 +434,29 @@ class WorkflowRun(Base): status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") error: Mapped[Optional[str]] = mapped_column(db.Text) - elapsed_time = db.Column(db.Float, nullable=False, server_default=sa.text("0")) + elapsed_time: Mapped[float] = mapped_column(db.Float, nullable=False, server_default=sa.text("0")) total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0")) - total_steps = db.Column(db.Integer, server_default=db.text("0")) + total_steps: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - finished_at = db.Column(db.DateTime) - exceptions_count = db.Column(db.Integer, server_default=db.text("0")) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime) + exceptions_count: Mapped[int] = mapped_column(db.Integer, server_default=db.text("0"), nullable=True) @property def created_by_account(self): - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None + created_by_role = CreatorUserRole(self.created_by_role) + return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None + created_by_role = CreatorUserRole(self.created_by_role) + return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None @property - def graph_dict(self): + def graph_dict(self) -> Mapping[str, Any]: return json.loads(self.graph) if self.graph else {} @property @@ -468,7 +484,6 @@ class WorkflowRun(Base): "id": self.id, "tenant_id": self.tenant_id, "app_id": self.app_id, - "sequence_number": self.sequence_number, "workflow_id": self.workflow_id, "type": self.type, "triggered_from": self.triggered_from, @@ -494,7 +509,6 @@ class WorkflowRun(Base): id=data.get("id"), tenant_id=data.get("tenant_id"), app_id=data.get("app_id"), - sequence_number=data.get("sequence_number"), workflow_id=data.get("workflow_id"), type=data.get("type"), triggered_from=data.get("triggered_from"), @@ -524,19 +538,7 @@ class WorkflowNodeExecutionTriggeredFrom(StrEnum): WORKFLOW_RUN = "workflow-run" -class WorkflowNodeExecutionStatus(StrEnum): - """ - Workflow Node Execution Status Enum - """ - - RUNNING = "running" - SUCCEEDED = "succeeded" - FAILED = "failed" - EXCEPTION = "exception" - RETRY = "retry" - - -class WorkflowNodeExecution(Base): +class WorkflowNodeExecutionModel(Base): """ Workflow Node Execution @@ -585,28 +587,48 @@ class WorkflowNodeExecution(Base): """ __tablename__ = "workflow_node_executions" - __table_args__ = ( - db.PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), - db.Index( - "workflow_node_execution_workflow_run_idx", - "tenant_id", - "app_id", - "workflow_id", - "triggered_from", - "workflow_run_id", - ), - db.Index( - "workflow_node_execution_node_run_idx", "tenant_id", "app_id", "workflow_id", "triggered_from", "node_id" - ), - db.Index( - "workflow_node_execution_id_idx", - "tenant_id", - "app_id", - "workflow_id", - "triggered_from", - "node_execution_id", - ), - ) + + @declared_attr + def __table_args__(cls): # noqa + return ( + PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"), + Index( + "workflow_node_execution_workflow_run_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "workflow_run_id", + ), + Index( + "workflow_node_execution_node_run_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "node_id", + ), + Index( + "workflow_node_execution_id_idx", + "tenant_id", + "app_id", + "workflow_id", + "triggered_from", + "node_execution_id", + ), + Index( + # The first argument is the index name, + # which we leave as `None`` to allow auto-generation by the ORM. + None, + cls.tenant_id, + cls.workflow_id, + cls.node_id, + # MyPy may flag the following line because it doesn't recognize that + # the `declared_attr` decorator passes the receiving class as the first + # argument to this method, allowing us to reference class attributes. + cls.created_at.desc(), # type: ignore + ), + ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) @@ -634,24 +656,24 @@ class WorkflowNodeExecution(Base): @property def created_by_account(self): - created_by_role = CreatedByRole(self.created_by_role) + created_by_role = CreatorUserRole(self.created_by_role) # TODO(-LAN-): Avoid using db.session.get() here. - return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None + return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole(self.created_by_role) + created_by_role = CreatorUserRole(self.created_by_role) # TODO(-LAN-): Avoid using db.session.get() here. - return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None + return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None @property def inputs_dict(self): return json.loads(self.inputs) if self.inputs else None @property - def outputs_dict(self): + def outputs_dict(self) -> dict[str, Any] | None: return json.loads(self.outputs) if self.outputs else None @property @@ -659,8 +681,11 @@ class WorkflowNodeExecution(Base): return json.loads(self.process_data) if self.process_data else None @property - def execution_metadata_dict(self): - return json.loads(self.execution_metadata) if self.execution_metadata else None + def execution_metadata_dict(self) -> dict[str, Any]: + # When the metadata is unset, we return an empty dictionary instead of `None`. + # This approach streamlines the logic for the caller, making it easier to handle + # cases where metadata is absent. + return json.loads(self.execution_metadata) if self.execution_metadata else {} @property def extras(self): @@ -736,19 +761,18 @@ class WorkflowAppLog(Base): __tablename__ = "workflow_app_logs" __table_args__ = ( db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"), - db.Index("workflow_app_log_app_idx", "tenant_id", "app_id", "created_at"), - db.Index("workflow_app_log_workflow_run_idx", "workflow_run_id"), + db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), ) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) tenant_id: Mapped[str] = mapped_column(StringUUID) app_id: Mapped[str] = mapped_column(StringUUID) - workflow_id = db.Column(StringUUID, nullable=False) + workflow_id: Mapped[str] = mapped_column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID) - created_from = db.Column(db.String(255), nullable=False) - created_by_role = db.Column(db.String(255), nullable=False) - created_by = db.Column(StringUUID, nullable=False) - created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) + created_from: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_by_role: Mapped[str] = mapped_column(db.String(255), nullable=False) + created_by: Mapped[str] = mapped_column(StringUUID, nullable=False) + created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) @property def workflow_run(self): @@ -756,31 +780,28 @@ class WorkflowAppLog(Base): @property def created_by_account(self): - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None + created_by_role = CreatorUserRole(self.created_by_role) + return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None + created_by_role = CreatorUserRole(self.created_by_role) + return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None class ConversationVariable(Base): __tablename__ = "workflow_conversation_variables" - __table_args__ = ( - PrimaryKeyConstraint("id", "conversation_id", name="workflow_conversation_variables_pkey"), - Index("workflow__conversation_variables_app_id_idx", "app_id"), - Index("workflow__conversation_variables_created_at_idx", "created_at"), - ) id: Mapped[str] = mapped_column(StringUUID, primary_key=True) - conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True) - app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) - data = mapped_column(db.Text, nullable=False) - created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) - updated_at = mapped_column( + conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True) + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True) + data: Mapped[str] = mapped_column(db.Text, nullable=False) + created_at: Mapped[datetime] = mapped_column( + db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True + ) + updated_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() ) @@ -803,3 +824,216 @@ class ConversationVariable(Base): def to_variable(self) -> Variable: mapping = json.loads(self.data) return variable_factory.build_conversation_variable_from_mapping(mapping) + + +# Only `sys.query` and `sys.files` could be modified. +_EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"]) + + +def _naive_utc_datetime(): + return datetime.now(UTC).replace(tzinfo=None) + + +class WorkflowDraftVariable(Base): + @staticmethod + def unique_columns() -> list[str]: + return [ + "app_id", + "node_id", + "name", + ] + + __tablename__ = "workflow_draft_variables" + __table_args__ = (UniqueConstraint(*unique_columns()),) + + # id is the unique identifier of a draft variable. + id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + + created_at: Mapped[datetime] = mapped_column( + db.DateTime, + nullable=False, + default=_naive_utc_datetime, + server_default=func.current_timestamp(), + ) + + updated_at: Mapped[datetime] = mapped_column( + db.DateTime, + nullable=False, + default=_naive_utc_datetime, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + ) + + # "`app_id` maps to the `id` field in the `model.App` model." + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + + # `last_edited_at` records when the value of a given draft variable + # is edited. + # + # If it's not edited after creation, its value is `None`. + last_edited_at: Mapped[datetime | None] = mapped_column( + db.DateTime, + nullable=True, + default=None, + ) + + # The `node_id` field is special. + # + # If the variable is a conversation variable or a system variable, then the value of `node_id` + # is `conversation` or `sys`, respective. + # + # Otherwise, if the variable is a variable belonging to a specific node, the value of `_node_id` is + # the identity of correspond node in graph definition. An example of node id is `"1745769620734"`. + # + # However, there's one caveat. The id of the first "Answer" node in chatflow is "answer". (Other + # "Answer" node conform the rules above.) + node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="node_id") + + # From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than + # 80 chars. + # + # ref: api/core/workflow/entities/variable_pool.py:18 + name: Mapped[str] = mapped_column(sa.String(255), nullable=False) + description: Mapped[str] = mapped_column( + sa.String(255), + default="", + nullable=False, + ) + + selector: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="selector") + + # The data type of this variable's value + value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20)) + + # The variable's value serialized as a JSON string + value: Mapped[str] = mapped_column(sa.Text, nullable=False, name="value") + + # Controls whether the variable should be displayed in the variable inspection panel + visible: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) + + # Determines whether this variable can be modified by users + editable: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False) + + # The `node_execution_id` field identifies the workflow node execution that created this variable. + # It corresponds to the `id` field in the `WorkflowNodeExecutionModel` model. + # + # This field is not `None` for system variables and node variables, and is `None` + # for conversation variables. + node_execution_id: Mapped[str | None] = mapped_column( + StringUUID, + nullable=True, + default=None, + ) + + def get_selector(self) -> list[str]: + selector = json.loads(self.selector) + if not isinstance(selector, list): + _logger.error( + "invalid selector loaded from database, type=%s, value=%s", + type(selector), + self.selector, + ) + raise ValueError("invalid selector.") + return selector + + def _set_selector(self, value: list[str]): + self.selector = json.dumps(value) + + def get_value(self) -> Segment | None: + return build_segment(json.loads(self.value)) + + def set_name(self, name: str): + self.name = name + self._set_selector([self.node_id, name]) + + def set_value(self, value: Segment): + self.value = json.dumps(value.value) + self.value_type = value.value_type + + def get_node_id(self) -> str | None: + if self.get_variable_type() == DraftVariableType.NODE: + return self.node_id + else: + return None + + def get_variable_type(self) -> DraftVariableType: + match self.node_id: + case DraftVariableType.CONVERSATION: + return DraftVariableType.CONVERSATION + case DraftVariableType.SYS: + return DraftVariableType.SYS + case _: + return DraftVariableType.NODE + + @classmethod + def _new( + cls, + *, + app_id: str, + node_id: str, + name: str, + value: Segment, + description: str = "", + ) -> "WorkflowDraftVariable": + variable = WorkflowDraftVariable() + variable.created_at = _naive_utc_datetime() + variable.updated_at = _naive_utc_datetime() + variable.description = description + variable.app_id = app_id + variable.node_id = node_id + variable.name = name + variable.set_value(value) + variable._set_selector(list(variable_utils.to_selector(node_id, name))) + return variable + + @classmethod + def new_conversation_variable( + cls, + *, + app_id: str, + name: str, + value: Segment, + ) -> "WorkflowDraftVariable": + variable = cls._new( + app_id=app_id, + node_id=CONVERSATION_VARIABLE_NODE_ID, + name=name, + value=value, + ) + return variable + + @classmethod + def new_sys_variable( + cls, + *, + app_id: str, + name: str, + value: Segment, + editable: bool = False, + ) -> "WorkflowDraftVariable": + variable = cls._new(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, value=value) + variable.editable = editable + return variable + + @classmethod + def new_node_variable( + cls, + *, + app_id: str, + node_id: str, + name: str, + value: Segment, + visible: bool = True, + ) -> "WorkflowDraftVariable": + variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value) + variable.visible = visible + variable.editable = True + return variable + + @property + def edited(self): + return self.last_edited_at is not None + + +def is_system_variable_editable(name: str) -> bool: + return name in _EDITABLE_SYSTEM_VARIABLE diff --git a/api/mypy.ini b/api/mypy.ini index 865be3c17d..6836b2602b 100644 --- a/api/mypy.ini +++ b/api/mypy.ini @@ -2,6 +2,8 @@ warn_return_any = True warn_unused_configs = True check_untyped_defs = True +cache_fine_grained = True +sqlite_cache = True exclude = (?x)( core/model_runtime/model_providers/ | tests/ @@ -16,4 +18,3 @@ ignore_missing_imports=True [mypy-flask_restful.inputs] ignore_missing_imports=True - diff --git a/api/pyproject.toml b/api/pyproject.toml index 65315e9be7..38cc9ae75d 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -14,7 +14,7 @@ dependencies = [ "chardet~=5.1.0", "flask~=3.1.0", "flask-compress~=1.17", - "flask-cors~=4.0.0", + "flask-cors~=6.0.0", "flask-login~=0.6.3", "flask-migrate~=4.0.7", "flask-restful~=0.3.10", @@ -36,10 +36,9 @@ dependencies = [ "mailchimp-transactional~=1.0.50", "markdown~=3.5.1", "numpy~=1.26.4", - "oci~=2.135.1", "openai~=1.61.0", "openpyxl~=3.1.5", - "opik~=1.3.4", + "opik~=1.7.25", "opentelemetry-api==1.27.0", "opentelemetry-distro==0.48b0", "opentelemetry-exporter-otlp==1.27.0", @@ -57,33 +56,32 @@ dependencies = [ "opentelemetry-sdk==1.27.0", "opentelemetry-semantic-conventions==0.48b0", "opentelemetry-util-http==0.48b0", - "pandas-stubs~=2.2.3.241009", "pandas[excel,output-formatting,performance]~=2.2.2", "pandoc~=2.4", "psycogreen~=1.0.2", "psycopg2-binary~=2.9.6", "pycryptodome==3.19.1", - "pydantic~=2.9.2", - "pydantic-extra-types~=2.9.0", - "pydantic-settings~=2.6.0", + "pydantic~=2.11.4", + "pydantic-extra-types~=2.10.3", + "pydantic-settings~=2.9.1", "pyjwt~=2.8.0", - "pypdfium2~=4.30.0", + "pypdfium2==4.30.0", "python-docx~=1.1.0", "python-dotenv==1.0.1", "pyyaml~=6.0.1", - "readabilipy==0.2.0", - "redis[hiredis]~=5.0.3", - "resend~=0.7.0", - "sentry-sdk[flask]~=1.44.1", + "readabilipy~=0.3.0", + "redis[hiredis]~=6.1.0", + "resend~=2.9.0", + "sentry-sdk[flask]~=2.28.0", "sqlalchemy~=2.0.29", "starlette==0.41.0", "tiktoken~=0.9.0", - "tokenizers~=0.15.0", - "transformers~=4.35.0", + "transformers~=4.51.0", "unstructured[docx,epub,md,ppt,pptx]~=0.16.1", - "weave~=0.51.34", + "weave~=0.51.0", "yarl~=1.18.3", "webvtt-py~=0.5.1", + "sendgrid~=6.12.3", ] # Before adding new dependency, consider place it in # alphabet order (a-z) and suitable group. @@ -106,7 +104,7 @@ dev = [ "dotenv-linter~=0.5.0", "faker~=32.1.0", "lxml-stubs~=0.5.1", - "mypy~=1.15.0", + "mypy~=1.16.0", "ruff~=0.11.5", "pytest~=8.3.2", "pytest-benchmark~=4.0.0", @@ -144,11 +142,18 @@ dev = [ "types-requests~=2.32.0", "types-requests-oauthlib~=2.0.0", "types-shapely~=2.0.0", - "types-simplejson~=3.20.0", - "types-six~=1.17.0", - "types-tensorflow~=2.18.0", - "types-tqdm~=4.67.0", - "types-ujson~=5.10.0", + "types-simplejson>=3.20.0", + "types-six>=1.17.0", + "types-tensorflow>=2.18.0", + "types-tqdm>=4.67.0", + "types-ujson>=5.10.0", + "boto3-stubs>=1.38.20", + "types-jmespath>=1.0.2.20240106", + "types_pyOpenSSL>=24.1.0", + "types_cffi>=1.17.0", + "types_setuptools>=80.9.0", + "pandas-stubs~=2.2.3", + "scipy-stubs>=1.15.3.0", ] ############################################################ @@ -190,12 +195,13 @@ vdb = [ "pymilvus~=2.5.0", "pymochow==1.3.1", "pyobvector~=0.1.6", - "qdrant-client==1.7.3", + "qdrant-client==1.9.0", "tablestore==6.1.0", "tcvectordb~=1.6.4", "tidb-vector==0.0.9", "upstash-vector==0.6.0", - "volcengine-compat~=1.0.156", + "volcengine-compat~=1.0.0", "weaviate-client~=3.24.0", "xinference-client~=1.2.2", + "mo-vector~=0.1.13", ] diff --git a/api/pytest.ini b/api/pytest.ini index 618e921825..eb49619481 100644 --- a/api/pytest.ini +++ b/api/pytest.ini @@ -1,5 +1,4 @@ [pytest] -continue-on-collection-errors = true addopts = --cov=./api --cov-report=json --cov-report=xml env = ANTHROPIC_API_KEY = sk-ant-api11-IamNotARealKeyJustForMockTestKawaiiiiiiiiii-NotBaka-ASkksz diff --git a/api/schedule/clean_messages.py b/api/schedule/clean_messages.py index 5e4d3ec323..d02bc81f33 100644 --- a/api/schedule/clean_messages.py +++ b/api/schedule/clean_messages.py @@ -1,4 +1,5 @@ import datetime +import logging import time import click @@ -20,6 +21,8 @@ from models.model import ( from models.web import SavedMessage from services.feature_service import FeatureService +_logger = logging.getLogger(__name__) + @app.celery.task(queue="dataset") def clean_messages(): @@ -31,9 +34,8 @@ def clean_messages(): while True: try: # Main query with join and filter - # FIXME:for mypy no paginate method error messages = ( - db.session.query(Message) # type: ignore + db.session.query(Message) .filter(Message.created_at < plan_sandbox_clean_message_day) .order_by(Message.created_at.desc()) .limit(100) @@ -46,7 +48,14 @@ def clean_messages(): break for message in messages: plan_sandbox_clean_message_day = message.created_at - app = App.query.filter_by(id=message.app_id).first() + app = db.session.query(App).filter_by(id=message.app_id).first() + if not app: + _logger.warning( + "Expected App record to exist, but none was found, app_id=%s, message_id=%s", + message.app_id, + message.id, + ) + continue features_cache_key = f"features:{app.tenant_id}" plan_cache = redis_client.get(features_cache_key) if plan_cache is None: diff --git a/api/schedule/clean_unused_datasets_task.py b/api/schedule/clean_unused_datasets_task.py index 4e7e443c2c..c0cd42a226 100644 --- a/api/schedule/clean_unused_datasets_task.py +++ b/api/schedule/clean_unused_datasets_task.py @@ -2,7 +2,7 @@ import datetime import time import click -from sqlalchemy import func +from sqlalchemy import func, select from werkzeug.exceptions import NotFound import app @@ -51,8 +51,9 @@ def clean_unused_datasets_task(): ) # Main query with join and filter - datasets = ( - Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + stmt = ( + select(Dataset) + .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( Dataset.created_at < plan_sandbox_clean_day, @@ -60,9 +61,10 @@ def clean_unused_datasets_task(): func.coalesce(document_subquery_old.c.document_count, 0) > 0, ) .order_by(Dataset.created_at.desc()) - .paginate(page=1, per_page=50) ) + datasets = db.paginate(stmt, page=1, per_page=50) + except NotFound: break if datasets.items is None or len(datasets.items) == 0: @@ -99,7 +101,7 @@ def clean_unused_datasets_task(): # update document update_params = {Document.enabled: False} - Document.query.filter_by(dataset_id=dataset.id).update(update_params) + db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params) db.session.commit() click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")) except Exception as e: @@ -135,8 +137,9 @@ def clean_unused_datasets_task(): ) # Main query with join and filter - datasets = ( - Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) + stmt = ( + select(Dataset) + .outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .filter( Dataset.created_at < plan_pro_clean_day, @@ -144,8 +147,8 @@ def clean_unused_datasets_task(): func.coalesce(document_subquery_old.c.document_count, 0) > 0, ) .order_by(Dataset.created_at.desc()) - .paginate(page=1, per_page=50) ) + datasets = db.paginate(stmt, page=1, per_page=50) except NotFound: break @@ -175,7 +178,7 @@ def clean_unused_datasets_task(): # update document update_params = {Document.enabled: False} - Document.query.filter_by(dataset_id=dataset.id).update(update_params) + db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params) db.session.commit() click.echo( click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green") diff --git a/api/schedule/create_tidb_serverless_task.py b/api/schedule/create_tidb_serverless_task.py index 1c985461c6..8a02278de8 100644 --- a/api/schedule/create_tidb_serverless_task.py +++ b/api/schedule/create_tidb_serverless_task.py @@ -19,7 +19,9 @@ def create_tidb_serverless_task(): while True: try: # check the number of idle tidb serverless - idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count() + idle_tidb_serverless_number = ( + db.session.query(TidbAuthBinding).filter(TidbAuthBinding.active == False).count() + ) if idle_tidb_serverless_number >= tidb_serverless_number: break # create tidb serverless diff --git a/api/schedule/mail_clean_document_notify_task.py b/api/schedule/mail_clean_document_notify_task.py index b3d0e09784..5ee813e1de 100644 --- a/api/schedule/mail_clean_document_notify_task.py +++ b/api/schedule/mail_clean_document_notify_task.py @@ -29,7 +29,9 @@ def mail_clean_document_notify_task(): # send document clean notify mail try: - dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() + dataset_auto_disable_logs = ( + db.session.query(DatasetAutoDisableLog).filter(DatasetAutoDisableLog.notified == False).all() + ) # group by tenant_id dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) for dataset_auto_disable_log in dataset_auto_disable_logs: @@ -43,14 +45,16 @@ def mail_clean_document_notify_task(): if plan != "sandbox": knowledge_details = [] # check tenant - tenant = Tenant.query.filter(Tenant.id == tenant_id).first() + tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first() if not tenant: continue # check current owner - current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() + current_owner_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first() + ) if not current_owner_join: continue - account = Account.query.filter(Account.id == current_owner_join.account_id).first() + account = db.session.query(Account).filter(Account.id == current_owner_join.account_id).first() if not account: continue @@ -63,7 +67,7 @@ def mail_clean_document_notify_task(): ) for dataset_id, document_ids in dataset_auto_dataset_map.items(): - dataset = Dataset.query.filter(Dataset.id == dataset_id).first() + dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() if dataset: document_count = len(document_ids) knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") diff --git a/api/schedule/queue_monitor_task.py b/api/schedule/queue_monitor_task.py new file mode 100644 index 0000000000..e3a7021b9d --- /dev/null +++ b/api/schedule/queue_monitor_task.py @@ -0,0 +1,62 @@ +import logging +from datetime import datetime +from urllib.parse import urlparse + +import click +from flask import render_template +from redis import Redis + +import app +from configs import dify_config +from extensions.ext_database import db +from extensions.ext_mail import mail + +# Create a dedicated Redis connection (using the same configuration as Celery) +celery_broker_url = dify_config.CELERY_BROKER_URL + +parsed = urlparse(celery_broker_url) +host = parsed.hostname or "localhost" +port = parsed.port or 6379 +password = parsed.password or None +redis_db = parsed.path.strip("/") or "1" # type: ignore + +celery_redis = Redis(host=host, port=port, password=password, db=redis_db) + + +@app.celery.task(queue="monitor") +def queue_monitor_task(): + queue_name = "dataset" + threshold = dify_config.QUEUE_MONITOR_THRESHOLD + + try: + queue_length = celery_redis.llen(f"{queue_name}") + logging.info(click.style(f"Start monitor {queue_name}", fg="green")) + logging.info(click.style(f"Queue length: {queue_length}", fg="green")) + + if queue_length >= threshold: + warning_msg = f"Queue {queue_name} task count exceeded the limit.: {queue_length}/{threshold}" + logging.warning(click.style(warning_msg, fg="red")) + alter_emails = dify_config.QUEUE_MONITOR_ALERT_EMAILS + if alter_emails: + to_list = alter_emails.split(",") + for to in to_list: + try: + current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + html_content = render_template( + "queue_monitor_alert_email_template_en-US.html", + queue_name=queue_name, + queue_length=queue_length, + threshold=threshold, + alert_time=current_time, + ) + mail.send( + to=to, subject="Alert: Dataset Queue pending tasks exceeded the limit", html=html_content + ) + except Exception as e: + logging.exception(click.style("Exception occurred during sending email", fg="red")) + + except Exception as e: + logging.exception(click.style("Exception occurred during queue monitoring", fg="red")) + finally: + if db.session.is_active: + db.session.close() diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index 11a39e60ee..ce4ecb6e7c 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -5,6 +5,7 @@ import click import app from configs import dify_config from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService +from extensions.ext_database import db from models.dataset import TidbAuthBinding @@ -14,9 +15,11 @@ def update_tidb_serverless_status_task(): start_at = time.perf_counter() try: # check the number of idle tidb serverless - tidb_serverless_list = TidbAuthBinding.query.filter( - TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING" - ).all() + tidb_serverless_list = ( + db.session.query(TidbAuthBinding) + .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING") + .all() + ) if len(tidb_serverless_list) == 0: return # update tidb serverless status diff --git a/api/services/account_service.py b/api/services/account_service.py index f930ef910b..14d238467d 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1,7 +1,6 @@ import base64 import json import logging -import random import secrets import uuid from datetime import UTC, datetime, timedelta @@ -49,7 +48,7 @@ from services.errors.account import ( RoleAlreadyAssignedError, TenantNotFoundError, ) -from services.errors.workspace import WorkSpaceNotAllowedCreateError +from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService from tasks.delete_account_task import delete_account_task from tasks.mail_account_deletion_task import send_account_deletion_verification_code @@ -108,17 +107,20 @@ class AccountService: if account.status == AccountStatus.BANNED.value: raise Unauthorized("Account is banned.") - current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() + current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first() if current_tenant: - account.current_tenant_id = current_tenant.tenant_id + account.set_tenant_id(current_tenant.tenant_id) else: available_ta = ( - TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() + db.session.query(TenantAccountJoin) + .filter_by(account_id=account.id) + .order_by(TenantAccountJoin.id.asc()) + .first() ) if not available_ta: return None - account.current_tenant_id = available_ta.tenant_id + account.set_tenant_id(available_ta.tenant_id) available_ta.current = True db.session.commit() @@ -258,7 +260,7 @@ class AccountService: @staticmethod def generate_account_deletion_verification_code(account: Account) -> tuple[str, str]: - code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) token = TokenManager.generate_token( account=account, token_type="account_deletion", additional_data={"code": code} ) @@ -297,9 +299,9 @@ class AccountService: """Link account integrate""" try: # Query whether there is an existing binding record for the same provider - account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by( - account_id=account.id, provider=provider - ).first() + account_integrate: Optional[AccountIntegrate] = ( + db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first() + ) if account_integrate: # If it exists, update the record @@ -426,7 +428,7 @@ class AccountService: additional_data: dict[str, Any] = {}, ): if not code: - code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) additional_data["code"] = code token = TokenManager.generate_token( account=account, email=email, token_type="reset_password", additional_data=additional_data @@ -453,7 +455,7 @@ class AccountService: raise EmailCodeLoginRateLimitExceededError() - code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) token = TokenManager.generate_token( account=account, email=email, token_type="email_code_login", additional_data={"code": code} ) @@ -612,7 +614,10 @@ class TenantService: ): """Check if user have a workspace or not""" available_ta = ( - TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() + db.session.query(TenantAccountJoin) + .filter_by(account_id=account.id) + .order_by(TenantAccountJoin.id.asc()) + .first() ) if available_ta: @@ -622,6 +627,10 @@ class TenantService: if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup: raise WorkSpaceNotAllowedCreateError() + workspaces = FeatureService.get_system_features().license.workspaces + if not workspaces.is_available(): + raise WorkspacesLimitExceededError() + if name: tenant = TenantService.create_tenant(name=name, is_setup=is_setup) else: @@ -666,7 +675,7 @@ class TenantService: if not tenant: raise TenantNotFoundError("Tenant not found.") - ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() + ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() if ta: tenant.role = ta.role else: @@ -695,12 +704,12 @@ class TenantService: if not tenant_account_join: raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") else: - TenantAccountJoin.query.filter( + db.session.query(TenantAccountJoin).filter( TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id ).update({"current": False}) tenant_account_join.current = True # Set the current tenant for the account - account.current_tenant_id = tenant_account_join.tenant_id + account.set_tenant_id(tenant_account_join.tenant_id) db.session.commit() @staticmethod @@ -787,7 +796,7 @@ class TenantService: if operator.id == member.id: raise CannotOperateSelfError("Cannot operate self.") - ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first() + ta_operator = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=operator.id).first() if not ta_operator or ta_operator.role not in perms[action]: raise NoPermissionError(f"No permission to {action} member.") @@ -800,7 +809,7 @@ class TenantService: TenantService.check_member_permission(tenant, operator, account, "remove") - ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() + ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() if not ta: raise MemberNotInTenantError("Member not in tenant.") @@ -812,15 +821,23 @@ class TenantService: """Update member role""" TenantService.check_member_permission(tenant, operator, member, "update") - target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first() + target_member_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member.id).first() + ) + + if not target_member_join: + raise MemberNotInTenantError("Member not in tenant.") if target_member_join.role == new_role: raise RoleAlreadyAssignedError("The provided role is already assigned to the member.") if new_role == "owner": # Find the current owner and change their role to 'admin' - current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() - current_owner_join.role = "admin" + current_owner_join = ( + db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first() + ) + if current_owner_join: + current_owner_join.role = "admin" # Update the role of the target member target_member_join.role = new_role @@ -837,7 +854,7 @@ class TenantService: @staticmethod def get_custom_config(tenant_id: str) -> dict: - tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404() + tenant = db.get_or_404(Tenant, tenant_id) return cast(dict, tenant.custom_config_dict) @@ -914,7 +931,11 @@ class RegisterService: if open_id is not None and provider is not None: AccountService.link_account_integrate(provider, open_id, account) - if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required: + if ( + FeatureService.get_system_features().is_allow_create_workspace + and create_workspace_required + and FeatureService.get_system_features().license.workspaces.is_available() + ): tenant = TenantService.create_tenant(f"{account.name}'s Workspace") TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant @@ -959,7 +980,7 @@ class RegisterService: TenantService.switch_tenant(account, tenant.id) else: TenantService.check_member_permission(tenant, inviter, account, "add") - ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() + ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first() if not ta: TenantService.create_tenant_member(tenant, account, role) diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index ae7b372b82..8c950abc24 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -4,7 +4,7 @@ from typing import cast import pandas as pd from flask_login import current_user -from sqlalchemy import or_ +from sqlalchemy import or_, select from werkzeug.datastructures import FileStorage from werkzeug.exceptions import NotFound @@ -124,8 +124,9 @@ class AppAnnotationService: if not app: raise NotFound("App not found") if keyword: - annotations = ( - MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) + stmt = ( + select(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) .filter( or_( MessageAnnotation.question.ilike("%{}%".format(keyword)), @@ -133,14 +134,14 @@ class AppAnnotationService: ) ) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) ) else: - annotations = ( - MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) + stmt = ( + select(MessageAnnotation) + .filter(MessageAnnotation.app_id == app_id) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) ) + annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False) return annotations.items, annotations.total @classmethod @@ -325,13 +326,16 @@ class AppAnnotationService: if not annotation: raise NotFound("Annotation not found") - annotation_hit_histories = ( - AppAnnotationHitHistory.query.filter( + stmt = ( + select(AppAnnotationHitHistory) + .filter( AppAnnotationHitHistory.app_id == app_id, AppAnnotationHitHistory.annotation_id == annotation_id, ) .order_by(AppAnnotationHitHistory.created_at.desc()) - .paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + ) + annotation_hit_histories = db.paginate( + select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False ) return annotation_hit_histories.items, annotation_hit_histories.total diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index a2775fe6ad..1b026acfd6 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -40,7 +40,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB -CURRENT_DSL_VERSION = "0.2.0" +CURRENT_DSL_VERSION = "0.3.0" class ImportMode(StrEnum): @@ -421,7 +421,7 @@ class AppDslService: # Set icon type icon_type_value = icon_type or app_data.get("icon_type") - if icon_type_value in ["emoji", "link"]: + if icon_type_value in ["emoji", "link", "image"]: icon_type = icon_type_value else: icon_type = "emoji" diff --git a/api/services/app_service.py b/api/services/app_service.py index 2fae479e05..d08462d001 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -18,8 +18,10 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created from extensions.ext_database import db from models.account import Account -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, Site from models.tools import ApiToolProvider +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService from services.tag_service import TagService from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task @@ -155,6 +157,10 @@ class AppService: app_was_created.send(app, account=account) + if FeatureService.get_system_features().webapp_auth.enabled: + # update web app setting as private + EnterpriseService.WebAppAuth.update_app_access_mode(app.id, "private") + return app def get_app(self, app: App) -> App: @@ -307,6 +313,10 @@ class AppService: db.session.delete(app) db.session.commit() + # clean up web app settings + if FeatureService.get_system_features().webapp_auth.enabled: + EnterpriseService.WebAppAuth.cleanup_webapp(app.id) + # Trigger asynchronous deletion of app and related data remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id) @@ -373,3 +383,27 @@ class AppService: meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} return meta + + @staticmethod + def get_app_code_by_id(app_id: str) -> str: + """ + Get app code by app id + :param app_id: app id + :return: app code + """ + site = db.session.query(Site).filter(Site.app_id == app_id).first() + if not site: + raise ValueError(f"App with id {app_id} not found") + return str(site.code) + + @staticmethod + def get_app_id_by_code(app_code: str) -> str: + """ + Get app id by app code + :param app_code: app code + :return: app id + """ + site = db.session.query(Site).filter(Site.code == app_code).first() + if not site: + raise ValueError(f"App with code {app_code} not found") + return str(site.app_id) diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 5762bf9600..1fd560d581 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -14,7 +14,7 @@ from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Tenant from models.model import App, Conversation, Message -from models.workflow import WorkflowNodeExecution, WorkflowRun +from models.workflow import WorkflowNodeExecutionModel, WorkflowRun from services.billing_service import BillingService logger = logging.getLogger(__name__) @@ -108,10 +108,11 @@ class ClearFreePlanTenantExpiredLogs: while True: with Session(db.engine).no_autoflush as session: workflow_node_executions = ( - session.query(WorkflowNodeExecution) + session.query(WorkflowNodeExecutionModel) .filter( - WorkflowNodeExecution.tenant_id == tenant_id, - WorkflowNodeExecution.created_at < datetime.datetime.now() - datetime.timedelta(days=days), + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.created_at + < datetime.datetime.now() - datetime.timedelta(days=days), ) .limit(batch) .all() @@ -135,8 +136,8 @@ class ClearFreePlanTenantExpiredLogs: ] # delete workflow node executions - session.query(WorkflowNodeExecution).filter( - WorkflowNodeExecution.id.in_(workflow_node_execution_ids), + session.query(WorkflowNodeExecutionModel).filter( + WorkflowNodeExecutionModel.id.in_(workflow_node_execution_ids), ).delete(synchronize_session=False) session.commit() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index de90355ebf..49ca98624a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -2,14 +2,14 @@ import copy import datetime import json import logging -import random +import secrets import time import uuid from collections import Counter from typing import Any, Optional from flask_login import current_user -from sqlalchemy import func +from sqlalchemy import func, select from sqlalchemy.orm import Session from werkzeug.exceptions import NotFound @@ -59,6 +59,7 @@ from services.external_knowledge_service import ExternalDatasetService from services.feature_service import FeatureModel, FeatureService from services.tag_service import TagService from services.vector_service import VectorService +from tasks.add_document_to_index_task import add_document_to_index_task from tasks.batch_clean_document_task import batch_clean_document_task from tasks.clean_notion_document_task import clean_notion_document_task from tasks.deal_dataset_vector_index_task import deal_dataset_vector_index_task @@ -70,6 +71,7 @@ from tasks.document_indexing_update_task import document_indexing_update_task from tasks.duplicate_document_indexing_task import duplicate_document_indexing_task from tasks.enable_segments_to_index_task import enable_segments_to_index_task from tasks.recover_document_indexing_task import recover_document_indexing_task +from tasks.remove_document_from_index_task import remove_document_from_index_task from tasks.retry_document_indexing_task import retry_document_indexing_task from tasks.sync_website_document_indexing_task import sync_website_document_indexing_task @@ -77,11 +79,13 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde class DatasetService: @staticmethod def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): - query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) + query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) if user: # get permitted dataset ids - dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all() + dataset_permission = ( + db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all() + ) permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None if user.current_role == TenantAccountRole.DATASET_OPERATOR: @@ -129,7 +133,7 @@ class DatasetService: else: return [], 0 - datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) + datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False) return datasets.items, datasets.total @@ -153,9 +157,10 @@ class DatasetService: @staticmethod def get_datasets_by_ids(ids, tenant_id): - datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate( - page=1, per_page=len(ids), max_per_page=len(ids), error_out=False - ) + stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id) + + datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False) + return datasets.items, datasets.total @staticmethod @@ -174,7 +179,7 @@ class DatasetService: retrieval_model: Optional[RetrievalModel] = None, ): # check if dataset name already exists - if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): + if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first(): raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None if indexing_technique == "high_quality": @@ -235,7 +240,7 @@ class DatasetService: @staticmethod def get_dataset(dataset_id) -> Optional[Dataset]: - dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first() + dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first() return dataset @staticmethod @@ -431,12 +436,12 @@ class DatasetService: raise ValueError(ex.description) filtered_data["updated_by"] = user.id - filtered_data["updated_at"] = datetime.datetime.now() + filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) # update Retrieval model filtered_data["retrieval_model"] = data["retrieval_model"] - dataset.query.filter_by(id=dataset_id).update(filtered_data) + db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data) db.session.commit() if action: @@ -460,7 +465,7 @@ class DatasetService: @staticmethod def dataset_use_check(dataset_id) -> bool: - count = AppDatasetJoin.query.filter_by(dataset_id=dataset_id).count() + count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count() if count > 0: return True return False @@ -474,15 +479,15 @@ class DatasetService: if dataset.permission == DatasetPermissionEnum.ONLY_ME and dataset.created_by != user.id: logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") raise NoPermissionError("You do not have permission to access this dataset.") - if dataset.permission == "partial_members": - user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first() - if ( - not user_permission - and dataset.tenant_id != user.current_tenant_id - and dataset.created_by != user.id - ): - logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") - raise NoPermissionError("You do not have permission to access this dataset.") + if dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: + # For partial team permission, user needs explicit permission or be the creator + if dataset.created_by != user.id: + user_permission = ( + db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first() + ) + if not user_permission: + logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") + raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod def check_dataset_operator_permission(user: Optional[Account] = None, dataset: Optional[Dataset] = None): @@ -499,23 +504,24 @@ class DatasetService: elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: if not any( - dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all() + dp.dataset_id == dataset.id + for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all() ): raise NoPermissionError("You do not have permission to access this dataset.") @staticmethod def get_dataset_queries(dataset_id: str, page: int, per_page: int): - dataset_queries = ( - DatasetQuery.query.filter_by(dataset_id=dataset_id) - .order_by(db.desc(DatasetQuery.created_at)) - .paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) - ) + stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at)) + + dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False) + return dataset_queries.items, dataset_queries.total @staticmethod def get_related_apps(dataset_id: str): return ( - AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) + db.session.query(AppDatasetJoin) + .filter(AppDatasetJoin.dataset_id == dataset_id) .order_by(db.desc(AppDatasetJoin.created_at)) .all() ) @@ -530,10 +536,14 @@ class DatasetService: } # get recent 30 days auto disable logs start_date = datetime.datetime.now() - datetime.timedelta(days=30) - dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( - DatasetAutoDisableLog.dataset_id == dataset_id, - DatasetAutoDisableLog.created_at >= start_date, - ).all() + dataset_auto_disable_logs = ( + db.session.query(DatasetAutoDisableLog) + .filter( + DatasetAutoDisableLog.dataset_id == dataset_id, + DatasetAutoDisableLog.created_at >= start_date, + ) + .all() + ) if dataset_auto_disable_logs: return { "document_ids": [log.document_id for log in dataset_auto_disable_logs], @@ -873,7 +883,9 @@ class DocumentService: @staticmethod def get_documents_position(dataset_id): - document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() + document = ( + db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() + ) if document: return document.position + 1 else: @@ -948,11 +960,11 @@ class DocumentService: "score_threshold_enabled": False, } - dataset.retrieval_model = ( - knowledge_config.retrieval_model.model_dump() - if knowledge_config.retrieval_model - else default_retrieval_model - ) # type: ignore + dataset.retrieval_model = ( + knowledge_config.retrieval_model.model_dump() + if knowledge_config.retrieval_model + else default_retrieval_model + ) # type: ignore documents = [] if knowledge_config.original_document_id: @@ -960,18 +972,23 @@ class DocumentService: documents.append(document) batch = document.batch else: - batch = time.strftime("%Y%m%d%H%M%S") + str(random.randint(100000, 999999)) + batch = time.strftime("%Y%m%d%H%M%S") + str(100000 + secrets.randbelow(exclusive_upper_bound=900000)) # save process rule if not dataset_process_rule: process_rule = knowledge_config.process_rule if process_rule: if process_rule.mode in ("custom", "hierarchical"): - dataset_process_rule = DatasetProcessRule( - dataset_id=dataset.id, - mode=process_rule.mode, - rules=process_rule.rules.model_dump_json() if process_rule.rules else None, - created_by=account.id, - ) + if process_rule.rules: + dataset_process_rule = DatasetProcessRule( + dataset_id=dataset.id, + mode=process_rule.mode, + rules=process_rule.rules.model_dump_json() if process_rule.rules else None, + created_by=account.id, + ) + else: + dataset_process_rule = dataset.latest_process_rule + if not dataset_process_rule: + raise ValueError("No process rule found.") elif process_rule.mode == "automatic": dataset_process_rule = DatasetProcessRule( dataset_id=dataset.id, @@ -980,7 +997,7 @@ class DocumentService: created_by=account.id, ) else: - logging.warn( + logging.warning( f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" ) return @@ -1010,13 +1027,17 @@ class DocumentService: } # check duplicate if knowledge_config.duplicate: - document = Document.query.filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="upload_file", - enabled=True, - name=file_name, - ).first() + document = ( + db.session.query(Document) + .filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="upload_file", + enabled=True, + name=file_name, + ) + .first() + ) if document: document.dataset_process_rule_id = dataset_process_rule.id # type: ignore document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) @@ -1054,12 +1075,16 @@ class DocumentService: raise ValueError("No notion info list found.") exist_page_ids = [] exist_document = {} - documents = Document.query.filter_by( - dataset_id=dataset.id, - tenant_id=current_user.current_tenant_id, - data_source_type="notion_import", - enabled=True, - ).all() + documents = ( + db.session.query(Document) + .filter_by( + dataset_id=dataset.id, + tenant_id=current_user.current_tenant_id, + data_source_type="notion_import", + enabled=True, + ) + .all() + ) if documents: for document in documents: data_source_info = json.loads(document.data_source_info) @@ -1067,14 +1092,18 @@ class DocumentService: exist_document[data_source_info["notion_page_id"]] = document.id for notion_info in notion_info_list: workspace_id = notion_info.workspace_id - data_source_binding = DataSourceOauthBinding.query.filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + data_source_binding = ( + db.session.query(DataSourceOauthBinding) + .filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + ) ) - ).first() + .first() + ) if not data_source_binding: raise ValueError("Data source binding not found.") for page in notion_info.pages: @@ -1206,12 +1235,16 @@ class DocumentService: @staticmethod def get_tenant_documents_count(): - documents_count = Document.query.filter( - Document.completed_at.isnot(None), - Document.enabled == True, - Document.archived == False, - Document.tenant_id == current_user.current_tenant_id, - ).count() + documents_count = ( + db.session.query(Document) + .filter( + Document.completed_at.isnot(None), + Document.enabled == True, + Document.archived == False, + Document.tenant_id == current_user.current_tenant_id, + ) + .count() + ) return documents_count @staticmethod @@ -1278,14 +1311,18 @@ class DocumentService: notion_info_list = document_data.data_source.info_list.notion_info_list for notion_info in notion_info_list: workspace_id = notion_info.workspace_id - data_source_binding = DataSourceOauthBinding.query.filter( - db.and_( - DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + data_source_binding = ( + db.session.query(DataSourceOauthBinding) + .filter( + db.and_( + DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + ) ) - ).first() + .first() + ) if not data_source_binding: raise ValueError("Data source binding not found.") for page in notion_info.pages: @@ -1328,7 +1365,7 @@ class DocumentService: db.session.commit() # update document segment update_params = {DocumentSegment.status: "re_segment"} - DocumentSegment.query.filter_by(document_id=document.id).update(update_params) + db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params) db.session.commit() # trigger async task document_indexing_update_task.delay(document.dataset_id, document.id) @@ -1372,16 +1409,16 @@ class DocumentService: knowledge_config.embedding_model, # type: ignore ) dataset_collection_binding_id = dataset_collection_binding.id - if knowledge_config.retrieval_model: - retrieval_model = knowledge_config.retrieval_model - else: - retrieval_model = RetrievalModel( - search_method=RetrievalMethod.SEMANTIC_SEARCH.value, - reranking_enable=False, - reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""), - top_k=2, - score_threshold_enabled=False, - ) + if knowledge_config.retrieval_model: + retrieval_model = knowledge_config.retrieval_model + else: + retrieval_model = RetrievalModel( + search_method=RetrievalMethod.SEMANTIC_SEARCH.value, + reranking_enable=False, + reranking_model=RerankingModel(reranking_provider_name="", reranking_model_name=""), + top_k=2, + score_threshold_enabled=False, + ) # save dataset dataset = Dataset( tenant_id=tenant_id, @@ -1573,6 +1610,191 @@ class DocumentService: if not isinstance(args["process_rule"]["rules"]["segmentation"]["max_tokens"], int): raise ValueError("Process rule segmentation max_tokens is invalid") + @staticmethod + def batch_update_document_status(dataset: Dataset, document_ids: list[str], action: str, user): + """ + Batch update document status. + + Args: + dataset (Dataset): The dataset object + document_ids (list[str]): List of document IDs to update + action (str): Action to perform (enable, disable, archive, un_archive) + user: Current user performing the action + + Raises: + DocumentIndexingError: If document is being indexed or not in correct state + ValueError: If action is invalid + """ + if not document_ids: + return + + # Early validation of action parameter + valid_actions = ["enable", "disable", "archive", "un_archive"] + if action not in valid_actions: + raise ValueError(f"Invalid action: {action}. Must be one of {valid_actions}") + + documents_to_update = [] + + # First pass: validate all documents and prepare updates + for document_id in document_ids: + document = DocumentService.get_document(dataset.id, document_id) + if not document: + continue + + # Check if document is being indexed + indexing_cache_key = f"document_{document.id}_indexing" + cache_result = redis_client.get(indexing_cache_key) + if cache_result is not None: + raise DocumentIndexingError(f"Document:{document.name} is being indexed, please try again later") + + # Prepare update based on action + update_info = DocumentService._prepare_document_status_update(document, action, user) + if update_info: + documents_to_update.append(update_info) + + # Second pass: apply all updates in a single transaction + if documents_to_update: + try: + for update_info in documents_to_update: + document = update_info["document"] + updates = update_info["updates"] + + # Apply updates to the document + for field, value in updates.items(): + setattr(document, field, value) + + db.session.add(document) + + # Batch commit all changes + db.session.commit() + except Exception as e: + # Rollback on any error + db.session.rollback() + raise e + # Execute async tasks and set Redis cache after successful commit + # propagation_error is used to capture any errors for submitting async task execution + propagation_error = None + for update_info in documents_to_update: + try: + # Execute async tasks after successful commit + if update_info["async_task"]: + task_info = update_info["async_task"] + task_func = task_info["function"] + task_args = task_info["args"] + task_func.delay(*task_args) + except Exception as e: + # Log the error but do not rollback the transaction + logging.exception(f"Error executing async task for document {update_info['document'].id}") + # don't raise the error immediately, but capture it for later + propagation_error = e + try: + # Set Redis cache if needed after successful commit + if update_info["set_cache"]: + document = update_info["document"] + indexing_cache_key = f"document_{document.id}_indexing" + redis_client.setex(indexing_cache_key, 600, 1) + except Exception as e: + # Log the error but do not rollback the transaction + logging.exception(f"Error setting cache for document {update_info['document'].id}") + # Raise any propagation error after all updates + if propagation_error: + raise propagation_error + + @staticmethod + def _prepare_document_status_update(document, action: str, user): + """ + Prepare document status update information. + + Args: + document: Document object to update + action: Action to perform + user: Current user + + Returns: + dict: Update information or None if no update needed + """ + now = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + + if action == "enable": + return DocumentService._prepare_enable_update(document, now) + elif action == "disable": + return DocumentService._prepare_disable_update(document, user, now) + elif action == "archive": + return DocumentService._prepare_archive_update(document, user, now) + elif action == "un_archive": + return DocumentService._prepare_unarchive_update(document, now) + + return None + + @staticmethod + def _prepare_enable_update(document, now): + """Prepare updates for enabling a document.""" + if document.enabled: + return None + + return { + "document": document, + "updates": {"enabled": True, "disabled_at": None, "disabled_by": None, "updated_at": now}, + "async_task": {"function": add_document_to_index_task, "args": [document.id]}, + "set_cache": True, + } + + @staticmethod + def _prepare_disable_update(document, user, now): + """Prepare updates for disabling a document.""" + if not document.completed_at or document.indexing_status != "completed": + raise DocumentIndexingError(f"Document: {document.name} is not completed.") + + if not document.enabled: + return None + + return { + "document": document, + "updates": {"enabled": False, "disabled_at": now, "disabled_by": user.id, "updated_at": now}, + "async_task": {"function": remove_document_from_index_task, "args": [document.id]}, + "set_cache": True, + } + + @staticmethod + def _prepare_archive_update(document, user, now): + """Prepare updates for archiving a document.""" + if document.archived: + return None + + update_info = { + "document": document, + "updates": {"archived": True, "archived_at": now, "archived_by": user.id, "updated_at": now}, + "async_task": None, + "set_cache": False, + } + + # Only set async task and cache if document is currently enabled + if document.enabled: + update_info["async_task"] = {"function": remove_document_from_index_task, "args": [document.id]} + update_info["set_cache"] = True + + return update_info + + @staticmethod + def _prepare_unarchive_update(document, now): + """Prepare updates for unarchiving a document.""" + if not document.archived: + return None + + update_info = { + "document": document, + "updates": {"archived": False, "archived_at": None, "archived_by": None, "updated_at": now}, + "async_task": None, + "set_cache": False, + } + + # Only re-index if the document is currently enabled + if document.enabled: + update_info["async_task"] = {"function": add_document_to_index_task, "args": [document.id]} + update_info["set_cache"] = True + + return update_info + class SegmentService: @classmethod @@ -1918,7 +2140,8 @@ class SegmentService: @classmethod def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): index_node_ids = ( - DocumentSegment.query.with_entities(DocumentSegment.index_node_id) + db.session.query(DocumentSegment) + .with_entities(DocumentSegment.index_node_id) .filter( DocumentSegment.id.in_(segment_ids), DocumentSegment.dataset_id == dataset.id, @@ -2157,20 +2380,28 @@ class SegmentService: def get_child_chunks( cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None ): - query = ChildChunk.query.filter_by( - tenant_id=current_user.current_tenant_id, - dataset_id=dataset_id, - document_id=document_id, - segment_id=segment_id, - ).order_by(ChildChunk.position.asc()) + query = ( + select(ChildChunk) + .filter_by( + tenant_id=current_user.current_tenant_id, + dataset_id=dataset_id, + document_id=document_id, + segment_id=segment_id, + ) + .order_by(ChildChunk.position.asc()) + ) if keyword: query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) - return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) + return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) @classmethod def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]: """Get a child chunk by its ID.""" - result = ChildChunk.query.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).first() + result = ( + db.session.query(ChildChunk) + .filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id) + .first() + ) return result if isinstance(result, ChildChunk) else None @classmethod @@ -2184,7 +2415,7 @@ class SegmentService: limit: int = 20, ): """Get segments for a document with optional filtering.""" - query = DocumentSegment.query.filter( + query = select(DocumentSegment).filter( DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id ) @@ -2194,9 +2425,8 @@ class SegmentService: if keyword: query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%")) - paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate( - page=page, per_page=limit, max_per_page=100, error_out=False - ) + query = query.order_by(DocumentSegment.position.asc()) + paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False) return paginated_segments.items, paginated_segments.total @@ -2236,9 +2466,11 @@ class SegmentService: raise ValueError(ex.description) # check segment - segment = DocumentSegment.query.filter( - DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id - ).first() + segment = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id) + .first() + ) if not segment: raise NotFound("Segment not found.") @@ -2251,9 +2483,11 @@ class SegmentService: @classmethod def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]: """Get a segment by its ID.""" - result = DocumentSegment.query.filter( - DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id - ).first() + result = ( + db.session.query(DocumentSegment) + .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id) + .first() + ) return result if isinstance(result, DocumentSegment) else None diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index abc01ddf8f..8c06ee9386 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -1,11 +1,114 @@ +from datetime import datetime + +from pydantic import BaseModel, Field + from services.enterprise.base import EnterpriseRequest +class WebAppSettings(BaseModel): + access_mode: str = Field( + description="Access mode for the web app. Can be 'public', 'private', 'private_all', 'sso_verified'", + default="private", + alias="accessMode", + ) + + class EnterpriseService: @classmethod def get_info(cls): return EnterpriseRequest.send_request("GET", "/info") @classmethod - def get_app_web_sso_enabled(cls, app_code): - return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}") + def get_workspace_info(cls, tenant_id: str): + return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info") + + @classmethod + def get_app_sso_settings_last_update_time(cls) -> datetime: + data = EnterpriseRequest.send_request("GET", "/sso/app/last-update-time") + if not data: + raise ValueError("No data found.") + try: + # parse the UTC timestamp from the response + return datetime.fromisoformat(data.replace("Z", "+00:00")) + except ValueError as e: + raise ValueError(f"Invalid date format: {data}") from e + + @classmethod + def get_workspace_sso_settings_last_update_time(cls) -> datetime: + data = EnterpriseRequest.send_request("GET", "/sso/workspace/last-update-time") + if not data: + raise ValueError("No data found.") + try: + # parse the UTC timestamp from the response + return datetime.fromisoformat(data.replace("Z", "+00:00")) + except ValueError as e: + raise ValueError(f"Invalid date format: {data}") from e + + class WebAppAuth: + @classmethod + def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str): + params = {"userId": user_id, "appCode": app_code} + data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params) + + return data.get("result", False) + + @classmethod + def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings: + if not app_id: + raise ValueError("app_id must be provided.") + params = {"appId": app_id} + data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params) + if not data: + raise ValueError("No data found.") + return WebAppSettings(**data) + + @classmethod + def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]: + if not app_ids: + return {} + body = {"appIds": app_ids} + data: dict[str, str] = EnterpriseRequest.send_request("POST", "/webapp/access-mode/batch/id", json=body) + if not data: + raise ValueError("No data found.") + + if not isinstance(data["accessModes"], dict): + raise ValueError("Invalid data format.") + + ret = {} + for key, value in data["accessModes"].items(): + curr = WebAppSettings() + curr.access_mode = value + ret[key] = curr + + return ret + + @classmethod + def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings: + if not app_code: + raise ValueError("app_code must be provided.") + params = {"appCode": app_code} + data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params) + if not data: + raise ValueError("No data found.") + return WebAppSettings(**data) + + @classmethod + def update_app_access_mode(cls, app_id: str, access_mode: str): + if not app_id: + raise ValueError("app_id must be provided.") + if access_mode not in ["public", "private", "private_all"]: + raise ValueError("access_mode must be either 'public', 'private', or 'private_all'") + + data = {"appId": app_id, "accessMode": access_mode} + + response = EnterpriseRequest.send_request("POST", "/webapp/access-mode", json=data) + + return response.get("result", False) + + @classmethod + def cleanup_webapp(cls, app_id: str): + if not app_id: + raise ValueError("app_id must be provided.") + + body = {"appId": app_id} + EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body) diff --git a/api/services/enterprise/mail_service.py b/api/services/enterprise/mail_service.py new file mode 100644 index 0000000000..630e7679ac --- /dev/null +++ b/api/services/enterprise/mail_service.py @@ -0,0 +1,18 @@ +from pydantic import BaseModel + +from tasks.mail_enterprise_task import send_enterprise_email_task + + +class DifyMail(BaseModel): + to: list[str] + subject: str + body: str + substitutions: dict[str, str] = {} + + +class EnterpriseMailService: + @classmethod + def send_mail(cls, mail: DifyMail): + send_enterprise_email_task.delay( + to=mail.to, subject=mail.subject, body=mail.body, substitutions=mail.substitutions + ) diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index bb3be61f85..603064ca07 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -101,7 +101,7 @@ class WeightModel(BaseModel): class RetrievalModel(BaseModel): - search_method: Literal["hybrid_search", "semantic_search", "full_text_search"] + search_method: Literal["hybrid_search", "semantic_search", "full_text_search", "keyword_search"] reranking_enable: bool reranking_model: Optional[RerankingModel] = None reranking_mode: Optional[str] = None diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py index eb1f055708..697e691224 100644 --- a/api/services/errors/__init__.py +++ b/api/services/errors/__init__.py @@ -4,7 +4,6 @@ from . import ( app_model_config, audio, base, - completion, conversation, dataset, document, @@ -19,7 +18,6 @@ __all__ = [ "app_model_config", "audio", "base", - "completion", "conversation", "dataset", "document", diff --git a/api/services/errors/account.py b/api/services/errors/account.py index 5aca12ffeb..4d3d150e07 100644 --- a/api/services/errors/account.py +++ b/api/services/errors/account.py @@ -55,7 +55,3 @@ class MemberNotInTenantError(BaseServiceError): class RoleAlreadyAssignedError(BaseServiceError): pass - - -class RateLimitExceededError(BaseServiceError): - pass diff --git a/api/services/errors/completion.py b/api/services/errors/plugin.py similarity index 51% rename from api/services/errors/completion.py rename to api/services/errors/plugin.py index 7fc50a588e..be5b144b3d 100644 --- a/api/services/errors/completion.py +++ b/api/services/errors/plugin.py @@ -1,5 +1,5 @@ from services.errors.base import BaseServiceError -class CompletionStoppedError(BaseServiceError): +class PluginInstallationForbiddenError(BaseServiceError): pass diff --git a/api/services/errors/workspace.py b/api/services/errors/workspace.py index 714064ffdf..577238507f 100644 --- a/api/services/errors/workspace.py +++ b/api/services/errors/workspace.py @@ -7,3 +7,7 @@ class WorkSpaceNotAllowedCreateError(BaseServiceError): class WorkSpaceNotFoundError(BaseServiceError): pass + + +class WorkspacesLimitExceededError(BaseServiceError): + pass diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 6b75c29d95..eb50d79494 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -5,6 +5,7 @@ from typing import Any, Optional, Union, cast from urllib.parse import urlparse import httpx +from sqlalchemy import select from constants import HIDDEN_VALUE from core.helper import ssrf_proxy @@ -24,14 +25,20 @@ from services.errors.dataset import DatasetNameDuplicateError class ExternalDatasetService: @staticmethod - def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]: - query = ExternalKnowledgeApis.query.filter(ExternalKnowledgeApis.tenant_id == tenant_id).order_by( - ExternalKnowledgeApis.created_at.desc() + def get_external_knowledge_apis( + page, per_page, tenant_id, search=None + ) -> tuple[list[ExternalKnowledgeApis], int | None]: + query = ( + select(ExternalKnowledgeApis) + .filter(ExternalKnowledgeApis.tenant_id == tenant_id) + .order_by(ExternalKnowledgeApis.created_at.desc()) ) if search: query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%")) - external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) + external_knowledge_apis = db.paginate( + select=query, page=page, per_page=per_page, max_per_page=100, error_out=False + ) return external_knowledge_apis.items, external_knowledge_apis.total @@ -92,18 +99,18 @@ class ExternalDatasetService: @staticmethod def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis: - external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( - id=external_knowledge_api_id - ).first() + external_knowledge_api: Optional[ExternalKnowledgeApis] = ( + db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first() + ) if external_knowledge_api is None: raise ValueError("api template not found") return external_knowledge_api @staticmethod def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: - external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( - id=external_knowledge_api_id, tenant_id=tenant_id - ).first() + external_knowledge_api: Optional[ExternalKnowledgeApis] = ( + db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first() + ) if external_knowledge_api is None: raise ValueError("api template not found") if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE: @@ -120,9 +127,9 @@ class ExternalDatasetService: @staticmethod def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str): - external_knowledge_api = ExternalKnowledgeApis.query.filter_by( - id=external_knowledge_api_id, tenant_id=tenant_id - ).first() + external_knowledge_api = ( + db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first() + ) if external_knowledge_api is None: raise ValueError("api template not found") @@ -131,25 +138,29 @@ class ExternalDatasetService: @staticmethod def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]: - count = ExternalKnowledgeBindings.query.filter_by(external_knowledge_api_id=external_knowledge_api_id).count() + count = ( + db.session.query(ExternalKnowledgeBindings) + .filter_by(external_knowledge_api_id=external_knowledge_api_id) + .count() + ) if count > 0: return True, count return False, 0 @staticmethod def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: - external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by( - dataset_id=dataset_id, tenant_id=tenant_id - ).first() + external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ( + db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first() + ) if not external_knowledge_binding: raise ValueError("external knowledge binding not found") return external_knowledge_binding @staticmethod def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict): - external_knowledge_api = ExternalKnowledgeApis.query.filter_by( - id=external_knowledge_api_id, tenant_id=tenant_id - ).first() + external_knowledge_api = ( + db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first() + ) if external_knowledge_api is None: raise ValueError("api template not found") settings = json.loads(external_knowledge_api.settings) @@ -212,11 +223,13 @@ class ExternalDatasetService: @staticmethod def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset: # check if dataset name already exists - if Dataset.query.filter_by(name=args.get("name"), tenant_id=tenant_id).first(): + if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first(): raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.") - external_knowledge_api = ExternalKnowledgeApis.query.filter_by( - id=args.get("external_knowledge_api_id"), tenant_id=tenant_id - ).first() + external_knowledge_api = ( + db.session.query(ExternalKnowledgeApis) + .filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id) + .first() + ) if external_knowledge_api is None: raise ValueError("api template not found") @@ -254,15 +267,17 @@ class ExternalDatasetService: external_retrieval_parameters: dict, metadata_condition: Optional[MetadataCondition] = None, ) -> list: - external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( - dataset_id=dataset_id, tenant_id=tenant_id - ).first() + external_knowledge_binding = ( + db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first() + ) if not external_knowledge_binding: raise ValueError("external knowledge binding not found") - external_knowledge_api = ExternalKnowledgeApis.query.filter_by( - id=external_knowledge_binding.external_knowledge_api_id - ).first() + external_knowledge_api = ( + db.session.query(ExternalKnowledgeApis) + .filter_by(id=external_knowledge_binding.external_knowledge_api_id) + .first() + ) if not external_knowledge_api: raise ValueError("external api template not found") diff --git a/api/services/feature_service.py b/api/services/feature_service.py index c2226c319f..188caf3505 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -1,6 +1,6 @@ from enum import StrEnum -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from configs import dify_config from services.billing_service import BillingService @@ -27,6 +27,32 @@ class LimitationModel(BaseModel): limit: int = 0 +class LicenseLimitationModel(BaseModel): + """ + - enabled: whether this limit is enforced + - size: current usage count + - limit: maximum allowed count; 0 means unlimited + """ + + enabled: bool = Field(False, description="Whether this limit is currently active") + size: int = Field(0, description="Number of resources already consumed") + limit: int = Field(0, description="Maximum number of resources allowed; 0 means no limit") + + def is_available(self, required: int = 1) -> bool: + """ + Determine whether the requested amount can be allocated. + + Returns True if: + - this limit is not active, or + - the limit is zero (unlimited), or + - there is enough remaining quota. + """ + if not self.enabled or self.limit == 0: + return True + + return (self.limit - self.size) >= required + + class LicenseStatus(StrEnum): NONE = "none" INACTIVE = "inactive" @@ -39,6 +65,47 @@ class LicenseStatus(StrEnum): class LicenseModel(BaseModel): status: LicenseStatus = LicenseStatus.NONE expired_at: str = "" + workspaces: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0) + + +class BrandingModel(BaseModel): + enabled: bool = False + application_title: str = "" + login_page_logo: str = "" + workspace_logo: str = "" + favicon: str = "" + + +class WebAppAuthSSOModel(BaseModel): + protocol: str = "" + + +class WebAppAuthModel(BaseModel): + enabled: bool = False + allow_sso: bool = False + sso_config: WebAppAuthSSOModel = WebAppAuthSSOModel() + allow_email_code_login: bool = False + allow_email_password_login: bool = False + + +class PluginInstallationScope(StrEnum): + NONE = "none" + OFFICIAL_ONLY = "official_only" + OFFICIAL_AND_SPECIFIC_PARTNERS = "official_and_specific_partners" + ALL = "all" + + +class PluginInstallationPermissionModel(BaseModel): + # Plugin installation scope – possible values: + # none: prohibit all plugin installations + # official_only: allow only Dify official plugins + # official_and_specific_partners: allow official and specific partner plugins + # all: allow installation of all plugins + plugin_installation_scope: PluginInstallationScope = PluginInstallationScope.ALL + + # If True, restrict plugin installation to the marketplace only + # Equivalent to ForceEnablePluginVerification + restrict_to_marketplace_only: bool = False class FeatureModel(BaseModel): @@ -54,6 +121,8 @@ class FeatureModel(BaseModel): can_replace_logo: bool = False model_load_balancing_enabled: bool = False dataset_operator_enabled: bool = False + webapp_copyright_enabled: bool = False + workspace_members: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0) # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -68,9 +137,6 @@ class KnowledgeRateLimitModel(BaseModel): class SystemFeatureModel(BaseModel): sso_enforced_for_signin: bool = False sso_enforced_for_signin_protocol: str = "" - sso_enforced_for_web: bool = False - sso_enforced_for_web_protocol: str = "" - enable_web_sso_switch_component: bool = False enable_marketplace: bool = False max_plugin_package_size: int = dify_config.PLUGIN_MAX_PACKAGE_SIZE enable_email_code_login: bool = False @@ -80,6 +146,9 @@ class SystemFeatureModel(BaseModel): is_allow_create_workspace: bool = False is_email_setup: bool = False license: LicenseModel = LicenseModel() + branding: BrandingModel = BrandingModel() + webapp_auth: WebAppAuthModel = WebAppAuthModel() + plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel() class FeatureService: @@ -92,6 +161,10 @@ class FeatureService: if dify_config.BILLING_ENABLED and tenant_id: cls._fulfill_params_from_billing_api(features, tenant_id) + if dify_config.ENTERPRISE_ENABLED: + features.webapp_copyright_enabled = True + cls._fulfill_params_from_workspace_info(features, tenant_id) + return features @classmethod @@ -111,8 +184,8 @@ class FeatureService: cls._fulfill_system_params_from_env(system_features) if dify_config.ENTERPRISE_ENABLED: - system_features.enable_web_sso_switch_component = True - + system_features.branding.enabled = True + system_features.webapp_auth.enabled = True cls._fulfill_params_from_enterprise(system_features) if dify_config.MARKETPLACE_ENABLED: @@ -136,6 +209,14 @@ class FeatureService: features.dataset_operator_enabled = dify_config.DATASET_OPERATOR_ENABLED features.education.enabled = dify_config.EDUCATION_ENABLED + @classmethod + def _fulfill_params_from_workspace_info(cls, features: FeatureModel, tenant_id: str): + workspace_info = EnterpriseService.get_workspace_info(tenant_id) + if "WorkspaceMembers" in workspace_info: + features.workspace_members.size = workspace_info["WorkspaceMembers"]["used"] + features.workspace_members.limit = workspace_info["WorkspaceMembers"]["limit"] + features.workspace_members.enabled = workspace_info["WorkspaceMembers"]["enabled"] + @classmethod def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): billing_info = BillingService.get_info(tenant_id) @@ -145,6 +226,9 @@ class FeatureService: features.billing.subscription.interval = billing_info["subscription"]["interval"] features.education.activated = billing_info["subscription"].get("education", False) + if features.billing.subscription.plan != "sandbox": + features.webapp_copyright_enabled = True + if "members" in billing_info: features.members.size = billing_info["members"]["size"] features.members.limit = billing_info["members"]["limit"] @@ -178,38 +262,62 @@ class FeatureService: features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"] @classmethod - def _fulfill_params_from_enterprise(cls, features): + def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel): enterprise_info = EnterpriseService.get_info() - if "sso_enforced_for_signin" in enterprise_info: - features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"] + if "SSOEnforcedForSignin" in enterprise_info: + features.sso_enforced_for_signin = enterprise_info["SSOEnforcedForSignin"] - if "sso_enforced_for_signin_protocol" in enterprise_info: - features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"] + if "SSOEnforcedForSigninProtocol" in enterprise_info: + features.sso_enforced_for_signin_protocol = enterprise_info["SSOEnforcedForSigninProtocol"] - if "sso_enforced_for_web" in enterprise_info: - features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"] + if "EnableEmailCodeLogin" in enterprise_info: + features.enable_email_code_login = enterprise_info["EnableEmailCodeLogin"] - if "sso_enforced_for_web_protocol" in enterprise_info: - features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"] + if "EnableEmailPasswordLogin" in enterprise_info: + features.enable_email_password_login = enterprise_info["EnableEmailPasswordLogin"] - if "enable_email_code_login" in enterprise_info: - features.enable_email_code_login = enterprise_info["enable_email_code_login"] + if "IsAllowRegister" in enterprise_info: + features.is_allow_register = enterprise_info["IsAllowRegister"] - if "enable_email_password_login" in enterprise_info: - features.enable_email_password_login = enterprise_info["enable_email_password_login"] + if "IsAllowCreateWorkspace" in enterprise_info: + features.is_allow_create_workspace = enterprise_info["IsAllowCreateWorkspace"] - if "is_allow_register" in enterprise_info: - features.is_allow_register = enterprise_info["is_allow_register"] + if "Branding" in enterprise_info: + features.branding.application_title = enterprise_info["Branding"].get("applicationTitle", "") + features.branding.login_page_logo = enterprise_info["Branding"].get("loginPageLogo", "") + features.branding.workspace_logo = enterprise_info["Branding"].get("workspaceLogo", "") + features.branding.favicon = enterprise_info["Branding"].get("favicon", "") - if "is_allow_create_workspace" in enterprise_info: - features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"] + if "WebAppAuth" in enterprise_info: + features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSso", False) + features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get( + "allowEmailCodeLogin", False + ) + features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get( + "allowEmailPasswordLogin", False + ) + features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "") - if "license" in enterprise_info: - license_info = enterprise_info["license"] + if "License" in enterprise_info: + license_info = enterprise_info["License"] if "status" in license_info: features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) - if "expired_at" in license_info: - features.license.expired_at = license_info["expired_at"] + if "expiredAt" in license_info: + features.license.expired_at = license_info["expiredAt"] + + if "workspaces" in license_info: + features.license.workspaces.enabled = license_info["workspaces"]["enabled"] + features.license.workspaces.limit = license_info["workspaces"]["limit"] + features.license.workspaces.size = license_info["workspaces"]["used"] + + if "PluginInstallationPermission" in enterprise_info: + plugin_installation_info = enterprise_info["PluginInstallationPermission"] + features.plugin_installation_permission.plugin_installation_scope = plugin_installation_info[ + "pluginInstallationScope" + ] + features.plugin_installation_permission.restrict_to_marketplace_only = plugin_installation_info[ + "restrictToMarketplaceOnly" + ] diff --git a/api/services/file_service.py b/api/services/file_service.py index 2ca6b4f9aa..2d68f30c5a 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -19,7 +19,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Account -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import EndUser, UploadFile from .errors.file import FileTooLargeError, UnsupportedFileTypeError @@ -81,7 +81,7 @@ class FileService: size=file_size, extension=extension, mime_type=mimetype, - created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER), + created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER), created_by=user.id, created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=False, @@ -133,7 +133,7 @@ class FileService: extension="txt", mime_type="text/plain", created_by=current_user.id, - created_by_role=CreatedByRole.ACCOUNT, + created_by_role=CreatorUserRole.ACCOUNT, created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=True, used_by=current_user.id, diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 56e06cc33e..519d5abca5 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -2,8 +2,11 @@ import logging import time from typing import Any +from core.app.app_config.entities import ModelConfig +from core.model_runtime.entities import LLMMode from core.rag.datasource.retrieval_service import RetrievalService from core.rag.models.document import Document +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from extensions.ext_database import db from models.account import Account @@ -34,7 +37,29 @@ class HitTestingService: # get retrieval model , if the model is not setting , using default if not retrieval_model: retrieval_model = dataset.retrieval_model or default_retrieval_model - + document_ids_filter = None + metadata_filtering_conditions = retrieval_model.get("metadata_filtering_conditions", {}) + if metadata_filtering_conditions: + dataset_retrieval = DatasetRetrieval() + + from core.app.app_config.entities import MetadataFilteringCondition + + metadata_filtering_conditions = MetadataFilteringCondition(**metadata_filtering_conditions) + + metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition( + dataset_ids=[dataset.id], + query=query, + metadata_filtering_mode="manual", + metadata_filtering_conditions=metadata_filtering_conditions, + inputs={}, + tenant_id="", + user_id="", + metadata_model_config=ModelConfig(provider="", name="", mode=LLMMode.CHAT, completion_params={}), + ) + if metadata_filter_document_ids: + document_ids_filter = metadata_filter_document_ids.get(dataset.id, []) + if metadata_condition and not document_ids_filter: + return cls.compact_retrieve_response(query, []) all_documents = RetrievalService.retrieve( retrieval_method=retrieval_model.get("search_method", "semantic_search"), dataset_id=dataset.id, @@ -48,6 +73,7 @@ class HitTestingService: else None, reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", weights=retrieval_model.get("weights", None), + document_ids_filter=document_ids_filter, ) end = time.perf_counter() @@ -99,7 +125,7 @@ class HitTestingService: return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) @classmethod - def compact_retrieve_response(cls, query: str, documents: list[Document]): + def compact_retrieve_response(cls, query: str, documents: list[Document]) -> dict[Any, Any]: records = RetrievalService.format_retrieval_documents(documents) return { diff --git a/api/services/message_service.py b/api/services/message_service.py index aefab1556c..51b070ece7 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -177,6 +177,21 @@ class MessageService: return feedback + @classmethod + def get_all_messages_feedbacks(cls, app_model: App, page: int, limit: int): + """Get all feedbacks of an app""" + offset = (page - 1) * limit + feedbacks = ( + db.session.query(MessageFeedback) + .filter(MessageFeedback.app_id == app_model.id) + .order_by(MessageFeedback.created_at.desc(), MessageFeedback.id.desc()) + .limit(limit) + .offset(offset) + .all() + ) + + return [record.to_dict() for record in feedbacks] + @classmethod def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): message = ( diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index c47c16f2f7..26d6d4ce18 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -20,9 +20,11 @@ class MetadataService: @staticmethod def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata: # check if metadata name already exists - if DatasetMetadata.query.filter_by( - tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name - ).first(): + if ( + db.session.query(DatasetMetadata) + .filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name) + .first() + ): raise ValueError("Metadata name already exists.") for field in BuiltInField: if field.value == metadata_args.name: @@ -42,16 +44,18 @@ class MetadataService: def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore lock_key = f"dataset_metadata_lock_{dataset_id}" # check if metadata name already exists - if DatasetMetadata.query.filter_by( - tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name - ).first(): + if ( + db.session.query(DatasetMetadata) + .filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name) + .first() + ): raise ValueError("Metadata name already exists.") for field in BuiltInField: if field.value == name: raise ValueError("Metadata name already exists in Built-in fields.") try: MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) - metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() + metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first() if metadata is None: raise ValueError("Metadata not found.") old_name = metadata.name @@ -60,7 +64,9 @@ class MetadataService: metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) # update related documents - dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() + dataset_metadata_bindings = ( + db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all() + ) if dataset_metadata_bindings: document_ids = [binding.document_id for binding in dataset_metadata_bindings] documents = DocumentService.get_document_by_ids(document_ids) @@ -82,13 +88,15 @@ class MetadataService: lock_key = f"dataset_metadata_lock_{dataset_id}" try: MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) - metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() + metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first() if metadata is None: raise ValueError("Metadata not found.") db.session.delete(metadata) # deal related documents - dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() + dataset_metadata_bindings = ( + db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all() + ) if dataset_metadata_bindings: document_ids = [binding.document_id for binding in dataset_metadata_bindings] documents = DocumentService.get_document_by_ids(document_ids) @@ -193,7 +201,7 @@ class MetadataService: db.session.add(document) db.session.commit() # deal metadata binding - DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete() + db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete() for metadata_value in operation.metadata_list: dataset_metadata_binding = DatasetMetadataBinding( tenant_id=current_user.current_tenant_id, @@ -230,9 +238,9 @@ class MetadataService: "id": item.get("id"), "name": item.get("name"), "type": item.get("type"), - "count": DatasetMetadataBinding.query.filter_by( - metadata_id=item.get("id"), dataset_id=dataset.id - ).count(), + "count": db.session.query(DatasetMetadataBinding) + .filter_by(metadata_id=item.get("id"), dataset_id=dataset.id) + .count(), } for item in dataset.doc_metadata or [] if item.get("id") != "built-in" diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 6b317212d1..792f50703e 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -1,5 +1,6 @@ -from typing import Optional +from typing import Any, Optional +from core.ops.entities.config_entity import BaseTracingConfig from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map from extensions.ext_database import db from models.model import App, TraceAppConfig @@ -87,16 +88,17 @@ class OpsService: :param tracing_config: tracing config :return: """ - if tracing_provider not in provider_config_map and tracing_provider: + try: + provider_config_map[tracing_provider] + except KeyError: return {"error": f"Invalid tracing provider: {tracing_provider}"} - config_class, other_keys = ( - provider_config_map[tracing_provider]["config_class"], - provider_config_map[tracing_provider]["other_keys"], - ) - # FIXME: ignore type error - default_config_instance = config_class(**tracing_config) # type: ignore - for key in other_keys: # type: ignore + provider_config: dict[str, Any] = provider_config_map[tracing_provider] + config_class: type[BaseTracingConfig] = provider_config["config_class"] + other_keys: list[str] = provider_config["other_keys"] + + default_config_instance: BaseTracingConfig = config_class(**tracing_config) + for key in other_keys: if key in tracing_config and tracing_config[key] == "": tracing_config[key] = getattr(default_config_instance, key, None) @@ -150,7 +152,9 @@ class OpsService: :param tracing_config: tracing config :return: """ - if tracing_provider not in provider_config_map: + try: + provider_config_map[tracing_provider] + except KeyError: raise ValueError(f"Invalid tracing provider: {tracing_provider}") # check if trace config already exists diff --git a/api/services/plugin/data_migration.py b/api/services/plugin/data_migration.py index 1c5abfecba..02de5a79d7 100644 --- a/api/services/plugin/data_migration.py +++ b/api/services/plugin/data_migration.py @@ -3,7 +3,7 @@ import logging import click -from core.entities import DEFAULT_PLUGIN_ID +from core.plugin.entities.plugin import GenericProviderID, ModelProviderID, ToolProviderID from models.engine import db logger = logging.getLogger(__name__) @@ -12,17 +12,17 @@ logger = logging.getLogger(__name__) class PluginDataMigration: @classmethod def migrate(cls) -> None: - cls.migrate_db_records("providers", "provider_name") # large table - cls.migrate_db_records("provider_models", "provider_name") - cls.migrate_db_records("provider_orders", "provider_name") - cls.migrate_db_records("tenant_default_models", "provider_name") - cls.migrate_db_records("tenant_preferred_model_providers", "provider_name") - cls.migrate_db_records("provider_model_settings", "provider_name") - cls.migrate_db_records("load_balancing_model_configs", "provider_name") + cls.migrate_db_records("providers", "provider_name", ModelProviderID) # large table + cls.migrate_db_records("provider_models", "provider_name", ModelProviderID) + cls.migrate_db_records("provider_orders", "provider_name", ModelProviderID) + cls.migrate_db_records("tenant_default_models", "provider_name", ModelProviderID) + cls.migrate_db_records("tenant_preferred_model_providers", "provider_name", ModelProviderID) + cls.migrate_db_records("provider_model_settings", "provider_name", ModelProviderID) + cls.migrate_db_records("load_balancing_model_configs", "provider_name", ModelProviderID) cls.migrate_datasets() - cls.migrate_db_records("embeddings", "provider_name") # large table - cls.migrate_db_records("dataset_collection_bindings", "provider_name") - cls.migrate_db_records("tool_builtin_providers", "provider") + cls.migrate_db_records("embeddings", "provider_name", ModelProviderID) # large table + cls.migrate_db_records("dataset_collection_bindings", "provider_name", ModelProviderID) + cls.migrate_db_records("tool_builtin_providers", "provider_name", ToolProviderID) @classmethod def migrate_datasets(cls) -> None: @@ -66,9 +66,10 @@ limit 1000""" fg="white", ) ) - retrieval_model["reranking_model"]["reranking_provider_name"] = ( - f"{DEFAULT_PLUGIN_ID}/{retrieval_model['reranking_model']['reranking_provider_name']}/{retrieval_model['reranking_model']['reranking_provider_name']}" - ) + # update google to langgenius/gemini/google etc. + retrieval_model["reranking_model"]["reranking_provider_name"] = ModelProviderID( + retrieval_model["reranking_model"]["reranking_provider_name"] + ).to_string() retrieval_model_changed = True click.echo( @@ -86,9 +87,11 @@ limit 1000""" update_retrieval_model_sql = ", retrieval_model = :retrieval_model" params["retrieval_model"] = json.dumps(retrieval_model) + params["provider_name"] = ModelProviderID(provider_name).to_string() + sql = f"""update {table_name} set {provider_column_name} = - concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name}) + :provider_name {update_retrieval_model_sql} where id = :record_id""" conn.execute(db.text(sql), params) @@ -122,7 +125,9 @@ limit 1000""" ) @classmethod - def migrate_db_records(cls, table_name: str, provider_column_name: str) -> None: + def migrate_db_records( + cls, table_name: str, provider_column_name: str, provider_cls: type[GenericProviderID] + ) -> None: click.echo(click.style(f"Migrating [{table_name}] data for plugin", fg="white")) processed_count = 0 @@ -166,7 +171,8 @@ limit 1000""" ) try: - updated_value = f"{DEFAULT_PLUGIN_ID}/{provider_name}/{provider_name}" + # update jina to langgenius/jina_tool/jina etc. + updated_value = provider_cls(provider_name).to_string() batch_updates.append((updated_value, record_id)) except Exception as e: failed_ids.append(record_id) diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index be722a59ad..d7fb4a7c1b 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -17,11 +17,18 @@ from core.plugin.entities.plugin import ( PluginInstallation, PluginInstallationSource, ) -from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginUploadResponse +from core.plugin.entities.plugin_daemon import ( + PluginDecodeResponse, + PluginInstallTask, + PluginListResponse, + PluginVerification, +) from core.plugin.impl.asset import PluginAssetManager from core.plugin.impl.debugging import PluginDebuggingClient from core.plugin.impl.plugin import PluginInstaller from extensions.ext_redis import redis_client +from services.errors.plugin import PluginInstallationForbiddenError +from services.feature_service import FeatureService, PluginInstallationScope logger = logging.getLogger(__name__) @@ -86,6 +93,42 @@ class PluginService: logger.exception("failed to fetch latest plugin version") return result + @staticmethod + def _check_marketplace_only_permission(): + """ + Check if the marketplace only permission is enabled + """ + features = FeatureService.get_system_features() + if features.plugin_installation_permission.restrict_to_marketplace_only: + raise PluginInstallationForbiddenError("Plugin installation is restricted to marketplace only") + + @staticmethod + def _check_plugin_installation_scope(plugin_verification: Optional[PluginVerification]): + """ + Check the plugin installation scope + """ + features = FeatureService.get_system_features() + + match features.plugin_installation_permission.plugin_installation_scope: + case PluginInstallationScope.OFFICIAL_ONLY: + if ( + plugin_verification is None + or plugin_verification.authorized_category != PluginVerification.AuthorizedCategory.Langgenius + ): + raise PluginInstallationForbiddenError("Plugin installation is restricted to official only") + case PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS: + if plugin_verification is None or plugin_verification.authorized_category not in [ + PluginVerification.AuthorizedCategory.Langgenius, + PluginVerification.AuthorizedCategory.Partner, + ]: + raise PluginInstallationForbiddenError( + "Plugin installation is restricted to official and specific partners" + ) + case PluginInstallationScope.NONE: + raise PluginInstallationForbiddenError("Installing plugins is not allowed") + case PluginInstallationScope.ALL: + pass + @staticmethod def get_debugging_key(tenant_id: str) -> str: """ @@ -110,6 +153,15 @@ class PluginService: plugins = manager.list_plugins(tenant_id) return plugins + @staticmethod + def list_with_total(tenant_id: str, page: int, page_size: int) -> PluginListResponse: + """ + list all plugins of the tenant + """ + manager = PluginInstaller() + plugins = manager.list_plugins_with_total(tenant_id, page, page_size) + return plugins + @staticmethod def list_installations_from_ids(tenant_id: str, ids: Sequence[str]) -> Sequence[PluginInstallation]: """ @@ -199,6 +251,8 @@ class PluginService: # check if plugin pkg is already downloaded manager = PluginInstaller() + features = FeatureService.get_system_features() + try: manager.fetch_plugin_manifest(tenant_id, new_plugin_unique_identifier) # already downloaded, skip, and record install event @@ -206,7 +260,14 @@ class PluginService: except Exception: # plugin not installed, download and upload pkg pkg = download_plugin_pkg(new_plugin_unique_identifier) - manager.upload_pkg(tenant_id, pkg, verify_signature=False) + response = manager.upload_pkg( + tenant_id, + pkg, + verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only, + ) + + # check if the plugin is available to install + PluginService._check_plugin_installation_scope(response.verification) return manager.upgrade_plugin( tenant_id, @@ -230,6 +291,7 @@ class PluginService: """ Upgrade plugin with github """ + PluginService._check_marketplace_only_permission() manager = PluginInstaller() return manager.upgrade_plugin( tenant_id, @@ -244,33 +306,43 @@ class PluginService: ) @staticmethod - def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginUploadResponse: + def upload_pkg(tenant_id: str, pkg: bytes, verify_signature: bool = False) -> PluginDecodeResponse: """ Upload plugin package files returns: plugin_unique_identifier """ + PluginService._check_marketplace_only_permission() manager = PluginInstaller() - return manager.upload_pkg(tenant_id, pkg, verify_signature) + features = FeatureService.get_system_features() + response = manager.upload_pkg( + tenant_id, + pkg, + verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only, + ) + return response @staticmethod def upload_pkg_from_github( tenant_id: str, repo: str, version: str, package: str, verify_signature: bool = False - ) -> PluginUploadResponse: + ) -> PluginDecodeResponse: """ Install plugin from github release package files, returns plugin_unique_identifier """ + PluginService._check_marketplace_only_permission() pkg = download_with_size_limit( f"https://github.com/{repo}/releases/download/{version}/{package}", dify_config.PLUGIN_MAX_PACKAGE_SIZE ) + features = FeatureService.get_system_features() manager = PluginInstaller() - return manager.upload_pkg( + response = manager.upload_pkg( tenant_id, pkg, - verify_signature, + verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only, ) + return response @staticmethod def upload_bundle( @@ -280,11 +352,15 @@ class PluginService: Upload a plugin bundle and return the dependencies. """ manager = PluginInstaller() + PluginService._check_marketplace_only_permission() return manager.upload_bundle(tenant_id, bundle, verify_signature) @staticmethod def install_from_local_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]): + PluginService._check_marketplace_only_permission() + manager = PluginInstaller() + return manager.install_from_identifiers( tenant_id, plugin_unique_identifiers, @@ -298,6 +374,8 @@ class PluginService: Install plugin from github release package files, returns plugin_unique_identifier """ + PluginService._check_marketplace_only_permission() + manager = PluginInstaller() return manager.install_from_identifiers( tenant_id, @@ -313,28 +391,33 @@ class PluginService: ) @staticmethod - def fetch_marketplace_pkg( - tenant_id: str, plugin_unique_identifier: str, verify_signature: bool = False - ) -> PluginDeclaration: + def fetch_marketplace_pkg(tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration: """ Fetch marketplace package """ if not dify_config.MARKETPLACE_ENABLED: raise ValueError("marketplace is not enabled") + features = FeatureService.get_system_features() + manager = PluginInstaller() try: declaration = manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier) except Exception: pkg = download_plugin_pkg(plugin_unique_identifier) - declaration = manager.upload_pkg(tenant_id, pkg, verify_signature).manifest + response = manager.upload_pkg( + tenant_id, + pkg, + verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only, + ) + # check if the plugin is available to install + PluginService._check_plugin_installation_scope(response.verification) + declaration = response.manifest return declaration @staticmethod - def install_from_marketplace_pkg( - tenant_id: str, plugin_unique_identifiers: Sequence[str], verify_signature: bool = False - ): + def install_from_marketplace_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]): """ Install plugin from marketplace package files, returns installation task id @@ -344,15 +427,26 @@ class PluginService: manager = PluginInstaller() + features = FeatureService.get_system_features() + # check if already downloaded for plugin_unique_identifier in plugin_unique_identifiers: try: manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier) + plugin_decode_response = manager.decode_plugin_from_identifier(tenant_id, plugin_unique_identifier) + # check if the plugin is available to install + PluginService._check_plugin_installation_scope(plugin_decode_response.verification) # already downloaded, skip except Exception: # plugin not installed, download and upload pkg pkg = download_plugin_pkg(plugin_unique_identifier) - manager.upload_pkg(tenant_id, pkg, verify_signature) + response = manager.upload_pkg( + tenant_id, + pkg, + verify_signature=features.plugin_installation_permission.restrict_to_marketplace_only, + ) + # check if the plugin is available to install + PluginService._check_plugin_installation_scope(response.verification) return manager.install_from_identifiers( tenant_id, diff --git a/api/services/tag_service.py b/api/services/tag_service.py index 21cb861f87..74c6150b44 100644 --- a/api/services/tag_service.py +++ b/api/services/tag_service.py @@ -44,6 +44,19 @@ class TagService: results = [tag_binding.target_id for tag_binding in tag_bindings] return results + @staticmethod + def get_tag_by_tag_name(tag_type: str, current_tenant_id: str, tag_name: str) -> list: + if not tag_type or not tag_name: + return [] + tags = ( + db.session.query(Tag) + .filter(Tag.name == tag_name, Tag.tenant_id == current_tenant_id, Tag.type == tag_type) + .all() + ) + if not tags: + return [] + return tags + @staticmethod def get_tags_by_target_id(tag_type: str, current_tenant_id: str, target_id: str) -> list: tags = ( @@ -62,6 +75,8 @@ class TagService: @staticmethod def save_tags(args: dict) -> Tag: + if TagService.get_tag_by_tag_name(args["type"], current_user.current_tenant_id, args["name"]): + raise ValueError("Tag name already exists") tag = Tag( id=str(uuid.uuid4()), name=args["name"], @@ -75,6 +90,8 @@ class TagService: @staticmethod def update_tags(args: dict, tag_id: str) -> Tag: + if TagService.get_tag_by_tag_name(args.get("type", ""), current_user.current_tenant_id, args.get("name", "")): + raise ValueError("Tag name already exists") tag = db.session.query(Tag).filter(Tag.id == tag_id).first() if not tag: raise NotFound("Tag not found") diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 3ccd14415d..58a4b2f179 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import Session from configs import dify_config from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder -from core.plugin.entities.plugin import GenericProviderID, ToolProviderID +from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.exc import PluginDaemonClientSideError from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity @@ -290,7 +290,7 @@ class BuiltinToolManageService: def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: try: full_provider_name = provider_name - provider_id_entity = GenericProviderID(provider_name) + provider_id_entity = ToolProviderID(provider_name) provider_name = provider_id_entity.provider_name if provider_id_entity.organization != "langgenius": provider_obj = ( @@ -315,7 +315,7 @@ class BuiltinToolManageService: if provider_obj is None: return None - provider_obj.provider = GenericProviderID(provider_obj.provider).to_string() + provider_obj.provider = ToolProviderID(provider_obj.provider).to_string() return provider_obj except Exception: # it's an old provider without organization diff --git a/api/services/vector_service.py b/api/services/vector_service.py index 92422bf29d..19e37f4ee3 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,3 +1,4 @@ +import logging from typing import Optional from core.model_manager import ModelInstance, ModelManager @@ -12,21 +13,30 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm from models.dataset import Document as DatasetDocument from services.entities.knowledge_entities.knowledge_entities import ParentMode +_logger = logging.getLogger(__name__) + class VectorService: @classmethod def create_segments_vector( cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str ): - documents = [] + documents: list[Document] = [] for segment in segments: if doc_form == IndexType.PARENT_CHILD_INDEX: - document = DatasetDocument.query.filter_by(id=segment.document_id).first() + dataset_document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first() + if not dataset_document: + _logger.warning( + "Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s", + segment.document_id, + segment.id, + ) + continue # get the process rule processing_rule = ( db.session.query(DatasetProcessRule) - .filter(DatasetProcessRule.id == document.dataset_process_rule_id) + .filter(DatasetProcessRule.id == dataset_document.dataset_process_rule_id) .first() ) if not processing_rule: @@ -50,9 +60,11 @@ class VectorService: ) else: raise ValueError("The knowledge base index technique is not high quality!") - cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False) + cls.generate_child_chunks( + segment, dataset_document, dataset, embedding_model_instance, processing_rule, False + ) else: - document = Document( + rag_document = Document( page_content=segment.content, metadata={ "doc_id": segment.index_node_id, @@ -61,7 +73,7 @@ class VectorService: "dataset_id": segment.dataset_id, }, ) - documents.append(document) + documents.append(rag_document) if len(documents) > 0: index_processor = IndexProcessorFactory(doc_form).init_index_processor() index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list) diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py new file mode 100644 index 0000000000..8f92b3f070 --- /dev/null +++ b/api/services/webapp_auth_service.py @@ -0,0 +1,178 @@ +import enum +import secrets +from datetime import UTC, datetime, timedelta +from typing import Any, Optional, cast + +from werkzeug.exceptions import NotFound, Unauthorized + +from configs import dify_config +from extensions.ext_database import db +from libs.helper import TokenManager +from libs.passport import PassportService +from libs.password import compare_password +from models.account import Account, AccountStatus +from models.model import App, EndUser, Site +from services.app_service import AppService +from services.enterprise.enterprise_service import EnterpriseService +from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError +from tasks.mail_email_code_login import send_email_code_login_mail_task + + +class WebAppAuthType(enum.StrEnum): + """Enum for web app authentication types.""" + + PUBLIC = "public" + INTERNAL = "internal" + EXTERNAL = "external" + + +class WebAppAuthService: + """Service for web app authentication.""" + + @staticmethod + def authenticate(email: str, password: str) -> Account: + """authenticate account with email and password""" + account = db.session.query(Account).filter_by(email=email).first() + if not account: + raise AccountNotFoundError() + + if account.status == AccountStatus.BANNED.value: + raise AccountLoginError("Account is banned.") + + if account.password is None or not compare_password(password, account.password, account.password_salt): + raise AccountPasswordError("Invalid email or password.") + + return cast(Account, account) + + @classmethod + def login(cls, account: Account) -> str: + access_token = cls._get_account_jwt_token(account=account) + + return access_token + + @classmethod + def get_user_through_email(cls, email: str): + account = db.session.query(Account).filter(Account.email == email).first() + if not account: + return None + + if account.status == AccountStatus.BANNED.value: + raise Unauthorized("Account is banned.") + + return account + + @classmethod + def send_email_code_login_email( + cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" + ): + email = account.email if account else email + if email is None: + raise ValueError("Email must be provided.") + + code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) + token = TokenManager.generate_token( + account=account, email=email, token_type="email_code_login", additional_data={"code": code} + ) + send_email_code_login_mail_task.delay( + language=language, + to=account.email if account else email, + code=code, + ) + + return token + + @classmethod + def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: + return TokenManager.get_token_data(token, "email_code_login") + + @classmethod + def revoke_email_code_login_token(cls, token: str): + TokenManager.revoke_token(token, "email_code_login") + + @classmethod + def create_end_user(cls, app_code, email) -> EndUser: + site = db.session.query(Site).filter(Site.code == app_code).first() + if not site: + raise NotFound("Site not found.") + app_model = db.session.query(App).filter(App.id == site.app_id).first() + if not app_model: + raise NotFound("App not found.") + end_user = EndUser( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type="browser", + is_anonymous=False, + session_id=email, + name="enterpriseuser", + external_user_id="enterpriseuser", + ) + db.session.add(end_user) + db.session.commit() + + return end_user + + @classmethod + def _get_account_jwt_token(cls, account: Account) -> str: + exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24) + exp = int(exp_dt.timestamp()) + + payload = { + "sub": "Web API Passport", + "user_id": account.id, + "session_id": account.email, + "token_source": "webapp_login_token", + "auth_type": "internal", + "exp": exp, + } + + token: str = PassportService().issue(payload) + return token + + @classmethod + def is_app_require_permission_check( + cls, app_code: Optional[str] = None, app_id: Optional[str] = None, access_mode: Optional[str] = None + ) -> bool: + """ + Check if the app requires permission check based on its access mode. + """ + modes_requiring_permission_check = [ + "private", + "private_all", + ] + if access_mode: + return access_mode in modes_requiring_permission_check + + if not app_code and not app_id: + raise ValueError("Either app_code or app_id must be provided.") + + if app_code: + app_id = AppService.get_app_id_by_code(app_code) + if not app_id: + raise ValueError("App ID could not be determined from the provided app_code.") + + webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) + if webapp_settings and webapp_settings.access_mode in modes_requiring_permission_check: + return True + return False + + @classmethod + def get_app_auth_type(cls, app_code: str | None = None, access_mode: str | None = None) -> WebAppAuthType: + """ + Get the authentication type for the app based on its access mode. + """ + if not app_code and not access_mode: + raise ValueError("Either app_code or access_mode must be provided.") + + if access_mode: + if access_mode == "public": + return WebAppAuthType.PUBLIC + elif access_mode in ["private", "private_all"]: + return WebAppAuthType.INTERNAL + elif access_mode == "sso_verified": + return WebAppAuthType.EXTERNAL + + if app_code: + webapp_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code) + return cls.get_app_auth_type(access_mode=webapp_settings.access_mode) + + raise ValueError("Could not determine app authentication type.") diff --git a/api/services/website_service.py b/api/services/website_service.py index 3913dc2efe..6720932a3a 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -173,26 +173,27 @@ class WebsiteService: return crawl_status_data @classmethod - def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[Any, Any] | None: + def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[str, Any] | None: credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) # decrypt api_key api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) - # FIXME data is redefine too many times here, use Any to ease the type checking, fix it later - data: Any + if provider == "firecrawl": + crawl_data: list[dict[str, Any]] | None = None file_key = "website_files/" + job_id + ".txt" if storage.exists(file_key): - d = storage.load_once(file_key) - if d: - data = json.loads(d.decode("utf-8")) + stored_data = storage.load_once(file_key) + if stored_data: + crawl_data = json.loads(stored_data.decode("utf-8")) else: firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) result = firecrawl_app.check_crawl_status(job_id) if result.get("status") != "completed": raise ValueError("Crawl job is not completed") - data = result.get("data") - if data: - for item in data: + crawl_data = result.get("data") + + if crawl_data: + for item in crawl_data: if item.get("source_url") == url: return dict(item) return None @@ -211,23 +212,24 @@ class WebsiteService: raise ValueError("Failed to crawl") return dict(response.json().get("data", {})) else: - api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) - response = requests.post( + # Get crawl status first + status_response = requests.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, json={"taskId": job_id}, ) - data = response.json().get("data", {}) - if data.get("status") != "completed": + status_data = status_response.json().get("data", {}) + if status_data.get("status") != "completed": raise ValueError("Crawl job is not completed") - response = requests.post( + # Get processed data + data_response = requests.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, - json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, + json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())}, ) - data = response.json().get("data", {}) - for item in data.get("processed", {}).values(): + processed_data = data_response.json().get("data", {}) + for item in processed_data.get("processed", {}).values(): if item.get("data", {}).get("url") == url: return dict(item.get("data", {})) return None diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index e526517b51..6eabf03018 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -4,9 +4,9 @@ from datetime import datetime from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session -from models import App, EndUser, WorkflowAppLog, WorkflowRun -from models.enums import CreatedByRole -from models.workflow import WorkflowRunStatus +from core.workflow.entities.workflow_execution import WorkflowExecutionStatus +from models import Account, App, EndUser, WorkflowAppLog, WorkflowRun +from models.enums import CreatorUserRole class WorkflowAppService: @@ -16,11 +16,13 @@ class WorkflowAppService: session: Session, app_model: App, keyword: str | None = None, - status: WorkflowRunStatus | None = None, + status: WorkflowExecutionStatus | None = None, created_at_before: datetime | None = None, created_at_after: datetime | None = None, page: int = 1, limit: int = 20, + created_by_end_user_session_id: str | None = None, + created_by_account: str | None = None, ) -> dict: """ Get paginate workflow app logs using SQLAlchemy 2.0 style @@ -32,6 +34,8 @@ class WorkflowAppService: :param created_at_after: filter logs created after this timestamp :param page: page number :param limit: items per page + :param created_by_end_user_session_id: filter by end user session id + :param created_by_account: filter by account email :return: Pagination object """ # Build base statement using SQLAlchemy 2.0 style @@ -58,7 +62,7 @@ class WorkflowAppService: stmt = stmt.outerjoin( EndUser, - and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER), + and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatorUserRole.END_USER), ).where(or_(*keyword_conditions)) if status: @@ -71,6 +75,26 @@ class WorkflowAppService: if created_at_after: stmt = stmt.where(WorkflowAppLog.created_at >= created_at_after) + # Filter by end user session id or account email + if created_by_end_user_session_id: + stmt = stmt.join( + EndUser, + and_( + WorkflowAppLog.created_by == EndUser.id, + WorkflowAppLog.created_by_role == CreatorUserRole.END_USER, + EndUser.session_id == created_by_end_user_session_id, + ), + ) + if created_by_account: + stmt = stmt.join( + Account, + and_( + WorkflowAppLog.created_by == Account.id, + WorkflowAppLog.created_by_role == CreatorUserRole.ACCOUNT, + Account.email == created_by_account, + ), + ) + stmt = stmt.order_by(WorkflowAppLog.created_at.desc()) # Get total count using the same filters diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 6d5b737962..483c0d3086 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,17 +1,21 @@ import threading +from collections.abc import Sequence from typing import Optional import contexts from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.repository.workflow_node_execution_repository import OrderConfig +from core.workflow.repositories.workflow_node_execution_repository import OrderConfig from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination -from models.enums import WorkflowRunTriggeredFrom -from models.model import App -from models.workflow import ( - WorkflowNodeExecution, +from models import ( + Account, + App, + EndUser, + WorkflowNodeExecutionModel, WorkflowRun, + WorkflowRunTriggeredFrom, ) +from models.workflow import WorkflowNodeExecutionTriggeredFrom class WorkflowRunService: @@ -116,7 +120,12 @@ class WorkflowRunService: return workflow_run - def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[WorkflowNodeExecution]: + def get_workflow_run_node_executions( + self, + app_model: App, + run_id: str, + user: Account | EndUser, + ) -> Sequence[WorkflowNodeExecutionModel]: """ Get workflow run node execution list """ @@ -128,13 +137,17 @@ class WorkflowRunService: if not workflow_run: return [] - # Use the repository to get the node executions repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id + session_factory=db.engine, + user=user, + app_id=app_model.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - # Use the repository to get the node executions with ordering + # Use the repository to get the database models directly order_config = OrderConfig(order_by=["index"], order_direction="desc") - node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) + workflow_node_executions = repository.get_db_models_by_workflow_run( + workflow_run_id=run_id, order_config=order_config + ) - return list(node_executions) + return workflow_node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 331dba8bf1..bc213ccce6 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -10,10 +10,10 @@ from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.model_runtime.utils.encoders import jsonable_encoder from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.variables import Variable from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes import NodeType @@ -26,13 +26,11 @@ from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from models.account import Account -from models.enums import CreatedByRole from models.model import App, AppMode from models.tools import WorkflowToolProvider from models.workflow import ( Workflow, - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, + WorkflowNodeExecutionModel, WorkflowNodeExecutionTriggeredFrom, WorkflowType, ) @@ -256,7 +254,7 @@ class WorkflowService: def run_draft_workflow_node( self, app_model: App, node_id: str, user_inputs: dict, account: Account - ) -> WorkflowNodeExecution: + ) -> WorkflowNodeExecutionModel: """ Run draft workflow node """ @@ -268,27 +266,31 @@ class WorkflowService: # run draft workflow node start_at = time.perf_counter() - workflow_node_execution = self._handle_node_run_result( - getter=lambda: WorkflowEntry.single_step_run( + node_execution = self._handle_node_run_result( + invoke_node_fn=lambda: WorkflowEntry.single_step_run( workflow=draft_workflow, node_id=node_id, user_inputs=user_inputs, user_id=account.id, ), start_at=start_at, - tenant_id=app_model.tenant_id, node_id=node_id, ) - workflow_node_execution.app_id = app_model.id - workflow_node_execution.created_by = account.id - workflow_node_execution.workflow_id = draft_workflow.id + # Set workflow_id on the NodeExecution + node_execution.workflow_id = draft_workflow.id - # Use the repository to save the workflow node execution + # Create repository and save the node execution repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id + session_factory=db.engine, + user=account, + app_id=app_model.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) - repository.save(workflow_node_execution) + repository.save(node_execution) + + # Convert node_execution to WorkflowNodeExecution after save + workflow_node_execution = repository.to_db_model(node_execution) return workflow_node_execution @@ -302,7 +304,7 @@ class WorkflowService: start_at = time.perf_counter() workflow_node_execution = self._handle_node_run_result( - getter=lambda: WorkflowEntry.run_free_node( + invoke_node_fn=lambda: WorkflowEntry.run_free_node( node_id=node_id, node_data=node_data, tenant_id=tenant_id, @@ -310,7 +312,6 @@ class WorkflowService: user_inputs=user_inputs, ), start_at=start_at, - tenant_id=tenant_id, node_id=node_id, ) @@ -318,21 +319,12 @@ class WorkflowService: def _handle_node_run_result( self, - getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]], + invoke_node_fn: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]], start_at: float, - tenant_id: str, node_id: str, ) -> WorkflowNodeExecution: - """ - Handle node run result - - :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]] - :param start_at: float - :param tenant_id: str - :param node_id: str - """ try: - node_instance, generator = getter() + node_instance, generator = invoke_node_fn() node_run_result: NodeRunResult | None = None for event in generator: @@ -381,20 +373,21 @@ class WorkflowService: node_run_result = None error = e.error - workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.id = str(uuid4()) - workflow_node_execution.tenant_id = tenant_id - workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value - workflow_node_execution.index = 1 - workflow_node_execution.node_id = node_id - workflow_node_execution.node_type = node_instance.node_type - workflow_node_execution.title = node_instance.node_data.title - workflow_node_execution.elapsed_time = time.perf_counter() - start_at - workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value - workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) + # Create a NodeExecution domain model + node_execution = WorkflowNodeExecution( + id=str(uuid4()), + workflow_id="", # This is a single-step execution, so no workflow ID + index=1, + node_id=node_id, + node_type=node_instance.node_type, + title=node_instance.node_data.title, + elapsed_time=time.perf_counter() - start_at, + created_at=datetime.now(UTC).replace(tzinfo=None), + finished_at=datetime.now(UTC).replace(tzinfo=None), + ) + if run_succeeded and node_run_result: - # create workflow node execution + # Set inputs, process_data, and outputs as dictionaries (not JSON strings) inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None process_data = ( WorkflowEntry.handle_special_values(node_run_result.process_data) @@ -403,23 +396,23 @@ class WorkflowService: ) outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None - workflow_node_execution.inputs = json.dumps(inputs) - workflow_node_execution.process_data = json.dumps(process_data) - workflow_node_execution.outputs = json.dumps(outputs) - workflow_node_execution.execution_metadata = ( - json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None - ) + node_execution.inputs = inputs + node_execution.process_data = process_data + node_execution.outputs = outputs + node_execution.metadata = node_run_result.metadata + + # Map status from WorkflowNodeExecutionStatus to NodeExecutionStatus if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION: - workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value - workflow_node_execution.error = node_run_result.error + node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION + node_execution.error = node_run_result.error else: - # create workflow node execution - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error + # Set failed status and error + node_execution.status = WorkflowNodeExecutionStatus.FAILED + node_execution.error = error - return workflow_node_execution + return node_execution def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App: """ @@ -514,11 +507,11 @@ class WorkflowService: raise DraftWorkflowDeletionError("Cannot delete draft workflow versions") # Check if this workflow is currently referenced by an app - stmt = select(App).where(App.workflow_id == workflow_id) - app = session.scalar(stmt) + app_stmt = select(App).where(App.workflow_id == workflow_id) + app = session.scalar(app_stmt) if app: # Cannot delete a workflow that's currently in use by an app - raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'") + raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.id}'") # Don't use workflow.tool_published as it's not accurate for specific workflow versions # Check if there's a tool provider using this specific workflow version diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index be88881efc..75d648e1b7 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -111,7 +111,7 @@ def add_document_to_index_task(dataset_document_id: str): logging.exception("add document to index failed") dataset_document.enabled = False dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) - dataset_document.status = "error" + dataset_document.indexing_status = "error" dataset_document.error = str(e) db.session.commit() finally: diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index f32bc4f187..51b6343fdc 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -5,7 +5,7 @@ import uuid import click from celery import shared_task # type: ignore -from sqlalchemy import func, select +from sqlalchemy import func from sqlalchemy.orm import Session from core.model_manager import ModelManager @@ -68,11 +68,6 @@ def batch_create_segment_to_index_task( model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) - word_count_change = 0 - segments_to_insert: list[str] = [] - max_position_stmt = select(func.max(DocumentSegment.position)).where( - DocumentSegment.document_id == dataset_document.id - ) word_count_change = 0 if embedding_model: tokens_list = embedding_model.get_text_embedding_num_tokens( diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 4500b2a44b..a3f811faa1 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -41,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] DocumentSegment.status: "indexing", DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } - DocumentSegment.query.filter_by(id=segment.id).update(update_params) + db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params) db.session.commit() document = Document( page_content=segment.content, @@ -78,7 +78,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] DocumentSegment.status: "completed", DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } - DocumentSegment.query.filter_by(id=segment.id).update(update_params) + db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params) db.session.commit() end_at = time.perf_counter() diff --git a/api/tasks/deal_dataset_vector_index_task.py b/api/tasks/deal_dataset_vector_index_task.py index 075453e283..a27207f2f1 100644 --- a/api/tasks/deal_dataset_vector_index_task.py +++ b/api/tasks/deal_dataset_vector_index_task.py @@ -24,7 +24,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str): start_at = time.perf_counter() try: - dataset = Dataset.query.filter_by(id=dataset_id).first() + dataset = db.session.query(Dataset).filter_by(id=dataset_id).first() if not dataset: raise Exception("Dataset not found") diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 2e68dcb0fb..b4848be192 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -44,14 +44,18 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): page_id = data_source_info["notion_page_id"] page_type = data_source_info["type"] page_edited_time = data_source_info["last_edited_time"] - data_source_binding = DataSourceOauthBinding.query.filter( - db.and_( - DataSourceOauthBinding.tenant_id == document.tenant_id, - DataSourceOauthBinding.provider == "notion", - DataSourceOauthBinding.disabled == False, - DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + data_source_binding = ( + db.session.query(DataSourceOauthBinding) + .filter( + db.and_( + DataSourceOauthBinding.tenant_id == document.tenant_id, + DataSourceOauthBinding.provider == "notion", + DataSourceOauthBinding.disabled == False, + DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', + ) ) - ).first() + .first() + ) if not data_source_binding: raise ValueError("Data source binding not found.") @@ -110,4 +114,4 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: - pass + logging.exception("document_indexing_sync_task failed, document_id: {}".format(document_id)) diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index ee470d44e8..55cac6a9af 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -81,6 +81,6 @@ def document_indexing_task(dataset_id: str, document_ids: list): except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: - pass + logging.exception("Document indexing task failed, dataset_id: {}".format(dataset_id)) finally: db.session.close() diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index b9ed11a8da..167b928f5d 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -73,6 +73,6 @@ def document_indexing_update_task(dataset_id: str, document_id: str): except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: - pass + logging.exception("document_indexing_update_task failed, document_id: {}".format(document_id)) finally: db.session.close() diff --git a/api/tasks/duplicate_document_indexing_task.py b/api/tasks/duplicate_document_indexing_task.py index 100fc257ce..a6c93e110e 100644 --- a/api/tasks/duplicate_document_indexing_task.py +++ b/api/tasks/duplicate_document_indexing_task.py @@ -99,6 +99,6 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list): except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: - pass + logging.exception("duplicate_document_indexing_task failed, dataset_id: {}".format(dataset_id)) finally: db.session.close() diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py index 5dc935548f..ddad331725 100644 --- a/api/tasks/mail_email_code_login.py +++ b/api/tasks/mail_email_code_login.py @@ -6,6 +6,7 @@ from celery import shared_task # type: ignore from flask import render_template from extensions.ext_mail import mail +from services.feature_service import FeatureService @shared_task(queue="mail") @@ -25,10 +26,24 @@ def send_email_code_login_mail_task(language: str, to: str, code: str): # send email code login mail using different languages try: if language == "zh-Hans": - html_content = render_template("email_code_login_mail_template_zh-CN.html", to=to, code=code) + template = "email_code_login_mail_template_zh-CN.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + application_title = system_features.branding.application_title + template = "without-brand/email_code_login_mail_template_zh-CN.html" + html_content = render_template(template, to=to, code=code, application_title=application_title) + else: + html_content = render_template(template, to=to, code=code) mail.send(to=to, subject="邮箱验证码", html=html_content) else: - html_content = render_template("email_code_login_mail_template_en-US.html", to=to, code=code) + template = "email_code_login_mail_template_en-US.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + application_title = system_features.branding.application_title + template = "without-brand/email_code_login_mail_template_en-US.html" + html_content = render_template(template, to=to, code=code, application_title=application_title) + else: + html_content = render_template(template, to=to, code=code) mail.send(to=to, subject="Email Code", html=html_content) end_at = time.perf_counter() diff --git a/api/tasks/mail_enterprise_task.py b/api/tasks/mail_enterprise_task.py new file mode 100644 index 0000000000..b9d8fd55df --- /dev/null +++ b/api/tasks/mail_enterprise_task.py @@ -0,0 +1,33 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore +from flask import render_template_string + +from extensions.ext_mail import mail + + +@shared_task(queue="mail") +def send_enterprise_email_task(to, subject, body, substitutions): + if not mail.is_inited(): + return + + logging.info(click.style("Start enterprise mail to {} with subject {}".format(to, subject), fg="green")) + start_at = time.perf_counter() + + try: + html_content = render_template_string(body, **substitutions) + + if isinstance(to, list): + for t in to: + mail.send(to=t, subject=subject, html=html_content) + else: + mail.send(to=to, subject=subject, html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style("Send enterprise mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green") + ) + except Exception: + logging.exception("Send enterprise mail to {} failed".format(to)) diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index 3094527fd4..7ca85c7f2d 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -7,6 +7,7 @@ from flask import render_template from configs import dify_config from extensions.ext_mail import mail +from services.feature_service import FeatureService @shared_task(queue="mail") @@ -33,23 +34,45 @@ def send_invite_member_mail_task(language: str, to: str, token: str, inviter_nam try: url = f"{dify_config.CONSOLE_WEB_URL}/activate?token={token}" if language == "zh-Hans": - html_content = render_template( - "invite_member_mail_template_zh-CN.html", - to=to, - inviter_name=inviter_name, - workspace_name=workspace_name, - url=url, - ) - mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) + template = "invite_member_mail_template_zh-CN.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + application_title = system_features.branding.application_title + template = "without-brand/invite_member_mail_template_zh-CN.html" + html_content = render_template( + template, + to=to, + inviter_name=inviter_name, + workspace_name=workspace_name, + url=url, + application_title=application_title, + ) + mail.send(to=to, subject=f"立即加入 {application_title} 工作空间", html=html_content) + else: + html_content = render_template( + template, to=to, inviter_name=inviter_name, workspace_name=workspace_name, url=url + ) + mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) else: - html_content = render_template( - "invite_member_mail_template_en-US.html", - to=to, - inviter_name=inviter_name, - workspace_name=workspace_name, - url=url, - ) - mail.send(to=to, subject="Join Dify Workspace Now", html=html_content) + template = "invite_member_mail_template_en-US.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + application_title = system_features.branding.application_title + template = "without-brand/invite_member_mail_template_en-US.html" + html_content = render_template( + template, + to=to, + inviter_name=inviter_name, + workspace_name=workspace_name, + url=url, + application_title=application_title, + ) + mail.send(to=to, subject=f"Join {application_title} Workspace Now", html=html_content) + else: + html_content = render_template( + template, to=to, inviter_name=inviter_name, workspace_name=workspace_name, url=url + ) + mail.send(to=to, subject="Join Dify Workspace Now", html=html_content) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index d5be94431b..d4f4482a48 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -6,6 +6,7 @@ from celery import shared_task # type: ignore from flask import render_template from extensions.ext_mail import mail +from services.feature_service import FeatureService @shared_task(queue="mail") @@ -25,11 +26,27 @@ def send_reset_password_mail_task(language: str, to: str, code: str): # send reset password mail using different languages try: if language == "zh-Hans": - html_content = render_template("reset_password_mail_template_zh-CN.html", to=to, code=code) - mail.send(to=to, subject="设置您的 Dify 密码", html=html_content) + template = "reset_password_mail_template_zh-CN.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + application_title = system_features.branding.application_title + template = "without-brand/reset_password_mail_template_zh-CN.html" + html_content = render_template(template, to=to, code=code, application_title=application_title) + mail.send(to=to, subject=f"设置您的 {application_title} 密码", html=html_content) + else: + html_content = render_template(template, to=to, code=code) + mail.send(to=to, subject="设置您的 Dify 密码", html=html_content) else: - html_content = render_template("reset_password_mail_template_en-US.html", to=to, code=code) - mail.send(to=to, subject="Set Your Dify Password", html=html_content) + template = "reset_password_mail_template_en-US.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + application_title = system_features.branding.application_title + template = "without-brand/reset_password_mail_template_en-US.html" + html_content = render_template(template, to=to, code=code, application_title=application_title) + mail.send(to=to, subject=f"Set Your {application_title} Password", html=html_content) + else: + html_content = render_template(template, to=to, code=code) + mail.send(to=to, subject="Set Your Dify Password", html=html_content) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/recover_document_indexing_task.py b/api/tasks/recover_document_indexing_task.py index eada2ff9db..e7d49c78dc 100644 --- a/api/tasks/recover_document_indexing_task.py +++ b/api/tasks/recover_document_indexing_task.py @@ -43,6 +43,6 @@ def recover_document_indexing_task(dataset_id: str, document_id: str): except DocumentIsPausedError as ex: logging.info(click.style(str(ex), fg="yellow")) except Exception: - pass + logging.exception("recover_document_indexing_task failed, document_id: {}".format(document_id)) finally: db.session.close() diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index d5a783396a..d366efd6f2 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -7,13 +7,12 @@ from celery import shared_task # type: ignore from sqlalchemy import delete from sqlalchemy.exc import SQLAlchemyError -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from extensions.ext_database import db -from models.dataset import AppDatasetJoin -from models.model import ( +from models import ( ApiToken, AppAnnotationHitHistory, AppAnnotationSetting, + AppDatasetJoin, AppModelConfig, Conversation, EndUser, @@ -31,7 +30,7 @@ from models.model import ( ) from models.tools import WorkflowToolProvider from models.web import PinnedConversation, SavedMessage -from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowRun +from models.workflow import ConversationVariable, Workflow, WorkflowAppLog, WorkflowNodeExecutionModel, WorkflowRun @shared_task(queue="app_deletion", bind=True, max_retries=3) @@ -188,15 +187,17 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str): def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): - # Create a repository instance for WorkflowNodeExecution - repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, tenant_id=tenant_id, app_id=app_id - ) - - # Use the clear method to delete all records for this tenant_id and app_id - repository.clear() + def del_workflow_node_execution(workflow_node_execution_id: str): + db.session.query(WorkflowNodeExecutionModel).filter( + WorkflowNodeExecutionModel.id == workflow_node_execution_id + ).delete(synchronize_session=False) - logging.info(click.style(f"Deleted workflow node executions for tenant {tenant_id} and app {app_id}", fg="green")) + _delete_records( + """select id from workflow_node_executions where tenant_id=:tenant_id and app_id=:app_id limit 1000""", + {"tenant_id": tenant_id, "app_id": app_id}, + del_workflow_node_execution, + "workflow node execution", + ) def _delete_app_workflow_app_logs(tenant_id: str, app_id: str): diff --git a/api/tasks/retry_document_indexing_task.py b/api/tasks/retry_document_indexing_task.py index 7e50eb9f8d..8f8c3f9d81 100644 --- a/api/tasks/retry_document_indexing_task.py +++ b/api/tasks/retry_document_indexing_task.py @@ -30,11 +30,11 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): logging.info(click.style("Dataset not found: {}".format(dataset_id), fg="red")) db.session.close() return - + tenant_id = dataset.tenant_id for document_id in document_ids: retry_indexing_cache_key = "document_{}_is_retried".format(document_id) # check document limit - features = FeatureService.get_features(dataset.tenant_id) + features = FeatureService.get_features(tenant_id) try: if features.billing.enabled: vector_space = features.vector_space @@ -95,7 +95,7 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]): db.session.commit() logging.info(click.style(str(ex), fg="yellow")) redis_client.delete(retry_indexing_cache_key) - pass + logging.exception("retry_document_indexing_task failed, document_id: {}".format(document_id)) finally: db.session.close() end_at = time.perf_counter() diff --git a/api/tasks/sync_website_document_indexing_task.py b/api/tasks/sync_website_document_indexing_task.py index e75252edbe..dba0a39c2d 100644 --- a/api/tasks/sync_website_document_indexing_task.py +++ b/api/tasks/sync_website_document_indexing_task.py @@ -87,6 +87,6 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str): db.session.commit() logging.info(click.style(str(ex), fg="yellow")) redis_client.delete(sync_indexing_cache_key) - pass + logging.exception("sync_website_document_indexing_task failed, document_id: {}".format(document_id)) end_at = time.perf_counter() logging.info(click.style("Sync document: {} latency: {}".format(document_id, end_at - start_at), fg="green")) diff --git a/api/templates/clean_document_job_mail_template-US.html b/api/templates/clean_document_job_mail_template-US.html index 0f7ddc62a9..2d8f78b46a 100644 --- a/api/templates/clean_document_job_mail_template-US.html +++ b/api/templates/clean_document_job_mail_template-US.html @@ -69,7 +69,7 @@