pull/22133/head
Joel 10 months ago
commit 38d1c85c57

@ -83,9 +83,15 @@ jobs:
compose-file: | compose-file: |
docker/docker-compose.middleware.yaml docker/docker-compose.middleware.yaml
services: | services: |
db
redis
sandbox sandbox
ssrf_proxy ssrf_proxy
- name: setup test config
run: |
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
- name: Run Workflow - name: Run Workflow
run: uv run --project api bash dev/pytest/pytest_workflow.sh run: uv run --project api bash dev/pytest/pytest_workflow.sh

@ -84,6 +84,12 @@ jobs:
elasticsearch elasticsearch
oceanbase oceanbase
- name: setup test config
run: |
echo $(pwd)
ls -lah .
cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env
- name: Check VDB Ready (TiDB) - name: Check VDB Ready (TiDB)
run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py run: uv run --project api python api/tests/integration_tests/vdb/tidb_vector/check_tiflash_ready.py

1
.gitignore vendored

@ -179,6 +179,7 @@ docker/volumes/pgvecto_rs/data/*
docker/volumes/couchbase/* docker/volumes/couchbase/*
docker/volumes/oceanbase/* docker/volumes/oceanbase/*
docker/volumes/plugin_daemon/* docker/volumes/plugin_daemon/*
docker/volumes/matrixone/*
!docker/volumes/oceanbase/init.d !docker/volumes/oceanbase/init.d
docker/nginx/conf.d/default.conf docker/nginx/conf.d/default.conf

@ -226,6 +226,15 @@ Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/)
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Using Alibaba Cloud Computing Nest
Quickly deploy Dify to Alibaba cloud with [Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Using Alibaba Cloud Data Management
One-Click deploy Dify to Alibaba Cloud with [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Contributing ## Contributing
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).

@ -209,6 +209,14 @@ docker compose up -d
- [AWS CDK بواسطة @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - [AWS CDK بواسطة @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### استخدام Alibaba Cloud للنشر
[بسرعة نشر Dify إلى سحابة علي بابا مع عش الحوسبة السحابية علي بابا](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### استخدام Alibaba Cloud Data Management للنشر
انشر Dify على علي بابا كلاود بنقرة واحدة باستخدام [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## المساهمة ## المساهمة
لأولئك الذين يرغبون في المساهمة، انظر إلى [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) لدينا. لأولئك الذين يرغبون في المساهمة، انظر إلى [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) لدينا.

@ -225,6 +225,15 @@ GitHub-এ ডিফাইকে স্টার দিয়ে রাখুন
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud ব্যবহার করে ডিপ্লয়
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management ব্যবহার করে ডিপ্লয়
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Contributing ## Contributing
যারা কোড অবদান রাখতে চান, তাদের জন্য আমাদের [অবদান নির্দেশিকা] দেখুন (https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)। যারা কোড অবদান রাখতে চান, তাদের জন্য আমাদের [অবদান নির্দেশিকা] দেখুন (https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)।

@ -221,6 +221,15 @@ docker compose up -d
##### AWS ##### AWS
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### 使用 阿里云计算巢 部署
使用 [阿里云计算巢](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88) 将 Dify 一键部署到 阿里云
#### 使用 阿里云数据管理DMS 部署
使用 [阿里云数据管理DMS](https://help.aliyun.com/zh/dms/dify-in-invitational-preview) 将 Dify 一键部署到 阿里云
## Star History ## Star History
[![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date)

@ -221,6 +221,15 @@ Bereitstellung von Dify auf AWS mit [CDK](https://aws.amazon.com/cdk/)
##### AWS ##### AWS
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
Ein-Klick-Bereitstellung von Dify in der Alibaba Cloud mit [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Contributing ## Contributing
Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren. Falls Sie Code beitragen möchten, lesen Sie bitte unseren [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Gleichzeitig bitten wir Sie, Dify zu unterstützen, indem Sie es in den sozialen Medien teilen und auf Veranstaltungen und Konferenzen präsentieren.

@ -221,6 +221,15 @@ Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/)
##### AWS ##### AWS
- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
Despliega Dify en Alibaba Cloud con un solo clic con [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Contribuir ## Contribuir
Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).

@ -219,6 +219,15 @@ Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/)
##### AWS ##### AWS
- [AWS CDK par @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - [AWS CDK par @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
Déployez Dify en un clic sur Alibaba Cloud avec [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Contribuer ## Contribuer
Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).

@ -155,7 +155,7 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ
[こちら](https://dify.ai)のDify Cloudサービスを利用して、セットアップ不要で試すことができます。サンドボックスプランには、200回のGPT-4呼び出しが無料で含まれています。 [こちら](https://dify.ai)のDify Cloudサービスを利用して、セットアップ不要で試すことができます。サンドボックスプランには、200回のGPT-4呼び出しが無料で含まれています。
- **Dify Community Editionのセルフホスティング</br>** - **Dify Community Editionのセルフホスティング</br>**
この[スタートガイド](#quick-start)を使用して、ローカル環境でDifyを簡単に実行できます。 この[スタートガイド](#クイックスタート)を使用して、ローカル環境でDifyを簡単に実行できます。
詳しくは[ドキュメント](https://docs.dify.ai)をご覧ください。 詳しくは[ドキュメント](https://docs.dify.ai)をご覧ください。
- **企業/組織向けのDify</br>** - **企業/組織向けのDify</br>**
@ -220,6 +220,13 @@ docker compose up -d
##### AWS ##### AWS
- [@KevinZhaoによるAWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - [@KevinZhaoによるAWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) を利用して、DifyをAlibaba Cloudへワンクリックでデプロイできます
## 貢献 ## 貢献
コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。 コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。

@ -219,6 +219,15 @@ wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo
##### AWS ##### AWS
- [AWS CDK qachlot @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - [AWS CDK qachlot @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Contributing ## Contributing
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).

@ -213,6 +213,15 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했
##### AWS ##### AWS
- [KevinZhao의 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - [KevinZhao의 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)를 통해 원클릭으로 Dify를 Alibaba Cloud에 배포할 수 있습니다
## 기여 ## 기여
코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요. 코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요.

@ -218,6 +218,15 @@ Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/)
##### AWS ##### AWS
- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
Implante o Dify na Alibaba Cloud com um clique usando o [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Contribuindo ## Contribuindo
Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).

@ -219,6 +219,15 @@ Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/)
##### AWS ##### AWS
- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
Z enim klikom namestite Dify na Alibaba Cloud z [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Prispevam ## Prispevam
Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkrati vas prosimo, da podprete Dify tako, da ga delite na družbenih medijih ter na dogodkih in konferencah. Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkrati vas prosimo, da podprete Dify tako, da ga delite na družbenih medijih ter na dogodkih in konferencah.

@ -212,6 +212,15 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter
##### AWS ##### AWS
- [AWS CDK tarafından @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - [AWS CDK tarafından @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
[Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/) kullanarak Dify'ı tek tıkla Alibaba Cloud'a dağıtın
## Katkıda Bulunma ## Katkıda Bulunma
Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz. Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz.

@ -224,6 +224,15 @@ Dify 的所有功能都提供相應的 API因此您可以輕鬆地將 Dify
- [由 @KevinZhao 提供的 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - [由 @KevinZhao 提供的 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### 使用 阿里云计算巢進行部署
[阿里云](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### 使用 阿里雲數據管理DMS 進行部署
透過 [阿里雲數據管理DMS](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/),一鍵將 Dify 部署至阿里雲
## 貢獻 ## 貢獻
對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。 對於想要貢獻程式碼的開發者,請參閱我們的[貢獻指南](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)。

@ -214,6 +214,16 @@ Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/)
##### AWS ##### AWS
- [AWS CDK bởi @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) - [AWS CDK bởi @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
#### Alibaba Cloud
[Alibaba Cloud Computing Nest](https://computenest.console.aliyun.com/service/instance/create/default?type=user&ServiceName=Dify%E7%A4%BE%E5%8C%BA%E7%89%88)
#### Alibaba Cloud Data Management
Triển khai Dify lên Alibaba Cloud chỉ với một cú nhấp chuột bằng [Alibaba Cloud Data Management](https://www.alibabacloud.com/help/en/dms/dify-in-invitational-preview/)
## Đóng góp ## Đóng góp
Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi. Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi.

@ -137,7 +137,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration # Vector database configuration
# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore # support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone
VECTOR_STORE=weaviate VECTOR_STORE=weaviate
# Weaviate configuration # Weaviate configuration
@ -294,6 +294,13 @@ VIKINGDB_SCHEMA=http
VIKINGDB_CONNECTION_TIMEOUT=30 VIKINGDB_CONNECTION_TIMEOUT=30
VIKINGDB_SOCKET_TIMEOUT=30 VIKINGDB_SOCKET_TIMEOUT=30
# Matrixone configration
MATRIXONE_HOST=127.0.0.1
MATRIXONE_PORT=6001
MATRIXONE_USER=dump
MATRIXONE_PASSWORD=111
MATRIXONE_DATABASE=dify
# Lindorm configuration # Lindorm configuration
LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070 LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:30070
LINDORM_USERNAME=admin LINDORM_USERNAME=admin
@ -332,9 +339,11 @@ PROMPT_GENERATION_MAX_TOKENS=512
CODE_GENERATION_MAX_TOKENS=1024 CODE_GENERATION_MAX_TOKENS=1024
PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false PLUGIN_BASED_TOKEN_COUNTING_ENABLED=false
# Mail configuration, support: resend, smtp # Mail configuration, support: resend, smtp, sendgrid
MAIL_TYPE= MAIL_TYPE=
# If using SendGrid, use the 'from' field for authentication if necessary.
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai> MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
# resend configuration
RESEND_API_KEY= RESEND_API_KEY=
RESEND_API_URL=https://api.resend.com RESEND_API_URL=https://api.resend.com
# smtp configuration # smtp configuration
@ -344,7 +353,8 @@ SMTP_USERNAME=123
SMTP_PASSWORD=abc SMTP_PASSWORD=abc
SMTP_USE_TLS=true SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false SMTP_OPPORTUNISTIC_TLS=false
# Sendgid configuration
SENDGRID_API_KEY=
# Sentry configuration # Sentry configuration
SENTRY_DSN= SENTRY_DSN=

@ -1,6 +1,4 @@
exclude = [ exclude = ["migrations/*"]
"migrations/*",
]
line-length = 120 line-length = 120
[format] [format]
@ -9,14 +7,14 @@ quote-style = "double"
[lint] [lint]
preview = false preview = false
select = [ select = [
"B", # flake8-bugbear rules "B", # flake8-bugbear rules
"C4", # flake8-comprehensions "C4", # flake8-comprehensions
"E", # pycodestyle E rules "E", # pycodestyle E rules
"F", # pyflakes rules "F", # pyflakes rules
"FURB", # refurb rules "FURB", # refurb rules
"I", # isort rules "I", # isort rules
"N", # pep8-naming "N", # pep8-naming
"PT", # flake8-pytest-style rules "PT", # flake8-pytest-style rules
"PLC0208", # iteration-over-set "PLC0208", # iteration-over-set
"PLC0414", # useless-import-alias "PLC0414", # useless-import-alias
"PLE0604", # invalid-all-object "PLE0604", # invalid-all-object
@ -24,19 +22,19 @@ select = [
"PLR0402", # manual-from-import "PLR0402", # manual-from-import
"PLR1711", # useless-return "PLR1711", # useless-return
"PLR1714", # repeated-equality-comparison "PLR1714", # repeated-equality-comparison
"RUF013", # implicit-optional "RUF013", # implicit-optional
"RUF019", # unnecessary-key-check "RUF019", # unnecessary-key-check
"RUF100", # unused-noqa "RUF100", # unused-noqa
"RUF101", # redirected-noqa "RUF101", # redirected-noqa
"RUF200", # invalid-pyproject-toml "RUF200", # invalid-pyproject-toml
"RUF022", # unsorted-dunder-all "RUF022", # unsorted-dunder-all
"S506", # unsafe-yaml-load "S506", # unsafe-yaml-load
"SIM", # flake8-simplify rules "SIM", # flake8-simplify rules
"TRY400", # error-instead-of-exception "TRY400", # error-instead-of-exception
"TRY401", # verbose-log-message "TRY401", # verbose-log-message
"UP", # pyupgrade rules "UP", # pyupgrade rules
"W191", # tab-indentation "W191", # tab-indentation
"W605", # invalid-escape-sequence "W605", # invalid-escape-sequence
# security related linting rules # security related linting rules
# RCE proctection (sort of) # RCE proctection (sort of)
"S102", # exec-builtin, disallow use of `exec` "S102", # exec-builtin, disallow use of `exec`
@ -47,36 +45,37 @@ select = [
] ]
ignore = [ ignore = [
"E402", # module-import-not-at-top-of-file "E402", # module-import-not-at-top-of-file
"E711", # none-comparison "E711", # none-comparison
"E712", # true-false-comparison "E712", # true-false-comparison
"E721", # type-comparison "E721", # type-comparison
"E722", # bare-except "E722", # bare-except
"F821", # undefined-name "F821", # undefined-name
"F841", # unused-variable "F841", # unused-variable
"FURB113", # repeated-append "FURB113", # repeated-append
"FURB152", # math-constant "FURB152", # math-constant
"UP007", # non-pep604-annotation "UP007", # non-pep604-annotation
"UP032", # f-string "UP032", # f-string
"UP045", # non-pep604-annotation-optional "UP045", # non-pep604-annotation-optional
"B005", # strip-with-multi-characters "B005", # strip-with-multi-characters
"B006", # mutable-argument-default "B006", # mutable-argument-default
"B007", # unused-loop-control-variable "B007", # unused-loop-control-variable
"B026", # star-arg-unpacking-after-keyword-arg "B026", # star-arg-unpacking-after-keyword-arg
"B903", # class-as-data-structure "B903", # class-as-data-structure
"B904", # raise-without-from-inside-except "B904", # raise-without-from-inside-except
"B905", # zip-without-explicit-strict "B905", # zip-without-explicit-strict
"N806", # non-lowercase-variable-in-function "N806", # non-lowercase-variable-in-function
"N815", # mixed-case-variable-in-class-scope "N815", # mixed-case-variable-in-class-scope
"PT011", # pytest-raises-too-broad "PT011", # pytest-raises-too-broad
"SIM102", # collapsible-if "SIM102", # collapsible-if
"SIM103", # needless-bool "SIM103", # needless-bool
"SIM105", # suppressible-exception "SIM105", # suppressible-exception
"SIM107", # return-in-try-except-finally "SIM107", # return-in-try-except-finally
"SIM108", # if-else-block-instead-of-if-exp "SIM108", # if-else-block-instead-of-if-exp
"SIM113", # enumerate-for-loop "SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements "SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false "SIM210", # if-expr-with-true-false
"UP038", # deprecated and not recommended by Ruff, https://docs.astral.sh/ruff/rules/non-pep604-isinstance/
] ]
[lint.per-file-ignores] [lint.per-file-ignores]

@ -281,6 +281,7 @@ def migrate_knowledge_vector_database():
VectorType.ELASTICSEARCH, VectorType.ELASTICSEARCH,
VectorType.OPENGAUSS, VectorType.OPENGAUSS,
VectorType.TABLESTORE, VectorType.TABLESTORE,
VectorType.MATRIXONE,
} }
lower_collection_vector_types = { lower_collection_vector_types = {
VectorType.ANALYTICDB, VectorType.ANALYTICDB,

@ -609,7 +609,7 @@ class MailConfig(BaseSettings):
""" """
MAIL_TYPE: Optional[str] = Field( MAIL_TYPE: Optional[str] = Field(
description="Email service provider type ('smtp' or 'resend'), default to None.", description="Email service provider type ('smtp' or 'resend' or 'sendGrid), default to None.",
default=None, default=None,
) )
@ -663,6 +663,11 @@ class MailConfig(BaseSettings):
default=50, default=50,
) )
SENDGRID_API_KEY: Optional[str] = Field(
description="API key for SendGrid service",
default=None,
)
class RagEtlConfig(BaseSettings): class RagEtlConfig(BaseSettings):
""" """

@ -24,6 +24,7 @@ from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig from .vdb.elasticsearch_config import ElasticsearchConfig
from .vdb.huawei_cloud_config import HuaweiCloudConfig from .vdb.huawei_cloud_config import HuaweiCloudConfig
from .vdb.lindorm_config import LindormConfig from .vdb.lindorm_config import LindormConfig
from .vdb.matrixone_config import MatrixoneConfig
from .vdb.milvus_config import MilvusConfig from .vdb.milvus_config import MilvusConfig
from .vdb.myscale_config import MyScaleConfig from .vdb.myscale_config import MyScaleConfig
from .vdb.oceanbase_config import OceanBaseVectorConfig from .vdb.oceanbase_config import OceanBaseVectorConfig
@ -323,5 +324,6 @@ class MiddlewareConfig(
OpenGaussConfig, OpenGaussConfig,
TableStoreConfig, TableStoreConfig,
DatasetQueueMonitorConfig, DatasetQueueMonitorConfig,
MatrixoneConfig,
): ):
pass pass

@ -0,0 +1,14 @@
from pydantic import BaseModel, Field
class MatrixoneConfig(BaseModel):
"""Matrixone vector database configuration."""
MATRIXONE_HOST: str = Field(default="localhost", description="Host address of the Matrixone server")
MATRIXONE_PORT: int = Field(default=6001, description="Port number of the Matrixone server")
MATRIXONE_USER: str = Field(default="dump", description="Username for authenticating with Matrixone")
MATRIXONE_PASSWORD: str = Field(default="111", description="Password for authenticating with Matrixone")
MATRIXONE_DATABASE: str = Field(default="dify", description="Name of the Matrixone database to connect to")
MATRIXONE_METRIC: str = Field(
default="l2", description="Distance metric type for vector similarity search (cosine or l2)"
)

@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field( CURRENT_VERSION: str = Field(
description="Dify version", description="Dify version",
default="1.4.3", default="1.5.0",
) )
COMMIT_SHA: str = Field( COMMIT_SHA: str = Field(

@ -63,6 +63,7 @@ from .app import (
statistic, statistic,
workflow, workflow,
workflow_app_log, workflow_app_log,
workflow_draft_variable,
workflow_run, workflow_run,
workflow_statistic, workflow_statistic,
) )

@ -56,8 +56,7 @@ class InsertExploreAppListApi(Resource):
parser.add_argument("position", type=int, required=True, nullable=False, location="json") parser.add_argument("position", type=int, required=True, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
with Session(db.engine) as session: app = db.session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
app = session.execute(select(App).filter(App.id == args["app_id"])).scalar_one_or_none()
if not app: if not app:
raise NotFound(f"App '{args['app_id']}' is not found") raise NotFound(f"App '{args['app_id']}' is not found")
@ -78,38 +77,38 @@ class InsertExploreAppListApi(Resource):
select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]) select(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"])
).scalar_one_or_none() ).scalar_one_or_none()
if not recommended_app: if not recommended_app:
recommended_app = RecommendedApp( recommended_app = RecommendedApp(
app_id=app.id, app_id=app.id,
description=desc, description=desc,
copyright=copy_right, copyright=copy_right,
privacy_policy=privacy_policy, privacy_policy=privacy_policy,
custom_disclaimer=custom_disclaimer, custom_disclaimer=custom_disclaimer,
language=args["language"], language=args["language"],
category=args["category"], category=args["category"],
position=args["position"], position=args["position"],
) )
db.session.add(recommended_app) db.session.add(recommended_app)
app.is_public = True app.is_public = True
db.session.commit() db.session.commit()
return {"result": "success"}, 201 return {"result": "success"}, 201
else: else:
recommended_app.description = desc recommended_app.description = desc
recommended_app.copyright = copy_right recommended_app.copyright = copy_right
recommended_app.privacy_policy = privacy_policy recommended_app.privacy_policy = privacy_policy
recommended_app.custom_disclaimer = custom_disclaimer recommended_app.custom_disclaimer = custom_disclaimer
recommended_app.language = args["language"] recommended_app.language = args["language"]
recommended_app.category = args["category"] recommended_app.category = args["category"]
recommended_app.position = args["position"] recommended_app.position = args["position"]
app.is_public = True app.is_public = True
db.session.commit() db.session.commit()
return {"result": "success"}, 200 return {"result": "success"}, 200
class InsertExploreAppApi(Resource): class InsertExploreAppApi(Resource):

@ -208,7 +208,7 @@ class AnnotationBatchImportApi(Resource):
if len(request.files) > 1: if len(request.files) > 1:
raise TooManyFilesError() raise TooManyFilesError()
# check file type # check file type
if not file.filename or not file.filename.endswith(".csv"): if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed") raise ValueError("Invalid file type. Only CSV files are allowed")
return AppAnnotationService.batch_import_app_annotations(app_id, file) return AppAnnotationService.batch_import_app_annotations(app_id, file)

@ -17,6 +17,8 @@ from libs.login import login_required
from models import Account from models import Account
from models.model import App from models.model import App
from services.app_dsl_service import AppDslService, ImportStatus from services.app_dsl_service import AppDslService, ImportStatus
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
class AppImportApi(Resource): class AppImportApi(Resource):
@ -60,7 +62,9 @@ class AppImportApi(Resource):
app_id=args.get("app_id"), app_id=args.get("app_id"),
) )
session.commit() session.commit()
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
# update web app setting as private
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
# Return appropriate status code based on result # Return appropriate status code based on result
status = result.status status = result.status
if status == ImportStatus.FAILED.value: if status == ImportStatus.FAILED.value:

@ -1,5 +1,6 @@
import json import json
import logging import logging
from collections.abc import Sequence
from typing import cast from typing import cast
from flask import abort, request from flask import abort, request
@ -18,10 +19,12 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from extensions.ext_database import db from extensions.ext_database import db
from factories import variable_factory from factories import file_factory, variable_factory
from fields.workflow_fields import workflow_fields, workflow_pagination_fields from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from fields.workflow_run_fields import workflow_run_node_execution_fields from fields.workflow_run_fields import workflow_run_node_execution_fields
from libs import helper from libs import helper
@ -30,6 +33,7 @@ from libs.login import current_user, login_required
from models import App from models import App
from models.account import Account from models.account import Account
from models.model import AppMode from models.model import AppMode
from models.workflow import Workflow
from services.app_generate_service import AppGenerateService from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError from services.errors.llm import InvokeRateLimitError
@ -38,6 +42,24 @@ from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
# at the controller level rather than in the workflow logic. This would improve separation
# of concerns and make the code more maintainable.
def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]:
files = files or []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
file_objs: Sequence[File] = []
if file_extra_config is None:
return file_objs
file_objs = file_factory.build_from_mappings(
mappings=files,
tenant_id=workflow.tenant_id,
config=file_extra_config,
)
return file_objs
class DraftWorkflowApi(Resource): class DraftWorkflowApi(Resource):
@setup_required @setup_required
@login_required @login_required
@ -402,15 +424,30 @@ class DraftWorkflowNodeRunApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json") parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("query", type=str, required=False, location="json", default="")
parser.add_argument("files", type=list, location="json", default=[])
args = parser.parse_args() args = parser.parse_args()
inputs = args.get("inputs") user_inputs = args.get("inputs")
if inputs == None: if user_inputs is None:
raise ValueError("missing inputs") raise ValueError("missing inputs")
workflow_srv = WorkflowService()
# fetch draft workflow by app_model
draft_workflow = workflow_srv.get_draft_workflow(app_model=app_model)
if not draft_workflow:
raise ValueError("Workflow not initialized")
files = _parse_file(draft_workflow, args.get("files"))
workflow_service = WorkflowService() workflow_service = WorkflowService()
workflow_node_execution = workflow_service.run_draft_workflow_node( workflow_node_execution = workflow_service.run_draft_workflow_node(
app_model=app_model, node_id=node_id, user_inputs=inputs, account=current_user app_model=app_model,
draft_workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
account=current_user,
query=args.get("query", ""),
files=files,
) )
return workflow_node_execution return workflow_node_execution
@ -731,6 +768,27 @@ class WorkflowByIdApi(Resource):
return None, 204 return None, 204
class DraftWorkflowNodeLastRunApi(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@marshal_with(workflow_run_node_execution_fields)
def get(self, app_model: App, node_id: str):
srv = WorkflowService()
workflow = srv.get_draft_workflow(app_model)
if not workflow:
raise NotFound("Workflow not found")
node_exec = srv.get_node_last_run(
app_model=app_model,
workflow=workflow,
node_id=node_id,
)
if node_exec is None:
raise NotFound("last run not found")
return node_exec
api.add_resource( api.add_resource(
DraftWorkflowApi, DraftWorkflowApi,
"/apps/<uuid:app_id>/workflows/draft", "/apps/<uuid:app_id>/workflows/draft",
@ -795,3 +853,7 @@ api.add_resource(
WorkflowByIdApi, WorkflowByIdApi,
"/apps/<uuid:app_id>/workflows/<string:workflow_id>", "/apps/<uuid:app_id>/workflows/<string:workflow_id>",
) )
api.add_resource(
DraftWorkflowNodeLastRunApi,
"/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/last-run",
)

@ -34,6 +34,20 @@ class WorkflowAppLogApi(Resource):
parser.add_argument( parser.add_argument(
"created_at__after", type=str, location="args", help="Filter logs created after this timestamp" "created_at__after", type=str, location="args", help="Filter logs created after this timestamp"
) )
parser.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
@ -57,6 +71,8 @@ class WorkflowAppLogApi(Resource):
created_at_after=args.created_at__after, created_at_after=args.created_at__after,
page=args.page, page=args.page,
limit=args.limit, limit=args.limit,
created_by_end_user_session_id=args.created_by_end_user_session_id,
created_by_account=args.created_by_account,
) )
return workflow_app_log_pagination return workflow_app_log_pagination

@ -0,0 +1,421 @@
import logging
from typing import Any, NoReturn
from flask import Response
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.app.error import (
DraftWorkflowNotExist,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.file_factory import build_from_mapping, build_from_mappings
from factories.variable_factory import build_segment_with_type
from libs.login import current_user, login_required
from models import App, AppMode, db
from models.workflow import WorkflowDraftVariable
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
return value.value
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
value = variable.get_value()
# create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True)
# Refresh the url signature before returning it to client.
if isinstance(value, FileSegment):
file = value.value
file.remote_url = file.generate_url()
elif isinstance(value, ArrayFileSegment):
files = value.value
for file in files:
file.remote_url = file.generate_url()
return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser():
parser = reqparse.RequestParser()
parser.add_argument(
"page",
type=inputs.int_range(1, 100_000),
required=False,
default=1,
location="args",
help="the page of data requested",
)
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
return parser
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda model: model.get_variable_type()),
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String,
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
value=fields.Raw(attribute=_serialize_var_value),
)
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda _: "env"),
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String,
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}
_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
}
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
return var_list.variables
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
"total": fields.Raw(),
}
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
}
def _api_prerequisite(f):
"""Common prerequisites for all draft workflow variable APIs.
It ensures the following conditions are satisfied:
- Dify has been property setup.
- The request user has logged in and initialized.
- The requested app is a workflow or a chat flow.
- The request user has the edit permission for the app.
"""
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def wrapper(*args, **kwargs):
if not current_user.is_editor:
raise Forbidden()
return f(*args, **kwargs)
return wrapper
class WorkflowVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
def get(self, app_model: App):
"""
Get draft workflow
"""
parser = _create_pagination_parser()
args = parser.parse_args()
# fetch draft workflow by app_model
workflow_service = WorkflowService()
workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
if not workflow_exist:
raise DraftWorkflowNotExist()
# fetch draft workflow by app_model
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
workflow_vars = draft_var_srv.list_variables_without_values(
app_id=app_model.id,
page=args.page,
limit=args.limit,
)
return workflow_vars
@_api_prerequisite
def delete(self, app_model: App):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
draft_var_srv.delete_workflow_variables(app_model.id)
db.session.commit()
return Response("", 204)
def validate_node_id(node_id: str) -> NoReturn | None:
if node_id in [
CONVERSATION_VARIABLE_NODE_ID,
SYSTEM_VARIABLE_NODE_ID,
]:
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
# with specific `node_id` in database, we still want to make the API separated. By disallowing
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
# we mitigate the risk that user of the API depending on the implementation detail of the API.
#
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
raise InvalidArgumentError(
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
)
return None
class NodeVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App, node_id: str):
validate_node_id(node_id)
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
return node_vars
@_api_prerequisite
def delete(self, app_model: App, node_id: str):
validate_node_id(node_id)
srv = WorkflowDraftVariableService(db.session())
srv.delete_node_variables(app_model.id, node_id)
db.session.commit()
return Response("", 204)
class VariableApi(Resource):
_PATCH_NAME_FIELD = "name"
_PATCH_VALUE_FIELD = "value"
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def get(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
return variable
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
def patch(self, app_model: App, variable_id: str):
# Request payload for file types:
#
# Local File:
#
# {
# "type": "image",
# "transfer_method": "local_file",
# "url": "",
# "upload_file_id": "daded54f-72c7-4f8e-9d18-9b0abdd9f190"
# }
#
# Remote File:
#
#
# {
# "type": "image",
# "transfer_method": "remote_url",
# "url": "http://127.0.0.1:5001/files/1602650a-4fe4-423c-85a2-af76c083e3c4/file-preview?timestamp=1750041099&nonce=...&sign=...=",
# "upload_file_id": "1602650a-4fe4-423c-85a2-af76c083e3c4"
# }
parser = reqparse.RequestParser()
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
# Parse 'value' field as-is to maintain its original data structure
parser.add_argument(self._PATCH_VALUE_FIELD, type=lambda x: x, required=False, nullable=True, location="json")
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
args = parser.parse_args(strict=True)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
new_name = args.get(self._PATCH_NAME_FIELD, None)
raw_value = args.get(self._PATCH_VALUE_FIELD, None)
if new_name is None and raw_value is None:
return variable
new_value = None
if raw_value is not None:
if variable.value_type == SegmentType.FILE:
if not isinstance(raw_value, dict):
raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}")
raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id)
elif variable.value_type == SegmentType.ARRAY_FILE:
if not isinstance(raw_value, list):
raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}")
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()
return variable
@_api_prerequisite
def delete(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
draft_var_srv.delete_variable(variable)
db.session.commit()
return Response("", 204)
class VariableResetApi(Resource):
@_api_prerequisite
def put(self, app_model: App, variable_id: str):
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
workflow_srv = WorkflowService()
draft_workflow = workflow_srv.get_draft_workflow(app_model)
if draft_workflow is None:
raise NotFoundError(
f"Draft workflow not found, app_id={app_model.id}",
)
variable = draft_var_srv.get_variable(variable_id=variable_id)
if variable is None:
raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}")
resetted = draft_var_srv.reset_variable(draft_workflow, variable)
db.session.commit()
if resetted is None:
return Response("", 204)
else:
return marshal(resetted, _WORKFLOW_DRAFT_VARIABLE_FIELDS)
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
with Session(bind=db.engine, expire_on_commit=False) as session:
draft_var_srv = WorkflowDraftVariableService(
session=session,
)
if node_id == CONVERSATION_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
elif node_id == SYSTEM_VARIABLE_NODE_ID:
draft_vars = draft_var_srv.list_system_variables(app_model.id)
else:
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
return draft_vars
class ConversationVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App):
# NOTE(QuantumGhost): Prefill conversation variables into the draft variables table
# so their IDs can be returned to the caller.
workflow_srv = WorkflowService()
draft_workflow = workflow_srv.get_draft_workflow(app_model)
if draft_workflow is None:
raise NotFoundError(description=f"draft workflow not found, id={app_model.id}")
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(draft_workflow)
db.session.commit()
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
class SystemVariableCollectionApi(Resource):
@_api_prerequisite
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
def get(self, app_model: App):
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
class EnvironmentVariableCollectionApi(Resource):
@_api_prerequisite
def get(self, app_model: App):
"""
Get draft workflow
"""
# fetch draft workflow by app_model
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model=app_model)
if workflow is None:
raise DraftWorkflowNotExist()
env_vars = workflow.environment_variables
env_vars_list = []
for v in env_vars:
env_vars_list.append(
{
"id": v.id,
"type": "env",
"name": v.name,
"description": v.description,
"selector": v.selector,
"value_type": v.value_type.value,
"value": v.value,
# Do not track edited for env vars.
"edited": False,
"visible": True,
"editable": True,
}
)
return {"items": env_vars_list}
api.add_resource(
WorkflowVariableCollectionApi,
"/apps/<uuid:app_id>/workflows/draft/variables",
)
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
api.add_resource(VariableResetApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>/reset")
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")

@ -8,6 +8,15 @@ from libs.login import current_user
from models import App, AppMode from models import App, AppMode
def _load_app_model(app_id: str) -> Optional[App]:
app_model = (
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
return app_model
def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None): def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[AppMode], None] = None):
def decorator(view_func): def decorator(view_func):
@wraps(view_func) @wraps(view_func)
@ -20,11 +29,7 @@ def get_app_model(view: Optional[Callable] = None, *, mode: Union[AppMode, list[
del kwargs["app_id"] del kwargs["app_id"]
app_model = ( app_model = _load_app_model(app_id)
db.session.query(App)
.filter(App.id == app_id, App.tenant_id == current_user.current_tenant_id, App.status == "normal")
.first()
)
if not app_model: if not app_model:
raise AppNotFoundError() raise AppNotFoundError()

@ -686,6 +686,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.TABLESTORE | VectorType.TABLESTORE
| VectorType.HUAWEI_CLOUD | VectorType.HUAWEI_CLOUD
| VectorType.TENCENT | VectorType.TENCENT
| VectorType.MATRIXONE
): ):
return { return {
"retrieval_method": [ "retrieval_method": [
@ -733,6 +734,7 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.TABLESTORE | VectorType.TABLESTORE
| VectorType.TENCENT | VectorType.TENCENT
| VectorType.HUAWEI_CLOUD | VectorType.HUAWEI_CLOUD
| VectorType.MATRIXONE
): ):
return { return {
"retrieval_method": [ "retrieval_method": [

@ -5,7 +5,7 @@ from typing import cast
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, fields, marshal, marshal_with, reqparse from flask_restful import Resource, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc, select from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
@ -43,7 +43,6 @@ from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.document_fields import ( from fields.document_fields import (
dataset_and_document_fields, dataset_and_document_fields,
document_fields, document_fields,
@ -54,8 +53,6 @@ from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task
class DocumentResource(Resource): class DocumentResource(Resource):
@ -242,12 +239,10 @@ class DatasetDocumentListApi(Resource):
return response return response
documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String}
@setup_required @setup_required
@login_required @login_required
@account_initialization_required @account_initialization_required
@marshal_with(documents_and_batch_fields) @marshal_with(dataset_and_document_fields)
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_rate_limit_check("knowledge") @cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id): def post(self, dataset_id):
@ -293,6 +288,8 @@ class DatasetDocumentListApi(Resource):
try: try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user) documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
dataset = DatasetService.get_dataset(dataset_id)
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
@ -300,7 +297,7 @@ class DatasetDocumentListApi(Resource):
except ModelCurrentlyNotSupportError: except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError() raise ProviderModelCurrentlyNotSupportError()
return {"documents": documents, "batch": batch} return {"dataset": dataset, "documents": documents, "batch": batch}
@setup_required @setup_required
@login_required @login_required
@ -862,77 +859,16 @@ class DocumentStatusApi(DocumentResource):
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
document_ids = request.args.getlist("document_id") document_ids = request.args.getlist("document_id")
for document_id in document_ids:
document = self.get_document(dataset_id, document_id)
indexing_cache_key = "document_{}_indexing".format(document.id)
cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None:
raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later")
if action == "enable":
if document.enabled:
continue
document.enabled = True
document.disabled_at = None
document.disabled_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id)
elif action == "disable":
if not document.completed_at or document.indexing_status != "completed":
raise InvalidActionError(f"Document: {document.name} is not completed.")
if not document.enabled:
continue
document.enabled = False
document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
document.disabled_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
remove_document_from_index_task.delay(document_id)
elif action == "archive":
if document.archived:
continue
document.archived = True
document.archived_at = datetime.now(UTC).replace(tzinfo=None)
document.archived_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
if document.enabled:
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
remove_document_from_index_task.delay(document_id)
elif action == "un_archive":
if not document.archived:
continue
document.archived = False
document.archived_at = None
document.archived_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id)
else: try:
raise InvalidActionError() DocumentService.batch_update_document_status(dataset, document_ids, action, current_user)
except services.errors.document.DocumentIndexingError as e:
raise InvalidActionError(str(e))
except ValueError as e:
raise InvalidActionError(str(e))
except NotFound as e:
raise NotFound(str(e))
return {"result": "success"}, 200 return {"result": "success"}, 200

@ -374,7 +374,7 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
if len(request.files) > 1: if len(request.files) > 1:
raise TooManyFilesError() raise TooManyFilesError()
# check file type # check file type
if not file.filename or not file.filename.endswith(".csv"): if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed") raise ValueError("Invalid file type. Only CSV files are allowed")
try: try:

@ -15,7 +15,7 @@ class LoadBalancingCredentialsValidateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str): def post(self, provider: str):
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role): if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden() raise Forbidden()
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id
@ -64,7 +64,7 @@ class LoadBalancingConfigCredentialsValidateApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
def post(self, provider: str, config_id: str): def post(self, provider: str, config_id: str):
if not TenantAccountRole.is_privileged_role(current_user.current_tenant.current_role): if not TenantAccountRole.is_privileged_role(current_user.current_role):
raise Forbidden() raise Forbidden()
tenant_id = current_user.current_tenant_id tenant_id = current_user.current_tenant_id

@ -85,6 +85,7 @@ class MemberInviteEmailApi(Resource):
return { return {
"result": "success", "result": "success",
"invitation_results": invitation_results, "invitation_results": invitation_results,
"tenant_id": str(current_user.current_tenant.id),
}, 201 }, 201
@ -110,7 +111,7 @@ class MemberCancelInviteApi(Resource):
except Exception as e: except Exception as e:
raise ValueError(str(e)) raise ValueError(str(e))
return {"result": "success"}, 204 return {"result": "success", "tenant_id": str(current_user.current_tenant.id)}, 200
class MemberUpdateRoleApi(Resource): class MemberUpdateRoleApi(Resource):

@ -135,6 +135,20 @@ class WorkflowAppLogApi(Resource):
parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args") parser.add_argument("status", type=str, choices=["succeeded", "failed", "stopped"], location="args")
parser.add_argument("created_at__before", type=str, location="args") parser.add_argument("created_at__before", type=str, location="args")
parser.add_argument("created_at__after", type=str, location="args") parser.add_argument("created_at__after", type=str, location="args")
parser.add_argument(
"created_by_end_user_session_id",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument(
"created_by_account",
type=str,
location="args",
required=False,
default=None,
)
parser.add_argument("page", type=int_range(1, 99999), default=1, location="args") parser.add_argument("page", type=int_range(1, 99999), default=1, location="args")
parser.add_argument("limit", type=int_range(1, 100), default=20, location="args") parser.add_argument("limit", type=int_range(1, 100), default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
@ -158,6 +172,8 @@ class WorkflowAppLogApi(Resource):
created_at_after=args.created_at__after, created_at_after=args.created_at__after,
page=args.page, page=args.page,
limit=args.limit, limit=args.limit,
created_by_end_user_session_id=args.created_by_end_user_session_id,
created_by_account=args.created_by_account,
) )
return workflow_app_log_pagination return workflow_app_log_pagination

@ -4,8 +4,12 @@ from werkzeug.exceptions import Forbidden, NotFound
import services.dataset_service import services.dataset_service
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError from controllers.service_api.dataset.error import DatasetInUseError, DatasetNameDuplicateError, InvalidActionError
from controllers.service_api.wraps import DatasetApiResource, validate_dataset_token from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_rate_limit_check,
validate_dataset_token,
)
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.plugin.entities.plugin import ModelProviderID from core.plugin.entities.plugin import ModelProviderID
from core.provider_manager import ProviderManager from core.provider_manager import ProviderManager
@ -13,7 +17,7 @@ from fields.dataset_fields import dataset_detail_fields
from fields.tag_fields import tag_fields from fields.tag_fields import tag_fields
from libs.login import current_user from libs.login import current_user
from models.dataset import Dataset, DatasetPermissionEnum from models.dataset import Dataset, DatasetPermissionEnum
from services.dataset_service import DatasetPermissionService, DatasetService from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import RetrievalModel from services.entities.knowledge_entities.knowledge_entities import RetrievalModel
from services.tag_service import TagService from services.tag_service import TagService
@ -70,6 +74,7 @@ class DatasetListApi(DatasetApiResource):
response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page} response = {"data": data, "has_more": len(datasets) == limit, "limit": limit, "total": total, "page": page}
return response, 200 return response, 200
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id): def post(self, tenant_id):
"""Resource for creating datasets.""" """Resource for creating datasets."""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -193,6 +198,7 @@ class DatasetApi(DatasetApiResource):
return data, 200 return data, 200
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, _, dataset_id): def patch(self, _, dataset_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@ -293,6 +299,7 @@ class DatasetApi(DatasetApiResource):
return result_data, 200 return result_data, 200
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, _, dataset_id): def delete(self, _, dataset_id):
""" """
Deletes a dataset given its ID. Deletes a dataset given its ID.
@ -322,6 +329,56 @@ class DatasetApi(DatasetApiResource):
raise DatasetInUseError() raise DatasetInUseError()
class DocumentStatusApi(DatasetApiResource):
"""Resource for batch document status operations."""
def patch(self, tenant_id, dataset_id, action):
"""
Batch update document status.
Args:
tenant_id: tenant id
dataset_id: dataset id
action: action to perform (enable, disable, archive, un_archive)
Returns:
dict: A dictionary with a key 'result' and a value 'success'
int: HTTP status code 200 indicating that the operation was successful.
Raises:
NotFound: If the dataset with the given ID does not exist.
Forbidden: If the user does not have permission.
InvalidActionError: If the action is invalid or cannot be performed.
"""
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
# Check user's permission
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Check dataset model setting
DatasetService.check_dataset_model_setting(dataset)
# Get document IDs from request body
data = request.get_json()
document_ids = data.get("document_ids", [])
try:
DocumentService.batch_update_document_status(dataset, document_ids, action, current_user)
except services.errors.document.DocumentIndexingError as e:
raise InvalidActionError(str(e))
except ValueError as e:
raise InvalidActionError(str(e))
return {"result": "success"}, 200
class DatasetTagsApi(DatasetApiResource): class DatasetTagsApi(DatasetApiResource):
@validate_dataset_token @validate_dataset_token
@marshal_with(tag_fields) @marshal_with(tag_fields)
@ -450,6 +507,7 @@ class DatasetTagsBindingStatusApi(DatasetApiResource):
api.add_resource(DatasetListApi, "/datasets") api.add_resource(DatasetListApi, "/datasets")
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>") api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>")
api.add_resource(DatasetTagsApi, "/datasets/tags") api.add_resource(DatasetTagsApi, "/datasets/tags")
api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding") api.add_resource(DatasetTagBindingApi, "/datasets/tags/binding")
api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding") api.add_resource(DatasetTagUnbindingApi, "/datasets/tags/unbinding")

@ -19,7 +19,11 @@ from controllers.service_api.dataset.error import (
ArchivedDocumentImmutableError, ArchivedDocumentImmutableError,
DocumentIndexingError, DocumentIndexingError,
) )
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check from controllers.service_api.wraps import (
DatasetApiResource,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check,
)
from core.errors.error import ProviderTokenNotInitError from core.errors.error import ProviderTokenNotInitError
from extensions.ext_database import db from extensions.ext_database import db
from fields.document_fields import document_fields, document_status_fields from fields.document_fields import document_fields, document_status_fields
@ -35,6 +39,7 @@ class DocumentAddByTextApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset") @cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
"""Create document by text.""" """Create document by text."""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -99,6 +104,7 @@ class DocumentUpdateByTextApi(DatasetApiResource):
"""Resource for update documents.""" """Resource for update documents."""
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id): def post(self, tenant_id, dataset_id, document_id):
"""Update document by text.""" """Update document by text."""
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -158,6 +164,7 @@ class DocumentAddByFileApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_resource_check("documents", "dataset") @cloud_edition_billing_resource_check("documents", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
"""Create document by upload file.""" """Create document by upload file."""
args = {} args = {}
@ -232,6 +239,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
"""Resource for update documents.""" """Resource for update documents."""
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id): def post(self, tenant_id, dataset_id, document_id):
"""Update document by upload file.""" """Update document by upload file."""
args = {} args = {}
@ -302,6 +310,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
class DocumentDeleteApi(DatasetApiResource): class DocumentDeleteApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, document_id): def delete(self, tenant_id, dataset_id, document_id):
"""Delete document.""" """Delete document."""
document_id = str(document_id) document_id = str(document_id)

@ -1,9 +1,10 @@
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.wraps import DatasetApiResource from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase): class HitTestingApi(DatasetApiResource, DatasetsHitTestingBase):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)

@ -3,7 +3,7 @@ from flask_restful import marshal, reqparse
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.wraps import DatasetApiResource from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_rate_limit_check
from fields.dataset_fields import dataset_metadata_fields from fields.dataset_fields import dataset_metadata_fields
from services.dataset_service import DatasetService from services.dataset_service import DatasetService
from services.entities.knowledge_entities.knowledge_entities import ( from services.entities.knowledge_entities.knowledge_entities import (
@ -14,6 +14,7 @@ from services.metadata_service import MetadataService
class DatasetMetadataCreateServiceApi(DatasetApiResource): class DatasetMetadataCreateServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("type", type=str, required=True, nullable=True, location="json") parser.add_argument("type", type=str, required=True, nullable=True, location="json")
@ -39,6 +40,7 @@ class DatasetMetadataCreateServiceApi(DatasetApiResource):
class DatasetMetadataServiceApi(DatasetApiResource): class DatasetMetadataServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id, dataset_id, metadata_id): def patch(self, tenant_id, dataset_id, metadata_id):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("name", type=str, required=True, nullable=True, location="json") parser.add_argument("name", type=str, required=True, nullable=True, location="json")
@ -54,6 +56,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name")) metadata = MetadataService.update_metadata_name(dataset_id_str, metadata_id_str, args.get("name"))
return marshal(metadata, dataset_metadata_fields), 200 return marshal(metadata, dataset_metadata_fields), 200
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, metadata_id): def delete(self, tenant_id, dataset_id, metadata_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
metadata_id_str = str(metadata_id) metadata_id_str = str(metadata_id)
@ -73,6 +76,7 @@ class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource): class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, action): def post(self, tenant_id, dataset_id, action):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)
@ -88,6 +92,7 @@ class DatasetMetadataBuiltInFieldActionServiceApi(DatasetApiResource):
class DocumentMetadataEditServiceApi(DatasetApiResource): class DocumentMetadataEditServiceApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
dataset_id_str = str(dataset_id) dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str) dataset = DatasetService.get_dataset(dataset_id_str)

@ -8,6 +8,7 @@ from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import ( from controllers.service_api.wraps import (
DatasetApiResource, DatasetApiResource,
cloud_edition_billing_knowledge_limit_check, cloud_edition_billing_knowledge_limit_check,
cloud_edition_billing_rate_limit_check,
cloud_edition_billing_resource_check, cloud_edition_billing_resource_check,
) )
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
@ -35,6 +36,7 @@ class SegmentApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id): def post(self, tenant_id, dataset_id, document_id):
"""Create single segment.""" """Create single segment."""
# check dataset # check dataset
@ -139,6 +141,7 @@ class SegmentApi(DatasetApiResource):
class DatasetSegmentApi(DatasetApiResource): class DatasetSegmentApi(DatasetApiResource):
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, document_id, segment_id): def delete(self, tenant_id, dataset_id, document_id, segment_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -162,6 +165,7 @@ class DatasetSegmentApi(DatasetApiResource):
return 204 return 204
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id, segment_id): def post(self, tenant_id, dataset_id, document_id, segment_id):
# check dataset # check dataset
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -236,6 +240,7 @@ class ChildChunkApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def post(self, tenant_id, dataset_id, document_id, segment_id): def post(self, tenant_id, dataset_id, document_id, segment_id):
"""Create child chunk.""" """Create child chunk."""
# check dataset # check dataset
@ -332,6 +337,7 @@ class DatasetChildChunkApi(DatasetApiResource):
"""Resource for updating child chunks.""" """Resource for updating child chunks."""
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def delete(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id): def delete(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id):
"""Delete child chunk.""" """Delete child chunk."""
# check dataset # check dataset
@ -370,6 +376,7 @@ class DatasetChildChunkApi(DatasetApiResource):
@cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
@cloud_edition_billing_rate_limit_check("knowledge", "dataset")
def patch(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id): def patch(self, tenant_id, dataset_id, document_id, segment_id, child_chunk_id):
"""Update child chunk.""" """Update child chunk."""
# check dataset # check dataset

@ -139,3 +139,13 @@ class InvokeRateLimitError(BaseHTTPException):
error_code = "rate_limit_error" error_code = "rate_limit_error"
description = "Rate Limit Error" description = "Rate Limit Error"
code = 429 code = 429
class NotFoundError(BaseHTTPException):
error_code = "not_found"
code = 404
class InvalidArgumentError(BaseHTTPException):
error_code = "invalid_param"
code = 400

@ -138,14 +138,11 @@ class DatasetConfigManager:
if not config.get("dataset_configs"): if not config.get("dataset_configs"):
config["dataset_configs"] = {"retrieval_model": "single"} config["dataset_configs"] = {"retrieval_model": "single"}
if not config["dataset_configs"].get("datasets"):
config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
if not isinstance(config["dataset_configs"], dict): if not isinstance(config["dataset_configs"], dict):
raise ValueError("dataset_configs must be of object type") raise ValueError("dataset_configs must be of object type")
if not isinstance(config["dataset_configs"], dict): if not config["dataset_configs"].get("datasets"):
raise ValueError("dataset_configs must be of object type") config["dataset_configs"]["datasets"] = {"strategy": "router", "datasets": []}
need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get( need_manual_query_datasets = config.get("dataset_configs") and config["dataset_configs"].get(
"datasets", {} "datasets", {}

@ -104,6 +104,7 @@ class VariableEntity(BaseModel):
Variable Entity. Variable Entity.
""" """
# `variable` records the name of the variable in user inputs.
variable: str variable: str
label: str label: str
description: str = "" description: str = ""

@ -29,13 +29,14 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from libs.flask_utils import preserve_flask_contexts from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -116,6 +117,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
) )
# parse files # parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else [] files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config: if file_extra_config:
@ -261,6 +267,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
) )
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
return self._generate( return self._generate(
workflow=workflow, workflow=workflow,
@ -271,6 +284,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None, conversation=None,
stream=streaming, stream=streaming,
variable_loader=var_loader,
) )
def single_loop_generate( def single_loop_generate(
@ -336,6 +350,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
) )
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
return self._generate( return self._generate(
workflow=workflow, workflow=workflow,
@ -346,6 +367,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None, conversation=None,
stream=streaming, stream=streaming,
variable_loader=var_loader,
) )
def _generate( def _generate(
@ -359,6 +381,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Optional[Conversation] = None, conversation: Optional[Conversation] = None,
stream: bool = True, stream: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
""" """
Generate App response. Generate App response.
@ -367,6 +390,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param user: account or end user :param user: account or end user
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution :param workflow_node_execution_repository: repository for workflow node execution
:param conversation: conversation :param conversation: conversation
:param stream: is stream :param stream: is stream
@ -409,6 +433,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"conversation_id": conversation.id, "conversation_id": conversation.id,
"message_id": message.id, "message_id": message.id,
"context": context, "context": context,
"variable_loader": variable_loader,
}, },
) )
@ -437,6 +462,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation_id: str, conversation_id: str,
message_id: str, message_id: str,
context: contextvars.Context, context: contextvars.Context,
variable_loader: VariableLoader,
) -> None: ) -> None:
""" """
Generate worker in a new thread. Generate worker in a new thread.
@ -453,8 +479,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message # get conversation and message
conversation = self._get_conversation(conversation_id) conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id) message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app # chatbot app
runner = AdvancedChatAppRunner( runner = AdvancedChatAppRunner(
@ -463,6 +487,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation, conversation=conversation,
message=message, message=message,
dialogue_count=self._dialogue_count, dialogue_count=self._dialogue_count,
variable_loader=variable_loader,
) )
runner.run() runner.run()

@ -19,6 +19,7 @@ from core.moderation.base import ModerationError
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import UserFrom from models.enums import UserFrom
@ -40,14 +41,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
dialogue_count: int, dialogue_count: int,
variable_loader: VariableLoader,
) -> None: ) -> None:
super().__init__(queue_manager) super().__init__(queue_manager, variable_loader)
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.conversation = conversation self.conversation = conversation
self.message = message self.message = message
self._dialogue_count = dialogue_count self._dialogue_count = dialogue_count
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
def run(self) -> None: def run(self) -> None:
app_config = self.application_generate_entity.app_config app_config = self.application_generate_entity.app_config
app_config = cast(AdvancedChatAppConfig, app_config) app_config = cast(AdvancedChatAppConfig, app_config)

@ -26,7 +26,6 @@ from factories import file_factory
from libs.flask_utils import preserve_flask_contexts from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser from models import Account, App, EndUser
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -124,6 +123,11 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
override_model_config_dict["retriever_resource"] = {"enabled": True} override_model_config_dict["retriever_resource"] = {"enabled": True}
# parse files # parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args.get("files") or [] files = args.get("files") or []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config: if file_extra_config:
@ -233,8 +237,6 @@ class AgentChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message # get conversation and message
conversation = self._get_conversation(conversation_id) conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id) message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app # chatbot app
runner = AgentChatAppRunner() runner = AgentChatAppRunner()

@ -25,7 +25,6 @@ from factories import file_factory
from models.account import Account from models.account import Account
from models.model import App, EndUser from models.model import App, EndUser
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -115,6 +114,11 @@ class ChatAppGenerator(MessageBasedAppGenerator):
override_model_config_dict["retriever_resource"] = {"enabled": True} override_model_config_dict["retriever_resource"] = {"enabled": True}
# parse files # parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else [] files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config: if file_extra_config:
@ -219,8 +223,6 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# get conversation and message # get conversation and message
conversation = self._get_conversation(conversation_id) conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id) message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError("Message not exists")
# chatbot app # chatbot app
runner = ChatAppRunner() runner = ChatAppRunner()

@ -48,6 +48,7 @@ from core.workflow.entities.workflow_execution import WorkflowExecution
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import ( from models import (
Account, Account,
CreatorUserRole, CreatorUserRole,
@ -125,7 +126,7 @@ class WorkflowResponseConverter:
id=workflow_execution.id_, id=workflow_execution.id_,
workflow_id=workflow_execution.workflow_id, workflow_id=workflow_execution.workflow_id,
status=workflow_execution.status, status=workflow_execution.status,
outputs=workflow_execution.outputs, outputs=WorkflowRuntimeTypeConverter().to_json_encodable(workflow_execution.outputs),
error=workflow_execution.error_message, error=workflow_execution.error_message,
elapsed_time=workflow_execution.elapsed_time, elapsed_time=workflow_execution.elapsed_time,
total_tokens=workflow_execution.total_tokens, total_tokens=workflow_execution.total_tokens,
@ -202,6 +203,8 @@ class WorkflowResponseConverter:
if not workflow_node_execution.finished_at: if not workflow_node_execution.finished_at:
return None return None
json_converter = WorkflowRuntimeTypeConverter()
return NodeFinishStreamResponse( return NodeFinishStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id, workflow_run_id=workflow_node_execution.workflow_execution_id,
@ -214,7 +217,7 @@ class WorkflowResponseConverter:
predecessor_node_id=workflow_node_execution.predecessor_node_id, predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs, inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data, process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs, outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
status=workflow_node_execution.status, status=workflow_node_execution.status,
error=workflow_node_execution.error, error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time, elapsed_time=workflow_node_execution.elapsed_time,
@ -245,6 +248,8 @@ class WorkflowResponseConverter:
if not workflow_node_execution.finished_at: if not workflow_node_execution.finished_at:
return None return None
json_converter = WorkflowRuntimeTypeConverter()
return NodeRetryStreamResponse( return NodeRetryStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_execution_id, workflow_run_id=workflow_node_execution.workflow_execution_id,
@ -257,7 +262,7 @@ class WorkflowResponseConverter:
predecessor_node_id=workflow_node_execution.predecessor_node_id, predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs, inputs=workflow_node_execution.inputs,
process_data=workflow_node_execution.process_data, process_data=workflow_node_execution.process_data,
outputs=workflow_node_execution.outputs, outputs=json_converter.to_json_encodable(workflow_node_execution.outputs),
status=workflow_node_execution.status, status=workflow_node_execution.status,
error=workflow_node_execution.error, error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time, elapsed_time=workflow_node_execution.elapsed_time,
@ -376,6 +381,7 @@ class WorkflowResponseConverter:
workflow_execution_id: str, workflow_execution_id: str,
event: QueueIterationCompletedEvent, event: QueueIterationCompletedEvent,
) -> IterationNodeCompletedStreamResponse: ) -> IterationNodeCompletedStreamResponse:
json_converter = WorkflowRuntimeTypeConverter()
return IterationNodeCompletedStreamResponse( return IterationNodeCompletedStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_execution_id, workflow_run_id=workflow_execution_id,
@ -384,7 +390,7 @@ class WorkflowResponseConverter:
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type.value, node_type=event.node_type.value,
title=event.node_data.title, title=event.node_data.title,
outputs=event.outputs, outputs=json_converter.to_json_encodable(event.outputs),
created_at=int(time.time()), created_at=int(time.time()),
extras={}, extras={},
inputs=event.inputs or {}, inputs=event.inputs or {},
@ -463,7 +469,7 @@ class WorkflowResponseConverter:
node_id=event.node_id, node_id=event.node_id,
node_type=event.node_type.value, node_type=event.node_type.value,
title=event.node_data.title, title=event.node_data.title,
outputs=event.outputs, outputs=WorkflowRuntimeTypeConverter().to_json_encodable(event.outputs),
created_at=int(time.time()), created_at=int(time.time()),
extras={}, extras={},
inputs=event.inputs or {}, inputs=event.inputs or {},

@ -101,6 +101,11 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
) )
# parse files # parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
files = args["files"] if args.get("files") else [] files = args["files"] if args.get("files") else []
file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict())
if file_extra_config: if file_extra_config:
@ -196,8 +201,6 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
try: try:
# get message # get message
message = self._get_message(message_id) message = self._get_message(message_id)
if message is None:
raise MessageNotExistsError()
# chatbot app # chatbot app
runner = CompletionAppRunner() runner = CompletionAppRunner()

@ -29,6 +29,7 @@ from models.enums import CreatorUserRole
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -251,7 +252,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
return introduction or "" return introduction or ""
def _get_conversation(self, conversation_id: str): def _get_conversation(self, conversation_id: str) -> Conversation:
""" """
Get conversation by conversation id Get conversation by conversation id
:param conversation_id: conversation id :param conversation_id: conversation id
@ -260,11 +261,11 @@ class MessageBasedAppGenerator(BaseAppGenerator):
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first() conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation: if not conversation:
raise ConversationNotExistsError() raise ConversationNotExistsError("Conversation not exists")
return conversation return conversation
def _get_message(self, message_id: str) -> Optional[Message]: def _get_message(self, message_id: str) -> Message:
""" """
Get message by message id Get message by message id
:param message_id: message id :param message_id: message id
@ -272,4 +273,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
""" """
message = db.session.query(Message).filter(Message.id == message_id).first() message = db.session.query(Message).filter(Message.id == message_id).first()
if message is None:
raise MessageNotExistsError("Message not exists")
return message return message

@ -27,11 +27,13 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from libs.flask_utils import preserve_flask_contexts from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -94,6 +96,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
files: Sequence[Mapping[str, Any]] = args.get("files") or [] files: Sequence[Mapping[str, Any]] = args.get("files") or []
# parse files # parse files
# TODO(QuantumGhost): Move file parsing logic to the API controller layer
# for better separation of concerns.
#
# For implementation reference, see the `_parse_file` function and
# `DraftWorkflowNodeRunApi` class which handle this properly.
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
system_files = file_factory.build_from_mappings( system_files = file_factory.build_from_mappings(
mappings=files, mappings=files,
@ -186,6 +193,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True, streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
""" """
Generate App response. Generate App response.
@ -195,6 +203,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param user: account or end user :param user: account or end user
:param application_generate_entity: application generate entity :param application_generate_entity: application generate entity
:param invoke_from: invoke from source :param invoke_from: invoke from source
:param workflow_execution_repository: repository for workflow execution
:param workflow_node_execution_repository: repository for workflow node execution :param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream :param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id :param workflow_thread_pool_id: workflow thread pool id
@ -218,6 +227,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
"queue_manager": queue_manager, "queue_manager": queue_manager,
"context": context, "context": context,
"workflow_thread_pool_id": workflow_thread_pool_id, "workflow_thread_pool_id": workflow_thread_pool_id,
"variable_loader": variable_loader,
}, },
) )
@ -302,6 +312,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
) )
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate( return self._generate(
app_model=app_model, app_model=app_model,
@ -312,6 +329,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository, workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming, streaming=streaming,
variable_loader=var_loader,
) )
def single_loop_generate( def single_loop_generate(
@ -378,7 +396,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
) )
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate( return self._generate(
app_model=app_model, app_model=app_model,
workflow=workflow, workflow=workflow,
@ -388,6 +412,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository, workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming, streaming=streaming,
variable_loader=var_loader,
) )
def _generate_worker( def _generate_worker(
@ -396,6 +421,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
context: contextvars.Context, context: contextvars.Context,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
) -> None: ) -> None:
""" """
@ -414,6 +440,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id, workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
) )
runner.run() runner.run()

@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import (
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import UserFrom from models.enums import UserFrom
@ -30,6 +31,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self, self,
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
) -> None: ) -> None:
""" """
@ -37,10 +39,13 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
:param queue_manager: application queue manager :param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id :param workflow_thread_pool_id: workflow thread pool id
""" """
super().__init__(queue_manager, variable_loader)
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id self.workflow_thread_pool_id = workflow_thread_pool_id
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
def run(self) -> None: def run(self) -> None:
""" """
Run application Run application

@ -1,6 +1,8 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Optional, cast from typing import Any, Optional, cast
from sqlalchemy.orm import Session
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import ( from core.app.entities.queue_entities import (
@ -33,6 +35,7 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.event import (
AgentLogEvent, AgentLogEvent,
BaseNodeEvent,
GraphEngineEvent, GraphEngineEvent,
GraphRunFailedEvent, GraphRunFailedEvent,
GraphRunPartialSucceededEvent, GraphRunPartialSucceededEvent,
@ -62,15 +65,23 @@ from core.workflow.graph_engine.entities.event import (
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App from models.model import App
from models.workflow import Workflow from models.workflow import Workflow
from services.workflow_draft_variable_service import (
DraftVariableSaver,
)
class WorkflowBasedAppRunner(AppRunner): class WorkflowBasedAppRunner(AppRunner):
def __init__(self, queue_manager: AppQueueManager): def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None:
self.queue_manager = queue_manager self.queue_manager = queue_manager
self._variable_loader = variable_loader
def _get_app_id(self) -> str:
raise NotImplementedError("not implemented")
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph: def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
""" """
@ -173,6 +184,13 @@ class WorkflowBasedAppRunner(AppRunner):
except NotImplementedError: except NotImplementedError:
variable_mapping = {} variable_mapping = {}
load_into_variable_pool(
variable_loader=self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
WorkflowEntry.mapping_user_inputs_to_variable_pool( WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping, variable_mapping=variable_mapping,
user_inputs=user_inputs, user_inputs=user_inputs,
@ -262,6 +280,12 @@ class WorkflowBasedAppRunner(AppRunner):
) )
except NotImplementedError: except NotImplementedError:
variable_mapping = {} variable_mapping = {}
load_into_variable_pool(
self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
WorkflowEntry.mapping_user_inputs_to_variable_pool( WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping, variable_mapping=variable_mapping,
@ -376,6 +400,8 @@ class WorkflowBasedAppRunner(AppRunner):
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
) )
) )
self._save_draft_var_for_event(event)
elif isinstance(event, NodeRunFailedEvent): elif isinstance(event, NodeRunFailedEvent):
self._publish_event( self._publish_event(
QueueNodeFailedEvent( QueueNodeFailedEvent(
@ -438,6 +464,8 @@ class WorkflowBasedAppRunner(AppRunner):
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
) )
) )
self._save_draft_var_for_event(event)
elif isinstance(event, NodeInIterationFailedEvent): elif isinstance(event, NodeInIterationFailedEvent):
self._publish_event( self._publish_event(
QueueNodeInIterationFailedEvent( QueueNodeInIterationFailedEvent(
@ -690,3 +718,30 @@ class WorkflowBasedAppRunner(AppRunner):
def _publish_event(self, event: AppQueueEvent) -> None: def _publish_event(self, event: AppQueueEvent) -> None:
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
def _save_draft_var_for_event(self, event: BaseNodeEvent):
run_result = event.route_node_state.node_run_result
if run_result is None:
return
process_data = run_result.process_data
outputs = run_result.outputs
with Session(bind=db.engine) as session, session.begin():
draft_var_saver = DraftVariableSaver(
session=session,
app_id=self._get_app_id(),
node_id=event.node_id,
node_type=event.node_type,
# FIXME(QuantumGhost): rely on private state of queue_manager is not ideal.
invoke_from=self.queue_manager._invoke_from,
node_execution_id=event.id,
enclosing_node_id=event.in_loop_id or event.in_iteration_id or None,
)
draft_var_saver.save(process_data=process_data, outputs=outputs)
def _remove_first_element_from_variable_string(key: str) -> str:
"""
Remove the first element from the prefix.
"""
prefix, remaining = key.split(".", maxsplit=1)
return remaining

@ -17,9 +17,24 @@ class InvokeFrom(Enum):
Invoke From. Invoke From.
""" """
# SERVICE_API indicates that this invocation is from an API call to Dify app.
#
# Description of service api in Dify docs:
# https://docs.dify.ai/en/guides/application-publishing/developing-with-apis
SERVICE_API = "service-api" SERVICE_API = "service-api"
# WEB_APP indicates that this invocation is from
# the web app of the workflow (or chatflow).
#
# Description of web app in Dify docs:
# https://docs.dify.ai/en/guides/application-publishing/launch-your-webapp-quickly/README
WEB_APP = "web-app" WEB_APP = "web-app"
# EXPLORE indicates that this invocation is from
# the workflow (or chatflow) explore page.
EXPLORE = "explore" EXPLORE = "explore"
# DEBUGGER indicates that this invocation is from
# the workflow (or chatflow) edit page.
DEBUGGER = "debugger" DEBUGGER = "debugger"
@classmethod @classmethod

@ -1 +1,11 @@
from typing import Any
# TODO(QuantumGhost): Refactor variable type identification. Instead of directly
# comparing `dify_model_identity` with constants throughout the codebase, extract
# this logic into a dedicated function. This would encapsulate the implementation
# details of how different variable types are identified.
FILE_MODEL_IDENTITY = "__dify__file__" FILE_MODEL_IDENTITY = "__dify__file__"
def maybe_file_object(o: Any) -> bool:
return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY

@ -534,7 +534,7 @@ class IndexingRunner:
# chunk nodes by chunk size # chunk nodes by chunk size
indexing_start_at = time.perf_counter() indexing_start_at = time.perf_counter()
tokens = 0 tokens = 0
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
# create keyword index # create keyword index
create_keyword_thread = threading.Thread( create_keyword_thread = threading.Thread(
target=self._process_keyword_index, target=self._process_keyword_index,
@ -572,7 +572,7 @@ class IndexingRunner:
for future in futures: for future in futures:
tokens += future.result() tokens += future.result()
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX and dataset.indexing_technique == "economy":
create_keyword_thread.join() create_keyword_thread.join()
indexing_end_at = time.perf_counter() indexing_end_at = time.perf_counter()

@ -542,8 +542,6 @@ class LBModelManager:
return config return config
return None
def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None: def cooldown(self, config: ModelLoadBalancingConfiguration, expire: int = 60) -> None:
""" """
Cooldown model load balancing config Cooldown model load balancing config

@ -251,7 +251,7 @@ class OpsTraceManager:
provider_config_map[tracing_provider]["trace_instance"], provider_config_map[tracing_provider]["trace_instance"],
provider_config_map[tracing_provider]["config_class"], provider_config_map[tracing_provider]["config_class"],
) )
decrypt_trace_config_key = str(decrypt_trace_config) decrypt_trace_config_key = json.dumps(decrypt_trace_config, sort_keys=True)
tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key) tracing_instance = cls.ops_trace_instances_cache.get(decrypt_trace_config_key)
if tracing_instance is None: if tracing_instance is None:
# create new tracing_instance and update the cache if it absent # create new tracing_instance and update the cache if it absent

@ -156,9 +156,23 @@ class PluginInstallTaskStartResponse(BaseModel):
task_id: str = Field(description="The ID of the install task.") task_id: str = Field(description="The ID of the install task.")
class PluginUploadResponse(BaseModel): class PluginVerification(BaseModel):
"""
Verification of the plugin.
"""
class AuthorizedCategory(StrEnum):
Langgenius = "langgenius"
Partner = "partner"
Community = "community"
authorized_category: AuthorizedCategory = Field(description="The authorized category of the plugin.")
class PluginDecodeResponse(BaseModel):
unique_identifier: str = Field(description="The unique identifier of the plugin.") unique_identifier: str = Field(description="The unique identifier of the plugin.")
manifest: PluginDeclaration manifest: PluginDeclaration
verification: Optional[PluginVerification] = Field(default=None, description="Basic verification information")
class PluginOAuthAuthorizationUrlResponse(BaseModel): class PluginOAuthAuthorizationUrlResponse(BaseModel):

@ -1,3 +1,4 @@
import binascii
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any from typing import Any
@ -16,7 +17,7 @@ class OAuthHandler(BasePluginClient):
provider: str, provider: str,
system_credentials: Mapping[str, Any], system_credentials: Mapping[str, Any],
) -> PluginOAuthAuthorizationUrlResponse: ) -> PluginOAuthAuthorizationUrlResponse:
return self._request_with_plugin_daemon_response( response = self._request_with_plugin_daemon_response_stream(
"POST", "POST",
f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
PluginOAuthAuthorizationUrlResponse, PluginOAuthAuthorizationUrlResponse,
@ -32,6 +33,9 @@ class OAuthHandler(BasePluginClient):
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
) )
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
def get_credentials( def get_credentials(
self, self,
@ -49,7 +53,7 @@ class OAuthHandler(BasePluginClient):
# encode request to raw http request # encode request to raw http request
raw_request_bytes = self._convert_request_to_raw_data(request) raw_request_bytes = self._convert_request_to_raw_data(request)
return self._request_with_plugin_daemon_response( response = self._request_with_plugin_daemon_response_stream(
"POST", "POST",
f"plugin/{tenant_id}/dispatch/oauth/get_credentials", f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
PluginOAuthCredentialsResponse, PluginOAuthCredentialsResponse,
@ -58,7 +62,8 @@ class OAuthHandler(BasePluginClient):
"data": { "data": {
"provider": provider, "provider": provider,
"system_credentials": system_credentials, "system_credentials": system_credentials,
"raw_request_bytes": raw_request_bytes, # for json serialization
"raw_http_request": binascii.hexlify(raw_request_bytes).decode(),
}, },
}, },
headers={ headers={
@ -66,6 +71,9 @@ class OAuthHandler(BasePluginClient):
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
) )
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for authorization URL request.")
def _convert_request_to_raw_data(self, request: Request) -> bytes: def _convert_request_to_raw_data(self, request: Request) -> bytes:
""" """
@ -79,7 +87,7 @@ class OAuthHandler(BasePluginClient):
""" """
# Start with the request line # Start with the request line
method = request.method method = request.method
path = request.path path = request.full_path
protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1") protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1")
raw_data = f"{method} {path} {protocol}\r\n".encode() raw_data = f"{method} {path} {protocol}\r\n".encode()

@ -10,10 +10,10 @@ from core.plugin.entities.plugin import (
PluginInstallationSource, PluginInstallationSource,
) )
from core.plugin.entities.plugin_daemon import ( from core.plugin.entities.plugin_daemon import (
PluginDecodeResponse,
PluginInstallTask, PluginInstallTask,
PluginInstallTaskStartResponse, PluginInstallTaskStartResponse,
PluginListResponse, PluginListResponse,
PluginUploadResponse,
) )
from core.plugin.impl.base import BasePluginClient from core.plugin.impl.base import BasePluginClient
@ -53,7 +53,7 @@ class PluginInstaller(BasePluginClient):
tenant_id: str, tenant_id: str,
pkg: bytes, pkg: bytes,
verify_signature: bool = False, verify_signature: bool = False,
) -> PluginUploadResponse: ) -> PluginDecodeResponse:
""" """
Upload a plugin package and return the plugin unique identifier. Upload a plugin package and return the plugin unique identifier.
""" """
@ -68,7 +68,7 @@ class PluginInstaller(BasePluginClient):
return self._request_with_plugin_daemon_response( return self._request_with_plugin_daemon_response(
"POST", "POST",
f"plugin/{tenant_id}/management/install/upload/package", f"plugin/{tenant_id}/management/install/upload/package",
PluginUploadResponse, PluginDecodeResponse,
files=body, files=body,
data=data, data=data,
) )
@ -176,6 +176,18 @@ class PluginInstaller(BasePluginClient):
params={"plugin_unique_identifier": plugin_unique_identifier}, params={"plugin_unique_identifier": plugin_unique_identifier},
) )
def decode_plugin_from_identifier(self, tenant_id: str, plugin_unique_identifier: str) -> PluginDecodeResponse:
"""
Decode a plugin from an identifier.
"""
return self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/decode/from_identifier",
PluginDecodeResponse,
data={"plugin_unique_identifier": plugin_unique_identifier},
headers={"Content-Type": "application/json"},
)
def fetch_plugin_installation_by_ids( def fetch_plugin_installation_by_ids(
self, tenant_id: str, plugin_ids: Sequence[str] self, tenant_id: str, plugin_ids: Sequence[str]
) -> Sequence[PluginInstallation]: ) -> Sequence[PluginInstallation]:

@ -0,0 +1,233 @@
import json
import logging
import uuid
from functools import wraps
from typing import Any, Optional
from mo_vector.client import MoVectorClient # type: ignore
from pydantic import BaseModel, model_validator
from configs import dify_config
from core.rag.datasource.vdb.vector_base import BaseVector
from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.embedding_base import Embeddings
from core.rag.models.document import Document
from extensions.ext_redis import redis_client
from models.dataset import Dataset
logger = logging.getLogger(__name__)
class MatrixoneConfig(BaseModel):
host: str = "localhost"
port: int = 6001
user: str = "dump"
password: str = "111"
database: str = "dify"
metric: str = "l2"
@model_validator(mode="before")
@classmethod
def validate_config(cls, values: dict) -> dict:
if not values["host"]:
raise ValueError("config host is required")
if not values["port"]:
raise ValueError("config port is required")
if not values["user"]:
raise ValueError("config user is required")
if not values["password"]:
raise ValueError("config password is required")
if not values["database"]:
raise ValueError("config database is required")
return values
def ensure_client(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
if self.client is None:
self.client = self._get_client(None, False)
return func(self, *args, **kwargs)
return wrapper
class MatrixoneVector(BaseVector):
"""
Matrixone vector storage implementation.
"""
def __init__(self, collection_name: str, config: MatrixoneConfig):
super().__init__(collection_name)
self.config = config
self.collection_name = collection_name.lower()
self.client = None
@property
def collection_name(self):
return self._collection_name
@collection_name.setter
def collection_name(self, value):
self._collection_name = value
def get_type(self) -> str:
return VectorType.MATRIXONE
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
if self.client is None:
self.client = self._get_client(len(embeddings[0]), True)
return self.add_texts(texts, embeddings)
def _get_client(self, dimension: Optional[int] = None, create_table: bool = False) -> MoVectorClient:
"""
Create a new client for the collection.
The collection will be created if it doesn't exist.
"""
lock_name = f"vector_indexing_lock_{self._collection_name}"
with redis_client.lock(lock_name, timeout=20):
client = MoVectorClient(
connection_string=f"mysql+pymysql://{self.config.user}:{self.config.password}@{self.config.host}:{self.config.port}/{self.config.database}",
table_name=self.collection_name,
vector_dimension=dimension,
create_table=create_table,
)
collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
if redis_client.get(collection_exist_cache_key):
return client
try:
client.create_full_text_index()
except Exception as e:
logger.exception("Failed to create full text index")
redis_client.set(collection_exist_cache_key, 1, ex=3600)
return client
def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
if self.client is None:
self.client = self._get_client(len(embeddings[0]), True)
assert self.client is not None
ids = []
for _, doc in enumerate(documents):
if doc.metadata is not None:
doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
ids.append(doc_id)
self.client.insert(
texts=[doc.page_content for doc in documents],
embeddings=embeddings,
metadatas=[doc.metadata for doc in documents],
ids=ids,
)
return ids
@ensure_client
def text_exists(self, id: str) -> bool:
assert self.client is not None
result = self.client.get(ids=[id])
return len(result) > 0
@ensure_client
def delete_by_ids(self, ids: list[str]) -> None:
assert self.client is not None
if not ids:
return
self.client.delete(ids=ids)
@ensure_client
def get_ids_by_metadata_field(self, key: str, value: str):
assert self.client is not None
results = self.client.query_by_metadata(filter={key: value})
return [result.id for result in results]
@ensure_client
def delete_by_metadata_field(self, key: str, value: str) -> None:
assert self.client is not None
self.client.delete(filter={key: value})
@ensure_client
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
assert self.client is not None
top_k = kwargs.get("top_k", 5)
document_ids_filter = kwargs.get("document_ids_filter")
filter = None
if document_ids_filter:
filter = {"document_id": {"$in": document_ids_filter}}
results = self.client.query(
query_vector=query_vector,
k=top_k,
filter=filter,
)
docs = []
# TODO: add the score threshold to the query
for result in results:
metadata = result.metadata
docs.append(
Document(
page_content=result.document,
metadata=metadata,
)
)
return docs
@ensure_client
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
assert self.client is not None
top_k = kwargs.get("top_k", 5)
document_ids_filter = kwargs.get("document_ids_filter")
filter = None
if document_ids_filter:
filter = {"document_id": {"$in": document_ids_filter}}
score_threshold = float(kwargs.get("score_threshold", 0.0))
results = self.client.full_text_query(
keywords=[query],
k=top_k,
filter=filter,
)
docs = []
for result in results:
metadata = result.metadata
if isinstance(metadata, str):
import json
metadata = json.loads(metadata)
score = 1 - result.distance
if score >= score_threshold:
metadata["score"] = score
docs.append(
Document(
page_content=result.document,
metadata=metadata,
)
)
return docs
@ensure_client
def delete(self) -> None:
assert self.client is not None
self.client.delete()
class MatrixoneVectorFactory(AbstractVectorFactory):
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> MatrixoneVector:
if dataset.index_struct_dict:
class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
collection_name = class_prefix
else:
dataset_id = dataset.id
collection_name = Dataset.gen_collection_name_by_id(dataset_id)
dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.MATRIXONE, collection_name))
config = MatrixoneConfig(
host=dify_config.MATRIXONE_HOST or "localhost",
port=dify_config.MATRIXONE_PORT or 6001,
user=dify_config.MATRIXONE_USER or "dump",
password=dify_config.MATRIXONE_PASSWORD or "111",
database=dify_config.MATRIXONE_DATABASE or "dify",
metric=dify_config.MATRIXONE_METRIC or "l2",
)
return MatrixoneVector(collection_name=collection_name, config=config)

@ -164,6 +164,10 @@ class Vector:
from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
return HuaweiCloudVectorFactory return HuaweiCloudVectorFactory
case VectorType.MATRIXONE:
from core.rag.datasource.vdb.matrixone.matrixone_vector import MatrixoneVectorFactory
return MatrixoneVectorFactory
case _: case _:
raise ValueError(f"Vector store {vector_type} is not supported.") raise ValueError(f"Vector store {vector_type} is not supported.")

@ -29,3 +29,4 @@ class VectorType(StrEnum):
OPENGAUSS = "opengauss" OPENGAUSS = "opengauss"
TABLESTORE = "tablestore" TABLESTORE = "tablestore"
HUAWEI_CLOUD = "huawei_cloud" HUAWEI_CLOUD = "huawei_cloud"
MATRIXONE = "matrixone"

@ -45,7 +45,8 @@ class WeaviateVector(BaseVector):
# by changing the connection timeout to pypi.org from 1 second to 0.001 seconds. # by changing the connection timeout to pypi.org from 1 second to 0.001 seconds.
# TODO: This can be removed once weaviate-client is updated to 3.26.7 or higher, # TODO: This can be removed once weaviate-client is updated to 3.26.7 or higher,
# which does not contain the deprecation check. # which does not contain the deprecation check.
weaviate.connect.connection.PYPI_TIMEOUT = 0.001 if hasattr(weaviate.connect.connection, "PYPI_TIMEOUT"):
weaviate.connect.connection.PYPI_TIMEOUT = 0.001
try: try:
client = weaviate.Client( client = weaviate.Client(

@ -22,6 +22,7 @@ class FirecrawlApp:
"formats": ["markdown"], "formats": ["markdown"],
"onlyMainContent": True, "onlyMainContent": True,
"timeout": 30000, "timeout": 30000,
"integration": "dify",
} }
if params: if params:
json_data.update(params) json_data.update(params)
@ -39,7 +40,7 @@ class FirecrawlApp:
def crawl_url(self, url, params=None) -> str: def crawl_url(self, url, params=None) -> str:
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post # Documentation: https://docs.firecrawl.dev/api-reference/endpoint/crawl-post
headers = self._prepare_headers() headers = self._prepare_headers()
json_data = {"url": url} json_data = {"url": url, "integration": "dify"}
if params: if params:
json_data.update(params) json_data.update(params)
response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers) response = self._post_request(f"{self.base_url}/v1/crawl", json_data, headers)
@ -49,7 +50,6 @@ class FirecrawlApp:
return cast(str, job_id) return cast(str, job_id)
else: else:
self._handle_error(response, "start crawl job") self._handle_error(response, "start crawl job")
# FIXME: unreachable code for mypy
return "" # unreachable return "" # unreachable
def check_crawl_status(self, job_id) -> dict[str, Any]: def check_crawl_status(self, job_id) -> dict[str, Any]:
@ -82,7 +82,6 @@ class FirecrawlApp:
) )
else: else:
self._handle_error(response, "check crawl status") self._handle_error(response, "check crawl status")
# FIXME: unreachable code for mypy
return {} # unreachable return {} # unreachable
def _format_crawl_status_response( def _format_crawl_status_response(
@ -126,4 +125,31 @@ class FirecrawlApp:
def _handle_error(self, response, action) -> None: def _handle_error(self, response, action) -> None:
error_message = response.json().get("error", "Unknown error occurred") error_message = response.json().get("error", "Unknown error occurred")
raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") raise Exception(f"Failed to {action}. Status code: {response.status_code}. Error: {error_message}") # type: ignore[return]
def search(self, query: str, params: dict[str, Any] | None = None) -> dict[str, Any]:
# Documentation: https://docs.firecrawl.dev/api-reference/endpoint/search
headers = self._prepare_headers()
json_data = {
"query": query,
"limit": 5,
"lang": "en",
"country": "us",
"timeout": 60000,
"ignoreInvalidURLs": False,
"scrapeOptions": {},
"integration": "dify",
}
if params:
json_data.update(params)
response = self._post_request(f"{self.base_url}/v1/search", json_data, headers)
if response.status_code == 200:
response_data = response.json()
if not response_data.get("success"):
raise Exception(f"Search failed. Error: {response_data.get('warning', 'Unknown error')}")
return cast(dict[str, Any], response_data)
elif response.status_code in {402, 409, 500, 429, 408}:
self._handle_error(response, "perform search")
return {} # Avoid additional exception after handling error
else:
raise Exception(f"Failed to perform search. Status code: {response.status_code}")

@ -68,22 +68,17 @@ class MarkdownExtractor(BaseExtractor):
continue continue
header_match = re.match(r"^#+\s", line) header_match = re.match(r"^#+\s", line)
if header_match: if header_match:
if current_header is not None: markdown_tups.append((current_header, current_text))
markdown_tups.append((current_header, current_text))
current_header = line current_header = line
current_text = "" current_text = ""
else: else:
current_text += line + "\n" current_text += line + "\n"
markdown_tups.append((current_header, current_text)) markdown_tups.append((current_header, current_text))
if current_header is not None: markdown_tups = [
# pass linting, assert keys are defined (re.sub(r"#", "", cast(str, key)).strip() if key else None, re.sub(r"<.*?>", "", value))
markdown_tups = [ for key, value in markdown_tups
(re.sub(r"#", "", cast(str, key)).strip(), re.sub(r"<.*?>", "", value)) for key, value in markdown_tups ]
]
else:
markdown_tups = [(key, re.sub("\n", "", value)) for key, value in markdown_tups]
return markdown_tups return markdown_tups

@ -76,6 +76,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
vector = Vector(dataset) vector = Vector(dataset)
vector.create(documents) vector.create(documents)
with_keywords = False
if with_keywords: if with_keywords:
keywords_list = kwargs.get("keywords_list") keywords_list = kwargs.get("keywords_list")
keyword = Keyword(dataset) keyword = Keyword(dataset)
@ -91,6 +92,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
vector.delete_by_ids(node_ids) vector.delete_by_ids(node_ids)
else: else:
vector.delete() vector.delete()
with_keywords = False
if with_keywords: if with_keywords:
keyword = Keyword(dataset) keyword = Keyword(dataset)
if node_ids: if node_ids:

@ -104,7 +104,7 @@ class QAIndexProcessor(BaseIndexProcessor):
def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
# check file type # check file type
if not file.filename or not file.filename.endswith(".csv"): if not file.filename or not file.filename.lower().endswith(".csv"):
raise ValueError("Invalid file type. Only CSV files are allowed") raise ValueError("Invalid file type. Only CSV files are allowed")
try: try:

@ -496,6 +496,8 @@ class DatasetRetrieval:
all_documents = self.calculate_keyword_score(query, all_documents, top_k) all_documents = self.calculate_keyword_score(query, all_documents, top_k)
elif index_type == "high_quality": elif index_type == "high_quality":
all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold) all_documents = self.calculate_vector_score(all_documents, top_k, score_threshold)
else:
all_documents = all_documents[:top_k] if top_k else all_documents
self._on_query(query, dataset_ids, app_id, user_from, user_id) self._on_query(query, dataset_ids, app_id, user_from, user_id)

@ -6,7 +6,7 @@ import json
import logging import logging
from typing import Optional, Union from typing import Optional, Union
from sqlalchemy import func, select from sqlalchemy import select
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
@ -16,6 +16,7 @@ from core.workflow.entities.workflow_execution import (
WorkflowType, WorkflowType,
) )
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import ( from models import (
Account, Account,
CreatorUserRole, CreatorUserRole,
@ -146,26 +147,17 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository):
db_model.workflow_id = domain_model.workflow_id db_model.workflow_id = domain_model.workflow_id
db_model.triggered_from = self._triggered_from db_model.triggered_from = self._triggered_from
# Check if this is a new record # No sequence number generation needed anymore
with self._session_factory() as session:
existing = session.scalar(select(WorkflowRun).where(WorkflowRun.id == domain_model.id_))
if not existing:
# For new records, get the next sequence number
stmt = select(func.max(WorkflowRun.sequence_number)).where(
WorkflowRun.app_id == self._app_id,
WorkflowRun.tenant_id == self._tenant_id,
)
max_sequence = session.scalar(stmt)
db_model.sequence_number = (max_sequence or 0) + 1
else:
# For updates, keep the existing sequence number
db_model.sequence_number = existing.sequence_number
db_model.type = domain_model.workflow_type db_model.type = domain_model.workflow_type
db_model.version = domain_model.workflow_version db_model.version = domain_model.workflow_version
db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None db_model.graph = json.dumps(domain_model.graph) if domain_model.graph else None
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None
db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None db_model.outputs = (
json.dumps(WorkflowRuntimeTypeConverter().to_json_encodable(domain_model.outputs))
if domain_model.outputs
else None
)
db_model.status = domain_model.status db_model.status = domain_model.status
db_model.error = domain_model.error_message if domain_model.error_message else None db_model.error = domain_model.error_message if domain_model.error_message else None
db_model.total_tokens = domain_model.total_tokens db_model.total_tokens = domain_model.total_tokens

@ -19,6 +19,7 @@ from core.workflow.entities.workflow_node_execution import (
) )
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from models import ( from models import (
Account, Account,
CreatorUserRole, CreatorUserRole,
@ -146,6 +147,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
if not self._creator_user_role: if not self._creator_user_role:
raise ValueError("created_by_role is required in repository constructor") raise ValueError("created_by_role is required in repository constructor")
json_converter = WorkflowRuntimeTypeConverter()
db_model = WorkflowNodeExecutionModel() db_model = WorkflowNodeExecutionModel()
db_model.id = domain_model.id db_model.id = domain_model.id
db_model.tenant_id = self._tenant_id db_model.tenant_id = self._tenant_id
@ -160,9 +162,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
db_model.node_id = domain_model.node_id db_model.node_id = domain_model.node_id
db_model.node_type = domain_model.node_type db_model.node_type = domain_model.node_type
db_model.title = domain_model.title db_model.title = domain_model.title
db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None db_model.inputs = (
db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None json.dumps(json_converter.to_json_encodable(domain_model.inputs)) if domain_model.inputs else None
db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None )
db_model.process_data = (
json.dumps(json_converter.to_json_encodable(domain_model.process_data))
if domain_model.process_data
else None
)
db_model.outputs = (
json.dumps(json_converter.to_json_encodable(domain_model.outputs)) if domain_model.outputs else None
)
db_model.status = domain_model.status db_model.status = domain_model.status
db_model.error = domain_model.error db_model.error = domain_model.error
db_model.elapsed_time = domain_model.elapsed_time db_model.elapsed_time = domain_model.elapsed_time

@ -75,6 +75,20 @@ class StringSegment(Segment):
class FloatSegment(Segment): class FloatSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER value_type: SegmentType = SegmentType.NUMBER
value: float value: float
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass.
#
# def test_float_segment_and_nan():
# nan = float("nan")
# assert nan != nan
#
# f1 = FloatSegment(value=float("nan"))
# f2 = FloatSegment(value=float("nan"))
# assert f1 != f2
#
# f3 = FloatSegment(value=nan)
# f4 = FloatSegment(value=nan)
# assert f3 != f4
class IntegerSegment(Segment): class IntegerSegment(Segment):

@ -18,3 +18,17 @@ class SegmentType(StrEnum):
NONE = "none" NONE = "none"
GROUP = "group" GROUP = "group"
def is_array_type(self):
return self in _ARRAY_TYPES
_ARRAY_TYPES = frozenset(
[
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
]
)

@ -1,8 +1,26 @@
import json
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from .segment_group import SegmentGroup
from .segments import ArrayFileSegment, FileSegment, Segment
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]: def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
selectors = [node_id, name] selectors = [node_id, name]
if paths: if paths:
selectors.extend(paths) selectors.extend(paths)
return selectors return selectors
class SegmentJSONEncoder(json.JSONEncoder):
def default(self, o):
if isinstance(o, ArrayFileSegment):
return [v.model_dump() for v in o.value]
elif isinstance(o, FileSegment):
return o.value.model_dump()
elif isinstance(o, SegmentGroup):
return [self.default(seg) for seg in o.value]
elif isinstance(o, Segment):
return o.value
else:
super().default(o)

@ -0,0 +1,39 @@
import abc
from typing import Protocol
from core.variables import Variable
class ConversationVariableUpdater(Protocol):
"""
ConversationVariableUpdater defines an abstraction for updating conversation variable values.
It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating
conversation variables.
Implementations may choose to batch updates. If batching is used, the `flush` method
should be implemented to persist buffered changes, and `update`
should handle buffering accordingly.
Note: Since implementations may buffer updates, instances of ConversationVariableUpdater
are not thread-safe. Each VariableAssignerNode should create its own instance during execution.
"""
@abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable") -> None:
"""
Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
:param variable: The `Variable` instance containing the updated value.
"""
pass
@abc.abstractmethod
def flush(self):
"""
Flushes all pending updates to the underlying storage system.
If the implementation does not buffer updates, this method can be a no-op.
"""
pass

@ -7,12 +7,12 @@ from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment from core.variables.segments import FileSegment, NoneSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import SystemVariableKey
from factories import variable_factory from factories import variable_factory
from ..constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from ..enums import SystemVariableKey
VariableValue = Union[str, int, float, dict, list, File] VariableValue = Union[str, int, float, dict, list, File]
VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}")
@ -30,9 +30,11 @@ class VariablePool(BaseModel):
# TODO: This user inputs is not used for pool. # TODO: This user inputs is not used for pool.
user_inputs: Mapping[str, Any] = Field( user_inputs: Mapping[str, Any] = Field(
description="User inputs", description="User inputs",
default_factory=dict,
) )
system_variables: Mapping[SystemVariableKey, Any] = Field( system_variables: Mapping[SystemVariableKey, Any] = Field(
description="System variables", description="System variables",
default_factory=dict,
) )
environment_variables: Sequence[Variable] = Field( environment_variables: Sequence[Variable] = Field(
description="Environment variables.", description="Environment variables.",
@ -43,28 +45,7 @@ class VariablePool(BaseModel):
default_factory=list, default_factory=list,
) )
def __init__( def model_post_init(self, context: Any, /) -> None:
self,
*,
system_variables: Mapping[SystemVariableKey, Any] | None = None,
user_inputs: Mapping[str, Any] | None = None,
environment_variables: Sequence[Variable] | None = None,
conversation_variables: Sequence[Variable] | None = None,
**kwargs,
):
environment_variables = environment_variables or []
conversation_variables = conversation_variables or []
user_inputs = user_inputs or {}
system_variables = system_variables or {}
super().__init__(
system_variables=system_variables,
user_inputs=user_inputs,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
**kwargs,
)
for key, value in self.system_variables.items(): for key, value in self.system_variables.items():
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
# Add environment variables to the variable pool # Add environment variables to the variable pool
@ -91,12 +72,12 @@ class VariablePool(BaseModel):
Returns: Returns:
None None
""" """
if len(selector) < 2: if len(selector) < MIN_SELECTORS_LENGTH:
raise ValueError("Invalid selector") raise ValueError("Invalid selector")
if isinstance(value, Variable): if isinstance(value, Variable):
variable = value variable = value
if isinstance(value, Segment): elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector) variable = variable_factory.segment_to_variable(segment=value, selector=selector)
else: else:
segment = variable_factory.build_segment(value) segment = variable_factory.build_segment(value)
@ -118,7 +99,7 @@ class VariablePool(BaseModel):
Raises: Raises:
ValueError: If the selector is invalid. ValueError: If the selector is invalid.
""" """
if len(selector) < 2: if len(selector) < MIN_SELECTORS_LENGTH:
return None return None
hash_key = hash(tuple(selector[1:])) hash_key = hash(tuple(selector[1:]))

@ -66,6 +66,8 @@ class BaseNodeEvent(GraphEngineEvent):
"""iteration id if node is in iteration""" """iteration id if node is in iteration"""
in_loop_id: Optional[str] = None in_loop_id: Optional[str] = None
"""loop id if node is in loop""" """loop id if node is in loop"""
# The version of the node, or "1" if not specified.
node_version: str = "1"
class NodeRunStartedEvent(BaseNodeEvent): class NodeRunStartedEvent(BaseNodeEvent):

@ -53,6 +53,7 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.utils import variable_utils
from libs.flask_utils import preserve_flask_contexts from libs.flask_utils import preserve_flask_contexts
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowType
@ -314,6 +315,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
raise e raise e
@ -627,6 +629,7 @@ class GraphEngine:
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
agent_strategy=agent_strategy, agent_strategy=agent_strategy,
node_version=node_instance.version(),
) )
max_retries = node_instance.node_data.retry_config.max_retries max_retries = node_instance.node_data.retry_config.max_retries
@ -639,26 +642,19 @@ class GraphEngine:
retry_start_at = datetime.now(UTC).replace(tzinfo=None) retry_start_at = datetime.now(UTC).replace(tzinfo=None)
# yield control to other threads # yield control to other threads
time.sleep(0.001) time.sleep(0.001)
generator = node_instance.run() event_stream = node_instance.run()
for item in generator: for event in event_stream:
if isinstance(item, GraphEngineEvent): if isinstance(event, GraphEngineEvent):
if isinstance(item, BaseIterationEvent): # add parallel info to iteration event
# add parallel info to iteration event if isinstance(event, BaseIterationEvent | BaseLoopEvent):
item.parallel_id = parallel_id event.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id event.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id event.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id event.parent_parallel_start_node_id = parent_parallel_start_node_id
elif isinstance(item, BaseLoopEvent): yield event
# add parallel info to loop event
item.parallel_id = parallel_id
item.parallel_start_node_id = parallel_start_node_id
item.parent_parallel_id = parent_parallel_id
item.parent_parallel_start_node_id = parent_parallel_start_node_id
yield item
else: else:
if isinstance(item, RunCompletedEvent): if isinstance(event, RunCompletedEvent):
run_result = item.run_result run_result = event.run_result
if run_result.status == WorkflowNodeExecutionStatus.FAILED: if run_result.status == WorkflowNodeExecutionStatus.FAILED:
if ( if (
retries == max_retries retries == max_retries
@ -684,6 +680,7 @@ class GraphEngine:
error=run_result.error or "Unknown error", error=run_result.error or "Unknown error",
retry_index=retries, retry_index=retries,
start_at=retry_start_at, start_at=retry_start_at,
node_version=node_instance.version(),
) )
time.sleep(retry_interval) time.sleep(retry_interval)
break break
@ -694,7 +691,7 @@ class GraphEngine:
# if run failed, handle error # if run failed, handle error
run_result = self._handle_continue_on_error( run_result = self._handle_continue_on_error(
node_instance, node_instance,
item.run_result, event.run_result,
self.graph_runtime_state.variable_pool, self.graph_runtime_state.variable_pool,
handle_exceptions=handle_exceptions, handle_exceptions=handle_exceptions,
) )
@ -719,6 +716,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
should_continue_retry = False should_continue_retry = False
else: else:
@ -733,6 +731,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
should_continue_retry = False should_continue_retry = False
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
@ -793,37 +792,40 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
should_continue_retry = False should_continue_retry = False
break break
elif isinstance(item, RunStreamChunkEvent): elif isinstance(event, RunStreamChunkEvent):
yield NodeRunStreamChunkEvent( yield NodeRunStreamChunkEvent(
id=node_instance.id, id=node_instance.id,
node_id=node_instance.node_id, node_id=node_instance.node_id,
node_type=node_instance.node_type, node_type=node_instance.node_type,
node_data=node_instance.node_data, node_data=node_instance.node_data,
chunk_content=item.chunk_content, chunk_content=event.chunk_content,
from_variable_selector=item.from_variable_selector, from_variable_selector=event.from_variable_selector,
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
elif isinstance(item, RunRetrieverResourceEvent): elif isinstance(event, RunRetrieverResourceEvent):
yield NodeRunRetrieverResourceEvent( yield NodeRunRetrieverResourceEvent(
id=node_instance.id, id=node_instance.id,
node_id=node_instance.node_id, node_id=node_instance.node_id,
node_type=node_instance.node_type, node_type=node_instance.node_type,
node_data=node_instance.node_data, node_data=node_instance.node_data,
retriever_resources=item.retriever_resources, retriever_resources=event.retriever_resources,
context=item.context, context=event.context,
route_node_state=route_node_state, route_node_state=route_node_state,
parallel_id=parallel_id, parallel_id=parallel_id,
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
except GenerateTaskStoppedError: except GenerateTaskStoppedError:
# trigger node run failed event # trigger node run failed event
@ -840,6 +842,7 @@ class GraphEngine:
parallel_start_node_id=parallel_start_node_id, parallel_start_node_id=parallel_start_node_id,
parent_parallel_id=parent_parallel_id, parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id, parent_parallel_start_node_id=parent_parallel_start_node_id,
node_version=node_instance.version(),
) )
return return
except Exception as e: except Exception as e:
@ -854,16 +857,12 @@ class GraphEngine:
:param variable_value: variable value :param variable_value: variable value
:return: :return:
""" """
self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value) variable_utils.append_variables_recursively(
self.graph_runtime_state.variable_pool,
# if variable_value is a dict, then recursively append variables node_id,
if isinstance(variable_value, dict): variable_key_list,
for key, value in variable_value.items(): variable_value,
# construct new key list )
new_key_list = variable_key_list + [key]
self._append_variables_recursively(
node_id=node_id, variable_key_list=new_key_list, variable_value=value
)
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
""" """

@ -39,6 +39,10 @@ class AgentNode(ToolNode):
_node_data_cls = AgentNodeData # type: ignore _node_data_cls = AgentNodeData # type: ignore
_node_type = NodeType.AGENT _node_type = NodeType.AGENT
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator: def _run(self) -> Generator:
""" """
Run the agent node Run the agent node

@ -18,7 +18,11 @@ from core.workflow.utils.variable_template_parser import VariableTemplateParser
class AnswerNode(BaseNode[AnswerNodeData]): class AnswerNode(BaseNode[AnswerNodeData]):
_node_data_cls = AnswerNodeData _node_data_cls = AnswerNodeData
_node_type: NodeType = NodeType.ANSWER _node_type = NodeType.ANSWER
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
""" """
@ -45,7 +49,10 @@ class AnswerNode(BaseNode[AnswerNodeData]):
part = cast(TextGenerateRouteChunk, part) part = cast(TextGenerateRouteChunk, part)
answer += part.text answer += part.text
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"answer": answer, "files": files}) return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"answer": answer, "files": ArrayFileSegment(value=files)},
)
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(

@ -109,6 +109,7 @@ class AnswerStreamProcessor(StreamProcessor):
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
from_variable_selector=[answer_node_id, "answer"], from_variable_selector=[answer_node_id, "answer"],
node_version=event.node_version,
) )
else: else:
route_chunk = cast(VarGenerateRouteChunk, route_chunk) route_chunk = cast(VarGenerateRouteChunk, route_chunk)
@ -134,6 +135,7 @@ class AnswerStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state, route_node_state=event.route_node_state,
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
) )
self.route_position[answer_node_id] += 1 self.route_position[answer_node_id] += 1

@ -57,7 +57,6 @@ class StreamProcessor(ABC):
# The branch_identify parameter is added to ensure that # The branch_identify parameter is added to ensure that
# only nodes in the correct logical branch are included. # only nodes in the correct logical branch are included.
reachable_node_ids.append(edge.target_node_id)
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle) ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id, run_result.edge_source_handle)
reachable_node_ids.extend(ids) reachable_node_ids.extend(ids)
else: else:
@ -74,6 +73,8 @@ class StreamProcessor(ABC):
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]: def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
if node_id not in self.rest_node_ids:
self.rest_node_ids.append(node_id)
node_ids = [] node_ids = []
for edge in self.graph.edge_mapping.get(node_id, []): for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id == self.graph.root_node_id: if edge.target_node_id == self.graph.root_node_id:

@ -1,7 +1,7 @@
import logging import logging
from abc import abstractmethod from abc import abstractmethod
from collections.abc import Generator, Mapping, Sequence from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar, Union, cast from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -23,7 +23,7 @@ GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData)
class BaseNode(Generic[GenericNodeData]): class BaseNode(Generic[GenericNodeData]):
_node_data_cls: type[GenericNodeData] _node_data_cls: type[GenericNodeData]
_node_type: NodeType _node_type: ClassVar[NodeType]
def __init__( def __init__(
self, self,
@ -90,8 +90,38 @@ class BaseNode(Generic[GenericNodeData]):
graph_config: Mapping[str, Any], graph_config: Mapping[str, Any],
config: Mapping[str, Any], config: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]: ) -> Mapping[str, Sequence[str]]:
""" """Extracts references variable selectors from node configuration.
Extract variable selector to variable mapping
The `config` parameter represents the configuration for a specific node type and corresponds
to the `data` field in the node definition object.
The returned mapping has the following structure:
{'1747829548239.#1747829667553.result#': ['1747829667553', 'result']}
For loop and iteration nodes, the mapping may look like this:
{
"1748332301644.input_selector": ["1748332363630", "result"],
"1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"],
}
where `1748332301644` is the ID of the loop / iteration node,
and `1748332325079` is the ID of the node inside the loop or iteration node.
Here, the key consists of two parts: the current node ID (provided as the `node_id`
parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector,
enclosed in `#` symbols. These two parts are separated by a dot (`.`).
The value is a list of string representing the variable selector, where the first element is the node ID
of the referenced variable, and the second element is the variable name within that node.
The meaning of the above response is:
The node with ID `1747829548239` references the variable `result` from the node with
ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a
reference to the `result` output variable of node `1747829667553`.
:param graph_config: graph config :param graph_config: graph config
:param config: node config :param config: node config
:return: :return:
@ -101,9 +131,10 @@ class BaseNode(Generic[GenericNodeData]):
raise ValueError("Node ID is required when extracting variable selector to variable mapping.") raise ValueError("Node ID is required when extracting variable selector to variable mapping.")
node_data = cls._node_data_cls(**config.get("data", {})) node_data = cls._node_data_cls(**config.get("data", {}))
return cls._extract_variable_selector_to_variable_mapping( data = cls._extract_variable_selector_to_variable_mapping(
graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data)
) )
return data
@classmethod @classmethod
def _extract_variable_selector_to_variable_mapping( def _extract_variable_selector_to_variable_mapping(
@ -139,6 +170,16 @@ class BaseNode(Generic[GenericNodeData]):
""" """
return self._node_type return self._node_type
@classmethod
@abstractmethod
def version(cls) -> str:
"""`node_version` returns the version of current node type."""
# NOTE(QuantumGhost): This should be in sync with `NODE_TYPE_CLASSES_MAPPING`.
#
# If you have introduced a new node type, please add it to `NODE_TYPE_CLASSES_MAPPING`
# in `api/core/workflow/nodes/__init__.py`.
raise NotImplementedError("subclasses of BaseNode must implement `version` method.")
@property @property
def should_continue_on_error(self) -> bool: def should_continue_on_error(self) -> bool:
"""judge if should continue on error """judge if should continue on error

@ -40,6 +40,10 @@ class CodeNode(BaseNode[CodeNodeData]):
return code_provider.get_default_config() return code_provider.get_default_config()
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
# Get code language # Get code language
code_language = self.node_data.code_language code_language = self.node_data.code_language
@ -126,6 +130,9 @@ class CodeNode(BaseNode[CodeNodeData]):
prefix: str = "", prefix: str = "",
depth: int = 1, depth: int = 1,
): ):
# TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes.
# Note that `_transform_result` may produce lists containing `None` values,
# which don't conform to the type requirements of `Array*Segment` classes.
if depth > dify_config.CODE_MAX_DEPTH: if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.") raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")

@ -24,7 +24,7 @@ from configs import dify_config
from core.file import File, FileTransferMethod, file_manager from core.file import File, FileTransferMethod, file_manager
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
from core.variables import ArrayFileSegment from core.variables import ArrayFileSegment
from core.variables.segments import FileSegment from core.variables.segments import ArrayStringSegment, FileSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
@ -45,6 +45,10 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
_node_data_cls = DocumentExtractorNodeData _node_data_cls = DocumentExtractorNodeData
_node_type = NodeType.DOCUMENT_EXTRACTOR _node_type = NodeType.DOCUMENT_EXTRACTOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self): def _run(self):
variable_selector = self.node_data.variable_selector variable_selector = self.node_data.variable_selector
variable = self.graph_runtime_state.variable_pool.get(variable_selector) variable = self.graph_runtime_state.variable_pool.get(variable_selector)
@ -67,7 +71,7 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs, inputs=inputs,
process_data=process_data, process_data=process_data,
outputs={"text": extracted_text_list}, outputs={"text": ArrayStringSegment(value=extracted_text_list)},
) )
elif isinstance(value, File): elif isinstance(value, File):
extracted_text = _extract_text_from_file(value) extracted_text = _extract_text_from_file(value)
@ -447,7 +451,7 @@ def _extract_text_from_excel(file_content: bytes) -> str:
df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore
# Combine multi-line text in column names into a single line # Combine multi-line text in column names into a single line
df.columns = pd.Index([" ".join(col.splitlines()) for col in df.columns]) df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns])
# Manually construct the Markdown table # Manually construct the Markdown table
markdown_table += _construct_markdown_table(df) + "\n\n" markdown_table += _construct_markdown_table(df) + "\n\n"

@ -9,6 +9,10 @@ class EndNode(BaseNode[EndNodeData]):
_node_data_cls = EndNodeData _node_data_cls = EndNodeData
_node_type = NodeType.END _node_type = NodeType.END
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
""" """
Run node Run node

@ -139,6 +139,7 @@ class EndStreamProcessor(StreamProcessor):
route_node_state=event.route_node_state, route_node_state=event.route_node_state,
parallel_id=event.parallel_id, parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id, parallel_start_node_id=event.parallel_start_node_id,
node_version=event.node_version,
) )
self.route_position[end_node_id] += 1 self.route_position[end_node_id] += 1

@ -6,6 +6,7 @@ from typing import Any, Optional
from configs import dify_config from configs import dify_config
from core.file import File, FileTransferMethod from core.file import File, FileTransferMethod
from core.tools.tool_file_manager import ToolFileManager from core.tools.tool_file_manager import ToolFileManager
from core.variables.segments import ArrayFileSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_entities import VariableSelector
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -60,6 +61,10 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
}, },
} }
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
process_data = {} process_data = {}
try: try:
@ -92,7 +97,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={ outputs={
"status_code": response.status_code, "status_code": response.status_code,
"body": response.text if not files else "", "body": response.text if not files.value else "",
"headers": response.headers, "headers": response.headers,
"files": files, "files": files,
}, },
@ -166,7 +171,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
return mapping return mapping
def extract_files(self, url: str, response: Response) -> list[File]: def extract_files(self, url: str, response: Response) -> ArrayFileSegment:
""" """
Extract files from response by checking both Content-Type header and URL Extract files from response by checking both Content-Type header and URL
""" """
@ -178,7 +183,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
content_disposition_type = None content_disposition_type = None
if not is_file: if not is_file:
return files return ArrayFileSegment(value=[])
if parsed_content_disposition: if parsed_content_disposition:
content_disposition_filename = parsed_content_disposition.get_filename() content_disposition_filename = parsed_content_disposition.get_filename()
@ -211,4 +216,4 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
) )
files.append(file) files.append(file)
return files return ArrayFileSegment(value=files)

@ -1,4 +1,5 @@
from typing import Literal from collections.abc import Mapping, Sequence
from typing import Any, Literal
from typing_extensions import deprecated from typing_extensions import deprecated
@ -16,6 +17,10 @@ class IfElseNode(BaseNode[IfElseNodeData]):
_node_data_cls = IfElseNodeData _node_data_cls = IfElseNodeData
_node_type = NodeType.IF_ELSE _node_type = NodeType.IF_ELSE
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
""" """
Run node Run node
@ -87,6 +92,22 @@ class IfElseNode(BaseNode[IfElseNodeData]):
return data return data
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: IfElseNodeData,
) -> Mapping[str, Sequence[str]]:
var_mapping: dict[str, list[str]] = {}
for case in node_data.cases or []:
for condition in case.conditions:
key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector))
var_mapping[key] = condition.variable_selector
return var_mapping
@deprecated("This function is deprecated. You should use the new cases structure.") @deprecated("This function is deprecated. You should use the new cases structure.")
def _should_not_use_old_function( def _should_not_use_old_function(

@ -11,6 +11,7 @@ from flask import Flask, current_app
from configs import dify_config from configs import dify_config
from core.variables import ArrayVariable, IntegerVariable, NoneVariable from core.variables import ArrayVariable, IntegerVariable, NoneVariable
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import ( from core.workflow.entities.node_entities import (
NodeRunResult, NodeRunResult,
) )
@ -37,6 +38,7 @@ from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData
from factories.variable_factory import build_segment
from libs.flask_utils import preserve_flask_contexts from libs.flask_utils import preserve_flask_contexts
from .exc import ( from .exc import (
@ -72,6 +74,10 @@ class IterationNode(BaseNode[IterationNodeData]):
}, },
} }
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
""" """
Run the node. Run the node.
@ -85,10 +91,17 @@ class IterationNode(BaseNode[IterationNodeData]):
raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.")
if isinstance(variable, NoneVariable) or len(variable.value) == 0: if isinstance(variable, NoneVariable) or len(variable.value) == 0:
# Try our best to preserve the type informat.
if isinstance(variable, ArraySegment):
output = variable.model_copy(update={"value": []})
else:
output = ArrayAnySegment(value=[])
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": []}, # TODO(QuantumGhost): is it possible to compute the type of `output`
# from graph definition?
outputs={"output": output},
) )
) )
return return
@ -231,6 +244,7 @@ class IterationNode(BaseNode[IterationNodeData]):
# Flatten the list of lists # Flatten the list of lists
if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs): if isinstance(outputs, list) and all(isinstance(output, list) for output in outputs):
outputs = [item for sublist in outputs for item in sublist] outputs = [item for sublist in outputs for item in sublist]
output_segment = build_segment(outputs)
yield IterationRunSucceededEvent( yield IterationRunSucceededEvent(
iteration_id=self.id, iteration_id=self.id,
@ -247,7 +261,7 @@ class IterationNode(BaseNode[IterationNodeData]):
yield RunCompletedEvent( yield RunCompletedEvent(
run_result=NodeRunResult( run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": outputs}, outputs={"output": output_segment},
metadata={ metadata={
WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map,
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens,

@ -13,6 +13,10 @@ class IterationStartNode(BaseNode[IterationStartNodeData]):
_node_data_cls = IterationStartNodeData _node_data_cls = IterationStartNodeData
_node_type = NodeType.ITERATION_START _node_type = NodeType.ITERATION_START
@classmethod
def version(cls) -> str:
return "1"
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
""" """
Run the node. Run the node.

@ -24,6 +24,7 @@ from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment from core.variables import StringSegment
from core.variables.segments import ArrayObjectSegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
@ -70,6 +71,10 @@ class KnowledgeRetrievalNode(LLMNode):
_node_data_cls = KnowledgeRetrievalNodeData # type: ignore _node_data_cls = KnowledgeRetrievalNodeData # type: ignore
_node_type = NodeType.KNOWLEDGE_RETRIEVAL _node_type = NodeType.KNOWLEDGE_RETRIEVAL
@classmethod
def version(cls):
return "1"
def _run(self) -> NodeRunResult: # type: ignore def _run(self) -> NodeRunResult: # type: ignore
node_data = cast(KnowledgeRetrievalNodeData, self.node_data) node_data = cast(KnowledgeRetrievalNodeData, self.node_data)
# extract variables # extract variables
@ -115,9 +120,12 @@ class KnowledgeRetrievalNode(LLMNode):
# retrieve knowledge # retrieve knowledge
try: try:
results = self._fetch_dataset_retriever(node_data=node_data, query=query) results = self._fetch_dataset_retriever(node_data=node_data, query=query)
outputs = {"result": results} outputs = {"result": ArrayObjectSegment(value=results)}
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=outputs status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
process_data=None,
outputs=outputs, # type: ignore
) )
except KnowledgeRetrievalNodeError as e: except KnowledgeRetrievalNodeError as e:

@ -3,6 +3,7 @@ from typing import Any, Literal, Union
from core.file import File from core.file import File
from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
@ -16,6 +17,10 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
_node_data_cls = ListOperatorNodeData _node_data_cls = ListOperatorNodeData
_node_type = NodeType.LIST_OPERATOR _node_type = NodeType.LIST_OPERATOR
@classmethod
def version(cls) -> str:
return "1"
def _run(self): def _run(self):
inputs: dict[str, list] = {} inputs: dict[str, list] = {}
process_data: dict[str, list] = {} process_data: dict[str, list] = {}
@ -30,7 +35,11 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
if not variable.value: if not variable.value:
inputs = {"variable": []} inputs = {"variable": []}
process_data = {"variable": []} process_data = {"variable": []}
outputs = {"result": [], "first_record": None, "last_record": None} if isinstance(variable, ArraySegment):
result = variable.model_copy(update={"value": []})
else:
result = ArrayAnySegment(value=[])
outputs = {"result": result, "first_record": None, "last_record": None}
return NodeRunResult( return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED, status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=inputs, inputs=inputs,
@ -71,7 +80,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]):
variable = self._apply_slice(variable) variable = self._apply_slice(variable)
outputs = { outputs = {
"result": variable.value, "result": variable,
"first_record": variable.value[0] if variable.value else None, "first_record": variable.value[0] if variable.value else None,
"last_record": variable.value[-1] if variable.value else None, "last_record": variable.value[-1] if variable.value else None,
} }

@ -119,9 +119,6 @@ class FileSaverImpl(LLMFileSaver):
size=len(data), size=len(data),
related_id=tool_file.id, related_id=tool_file.id,
url=url, url=url,
# TODO(QuantumGhost): how should I set the following key?
# What's the difference between `remote_url` and `url`?
# What's the purpose of `storage_key` and `dify_model_identity`?
storage_key=tool_file.file_key, storage_key=tool_file.file_key,
) )

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save