diff --git a/.github/actions/setup-poetry/action.yml b/.github/actions/setup-poetry/action.yml index 5feab33d1d..2e76676f37 100644 --- a/.github/actions/setup-poetry/action.yml +++ b/.github/actions/setup-poetry/action.yml @@ -4,7 +4,7 @@ inputs: python-version: description: Python version to use and the Poetry installed with required: true - default: '3.10' + default: '3.11' poetry-version: description: Poetry version to set up required: true diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 76e844aaad..e1c0bf33a4 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -20,7 +20,6 @@ jobs: strategy: matrix: python-version: - - "3.10" - "3.11" - "3.12" diff --git a/.github/workflows/db-migration-test.yml b/.github/workflows/db-migration-test.yml index f4eb0f8e33..3d881c4c3d 100644 --- a/.github/workflows/db-migration-test.yml +++ b/.github/workflows/db-migration-test.yml @@ -48,6 +48,8 @@ jobs: cp .env.example .env - name: Run DB Migration + env: + DEBUG: true run: | cd api poetry run python -m flask upgrade-db diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index caddd23bab..73af370063 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -8,6 +8,8 @@ on: - api/core/rag/datasource/** - docker/** - .github/workflows/vdb-tests.yml + - api/poetry.lock + - api/pyproject.toml concurrency: group: vdb-tests-${{ github.head_ref || github.run_id }} @@ -20,7 +22,6 @@ jobs: strategy: matrix: python-version: - - "3.10" - "3.11" - "3.12" diff --git a/CONTRIBUTING_CN.md b/CONTRIBUTING_CN.md index 310c55090a..4fa43b24ee 100644 --- a/CONTRIBUTING_CN.md +++ b/CONTRIBUTING_CN.md @@ -71,7 +71,7 @@ Dify 依赖以下工具和库: - [Docker Compose](https://docs.docker.com/compose/install/) - [Node.js v18.x (LTS)](http://nodejs.org) - [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) -- [Python](https://www.python.org/) version 3.10.x +- [Python](https://www.python.org/) version 3.11.x or 3.12.x ### 4. 安装 diff --git a/CONTRIBUTING_JA.md b/CONTRIBUTING_JA.md index a68bdeddbc..22e30e9c03 100644 --- a/CONTRIBUTING_JA.md +++ b/CONTRIBUTING_JA.md @@ -74,7 +74,7 @@ Dify を構築するには次の依存関係が必要です。それらがシス - [Docker Compose](https://docs.docker.com/compose/install/) - [Node.js v18.x (LTS)](http://nodejs.org) - [npm](https://www.npmjs.com/) version 8.x.x or [Yarn](https://yarnpkg.com/) -- [Python](https://www.python.org/) version 3.10.x +- [Python](https://www.python.org/) version 3.11.x or 3.12.x ### 4. インストール diff --git a/CONTRIBUTING_VI.md b/CONTRIBUTING_VI.md index a77239ff38..ad41f51aeb 100644 --- a/CONTRIBUTING_VI.md +++ b/CONTRIBUTING_VI.md @@ -73,7 +73,7 @@ Dify yêu cầu các phụ thuộc sau để build, hãy đảm bảo chúng đ - [Docker Compose](https://docs.docker.com/compose/install/) - [Node.js v18.x (LTS)](http://nodejs.org) - [npm](https://www.npmjs.com/) phiên bản 8.x.x hoặc [Yarn](https://yarnpkg.com/) -- [Python](https://www.python.org/) phiên bản 3.10.x +- [Python](https://www.python.org/) phiên bản 3.11.x hoặc 3.12.x ### 4. Cài đặt @@ -153,4 +153,4 @@ Và thế là xong! Khi PR của bạn được merge, bạn sẽ được giớ ## Nhận trợ giúp -Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng. \ No newline at end of file +Nếu bạn gặp khó khăn hoặc có câu hỏi cấp bách trong quá trình đóng góp, hãy đặt câu hỏi của bạn trong vấn đề GitHub liên quan, hoặc tham gia [Discord](https://discord.gg/8Tpq4AcN9c) của chúng tôi để trò chuyện nhanh chóng. diff --git a/README.md b/README.md index 4c2d803854..df6c481e78 100644 --- a/README.md +++ b/README.md @@ -147,6 +147,13 @@ Deploy Dify to Cloud Platform with a single click using [terraform](https://www. ##### Google Cloud - [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) +#### Using AWS CDK for Deployment + +Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/) + +##### AWS +- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) + ## Contributing For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_AR.md b/README_AR.md index a4cfd744c0..d42c7508b1 100644 --- a/README_AR.md +++ b/README_AR.md @@ -190,6 +190,13 @@ docker compose up -d ##### Google Cloud - [Google Cloud Terraform بواسطة @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) +#### استخدام AWS CDK للنشر + +انشر Dify على AWS باستخدام [CDK](https://aws.amazon.com/cdk/) + +##### AWS +- [AWS CDK بواسطة @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) + ## المساهمة لأولئك الذين يرغبون في المساهمة، انظر إلى [دليل المساهمة](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) لدينا. @@ -222,3 +229,10 @@ docker compose up -d ## الرخصة هذا المستودع متاح تحت [رخصة البرنامج الحر Dify](LICENSE)، والتي تعتبر بشكل أساسي Apache 2.0 مع بعض القيود الإضافية. +## الكشف عن الأمان + +لحماية خصوصيتك، يرجى تجنب نشر مشكلات الأمان على GitHub. بدلاً من ذلك، أرسل أسئلتك إلى security@dify.ai وسنقدم لك إجابة أكثر تفصيلاً. + +## الرخصة + +هذا المستودع متاح تحت [رخصة البرنامج الحر Dify](LICENSE)، والتي تعتبر بشكل أساسي Apache 2.0 مع بعض القيود الإضافية. diff --git a/README_CN.md b/README_CN.md index 2a3f12dd05..8d1cfbf274 100644 --- a/README_CN.md +++ b/README_CN.md @@ -213,6 +213,13 @@ docker compose up -d ##### Google Cloud - [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) +#### 使用 AWS CDK 部署 + +使用 [CDK](https://aws.amazon.com/cdk/) 将 Dify 部署到 AWS + +##### AWS +- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) + ## Star History [![Star History Chart](https://api.star-history.com/svg?repos=langgenius/dify&type=Date)](https://star-history.com/#langgenius/dify&Date) diff --git a/README_ES.md b/README_ES.md index ab79ec9f85..9763de69fb 100644 --- a/README_ES.md +++ b/README_ES.md @@ -215,6 +215,13 @@ Despliega Dify en una plataforma en la nube con un solo clic utilizando [terrafo ##### Google Cloud - [Google Cloud Terraform por @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) +#### Usando AWS CDK para el Despliegue + +Despliegue Dify en AWS usando [CDK](https://aws.amazon.com/cdk/) + +##### AWS +- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) + ## Contribuir Para aquellos que deseen contribuir con código, consulten nuestra [Guía de contribución](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). @@ -248,3 +255,10 @@ Para proteger tu privacidad, evita publicar problemas de seguridad en GitHub. En ## Licencia Este repositorio está disponible bajo la [Licencia de Código Abierto de Dify](LICENSE), que es esencialmente Apache 2.0 con algunas restricciones adicionales. +## Divulgación de Seguridad + +Para proteger tu privacidad, evita publicar problemas de seguridad en GitHub. En su lugar, envía tus preguntas a security@dify.ai y te proporcionaremos una respuesta más detallada. + +## Licencia + +Este repositorio está disponible bajo la [Licencia de Código Abierto de Dify](LICENSE), que es esencialmente Apache 2.0 con algunas restricciones adicionales. diff --git a/README_FR.md b/README_FR.md index 1c963b495f..974c0b9297 100644 --- a/README_FR.md +++ b/README_FR.md @@ -213,6 +213,13 @@ Déployez Dify sur une plateforme cloud en un clic en utilisant [terraform](http ##### Google Cloud - [Google Cloud Terraform par @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) +#### Utilisation d'AWS CDK pour le déploiement + +Déployez Dify sur AWS en utilisant [CDK](https://aws.amazon.com/cdk/) + +##### AWS +- [AWS CDK par @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) + ## Contribuer Pour ceux qui souhaitent contribuer du code, consultez notre [Guide de contribution](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). @@ -246,3 +253,10 @@ Pour protéger votre vie privée, veuillez éviter de publier des problèmes de ## Licence Ce référentiel est disponible sous la [Licence open source Dify](LICENSE), qui est essentiellement l'Apache 2.0 avec quelques restrictions supplémentaires. +## Divulgation de sécurité + +Pour protéger votre vie privée, veuillez éviter de publier des problèmes de sécurité sur GitHub. Au lieu de cela, envoyez vos questions à security@dify.ai et nous vous fournirons une réponse plus détaillée. + +## Licence + +Ce référentiel est disponible sous la [Licence open source Dify](LICENSE), qui est essentiellement l'Apache 2.0 avec quelques restrictions supplémentaires. diff --git a/README_JA.md b/README_JA.md index b0f06ff259..9651219157 100644 --- a/README_JA.md +++ b/README_JA.md @@ -212,6 +212,13 @@ docker compose up -d ##### Google Cloud - [@sotazumによるGoogle Cloud Terraform](https://github.com/DeNA/dify-google-cloud-terraform) +#### AWS CDK を使用したデプロイ + +[CDK](https://aws.amazon.com/cdk/) を使用して、DifyをAWSにデプロイします + +##### AWS +- [@KevinZhaoによるAWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) + ## 貢献 コードに貢献したい方は、[Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)を参照してください。 diff --git a/README_KL.md b/README_KL.md index be727774e9..dd37b8243b 100644 --- a/README_KL.md +++ b/README_KL.md @@ -213,6 +213,13 @@ wa'logh nIqHom neH ghun deployment toy'wI' [terraform](https://www.terraform.io/ ##### Google Cloud - [Google Cloud Terraform qachlot @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) +#### AWS CDK atorlugh pilersitsineq + +wa'logh nIqHom neH ghun deployment toy'wI' [CDK](https://aws.amazon.com/cdk/) lo'laH. + +##### AWS +- [AWS CDK qachlot @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) + ## Contributing For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_KR.md b/README_KR.md index 9f8e072ba6..8edbb99226 100644 --- a/README_KR.md +++ b/README_KR.md @@ -205,6 +205,13 @@ Dify를 Kubernetes에 배포하고 프리미엄 스케일링 설정을 구성했 ##### Google Cloud - [sotazum의 Google Cloud Terraform](https://github.com/DeNA/dify-google-cloud-terraform) +#### AWS CDK를 사용한 배포 + +[CDK](https://aws.amazon.com/cdk/)를 사용하여 AWS에 Dify 배포 + +##### AWS +- [KevinZhao의 AWS CDK](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) + ## 기여 코드에 기여하고 싶은 분들은 [기여 가이드](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md)를 참조하세요. diff --git a/README_PT.md b/README_PT.md index d822cbea67..f947538952 100644 --- a/README_PT.md +++ b/README_PT.md @@ -211,6 +211,13 @@ Implante o Dify na Plataforma Cloud com um único clique usando [terraform](http ##### Google Cloud - [Google Cloud Terraform por @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) +#### Usando AWS CDK para Implantação + +Implante o Dify na AWS usando [CDK](https://aws.amazon.com/cdk/) + +##### AWS +- [AWS CDK por @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) + ## Contribuindo Para aqueles que desejam contribuir com código, veja nosso [Guia de Contribuição](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md). diff --git a/README_SI.md b/README_SI.md index 41a44600e8..6badf47f01 100644 --- a/README_SI.md +++ b/README_SI.md @@ -145,6 +145,13 @@ namestite Dify v Cloud Platform z enim klikom z uporabo [terraform](https://www. ##### Google Cloud - [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) +#### Uporaba AWS CDK za uvajanje + +Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/) + +##### AWS +- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) + ## Prispevam Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkrati vas prosimo, da podprete Dify tako, da ga delite na družbenih medijih ter na dogodkih in konferencah. diff --git a/README_TR.md b/README_TR.md index 38fada34e9..24ed0c9a08 100644 --- a/README_TR.md +++ b/README_TR.md @@ -211,6 +211,13 @@ Dify'ı bulut platformuna tek tıklamayla dağıtın [terraform](https://www.ter ##### Google Cloud - [Google Cloud Terraform tarafından @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) +#### AWS CDK ile Dağıtım + +[CDK](https://aws.amazon.com/cdk/) kullanarak Dify'ı AWS'ye dağıtın + +##### AWS +- [AWS CDK tarafından @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) + ## Katkıda Bulunma Kod katkısında bulunmak isteyenler için [Katkı Kılavuzumuza](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) bakabilirsiniz. diff --git a/README_VI.md b/README_VI.md index 6f296e508c..9076fcaae7 100644 --- a/README_VI.md +++ b/README_VI.md @@ -207,6 +207,13 @@ Triển khai Dify lên nền tảng đám mây với một cú nhấp chuột b ##### Google Cloud - [Google Cloud Terraform bởi @sotazum](https://github.com/DeNA/dify-google-cloud-terraform) +#### Sử dụng AWS CDK để Triển khai + +Triển khai Dify trên AWS bằng [CDK](https://aws.amazon.com/cdk/) + +##### AWS +- [AWS CDK bởi @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws) + ## Đóng góp Đối với những người muốn đóng góp mã, xem [Hướng dẫn Đóng góp](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) của chúng tôi. diff --git a/api/.env.example b/api/.env.example index f8a2812563..52cdd9ecb2 100644 --- a/api/.env.example +++ b/api/.env.example @@ -329,6 +329,7 @@ NOTION_INTERNAL_SECRET=you-internal-secret ETL_TYPE=dify UNSTRUCTURED_API_URL= UNSTRUCTURED_API_KEY= +SCARF_NO_ANALYTICS=true #ssrf SSRF_PROXY_HTTP_URL= @@ -382,7 +383,7 @@ LOG_DATEFORMAT=%Y-%m-%d %H:%M:%S LOG_TZ=UTC # Indexing configuration -INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000 +INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 # Workflow runtime configuration WORKFLOW_MAX_EXECUTION_STEPS=500 @@ -410,4 +411,5 @@ POSITION_PROVIDER_EXCLUDES= # Reset password token expiry minutes RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 -CREATE_TIDB_SERVICE_JOB_ENABLED=false \ No newline at end of file +CREATE_TIDB_SERVICE_JOB_ENABLED=false + diff --git a/api/.ruff.toml b/api/.ruff.toml new file mode 100644 index 0000000000..0f3185223c --- /dev/null +++ b/api/.ruff.toml @@ -0,0 +1,96 @@ +exclude = [ + "migrations/*", +] +line-length = 120 + +[format] +quote-style = "double" + +[lint] +preview = true +select = [ + "B", # flake8-bugbear rules + "C4", # flake8-comprehensions + "E", # pycodestyle E rules + "F", # pyflakes rules + "FURB", # refurb rules + "I", # isort rules + "N", # pep8-naming + "PT", # flake8-pytest-style rules + "PLC0208", # iteration-over-set + "PLC2801", # unnecessary-dunder-call + "PLC0414", # useless-import-alias + "PLE0604", # invalid-all-object + "PLE0605", # invalid-all-format + "PLR0402", # manual-from-import + "PLR1711", # useless-return + "PLR1714", # repeated-equality-comparison + "RUF013", # implicit-optional + "RUF019", # unnecessary-key-check + "RUF100", # unused-noqa + "RUF101", # redirected-noqa + "RUF200", # invalid-pyproject-toml + "RUF022", # unsorted-dunder-all + "S506", # unsafe-yaml-load + "SIM", # flake8-simplify rules + "TRY400", # error-instead-of-exception + "TRY401", # verbose-log-message + "UP", # pyupgrade rules + "W191", # tab-indentation + "W605", # invalid-escape-sequence +] + +ignore = [ + "E402", # module-import-not-at-top-of-file + "E711", # none-comparison + "E712", # true-false-comparison + "E721", # type-comparison + "E722", # bare-except + "E731", # lambda-assignment + "F821", # undefined-name + "F841", # unused-variable + "FURB113", # repeated-append + "FURB152", # math-constant + "UP007", # non-pep604-annotation + "UP032", # f-string + "B005", # strip-with-multi-characters + "B006", # mutable-argument-default + "B007", # unused-loop-control-variable + "B026", # star-arg-unpacking-after-keyword-arg + "B904", # raise-without-from-inside-except + "B905", # zip-without-explicit-strict + "N806", # non-lowercase-variable-in-function + "N815", # mixed-case-variable-in-class-scope + "PT011", # pytest-raises-too-broad + "SIM102", # collapsible-if + "SIM103", # needless-bool + "SIM105", # suppressible-exception + "SIM107", # return-in-try-except-finally + "SIM108", # if-else-block-instead-of-if-exp + "SIM113", # eumerate-for-loop + "SIM117", # multiple-with-statements + "SIM210", # if-expr-with-true-false + "SIM300", # yoda-conditions, +] + +[lint.per-file-ignores] +"__init__.py" = [ + "F401", # unused-import + "F811", # redefined-while-unused +] +"configs/*" = [ + "N802", # invalid-function-name +] +"libs/gmpy2_pkcs10aep_cipher.py" = [ + "N803", # invalid-argument-name +] +"tests/*" = [ + "F811", # redefined-while-unused + "F401", # unused-import +] + +[lint.pyflakes] +extend-generics = [ + "_pytest.monkeypatch", + "tests.integration_tests", +] diff --git a/api/Dockerfile b/api/Dockerfile index e7b64f1107..b5b8f69829 100644 --- a/api/Dockerfile +++ b/api/Dockerfile @@ -55,7 +55,7 @@ RUN apt-get update \ && echo "deb http://deb.debian.org/debian testing main" > /etc/apt/sources.list \ && apt-get update \ # For Security - && apt-get install -y --no-install-recommends expat=2.6.4-1 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-7 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \ + && apt-get install -y --no-install-recommends expat=2.6.4-1 libldap-2.5-0=2.5.18+dfsg-3+b1 perl=5.40.0-8 libsqlite3-0=3.46.1-1 zlib1g=1:1.3.dfsg+really1.3.1-1+b1 \ # install a chinese font to support the use of tools like matplotlib && apt-get install -y fonts-noto-cjk \ && apt-get autoremove -y \ diff --git a/api/app.py b/api/app.py index a667a84fd6..996e2e890f 100644 --- a/api/app.py +++ b/api/app.py @@ -1,111 +1,13 @@ -import os -import sys - -from configs import dify_config - -if not dify_config.DEBUG: - from gevent import monkey - - monkey.patch_all() - - import grpc.experimental.gevent - - grpc.experimental.gevent.init_gevent() - -import json -import threading -import time -import warnings - -from flask import Response - from app_factory import create_app +from libs import threadings_utils, version_utils -# DO NOT REMOVE BELOW -from events import event_handlers # noqa: F401 -from extensions.ext_database import db - -# TODO: Find a way to avoid importing models here -from models import account, dataset, model, source, task, tool, tools, web # noqa: F401 - -# DO NOT REMOVE ABOVE - -if sys.version_info[:2] == (3, 10): - print("Warning: Python 3.10 will not be supported in the next version.") - - -warnings.simplefilter("ignore", ResourceWarning) - -os.environ["TZ"] = "UTC" -# windows platform not support tzset -if hasattr(time, "tzset"): - time.tzset() - +# preparation before creating app +version_utils.check_supported_python_version() +threadings_utils.apply_gevent_threading_patch() # create app app = create_app() celery = app.extensions["celery"] -if dify_config.TESTING: - print("App is running in TESTING mode") - - -@app.after_request -def after_request(response): - """Add Version headers to the response.""" - response.headers.add("X-Version", dify_config.CURRENT_VERSION) - response.headers.add("X-Env", dify_config.DEPLOY_ENV) - return response - - -@app.route("/health") -def health(): - return Response( - json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}), - status=200, - content_type="application/json", - ) - - -@app.route("/threads") -def threads(): - num_threads = threading.active_count() - threads = threading.enumerate() - - thread_list = [] - for thread in threads: - thread_name = thread.name - thread_id = thread.ident - is_alive = thread.is_alive() - - thread_list.append( - { - "name": thread_name, - "id": thread_id, - "is_alive": is_alive, - } - ) - - return { - "pid": os.getpid(), - "thread_num": num_threads, - "threads": thread_list, - } - - -@app.route("/db-pool-stat") -def pool_stat(): - engine = db.engine - return { - "pid": os.getpid(), - "pool_size": engine.pool.size(), - "checked_in_connections": engine.pool.checkedin(), - "checked_out_connections": engine.pool.checkedout(), - "overflow_connections": engine.pool.overflow(), - "connection_timeout": engine.pool.timeout(), - "recycle_time": db.engine.pool._recycle, - } - - if __name__ == "__main__": app.run(host="0.0.0.0", port=5001) diff --git a/api/app_factory.py b/api/app_factory.py index 60a584798b..7dc08c4d93 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -1,54 +1,15 @@ +import logging import os +import time from configs import dify_config - -if not dify_config.DEBUG: - from gevent import monkey - - monkey.patch_all() - - import grpc.experimental.gevent - - grpc.experimental.gevent.init_gevent() - -import json - -from flask import Flask, Response, request -from flask_cors import CORS -from werkzeug.exceptions import Unauthorized - -import contexts -from commands import register_commands -from configs import dify_config -from extensions import ( - ext_celery, - ext_code_based_extension, - ext_compress, - ext_database, - ext_hosting_provider, - ext_logging, - ext_login, - ext_mail, - ext_migrate, - ext_proxy_fix, - ext_redis, - ext_sentry, - ext_storage, -) -from extensions.ext_database import db -from extensions.ext_login import login_manager -from libs.passport import PassportService -from services.account_service import AccountService - - -class DifyApp(Flask): - pass +from dify_app import DifyApp # ---------------------------- # Application Factory Function # ---------------------------- -def create_flask_app_with_configs() -> Flask: +def create_flask_app_with_configs() -> DifyApp: """ create a raw flask app with configs loaded from .env file @@ -68,111 +29,72 @@ def create_flask_app_with_configs() -> Flask: return dify_app -def create_app() -> Flask: +def create_app() -> DifyApp: + start_time = time.perf_counter() app = create_flask_app_with_configs() - app.secret_key = dify_config.SECRET_KEY initialize_extensions(app) - register_blueprints(app) - register_commands(app) - + end_time = time.perf_counter() + if dify_config.DEBUG: + logging.info(f"Finished create_app ({round((end_time - start_time) * 1000, 2)} ms)") return app -def initialize_extensions(app): - # Since the application instance is now created, pass it to each Flask - # extension instance to bind it to the Flask application instance (app) - ext_logging.init_app(app) - ext_compress.init_app(app) - ext_code_based_extension.init() - ext_database.init_app(app) - ext_migrate.init(app, db) - ext_redis.init_app(app) - ext_storage.init_app(app) - ext_celery.init_app(app) - ext_login.init_app(app) - ext_mail.init_app(app) - ext_hosting_provider.init_app(app) - ext_sentry.init_app(app) - ext_proxy_fix.init_app(app) - - -# Flask-Login configuration -@login_manager.request_loader -def load_user_from_request(request_from_flask_login): - """Load user based on the request.""" - if request.blueprint not in {"console", "inner_api"}: - return None - # Check if the user_id contains a dot, indicating the old format - auth_header = request.headers.get("Authorization", "") - if not auth_header: - auth_token = request.args.get("_token") - if not auth_token: - raise Unauthorized("Invalid Authorization token.") - else: - if " " not in auth_header: - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - auth_scheme, auth_token = auth_header.split(None, 1) - auth_scheme = auth_scheme.lower() - if auth_scheme != "bearer": - raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") - - decoded = PassportService().verify(auth_token) - user_id = decoded.get("user_id") - - logged_in_account = AccountService.load_logged_in_account(account_id=user_id) - if logged_in_account: - contexts.tenant_id.set(logged_in_account.current_tenant_id) - return logged_in_account - - -@login_manager.unauthorized_handler -def unauthorized_handler(): - """Handle unauthorized requests.""" - return Response( - json.dumps({"code": "unauthorized", "message": "Unauthorized."}), - status=401, - content_type="application/json", - ) - - -# register blueprint routers -def register_blueprints(app): - from controllers.console import bp as console_app_bp - from controllers.files import bp as files_bp - from controllers.inner_api import bp as inner_api_bp - from controllers.service_api import bp as service_api_bp - from controllers.web import bp as web_bp - - CORS( - service_api_bp, - allow_headers=["Content-Type", "Authorization", "X-App-Code"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], +def initialize_extensions(app: DifyApp): + from extensions import ( + ext_app_metrics, + ext_blueprints, + ext_celery, + ext_code_based_extension, + ext_commands, + ext_compress, + ext_database, + ext_hosting_provider, + ext_import_modules, + ext_logging, + ext_login, + ext_mail, + ext_migrate, + ext_proxy_fix, + ext_redis, + ext_sentry, + ext_set_secretkey, + ext_storage, + ext_timezone, + ext_warnings, ) - app.register_blueprint(service_api_bp) - - CORS( - web_bp, - resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, - supports_credentials=True, - allow_headers=["Content-Type", "Authorization", "X-App-Code"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - expose_headers=["X-Version", "X-Env"], - ) - - app.register_blueprint(web_bp) - - CORS( - console_app_bp, - resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, - supports_credentials=True, - allow_headers=["Content-Type", "Authorization"], - methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], - expose_headers=["X-Version", "X-Env"], - ) - - app.register_blueprint(console_app_bp) - - CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) - app.register_blueprint(files_bp) - app.register_blueprint(inner_api_bp) + extensions = [ + ext_timezone, + ext_logging, + ext_warnings, + ext_import_modules, + ext_set_secretkey, + ext_compress, + ext_code_based_extension, + ext_database, + ext_app_metrics, + ext_migrate, + ext_redis, + ext_storage, + ext_celery, + ext_login, + ext_mail, + ext_hosting_provider, + ext_sentry, + ext_proxy_fix, + ext_blueprints, + ext_commands, + ] + for ext in extensions: + short_name = ext.__name__.split(".")[-1] + is_enabled = ext.is_enabled() if hasattr(ext, "is_enabled") else True + if not is_enabled: + if dify_config.DEBUG: + logging.info(f"Skipped {short_name}") + continue + + start_time = time.perf_counter() + ext.init_app(app) + end_time = time.perf_counter() + if dify_config.DEBUG: + logging.info(f"Loaded {short_name} ({round((end_time - start_time) * 1000, 2)} ms)") diff --git a/api/commands.py b/api/commands.py index 23787f38bf..b6f3b52d04 100644 --- a/api/commands.py +++ b/api/commands.py @@ -640,15 +640,3 @@ where sites.id is null limit 1000""" break click.echo(click.style("Fix for missing app-related sites completed successfully!", fg="green")) - - -def register_commands(app): - app.cli.add_command(reset_password) - app.cli.add_command(reset_email) - app.cli.add_command(reset_encrypt_key_pair) - app.cli.add_command(vdb_migrate) - app.cli.add_command(convert_to_agent_apps) - app.cli.add_command(add_qdrant_doc_id_index) - app.cli.add_command(create_tenant) - app.cli.add_command(upgrade_db) - app.cli.add_command(fix_app_site_missing) diff --git a/api/configs/deploy/__init__.py b/api/configs/deploy/__init__.py index 66d6a55b4c..950936d3c6 100644 --- a/api/configs/deploy/__init__.py +++ b/api/configs/deploy/__init__.py @@ -17,11 +17,6 @@ class DeploymentConfig(BaseSettings): default=False, ) - TESTING: bool = Field( - description="Enable testing mode for running automated tests", - default=False, - ) - EDITION: str = Field( description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')", default="SELF_HOSTED", diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 99f86be12e..f1cb3efda7 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -585,6 +585,11 @@ class RagEtlConfig(BaseSettings): default=None, ) + SCARF_NO_ANALYTICS: Optional[str] = Field( + description="This is about whether to disable Scarf analytics in Unstructured library.", + default="false", + ) + class DataSetConfig(BaseSettings): """ @@ -640,7 +645,7 @@ class IndexingConfig(BaseSettings): INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: PositiveInt = Field( description="Maximum token length for text segmentation during indexing", - default=1000, + default=4000, ) diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index 1f2b8224e8..08f8728148 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="0.11.2", + default="0.13.0", ) COMMIT_SHA: str = Field( diff --git a/api/constants/languages.py b/api/constants/languages.py index a6394da819..1157ec4307 100644 --- a/api/constants/languages.py +++ b/api/constants/languages.py @@ -18,6 +18,7 @@ language_timezone_mapping = { "tr-TR": "Europe/Istanbul", "fa-IR": "Asia/Tehran", "sl-SI": "Europe/Ljubljana", + "th-TH": "Asia/Bangkok", } languages = list(language_timezone_mapping.keys()) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 9687b59cd1..da72b704c7 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -190,7 +190,7 @@ class AppCopyApi(Resource): ) session.commit() - stmt = select(App).where(App.id == result.app.id) + stmt = select(App).where(App.id == result.app_id) app = session.scalar(stmt) return app, 201 diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 7b78f622b9..a25004be4d 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime import pytz from flask_login import current_user @@ -314,7 +314,7 @@ def _get_conversation(app_model, conversation_id): raise NotFound("Conversation Not Exists.") if not conversation.read_at: - conversation.read_at = datetime.now(timezone.utc).replace(tzinfo=None) + conversation.read_at = datetime.now(UTC).replace(tzinfo=None) conversation.read_account_id = current_user.id db.session.commit() diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 2f5645852f..407f689819 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse @@ -75,7 +75,7 @@ class AppSite(Resource): setattr(site, attr_name, value) site.updated_by = current_user.id - site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + site.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return site @@ -99,7 +99,7 @@ class AppSiteAccessTokenReset(Resource): site.code = Site.generate_code(16) site.updated_by = current_user.id - site.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + site.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return site diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index cc05a0d509..c85d554069 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -100,11 +100,11 @@ class DraftWorkflowApi(Resource): try: environment_variables_list = args.get("environment_variables") or [] environment_variables = [ - variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list ] conversation_variables_list = args.get("conversation_variables") or [] conversation_variables = [ - variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] workflow = workflow_service.sync_draft_workflow( app_model=app_model, @@ -382,7 +382,7 @@ class DefaultBlockConfigApi(Resource): filters = None if args.get("q"): try: - filters = json.loads(args.get("q")) + filters = json.loads(args.get("q", "")) except json.JSONDecodeError: raise ValueError("Invalid filters") diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index be353cefac..d2aa7c903b 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -65,7 +65,7 @@ class ActivateApi(Resource): account.timezone = args["timezone"] account.interface_theme = "light" account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 3c3f45260a..faca67bb17 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -34,7 +34,6 @@ class OAuthDataSource(Resource): OAUTH_DATASOURCE_PROVIDERS = get_oauth_providers() with current_app.app_context(): oauth_provider = OAUTH_DATASOURCE_PROVIDERS.get(provider) - print(vars(oauth_provider)) if not oauth_provider: return {"error": "Invalid provider"}, 400 if dify_config.NOTION_INTEGRATION_TYPE == "internal": diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index d27e3353c9..5de8c6766d 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Optional import requests @@ -52,7 +52,6 @@ class OAuthLogin(Resource): OAUTH_PROVIDERS = get_oauth_providers() with current_app.app_context(): oauth_provider = OAUTH_PROVIDERS.get(provider) - print(vars(oauth_provider)) if not oauth_provider: return {"error": "Invalid provider"}, 400 @@ -106,7 +105,7 @@ class OAuthCallback(Resource): if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() try: diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index ef1e87905a..278295ca39 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -83,7 +83,7 @@ class DataSourceApi(Resource): if action == "enable": if data_source_binding.disabled: data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: @@ -92,7 +92,7 @@ class DataSourceApi(Resource): if action == "disable": if not data_source_binding.disabled: data_source_binding.disabled = True - data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(data_source_binding) db.session.commit() else: diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index f38408525a..de3b4f6262 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,6 +1,6 @@ import logging from argparse import ArgumentTypeError -from datetime import datetime, timezone +from datetime import UTC, datetime from flask import request from flask_login import current_user @@ -106,6 +106,7 @@ class GetProcessRuleApi(Resource): # get default rules mode = DocumentService.DEFAULT_RULES["mode"] rules = DocumentService.DEFAULT_RULES["rules"] + limits = DocumentService.DEFAULT_RULES["limits"] if document_id: # get the latest process rule document = Document.query.get_or_404(document_id) @@ -132,7 +133,7 @@ class GetProcessRuleApi(Resource): mode = dataset_process_rule.mode rules = dataset_process_rule.rules_dict - return {"mode": mode, "rules": rules} + return {"mode": mode, "rules": rules, "limits": limits} class DatasetDocumentListApi(Resource): @@ -665,7 +666,7 @@ class DocumentProcessingApi(DocumentResource): raise InvalidActionError("Document not in indexing state.") document.paused_by = current_user.id - document.paused_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.paused_at = datetime.now(UTC).replace(tzinfo=None) document.is_paused = True db.session.commit() @@ -745,7 +746,7 @@ class DocumentMetadataApi(DocumentResource): document.doc_metadata[key] = value document.doc_type = doc_type - document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return {"result": "success", "message": "Document metadata updated."}, 200 @@ -787,7 +788,7 @@ class DocumentStatusApi(DocumentResource): document.enabled = True document.disabled_at = None document.disabled_by = None - document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() # Set cache to prevent indexing the same document multiple times @@ -804,9 +805,9 @@ class DocumentStatusApi(DocumentResource): raise InvalidActionError("Document already disabled.") document.enabled = False - document.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.disabled_at = datetime.now(UTC).replace(tzinfo=None) document.disabled_by = current_user.id - document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() # Set cache to prevent indexing the same document multiple times @@ -821,9 +822,9 @@ class DocumentStatusApi(DocumentResource): raise InvalidActionError("Document already archived.") document.archived = True - document.archived_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.archived_at = datetime.now(UTC).replace(tzinfo=None) document.archived_by = current_user.id - document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() if document.enabled: @@ -840,7 +841,7 @@ class DocumentStatusApi(DocumentResource): document.archived = False document.archived_at = None document.archived_by = None - document.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() # Set cache to prevent indexing the same document multiple times diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 5d8d664e41..6f7ef86d2c 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime, timezone +from datetime import UTC, datetime import pandas as pd from flask import request @@ -188,7 +188,7 @@ class DatasetDocumentSegmentApi(Resource): raise InvalidActionError("Segment is already disabled.") segment.enabled = False - segment.disabled_at = datetime.now(timezone.utc).replace(tzinfo=None) + segment.disabled_at = datetime.now(UTC).replace(tzinfo=None) segment.disabled_by = current_user.id db.session.commit() diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 125bc1af8c..85c43f8101 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime, timezone +from datetime import UTC, datetime from flask_login import current_user from flask_restful import reqparse @@ -46,7 +46,7 @@ class CompletionApi(InstalledAppResource): streaming = args["response_mode"] == "streaming" args["auto_generate_name"] = False - installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) + installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() try: @@ -106,7 +106,7 @@ class ChatApi(InstalledAppResource): args["auto_generate_name"] = False - installed_app.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) + installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() try: diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index d72715a38c..b60c4e176b 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from flask_login import current_user from flask_restful import Resource, inputs, marshal_with, reqparse @@ -81,7 +81,7 @@ class InstalledAppsListApi(Resource): tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, - last_used_at=datetime.now(timezone.utc).replace(tzinfo=None), + last_used_at=datetime.now(UTC).replace(tzinfo=None), ) db.session.add(new_installed_app) db.session.commit() diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index 750f65168f..f704783cff 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -60,7 +60,7 @@ class AccountInitApi(Resource): raise InvalidInvitationCodeError() invitation_code.status = "used" - invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id @@ -68,7 +68,7 @@ class AccountInitApi(Resource): account.timezone = args["timezone"] account.interface_theme = "light" account.status = "active" - account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() return {"result": "success"} diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py index 88b13faa52..ecff7d07e9 100644 --- a/api/controllers/service_api/app/app.py +++ b/api/controllers/service_api/app/app.py @@ -48,7 +48,8 @@ class AppInfoApi(Resource): @validate_app_token def get(self, app_model: App): """Get app information""" - return {"name": app_model.name, "description": app_model.description} + tags = [tag.name for tag in app_model.tags] + return {"name": app_model.name, "description": app_model.description, "tags": tags} api.add_resource(AppParameterApi, "/parameters") diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index b935b23ed6..2128c4c53f 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from datetime import datetime, timezone +from datetime import UTC, datetime from enum import Enum from functools import wraps from typing import Optional @@ -198,7 +198,7 @@ def validate_and_get_api_token(scope=None): if not api_token: raise Unauthorized("Access token is invalid") - api_token.last_used_at = datetime.now(timezone.utc).replace(tzinfo=None) + api_token.last_used_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return api_token diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 2f5e7c7793..ead293200e 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -2,7 +2,7 @@ import json import logging import uuid from collections.abc import Mapping, Sequence -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Optional, Union, cast from core.agent.entities import AgentEntity, AgentToolEntity @@ -412,7 +412,7 @@ class BaseAgentRunner(AppRunner): .first() ) - db_variables.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + db_variables.updated_at = datetime.now(UTC).replace(tzinfo=None) db_variables.variables_str = json.dumps(jsonable_encoder(tool_variables.pool)) db.session.commit() db.session.close() diff --git a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py index a22395b8e3..b9aae7904f 100644 --- a/api/core/app/app_config/easy_ui_based_app/dataset/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/dataset/manager.py @@ -1,3 +1,4 @@ +import uuid from typing import Optional from core.app.app_config.entities import DatasetEntity, DatasetRetrieveConfigEntity diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index b5e4554181..5adcf26f14 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -1,3 +1,6 @@ +from collections.abc import Mapping +from typing import Any + from core.app.app_config.entities import ModelConfigEntity from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.model_providers import model_provider_factory @@ -36,7 +39,7 @@ class ModelConfigManager: ) @classmethod - def validate_and_set_defaults(cls, tenant_id: str, config: dict) -> tuple[dict, list[str]]: + def validate_and_set_defaults(cls, tenant_id: str, config: Mapping[str, Any]) -> tuple[dict, list[str]]: """ Validate and set defaults for model config diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index 82a0e56ce8..fa30511f63 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,4 +1,5 @@ from core.app.app_config.entities import ( + AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, AdvancedCompletionPromptTemplateEntity, PromptTemplateEntity, @@ -25,7 +26,9 @@ class PromptTemplateConfigManager: chat_prompt_messages = [] for message in chat_prompt_config.get("prompt", []): chat_prompt_messages.append( - {"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} + AdvancedChatMessageEntity( + **{"text": message["text"], "role": PromptMessageRole.value_of(message["role"])} + ) ) advanced_chat_prompt_template = AdvancedChatPromptTemplateEntity(messages=chat_prompt_messages) diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 9b72452d7a..15bd353484 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -1,5 +1,5 @@ from collections.abc import Sequence -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Optional from pydantic import BaseModel, Field, field_validator @@ -88,7 +88,7 @@ class PromptTemplateEntity(BaseModel): advanced_completion_prompt_template: Optional[AdvancedCompletionPromptTemplateEntity] = None -class VariableEntityType(str, Enum): +class VariableEntityType(StrEnum): TEXT_INPUT = "text-input" SELECT = "select" PARAGRAPH = "paragraph" diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 00e5a74732..bd4fd9cd3b 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -2,8 +2,8 @@ import contextvars import logging import threading import uuid -from collections.abc import Generator -from typing import Any, Literal, Optional, Union, overload +from collections.abc import Generator, Mapping +from typing import Any, Optional, Union from flask import Flask, current_app from pydantic import ValidationError @@ -23,6 +23,7 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.ops.ops_trace_manager import TraceQueueManager +from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from extensions.ext_database import db from factories import file_factory from models.account import Account @@ -33,37 +34,17 @@ logger = logging.getLogger(__name__) class AdvancedChatAppGenerator(MessageBasedAppGenerator): - @overload - def generate( - self, - app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - args: dict, - invoke_from: InvokeFrom, - stream: Literal[True] = True, - ) -> Generator[str, None, None]: ... - - @overload - def generate( - self, - app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - args: dict, - invoke_from: InvokeFrom, - stream: Literal[False] = False, - ) -> dict: ... + _dialogue_count: int def generate( self, app_model: App, workflow: Workflow, user: Union[Account, EndUser], - args: dict, + args: Mapping[str, Any], invoke_from: InvokeFrom, - stream: bool = True, - ) -> dict[str, Any] | Generator[str, Any, None]: + streaming: bool = True, + ) -> Mapping[str, Any] | Generator[str, None, None]: """ Generate App response. @@ -127,12 +108,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), + else self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, - stream=stream, + stream=streaming, invoke_from=invoke_from, extras=extras, trace_manager=trace_manager, @@ -146,12 +129,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from=invoke_from, application_generate_entity=application_generate_entity, conversation=conversation, - stream=stream, + stream=streaming, ) def single_iteration_generate( - self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True - ) -> dict[str, Any] | Generator[str, Any, None]: + self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, streaming: bool = True + ) -> Mapping[str, Any] | Generator[str, None, None]: """ Generate App response. @@ -180,7 +163,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): query="", files=[], user_id=user.id, - stream=stream, + stream=streaming, invoke_from=InvokeFrom.DEBUGGER, extras={"auto_generate_conversation_name": False}, single_iteration_run=AdvancedChatAppGenerateEntity.SingleIterationRunEntity( @@ -195,7 +178,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, conversation=None, - stream=stream, + stream=streaming, ) def _generate( @@ -207,7 +190,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): application_generate_entity: AdvancedChatAppGenerateEntity, conversation: Optional[Conversation] = None, stream: bool = True, - ) -> dict[str, Any] | Generator[str, Any, None]: + ) -> Mapping[str, Any] | Generator[str, None, None]: """ Generate App response. @@ -231,6 +214,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): db.session.commit() db.session.refresh(conversation) + # get conversation dialogue count + self._dialogue_count = get_thread_messages_length(conversation.id) + # init queue manager queue_manager = MessageBasedAppQueueManager( task_id=application_generate_entity.task_id, @@ -301,6 +287,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): queue_manager=queue_manager, conversation=conversation, message=message, + dialogue_count=self._dialogue_count, ) runner.run() @@ -354,6 +341,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): message=message, user=user, stream=stream, + dialogue_count=self._dialogue_count, ) try: diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 65d744eddf..cf0c9d7593 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -39,12 +39,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): queue_manager: AppQueueManager, conversation: Conversation, message: Message, + dialogue_count: int, ) -> None: super().__init__(queue_manager) self.application_generate_entity = application_generate_entity self.conversation = conversation self.message = message + self._dialogue_count = dialogue_count def run(self) -> None: app_config = self.application_generate_entity.app_config @@ -122,19 +124,13 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): session.commit() - # Increment dialogue count. - self.conversation.dialogue_count += 1 - - conversation_dialogue_count = self.conversation.dialogue_count - db.session.commit() - # Create a variable pool. system_inputs = { SystemVariableKey.QUERY: query, SystemVariableKey.FILES: files, SystemVariableKey.CONVERSATION_ID: self.conversation.id, SystemVariableKey.USER_ID: user_id, - SystemVariableKey.DIALOGUE_COUNT: conversation_dialogue_count, + SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count, SystemVariableKey.APP_ID: app_config.app_id, SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, SystemVariableKey.WORKFLOW_RUN_ID: self.application_generate_entity.workflow_run_id, diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index e1798957b9..cd12690e28 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -88,6 +88,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc message: Message, user: Union[Account, EndUser], stream: bool, + dialogue_count: int, ) -> None: """ Initialize AdvancedChatAppGenerateTaskPipeline. @@ -98,6 +99,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc :param message: message :param user: user :param stream: stream + :param dialogue_count: dialogue count """ super().__init__(application_generate_entity, queue_manager, user, stream) @@ -114,7 +116,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc SystemVariableKey.FILES: application_generate_entity.files, SystemVariableKey.CONVERSATION_ID: conversation.id, SystemVariableKey.USER_ID: user_id, - SystemVariableKey.DIALOGUE_COUNT: conversation.dialogue_count, + SystemVariableKey.DIALOGUE_COUNT: dialogue_count, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, @@ -125,6 +127,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc self._conversation_name_generate_thread = None self._recorded_files: list[Mapping[str, Any]] = [] + self.total_tokens: int = 0 def process(self): """ @@ -358,6 +361,8 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc if not workflow_run: raise Exception("Workflow run not initialized.") + # FIXME for issue #11221 quick fix maybe have a better solution + self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0 yield self._workflow_iteration_completed_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) @@ -371,7 +376,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc workflow_run = self._handle_workflow_run_success( workflow_run=workflow_run, start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, + total_tokens=graph_runtime_state.total_tokens or self.total_tokens, total_steps=graph_runtime_state.node_run_steps, outputs=event.outputs, conversation_id=self._conversation.id, diff --git a/api/core/app/apps/agent_chat/app_config_manager.py b/api/core/app/apps/agent_chat/app_config_manager.py index 9040f18bfd..417d23eccf 100644 --- a/api/core/app/apps/agent_chat/app_config_manager.py +++ b/api/core/app/apps/agent_chat/app_config_manager.py @@ -1,5 +1,6 @@ import uuid -from typing import Optional +from collections.abc import Mapping +from typing import Any, Optional from core.agent.entities import AgentEntity from core.app.app_config.base_app_config_manager import BaseAppConfigManager @@ -85,7 +86,7 @@ class AgentChatAppConfigManager(BaseAppConfigManager): return app_config @classmethod - def config_validate(cls, tenant_id: str, config: dict) -> dict: + def config_validate(cls, tenant_id: str, config: Mapping[str, Any]) -> dict: """ Validate for agent chat app model config diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index d1564a260e..b659c18556 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -1,8 +1,8 @@ import logging import threading import uuid -from collections.abc import Generator -from typing import Any, Literal, Union, overload +from collections.abc import Generator, Mapping +from typing import Any, Union from flask import Flask, current_app from pydantic import ValidationError @@ -28,34 +28,15 @@ logger = logging.getLogger(__name__) class AgentChatAppGenerator(MessageBasedAppGenerator): - @overload def generate( self, + *, app_model: App, user: Union[Account, EndUser], - args: dict, + args: Mapping[str, Any], invoke_from: InvokeFrom, - stream: Literal[True] = True, - ) -> Generator[dict, None, None]: ... - - @overload - def generate( - self, - app_model: App, - user: Union[Account, EndUser], - args: dict, - invoke_from: InvokeFrom, - stream: Literal[False] = False, - ) -> dict: ... - - def generate( - self, - app_model: App, - user: Union[Account, EndUser], - args: Any, - invoke_from: InvokeFrom, - stream: bool = True, - ) -> Union[dict, Generator[dict, None, None]]: + streaming: bool = True, + ) -> Mapping[str, Any] | Generator[str, None, None]: """ Generate App response. @@ -65,7 +46,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): :param invoke_from: invoke from source :param stream: is stream """ - if not stream: + if not streaming: raise ValueError("Agent Chat App does not support blocking mode") if not args.get("query"): @@ -96,7 +77,8 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): # validate config override_model_config_dict = AgentChatAppConfigManager.config_validate( - tenant_id=app_model.tenant_id, config=args.get("model_config") + tenant_id=app_model.tenant_id, + config=args["model_config"], ) # always enable retriever resource in debugger mode @@ -134,12 +116,14 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), + else self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, user_id=user.id, - stream=stream, + stream=streaming, invoke_from=invoke_from, extras=extras, call_depth=0, @@ -180,7 +164,7 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, user=user, - stream=stream, + stream=streaming, ) return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index 62e79ec444..210609b504 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -1,6 +1,6 @@ import logging from abc import ABC, abstractmethod -from collections.abc import Generator +from collections.abc import Generator, Mapping from typing import Any, Union from core.app.entities.app_invoke_entities import InvokeFrom @@ -14,8 +14,10 @@ class AppGenerateResponseConverter(ABC): @classmethod def convert( - cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom - ) -> dict[str, Any] | Generator[str, Any, None]: + cls, + response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], + invoke_from: InvokeFrom, + ) -> Mapping[str, Any] | Generator[str, None, None]: if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}: if isinstance(response, AppBlockingResponse): return cls.convert_blocking_full_response(response) diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 2c78d95778..85b7aced55 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,4 +1,4 @@ -from collections.abc import Mapping +from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Optional from core.app.app_config.entities import VariableEntityType @@ -6,7 +6,7 @@ from core.file import File, FileUploadConfig from factories import file_factory if TYPE_CHECKING: - from core.app.app_config.entities import AppConfig, VariableEntity + from core.app.app_config.entities import VariableEntity class BaseAppGenerator: @@ -14,23 +14,23 @@ class BaseAppGenerator: self, *, user_inputs: Optional[Mapping[str, Any]], - app_config: "AppConfig", + variables: Sequence["VariableEntity"], + tenant_id: str, ) -> Mapping[str, Any]: user_inputs = user_inputs or {} # Filter input variables from form configuration, handle required fields, default values, and option values - variables = app_config.variables user_inputs = { var.variable: self._validate_inputs(value=user_inputs.get(var.variable), variable_entity=var) for var in variables } user_inputs = {k: self._sanitize_value(v) for k, v in user_inputs.items()} # Convert files in inputs to File - entity_dictionary = {item.variable: item for item in app_config.variables} + entity_dictionary = {item.variable: item for item in variables} # Convert single file to File files_inputs = { k: file_factory.build_from_mapping( mapping=v, - tenant_id=app_config.tenant_id, + tenant_id=tenant_id, config=FileUploadConfig( allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, @@ -44,7 +44,7 @@ class BaseAppGenerator: file_list_inputs = { k: file_factory.build_from_mappings( mappings=v, - tenant_id=app_config.tenant_id, + tenant_id=tenant_id, config=FileUploadConfig( allowed_file_types=entity_dictionary[k].allowed_file_types, allowed_file_extensions=entity_dictionary[k].allowed_file_extensions, diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index e683dfef3f..6a9e162388 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -55,7 +55,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, - stream: bool = True, + streaming: bool = True, ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. @@ -132,7 +132,9 @@ class ChatAppGenerator(MessageBasedAppGenerator): conversation_id=conversation.id if conversation else None, inputs=conversation.inputs if conversation - else self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), + else self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), query=query, files=file_objs, parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, @@ -140,7 +142,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): invoke_from=invoke_from, extras=extras, trace_manager=trace_manager, - stream=stream, + stream=streaming, ) # init generate records @@ -177,7 +179,7 @@ class ChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, user=user, - stream=stream, + stream=streaming, ) return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 22ee8b0967..324e837a1c 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -50,7 +50,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): ) -> dict: ... def generate( - self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, stream: bool = True + self, app_model: App, user: Union[Account, EndUser], args: Any, invoke_from: InvokeFrom, streaming: bool = True ) -> Union[dict, Generator[str, None, None]]: """ Generate App response. @@ -113,11 +113,13 @@ class CompletionAppGenerator(MessageBasedAppGenerator): app_config=app_config, model_conf=ModelConfigConverter.convert(app_config), file_upload_config=file_extra_config, - inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), query=query, files=file_objs, user_id=user.id, - stream=stream, + stream=streaming, invoke_from=invoke_from, extras=extras, trace_manager=trace_manager, @@ -156,7 +158,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, user=user, - stream=stream, + stream=streaming, ) return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index da206f01e7..95ae798ec1 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Generator -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Optional, Union from sqlalchemy import and_ @@ -200,7 +200,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.commit() db.session.refresh(conversation) else: - conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() message = Message( diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 65da39b220..7acf05326e 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -3,7 +3,7 @@ import logging import threading import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Literal, Optional, Union, overload +from typing import Any, Optional, Union from flask import Flask, current_app from pydantic import ValidationError @@ -30,43 +30,18 @@ logger = logging.getLogger(__name__) class WorkflowAppGenerator(BaseAppGenerator): - @overload - def generate( - self, - app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - args: dict, - invoke_from: InvokeFrom, - stream: Literal[True] = True, - call_depth: int = 0, - workflow_thread_pool_id: Optional[str] = None, - ) -> Generator[str, None, None]: ... - - @overload - def generate( - self, - app_model: App, - workflow: Workflow, - user: Union[Account, EndUser], - args: dict, - invoke_from: InvokeFrom, - stream: Literal[False] = False, - call_depth: int = 0, - workflow_thread_pool_id: Optional[str] = None, - ) -> dict: ... - def generate( self, + *, app_model: App, workflow: Workflow, - user: Union[Account, EndUser], + user: Account | EndUser, args: Mapping[str, Any], invoke_from: InvokeFrom, - stream: bool = True, + streaming: bool = True, call_depth: int = 0, workflow_thread_pool_id: Optional[str] = None, - ): + ) -> Mapping[str, Any] | Generator[str, None, None]: files: Sequence[Mapping[str, Any]] = args.get("files") or [] # parse files @@ -96,10 +71,12 @@ class WorkflowAppGenerator(BaseAppGenerator): task_id=str(uuid.uuid4()), app_config=app_config, file_upload_config=file_extra_config, - inputs=self._prepare_user_inputs(user_inputs=inputs, app_config=app_config), + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), files=system_files, user_id=user.id, - stream=stream, + stream=streaming, invoke_from=invoke_from, call_depth=call_depth, trace_manager=trace_manager, @@ -113,7 +90,7 @@ class WorkflowAppGenerator(BaseAppGenerator): user=user, application_generate_entity=application_generate_entity, invoke_from=invoke_from, - stream=stream, + streaming=streaming, workflow_thread_pool_id=workflow_thread_pool_id, ) @@ -125,20 +102,9 @@ class WorkflowAppGenerator(BaseAppGenerator): user: Union[Account, EndUser], application_generate_entity: WorkflowAppGenerateEntity, invoke_from: InvokeFrom, - stream: bool = True, + streaming: bool = True, workflow_thread_pool_id: Optional[str] = None, - ) -> dict[str, Any] | Generator[str, None, None]: - """ - Generate App response. - - :param app_model: App - :param workflow: Workflow - :param user: account or end user - :param application_generate_entity: application generate entity - :param invoke_from: invoke from source - :param stream: is stream - :param workflow_thread_pool_id: workflow thread pool id - """ + ) -> Mapping[str, Any] | Generator[str, None, None]: # init queue manager queue_manager = WorkflowAppQueueManager( task_id=application_generate_entity.task_id, @@ -167,14 +133,20 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow=workflow, queue_manager=queue_manager, user=user, - stream=stream, + stream=streaming, ) return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def single_iteration_generate( - self, app_model: App, workflow: Workflow, node_id: str, user: Account, args: dict, stream: bool = True - ) -> dict[str, Any] | Generator[str, Any, None]: + self, + app_model: App, + workflow: Workflow, + node_id: str, + user: Account, + args: Mapping[str, Any], + streaming: bool = True, + ) -> Mapping[str, Any] | Generator[str, None, None]: """ Generate App response. @@ -201,7 +173,7 @@ class WorkflowAppGenerator(BaseAppGenerator): inputs={}, files=[], user_id=user.id, - stream=stream, + stream=streaming, invoke_from=InvokeFrom.DEBUGGER, extras={"auto_generate_conversation_name": False}, single_iteration_run=WorkflowAppGenerateEntity.SingleIterationRunEntity( @@ -216,7 +188,7 @@ class WorkflowAppGenerator(BaseAppGenerator): user=user, invoke_from=InvokeFrom.DEBUGGER, application_generate_entity=application_generate_entity, - stream=stream, + streaming=streaming, ) def _generate_worker( diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index 9e4921d6a2..9966a1a9d1 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -106,6 +106,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa self._task_state = WorkflowTaskState() self._wip_workflow_node_executions = {} + self.total_tokens: int = 0 def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: """ @@ -319,6 +320,8 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa if not workflow_run: raise Exception("Workflow run not initialized.") + # FIXME for issue #11221 quick fix maybe have a better solution + self.total_tokens += event.metadata.get("total_tokens", 0) if event.metadata else 0 yield self._workflow_iteration_completed_to_stream_response( task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event ) @@ -332,7 +335,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa workflow_run = self._handle_workflow_run_success( workflow_run=workflow_run, start_at=graph_runtime_state.start_at, - total_tokens=graph_runtime_state.total_tokens, + total_tokens=graph_runtime_state.total_tokens or self.total_tokens, total_steps=graph_runtime_state.node_run_steps, outputs=event.outputs, conversation_id=None, diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 2872390d46..3d46b8bab0 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -43,8 +43,7 @@ from core.workflow.graph_engine.entities.event import ( ) from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes import NodeType -from core.workflow.nodes.iteration import IterationNodeData -from core.workflow.nodes.node_mapping import node_type_classes_mapping +from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.model import App @@ -139,7 +138,8 @@ class WorkflowBasedAppRunner(AppRunner): # Get node class node_type = NodeType(iteration_node_config.get("data", {}).get("type")) - node_cls = node_type_classes_mapping[node_type] + node_version = iteration_node_config.get("data", {}).get("version", "1") + node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] # init variable pool variable_pool = VariablePool( @@ -160,8 +160,6 @@ class WorkflowBasedAppRunner(AppRunner): user_inputs=user_inputs, variable_pool=variable_pool, tenant_id=workflow.tenant_id, - node_type=node_type, - node_data=IterationNodeData(**iteration_node_config.get("data", {})), ) return graph, variable_pool diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index 69bc0d7f9e..15543638fc 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -1,5 +1,5 @@ from datetime import datetime -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Optional from pydantic import BaseModel, field_validator @@ -11,7 +11,7 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.base import BaseNodeData -class QueueEvent(str, Enum): +class QueueEvent(StrEnum): """ QueueEvent enum """ diff --git a/api/core/app/features/rate_limiting/rate_limit.py b/api/core/app/features/rate_limiting/rate_limit.py index 227182f5ab..154a49ebda 100644 --- a/api/core/app/features/rate_limiting/rate_limit.py +++ b/api/core/app/features/rate_limiting/rate_limit.py @@ -1,9 +1,9 @@ import logging import time import uuid -from collections.abc import Generator +from collections.abc import Generator, Mapping from datetime import timedelta -from typing import Optional, Union +from typing import Any, Optional, Union from core.errors.error import AppInvokeQuotaExceededError from extensions.ext_redis import redis_client @@ -88,20 +88,17 @@ class RateLimit: def gen_request_key() -> str: return str(uuid.uuid4()) - def generate(self, generator: Union[Generator, callable, dict], request_id: str): - if isinstance(generator, dict): + def generate(self, generator: Union[Generator[str, None, None], Mapping[str, Any]], request_id: str): + if isinstance(generator, Mapping): return generator else: - return RateLimitGenerator(self, generator, request_id) + return RateLimitGenerator(rate_limit=self, generator=generator, request_id=request_id) class RateLimitGenerator: - def __init__(self, rate_limit: RateLimit, generator: Union[Generator, callable], request_id: str): + def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str): self.rate_limit = rate_limit - if callable(generator): - self.generator = generator() - else: - self.generator = generator + self.generator = generator self.request_id = request_id self.closed = False diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/app/task_pipeline/workflow_cycle_manage.py index 9d776f6337..57a02f8bc8 100644 --- a/api/core/app/task_pipeline/workflow_cycle_manage.py +++ b/api/core/app/task_pipeline/workflow_cycle_manage.py @@ -1,8 +1,9 @@ import json import time from collections.abc import Mapping, Sequence -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any, Optional, Union, cast +from uuid import uuid4 from sqlalchemy.orm import Session @@ -80,38 +81,38 @@ class WorkflowCycleManage: inputs[f"sys.{key.value}"] = value - inputs = WorkflowEntry.handle_special_values(inputs) - triggered_from = ( WorkflowRunTriggeredFrom.DEBUGGING if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER else WorkflowRunTriggeredFrom.APP_RUN ) + # handle special values + inputs = WorkflowEntry.handle_special_values(inputs) + # init workflow run - workflow_run = WorkflowRun() - workflow_run_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID] - if workflow_run_id: - workflow_run.id = workflow_run_id - workflow_run.tenant_id = self._workflow.tenant_id - workflow_run.app_id = self._workflow.app_id - workflow_run.sequence_number = new_sequence_number - workflow_run.workflow_id = self._workflow.id - workflow_run.type = self._workflow.type - workflow_run.triggered_from = triggered_from.value - workflow_run.version = self._workflow.version - workflow_run.graph = self._workflow.graph - workflow_run.inputs = json.dumps(inputs) - workflow_run.status = WorkflowRunStatus.RUNNING.value - workflow_run.created_by_role = ( - CreatedByRole.ACCOUNT.value if isinstance(self._user, Account) else CreatedByRole.END_USER.value - ) - workflow_run.created_by = self._user.id + with Session(db.engine, expire_on_commit=False) as session: + workflow_run = WorkflowRun() + system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID] + workflow_run.id = system_id or str(uuid4()) + workflow_run.tenant_id = self._workflow.tenant_id + workflow_run.app_id = self._workflow.app_id + workflow_run.sequence_number = new_sequence_number + workflow_run.workflow_id = self._workflow.id + workflow_run.type = self._workflow.type + workflow_run.triggered_from = triggered_from.value + workflow_run.version = self._workflow.version + workflow_run.graph = self._workflow.graph + workflow_run.inputs = json.dumps(inputs) + workflow_run.status = WorkflowRunStatus.RUNNING + workflow_run.created_by_role = ( + CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER + ) + workflow_run.created_by = self._user.id + workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) - db.session.add(workflow_run) - db.session.commit() - db.session.refresh(workflow_run) - db.session.close() + session.add(workflow_run) + session.commit() return workflow_run @@ -144,7 +145,7 @@ class WorkflowCycleManage: workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps - workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() db.session.refresh(workflow_run) @@ -191,7 +192,7 @@ class WorkflowCycleManage: workflow_run.elapsed_time = time.perf_counter() - start_at workflow_run.total_tokens = total_tokens workflow_run.total_steps = total_steps - workflow_run.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() @@ -211,7 +212,7 @@ class WorkflowCycleManage: for workflow_node_execution in running_workflow_node_executions: workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.error = error - workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_node_execution.elapsed_time = ( workflow_node_execution.finished_at - workflow_node_execution.created_at ).total_seconds() @@ -262,7 +263,7 @@ class WorkflowCycleManage: NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, } ) - workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) session.add(workflow_node_execution) session.commit() @@ -285,7 +286,7 @@ class WorkflowCycleManage: execution_metadata = ( json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None ) - finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update( @@ -329,7 +330,7 @@ class WorkflowCycleManage: inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) - finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() execution_metadata = ( json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None @@ -339,7 +340,7 @@ class WorkflowCycleManage: WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.FAILED.value, WorkflowNodeExecution.error: event.error, WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None, - WorkflowNodeExecution.process_data: json.dumps(event.process_data) if event.process_data else None, + WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None, WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None, WorkflowNodeExecution.finished_at: finished_at, WorkflowNodeExecution.elapsed_time: elapsed_time, @@ -657,7 +658,7 @@ class WorkflowCycleManage: if event.error is None else WorkflowNodeExecutionStatus.FAILED, error=None, - elapsed_time=(datetime.now(timezone.utc).replace(tzinfo=None) - event.start_at).total_seconds(), + elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(), total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0, execution_metadata=event.metadata, finished_at=int(time.time()), diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 807f09598c..d1b34db2fe 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -240,7 +240,7 @@ class ProviderConfiguration(BaseModel): if provider_record: provider_record.encrypted_config = json.dumps(credentials) provider_record.is_valid = True - provider_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + provider_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: provider_record = Provider( @@ -394,7 +394,7 @@ class ProviderConfiguration(BaseModel): if provider_model_record: provider_model_record.encrypted_config = json.dumps(credentials) provider_model_record.is_valid = True - provider_model_record.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + provider_model_record.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: provider_model_record = ProviderModel( @@ -468,7 +468,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.enabled = True - model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: model_setting = ProviderModelSetting( @@ -503,7 +503,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.enabled = False - model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: model_setting = ProviderModelSetting( @@ -570,7 +570,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.load_balancing_enabled = True - model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: model_setting = ProviderModelSetting( @@ -605,7 +605,7 @@ class ProviderConfiguration(BaseModel): if model_setting: model_setting.load_balancing_enabled = False - model_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + model_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: model_setting = ProviderModelSetting( diff --git a/api/core/file/__init__.py b/api/core/file/__init__.py index fe9e52258a..44749ebec3 100644 --- a/api/core/file/__init__.py +++ b/api/core/file/__init__.py @@ -7,13 +7,13 @@ from .models import ( ) __all__ = [ + "FILE_MODEL_IDENTITY", + "ArrayFileAttribute", + "File", + "FileAttribute", + "FileBelongsTo", + "FileTransferMethod", "FileType", "FileUploadConfig", - "FileTransferMethod", - "FileBelongsTo", - "File", "ImageConfig", - "FileAttribute", - "ArrayFileAttribute", - "FILE_MODEL_IDENTITY", ] diff --git a/api/core/file/enums.py b/api/core/file/enums.py index f4153f1676..06b99d3eb0 100644 --- a/api/core/file/enums.py +++ b/api/core/file/enums.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class FileType(str, Enum): +class FileType(StrEnum): IMAGE = "image" DOCUMENT = "document" AUDIO = "audio" @@ -16,7 +16,7 @@ class FileType(str, Enum): raise ValueError(f"No matching enum found for value '{value}'") -class FileTransferMethod(str, Enum): +class FileTransferMethod(StrEnum): REMOTE_URL = "remote_url" LOCAL_FILE = "local_file" TOOL_FILE = "tool_file" @@ -29,7 +29,7 @@ class FileTransferMethod(str, Enum): raise ValueError(f"No matching enum found for value '{value}'") -class FileBelongsTo(str, Enum): +class FileBelongsTo(StrEnum): USER = "user" ASSISTANT = "assistant" @@ -41,7 +41,7 @@ class FileBelongsTo(str, Enum): raise ValueError(f"No matching enum found for value '{value}'") -class FileAttribute(str, Enum): +class FileAttribute(StrEnum): TYPE = "type" SIZE = "size" NAME = "name" @@ -51,5 +51,5 @@ class FileAttribute(str, Enum): EXTENSION = "extension" -class ArrayFileAttribute(str, Enum): +class ArrayFileAttribute(StrEnum): LENGTH = "length" diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 03c4b8d49d..011ff382ea 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -1,6 +1,6 @@ import logging from collections.abc import Mapping -from enum import Enum +from enum import StrEnum from threading import Lock from typing import Any, Optional @@ -31,7 +31,7 @@ class CodeExecutionResponse(BaseModel): data: Data -class CodeLanguage(str, Enum): +class CodeLanguage(StrEnum): PYTHON3 = "python3" JINJA2 = "jinja2" JAVASCRIPT = "javascript" diff --git a/api/core/helper/ssrf_proxy.py b/api/core/helper/ssrf_proxy.py index 80f01fa12b..566293d125 100644 --- a/api/core/helper/ssrf_proxy.py +++ b/api/core/helper/ssrf_proxy.py @@ -53,8 +53,6 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs): response = client.request(method=method, url=url, **kwargs) if response.status_code not in STATUS_FORCELIST: - if stream: - return response.iter_bytes() return response else: logging.warning(f"Received status code {response.status_code} for URL {url} which is in the force list") diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index cae539b6a7..29e161cb74 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -86,7 +86,7 @@ class IndexingRunner: except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() except ObjectDeletedError: logging.warning("Document deleted, document id: {}".format(dataset_document.id)) @@ -94,7 +94,7 @@ class IndexingRunner: logging.exception("consume document failed") dataset_document.indexing_status = "error" dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() def run_in_splitting_status(self, dataset_document: DatasetDocument): @@ -142,13 +142,13 @@ class IndexingRunner: except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") dataset_document.indexing_status = "error" dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() def run_in_indexing_status(self, dataset_document: DatasetDocument): @@ -200,13 +200,13 @@ class IndexingRunner: except ProviderTokenNotInitError as e: dataset_document.indexing_status = "error" dataset_document.error = str(e.description) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() except Exception as e: logging.exception("consume document failed") dataset_document.indexing_status = "error" dataset_document.error = str(e) - dataset_document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() def indexing_estimate( @@ -372,7 +372,7 @@ class IndexingRunner: after_indexing_status="splitting", extra_update_params={ DatasetDocument.word_count: sum(len(text_doc.page_content) for text_doc in text_docs), - DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DatasetDocument.parsing_completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), }, ) @@ -464,7 +464,7 @@ class IndexingRunner: doc_store.add_documents(documents) # update document status to indexing - cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="indexing", @@ -479,7 +479,7 @@ class IndexingRunner: dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), }, ) @@ -680,7 +680,7 @@ class IndexingRunner: after_indexing_status="completed", extra_update_params={ DatasetDocument.tokens: tokens, - DatasetDocument.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DatasetDocument.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), DatasetDocument.indexing_latency: indexing_end_at - indexing_start_at, DatasetDocument.error: None, }, @@ -705,7 +705,7 @@ class IndexingRunner: { DocumentSegment.status: "completed", DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } ) @@ -738,7 +738,7 @@ class IndexingRunner: { DocumentSegment.status: "completed", DocumentSegment.enabled: True, - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } ) @@ -849,7 +849,7 @@ class IndexingRunner: doc_store.add_documents(documents) # update document status to indexing - cur_time = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) self._update_document_index_status( document_id=dataset_document.id, after_indexing_status="indexing", @@ -864,7 +864,7 @@ class IndexingRunner: dataset_document_id=dataset_document.id, update_params={ DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), }, ) pass diff --git a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py index 182aeed98f..c451bf514c 100644 --- a/api/core/llm_generator/output_parser/suggested_questions_after_answer.py +++ b/api/core/llm_generator/output_parser/suggested_questions_after_answer.py @@ -15,6 +15,5 @@ class SuggestedQuestionsAfterAnswerOutputParser: json_obj = json.loads(action_match.group(0).strip()) else: json_obj = [] - print(f"Could not parse LLM output: {text}") return json_obj diff --git a/api/core/model_runtime/entities/__init__.py b/api/core/model_runtime/entities/__init__.py index 5e52f10b4c..1c73755cff 100644 --- a/api/core/model_runtime/entities/__init__.py +++ b/api/core/model_runtime/entities/__init__.py @@ -18,25 +18,25 @@ from .message_entities import ( from .model_entities import ModelPropertyKey __all__ = [ + "AssistantPromptMessage", + "AudioPromptMessageContent", + "DocumentPromptMessageContent", "ImagePromptMessageContent", - "VideoPromptMessageContent", - "PromptMessage", - "PromptMessageRole", + "LLMResult", + "LLMResultChunk", + "LLMResultChunkDelta", "LLMUsage", "ModelPropertyKey", - "AssistantPromptMessage", + "PromptMessage", "PromptMessage", "PromptMessageContent", + "PromptMessageContentType", "PromptMessageRole", + "PromptMessageRole", + "PromptMessageTool", "SystemPromptMessage", "TextPromptMessageContent", - "UserPromptMessage", - "PromptMessageTool", "ToolPromptMessage", - "PromptMessageContentType", - "LLMResult", - "LLMResultChunk", - "LLMResultChunkDelta", - "AudioPromptMessageContent", - "DocumentPromptMessageContent", + "UserPromptMessage", + "VideoPromptMessageContent", ] diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py index e86fb37522..f2870209bb 100644 --- a/api/core/model_runtime/entities/message_entities.py +++ b/api/core/model_runtime/entities/message_entities.py @@ -1,6 +1,6 @@ from abc import ABC from collections.abc import Sequence -from enum import Enum +from enum import Enum, StrEnum from typing import Literal, Optional from pydantic import BaseModel, Field, field_validator @@ -49,7 +49,7 @@ class PromptMessageFunction(BaseModel): function: PromptMessageTool -class PromptMessageContentType(str, Enum): +class PromptMessageContentType(StrEnum): """ Enum class for prompt message content type. """ @@ -95,7 +95,7 @@ class ImagePromptMessageContent(PromptMessageContent): Model class for image prompt message content. """ - class DETAIL(str, Enum): + class DETAIL(StrEnum): LOW = "low" HIGH = "high" diff --git a/api/core/model_runtime/entities/model_entities.py b/api/core/model_runtime/entities/model_entities.py index 4e1ce17533..edc6eac517 100644 --- a/api/core/model_runtime/entities/model_entities.py +++ b/api/core/model_runtime/entities/model_entities.py @@ -1,5 +1,5 @@ from decimal import Decimal -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Optional from pydantic import BaseModel, ConfigDict @@ -92,7 +92,7 @@ class ModelFeature(Enum): AUDIO = "audio" -class DefaultParameterName(str, Enum): +class DefaultParameterName(StrEnum): """ Enum class for parameter template variable. """ diff --git a/api/core/model_runtime/model_providers/anthropic/llm/llm.py b/api/core/model_runtime/model_providers/anthropic/llm/llm.py index 79701e4ea4..3faf5abbe8 100644 --- a/api/core/model_runtime/model_providers/anthropic/llm/llm.py +++ b/api/core/model_runtime/model_providers/anthropic/llm/llm.py @@ -453,7 +453,7 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): return credentials_kwargs - def _convert_prompt_messages(self, prompt_messages: list[PromptMessage]) -> tuple[str, list[dict]]: + def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) -> tuple[str, list[dict]]: """ Convert prompt messages to dict list and system """ @@ -461,7 +461,15 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): first_loop = True for message in prompt_messages: if isinstance(message, SystemPromptMessage): - message.content = message.content.strip() + if isinstance(message.content, str): + message.content = message.content.strip() + elif isinstance(message.content, list): + # System prompt only support text + message.content = "".join( + c.data.strip() for c in message.content if isinstance(c, TextPromptMessageContent) + ) + else: + raise ValueError(f"Unknown system prompt message content type {type(message.content)}") if first_loop: system = message.content first_loop = False @@ -475,6 +483,10 @@ class AnthropicLargeLanguageModel(LargeLanguageModel): if isinstance(message, UserPromptMessage): message = cast(UserPromptMessage, message) if isinstance(message.content, str): + # handle empty user prompt see #10013 #10520 + # responses, ignore user prompts containing only whitespace, the Claude API can't handle it. + if not message.content.strip(): + continue message_dict = {"role": "user", "content": message.content} prompt_message_dicts.append(message_dict) else: diff --git a/api/core/model_runtime/model_providers/azure_openai/_constant.py b/api/core/model_runtime/model_providers/azure_openai/_constant.py index e61a9e0474..4cf58275d7 100644 --- a/api/core/model_runtime/model_providers/azure_openai/_constant.py +++ b/api/core/model_runtime/model_providers/azure_openai/_constant.py @@ -779,7 +779,7 @@ LLM_BASE_MODELS = [ name="frequency_penalty", **PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY], ), - _get_max_tokens(default=512, min_val=1, max_val=4096), + _get_max_tokens(default=512, min_val=1, max_val=16384), ParameterRule( name="seed", label=I18nObject(zh_Hans="种子", en_US="Seed"), diff --git a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py index 95c8f36271..c5d7a83a4e 100644 --- a/api/core/model_runtime/model_providers/azure_openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/azure_openai/llm/llm.py @@ -598,6 +598,9 @@ class AzureOpenAILargeLanguageModel(_CommonAzureOpenAI, LargeLanguageModel): # message = cast(AssistantPromptMessage, message) message_dict = {"role": "assistant", "content": message.content} if message.tool_calls: + # fix azure when enable json schema cant process content = "" in assistant fix with None + if not message.content: + message_dict["content"] = None message_dict["tool_calls"] = [helper.dump_model(tool_call) for tool_call in message.tool_calls] elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) diff --git a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py index 133cc9f76e..173b9d250c 100644 --- a/api/core/model_runtime/model_providers/azure_openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/azure_openai/tts/tts.py @@ -14,7 +14,7 @@ from core.model_runtime.model_providers.azure_openai._constant import TTS_BASE_M class AzureOpenAIText2SpeechModel(_CommonAzureOpenAI, TTSModel): """ - Model class for OpenAI Speech to text model. + Model class for OpenAI text2speech model. """ def _invoke( diff --git a/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-lite-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-lite-v1.yaml new file mode 100644 index 0000000000..5aaf50473e --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-lite-v1.yaml @@ -0,0 +1,52 @@ +model: amazon.nova-lite-v1:0 +label: + en_US: Nova Lite V1 +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 300000 +parameter_rules: + - name: max_new_tokens + use_template: max_tokens + required: true + default: 2048 + min: 1 + max: 5000 + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.00006' + output: '0.00024' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-micro-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-micro-v1.yaml new file mode 100644 index 0000000000..4ba8da660b --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-micro-v1.yaml @@ -0,0 +1,52 @@ +model: amazon.nova-micro-v1:0 +label: + en_US: Nova Micro V1 +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: max_new_tokens + use_template: max_tokens + required: true + default: 2048 + min: 1 + max: 5000 + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.000035' + output: '0.00014' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-pro-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-pro-v1.yaml new file mode 100644 index 0000000000..75e53e74a9 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/amazon.nova-pro-v1.yaml @@ -0,0 +1,52 @@ +model: amazon.nova-pro-v1:0 +label: + en_US: Nova Pro V1 +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 300000 +parameter_rules: + - name: max_new_tokens + use_template: max_tokens + required: true + default: 2048 + min: 1 + max: 5000 + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.0008' + output: '0.0032' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index ef4dfaf6f1..e6e8a765ee 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -70,6 +70,8 @@ class BedrockLargeLanguageModel(LargeLanguageModel): {"prefix": "cohere.command-r", "support_system_prompts": True, "support_tool_use": True}, {"prefix": "amazon.titan", "support_system_prompts": False, "support_tool_use": False}, {"prefix": "ai21.jamba-1-5", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "amazon.nova", "support_system_prompts": True, "support_tool_use": False}, + {"prefix": "us.amazon.nova", "support_system_prompts": True, "support_tool_use": False}, ] @staticmethod @@ -194,6 +196,13 @@ class BedrockLargeLanguageModel(LargeLanguageModel): if model_info["support_tool_use"] and tools: parameters["toolConfig"] = self._convert_converse_tool_config(tools=tools) try: + # for issue #10976 + conversations_list = parameters["messages"] + # if two consecutive user messages found, combine them into one message + for i in range(len(conversations_list) - 2, -1, -1): + if conversations_list[i]["role"] == conversations_list[i + 1]["role"]: + conversations_list[i]["content"].extend(conversations_list.pop(i + 1)["content"]) + if stream: response = bedrock_client.converse_stream(**parameters) return self._handle_converse_stream_response( diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-lite-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-lite-v1.yaml new file mode 100644 index 0000000000..594f304617 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-lite-v1.yaml @@ -0,0 +1,52 @@ +model: us.amazon.nova-lite-v1:0 +label: + en_US: Nova Lite V1 (US.Cross Region Inference) +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 300000 +parameter_rules: + - name: max_new_tokens + use_template: max_tokens + required: true + default: 2048 + min: 1 + max: 5000 + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.00006' + output: '0.00024' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-micro-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-micro-v1.yaml new file mode 100644 index 0000000000..4662d0a203 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-micro-v1.yaml @@ -0,0 +1,52 @@ +model: us.amazon.nova-micro-v1:0 +label: + en_US: Nova Micro V1 (US.Cross Region Inference) +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 128000 +parameter_rules: + - name: max_new_tokens + use_template: max_tokens + required: true + default: 2048 + min: 1 + max: 5000 + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.000035' + output: '0.00014' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-pro-v1.yaml b/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-pro-v1.yaml new file mode 100644 index 0000000000..dfb3e5f210 --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/llm/us.amazon.nova-pro-v1.yaml @@ -0,0 +1,52 @@ +model: us.amazon.nova-pro-v1:0 +label: + en_US: Nova Pro V1 (US.Cross Region Inference) +model_type: llm +features: + - agent-thought + - tool-call + - stream-tool-call +model_properties: + mode: chat + context_size: 300000 +parameter_rules: + - name: max_new_tokens + use_template: max_tokens + required: true + default: 2048 + min: 1 + max: 5000 + - name: temperature + use_template: temperature + required: false + type: float + default: 1 + min: 0.0 + max: 1.0 + help: + zh_Hans: 生成内容的随机性。 + en_US: The amount of randomness injected into the response. + - name: top_p + required: false + type: float + default: 0.999 + min: 0.000 + max: 1.000 + help: + zh_Hans: 在核采样中,Anthropic Claude 按概率递减顺序计算每个后续标记的所有选项的累积分布,并在达到 top_p 指定的特定概率时将其切断。您应该更改温度或top_p,但不能同时更改两者。 + en_US: In nucleus sampling, Anthropic Claude computes the cumulative distribution over all the options for each subsequent token in decreasing probability order and cuts it off once it reaches a particular probability specified by top_p. You should alter either temperature or top_p, but not both. + - name: top_k + required: false + type: int + default: 0 + min: 0 + # tip docs from aws has error, max value is 500 + max: 500 + help: + zh_Hans: 对于每个后续标记,仅从前 K 个选项中进行采样。使用 top_k 删除长尾低概率响应。 + en_US: Only sample from the top K options for each subsequent token. Use top_k to remove long tail low probability responses. +pricing: + input: '0.0008' + output: '0.0032' + unit: '0.001' + currency: USD diff --git a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml index 4973ac8ad6..0bbd27ad74 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml +++ b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-chat.yaml @@ -5,6 +5,7 @@ label: model_type: llm features: - agent-thought + - tool-call - multi-tool-call - stream-tool-call model_properties: @@ -72,7 +73,7 @@ parameter_rules: - text - json_object pricing: - input: '1' - output: '2' - unit: '0.000001' + input: "1" + output: "2" + unit: "0.000001" currency: RMB diff --git a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml index caafeadadd..97310e76b9 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml +++ b/api/core/model_runtime/model_providers/deepseek/llm/deepseek-coder.yaml @@ -5,6 +5,7 @@ label: model_type: llm features: - agent-thought + - tool-call - multi-tool-call - stream-tool-call model_properties: diff --git a/api/core/model_runtime/model_providers/deepseek/llm/llm.py b/api/core/model_runtime/model_providers/deepseek/llm/llm.py index 6d0a3ee262..610dc7b458 100644 --- a/api/core/model_runtime/model_providers/deepseek/llm/llm.py +++ b/api/core/model_runtime/model_providers/deepseek/llm/llm.py @@ -1,18 +1,17 @@ from collections.abc import Generator from typing import Optional, Union -from urllib.parse import urlparse -import tiktoken +from yarl import URL -from core.model_runtime.entities.llm_entities import LLMResult +from core.model_runtime.entities.llm_entities import LLMMode, LLMResult from core.model_runtime.entities.message_entities import ( PromptMessage, PromptMessageTool, ) -from core.model_runtime.model_providers.openai.llm.llm import OpenAILargeLanguageModel +from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel -class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): +class DeepseekLargeLanguageModel(OAIAPICompatLargeLanguageModel): def _invoke( self, model: str, @@ -25,92 +24,15 @@ class DeepSeekLargeLanguageModel(OpenAILargeLanguageModel): user: Optional[str] = None, ) -> Union[LLMResult, Generator]: self._add_custom_parameters(credentials) - - return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) + return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream) def validate_credentials(self, model: str, credentials: dict) -> None: self._add_custom_parameters(credentials) super().validate_credentials(model, credentials) - # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_string(self, model: str, text: str, tools: Optional[list[PromptMessageTool]] = None) -> int: - """ - Calculate num tokens for text completion model with tiktoken package. - - :param model: model name - :param text: prompt text - :param tools: tools for tool calling - :return: number of tokens - """ - encoding = tiktoken.get_encoding("cl100k_base") - num_tokens = len(encoding.encode(text)) - - if tools: - num_tokens += self._num_tokens_for_tools(encoding, tools) - - return num_tokens - - # refactored from openai model runtime, use cl100k_base for calculate token number - def _num_tokens_from_messages( - self, model: str, messages: list[PromptMessage], tools: Optional[list[PromptMessageTool]] = None - ) -> int: - """Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package. - - Official documentation: https://github.com/openai/openai-cookbook/blob/ - main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb""" - encoding = tiktoken.get_encoding("cl100k_base") - tokens_per_message = 3 - tokens_per_name = 1 - - num_tokens = 0 - messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages] - for message in messages_dict: - num_tokens += tokens_per_message - for key, value in message.items(): - # Cast str(value) in case the message value is not a string - # This occurs with function messages - # TODO: The current token calculation method for the image type is not implemented, - # which need to download the image and then get the resolution for calculation, - # and will increase the request delay - if isinstance(value, list): - text = "" - for item in value: - if isinstance(item, dict) and item["type"] == "text": - text += item["text"] - - value = text - - if key == "tool_calls": - for tool_call in value: - for t_key, t_value in tool_call.items(): - num_tokens += len(encoding.encode(t_key)) - if t_key == "function": - for f_key, f_value in t_value.items(): - num_tokens += len(encoding.encode(f_key)) - num_tokens += len(encoding.encode(f_value)) - else: - num_tokens += len(encoding.encode(t_key)) - num_tokens += len(encoding.encode(t_value)) - else: - num_tokens += len(encoding.encode(str(value))) - - if key == "name": - num_tokens += tokens_per_name - - # every reply is primed with assistant - num_tokens += 3 - - if tools: - num_tokens += self._num_tokens_for_tools(encoding, tools) - - return num_tokens - @staticmethod - def _add_custom_parameters(credentials: dict) -> None: - credentials["mode"] = "chat" - credentials["openai_api_key"] = credentials["api_key"] - if "endpoint_url" not in credentials or credentials["endpoint_url"] == "": - credentials["openai_api_base"] = "https://api.deepseek.com" - else: - parsed_url = urlparse(credentials["endpoint_url"]) - credentials["openai_api_base"] = f"{parsed_url.scheme}://{parsed_url.netloc}" + def _add_custom_parameters(credentials) -> None: + credentials["endpoint_url"] = str(URL(credentials.get("endpoint_url", "https://api.deepseek.com"))) + credentials["mode"] = LLMMode.CHAT.value + credentials["function_calling_type"] = "tool_call" + credentials["stream_function_calling"] = "support" diff --git a/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py b/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py index 1b85b7fced..0c253a4a0a 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/gitee_ai/llm/llm.py @@ -32,12 +32,12 @@ class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel): return super()._invoke(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) def validate_credentials(self, model: str, credentials: dict) -> None: - self._add_custom_parameters(credentials, model, None) + self._add_custom_parameters(credentials, None) super().validate_credentials(model, credentials) - def _add_custom_parameters(self, credentials: dict, model: str, model_parameters: dict) -> None: + def _add_custom_parameters(self, credentials: dict, model: Optional[str]) -> None: if model is None: - model = "bge-large-zh-v1.5" + model = "Qwen2-72B-Instruct" model_identity = GiteeAILargeLanguageModel.MODEL_TO_IDENTITY.get(model, model) credentials["endpoint_url"] = f"https://ai.gitee.com/api/serverless/{model_identity}/" @@ -47,5 +47,7 @@ class GiteeAILargeLanguageModel(OAIAPICompatLargeLanguageModel): credentials["mode"] = LLMMode.CHAT.value schema = self.get_model_schema(model, credentials) + assert schema is not None, f"Model schema not found for model {model}" + assert schema.features is not None, f"Model features not found for model {model}" if ModelFeature.TOOL_CALL in schema.features or ModelFeature.MULTI_TOOL_CALL in schema.features: credentials["function_calling_type"] = "tool_call" diff --git a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py index 231345c2f4..832ba92740 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/gitee_ai/rerank/rerank.py @@ -122,7 +122,7 @@ class GiteeAIRerankModel(RerankModel): label=I18nObject(en_US=model), model_type=ModelType.RERANK, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512))}, ) return entity diff --git a/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py index ed2bd5b13d..36dcea405d 100644 --- a/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py +++ b/api/core/model_runtime/model_providers/gitee_ai/tts/tts.py @@ -10,7 +10,7 @@ from core.model_runtime.model_providers.gitee_ai._common import _CommonGiteeAI class GiteeAIText2SpeechModel(_CommonGiteeAI, TTSModel): """ - Model class for OpenAI Speech to text model. + Model class for OpenAI text2speech model. """ def _invoke( diff --git a/api/core/model_runtime/model_providers/google/llm/llm.py b/api/core/model_runtime/model_providers/google/llm/llm.py index 77e0801b63..c19e860d2e 100644 --- a/api/core/model_runtime/model_providers/google/llm/llm.py +++ b/api/core/model_runtime/model_providers/google/llm/llm.py @@ -254,8 +254,12 @@ class GoogleLargeLanguageModel(LargeLanguageModel): assistant_prompt_message = AssistantPromptMessage(content=response.text) # calculate num tokens - prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) - completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) + if response.usage_metadata: + prompt_tokens = response.usage_metadata.prompt_token_count + completion_tokens = response.usage_metadata.candidates_token_count + else: + prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages) + completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message]) # transform usage usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens) diff --git a/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py b/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py index 5ea7532564..feb5777028 100644 --- a/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/gpustack/rerank/rerank.py @@ -140,7 +140,7 @@ class GPUStackRerankModel(RerankModel): label=I18nObject(en_US=model), model_type=ModelType.RERANK, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512))}, ) return entity diff --git a/api/core/model_runtime/model_providers/jina/rerank/rerank.py b/api/core/model_runtime/model_providers/jina/rerank/rerank.py index aacc8e75d3..22f882be6b 100644 --- a/api/core/model_runtime/model_providers/jina/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/jina/rerank/rerank.py @@ -128,7 +128,7 @@ class JinaRerankModel(RerankModel): label=I18nObject(en_US=model), model_type=ModelType.RERANK, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000))}, ) return entity diff --git a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py index 49c558f4a4..f5be7a9828 100644 --- a/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/jina/text_embedding/text_embedding.py @@ -193,7 +193,7 @@ class JinaTextEmbeddingModel(TextEmbeddingModel): label=I18nObject(en_US=model), model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000))}, ) return entity diff --git a/api/core/model_runtime/model_providers/moonshot/llm/llm.py b/api/core/model_runtime/model_providers/moonshot/llm/llm.py index 5c955c86d3..90d015942e 100644 --- a/api/core/model_runtime/model_providers/moonshot/llm/llm.py +++ b/api/core/model_runtime/model_providers/moonshot/llm/llm.py @@ -252,7 +252,7 @@ class MoonshotLargeLanguageModel(OAIAPICompatLargeLanguageModel): # ignore sse comments if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip("data: ").lstrip() + decoded_chunk = chunk.strip().removeprefix("data: ") chunk_json = None try: chunk_json = json.loads(decoded_chunk) diff --git a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py index a16c91cd7e..83c4facc8d 100644 --- a/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/ollama/text_embedding/text_embedding.py @@ -139,7 +139,7 @@ class OllamaEmbeddingModel(TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], diff --git a/api/core/model_runtime/model_providers/openai/llm/llm.py b/api/core/model_runtime/model_providers/openai/llm/llm.py index aea884e002..07cb1e2d10 100644 --- a/api/core/model_runtime/model_providers/openai/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai/llm/llm.py @@ -943,6 +943,9 @@ class OpenAILargeLanguageModel(_CommonOpenAI, LargeLanguageModel): } elif isinstance(message, SystemPromptMessage): message = cast(SystemPromptMessage, message) + if isinstance(message.content, list): + text_contents = filter(lambda c: isinstance(c, TextPromptMessageContent), message.content) + message.content = "".join(c.data for c in text_contents) message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) diff --git a/api/core/model_runtime/model_providers/openai/tts/tts.py b/api/core/model_runtime/model_providers/openai/tts/tts.py index 2e57b95944..dac37f0c7f 100644 --- a/api/core/model_runtime/model_providers/openai/tts/tts.py +++ b/api/core/model_runtime/model_providers/openai/tts/tts.py @@ -11,7 +11,7 @@ from core.model_runtime.model_providers.openai._common import _CommonOpenAI class OpenAIText2SpeechModel(_CommonOpenAI, TTSModel): """ - Model class for OpenAI Speech to text model. + Model class for OpenAI text2speech model. """ def _invoke( diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py index e1342fe985..26c090d30e 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/llm/llm.py @@ -462,7 +462,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOaiApiCompat, LargeLanguageModel): # ignore sse comments if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip("data: ").lstrip() + decoded_chunk = chunk.strip().removeprefix("data: ") if decoded_chunk == "[DONE]": # Some provider returns "data: [DONE]" continue diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml index 69a081f35c..2b8dcb72d8 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml +++ b/api/core/model_runtime/model_providers/openai_api_compatible/openai_api_compatible.yaml @@ -9,6 +9,7 @@ supported_model_types: - text-embedding - speech2text - rerank + - tts configurate_methods: - customizable-model model_credential_schema: @@ -67,7 +68,7 @@ model_credential_schema: - variable: __model_type value: llm type: text-input - default: '4096' + default: "4096" placeholder: zh_Hans: 在此输入您的模型上下文长度 en_US: Enter your Model context size @@ -80,7 +81,7 @@ model_credential_schema: - variable: __model_type value: text-embedding type: text-input - default: '4096' + default: "4096" placeholder: zh_Hans: 在此输入您的模型上下文长度 en_US: Enter your Model context size @@ -93,7 +94,7 @@ model_credential_schema: - variable: __model_type value: rerank type: text-input - default: '4096' + default: "4096" placeholder: zh_Hans: 在此输入您的模型上下文长度 en_US: Enter your Model context size @@ -104,7 +105,7 @@ model_credential_schema: show_on: - variable: __model_type value: llm - default: '4096' + default: "4096" type: text-input - variable: function_calling_type show_on: @@ -174,3 +175,19 @@ model_credential_schema: value: llm default: '\n\n' type: text-input + - variable: voices + show_on: + - variable: __model_type + value: tts + label: + en_US: Available Voices (comma-separated) + zh_Hans: 可用声音(用英文逗号分隔) + type: text-input + required: false + default: "alloy" + placeholder: + en_US: "alloy,echo,fable,onyx,nova,shimmer" + zh_Hans: "alloy,echo,fable,onyx,nova,shimmer" + help: + en_US: "List voice names separated by commas. First voice will be used as default." + zh_Hans: "用英文逗号分隔的声音列表。第一个声音将作为默认值。" diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py index c2b7297aac..9da8f55d0a 100644 --- a/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/openai_api_compatible/text_embedding/text_embedding.py @@ -139,13 +139,17 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): if api_key: headers["Authorization"] = f"Bearer {api_key}" - endpoint_url = credentials.get("endpoint_url") + endpoint_url = credentials.get("endpoint_url", "") if not endpoint_url.endswith("/"): endpoint_url += "/" endpoint_url = urljoin(endpoint_url, "embeddings") payload = {"input": "ping", "model": model} + # For nvidia models, the "input_type":"query" need in the payload + # more to check issue #11193 or NvidiaTextEmbeddingModel + if model.startswith("nvidia/"): + payload["input_type"] = "query" response = requests.post(url=endpoint_url, headers=headers, data=json.dumps(payload), timeout=(10, 300)) @@ -176,7 +180,7 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/tts/__init__.py b/api/core/model_runtime/model_providers/openai_api_compatible/tts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py b/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py new file mode 100644 index 0000000000..8239c625f7 --- /dev/null +++ b/api/core/model_runtime/model_providers/openai_api_compatible/tts/tts.py @@ -0,0 +1,145 @@ +from collections.abc import Iterable +from typing import Optional +from urllib.parse import urljoin + +import requests + +from core.model_runtime.entities.common_entities import I18nObject +from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType +from core.model_runtime.errors.invoke import InvokeBadRequestError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.tts_model import TTSModel +from core.model_runtime.model_providers.openai_api_compatible._common import _CommonOaiApiCompat + + +class OAICompatText2SpeechModel(_CommonOaiApiCompat, TTSModel): + """ + Model class for OpenAI-compatible text2speech model. + """ + + def _invoke( + self, + model: str, + tenant_id: str, + credentials: dict, + content_text: str, + voice: str, + user: Optional[str] = None, + ) -> Iterable[bytes]: + """ + Invoke TTS model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model voice/speaker + :param user: unique user id + :return: audio data as bytes iterator + """ + # Set up headers with authentication if provided + headers = {} + if api_key := credentials.get("api_key"): + headers["Authorization"] = f"Bearer {api_key}" + + # Construct endpoint URL + endpoint_url = credentials.get("endpoint_url") + if not endpoint_url.endswith("/"): + endpoint_url += "/" + endpoint_url = urljoin(endpoint_url, "audio/speech") + + # Get audio format from model properties + audio_format = self._get_model_audio_type(model, credentials) + + # Split text into chunks if needed based on word limit + word_limit = self._get_model_word_limit(model, credentials) + sentences = self._split_text_into_sentences(content_text, word_limit) + + for sentence in sentences: + # Prepare request payload + payload = {"model": model, "input": sentence, "voice": voice, "response_format": audio_format} + + # Make POST request + response = requests.post(endpoint_url, headers=headers, json=payload, stream=True) + + if response.status_code != 200: + raise InvokeBadRequestError(response.text) + + # Stream the audio data + for chunk in response.iter_content(chunk_size=4096): + if chunk: + yield chunk + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + # Get default voice for validation + voice = self._get_model_default_voice(model, credentials) + + # Test with a simple text + next( + self._invoke( + model=model, tenant_id="validate", credentials=credentials, content_text="Test.", voice=voice + ) + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def get_customizable_model_schema(self, model: str, credentials: dict) -> Optional[AIModelEntity]: + """ + Get customizable model schema + """ + # Parse voices from comma-separated string + voice_names = credentials.get("voices", "alloy").strip().split(",") + voices = [] + + for voice in voice_names: + voice = voice.strip() + if not voice: + continue + + # Use en-US for all voices + voices.append( + { + "name": voice, + "mode": voice, + "language": "en-US", + } + ) + + # If no voices provided or all voices were empty strings, use 'alloy' as default + if not voices: + voices = [{"name": "Alloy", "mode": "alloy", "language": "en-US"}] + + return AIModelEntity( + model=model, + label=I18nObject(en_US=model), + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_type=ModelType.TTS, + model_properties={ + ModelPropertyKey.AUDIO_TYPE: credentials.get("audio_type", "mp3"), + ModelPropertyKey.WORD_LIMIT: int(credentials.get("word_limit", 4096)), + ModelPropertyKey.DEFAULT_VOICE: voices[0]["mode"], + ModelPropertyKey.VOICES: voices, + }, + ) + + def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list: + """ + Override base get_tts_model_voices to handle customizable voices + """ + model_schema = self.get_customizable_model_schema(model, credentials) + + if not model_schema or ModelPropertyKey.VOICES not in model_schema.model_properties: + raise ValueError("this model does not support voice") + + voices = model_schema.model_properties[ModelPropertyKey.VOICES] + + # Always return all voices regardless of language + return [{"name": d["name"], "value": d["mode"]} for d in voices] diff --git a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py index d78bdaa75e..7bbd31e87c 100644 --- a/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/perfxcloud/text_embedding/text_embedding.py @@ -182,7 +182,7 @@ class OAICompatEmbeddingModel(_CommonOaiApiCompat, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], diff --git a/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml b/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml index f010e4c826..b52df3e4e3 100644 --- a/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/llm/_position.yaml @@ -24,4 +24,3 @@ - meta-llama/Meta-Llama-3.1-8B-Instruct - google/gemma-2-27b-it - google/gemma-2-9b-it -- deepseek-ai/DeepSeek-V2-Chat diff --git a/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml b/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml index 71f9a92381..73a9e80769 100644 --- a/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml +++ b/api/core/model_runtime/model_providers/siliconflow/siliconflow.yaml @@ -18,6 +18,7 @@ supported_model_types: - text-embedding - rerank - speech2text + - tts configurate_methods: - predefined-model - customizable-model diff --git a/api/core/model_runtime/model_providers/siliconflow/tts/__init__.py b/api/core/model_runtime/model_providers/siliconflow/tts/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/model_runtime/model_providers/siliconflow/tts/fish-speech-1.4.yaml b/api/core/model_runtime/model_providers/siliconflow/tts/fish-speech-1.4.yaml new file mode 100644 index 0000000000..4adfd05c60 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/tts/fish-speech-1.4.yaml @@ -0,0 +1,37 @@ +model: fishaudio/fish-speech-1.4 +model_type: tts +model_properties: + default_voice: 'fishaudio/fish-speech-1.4:alex' + voices: + - mode: "fishaudio/fish-speech-1.4:alex" + name: "Alex(男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.4:benjamin" + name: "Benjamin(男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.4:charles" + name: "Charles(男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.4:david" + name: "David(男声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.4:anna" + name: "Anna(女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.4:bella" + name: "Bella(女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.4:claire" + name: "Claire(女声)" + language: [ "zh-Hans", "en-US" ] + - mode: "fishaudio/fish-speech-1.4:diana" + name: "Diana(女声)" + language: [ "zh-Hans", "en-US" ] + audio_type: 'mp3' + max_workers: 5 + # stream: false +pricing: + input: '0.015' + output: '0' + unit: '0.001' + currency: RMB diff --git a/api/core/model_runtime/model_providers/siliconflow/tts/tts.py b/api/core/model_runtime/model_providers/siliconflow/tts/tts.py new file mode 100644 index 0000000000..a5554abb73 --- /dev/null +++ b/api/core/model_runtime/model_providers/siliconflow/tts/tts.py @@ -0,0 +1,105 @@ +import concurrent.futures +from typing import Any, Optional + +from openai import OpenAI + +from core.model_runtime.errors.invoke import InvokeBadRequestError +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.tts_model import TTSModel +from core.model_runtime.model_providers.openai._common import _CommonOpenAI + + +class SiliconFlowText2SpeechModel(_CommonOpenAI, TTSModel): + """ + Model class for SiliconFlow Speech to text model. + """ + + def _invoke( + self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str, user: Optional[str] = None + ) -> Any: + """ + _invoke text2speech model + + :param model: model name + :param tenant_id: user tenant id + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :param user: unique user id + :return: text translated to audio file + """ + if not voice or voice not in [ + d["value"] for d in self.get_tts_model_voices(model=model, credentials=credentials) + ]: + voice = self._get_model_default_voice(model, credentials) + # if streaming: + return self._tts_invoke_streaming(model=model, credentials=credentials, content_text=content_text, voice=voice) + + def validate_credentials(self, model: str, credentials: dict, user: Optional[str] = None) -> None: + """ + validate credentials text2speech model + + :param model: model name + :param credentials: model credentials + :param user: unique user id + :return: text translated to audio file + """ + try: + self._tts_invoke_streaming( + model=model, + credentials=credentials, + content_text="Hello SiliconFlow!", + voice=self._get_model_default_voice(model, credentials), + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + def _tts_invoke_streaming(self, model: str, credentials: dict, content_text: str, voice: str) -> Any: + """ + _tts_invoke_streaming text2speech model + + :param model: model name + :param credentials: model credentials + :param content_text: text content to be translated + :param voice: model timbre + :return: text translated to audio file + """ + try: + # doc: https://docs.siliconflow.cn/capabilities/text-to-speech + self._add_custom_parameters(credentials) + credentials_kwargs = self._to_credential_kwargs(credentials) + client = OpenAI(**credentials_kwargs) + model_support_voice = [ + x.get("value") for x in self.get_tts_model_voices(model=model, credentials=credentials) + ] + if not voice or voice not in model_support_voice: + voice = self._get_model_default_voice(model, credentials) + if len(content_text) > 4096: + sentences = self._split_text_into_sentences(content_text, max_length=4096) + executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(3, len(sentences))) + futures = [ + executor.submit( + client.audio.speech.with_streaming_response.create, + model=model, + response_format="mp3", + input=sentences[i], + voice=voice, + ) + for i in range(len(sentences)) + ] + for future in futures: + yield from future.result().__enter__().iter_bytes(1024) # noqa:PLC2801 + + else: + response = client.audio.speech.with_streaming_response.create( + model=model, voice=voice, response_format="mp3", input=content_text.strip() + ) + + yield from response.__enter__().iter_bytes(1024) # noqa:PLC2801 + except Exception as ex: + raise InvokeBadRequestError(str(ex)) + + @classmethod + def _add_custom_parameters(cls, credentials: dict) -> None: + credentials["openai_api_base"] = "https://api.siliconflow.cn" + credentials["openai_api_key"] = credentials["api_key"] diff --git a/api/core/model_runtime/model_providers/stepfun/llm/llm.py b/api/core/model_runtime/model_providers/stepfun/llm/llm.py index 43b91a1aec..686809ff2b 100644 --- a/api/core/model_runtime/model_providers/stepfun/llm/llm.py +++ b/api/core/model_runtime/model_providers/stepfun/llm/llm.py @@ -250,7 +250,7 @@ class StepfunLargeLanguageModel(OAIAPICompatLargeLanguageModel): # ignore sse comments if chunk.startswith(":"): continue - decoded_chunk = chunk.strip().lstrip("data: ").lstrip() + decoded_chunk = chunk.strip().removeprefix("data: ") chunk_json = None try: chunk_json = json.loads(decoded_chunk) diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py index 43233e6126..9cd0c78d99 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -173,7 +173,7 @@ class VertexAiTextEmbeddingModel(_CommonVertexAi, TextEmbeddingModel): model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, model_properties={ - ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size")), + ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512)), ModelPropertyKey.MAX_CHUNKS: 1, }, parameter_rules=[], diff --git a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py index 8b3eb157be..2a269557f2 100644 --- a/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py +++ b/api/core/model_runtime/model_providers/volcengine_maas/legacy/volc_sdk/__init__.py @@ -1,4 +1,4 @@ from .common import ChatRole from .maas import MaasError, MaasService -__all__ = ["MaasService", "ChatRole", "MaasError"] +__all__ = ["ChatRole", "MaasError", "MaasService"] diff --git a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py index e69c9fccba..16f1bd43d8 100644 --- a/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/voyage/text_embedding/text_embedding.py @@ -166,7 +166,7 @@ class VoyageTextEmbeddingModel(TextEmbeddingModel): label=I18nObject(en_US=model), model_type=ModelType.TEXT_EMBEDDING, fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, - model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size"))}, + model_properties={ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 512))}, ) return entity diff --git a/api/core/model_runtime/model_providers/wenxin/rerank/rerank.py b/api/core/model_runtime/model_providers/wenxin/rerank/rerank.py index 9e6a7dd99e..1c7b949018 100644 --- a/api/core/model_runtime/model_providers/wenxin/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/wenxin/rerank/rerank.py @@ -17,7 +17,13 @@ class WenxinRerank(_CommonWenxin): def rerank(self, model: str, query: str, docs: list[str], top_n: Optional[int] = None): access_token = self._get_access_token() url = f"{self.api_bases[model]}?access_token={access_token}" - + # For issue #11252 + # for wenxin Rerank model top_n length should be equal or less than docs length + if top_n is not None and top_n > len(docs): + top_n = len(docs) + # for wenxin Rerank model, query should not be an empty string + if query == "": + query = " " # FIXME: this is a workaround for wenxin rerank model for better user experience. try: response = httpx.post( url, @@ -25,7 +31,11 @@ class WenxinRerank(_CommonWenxin): headers={"Content-Type": "application/json"}, ) response.raise_for_status() - return response.json() + data = response.json() + # wenxin error handling + if "error_code" in data: + raise InternalServerError(data["error_msg"]) + return data except httpx.HTTPStatusError as e: raise InternalServerError(str(e)) @@ -69,6 +79,9 @@ class WenxinRerankModel(RerankModel): results = wenxin_rerank.rerank(model, query, docs, top_n) rerank_documents = [] + if "results" not in results: + raise ValueError("results key not found in response") + for result in results["results"]: index = result["index"] if "document" in result: diff --git a/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml index 7c305735b9..bb71de2bad 100644 --- a/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml +++ b/api/core/model_runtime/model_providers/x/llm/grok-beta.yaml @@ -1,9 +1,12 @@ model: grok-beta label: - en_US: Grok beta + en_US: Grok Beta model_type: llm features: + - agent-thought + - tool-call - multi-tool-call + - stream-tool-call model_properties: mode: chat context_size: 131072 diff --git a/api/core/model_runtime/model_providers/x/llm/grok-vision-beta.yaml b/api/core/model_runtime/model_providers/x/llm/grok-vision-beta.yaml new file mode 100644 index 0000000000..844f0520bc --- /dev/null +++ b/api/core/model_runtime/model_providers/x/llm/grok-vision-beta.yaml @@ -0,0 +1,64 @@ +model: grok-vision-beta +label: + en_US: Grok Vision Beta +model_type: llm +features: + - agent-thought + - vision +model_properties: + mode: chat + context_size: 8192 +parameter_rules: + - name: temperature + label: + en_US: "Temperature" + zh_Hans: "采样温度" + type: float + default: 0.7 + min: 0.0 + max: 2.0 + precision: 1 + required: true + help: + en_US: "The randomness of the sampling temperature control output. The temperature value is within the range of [0.0, 1.0]. The higher the value, the more random and creative the output; the lower the value, the more stable it is. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样温度控制输出的随机性。温度值在 [0.0, 1.0] 范围内,值越高,输出越随机和创造性;值越低,输出越稳定。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: top_p + label: + en_US: "Top P" + zh_Hans: "Top P" + type: float + default: 0.7 + min: 0.0 + max: 1.0 + precision: 1 + required: true + help: + en_US: "The value range of the sampling method is [0.0, 1.0]. The top_p value determines that the model selects tokens from the top p% of candidate words with the highest probability; when top_p is 0, this parameter is invalid. It is recommended to adjust either top_p or temperature parameters according to your needs to avoid adjusting both at the same time." + zh_Hans: "采样方法的取值范围为 [0.0,1.0]。top_p 值确定模型从概率最高的前p%的候选词中选取 tokens;当 top_p 为 0 时,此参数无效。建议根据需求调整 top_p 或 temperature 参数,避免同时调整两者。" + + - name: frequency_penalty + use_template: frequency_penalty + label: + en_US: "Frequency Penalty" + zh_Hans: "频率惩罚" + type: float + default: 0 + min: 0 + max: 2.0 + precision: 1 + required: false + help: + en_US: "Number between 0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim." + zh_Hans: "介于0和2.0之间的数字。正值会根据新标记在文本中迄今为止的现有频率来惩罚它们,从而降低模型一字不差地重复同一句话的可能性。" + + - name: user + use_template: text + label: + en_US: "User" + zh_Hans: "用户" + type: string + required: false + help: + en_US: "Used to track and differentiate conversation requests from different users." + zh_Hans: "用于追踪和区分不同用户的对话请求。" diff --git a/api/core/model_runtime/model_providers/x/llm/llm.py b/api/core/model_runtime/model_providers/x/llm/llm.py index 3f5325a857..eacd086fee 100644 --- a/api/core/model_runtime/model_providers/x/llm/llm.py +++ b/api/core/model_runtime/model_providers/x/llm/llm.py @@ -35,3 +35,5 @@ class XAILargeLanguageModel(OAIAPICompatLargeLanguageModel): credentials["endpoint_url"] = str(URL(credentials["endpoint_url"])) or "https://api.x.ai/v1" credentials["mode"] = LLMMode.CHAT.value credentials["function_calling_type"] = "tool_call" + credentials["stream_function_calling"] = "support" + credentials["vision_support"] = "support" diff --git a/api/core/model_runtime/model_providers/xinference/llm/llm.py b/api/core/model_runtime/model_providers/xinference/llm/llm.py index b82f0430c5..8d86d6937d 100644 --- a/api/core/model_runtime/model_providers/xinference/llm/llm.py +++ b/api/core/model_runtime/model_providers/xinference/llm/llm.py @@ -63,6 +63,9 @@ from core.model_runtime.model_providers.xinference.xinference_helper import ( ) from core.model_runtime.utils import helper +DEFAULT_MAX_RETRIES = 3 +DEFAULT_INVOKE_TIMEOUT = 60 + class XinferenceAILargeLanguageModel(LargeLanguageModel): def _invoke( @@ -315,7 +318,12 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): message_dict = {"role": "system", "content": message.content} elif isinstance(message, ToolPromptMessage): message = cast(ToolPromptMessage, message) - message_dict = {"tool_call_id": message.tool_call_id, "role": "tool", "content": message.content} + message_dict = { + "tool_call_id": message.tool_call_id, + "role": "tool", + "content": message.content, + "name": message.name, + } else: raise ValueError(f"Unknown message type {type(message)}") @@ -466,8 +474,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel): client = OpenAI( base_url=f'{credentials["server_url"]}/v1', api_key=api_key, - max_retries=3, - timeout=60, + max_retries=int(credentials.get("max_retries") or DEFAULT_MAX_RETRIES), + timeout=int(credentials.get("invoke_timeout") or DEFAULT_INVOKE_TIMEOUT), ) xinference_client = Client( diff --git a/api/core/model_runtime/model_providers/xinference/xinference.yaml b/api/core/model_runtime/model_providers/xinference/xinference.yaml index be9073c1ca..3500136693 100644 --- a/api/core/model_runtime/model_providers/xinference/xinference.yaml +++ b/api/core/model_runtime/model_providers/xinference/xinference.yaml @@ -56,3 +56,23 @@ model_credential_schema: placeholder: zh_Hans: 在此输入您的API密钥 en_US: Enter the api key + - variable: invoke_timeout + label: + zh_Hans: 调用超时时间 (单位:秒) + en_US: invoke timeout (unit:second) + type: text-input + required: true + default: '60' + placeholder: + zh_Hans: 在此输入调用超时时间 + en_US: Enter invoke timeout value + - variable: max_retries + label: + zh_Hans: 调用重试次数 + en_US: max retries + type: text-input + required: true + default: '3' + placeholder: + zh_Hans: 在此输入调用重试次数 + en_US: Enter max retries diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml index 7c8da51d1b..035d9881eb 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-0520.yaml @@ -8,6 +8,7 @@ features: - stream-tool-call model_properties: mode: chat + context_size: 131072 parameter_rules: - name: temperature use_template: temperature diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml index 7a7b4b0892..c3ee76141d 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-air.yaml @@ -8,6 +8,7 @@ features: - stream-tool-call model_properties: mode: chat + context_size: 131072 parameter_rules: - name: temperature use_template: temperature diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml index 09ad842801..1926db7ac3 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-airx.yaml @@ -8,6 +8,7 @@ features: - stream-tool-call model_properties: mode: chat + context_size: 8192 parameter_rules: - name: temperature use_template: temperature diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml index aee82a0602..e54b5de4a1 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flash.yaml @@ -8,6 +8,7 @@ features: - stream-tool-call model_properties: mode: chat + context_size: 131072 parameter_rules: - name: temperature use_template: temperature diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flashx.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flashx.yaml index 40ff7609c7..724fe48909 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flashx.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm-4-flashx.yaml @@ -8,6 +8,7 @@ features: - stream-tool-call model_properties: mode: chat + context_size: 131072 parameter_rules: - name: temperature use_template: temperature diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml index 791a77ba15..fa5b1e1fe9 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_3_turbo.yaml @@ -8,6 +8,7 @@ features: - stream-tool-call model_properties: mode: chat + context_size: 131072 parameter_rules: - name: temperature use_template: temperature diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml index 13ed1e49c9..e1eb13df3d 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4.yaml @@ -8,6 +8,7 @@ features: - stream-tool-call model_properties: mode: chat + context_size: 131072 parameter_rules: - name: temperature use_template: temperature diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml index badcee22db..c0c4e04d37 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_long.yaml @@ -8,7 +8,7 @@ features: - stream-tool-call model_properties: mode: chat - context_size: 10240 + context_size: 1048576 parameter_rules: - name: temperature use_template: temperature diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml index e2f785e1bc..c4f26f8ba9 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4_plus.yaml @@ -8,6 +8,7 @@ features: - stream-tool-call model_properties: mode: chat + context_size: 131072 parameter_rules: - name: temperature use_template: temperature diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml index 3baa298300..0d99f89cb8 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v.yaml @@ -4,6 +4,7 @@ label: model_type: llm model_properties: mode: chat + context_size: 2048 features: - vision parameter_rules: diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml index dbda18b888..5cd0e16b0e 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml +++ b/api/core/model_runtime/model_providers/zhipuai/llm/glm_4v_plus.yaml @@ -4,6 +4,7 @@ label: model_type: llm model_properties: mode: chat + context_size: 8192 features: - vision - video diff --git a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py index eddb94aba3..e0601d681c 100644 --- a/api/core/model_runtime/model_providers/zhipuai/llm/llm.py +++ b/api/core/model_runtime/model_providers/zhipuai/llm/llm.py @@ -22,18 +22,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large from core.model_runtime.model_providers.zhipuai._common import _CommonZhipuaiAI from core.model_runtime.utils import helper -GLM_JSON_MODE_PROMPT = """You should always follow the instructions and output a valid JSON object. -The structure of the JSON object you can found in the instructions, use {"answer": "$your_answer"} as the default structure -if you are not sure about the structure. - -And you should always end the block with a "```" to indicate the end of the JSON object. - - -{{instructions}} - - -```JSON""" # noqa: E501 - class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): def _invoke( @@ -64,42 +52,8 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): credentials_kwargs = self._to_credential_kwargs(credentials) # invoke model - # stop = stop or [] - # self._transform_json_prompts(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user) return self._generate(model, credentials_kwargs, prompt_messages, model_parameters, tools, stop, stream, user) - # def _transform_json_prompts(self, model: str, credentials: dict, - # prompt_messages: list[PromptMessage], model_parameters: dict, - # tools: list[PromptMessageTool] | None = None, stop: list[str] | None = None, - # stream: bool = True, user: str | None = None) \ - # -> None: - # """ - # Transform json prompts to model prompts - # """ - # if "}\n\n" not in stop: - # stop.append("}\n\n") - - # # check if there is a system message - # if len(prompt_messages) > 0 and isinstance(prompt_messages[0], SystemPromptMessage): - # # override the system message - # prompt_messages[0] = SystemPromptMessage( - # content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", prompt_messages[0].content) - # ) - # else: - # # insert the system message - # prompt_messages.insert(0, SystemPromptMessage( - # content=GLM_JSON_MODE_PROMPT.replace("{{instructions}}", "Please output a valid JSON object.") - # )) - # # check if the last message is a user message - # if len(prompt_messages) > 0 and isinstance(prompt_messages[-1], UserPromptMessage): - # # add ```JSON\n to the last message - # prompt_messages[-1].content += "\n```JSON\n" - # else: - # # append a user message - # prompt_messages.append(UserPromptMessage( - # content="```JSON\n" - # )) - def get_num_tokens( self, model: str, @@ -170,7 +124,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): :return: full response or stream response chunk generator result """ extra_model_kwargs = {} - # request to glm-4v-plus with stop words will always response "finish_reason":"network_error" + # request to glm-4v-plus with stop words will always respond "finish_reason":"network_error" if stop and model != "glm-4v-plus": extra_model_kwargs["stop"] = stop @@ -186,7 +140,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): # resolve zhipuai model not support system message and user message, assistant message must be in sequence new_prompt_messages: list[PromptMessage] = [] for prompt_message in prompt_messages: - copy_prompt_message = prompt_message.copy() + copy_prompt_message = prompt_message.model_copy() if copy_prompt_message.role in {PromptMessageRole.USER, PromptMessageRole.SYSTEM, PromptMessageRole.TOOL}: if isinstance(copy_prompt_message.content, list): # check if model is 'glm-4v' @@ -238,59 +192,38 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): params = self._construct_glm_4v_parameter(model, new_prompt_messages, model_parameters) else: params = {"model": model, "messages": [], **model_parameters} - # glm model - if not model.startswith("chatglm"): - for prompt_message in new_prompt_messages: - if prompt_message.role == PromptMessageRole.TOOL: + for prompt_message in new_prompt_messages: + if prompt_message.role == PromptMessageRole.TOOL: + params["messages"].append( + { + "role": "tool", + "content": prompt_message.content, + "tool_call_id": prompt_message.tool_call_id, + } + ) + elif isinstance(prompt_message, AssistantPromptMessage): + if prompt_message.tool_calls: params["messages"].append( { - "role": "tool", + "role": "assistant", "content": prompt_message.content, - "tool_call_id": prompt_message.tool_call_id, + "tool_calls": [ + { + "id": tool_call.id, + "type": tool_call.type, + "function": { + "name": tool_call.function.name, + "arguments": tool_call.function.arguments, + }, + } + for tool_call in prompt_message.tool_calls + ], } ) - elif isinstance(prompt_message, AssistantPromptMessage): - if prompt_message.tool_calls: - params["messages"].append( - { - "role": "assistant", - "content": prompt_message.content, - "tool_calls": [ - { - "id": tool_call.id, - "type": tool_call.type, - "function": { - "name": tool_call.function.name, - "arguments": tool_call.function.arguments, - }, - } - for tool_call in prompt_message.tool_calls - ], - } - ) - else: - params["messages"].append({"role": "assistant", "content": prompt_message.content}) else: - params["messages"].append( - {"role": prompt_message.role.value, "content": prompt_message.content} - ) - else: - # chatglm model - for prompt_message in new_prompt_messages: - # merge system message to user message - if prompt_message.role in { - PromptMessageRole.SYSTEM, - PromptMessageRole.TOOL, - PromptMessageRole.USER, - }: - if len(params["messages"]) > 0 and params["messages"][-1]["role"] == "user": - params["messages"][-1]["content"] += "\n\n" + prompt_message.content - else: - params["messages"].append({"role": "user", "content": prompt_message.content}) - else: - params["messages"].append( - {"role": prompt_message.role.value, "content": prompt_message.content} - ) + params["messages"].append({"role": "assistant", "content": prompt_message.content}) + else: + params["messages"].append({"role": prompt_message.role.value, "content": prompt_message.content}) if tools and len(tools) > 0: params["tools"] = [{"type": "function", "function": helper.dump_model(tool)} for tool in tools] @@ -406,7 +339,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): Handle llm stream response :param model: model name - :param response: response + :param responses: response :param prompt_messages: prompt messages :return: llm response chunk generator result """ @@ -505,7 +438,7 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel): if tools and len(tools) > 0: text += "\n\nTools:" for tool in tools: - text += f"\n{tool.json()}" + text += f"\n{tool.model_dump_json()}" # trim off the trailing ' ' that might come from the "Assistant: " return text.rstrip() diff --git a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py index f629b62fd5..2428284ba9 100644 --- a/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/zhipuai/text_embedding/text_embedding.py @@ -105,17 +105,6 @@ class ZhipuAITextEmbeddingModel(_CommonZhipuaiAI, TextEmbeddingModel): return [list(map(float, e)) for e in embeddings], embedding_used_tokens - def embed_query(self, text: str) -> list[float]: - """Call out to ZhipuAI's embedding endpoint. - - Args: - text: The text to embed. - - Returns: - Embeddings for the text. - """ - return self.embed_documents([text])[0] - def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage: """ Calculate response usage diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 256595286f..71ff03b6ef 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -1,5 +1,5 @@ from datetime import datetime -from enum import Enum +from enum import StrEnum from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, field_validator @@ -122,7 +122,7 @@ trace_info_info_map = { } -class TraceTaskName(str, Enum): +class TraceTaskName(StrEnum): CONVERSATION_TRACE = "conversation" WORKFLOW_TRACE = "workflow" MESSAGE_TRACE = "message" diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index 447b799f1f..f486da3a6d 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -1,5 +1,5 @@ from datetime import datetime -from enum import Enum +from enum import StrEnum from typing import Any, Optional, Union from pydantic import BaseModel, ConfigDict, Field, field_validator @@ -39,7 +39,7 @@ def validate_input_output(v, field_name): return v -class LevelEnum(str, Enum): +class LevelEnum(StrEnum): DEBUG = "DEBUG" WARNING = "WARNING" ERROR = "ERROR" @@ -178,7 +178,7 @@ class LangfuseSpan(BaseModel): return validate_input_output(v, field_name) -class UnitEnum(str, Enum): +class UnitEnum(StrEnum): CHARACTERS = "CHARACTERS" TOKENS = "TOKENS" SECONDS = "SECONDS" diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 16c76f363c..99221d669b 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -1,5 +1,5 @@ from datetime import datetime -from enum import Enum +from enum import StrEnum from typing import Any, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -8,7 +8,7 @@ from pydantic_core.core_schema import ValidationInfo from core.ops.utils import replace_text_with_content -class LangSmithRunType(str, Enum): +class LangSmithRunType(StrEnum): tool = "tool" chain = "chain" llm = "llm" diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 1069889abd..b7799ce1fb 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -445,7 +445,7 @@ class TraceTask: "ls_provider": message_data.model_provider, "ls_model_name": message_data.model_id, "status": message_data.status, - "from_end_user_id": message_data.from_account_id, + "from_end_user_id": message_data.from_end_user_id, "from_account_id": message_data.from_account_id, "agent_based": message_data.agent_based, "workflow_run_id": message_data.workflow_run_id, @@ -521,7 +521,7 @@ class TraceTask: "ls_provider": message_data.model_provider, "ls_model_name": message_data.model_id, "status": message_data.status, - "from_end_user_id": message_data.from_account_id, + "from_end_user_id": message_data.from_end_user_id, "from_account_id": message_data.from_account_id, "agent_based": message_data.agent_based, "workflow_run_id": message_data.workflow_run_id, @@ -570,7 +570,7 @@ class TraceTask: "ls_provider": message_data.model_provider, "ls_model_name": message_data.model_id, "status": message_data.status, - "from_end_user_id": message_data.from_account_id, + "from_end_user_id": message_data.from_end_user_id, "from_account_id": message_data.from_account_id, "agent_based": message_data.agent_based, "workflow_run_id": message_data.workflow_run_id, diff --git a/api/core/prompt/prompt_templates/advanced_prompt_templates.py b/api/core/prompt/prompt_templates/advanced_prompt_templates.py index 0ab7f526cc..e55966eeee 100644 --- a/api/core/prompt/prompt_templates/advanced_prompt_templates.py +++ b/api/core/prompt/prompt_templates/advanced_prompt_templates.py @@ -5,7 +5,7 @@ BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找 CHAT_APP_COMPLETION_PROMPT_CONFIG = { "completion_prompt_config": { "prompt": { - "text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside XML tags.\n\n\n{{#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant: " # noqa: E501 + "text": "{{#pre_prompt#}}\nHere are the chat histories between human and assistant, inside XML tags.\n\n\n{{#histories#}}\n\n\n\nHuman: {{#query#}}\n\nAssistant: " # noqa: E501 }, "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, }, diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 5a3481b963..93dd92f188 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from core.file.models import File -class ModelMode(str, enum.Enum): +class ModelMode(enum.StrEnum): COMPLETION = "completion" CHAT = "chat" diff --git a/api/core/prompt/utils/get_thread_messages_length.py b/api/core/prompt/utils/get_thread_messages_length.py new file mode 100644 index 0000000000..f49466db6d --- /dev/null +++ b/api/core/prompt/utils/get_thread_messages_length.py @@ -0,0 +1,32 @@ +from core.prompt.utils.extract_thread_messages import extract_thread_messages +from extensions.ext_database import db +from models.model import Message + + +def get_thread_messages_length(conversation_id: str) -> int: + """ + Get the number of thread messages based on the parent message id. + """ + # Fetch all messages related to the conversation + query = ( + db.session.query( + Message.id, + Message.parent_message_id, + Message.answer, + ) + .filter( + Message.conversation_id == conversation_id, + ) + .order_by(Message.created_at.desc()) + ) + + messages = query.all() + + # Extract thread messages + thread_messages = extract_thread_messages(messages) + + # Exclude the newly created message with an empty answer + if thread_messages and not thread_messages[0].answer: + thread_messages.pop(0) + + return len(thread_messages) diff --git a/api/core/rag/datasource/keyword/keyword_type.py b/api/core/rag/datasource/keyword/keyword_type.py index d6deba3fb0..d845c7111d 100644 --- a/api/core/rag/datasource/keyword/keyword_type.py +++ b/api/core/rag/datasource/keyword/keyword_type.py @@ -1,5 +1,5 @@ -from enum import Enum +from enum import StrEnum -class KeyWordType(str, Enum): +class KeyWordType(StrEnum): JIEBA = "jieba" diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 57af05861c..b2141396d6 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -110,8 +110,12 @@ class RetrievalService: str(dataset.tenant_id), reranking_mode, reranking_model, weights, False ) all_documents = data_post_processor.invoke( - query=query, documents=all_documents, score_threshold=score_threshold, top_n=top_k + query=query, + documents=all_documents, + score_threshold=score_threshold, + top_n=top_k, ) + return all_documents @classmethod @@ -178,7 +182,10 @@ class RetrievalService: ) all_documents.extend( data_post_processor.invoke( - query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents), ) ) else: @@ -220,7 +227,10 @@ class RetrievalService: ) all_documents.extend( data_post_processor.invoke( - query=query, documents=documents, score_threshold=score_threshold, top_n=len(documents) + query=query, + documents=documents, + score_threshold=score_threshold, + top_n=len(documents), ) ) else: diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py index 8dd26a073b..c44338d42a 100644 --- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py +++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py @@ -104,8 +104,7 @@ class OceanBaseVector(BaseVector): val = int(row[6]) vals.append(val) if len(vals) == 0: - print("ob_vector_memory_limit_percentage not found in parameters.") - exit(1) + raise ValueError("ob_vector_memory_limit_percentage not found in parameters.") if any(val == 0 for val in vals): try: self._client.perform_raw_text_sql("ALTER SYSTEM SET ob_vector_memory_limit_percentage = 30") @@ -200,10 +199,10 @@ class OceanBaseVectorFactory(AbstractVectorFactory): return OceanBaseVector( collection_name, OceanBaseVectorConfig( - host=dify_config.OCEANBASE_VECTOR_HOST, - port=dify_config.OCEANBASE_VECTOR_PORT, - user=dify_config.OCEANBASE_VECTOR_USER, + host=dify_config.OCEANBASE_VECTOR_HOST or "", + port=dify_config.OCEANBASE_VECTOR_PORT or 0, + user=dify_config.OCEANBASE_VECTOR_USER or "", password=(dify_config.OCEANBASE_VECTOR_PASSWORD or ""), - database=dify_config.OCEANBASE_VECTOR_DATABASE, + database=dify_config.OCEANBASE_VECTOR_DATABASE or "", ), ) diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py index 4ced5d61e5..71c58c9d0c 100644 --- a/api/core/rag/datasource/vdb/oracle/oraclevector.py +++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py @@ -230,7 +230,6 @@ class OracleVector(BaseVector): except LookupError: nltk.download("punkt") nltk.download("stopwords") - print("run download") e_str = re.sub(r"[^\w ]", "", query) all_tokens = nltk.word_tokenize(e_str) stop_words = stopwords.words("english") diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py index a38f84636e..cfd47aac5b 100644 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py +++ b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py @@ -375,7 +375,6 @@ class TidbOnQdrantVector(BaseVector): for result in results: if result: document = self._document_from_scored_point(result, Field.CONTENT_KEY.value, Field.METADATA_KEY.value) - document.metadata["vector"] = result.vector documents.append(document) return documents @@ -394,6 +393,7 @@ class TidbOnQdrantVector(BaseVector): ) -> Document: return Document( page_content=scored_point.payload.get(content_payload_key), + vector=scored_point.vector, metadata=scored_point.payload.get(metadata_payload_key) or {}, ) diff --git a/api/core/rag/datasource/vdb/upstash/upstash_vector.py b/api/core/rag/datasource/vdb/upstash/upstash_vector.py index df1b550b40..5c3fee98a9 100644 --- a/api/core/rag/datasource/vdb/upstash/upstash_vector.py +++ b/api/core/rag/datasource/vdb/upstash/upstash_vector.py @@ -64,7 +64,7 @@ class UpstashVector(BaseVector): item_ids = [] for doc_id in ids: ids = self.get_ids_by_metadata_field("doc_id", doc_id) - if id: + if ids: item_ids += ids self._delete_by_ids(ids=item_ids) @@ -95,9 +95,10 @@ class UpstashVector(BaseVector): metadata = record.metadata text = record.data score = record.score - metadata["score"] = score - if score > score_threshold: - docs.append(Document(page_content=text, metadata=metadata)) + if metadata is not None and text is not None: + metadata["score"] = score + if score > score_threshold: + docs.append(Document(page_content=text, metadata=metadata)) return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: @@ -123,7 +124,7 @@ class UpstashVectorFactory(AbstractVectorFactory): return UpstashVector( collection_name=collection_name, config=UpstashVectorConfig( - url=dify_config.UPSTASH_VECTOR_URL, - token=dify_config.UPSTASH_VECTOR_TOKEN, + url=dify_config.UPSTASH_VECTOR_URL or "", + token=dify_config.UPSTASH_VECTOR_TOKEN or "", ), ) diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py index 8e53e3ae84..05183c0371 100644 --- a/api/core/rag/datasource/vdb/vector_type.py +++ b/api/core/rag/datasource/vdb/vector_type.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class VectorType(str, Enum): +class VectorType(StrEnum): ANALYTICDB = "analyticdb" CHROMA = "chroma" MILVUS = "milvus" diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 1157c5c8e4..fc8e0440c3 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -102,7 +102,8 @@ class CacheEmbedding(Embeddings): embedding = redis_client.get(embedding_cache_key) if embedding: redis_client.expire(embedding_cache_key, 600) - return list(np.frombuffer(base64.b64decode(embedding), dtype="float")) + decoded_embedding = np.frombuffer(base64.b64decode(embedding), dtype="float") + return [float(x) for x in decoded_embedding] try: embedding_result = self._model_instance.invoke_text_embedding( texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index 313bdce48b..0c38a9c076 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -86,7 +86,7 @@ class WordExtractor(BaseExtractor): image_count += 1 if rel.is_external: url = rel.reltype - response = ssrf_proxy.get(url, stream=True) + response = ssrf_proxy.get(url) if response.status_code == 200: image_ext = mimetypes.guess_extension(response.headers["Content-Type"]) file_uuid = str(uuid.uuid4()) @@ -114,10 +114,10 @@ class WordExtractor(BaseExtractor): mime_type=mime_type or "", created_by=self.user_id, created_by_role=CreatedByRole.ACCOUNT, - created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=True, used_by=self.user_id, - used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), ) db.session.add(upload_file) diff --git a/api/core/rag/rerank/rerank_type.py b/api/core/rag/rerank/rerank_type.py index d71eb2daa8..b2d1314654 100644 --- a/api/core/rag/rerank/rerank_type.py +++ b/api/core/rag/rerank/rerank_type.py @@ -1,6 +1,6 @@ -from enum import Enum +from enum import StrEnum -class RerankMode(str, Enum): +class RerankMode(StrEnum): RERANKING_MODEL = "reranking_model" WEIGHTED_SCORE = "weighted_score" diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index d8637fd2cb..4fc383f91b 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Optional, Union, cast from pydantic import BaseModel, Field, field_validator @@ -137,7 +137,7 @@ class ToolParameterOption(BaseModel): class ToolParameter(BaseModel): - class ToolParameterType(str, Enum): + class ToolParameterType(StrEnum): STRING = "string" NUMBER = "number" BOOLEAN = "boolean" diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py index 48755753ac..9896081221 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.py @@ -12,7 +12,7 @@ class LambdaTranslateUtilsTool(BuiltinTool): def _invoke_lambda(self, text_content, src_lang, dest_lang, model_id, dictionary_name, request_type, lambda_name): msg = { - "src_content": text_content, + "src_contents": [text_content], "src_lang": src_lang, "dest_lang": dest_lang, "dictionary_id": dictionary_name, diff --git a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml index 3bb133c7ec..646602fcd6 100644 --- a/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml +++ b/api/core/tools/provider/builtin/aws/tools/lambda_translate_utils.yaml @@ -8,9 +8,9 @@ identity: icon: icon.svg description: human: - en_US: A util tools for LLM translation, extra deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag - zh_Hans: 大语言模型翻译工具(专词映射获取),需要在AWS上进行额外部署,可参考Github Repo - https://github.com/ybalbert001/dynamodb-rag - pt_BR: A util tools for LLM translation, specific Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/ybalbert001/dynamodb-rag + en_US: A util tools for LLM translation, extra deployment is needed on AWS. Please refer Github Repo - https://github.com/aws-samples/rag-based-translation-with-dynamodb-and-bedrock + zh_Hans: 大语言模型翻译工具(专词映射获取),需要在AWS上进行额外部署,可参考Github Repo - https://github.com/aws-samples/rag-based-translation-with-dynamodb-and-bedrock + pt_BR: A util tools for LLM translation, specific Lambda Function deployment is needed on AWS. Please refer Github Repo - https://github.com/aws-samples/rag-based-translation-with-dynamodb-and-bedrock llm: A util tools for translation. parameters: - name: text_content diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py b/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py new file mode 100644 index 0000000000..e05e2d9bf7 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.py @@ -0,0 +1,67 @@ +import json +from typing import Any, Union + +import boto3 + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +# 定义标签映射 +LABEL_MAPPING = {"LABEL_0": "SAFE", "LABEL_1": "NO_SAFE"} + + +class ContentModerationTool(BuiltinTool): + sagemaker_client: Any = None + sagemaker_endpoint: str = None + + def _invoke_sagemaker(self, payload: dict, endpoint: str): + response = self.sagemaker_client.invoke_endpoint( + EndpointName=endpoint, + Body=json.dumps(payload), + ContentType="application/json", + ) + # Parse response + response_body = response["Body"].read().decode("utf8") + + json_obj = json.loads(response_body) + + # Handle nested JSON if present + if isinstance(json_obj, dict) and "body" in json_obj: + body_content = json.loads(json_obj["body"]) + raw_label = body_content.get("label") + else: + raw_label = json_obj.get("label") + + # 映射标签并返回 + result = LABEL_MAPPING.get(raw_label, "NO_SAFE") # 如果映射中没有找到,默认返回NO_SAFE + return result + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + try: + if not self.sagemaker_client: + aws_region = tool_parameters.get("aws_region") + if aws_region: + self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + else: + self.sagemaker_client = boto3.client("sagemaker-runtime") + + if not self.sagemaker_endpoint: + self.sagemaker_endpoint = tool_parameters.get("sagemaker_endpoint") + + content_text = tool_parameters.get("content_text") + + payload = {"text": content_text} + + result = self._invoke_sagemaker(payload, self.sagemaker_endpoint) + + return self.create_text_message(text=result) + + except Exception as e: + return self.create_text_message(f"Exception {str(e)}") diff --git a/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.yaml b/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.yaml new file mode 100644 index 0000000000..76dcb89632 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/sagemaker_chinese_toxicity_detector.yaml @@ -0,0 +1,46 @@ +identity: + name: chinese_toxicity_detector + author: AWS + label: + en_US: Chinese Toxicity Detector + zh_Hans: 中文有害内容检测 + icon: icon.svg +description: + human: + en_US: A tool to detect Chinese toxicity + zh_Hans: 检测中文有害内容的工具 + llm: A tool that checks if Chinese content is safe for work +parameters: + - name: sagemaker_endpoint + type: string + required: true + label: + en_US: sagemaker endpoint for moderation + zh_Hans: 内容审核的SageMaker端点 + human_description: + en_US: sagemaker endpoint for content moderation + zh_Hans: 内容审核的SageMaker端点 + llm_description: sagemaker endpoint for content moderation + form: form + - name: content_text + type: string + required: true + label: + en_US: content text + zh_Hans: 待审核文本 + human_description: + en_US: text content to be moderated + zh_Hans: 需要审核的文本内容 + llm_description: text content to be moderated + form: llm + - name: aws_region + type: string + required: false + label: + en_US: region of sagemaker endpoint + zh_Hans: SageMaker 端点所在的region + human_description: + en_US: region of sagemaker endpoint + zh_Hans: SageMaker 端点所在的region + llm_description: region of sagemaker endpoint + form: form diff --git a/api/core/tools/provider/builtin/aws/tools/transcribe_asr.py b/api/core/tools/provider/builtin/aws/tools/transcribe_asr.py new file mode 100644 index 0000000000..7520f6bca8 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/transcribe_asr.py @@ -0,0 +1,418 @@ +import json +import logging +import os +import re +import time +import uuid +from typing import Any, Union +from urllib.parse import urlparse + +import boto3 +import requests +from botocore.exceptions import ClientError +from requests.exceptions import RequestException + +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.tool.builtin_tool import BuiltinTool + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +LanguageCodeOptions = [ + "af-ZA", + "ar-AE", + "ar-SA", + "da-DK", + "de-CH", + "de-DE", + "en-AB", + "en-AU", + "en-GB", + "en-IE", + "en-IN", + "en-US", + "en-WL", + "es-ES", + "es-US", + "fa-IR", + "fr-CA", + "fr-FR", + "he-IL", + "hi-IN", + "id-ID", + "it-IT", + "ja-JP", + "ko-KR", + "ms-MY", + "nl-NL", + "pt-BR", + "pt-PT", + "ru-RU", + "ta-IN", + "te-IN", + "tr-TR", + "zh-CN", + "zh-TW", + "th-TH", + "en-ZA", + "en-NZ", + "vi-VN", + "sv-SE", + "ab-GE", + "ast-ES", + "az-AZ", + "ba-RU", + "be-BY", + "bg-BG", + "bn-IN", + "bs-BA", + "ca-ES", + "ckb-IQ", + "ckb-IR", + "cs-CZ", + "cy-WL", + "el-GR", + "et-ET", + "eu-ES", + "fi-FI", + "gl-ES", + "gu-IN", + "ha-NG", + "hr-HR", + "hu-HU", + "hy-AM", + "is-IS", + "ka-GE", + "kab-DZ", + "kk-KZ", + "kn-IN", + "ky-KG", + "lg-IN", + "lt-LT", + "lv-LV", + "mhr-RU", + "mi-NZ", + "mk-MK", + "ml-IN", + "mn-MN", + "mr-IN", + "mt-MT", + "no-NO", + "or-IN", + "pa-IN", + "pl-PL", + "ps-AF", + "ro-RO", + "rw-RW", + "si-LK", + "sk-SK", + "sl-SI", + "so-SO", + "sr-RS", + "su-ID", + "sw-BI", + "sw-KE", + "sw-RW", + "sw-TZ", + "sw-UG", + "tl-PH", + "tt-RU", + "ug-CN", + "uk-UA", + "uz-UZ", + "wo-SN", + "zu-ZA", +] + +MediaFormat = ["mp3", "mp4", "wav", "flac", "ogg", "amr", "webm", "m4a"] + + +def is_url(text): + if not text: + return False + text = text.strip() + # Regular expression pattern for URL validation + pattern = re.compile( + r"^" # Start of the string + r"(?:http|https)://" # Protocol (http or https) + r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}\.?)|" # Domain + r"localhost|" # localhost + r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3})" # IP address + r"(?::\d+)?" # Optional port + r"(?:/?|[/?]\S+)" # Path + r"$", # End of the string + re.IGNORECASE, + ) + return bool(pattern.match(text)) + + +def upload_file_from_url_to_s3(s3_client, url, bucket_name, s3_key=None, max_retries=3): + """ + Upload a file from a URL to an S3 bucket with retries and better error handling. + + Parameters: + - s3_client + - url (str): The URL of the file to upload + - bucket_name (str): The name of the S3 bucket + - s3_key (str): The desired key (path) in S3. If None, will use the filename from URL + - max_retries (int): Maximum number of retry attempts + + Returns: + - tuple: (bool, str) - (Success status, Message) + """ + + # Validate inputs + if not url or not bucket_name: + return False, "URL and bucket name are required" + + retry_count = 0 + while retry_count < max_retries: + try: + # Download the file from URL + response = requests.get(url, stream=True, timeout=30) + response.raise_for_status() + + # If s3_key is not provided, try to get filename from URL + if not s3_key: + parsed_url = urlparse(url) + filename = os.path.basename(parsed_url.path.split("/file-preview")[0]) + s3_key = "transcribe-files/" + filename + + # Upload the file to S3 + s3_client.upload_fileobj( + response.raw, + bucket_name, + s3_key, + ExtraArgs={ + "ContentType": response.headers.get("content-type"), + "ACL": "private", # Ensure the uploaded file is private + }, + ) + + return f"s3://{bucket_name}/{s3_key}", f"Successfully uploaded file to s3://{bucket_name}/{s3_key}" + + except RequestException as e: + retry_count += 1 + if retry_count == max_retries: + return None, f"Failed to download file from URL after {max_retries} attempts: {str(e)}" + continue + + except ClientError as e: + return None, f"AWS S3 error: {str(e)}" + + except Exception as e: + return None, f"Unexpected error: {str(e)}" + + return None, "Maximum retries exceeded" + + +class TranscribeTool(BuiltinTool): + s3_client: Any = None + transcribe_client: Any = None + + """ + Note that you must include one of LanguageCode, IdentifyLanguage, + or IdentifyMultipleLanguages in your request. + If you include more than one of these parameters, your transcription job fails. + """ + + def _transcribe_audio(self, audio_file_uri, file_type, **extra_args): + uuid_str = str(uuid.uuid4()) + job_name = f"{int(time.time())}-{uuid_str}" + try: + # Start transcription job + response = self.transcribe_client.start_transcription_job( + TranscriptionJobName=job_name, Media={"MediaFileUri": audio_file_uri}, **extra_args + ) + + # Wait for the job to complete + while True: + status = self.transcribe_client.get_transcription_job(TranscriptionJobName=job_name) + if status["TranscriptionJob"]["TranscriptionJobStatus"] in ["COMPLETED", "FAILED"]: + break + time.sleep(5) + + if status["TranscriptionJob"]["TranscriptionJobStatus"] == "COMPLETED": + return status["TranscriptionJob"]["Transcript"]["TranscriptFileUri"], None + else: + return None, f"Error: TranscriptionJobStatus:{status['TranscriptionJob']['TranscriptionJobStatus']} " + + except Exception as e: + return None, f"Error: {str(e)}" + + def _download_and_read_transcript(self, transcript_file_uri: str, max_retries: int = 3) -> tuple[str, str]: + """ + Download and read the transcript file from the given URI. + + Parameters: + - transcript_file_uri (str): The URI of the transcript file + - max_retries (int): Maximum number of retry attempts + + Returns: + - tuple: (text, error) - (Transcribed text if successful, error message if failed) + """ + retry_count = 0 + while retry_count < max_retries: + try: + # Download the transcript file + response = requests.get(transcript_file_uri, timeout=30) + response.raise_for_status() + + # Parse the JSON content + transcript_data = response.json() + + # Check if speaker labels are present and enabled + has_speaker_labels = ( + "results" in transcript_data + and "speaker_labels" in transcript_data["results"] + and "segments" in transcript_data["results"]["speaker_labels"] + ) + + if has_speaker_labels: + # Get speaker segments + segments = transcript_data["results"]["speaker_labels"]["segments"] + items = transcript_data["results"]["items"] + + # Create a mapping of start_time -> speaker_label + time_to_speaker = {} + for segment in segments: + speaker_label = segment["speaker_label"] + for item in segment["items"]: + time_to_speaker[item["start_time"]] = speaker_label + + # Build transcript with speaker labels + current_speaker = None + transcript_parts = [] + + for item in items: + # Skip non-pronunciation items (like punctuation) + if item["type"] == "punctuation": + transcript_parts.append(item["alternatives"][0]["content"]) + continue + + start_time = item["start_time"] + speaker = time_to_speaker.get(start_time) + + if speaker != current_speaker: + current_speaker = speaker + transcript_parts.append(f"\n[{speaker}]: ") + + transcript_parts.append(item["alternatives"][0]["content"]) + + return " ".join(transcript_parts).strip(), None + else: + # Extract the transcription text + # The transcript text is typically in the 'results' -> 'transcripts' array + if "results" in transcript_data and "transcripts" in transcript_data["results"]: + transcripts = transcript_data["results"]["transcripts"] + if transcripts: + # Combine all transcript segments + full_text = " ".join(t.get("transcript", "") for t in transcripts) + return full_text, None + + return None, "No transcripts found in the response" + + except requests.exceptions.RequestException as e: + retry_count += 1 + if retry_count == max_retries: + return None, f"Failed to download transcript file after {max_retries} attempts: {str(e)}" + continue + + except json.JSONDecodeError as e: + return None, f"Failed to parse transcript JSON: {str(e)}" + + except Exception as e: + return None, f"Unexpected error while processing transcript: {str(e)}" + + return None, "Maximum retries exceeded" + + def _invoke( + self, + user_id: str, + tool_parameters: dict[str, Any], + ) -> Union[ToolInvokeMessage, list[ToolInvokeMessage]]: + """ + invoke tools + """ + try: + if not self.transcribe_client: + aws_region = tool_parameters.get("aws_region") + if aws_region: + self.transcribe_client = boto3.client("transcribe", region_name=aws_region) + self.s3_client = boto3.client("s3", region_name=aws_region) + else: + self.transcribe_client = boto3.client("transcribe") + self.s3_client = boto3.client("s3") + + file_url = tool_parameters.get("file_url") + file_type = tool_parameters.get("file_type") + language_code = tool_parameters.get("language_code") + identify_language = tool_parameters.get("identify_language", True) + identify_multiple_languages = tool_parameters.get("identify_multiple_languages", False) + language_options_str = tool_parameters.get("language_options") + s3_bucket_name = tool_parameters.get("s3_bucket_name") + ShowSpeakerLabels = tool_parameters.get("ShowSpeakerLabels", True) + MaxSpeakerLabels = tool_parameters.get("MaxSpeakerLabels", 2) + + # Check the input params + if not s3_bucket_name: + return self.create_text_message(text="s3_bucket_name is required") + language_options = None + if language_options_str: + language_options = language_options_str.split("|") + for lang in language_options: + if lang not in LanguageCodeOptions: + return self.create_text_message( + text=f"{lang} is not supported, should be one of {LanguageCodeOptions}" + ) + if language_code and language_code not in LanguageCodeOptions: + err_msg = f"language_code:{language_code} is not supported, should be one of {LanguageCodeOptions}" + return self.create_text_message(text=err_msg) + + err_msg = f"identify_language:{identify_language}, \ + identify_multiple_languages:{identify_multiple_languages}, \ + Note that you must include one of LanguageCode, IdentifyLanguage, \ + or IdentifyMultipleLanguages in your request. \ + If you include more than one of these parameters, \ + your transcription job fails." + if not language_code: + if identify_language and identify_multiple_languages: + return self.create_text_message(text=err_msg) + else: + if identify_language or identify_multiple_languages: + return self.create_text_message(text=err_msg) + + extra_args = { + "IdentifyLanguage": identify_language, + "IdentifyMultipleLanguages": identify_multiple_languages, + } + if language_code: + extra_args["LanguageCode"] = language_code + if language_options: + extra_args["LanguageOptions"] = language_options + if ShowSpeakerLabels: + extra_args["Settings"] = {"ShowSpeakerLabels": ShowSpeakerLabels, "MaxSpeakerLabels": MaxSpeakerLabels} + + # upload to s3 bucket + s3_path_result, error = upload_file_from_url_to_s3(self.s3_client, url=file_url, bucket_name=s3_bucket_name) + if not s3_path_result: + return self.create_text_message(text=error) + + transcript_file_uri, error = self._transcribe_audio( + audio_file_uri=s3_path_result, + file_type=file_type, + **extra_args, + ) + if not transcript_file_uri: + return self.create_text_message(text=error) + + # Download and read the transcript + transcript_text, error = self._download_and_read_transcript(transcript_file_uri) + if not transcript_text: + return self.create_text_message(text=error) + + return self.create_text_message(text=transcript_text) + + except Exception as e: + return self.create_text_message(f"Exception {str(e)}") diff --git a/api/core/tools/provider/builtin/aws/tools/transcribe_asr.yaml b/api/core/tools/provider/builtin/aws/tools/transcribe_asr.yaml new file mode 100644 index 0000000000..0dccd615d2 --- /dev/null +++ b/api/core/tools/provider/builtin/aws/tools/transcribe_asr.yaml @@ -0,0 +1,133 @@ +identity: + name: transcribe_asr + author: AWS + label: + en_US: TranscribeASR + zh_Hans: Transcribe语音识别转录 + pt_BR: TranscribeASR + icon: icon.svg +description: + human: + en_US: A tool for ASR (Automatic Speech Recognition) - https://github.com/aws-samples/dify-aws-tool + zh_Hans: AWS 语音识别转录服务, 请参考 https://aws.amazon.com/cn/pm/transcribe/#Learn_More_About_Amazon_Transcribe + pt_BR: A tool for ASR (Automatic Speech Recognition). + llm: A tool for ASR (Automatic Speech Recognition). +parameters: + - name: file_url + type: string + required: true + label: + en_US: video or audio file url for transcribe + zh_Hans: 语音或者视频文件url + pt_BR: video or audio file url for transcribe + human_description: + en_US: video or audio file url for transcribe + zh_Hans: 语音或者视频文件url + pt_BR: video or audio file url for transcribe + llm_description: video or audio file url for transcribe + form: llm + - name: language_code + type: string + required: false + label: + en_US: Language Code + zh_Hans: 语言编码 + pt_BR: Language Code + human_description: + en_US: The language code used to create your transcription job. refer to :https://docs.aws.amazon.com/transcribe/latest/dg/supported-languages.html + zh_Hans: 语言编码,例如zh-CN, en-US 可参考 https://docs.aws.amazon.com/transcribe/latest/dg/supported-languages.html + pt_BR: The language code used to create your transcription job. refer to :https://docs.aws.amazon.com/transcribe/latest/dg/supported-languages.html + llm_description: The language code used to create your transcription job. + form: llm + - name: identify_language + type: boolean + default: true + required: false + label: + en_US: Automactically Identify Language + zh_Hans: 自动识别语言 + pt_BR: Automactically Identify Language + human_description: + en_US: Automactically Identify Language + zh_Hans: 自动识别语言 + pt_BR: Automactically Identify Language + llm_description: Enable Automactically Identify Language + form: form + - name: identify_multiple_languages + type: boolean + required: false + label: + en_US: Automactically Identify Multiple Languages + zh_Hans: 自动识别多种语言 + pt_BR: Automactically Identify Multiple Languages + human_description: + en_US: Automactically Identify Multiple Languages + zh_Hans: 自动识别多种语言 + pt_BR: Automactically Identify Multiple Languages + llm_description: Enable Automactically Identify Multiple Languages + form: form + - name: language_options + type: string + required: false + label: + en_US: Language Options + zh_Hans: 语言种类选项 + pt_BR: Language Options + human_description: + en_US: Seperated by |, e.g:zh-CN|en-US, You can specify two or more language codes that represent the languages you think may be present in your media + zh_Hans: 您可以指定两个或更多的语言代码来表示您认为可能出现在媒体中的语言。用|分隔,如 zh-CN|en-US + pt_BR: Seperated by |, e.g:zh-CN|en-US, You can specify two or more language codes that represent the languages you think may be present in your media + llm_description: Seperated by |, e.g:zh-CN|en-US, You can specify two or more language codes that represent the languages you think may be present in your media + form: llm + - name: s3_bucket_name + type: string + required: true + label: + en_US: s3 bucket name + zh_Hans: s3 存储桶名称 + pt_BR: s3 bucket name + human_description: + en_US: s3 bucket name to store transcribe files (don't add prefix s3://) + zh_Hans: s3 存储桶名称,用于存储转录文件 (不需要前缀 s3://) + pt_BR: s3 bucket name to store transcribe files (don't add prefix s3://) + llm_description: s3 bucket name to store transcribe files + form: form + - name: ShowSpeakerLabels + type: boolean + required: true + default: true + label: + en_US: ShowSpeakerLabels + zh_Hans: 显示说话人标签 + pt_BR: ShowSpeakerLabels + human_description: + en_US: Enables speaker partitioning (diarization) in your transcription output + zh_Hans: 在转录输出中启用说话人分区(说话人分离) + pt_BR: Enables speaker partitioning (diarization) in your transcription output + llm_description: Enables speaker partitioning (diarization) in your transcription output + form: form + - name: MaxSpeakerLabels + type: number + required: true + default: 2 + label: + en_US: MaxSpeakerLabels + zh_Hans: 说话人标签数量 + pt_BR: MaxSpeakerLabels + human_description: + en_US: Specify the maximum number of speakers you want to partition in your media + zh_Hans: 指定您希望在媒体中划分的最多演讲者数量。 + pt_BR: Specify the maximum number of speakers you want to partition in your media + llm_description: Specify the maximum number of speakers you want to partition in your media + form: form + - name: aws_region + type: string + required: false + label: + en_US: AWS Region + zh_Hans: AWS 区域 + human_description: + en_US: Please enter the AWS region for the transcribe service, for example 'us-east-1'. + zh_Hans: 请输入Transcribe的 AWS 区域,例如 'us-east-1'。 + llm_description: Please enter the AWS region for the transcribe service, for example 'us-east-1'. + form: form diff --git a/api/core/tools/provider/builtin/chart/chart.py b/api/core/tools/provider/builtin/chart/chart.py index dfa3fbea6a..8fa647d9ed 100644 --- a/api/core/tools/provider/builtin/chart/chart.py +++ b/api/core/tools/provider/builtin/chart/chart.py @@ -1,3 +1,4 @@ +import matplotlib import matplotlib.pyplot as plt from matplotlib.font_manager import FontProperties, fontManager @@ -5,7 +6,7 @@ from core.tools.provider.builtin_tool_provider import BuiltinToolProviderControl def set_chinese_font(): - font_list = [ + to_find_fonts = [ "PingFang SC", "SimHei", "Microsoft YaHei", @@ -15,16 +16,16 @@ def set_chinese_font(): "Noto Sans CJK SC", "Noto Sans CJK JP", ] - - for font in font_list: - if font in fontManager.ttflist: - chinese_font = FontProperties(font) - if chinese_font.get_name() == font: - return chinese_font + installed_fonts = frozenset(fontInfo.name for fontInfo in fontManager.ttflist) + for font in to_find_fonts: + if font in installed_fonts: + return FontProperties(font) return FontProperties() +# use non-interactive backend to prevent `RuntimeError: main thread is not in main loop` +matplotlib.use("Agg") # use a business theme plt.style.use("seaborn-v0_8-darkgrid") plt.rcParams["axes.unicode_minus"] = False diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py index 54bb38755a..b3c630878f 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.py @@ -18,6 +18,12 @@ class DuckDuckGoImageSearchTool(BuiltinTool): "size": tool_parameters.get("size"), "max_results": tool_parameters.get("max_results"), } + + # Add query_prefix handling + query_prefix = tool_parameters.get("query_prefix", "").strip() + final_query = f"{query_prefix} {query_dict['keywords']}".strip() + query_dict["keywords"] = final_query + response = DDGS().images(**query_dict) markdown_result = "\n\n" json_result = [] diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.yaml b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.yaml index 168cface22..a543d1e218 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.yaml +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_img.yaml @@ -86,3 +86,14 @@ parameters: en_US: The size of the image to be searched. zh_Hans: 要搜索的图片的大小 form: form + - name: query_prefix + label: + en_US: Query Prefix + zh_Hans: 查询前缀 + type: string + required: false + default: "" + form: form + human_description: + en_US: Specific Search e.g. "site:unsplash.com" + zh_Hans: 定向搜索 e.g. "site:unsplash.com" diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_news.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_news.py index 3a6fd394a8..11da6f5cf7 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_news.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_news.py @@ -7,7 +7,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool SUMMARY_PROMPT = """ -User's query: +User's query: {query} Here are the news results: @@ -30,6 +30,12 @@ class DuckDuckGoNewsSearchTool(BuiltinTool): "safesearch": "moderate", "region": "wt-wt", } + + # Add query_prefix handling + query_prefix = tool_parameters.get("query_prefix", "").strip() + final_query = f"{query_prefix} {query_dict['keywords']}".strip() + query_dict["keywords"] = final_query + try: response = list(DDGS().news(**query_dict)) if not response: diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_news.yaml b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_news.yaml index eb2b67b7c9..6e181e0f41 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_news.yaml +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_news.yaml @@ -69,3 +69,14 @@ parameters: en_US: Whether to pass the news results to llm for summarization. zh_Hans: 是否需要将新闻结果传给大模型总结 form: form + - name: query_prefix + label: + en_US: Query Prefix + zh_Hans: 查询前缀 + type: string + required: false + default: "" + form: form + human_description: + en_US: Specific Search e.g. "site:msn.com" + zh_Hans: 定向搜索 e.g. "site:msn.com" diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py index cbd65d2e77..3cd35d16a6 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.py @@ -7,7 +7,7 @@ from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool.builtin_tool import BuiltinTool SUMMARY_PROMPT = """ -User's query: +User's query: {query} Here is the search engine result: @@ -26,7 +26,12 @@ class DuckDuckGoSearchTool(BuiltinTool): query = tool_parameters.get("query") max_results = tool_parameters.get("max_results", 5) require_summary = tool_parameters.get("require_summary", False) - response = DDGS().text(query, max_results=max_results) + + # Add query_prefix handling + query_prefix = tool_parameters.get("query_prefix", "").strip() + final_query = f"{query_prefix} {query}".strip() + + response = DDGS().text(final_query, max_results=max_results) if require_summary: results = "\n".join([res.get("body") for res in response]) results = self.summary_results(user_id=user_id, content=results, query=query) diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.yaml b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.yaml index 333c0cb093..54e27d9905 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.yaml +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_search.yaml @@ -39,3 +39,14 @@ parameters: en_US: Whether to pass the search results to llm for summarization. zh_Hans: 是否需要将搜索结果传给大模型总结 form: form + - name: query_prefix + label: + en_US: Query Prefix + zh_Hans: 查询前缀 + type: string + required: false + default: "" + form: form + human_description: + en_US: Specific Search e.g. "site:wikipedia.org" + zh_Hans: 定向搜索 e.g. "site:wikipedia.org" diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_video.py b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_video.py index 4b74b223c1..1eef0b1ba2 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_video.py +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_video.py @@ -24,7 +24,7 @@ max-width: 100%; border-radius: 8px;"> def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]: query_dict = { - "keywords": tool_parameters.get("query"), + "keywords": tool_parameters.get("query"), # LLM's query "region": tool_parameters.get("region", "wt-wt"), "safesearch": tool_parameters.get("safesearch", "moderate"), "timelimit": tool_parameters.get("timelimit"), @@ -40,6 +40,12 @@ max-width: 100%; border-radius: 8px;"> # Get proxy URL from parameters proxy_url = tool_parameters.get("proxy_url", "").strip() + query_prefix = tool_parameters.get("query_prefix", "").strip() + final_query = f"{query_prefix} {query_dict['keywords']}".strip() + + # Update the keywords in query_dict with the final_query + query_dict["keywords"] = final_query + response = DDGS().videos(**query_dict) # Create HTML result with embedded iframes @@ -51,9 +57,13 @@ max-width: 100%; border-radius: 8px;"> embed_html = res.get("embed_html", "") description = res.get("description", "") content_url = res.get("content", "") + transcript_url = None # Handle TED.com videos - if not embed_html and "ted.com/talks" in content_url: + if "ted.com/talks" in content_url: + # Create transcript URL + transcript_url = f"{content_url}/transcript" + # Create embed URL embed_url = content_url.replace("www.ted.com", "embed.ted.com") if proxy_url: embed_url = f"{proxy_url}{embed_url}" @@ -68,8 +78,14 @@ max-width: 100%; border-radius: 8px;"> markdown_result += f"{title}\n\n" markdown_result += f"{embed_html}\n\n" + if description: + markdown_result += f"{description}\n\n" markdown_result += "---\n\n" - json_result.append(self.create_json_message(res)) + # Add transcript_url to the JSON result if available + result_dict = res.copy() + if transcript_url: + result_dict["transcript_url"] = transcript_url + json_result.append(self.create_json_message(result_dict)) return [self.create_text_message(markdown_result)] + json_result diff --git a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_video.yaml b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_video.yaml index a516d3cb98..d846244e3d 100644 --- a/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_video.yaml +++ b/api/core/tools/provider/builtin/duckduckgo/tools/ddgo_video.yaml @@ -95,3 +95,14 @@ parameters: en_US: Proxy URL zh_Hans: 视频代理地址 form: form + - name: query_prefix + label: + en_US: Query Prefix + zh_Hans: 查询前缀 + type: string + required: false + default: "" + form: form + human_description: + en_US: Specific Search e.g. "site:www.ted.com" + zh_Hans: 定向搜索 e.g. "site:www.ted.com" diff --git a/api/core/tools/provider/builtin/gitlab/tools/gitlab_mergerequests.yaml b/api/core/tools/provider/builtin/gitlab/tools/gitlab_mergerequests.yaml index 4c886b69c0..81adb3db7d 100644 --- a/api/core/tools/provider/builtin/gitlab/tools/gitlab_mergerequests.yaml +++ b/api/core/tools/provider/builtin/gitlab/tools/gitlab_mergerequests.yaml @@ -6,9 +6,9 @@ identity: zh_Hans: GitLab 合并请求查询 description: human: - en_US: A tool for query GitLab merge requests, Input should be a exists reposity or branch. + en_US: A tool for query GitLab merge requests, Input should be a exists repository or branch. zh_Hans: 一个用于查询 GitLab 代码合并请求的工具,输入的内容应该是一个已存在的仓库名或者分支。 - llm: A tool for query GitLab merge requests, Input should be a exists reposity or branch. + llm: A tool for query GitLab merge requests, Input should be a exists repository or branch. parameters: - name: repository type: string diff --git a/api/core/tools/provider/builtin/searchapi/tools/google.py b/api/core/tools/provider/builtin/searchapi/tools/google.py index 17e2978194..29d36f5f23 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google.py @@ -45,7 +45,7 @@ class SearchAPI: def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" if "error" in res: - raise ValueError(f"Got error from SearchApi: {res['error']}") + return res["error"] toret = "" if type == "text": diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py index c478bc108b..de42360898 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_jobs.py @@ -45,7 +45,7 @@ class SearchAPI: def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" if "error" in res: - raise ValueError(f"Got error from SearchApi: {res['error']}") + return res["error"] toret = "" if type == "text": diff --git a/api/core/tools/provider/builtin/searchapi/tools/google_news.py b/api/core/tools/provider/builtin/searchapi/tools/google_news.py index 562bc01964..c8b3ccda05 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/google_news.py +++ b/api/core/tools/provider/builtin/searchapi/tools/google_news.py @@ -45,7 +45,7 @@ class SearchAPI: def _process_response(res: dict, type: str) -> str: """Process response from SearchAPI.""" if "error" in res: - raise ValueError(f"Got error from SearchApi: {res['error']}") + return res["error"] toret = "" if type == "text": diff --git a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py index 1867cf7be7..b14821f831 100644 --- a/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py +++ b/api/core/tools/provider/builtin/searchapi/tools/youtube_transcripts.py @@ -45,7 +45,7 @@ class SearchAPI: def _process_response(res: dict) -> str: """Process response from SearchAPI.""" if "error" in res: - raise ValueError(f"Got error from SearchApi: {res['error']}") + return res["error"] toret = "" if "transcripts" in res and "text" in res["transcripts"][0]: diff --git a/api/core/tools/provider/builtin/slidespeak/tools/slides_generator.py b/api/core/tools/provider/builtin/slidespeak/tools/slides_generator.py index 74742bf4b7..aa4ee63e97 100644 --- a/api/core/tools/provider/builtin/slidespeak/tools/slides_generator.py +++ b/api/core/tools/provider/builtin/slidespeak/tools/slides_generator.py @@ -149,7 +149,7 @@ class SlidesGeneratorTool(BuiltinTool): presentation_bytes = await self._fetch_presentation(session, download_url) return [ - self.create_text_message("Presentation generated successfully"), + self.create_text_message(download_url), self.create_blob_message( blob=presentation_bytes, meta={"mime_type": "application/vnd.openxmlformats-officedocument.presentationml.presentation"}, diff --git a/api/core/tools/provider/builtin/time/tools/current_time.py b/api/core/tools/provider/builtin/time/tools/current_time.py index cc38739c16..6464bb6602 100644 --- a/api/core/tools/provider/builtin/time/tools/current_time.py +++ b/api/core/tools/provider/builtin/time/tools/current_time.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any, Union from pytz import timezone as pytz_timezone @@ -20,7 +20,7 @@ class CurrentTimeTool(BuiltinTool): tz = tool_parameters.get("timezone", "UTC") fm = tool_parameters.get("format") or "%Y-%m-%d %H:%M:%S %Z" if tz == "UTC": - return self.create_text_message(f"{datetime.now(timezone.utc).strftime(fm)}") + return self.create_text_message(f"{datetime.now(UTC).strftime(fm)}") try: tz = pytz_timezone(tz) diff --git a/api/core/tools/tool/tool.py b/api/core/tools/tool/tool.py index f17a26dfbd..8d40450381 100644 --- a/api/core/tools/tool/tool.py +++ b/api/core/tools/tool/tool.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from copy import deepcopy -from enum import Enum +from enum import Enum, StrEnum from typing import TYPE_CHECKING, Any, Optional, Union from pydantic import BaseModel, ConfigDict, field_validator @@ -62,7 +62,7 @@ class Tool(BaseModel, ABC): def __init__(self, **data: Any): super().__init__(**data) - class VariableKey(str, Enum): + class VariableKey(StrEnum): IMAGE = "image" DOCUMENT = "document" VIDEO = "video" @@ -324,7 +324,12 @@ class Tool(BaseModel, ABC): :param blob: the blob :return: the blob message """ - return ToolInvokeMessage(type=ToolInvokeMessage.MessageType.BLOB, message=blob, meta=meta, save_as=save_as) + return ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.BLOB, + message=blob, + meta=meta or {}, + save_as=save_as, + ) def create_json_message(self, object: dict) -> ToolInvokeMessage: """ diff --git a/api/core/tools/tool/workflow_tool.py b/api/core/tools/tool/workflow_tool.py index 721fa06c54..33b4ad021a 100644 --- a/api/core/tools/tool/workflow_tool.py +++ b/api/core/tools/tool/workflow_tool.py @@ -58,11 +58,11 @@ class WorkflowTool(Tool): user=self._get_user(user_id), args={"inputs": tool_parameters, "files": files}, invoke_from=self.runtime.invoke_from, - stream=False, + streaming=False, call_depth=self.workflow_call_depth + 1, workflow_thread_pool_id=self.thread_pool_id, ) - + assert isinstance(result, dict) data = result.get("data", {}) if data.get("error"): diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 01a1fe330f..f92b43608e 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -1,7 +1,7 @@ import json from collections.abc import Mapping from copy import deepcopy -from datetime import datetime, timezone +from datetime import UTC, datetime from mimetypes import guess_type from typing import Any, Optional, Union @@ -61,7 +61,12 @@ class ToolEngine: if parameters and len(parameters) == 1: tool_parameters = {parameters[0].name: tool_parameters} else: - raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") + try: + tool_parameters = json.loads(tool_parameters) + except Exception as e: + pass + if not isinstance(tool_parameters, dict): + raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}") # invoke the tool try: @@ -158,7 +163,7 @@ class ToolEngine: """ Invoke the tool with the given arguments. """ - started_at = datetime.now(timezone.utc) + started_at = datetime.now(UTC) meta = ToolInvokeMeta( time_cost=0.0, error=None, @@ -176,7 +181,7 @@ class ToolEngine: meta.error = str(e) raise ToolEngineInvokeError(meta) finally: - ended_at = datetime.now(timezone.utc) + ended_at = datetime.now(UTC) meta.time_cost = (ended_at - started_at).total_seconds() return meta, response diff --git a/api/core/variables/__init__.py b/api/core/variables/__init__.py index 144c1b899f..2b1a58f93a 100644 --- a/api/core/variables/__init__.py +++ b/api/core/variables/__init__.py @@ -32,32 +32,32 @@ from .variables import ( ) __all__ = [ - "IntegerVariable", - "FloatVariable", - "ObjectVariable", - "SecretVariable", - "StringVariable", - "ArrayAnyVariable", - "Variable", - "SegmentType", - "SegmentGroup", - "Segment", - "NoneSegment", - "NoneVariable", - "IntegerSegment", - "FloatSegment", - "ObjectSegment", "ArrayAnySegment", - "StringSegment", - "ArrayStringVariable", - "ArrayNumberVariable", - "ArrayObjectVariable", - "ArraySegment", + "ArrayAnyVariable", "ArrayFileSegment", + "ArrayFileVariable", "ArrayNumberSegment", + "ArrayNumberVariable", "ArrayObjectSegment", + "ArrayObjectVariable", + "ArraySegment", "ArrayStringSegment", + "ArrayStringVariable", "FileSegment", "FileVariable", - "ArrayFileVariable", + "FloatSegment", + "FloatVariable", + "IntegerSegment", + "IntegerVariable", + "NoneSegment", + "NoneVariable", + "ObjectSegment", + "ObjectVariable", + "SecretVariable", + "Segment", + "SegmentGroup", + "SegmentType", + "StringSegment", + "StringVariable", + "Variable", ] diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 53c2e8a3aa..4387e9693e 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -1,17 +1,20 @@ -from enum import Enum +from enum import StrEnum -class SegmentType(str, Enum): - NONE = "none" +class SegmentType(StrEnum): NUMBER = "number" STRING = "string" + OBJECT = "object" SECRET = "secret" + + FILE = "file" + ARRAY_ANY = "array[any]" ARRAY_STRING = "array[string]" ARRAY_NUMBER = "array[number]" ARRAY_OBJECT = "array[object]" - OBJECT = "object" - FILE = "file" ARRAY_FILE = "array[file]" + NONE = "none" + GROUP = "group" diff --git a/api/core/workflow/callbacks/__init__.py b/api/core/workflow/callbacks/__init__.py index 403fbbaa2f..fba86c1e2e 100644 --- a/api/core/workflow/callbacks/__init__.py +++ b/api/core/workflow/callbacks/__init__.py @@ -2,6 +2,6 @@ from .base_workflow_callback import WorkflowCallback from .workflow_logging_callback import WorkflowLoggingCallback __all__ = [ - "WorkflowLoggingCallback", "WorkflowCallback", + "WorkflowLoggingCallback", ] diff --git a/api/core/workflow/entities/node_entities.py b/api/core/workflow/entities/node_entities.py index a747266661..e174d3baa0 100644 --- a/api/core/workflow/entities/node_entities.py +++ b/api/core/workflow/entities/node_entities.py @@ -1,5 +1,5 @@ from collections.abc import Mapping -from enum import Enum +from enum import StrEnum from typing import Any, Optional from pydantic import BaseModel @@ -8,7 +8,7 @@ from core.model_runtime.entities.llm_entities import LLMUsage from models.workflow import WorkflowNodeExecutionStatus -class NodeRunMetadataKey(str, Enum): +class NodeRunMetadataKey(StrEnum): """ Node Run Metadata Key. """ @@ -36,7 +36,7 @@ class NodeRunResult(BaseModel): inputs: Optional[Mapping[str, Any]] = None # node inputs process_data: Optional[dict[str, Any]] = None # process data - outputs: Optional[dict[str, Any]] = None # node outputs + outputs: Optional[Mapping[str, Any]] = None # node outputs metadata: Optional[dict[NodeRunMetadataKey, Any]] = None # node metadata llm_usage: Optional[LLMUsage] = None # llm usage diff --git a/api/core/workflow/enums.py b/api/core/workflow/enums.py index 213ed57f57..9642efa1a5 100644 --- a/api/core/workflow/enums.py +++ b/api/core/workflow/enums.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class SystemVariableKey(str, Enum): +class SystemVariableKey(StrEnum): """ System Variables. """ diff --git a/api/core/workflow/graph_engine/entities/runtime_route_state.py b/api/core/workflow/graph_engine/entities/runtime_route_state.py index bb24b51112..baeec9bf01 100644 --- a/api/core/workflow/graph_engine/entities/runtime_route_state.py +++ b/api/core/workflow/graph_engine/entities/runtime_route_state.py @@ -1,5 +1,5 @@ import uuid -from datetime import datetime, timezone +from datetime import UTC, datetime from enum import Enum from typing import Optional @@ -63,7 +63,7 @@ class RouteNodeState(BaseModel): raise Exception(f"Invalid route status {run_result.status}") self.node_run_result = run_result - self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + self.finished_at = datetime.now(UTC).replace(tzinfo=None) class RuntimeRouteState(BaseModel): @@ -81,7 +81,7 @@ class RuntimeRouteState(BaseModel): :param node_id: node id """ - state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) + state = RouteNodeState(node_id=node_id, start_at=datetime.now(UTC).replace(tzinfo=None)) self.node_state_mapping[state.id] = state return state diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 60a5901b21..7cffd7bc8e 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -38,7 +38,7 @@ from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProce from core.workflow.nodes.base import BaseNode from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent -from core.workflow.nodes.node_mapping import node_type_classes_mapping +from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from extensions.ext_database import db from models.enums import UserFrom from models.workflow import WorkflowNodeExecutionStatus, WorkflowType @@ -64,7 +64,6 @@ class GraphEngineThreadPool(ThreadPoolExecutor): self.submit_count -= 1 def check_is_full(self) -> None: - print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}") if self.submit_count > self.max_submit_count: raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.") @@ -228,7 +227,8 @@ class GraphEngine: # convert to specific node node_type = NodeType(node_config.get("data", {}).get("type")) - node_cls = node_type_classes_mapping[node_type] + node_version = node_config.get("data", {}).get("version", "1") + node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None diff --git a/api/core/workflow/nodes/answer/__init__.py b/api/core/workflow/nodes/answer/__init__.py index 7a10f47eed..ee7676c7e4 100644 --- a/api/core/workflow/nodes/answer/__init__.py +++ b/api/core/workflow/nodes/answer/__init__.py @@ -1,4 +1,4 @@ from .answer_node import AnswerNode from .entities import AnswerStreamGenerateRoute -__all__ = ["AnswerStreamGenerateRoute", "AnswerNode"] +__all__ = ["AnswerNode", "AnswerStreamGenerateRoute"] diff --git a/api/core/workflow/nodes/answer/answer_stream_generate_router.py b/api/core/workflow/nodes/answer/answer_stream_generate_router.py index 96e24a7db3..8c78016f09 100644 --- a/api/core/workflow/nodes/answer/answer_stream_generate_router.py +++ b/api/core/workflow/nodes/answer/answer_stream_generate_router.py @@ -153,7 +153,7 @@ class AnswerStreamGeneratorRouter: NodeType.IF_ELSE, NodeType.QUESTION_CLASSIFIER, NodeType.ITERATION, - NodeType.CONVERSATION_VARIABLE_ASSIGNER, + NodeType.VARIABLE_ASSIGNER, }: answer_dependencies[answer_node_id].append(source_node_id) else: diff --git a/api/core/workflow/nodes/base/__init__.py b/api/core/workflow/nodes/base/__init__.py index 61f727740c..72d6392d4e 100644 --- a/api/core/workflow/nodes/base/__init__.py +++ b/api/core/workflow/nodes/base/__init__.py @@ -1,4 +1,4 @@ from .entities import BaseIterationNodeData, BaseIterationState, BaseNodeData from .node import BaseNode -__all__ = ["BaseNode", "BaseNodeData", "BaseIterationNodeData", "BaseIterationState"] +__all__ = ["BaseIterationNodeData", "BaseIterationState", "BaseNode", "BaseNodeData"] diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index 2a864dd7a8..fb50fbd6e8 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -7,6 +7,7 @@ from pydantic import BaseModel class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None + version: str = "1" class BaseIterationNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 1871fff618..d0fbed31cd 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -55,7 +55,9 @@ class BaseNode(Generic[GenericNodeData]): raise ValueError("Node ID is required.") self.node_id = node_id - self.node_data: GenericNodeData = cast(GenericNodeData, self._node_data_cls(**config.get("data", {}))) + + node_data = self._node_data_cls.model_validate(config.get("data", {})) + self.node_data = cast(GenericNodeData, node_data) @abstractmethod def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]: diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index c3cacdab7f..d490a2eb03 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -4,8 +4,8 @@ import json import docx import pandas as pd -import pypdfium2 -import yaml +import pypdfium2 # type: ignore +import yaml # type: ignore from unstructured.partition.api import partition_via_api from unstructured.partition.email import partition_email from unstructured.partition.epub import partition_epub @@ -113,7 +113,7 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str: """Extract text from a file based on its file extension.""" match file_extension: - case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml": + case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml" | ".vtt": return _extract_text_from_plain_text(file_content) case ".json": return _extract_text_from_json(file_content) @@ -237,15 +237,17 @@ def _extract_text_from_csv(file_content: bytes) -> str: def _extract_text_from_excel(file_content: bytes) -> str: """Extract text from an Excel file using pandas.""" - try: - df = pd.read_excel(io.BytesIO(file_content)) - - # Drop rows where all elements are NaN - df.dropna(how="all", inplace=True) - - # Convert DataFrame to Markdown table - markdown_table = df.to_markdown(index=False) + excel_file = pd.ExcelFile(io.BytesIO(file_content)) + markdown_table = "" + for sheet_name in excel_file.sheet_names: + try: + df = excel_file.parse(sheet_name=sheet_name) + df.dropna(how="all", inplace=True) + # Create Markdown table two times to separate tables with a newline + markdown_table += df.to_markdown(index=False) + "\n\n" + except Exception as e: + continue return markdown_table except Exception as e: raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e diff --git a/api/core/workflow/nodes/end/__init__.py b/api/core/workflow/nodes/end/__init__.py index adb381701c..c4c00e3ddc 100644 --- a/api/core/workflow/nodes/end/__init__.py +++ b/api/core/workflow/nodes/end/__init__.py @@ -1,4 +1,4 @@ from .end_node import EndNode from .entities import EndStreamParam -__all__ = ["EndStreamParam", "EndNode"] +__all__ = ["EndNode", "EndStreamParam"] diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 208144655b..44be403ee6 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class NodeType(str, Enum): +class NodeType(StrEnum): START = "start" END = "end" ANSWER = "answer" @@ -14,11 +14,11 @@ class NodeType(str, Enum): HTTP_REQUEST = "http-request" TOOL = "tool" VARIABLE_AGGREGATOR = "variable-aggregator" - VARIABLE_ASSIGNER = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database. + LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database. LOOP = "loop" ITERATION = "iteration" ITERATION_START = "iteration-start" # Fake start node for iteration. PARAMETER_EXTRACTOR = "parameter-extractor" - CONVERSATION_VARIABLE_ASSIGNER = "assigner" + VARIABLE_ASSIGNER = "assigner" DOCUMENT_EXTRACTOR = "document-extractor" LIST_OPERATOR = "list-operator" diff --git a/api/core/workflow/nodes/event/__init__.py b/api/core/workflow/nodes/event/__init__.py index 581def9553..5e3b31e48b 100644 --- a/api/core/workflow/nodes/event/__init__.py +++ b/api/core/workflow/nodes/event/__init__.py @@ -2,9 +2,9 @@ from .event import ModelInvokeCompletedEvent, RunCompletedEvent, RunRetrieverRes from .types import NodeEvent __all__ = [ + "ModelInvokeCompletedEvent", + "NodeEvent", "RunCompletedEvent", "RunRetrieverResourceEvent", "RunStreamChunkEvent", - "NodeEvent", - "ModelInvokeCompletedEvent", ] diff --git a/api/core/workflow/nodes/http_request/__init__.py b/api/core/workflow/nodes/http_request/__init__.py index 9408c2dde0..c51c678999 100644 --- a/api/core/workflow/nodes/http_request/__init__.py +++ b/api/core/workflow/nodes/http_request/__init__.py @@ -1,4 +1,4 @@ from .entities import BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeData from .node import HttpRequestNode -__all__ = ["HttpRequestNodeData", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "BodyData", "HttpRequestNode"] +__all__ = ["BodyData", "HttpRequestNode", "HttpRequestNodeAuthorization", "HttpRequestNodeBody", "HttpRequestNodeData"] diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 5b399bed63..2a92a16ede 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -1,11 +1,9 @@ import logging from collections.abc import Mapping, Sequence -from mimetypes import guess_extension -from os import path from typing import Any from configs import dify_config -from core.file import File, FileTransferMethod, FileType +from core.file import File, FileTransferMethod from core.tools.tool_file_manager import ToolFileManager from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_entities import VariableSelector @@ -107,6 +105,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): node_data: HttpRequestNodeData, ) -> Mapping[str, Sequence[str]]: selectors: list[VariableSelector] = [] + selectors += variable_template_parser.extract_selectors_from_template(node_data.url) selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) selectors += variable_template_parser.extract_selectors_from_template(node_data.params) if node_data.body: @@ -149,11 +148,6 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): content = response.content if is_file and content_type: - # extract filename from url - filename = path.basename(url) - # extract extension if possible - extension = guess_extension(content_type) or ".bin" - tool_file = ToolFileManager.create_file_by_raw( user_id=self.user_id, tenant_id=self.tenant_id, @@ -164,7 +158,6 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): mapping = { "tool_file_id": tool_file.id, - "type": FileType.IMAGE.value, "transfer_method": FileTransferMethod.TOOL_FILE.value, } file = file_factory.build_from_mapping( diff --git a/api/core/workflow/nodes/iteration/entities.py b/api/core/workflow/nodes/iteration/entities.py index ebcb6f82fb..7a489dd725 100644 --- a/api/core/workflow/nodes/iteration/entities.py +++ b/api/core/workflow/nodes/iteration/entities.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import StrEnum from typing import Any, Optional from pydantic import Field @@ -6,7 +6,7 @@ from pydantic import Field from core.workflow.nodes.base import BaseIterationNodeData, BaseIterationState, BaseNodeData -class ErrorHandleMode(str, Enum): +class ErrorHandleMode(StrEnum): TERMINATED = "terminated" CONTINUE_ON_ERROR = "continue-on-error" REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output" diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index d5428f0286..bba6ac20d3 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -2,7 +2,7 @@ import logging import uuid from collections.abc import Generator, Mapping, Sequence from concurrent.futures import Future, wait -from datetime import datetime, timezone +from datetime import UTC, datetime from queue import Empty, Queue from typing import TYPE_CHECKING, Any, Optional, cast @@ -135,7 +135,7 @@ class IterationNode(BaseNode[IterationNodeData]): thread_pool_id=self.thread_pool_id, ) - start_at = datetime.now(timezone.utc).replace(tzinfo=None) + start_at = datetime.now(UTC).replace(tzinfo=None) yield IterationRunStartedEvent( iteration_id=self.id, @@ -297,12 +297,13 @@ class IterationNode(BaseNode[IterationNodeData]): # variable selector to variable mapping try: # Get node class - from core.workflow.nodes.node_mapping import node_type_classes_mapping + from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING node_type = NodeType(sub_node_config.get("data", {}).get("type")) - node_cls = node_type_classes_mapping.get(node_type) - if not node_cls: + if node_type not in NODE_TYPE_CLASSES_MAPPING: continue + node_version = sub_node_config.get("data", {}).get("version", "1") + node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( graph_config=graph_config, config=sub_node_config @@ -367,7 +368,7 @@ class IterationNode(BaseNode[IterationNodeData]): """ run single iteration """ - iter_start_at = datetime.now(timezone.utc).replace(tzinfo=None) + iter_start_at = datetime.now(UTC).replace(tzinfo=None) try: rst = graph_engine.run() @@ -440,7 +441,7 @@ class IterationNode(BaseNode[IterationNodeData]): variable_pool.add([self.node_id, "index"], next_index) if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds() + duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, @@ -461,7 +462,7 @@ class IterationNode(BaseNode[IterationNodeData]): if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds() + duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, @@ -503,7 +504,7 @@ class IterationNode(BaseNode[IterationNodeData]): if next_index < len(iterator_list_value): variable_pool.add([self.node_id, "item"], iterator_list_value[next_index]) - duration = (datetime.now(timezone.utc).replace(tzinfo=None) - iter_start_at).total_seconds() + duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() iter_run_map[iteration_run_id] = duration yield IterationRunNextEvent( iteration_id=self.id, diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 2529c76942..8ab0d8b2eb 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -20,6 +20,7 @@ from core.model_runtime.entities import ( from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage from core.model_runtime.entities.message_entities import ( AssistantPromptMessage, + PromptMessageContent, PromptMessageRole, SystemPromptMessage, UserPromptMessage, @@ -38,6 +39,7 @@ from core.variables import ( ObjectSegment, StringSegment, ) +from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.variable_pool import VariablePool @@ -65,7 +67,6 @@ from .entities import ( ModelConfig, ) from .exc import ( - FileTypeNotSupportError, InvalidContextStructureError, InvalidVariableTypeError, LLMModeRequiredError, @@ -133,11 +134,15 @@ class LLMNode(BaseNode[LLMNodeData]): # fetch memory memory = self._fetch_memory(node_data_memory=self.node_data.memory, model_instance=model_instance) - # fetch prompt messages + query = None if self.node_data.memory: query = self.node_data.memory.query_prompt_template - else: - query = None + if not query and ( + query_variable := self.graph_runtime_state.variable_pool.get( + (SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY) + ) + ): + query = query_variable.text prompt_messages, stop = self._fetch_prompt_messages( user_query=query, @@ -192,7 +197,6 @@ class LLMNode(BaseNode[LLMNodeData]): ) return except Exception as e: - logger.exception(f"Node {self.node_id} failed to run") yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -670,7 +674,7 @@ class LLMNode(BaseNode[LLMNodeData]): and ModelFeature.AUDIO not in model_config.model_schema.features ) ): - raise FileTypeNotSupportError(type_name=content_item.type) + continue prompt_message_content.append(content_item) if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: prompt_message.content = prompt_message_content[0].data @@ -811,7 +815,7 @@ class LLMNode(BaseNode[LLMNodeData]): "completion_model": { "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, "prompt": { - "text": "Here is the chat histories between human and assistant, inside " + "text": "Here are the chat histories between human and assistant, inside " " XML tags.\n\n\n{{" "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", "edition_type": "basic", @@ -823,14 +827,14 @@ class LLMNode(BaseNode[LLMNodeData]): } -def _combine_text_message_with_role(*, text: str, role: PromptMessageRole): +def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole): match role: case PromptMessageRole.USER: - return UserPromptMessage(content=[TextPromptMessageContent(data=text)]) + return UserPromptMessage(content=contents) case PromptMessageRole.ASSISTANT: - return AssistantPromptMessage(content=[TextPromptMessageContent(data=text)]) + return AssistantPromptMessage(content=contents) case PromptMessageRole.SYSTEM: - return SystemPromptMessage(content=[TextPromptMessageContent(data=text)]) + return SystemPromptMessage(content=contents) raise NotImplementedError(f"Role {role} is not supported") @@ -872,7 +876,9 @@ def _handle_list_messages( jinjia2_variables=jinja2_variables, variable_pool=variable_pool, ) - prompt_message = _combine_text_message_with_role(text=result_text, role=message.role) + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=message.role + ) prompt_messages.append(prompt_message) else: # Get segment group from basic message @@ -903,12 +909,14 @@ def _handle_list_messages( # Create message with text from all segments plain_text = segment_group.text if plain_text: - prompt_message = _combine_text_message_with_role(text=plain_text, role=message.role) + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=plain_text)], role=message.role + ) prompt_messages.append(prompt_message) if file_contents: # Create message with image contents - prompt_message = UserPromptMessage(content=file_contents) + prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role) prompt_messages.append(prompt_message) return prompt_messages @@ -1013,6 +1021,8 @@ def _handle_completion_template( else: template_text = template.text result_text = variable_pool.convert_template(template_text).text - prompt_message = _combine_text_message_with_role(text=result_text, role=PromptMessageRole.USER) + prompt_message = _combine_message_content_with_role( + contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER + ) prompt_messages.append(prompt_message) return prompt_messages diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index c13b5ff76f..51fc5129cd 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -1,3 +1,5 @@ +from collections.abc import Mapping + from core.workflow.nodes.answer import AnswerNode from core.workflow.nodes.base import BaseNode from core.workflow.nodes.code import CodeNode @@ -16,26 +18,87 @@ from core.workflow.nodes.start import StartNode from core.workflow.nodes.template_transform import TemplateTransformNode from core.workflow.nodes.tool import ToolNode from core.workflow.nodes.variable_aggregator import VariableAggregatorNode -from core.workflow.nodes.variable_assigner import VariableAssignerNode +from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1 +from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2 + +LATEST_VERSION = "latest" -node_type_classes_mapping: dict[NodeType, type[BaseNode]] = { - NodeType.START: StartNode, - NodeType.END: EndNode, - NodeType.ANSWER: AnswerNode, - NodeType.LLM: LLMNode, - NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode, - NodeType.IF_ELSE: IfElseNode, - NodeType.CODE: CodeNode, - NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode, - NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode, - NodeType.HTTP_REQUEST: HttpRequestNode, - NodeType.TOOL: ToolNode, - NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode, - NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR - NodeType.ITERATION: IterationNode, - NodeType.ITERATION_START: IterationStartNode, - NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode, - NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode, - NodeType.DOCUMENT_EXTRACTOR: DocumentExtractorNode, - NodeType.LIST_OPERATOR: ListOperatorNode, +NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { + NodeType.START: { + LATEST_VERSION: StartNode, + "1": StartNode, + }, + NodeType.END: { + LATEST_VERSION: EndNode, + "1": EndNode, + }, + NodeType.ANSWER: { + LATEST_VERSION: AnswerNode, + "1": AnswerNode, + }, + NodeType.LLM: { + LATEST_VERSION: LLMNode, + "1": LLMNode, + }, + NodeType.KNOWLEDGE_RETRIEVAL: { + LATEST_VERSION: KnowledgeRetrievalNode, + "1": KnowledgeRetrievalNode, + }, + NodeType.IF_ELSE: { + LATEST_VERSION: IfElseNode, + "1": IfElseNode, + }, + NodeType.CODE: { + LATEST_VERSION: CodeNode, + "1": CodeNode, + }, + NodeType.TEMPLATE_TRANSFORM: { + LATEST_VERSION: TemplateTransformNode, + "1": TemplateTransformNode, + }, + NodeType.QUESTION_CLASSIFIER: { + LATEST_VERSION: QuestionClassifierNode, + "1": QuestionClassifierNode, + }, + NodeType.HTTP_REQUEST: { + LATEST_VERSION: HttpRequestNode, + "1": HttpRequestNode, + }, + NodeType.TOOL: { + LATEST_VERSION: ToolNode, + "1": ToolNode, + }, + NodeType.VARIABLE_AGGREGATOR: { + LATEST_VERSION: VariableAggregatorNode, + "1": VariableAggregatorNode, + }, + NodeType.LEGACY_VARIABLE_AGGREGATOR: { + LATEST_VERSION: VariableAggregatorNode, + "1": VariableAggregatorNode, + }, # original name of VARIABLE_AGGREGATOR + NodeType.ITERATION: { + LATEST_VERSION: IterationNode, + "1": IterationNode, + }, + NodeType.ITERATION_START: { + LATEST_VERSION: IterationStartNode, + "1": IterationStartNode, + }, + NodeType.PARAMETER_EXTRACTOR: { + LATEST_VERSION: ParameterExtractorNode, + "1": ParameterExtractorNode, + }, + NodeType.VARIABLE_ASSIGNER: { + LATEST_VERSION: VariableAssignerNodeV2, + "1": VariableAssignerNodeV1, + "2": VariableAssignerNodeV2, + }, + NodeType.DOCUMENT_EXTRACTOR: { + LATEST_VERSION: DocumentExtractorNode, + "1": DocumentExtractorNode, + }, + NodeType.LIST_OPERATOR: { + LATEST_VERSION: ListOperatorNode, + "1": ListOperatorNode, + }, } diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index b64bde8ac5..5b960ea615 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -235,7 +235,7 @@ class ParameterExtractorNode(LLMNode): raise InvalidInvokeResultError(f"Invalid invoke result: {invoke_result}") text = invoke_result.message.content - if not isinstance(text, str): + if not isinstance(text, str | None): raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.") usage = invoke_result.usage diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py index 58fcecc53b..e603add170 100644 --- a/api/core/workflow/nodes/parameter_extractor/prompts.py +++ b/api/core/workflow/nodes/parameter_extractor/prompts.py @@ -98,7 +98,7 @@ Step 3: Structure the extracted parameters to JSON object as specified in XML tags. +Here are the chat histories between human and assistant, inside XML tags. {histories} @@ -125,7 +125,7 @@ CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and out The structure of the JSON object you can found in the instructions. ### Memory -Here is the chat histories between human and assistant, inside XML tags. +Here are the chat histories between human and assistant, inside XML tags. {histories} diff --git a/api/core/workflow/nodes/question_classifier/__init__.py b/api/core/workflow/nodes/question_classifier/__init__.py index 70414c4199..4d06b6bea3 100644 --- a/api/core/workflow/nodes/question_classifier/__init__.py +++ b/api/core/workflow/nodes/question_classifier/__init__.py @@ -1,4 +1,4 @@ from .entities import QuestionClassifierNodeData from .question_classifier_node import QuestionClassifierNode -__all__ = ["QuestionClassifierNodeData", "QuestionClassifierNode"] +__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeData"] diff --git a/api/core/workflow/nodes/question_classifier/template_prompts.py b/api/core/workflow/nodes/question_classifier/template_prompts.py index 4bca2d9dd4..53fc136b2c 100644 --- a/api/core/workflow/nodes/question_classifier/template_prompts.py +++ b/api/core/workflow/nodes/question_classifier/template_prompts.py @@ -8,7 +8,7 @@ QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ ### Constraint DO NOT include anything other than the JSON array in your response. ### Memory - Here is the chat histories between human and assistant, inside XML tags. + Here are the chat histories between human and assistant, inside XML tags. {histories} @@ -66,7 +66,7 @@ User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{" Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}} ### Memory -Here is the chat histories between human and assistant, inside XML tags. +Here are the chat histories between human and assistant, inside XML tags. {histories} diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 5560f26456..951e5330a3 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -250,9 +250,8 @@ class ToolNode(BaseNode[ToolNodeData]): f"{message.message}" if message.type == ToolInvokeMessage.MessageType.TEXT else f"Link: {message.message}" - if message.type == ToolInvokeMessage.MessageType.LINK - else "" for message in tool_response + if message.type in {ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.LINK} ] ) diff --git a/api/core/workflow/nodes/variable_assigner/__init__.py b/api/core/workflow/nodes/variable_assigner/__init__.py index 83da4bdc79..e69de29bb2 100644 --- a/api/core/workflow/nodes/variable_assigner/__init__.py +++ b/api/core/workflow/nodes/variable_assigner/__init__.py @@ -1,8 +0,0 @@ -from .node import VariableAssignerNode -from .node_data import VariableAssignerData, WriteMode - -__all__ = [ - "VariableAssignerNode", - "VariableAssignerData", - "WriteMode", -] diff --git a/api/core/workflow/nodes/variable_assigner/common/__init__.py b/api/core/workflow/nodes/variable_assigner/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/core/workflow/nodes/variable_assigner/common/exc.py b/api/core/workflow/nodes/variable_assigner/common/exc.py new file mode 100644 index 0000000000..a1178fb020 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/common/exc.py @@ -0,0 +1,4 @@ +class VariableOperatorNodeError(Exception): + """Base error type, don't use directly.""" + + pass diff --git a/api/core/workflow/nodes/variable_assigner/common/helpers.py b/api/core/workflow/nodes/variable_assigner/common/helpers.py new file mode 100644 index 0000000000..8031b57fa8 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/common/helpers.py @@ -0,0 +1,19 @@ +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.variables import Variable +from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from extensions.ext_database import db +from models import ConversationVariable + + +def update_conversation_variable(conversation_id: str, variable: Variable): + stmt = select(ConversationVariable).where( + ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id + ) + with Session(db.engine) as session: + row = session.scalar(stmt) + if not row: + raise VariableOperatorNodeError("conversation variable not found in the database") + row.data = variable.model_dump_json() + session.commit() diff --git a/api/core/workflow/nodes/variable_assigner/exc.py b/api/core/workflow/nodes/variable_assigner/exc.py deleted file mode 100644 index 914be22256..0000000000 --- a/api/core/workflow/nodes/variable_assigner/exc.py +++ /dev/null @@ -1,2 +0,0 @@ -class VariableAssignerNodeError(Exception): - pass diff --git a/api/core/workflow/nodes/variable_assigner/v1/__init__.py b/api/core/workflow/nodes/variable_assigner/v1/__init__.py new file mode 100644 index 0000000000..7eb1428e50 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/v1/__init__.py @@ -0,0 +1,3 @@ +from .node import VariableAssignerNode + +__all__ = ["VariableAssignerNode"] diff --git a/api/core/workflow/nodes/variable_assigner/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py similarity index 69% rename from api/core/workflow/nodes/variable_assigner/node.py rename to api/core/workflow/nodes/variable_assigner/v1/node.py index 4e66f640df..8eb4bd5c2d 100644 --- a/api/core/workflow/nodes/variable_assigner/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -1,40 +1,36 @@ -from sqlalchemy import select -from sqlalchemy.orm import Session - from core.variables import SegmentType, Variable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.nodes.base import BaseNode, BaseNodeData from core.workflow.nodes.enums import NodeType -from extensions.ext_database import db +from core.workflow.nodes.variable_assigner.common import helpers as common_helpers +from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from factories import variable_factory -from models import ConversationVariable from models.workflow import WorkflowNodeExecutionStatus -from .exc import VariableAssignerNodeError from .node_data import VariableAssignerData, WriteMode class VariableAssignerNode(BaseNode[VariableAssignerData]): _node_data_cls: type[BaseNodeData] = VariableAssignerData - _node_type: NodeType = NodeType.CONVERSATION_VARIABLE_ASSIGNER + _node_type = NodeType.VARIABLE_ASSIGNER def _run(self) -> NodeRunResult: # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject original_variable = self.graph_runtime_state.variable_pool.get(self.node_data.assigned_variable_selector) if not isinstance(original_variable, Variable): - raise VariableAssignerNodeError("assigned variable not found") + raise VariableOperatorNodeError("assigned variable not found") match self.node_data.write_mode: case WriteMode.OVER_WRITE: income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) if not income_value: - raise VariableAssignerNodeError("input value not found") + raise VariableOperatorNodeError("input value not found") updated_variable = original_variable.model_copy(update={"value": income_value.value}) case WriteMode.APPEND: income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) if not income_value: - raise VariableAssignerNodeError("input value not found") + raise VariableOperatorNodeError("input value not found") updated_value = original_variable.value + [income_value.value] updated_variable = original_variable.model_copy(update={"value": updated_value}) @@ -43,7 +39,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) case _: - raise VariableAssignerNodeError(f"unsupported write mode: {self.node_data.write_mode}") + raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}") # Over write the variable. self.graph_runtime_state.variable_pool.add(self.node_data.assigned_variable_selector, updated_variable) @@ -52,8 +48,8 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): # Update conversation variable. conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) if not conversation_id: - raise VariableAssignerNodeError("conversation_id not found") - update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) + raise VariableOperatorNodeError("conversation_id not found") + common_helpers.update_conversation_variable(conversation_id=conversation_id.text, variable=updated_variable) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -63,18 +59,6 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): ) -def update_conversation_variable(conversation_id: str, variable: Variable): - stmt = select(ConversationVariable).where( - ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id - ) - with Session(db.engine) as session: - row = session.scalar(stmt) - if not row: - raise VariableAssignerNodeError("conversation variable not found in the database") - row.data = variable.model_dump_json() - session.commit() - - def get_zero_value(t: SegmentType): match t: case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: @@ -86,4 +70,4 @@ def get_zero_value(t: SegmentType): case SegmentType.NUMBER: return variable_factory.build_segment(0) case _: - raise VariableAssignerNodeError(f"unsupported variable type: {t}") + raise VariableOperatorNodeError(f"unsupported variable type: {t}") diff --git a/api/core/workflow/nodes/variable_assigner/node_data.py b/api/core/workflow/nodes/variable_assigner/v1/node_data.py similarity index 65% rename from api/core/workflow/nodes/variable_assigner/node_data.py rename to api/core/workflow/nodes/variable_assigner/v1/node_data.py index 70ae29d45f..9734d64712 100644 --- a/api/core/workflow/nodes/variable_assigner/node_data.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node_data.py @@ -1,19 +1,16 @@ from collections.abc import Sequence -from enum import Enum -from typing import Optional +from enum import StrEnum from core.workflow.nodes.base import BaseNodeData -class WriteMode(str, Enum): +class WriteMode(StrEnum): OVER_WRITE = "over-write" APPEND = "append" CLEAR = "clear" class VariableAssignerData(BaseNodeData): - title: str = "Variable Assigner" - desc: Optional[str] = "Assign a value to a variable" assigned_variable_selector: Sequence[str] write_mode: WriteMode input_variable_selector: Sequence[str] diff --git a/api/core/workflow/nodes/variable_assigner/v2/__init__.py b/api/core/workflow/nodes/variable_assigner/v2/__init__.py new file mode 100644 index 0000000000..7eb1428e50 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/v2/__init__.py @@ -0,0 +1,3 @@ +from .node import VariableAssignerNode + +__all__ = ["VariableAssignerNode"] diff --git a/api/core/workflow/nodes/variable_assigner/v2/constants.py b/api/core/workflow/nodes/variable_assigner/v2/constants.py new file mode 100644 index 0000000000..3797bfa77a --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/v2/constants.py @@ -0,0 +1,11 @@ +from core.variables import SegmentType + +EMPTY_VALUE_MAPPING = { + SegmentType.STRING: "", + SegmentType.NUMBER: 0, + SegmentType.OBJECT: {}, + SegmentType.ARRAY_ANY: [], + SegmentType.ARRAY_STRING: [], + SegmentType.ARRAY_NUMBER: [], + SegmentType.ARRAY_OBJECT: [], +} diff --git a/api/core/workflow/nodes/variable_assigner/v2/entities.py b/api/core/workflow/nodes/variable_assigner/v2/entities.py new file mode 100644 index 0000000000..01df33b6d4 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/v2/entities.py @@ -0,0 +1,20 @@ +from collections.abc import Sequence +from typing import Any + +from pydantic import BaseModel + +from core.workflow.nodes.base import BaseNodeData + +from .enums import InputType, Operation + + +class VariableOperationItem(BaseModel): + variable_selector: Sequence[str] + input_type: InputType + operation: Operation + value: Any | None = None + + +class VariableAssignerNodeData(BaseNodeData): + version: str = "2" + items: Sequence[VariableOperationItem] diff --git a/api/core/workflow/nodes/variable_assigner/v2/enums.py b/api/core/workflow/nodes/variable_assigner/v2/enums.py new file mode 100644 index 0000000000..36cf68aa19 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/v2/enums.py @@ -0,0 +1,18 @@ +from enum import StrEnum + + +class Operation(StrEnum): + OVER_WRITE = "over-write" + CLEAR = "clear" + APPEND = "append" + EXTEND = "extend" + SET = "set" + ADD = "+=" + SUBTRACT = "-=" + MULTIPLY = "*=" + DIVIDE = "/=" + + +class InputType(StrEnum): + VARIABLE = "variable" + CONSTANT = "constant" diff --git a/api/core/workflow/nodes/variable_assigner/v2/exc.py b/api/core/workflow/nodes/variable_assigner/v2/exc.py new file mode 100644 index 0000000000..5b1ef4b04f --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/v2/exc.py @@ -0,0 +1,31 @@ +from collections.abc import Sequence +from typing import Any + +from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError + +from .enums import InputType, Operation + + +class OperationNotSupportedError(VariableOperatorNodeError): + def __init__(self, *, operation: Operation, varialbe_type: str): + super().__init__(f"Operation {operation} is not supported for type {varialbe_type}") + + +class InputTypeNotSupportedError(VariableOperatorNodeError): + def __init__(self, *, input_type: InputType, operation: Operation): + super().__init__(f"Input type {input_type} is not supported for operation {operation}") + + +class VariableNotFoundError(VariableOperatorNodeError): + def __init__(self, *, variable_selector: Sequence[str]): + super().__init__(f"Variable {variable_selector} not found") + + +class InvalidInputValueError(VariableOperatorNodeError): + def __init__(self, *, value: Any): + super().__init__(f"Invalid input value {value}") + + +class ConversationIDNotFoundError(VariableOperatorNodeError): + def __init__(self): + super().__init__("conversation_id not found") diff --git a/api/core/workflow/nodes/variable_assigner/v2/helpers.py b/api/core/workflow/nodes/variable_assigner/v2/helpers.py new file mode 100644 index 0000000000..a86c7eb94a --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/v2/helpers.py @@ -0,0 +1,91 @@ +from typing import Any + +from core.variables import SegmentType + +from .enums import Operation + + +def is_operation_supported(*, variable_type: SegmentType, operation: Operation): + match operation: + case Operation.OVER_WRITE | Operation.CLEAR: + return True + case Operation.SET: + return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER} + case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE: + # Only number variable can be added, subtracted, multiplied or divided + return variable_type == SegmentType.NUMBER + case Operation.APPEND | Operation.EXTEND: + # Only array variable can be appended or extended + return variable_type in { + SegmentType.ARRAY_ANY, + SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_FILE, + } + case _: + return False + + +def is_variable_input_supported(*, operation: Operation): + if operation in {Operation.SET, Operation.ADD, Operation.SUBTRACT, Operation.MULTIPLY, Operation.DIVIDE}: + return False + return True + + +def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation): + match variable_type: + case SegmentType.STRING | SegmentType.OBJECT: + return operation in {Operation.OVER_WRITE, Operation.SET} + case SegmentType.NUMBER: + return operation in { + Operation.OVER_WRITE, + Operation.SET, + Operation.ADD, + Operation.SUBTRACT, + Operation.MULTIPLY, + Operation.DIVIDE, + } + case _: + return False + + +def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any): + if operation == Operation.CLEAR: + return True + match variable_type: + case SegmentType.STRING: + return isinstance(value, str) + + case SegmentType.NUMBER: + if not isinstance(value, int | float): + return False + if operation == Operation.DIVIDE and value == 0: + return False + return True + + case SegmentType.OBJECT: + return isinstance(value, dict) + + # Array & Append + case SegmentType.ARRAY_ANY if operation == Operation.APPEND: + return isinstance(value, str | float | int | dict) + case SegmentType.ARRAY_STRING if operation == Operation.APPEND: + return isinstance(value, str) + case SegmentType.ARRAY_NUMBER if operation == Operation.APPEND: + return isinstance(value, int | float) + case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND: + return isinstance(value, dict) + + # Array & Extend / Overwrite + case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}: + return isinstance(value, list) and all(isinstance(item, str | float | int | dict) for item in value) + case SegmentType.ARRAY_STRING if operation in {Operation.EXTEND, Operation.OVER_WRITE}: + return isinstance(value, list) and all(isinstance(item, str) for item in value) + case SegmentType.ARRAY_NUMBER if operation in {Operation.EXTEND, Operation.OVER_WRITE}: + return isinstance(value, list) and all(isinstance(item, int | float) for item in value) + case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}: + return isinstance(value, list) and all(isinstance(item, dict) for item in value) + + case _: + return False diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py new file mode 100644 index 0000000000..ea59a2f170 --- /dev/null +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -0,0 +1,159 @@ +import json +from typing import Any + +from core.variables import SegmentType, Variable +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.variable_assigner.common import helpers as common_helpers +from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError +from models.workflow import WorkflowNodeExecutionStatus + +from . import helpers +from .constants import EMPTY_VALUE_MAPPING +from .entities import VariableAssignerNodeData +from .enums import InputType, Operation +from .exc import ( + ConversationIDNotFoundError, + InputTypeNotSupportedError, + InvalidInputValueError, + OperationNotSupportedError, + VariableNotFoundError, +) + + +class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): + _node_data_cls = VariableAssignerNodeData + _node_type = NodeType.VARIABLE_ASSIGNER + + def _run(self) -> NodeRunResult: + inputs = self.node_data.model_dump() + process_data = {} + # NOTE: This node has no outputs + updated_variables: list[Variable] = [] + + try: + for item in self.node_data.items: + variable = self.graph_runtime_state.variable_pool.get(item.variable_selector) + + # ==================== Validation Part + + # Check if variable exists + if not isinstance(variable, Variable): + raise VariableNotFoundError(variable_selector=item.variable_selector) + + # Check if operation is supported + if not helpers.is_operation_supported(variable_type=variable.value_type, operation=item.operation): + raise OperationNotSupportedError(operation=item.operation, varialbe_type=variable.value_type) + + # Check if variable input is supported + if item.input_type == InputType.VARIABLE and not helpers.is_variable_input_supported( + operation=item.operation + ): + raise InputTypeNotSupportedError(input_type=InputType.VARIABLE, operation=item.operation) + + # Check if constant input is supported + if item.input_type == InputType.CONSTANT and not helpers.is_constant_input_supported( + variable_type=variable.value_type, operation=item.operation + ): + raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation) + + # Get value from variable pool + if ( + item.input_type == InputType.VARIABLE + and item.operation != Operation.CLEAR + and item.value is not None + ): + value = self.graph_runtime_state.variable_pool.get(item.value) + if value is None: + raise VariableNotFoundError(variable_selector=item.value) + # Skip if value is NoneSegment + if value.value_type == SegmentType.NONE: + continue + item.value = value.value + + # If set string / bytes / bytearray to object, try convert string to object. + if ( + item.operation == Operation.SET + and variable.value_type == SegmentType.OBJECT + and isinstance(item.value, str | bytes | bytearray) + ): + try: + item.value = json.loads(item.value) + except json.JSONDecodeError: + raise InvalidInputValueError(value=item.value) + + # Check if input value is valid + if not helpers.is_input_value_valid( + variable_type=variable.value_type, operation=item.operation, value=item.value + ): + raise InvalidInputValueError(value=item.value) + + # ==================== Execution Part + + updated_value = self._handle_item( + variable=variable, + operation=item.operation, + value=item.value, + ) + variable = variable.model_copy(update={"value": updated_value}) + updated_variables.append(variable) + except VariableOperatorNodeError as e: + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=inputs, + process_data=process_data, + error=str(e), + ) + + # Update variables + for variable in updated_variables: + self.graph_runtime_state.variable_pool.add(variable.selector, variable) + process_data[variable.name] = variable.value + + if variable.selector[0] == CONVERSATION_VARIABLE_NODE_ID: + conversation_id = self.graph_runtime_state.variable_pool.get(["sys", "conversation_id"]) + if not conversation_id: + raise ConversationIDNotFoundError + else: + conversation_id = conversation_id.value + common_helpers.update_conversation_variable( + conversation_id=conversation_id, + variable=variable, + ) + + return NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs=inputs, + process_data=process_data, + ) + + def _handle_item( + self, + *, + variable: Variable, + operation: Operation, + value: Any, + ): + match operation: + case Operation.OVER_WRITE: + return value + case Operation.CLEAR: + return EMPTY_VALUE_MAPPING[variable.value_type] + case Operation.APPEND: + return variable.value + [value] + case Operation.EXTEND: + return variable.value + value + case Operation.SET: + return value + case Operation.ADD: + return variable.value + value + case Operation.SUBTRACT: + return variable.value - value + case Operation.MULTIPLY: + return variable.value * value + case Operation.DIVIDE: + return variable.value / value + case _: + raise OperationNotSupportedError(operation=operation, varialbe_type=variable.value_type) diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 84b251223f..811e40c11e 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -2,13 +2,12 @@ import logging import time import uuid from collections.abc import Generator, Mapping, Sequence -from typing import Any, Optional, cast +from typing import Any, Optional from configs import dify_config -from core.app.app_config.entities import FileUploadConfig from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from core.file.models import File, FileTransferMethod, ImageConfig +from core.file.models import File from core.workflow.callbacks import WorkflowCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.errors import WorkflowNodeRunFailedError @@ -18,10 +17,9 @@ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParam from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.graph_engine.graph_engine import GraphEngine from core.workflow.nodes import NodeType -from core.workflow.nodes.base import BaseNode, BaseNodeData +from core.workflow.nodes.base import BaseNode from core.workflow.nodes.event import NodeEvent -from core.workflow.nodes.llm import LLMNodeData -from core.workflow.nodes.node_mapping import node_type_classes_mapping +from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from factories import file_factory from models.enums import UserFrom from models.workflow import ( @@ -115,7 +113,12 @@ class WorkflowEntry: @classmethod def single_step_run( - cls, workflow: Workflow, node_id: str, user_id: str, user_inputs: dict + cls, + *, + workflow: Workflow, + node_id: str, + user_id: str, + user_inputs: dict, ) -> tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]: """ Single step run workflow node @@ -135,29 +138,18 @@ class WorkflowEntry: raise ValueError("nodes not found in workflow graph") # fetch node config from node id - node_config = None - for node in nodes: - if node.get("id") == node_id: - node_config = node - break - - if not node_config: + try: + node_config = next(filter(lambda node: node["id"] == node_id, nodes)) + except StopIteration: raise ValueError("node id not found in workflow graph") # Get node class node_type = NodeType(node_config.get("data", {}).get("type")) - node_cls = node_type_classes_mapping.get(node_type) - node_cls = cast(type[BaseNode], node_cls) - - if not node_cls: - raise ValueError(f"Node class not found for node type {node_type}") + node_version = node_config.get("data", {}).get("version", "1") + node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] # init variable pool - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - environment_variables=workflow.environment_variables, - ) + variable_pool = VariablePool(environment_variables=workflow.environment_variables) # init graph graph = Graph.init(graph_config=workflow.graph_dict) @@ -183,28 +175,24 @@ class WorkflowEntry: try: # variable selector to variable mapping - try: - variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=workflow.graph_dict, config=node_config - ) - except NotImplementedError: - variable_mapping = {} - - cls.mapping_user_inputs_to_variable_pool( - variable_mapping=variable_mapping, - user_inputs=user_inputs, - variable_pool=variable_pool, - tenant_id=workflow.tenant_id, - node_type=node_type, - node_data=node_instance.node_data, + variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( + graph_config=workflow.graph_dict, config=node_config ) + except NotImplementedError: + variable_mapping = {} + cls.mapping_user_inputs_to_variable_pool( + variable_mapping=variable_mapping, + user_inputs=user_inputs, + variable_pool=variable_pool, + tenant_id=workflow.tenant_id, + ) + try: # run node generator = node_instance.run() - - return node_instance, generator except Exception as e: raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) + return node_instance, generator @staticmethod def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: @@ -231,12 +219,11 @@ class WorkflowEntry: @classmethod def mapping_user_inputs_to_variable_pool( cls, + *, variable_mapping: Mapping[str, Sequence[str]], user_inputs: dict, variable_pool: VariablePool, tenant_id: str, - node_type: NodeType, - node_data: BaseNodeData, ) -> None: for node_variable, variable_selector in variable_mapping.items(): # fetch node id and variable key from node_variable @@ -254,40 +241,21 @@ class WorkflowEntry: # fetch variable node id from variable selector variable_node_id = variable_selector[0] variable_key_list = variable_selector[1:] - variable_key_list = cast(list[str], variable_key_list) + variable_key_list = list(variable_key_list) # get input value input_value = user_inputs.get(node_variable) if not input_value: input_value = user_inputs.get(node_variable_key) - # FIXME: temp fix for image type - if node_type == NodeType.LLM: - new_value = [] - if isinstance(input_value, list): - node_data = cast(LLMNodeData, node_data) - - detail = node_data.vision.configs.detail if node_data.vision.configs else None - - for item in input_value: - if isinstance(item, dict) and "type" in item and item["type"] == "image": - transfer_method = FileTransferMethod.value_of(item.get("transfer_method")) - mapping = { - "id": item.get("id"), - "transfer_method": transfer_method, - "upload_file_id": item.get("upload_file_id"), - "url": item.get("url"), - } - config = FileUploadConfig(image_config=ImageConfig(detail=detail) if detail else None) - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - config=config, - ) - new_value.append(file) - - if new_value: - input_value = new_value + if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value: + input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id) + if ( + isinstance(input_value, list) + and all(isinstance(item, dict) for item in input_value) + and all("type" in item and "transfer_method" in item for item in input_value) + ): + input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id) # append variable and value to variable pool variable_pool.add([variable_node_id] + variable_key_list, input_value) diff --git a/api/dify_app.py b/api/dify_app.py new file mode 100644 index 0000000000..d6deb8e007 --- /dev/null +++ b/api/dify_app.py @@ -0,0 +1,5 @@ +from flask import Flask + + +class DifyApp(Flask): + pass diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 5af45e1e50..24fa013697 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -33,7 +33,7 @@ def handle(sender, **kwargs): raise NotFound("Document not found") document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/events/event_handlers/update_provider_last_used_at_when_message_created.py b/api/events/event_handlers/update_provider_last_used_at_when_message_created.py index a80572c0de..f225ef8e88 100644 --- a/api/events/event_handlers/update_provider_last_used_at_when_message_created.py +++ b/api/events/event_handlers/update_provider_last_used_at_when_message_created.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity from events.message_event import message_was_created @@ -17,5 +17,5 @@ def handle(sender, **kwargs): db.session.query(Provider).filter( Provider.tenant_id == application_generate_entity.app_config.tenant_id, Provider.provider_name == application_generate_entity.model_conf.provider, - ).update({"last_used": datetime.now(timezone.utc).replace(tzinfo=None)}) + ).update({"last_used": datetime.now(UTC).replace(tzinfo=None)}) db.session.commit() diff --git a/api/extensions/ext_app_metrics.py b/api/extensions/ext_app_metrics.py new file mode 100644 index 0000000000..de1cdfeb98 --- /dev/null +++ b/api/extensions/ext_app_metrics.py @@ -0,0 +1,65 @@ +import json +import os +import threading + +from flask import Response + +from configs import dify_config +from dify_app import DifyApp + + +def init_app(app: DifyApp): + @app.after_request + def after_request(response): + """Add Version headers to the response.""" + response.headers.add("X-Version", dify_config.CURRENT_VERSION) + response.headers.add("X-Env", dify_config.DEPLOY_ENV) + return response + + @app.route("/health") + def health(): + return Response( + json.dumps({"pid": os.getpid(), "status": "ok", "version": dify_config.CURRENT_VERSION}), + status=200, + content_type="application/json", + ) + + @app.route("/threads") + def threads(): + num_threads = threading.active_count() + threads = threading.enumerate() + + thread_list = [] + for thread in threads: + thread_name = thread.name + thread_id = thread.ident + is_alive = thread.is_alive() + + thread_list.append( + { + "name": thread_name, + "id": thread_id, + "is_alive": is_alive, + } + ) + + return { + "pid": os.getpid(), + "thread_num": num_threads, + "threads": thread_list, + } + + @app.route("/db-pool-stat") + def pool_stat(): + from extensions.ext_database import db + + engine = db.engine + return { + "pid": os.getpid(), + "pool_size": engine.pool.size(), + "checked_in_connections": engine.pool.checkedin(), + "checked_out_connections": engine.pool.checkedout(), + "overflow_connections": engine.pool.overflow(), + "connection_timeout": engine.pool.timeout(), + "recycle_time": db.engine.pool._recycle, + } diff --git a/api/extensions/ext_blueprints.py b/api/extensions/ext_blueprints.py new file mode 100644 index 0000000000..fcd1547a2f --- /dev/null +++ b/api/extensions/ext_blueprints.py @@ -0,0 +1,48 @@ +from configs import dify_config +from dify_app import DifyApp + + +def init_app(app: DifyApp): + # register blueprint routers + + from flask_cors import CORS + + from controllers.console import bp as console_app_bp + from controllers.files import bp as files_bp + from controllers.inner_api import bp as inner_api_bp + from controllers.service_api import bp as service_api_bp + from controllers.web import bp as web_bp + + CORS( + service_api_bp, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + ) + app.register_blueprint(service_api_bp) + + CORS( + web_bp, + resources={r"/*": {"origins": dify_config.WEB_API_CORS_ALLOW_ORIGINS}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization", "X-App-Code"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) + + app.register_blueprint(web_bp) + + CORS( + console_app_bp, + resources={r"/*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}}, + supports_credentials=True, + allow_headers=["Content-Type", "Authorization"], + methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"], + expose_headers=["X-Version", "X-Env"], + ) + + app.register_blueprint(console_app_bp) + + CORS(files_bp, allow_headers=["Content-Type"], methods=["GET", "PUT", "POST", "DELETE", "OPTIONS", "PATCH"]) + app.register_blueprint(files_bp) + + app.register_blueprint(inner_api_bp) diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 1b78e36a57..9dbc4b93d4 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -3,12 +3,12 @@ from datetime import timedelta import pytz from celery import Celery, Task from celery.schedules import crontab -from flask import Flask from configs import dify_config +from dify_app import DifyApp -def init_app(app: Flask) -> Celery: +def init_app(app: DifyApp) -> Celery: class FlaskTask(Task): def __call__(self, *args: object, **kwargs: object) -> object: with app.app_context(): @@ -86,7 +86,7 @@ def init_app(app: Flask) -> Celery: }, "update_tidb_serverless_status_task": { "task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task", - "schedule": crontab(minute="30", hour="*"), + "schedule": timedelta(minutes=10), }, "clean_messages": { "task": "schedule.clean_messages.clean_messages", diff --git a/api/extensions/ext_code_based_extension.py b/api/extensions/ext_code_based_extension.py index a8ae733aa6..9e4b4a41d9 100644 --- a/api/extensions/ext_code_based_extension.py +++ b/api/extensions/ext_code_based_extension.py @@ -1,7 +1,8 @@ from core.extension.extension import Extension +from dify_app import DifyApp -def init(): +def init_app(app: DifyApp): code_based_extension.init() diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py new file mode 100644 index 0000000000..ccf0d316ca --- /dev/null +++ b/api/extensions/ext_commands.py @@ -0,0 +1,29 @@ +from dify_app import DifyApp + + +def init_app(app: DifyApp): + from commands import ( + add_qdrant_doc_id_index, + convert_to_agent_apps, + create_tenant, + fix_app_site_missing, + reset_email, + reset_encrypt_key_pair, + reset_password, + upgrade_db, + vdb_migrate, + ) + + cmds_to_register = [ + reset_password, + reset_email, + reset_encrypt_key_pair, + vdb_migrate, + convert_to_agent_apps, + add_qdrant_doc_id_index, + create_tenant, + upgrade_db, + fix_app_site_missing, + ] + for cmd in cmds_to_register: + app.cli.add_command(cmd) diff --git a/api/extensions/ext_compress.py b/api/extensions/ext_compress.py index a6de28597b..9c3a663af4 100644 --- a/api/extensions/ext_compress.py +++ b/api/extensions/ext_compress.py @@ -1,17 +1,13 @@ -from flask import Flask - from configs import dify_config +from dify_app import DifyApp + +def is_enabled() -> bool: + return dify_config.API_COMPRESSION_ENABLED -def init_app(app: Flask): - if dify_config.API_COMPRESSION_ENABLED: - from flask_compress import Compress - app.config["COMPRESS_MIMETYPES"] = [ - "application/json", - "image/svg+xml", - "text/html", - ] +def init_app(app: DifyApp): + from flask_compress import Compress - compress = Compress() - compress.init_app(app) + compress = Compress() + compress.init_app(app) diff --git a/api/extensions/ext_database.py b/api/extensions/ext_database.py index f6ffa53634..e293afa111 100644 --- a/api/extensions/ext_database.py +++ b/api/extensions/ext_database.py @@ -1,6 +1,8 @@ from flask_sqlalchemy import SQLAlchemy from sqlalchemy import MetaData +from dify_app import DifyApp + POSTGRES_INDEXES_NAMING_CONVENTION = { "ix": "%(column_0_label)s_idx", "uq": "%(table_name)s_%(column_0_name)s_key", @@ -13,5 +15,5 @@ metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION) db = SQLAlchemy(metadata=metadata) -def init_app(app): +def init_app(app: DifyApp): db.init_app(app) diff --git a/api/extensions/ext_hosting_provider.py b/api/extensions/ext_hosting_provider.py index 49e2fcb0c7..3980eccf8e 100644 --- a/api/extensions/ext_hosting_provider.py +++ b/api/extensions/ext_hosting_provider.py @@ -1,9 +1,10 @@ -from flask import Flask - from core.hosting_configuration import HostingConfiguration hosting_configuration = HostingConfiguration() -def init_app(app: Flask): +from dify_app import DifyApp + + +def init_app(app: DifyApp): hosting_configuration.init_app(app) diff --git a/api/extensions/ext_import_modules.py b/api/extensions/ext_import_modules.py new file mode 100644 index 0000000000..eefdfd3823 --- /dev/null +++ b/api/extensions/ext_import_modules.py @@ -0,0 +1,6 @@ +from dify_app import DifyApp + + +def init_app(app: DifyApp): + from events import event_handlers # noqa: F401 + from models import account, dataset, model, source, task, tool, tools, web # noqa: F401 diff --git a/api/extensions/ext_logging.py b/api/extensions/ext_logging.py index a15c73bd71..738d5c7bd2 100644 --- a/api/extensions/ext_logging.py +++ b/api/extensions/ext_logging.py @@ -3,12 +3,11 @@ import os import sys from logging.handlers import RotatingFileHandler -from flask import Flask - from configs import dify_config +from dify_app import DifyApp -def init_app(app: Flask): +def init_app(app: DifyApp): log_handlers = [] log_file = dify_config.LOG_FILE if log_file: diff --git a/api/extensions/ext_login.py b/api/extensions/ext_login.py index f7d5cffdda..b295530714 100644 --- a/api/extensions/ext_login.py +++ b/api/extensions/ext_login.py @@ -1,7 +1,62 @@ +import json + import flask_login +from flask import Response, request +from flask_login import user_loaded_from_request, user_logged_in +from werkzeug.exceptions import Unauthorized + +import contexts +from dify_app import DifyApp +from libs.passport import PassportService +from services.account_service import AccountService login_manager = flask_login.LoginManager() -def init_app(app): +# Flask-Login configuration +@login_manager.request_loader +def load_user_from_request(request_from_flask_login): + """Load user based on the request.""" + if request.blueprint not in {"console", "inner_api"}: + return None + # Check if the user_id contains a dot, indicating the old format + auth_header = request.headers.get("Authorization", "") + if not auth_header: + auth_token = request.args.get("_token") + if not auth_token: + raise Unauthorized("Invalid Authorization token.") + else: + if " " not in auth_header: + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + auth_scheme, auth_token = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + if auth_scheme != "bearer": + raise Unauthorized("Invalid Authorization header format. Expected 'Bearer ' format.") + + decoded = PassportService().verify(auth_token) + user_id = decoded.get("user_id") + + logged_in_account = AccountService.load_logged_in_account(account_id=user_id) + return logged_in_account + + +@user_logged_in.connect +@user_loaded_from_request.connect +def on_user_logged_in(_sender, user): + """Called when a user logged in.""" + if user: + contexts.tenant_id.set(user.current_tenant_id) + + +@login_manager.unauthorized_handler +def unauthorized_handler(): + """Handle unauthorized requests.""" + return Response( + json.dumps({"code": "unauthorized", "message": "Unauthorized."}), + status=401, + content_type="application/json", + ) + + +def init_app(app: DifyApp): login_manager.init_app(app) diff --git a/api/extensions/ext_mail.py b/api/extensions/ext_mail.py index 5c5b331d8a..468aedd47e 100644 --- a/api/extensions/ext_mail.py +++ b/api/extensions/ext_mail.py @@ -1,10 +1,10 @@ import logging from typing import Optional -import resend from flask import Flask from configs import dify_config +from dify_app import DifyApp class Mail: @@ -26,6 +26,8 @@ class Mail: match mail_type: case "resend": + import resend + api_key = dify_config.RESEND_API_KEY if not api_key: raise ValueError("RESEND_API_KEY is not set") @@ -84,7 +86,11 @@ class Mail: ) -def init_app(app: Flask): +def is_enabled() -> bool: + return dify_config.MAIL_TYPE is not None and dify_config.MAIL_TYPE != "" + + +def init_app(app: DifyApp): mail.init_app(app) diff --git a/api/extensions/ext_migrate.py b/api/extensions/ext_migrate.py index e7b278fc38..6d8f35c30d 100644 --- a/api/extensions/ext_migrate.py +++ b/api/extensions/ext_migrate.py @@ -1,5 +1,9 @@ -import flask_migrate +from dify_app import DifyApp -def init(app, db): +def init_app(app: DifyApp): + import flask_migrate + + from extensions.ext_database import db + flask_migrate.Migrate(app, db) diff --git a/api/extensions/ext_proxy_fix.py b/api/extensions/ext_proxy_fix.py index c106a4384a..3b895ac95b 100644 --- a/api/extensions/ext_proxy_fix.py +++ b/api/extensions/ext_proxy_fix.py @@ -1,9 +1,8 @@ -from flask import Flask - from configs import dify_config +from dify_app import DifyApp -def init_app(app: Flask): +def init_app(app: DifyApp): if dify_config.RESPECT_XFORWARD_HEADERS_ENABLED: from werkzeug.middleware.proxy_fix import ProxyFix diff --git a/api/extensions/ext_redis.py b/api/extensions/ext_redis.py index 36f06c1104..da41805707 100644 --- a/api/extensions/ext_redis.py +++ b/api/extensions/ext_redis.py @@ -1,9 +1,12 @@ +from typing import Any, Union + import redis from redis.cluster import ClusterNode, RedisCluster from redis.connection import Connection, SSLConnection from redis.sentinel import Sentinel from configs import dify_config +from dify_app import DifyApp class RedisClientWrapper: @@ -43,13 +46,13 @@ class RedisClientWrapper: redis_client = RedisClientWrapper() -def init_app(app): +def init_app(app: DifyApp): global redis_client - connection_class = Connection + connection_class: type[Union[Connection, SSLConnection]] = Connection if dify_config.REDIS_USE_SSL: connection_class = SSLConnection - redis_params = { + redis_params: dict[str, Any] = { "username": dify_config.REDIS_USERNAME, "password": dify_config.REDIS_PASSWORD, "db": dify_config.REDIS_DB, @@ -59,6 +62,7 @@ def init_app(app): } if dify_config.REDIS_USE_SENTINEL: + assert dify_config.REDIS_SENTINELS is not None, "REDIS_SENTINELS must be set when REDIS_USE_SENTINEL is True" sentinel_hosts = [ (node.split(":")[0], int(node.split(":")[1])) for node in dify_config.REDIS_SENTINELS.split(",") ] @@ -73,11 +77,13 @@ def init_app(app): master = sentinel.master_for(dify_config.REDIS_SENTINEL_SERVICE_NAME, **redis_params) redis_client.initialize(master) elif dify_config.REDIS_USE_CLUSTERS: + assert dify_config.REDIS_CLUSTERS is not None, "REDIS_CLUSTERS must be set when REDIS_USE_CLUSTERS is True" nodes = [ - ClusterNode(host=node.split(":")[0], port=int(node.split.split(":")[1])) + ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1])) for node in dify_config.REDIS_CLUSTERS.split(",") ] - redis_client.initialize(RedisCluster(startup_nodes=nodes, password=dify_config.REDIS_CLUSTERS_PASSWORD)) + # FIXME: mypy error here, try to figure out how to fix it + redis_client.initialize(RedisCluster(startup_nodes=nodes, password=dify_config.REDIS_CLUSTERS_PASSWORD)) # type: ignore else: redis_params.update( { diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 11f1dd93c6..8016356a3e 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -1,25 +1,26 @@ -import openai -import sentry_sdk -from langfuse import parse_error -from sentry_sdk.integrations.celery import CeleryIntegration -from sentry_sdk.integrations.flask import FlaskIntegration -from werkzeug.exceptions import HTTPException - from configs import dify_config -from core.model_runtime.errors.invoke import InvokeRateLimitError +from dify_app import DifyApp + +def init_app(app: DifyApp): + if dify_config.SENTRY_DSN: + import openai + import sentry_sdk + from langfuse import parse_error + from sentry_sdk.integrations.celery import CeleryIntegration + from sentry_sdk.integrations.flask import FlaskIntegration + from werkzeug.exceptions import HTTPException -def before_send(event, hint): - if "exc_info" in hint: - exc_type, exc_value, tb = hint["exc_info"] - if parse_error.defaultErrorResponse in str(exc_value): - return None + from core.model_runtime.errors.invoke import InvokeRateLimitError - return event + def before_send(event, hint): + if "exc_info" in hint: + exc_type, exc_value, tb = hint["exc_info"] + if parse_error.defaultErrorResponse in str(exc_value): + return None + return event -def init_app(app): - if dify_config.SENTRY_DSN: sentry_sdk.init( dsn=dify_config.SENTRY_DSN, integrations=[FlaskIntegration(), CeleryIntegration()], diff --git a/api/extensions/ext_set_secretkey.py b/api/extensions/ext_set_secretkey.py new file mode 100644 index 0000000000..dfb87c0167 --- /dev/null +++ b/api/extensions/ext_set_secretkey.py @@ -0,0 +1,6 @@ +from configs import dify_config +from dify_app import DifyApp + + +def init_app(app: DifyApp): + app.secret_key = dify_config.SECRET_KEY diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index fa88da68b7..6c30b7a257 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -5,6 +5,7 @@ from typing import Union from flask import Flask from configs import dify_config +from dify_app import DifyApp from extensions.storage.base_storage import BaseStorage from extensions.storage.storage_type import StorageType @@ -122,5 +123,5 @@ class Storage: storage = Storage() -def init_app(app: Flask): +def init_app(app: DifyApp): storage.init_app(app) diff --git a/api/extensions/ext_timezone.py b/api/extensions/ext_timezone.py new file mode 100644 index 0000000000..77650bf972 --- /dev/null +++ b/api/extensions/ext_timezone.py @@ -0,0 +1,11 @@ +import os +import time + +from dify_app import DifyApp + + +def init_app(app: DifyApp): + os.environ["TZ"] = "UTC" + # windows platform not support tzset + if hasattr(time, "tzset"): + time.tzset() diff --git a/api/extensions/ext_warnings.py b/api/extensions/ext_warnings.py new file mode 100644 index 0000000000..246f977af5 --- /dev/null +++ b/api/extensions/ext_warnings.py @@ -0,0 +1,7 @@ +from dify_app import DifyApp + + +def init_app(app: DifyApp): + import warnings + + warnings.simplefilter("ignore", ResourceWarning) diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index 11a7544274..b26caa8671 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from azure.storage.blob import AccountSasPermissions, BlobServiceClient, ResourceTypes, generate_account_sas @@ -67,7 +67,7 @@ class AzureBlobStorage(BaseStorage): account_key=self.account_key, resource_types=ResourceTypes(service=True, container=True, object=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), - expiry=datetime.now(timezone.utc).replace(tzinfo=None) + timedelta(hours=1), + expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1), ) redis_client.set(cache_key, sas_token, ex=3000) return BlobServiceClient(account_url=self.account_url, credential=sas_token) diff --git a/api/extensions/storage/storage_type.py b/api/extensions/storage/storage_type.py index 415bf251f6..e7fa405afa 100644 --- a/api/extensions/storage/storage_type.py +++ b/api/extensions/storage/storage_type.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class StorageType(str, Enum): +class StorageType(StrEnum): ALIYUN_OSS = "aliyun-oss" AZURE_BLOB = "azure-blob" BAIDU_OBS = "baidu-obs" diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 4da0140d19..8538775a67 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -1,10 +1,11 @@ import mimetypes from collections.abc import Callable, Mapping, Sequence -from typing import Any +from typing import Any, cast import httpx from sqlalchemy import select +from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS from core.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig from core.helper import ssrf_proxy from extensions.ext_database import db @@ -51,8 +52,6 @@ def build_from_mapping( tenant_id: str, config: FileUploadConfig | None = None, ) -> File: - config = config or FileUploadConfig() - transfer_method = FileTransferMethod.value_of(mapping.get("transfer_method")) build_functions: dict[FileTransferMethod, Callable] = { @@ -71,7 +70,12 @@ def build_from_mapping( transfer_method=transfer_method, ) - if not _is_file_valid_with_config(file=file, config=config): + if config and not _is_file_valid_with_config( + input_file_type=mapping.get("type", FileType.CUSTOM), + file_extension=file.extension, + file_transfer_method=file.transfer_method, + config=config, + ): raise ValueError(f"File validation failed for file: {file.filename}") return file @@ -80,12 +84,9 @@ def build_from_mapping( def build_from_mappings( *, mappings: Sequence[Mapping[str, Any]], - config: FileUploadConfig | None, + config: FileUploadConfig | None = None, tenant_id: str, ) -> Sequence[File]: - if not config: - return [] - files = [ build_from_mapping( mapping=mapping, @@ -96,13 +97,14 @@ def build_from_mappings( ] if ( + config # If image config is set. - config.image_config + and config.image_config # And the number of image files exceeds the maximum limit and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits ): raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") - if config.number_limits and len(files) > config.number_limits: + if config and config.number_limits and len(files) > config.number_limits: raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") return files @@ -114,17 +116,18 @@ def _build_from_local_file( tenant_id: str, transfer_method: FileTransferMethod, ) -> File: - file_type = FileType.value_of(mapping.get("type")) stmt = select(UploadFile).where( UploadFile.id == mapping.get("upload_file_id"), UploadFile.tenant_id == tenant_id, ) row = db.session.scalar(stmt) - if row is None: raise ValueError("Invalid upload file") + file_type = FileType(mapping.get("type", "custom")) + file_type = _standardize_file_type(file_type, extension="." + row.extension, mime_type=row.mime_type) + return File( id=mapping.get("id"), filename=row.name, @@ -152,11 +155,14 @@ def _build_from_remote_url( mime_type, filename, file_size = _get_remote_file_info(url) extension = mimetypes.guess_extension(mime_type) or "." + filename.split(".")[-1] if "." in filename else ".bin" + file_type = FileType(mapping.get("type", "custom")) + file_type = _standardize_file_type(file_type, extension=extension, mime_type=mime_type) + return File( id=mapping.get("id"), filename=filename, tenant_id=tenant_id, - type=FileType.value_of(mapping.get("type")), + type=file_type, transfer_method=transfer_method, remote_url=url, mime_type=mime_type, @@ -171,6 +177,7 @@ def _get_remote_file_info(url: str): mime_type = mimetypes.guess_type(filename)[0] or "" resp = ssrf_proxy.head(url, follow_redirects=True) + resp = cast(httpx.Response, resp) if resp.status_code == httpx.codes.OK: if content_disposition := resp.headers.get("Content-Disposition"): filename = str(content_disposition.split("filename=")[-1].strip('"')) @@ -180,20 +187,6 @@ def _get_remote_file_info(url: str): return mime_type, filename, file_size -def _get_file_type_by_mimetype(mime_type: str) -> FileType: - if "image" in mime_type: - file_type = FileType.IMAGE - elif "video" in mime_type: - file_type = FileType.VIDEO - elif "audio" in mime_type: - file_type = FileType.AUDIO - elif "text" in mime_type or "pdf" in mime_type: - file_type = FileType.DOCUMENT - else: - file_type = FileType.CUSTOM - return file_type - - def _build_from_tool_file( *, mapping: Mapping[str, Any], @@ -213,7 +206,8 @@ def _build_from_tool_file( raise ValueError(f"ToolFile {mapping.get('tool_file_id')} not found") extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - file_type = mapping.get("type", _get_file_type_by_mimetype(tool_file.mimetype)) + file_type = FileType(mapping.get("type", "custom")) + file_type = _standardize_file_type(file_type, extension=extension, mime_type=tool_file.mimetype) return File( id=mapping.get("id"), @@ -229,18 +223,69 @@ def _build_from_tool_file( ) -def _is_file_valid_with_config(*, file: File, config: FileUploadConfig) -> bool: - if config.allowed_file_types and file.type not in config.allowed_file_types and file.type != FileType.CUSTOM: - return False - - if config.allowed_file_extensions and file.extension not in config.allowed_file_extensions: +def _is_file_valid_with_config( + *, + input_file_type: str, + file_extension: str, + file_transfer_method: FileTransferMethod, + config: FileUploadConfig, +) -> bool: + if ( + config.allowed_file_types + and input_file_type not in config.allowed_file_types + and input_file_type != FileType.CUSTOM + ): return False - if config.allowed_file_upload_methods and file.transfer_method not in config.allowed_file_upload_methods: + if ( + input_file_type == FileType.CUSTOM + and config.allowed_file_extensions is not None + and file_extension not in config.allowed_file_extensions + ): return False - if file.type == FileType.IMAGE and config.image_config: - if config.image_config.transfer_methods and file.transfer_method not in config.image_config.transfer_methods: + if input_file_type == FileType.IMAGE and config.image_config: + if config.image_config.transfer_methods and file_transfer_method not in config.image_config.transfer_methods: return False return True + + +def _standardize_file_type(file_type: FileType, /, *, extension: str = "", mime_type: str = "") -> FileType: + """ + If custom type, try to guess the file type by extension and mime_type. + """ + if file_type != FileType.CUSTOM: + return FileType(file_type) + guessed_type = None + if extension: + guessed_type = _get_file_type_by_extension(extension) + if guessed_type is None and mime_type: + guessed_type = _get_file_type_by_mimetype(mime_type) + return guessed_type or FileType.CUSTOM + + +def _get_file_type_by_extension(extension: str) -> FileType | None: + extension = extension.lstrip(".") + if extension in IMAGE_EXTENSIONS: + return FileType.IMAGE + elif extension in VIDEO_EXTENSIONS: + return FileType.VIDEO + elif extension in AUDIO_EXTENSIONS: + return FileType.AUDIO + elif extension in DOCUMENT_EXTENSIONS: + return FileType.DOCUMENT + + +def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: + if "image" in mime_type: + file_type = FileType.IMAGE + elif "video" in mime_type: + file_type = FileType.VIDEO + elif "audio" in mime_type: + file_type = FileType.AUDIO + elif "text" in mime_type or "pdf" in mime_type: + file_type = FileType.DOCUMENT + else: + file_type = FileType.CUSTOM + return file_type diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 5b004405b4..16a578728a 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -36,6 +36,7 @@ from core.variables.variables import ( StringVariable, Variable, ) +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID class InvalidSelectorError(ValueError): @@ -62,11 +63,25 @@ SEGMENT_TO_VARIABLE_MAP = { } -def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: - if (value_type := mapping.get("value_type")) is None: - raise VariableError("missing value type") +def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: + if not mapping.get("name"): + raise VariableError("missing name") + return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]]) + + +def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: if not mapping.get("name"): raise VariableError("missing name") + return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]]) + + +def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable: + """ + This factory function is used to create the environment variable or the conversation variable, + not support the File type. + """ + if (value_type := mapping.get("value_type")) is None: + raise VariableError("missing value type") if (value := mapping.get("value")) is None: raise VariableError("missing value") match value_type: @@ -92,6 +107,8 @@ def build_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable: raise VariableError(f"not supported value type {value_type}") if result.size > dify_config.MAX_VARIABLE_SIZE: raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}") + if not result.selector: + result = result.model_copy(update={"selector": selector}) return result diff --git a/api/libs/helper.py b/api/libs/helper.py index 023240a9a4..7652d73c8b 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -6,7 +6,7 @@ import string import subprocess import time import uuid -from collections.abc import Generator +from collections.abc import Generator, Mapping from datetime import datetime from hashlib import sha256 from typing import Any, Optional, Union @@ -31,12 +31,12 @@ class AppIconUrlField(fields.Raw): if obj is None: return None - from models.model import App, IconType + from models.model import App, IconType, Site if isinstance(obj, dict) and "app" in obj: obj = obj["app"] - if isinstance(obj, App) and obj.icon_type == IconType.IMAGE.value: + if isinstance(obj, App | Site) and obj.icon_type == IconType.IMAGE.value: return file_helpers.get_signed_file_url(obj.icon) return None @@ -180,7 +180,9 @@ def generate_text_hash(text: str) -> str: return sha256(hash_text.encode()).hexdigest() -def compact_generate_response(response: Union[dict, RateLimitGenerator]) -> Response: +def compact_generate_response( + response: Union[Mapping[str, Any], RateLimitGenerator, Generator[str, None, None]], +) -> Response: if isinstance(response, dict): return Response(response=json.dumps(response), status=200, mimetype="application/json") else: diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index e747ea97ad..48249e4a35 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -70,7 +70,7 @@ class NotionOAuth(OAuthDataSource): if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: new_data_source_binding = DataSourceOauthBinding( @@ -106,7 +106,7 @@ class NotionOAuth(OAuthDataSource): if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: new_data_source_binding = DataSourceOauthBinding( @@ -141,7 +141,7 @@ class NotionOAuth(OAuthDataSource): } data_source_binding.source_info = new_source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() else: raise ValueError("Data source binding not found") @@ -221,15 +221,29 @@ class NotionOAuth(OAuthDataSource): return pages def notion_page_search(self, access_token: str): - data = {"filter": {"value": "page", "property": "object"}} - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {access_token}", - "Notion-Version": "2022-06-28", - } - response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) - response_json = response.json() - results = response_json.get("results", []) + results = [] + next_cursor = None + has_more = True + + while has_more: + data = { + "filter": {"value": "page", "property": "object"}, + **({"start_cursor": next_cursor} if next_cursor else {}), + } + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", + } + + response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) + response_json = response.json() + + results.extend(response_json.get("results", [])) + + has_more = response_json.get("has_more", False) + next_cursor = response_json.get("next_cursor", None) + return results def notion_block_parent_page_id(self, access_token: str, block_id: str): @@ -260,13 +274,26 @@ class NotionOAuth(OAuthDataSource): return "workspace" def notion_database_search(self, access_token: str): - data = {"filter": {"value": "database", "property": "object"}} - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {access_token}", - "Notion-Version": "2022-06-28", - } - response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) - response_json = response.json() - results = response_json.get("results", []) + results = [] + next_cursor = None + has_more = True + + while has_more: + data = { + "filter": {"value": "database", "property": "object"}, + **({"start_cursor": next_cursor} if next_cursor else {}), + } + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + "Notion-Version": "2022-06-28", + } + response = requests.post(url=self._NOTION_PAGE_SEARCH, json=data, headers=headers) + response_json = response.json() + + results.extend(response_json.get("results", [])) + + has_more = response_json.get("has_more", False) + next_cursor = response_json.get("next_cursor", None) + return results diff --git a/api/libs/threadings_utils.py b/api/libs/threadings_utils.py new file mode 100644 index 0000000000..d356def418 --- /dev/null +++ b/api/libs/threadings_utils.py @@ -0,0 +1,19 @@ +from configs import dify_config + + +def apply_gevent_threading_patch(): + """ + Run threading patch by gevent + to make standard library threading compatible. + Patching should be done as early as possible in the lifecycle of the program. + :return: + """ + if not dify_config.DEBUG: + from gevent import monkey + from grpc.experimental import gevent as grpc_gevent + + # gevent + monkey.patch_all() + + # grpc gevent + grpc_gevent.init_gevent() diff --git a/api/libs/version_utils.py b/api/libs/version_utils.py new file mode 100644 index 0000000000..10edf8a058 --- /dev/null +++ b/api/libs/version_utils.py @@ -0,0 +1,12 @@ +import sys + + +def check_supported_python_version(): + python_version = sys.version_info + if not ((3, 11) <= python_version < (3, 13)): + print( + "Aborted to launch the service " + f" with unsupported Python version {python_version.major}.{python_version.minor}." + " Please ensure Python 3.11 or 3.12." + ) + raise SystemExit(1) diff --git a/api/models/__init__.py b/api/models/__init__.py index cd6c7674da..61a38870cf 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -24,30 +24,30 @@ from .workflow import ( ) __all__ = [ + "Account", + "AccountIntegrate", + "ApiToken", + "App", + "AppMode", + "Conversation", "ConversationVariable", - "Document", + "DataSourceOauthBinding", "Dataset", "DatasetProcessRule", + "Document", "DocumentSegment", - "DataSourceOauthBinding", - "AppMode", - "Workflow", - "App", - "Message", "EndUser", - "MessageFile", - "UploadFile", - "Account", - "WorkflowAppLog", - "WorkflowRun", - "Site", "InstalledApp", - "RecommendedApp", - "ApiToken", - "AccountIntegrate", "InvitationCode", - "Tenant", - "Conversation", + "Message", "MessageAnnotation", + "MessageFile", + "RecommendedApp", + "Site", + "Tenant", "ToolFile", + "UploadFile", + "Workflow", + "WorkflowAppLog", + "WorkflowRun", ] diff --git a/api/models/account.py b/api/models/account.py index 60b4f11aad..951e836dec 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -8,7 +8,7 @@ from extensions.ext_database import db from .types import StringUUID -class AccountStatus(str, enum.Enum): +class AccountStatus(enum.StrEnum): PENDING = "pending" UNINITIALIZED = "uninitialized" ACTIVE = "active" @@ -56,8 +56,8 @@ class Account(UserMixin, db.Model): self._current_tenant = tenant @property - def current_tenant_id(self): - return self._current_tenant.id + def current_tenant_id(self) -> str | None: + return self._current_tenant.id if self._current_tenant else None @current_tenant_id.setter def current_tenant_id(self, value: str): @@ -108,6 +108,10 @@ class Account(UserMixin, db.Model): def is_admin_or_owner(self): return TenantAccountRole.is_privileged_role(self._current_tenant.current_role) + @property + def is_admin(self): + return TenantAccountRole.is_admin_role(self._current_tenant.current_role) + @property def is_editor(self): return TenantAccountRole.is_editing_role(self._current_tenant.current_role) @@ -121,12 +125,12 @@ class Account(UserMixin, db.Model): return self._current_tenant.current_role == TenantAccountRole.DATASET_OPERATOR -class TenantStatus(str, enum.Enum): +class TenantStatus(enum.StrEnum): NORMAL = "normal" ARCHIVE = "archive" -class TenantAccountRole(str, enum.Enum): +class TenantAccountRole(enum.StrEnum): OWNER = "owner" ADMIN = "admin" EDITOR = "editor" @@ -147,6 +151,10 @@ class TenantAccountRole(str, enum.Enum): def is_privileged_role(role: str) -> bool: return role and role in {TenantAccountRole.OWNER, TenantAccountRole.ADMIN} + @staticmethod + def is_admin_role(role: str) -> bool: + return role and role == TenantAccountRole.ADMIN + @staticmethod def is_non_owner_role(role: str) -> bool: return role and role in { diff --git a/api/models/dataset.py b/api/models/dataset.py index a8b2c419d1..8ab957e875 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -23,7 +23,7 @@ from .model import App, Tag, TagBinding, UploadFile from .types import StringUUID -class DatasetPermissionEnum(str, enum.Enum): +class DatasetPermissionEnum(enum.StrEnum): ONLY_ME = "only_me" ALL_TEAM = "all_team_members" PARTIAL_TEAM = "partial_members" diff --git a/api/models/enums.py b/api/models/enums.py index a83d35e042..7b9500ebe4 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -1,16 +1,16 @@ -from enum import Enum +from enum import StrEnum -class CreatedByRole(str, Enum): +class CreatedByRole(StrEnum): ACCOUNT = "account" END_USER = "end_user" -class UserFrom(str, Enum): +class UserFrom(StrEnum): ACCOUNT = "account" END_USER = "end-user" -class WorkflowRunTriggeredFrom(str, Enum): +class WorkflowRunTriggeredFrom(StrEnum): DEBUGGING = "debugging" APP_RUN = "app-run" diff --git a/api/models/model.py b/api/models/model.py index 4562d8cf29..03b8e0bea5 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -3,7 +3,7 @@ import re import uuid from collections.abc import Mapping from datetime import datetime -from enum import Enum +from enum import Enum, StrEnum from typing import Any, Literal, Optional import sqlalchemy as sa @@ -32,7 +32,7 @@ class DifySetup(db.Model): setup_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) -class AppMode(str, Enum): +class AppMode(StrEnum): COMPLETION = "completion" WORKFLOW = "workflow" CHAT = "chat" diff --git a/api/models/task.py b/api/models/task.py index 57b147c78d..5d89ff85ac 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime from celery import states @@ -16,8 +16,8 @@ class CeleryTask(db.Model): result = db.Column(db.PickleType, nullable=True) date_done = db.Column( db.DateTime, - default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), - onupdate=lambda: datetime.now(timezone.utc).replace(tzinfo=None), + default=lambda: datetime.now(UTC).replace(tzinfo=None), + onupdate=lambda: datetime.now(UTC).replace(tzinfo=None), nullable=True, ) traceback = db.Column(db.Text, nullable=True) @@ -37,4 +37,4 @@ class CeleryTaskSet(db.Model): id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True) taskset_id = db.Column(db.String(155), unique=True) result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=lambda: datetime.now(timezone.utc).replace(tzinfo=None), nullable=True) + date_done = db.Column(db.DateTime, default=lambda: datetime.now(UTC).replace(tzinfo=None), nullable=True) diff --git a/api/models/workflow.py b/api/models/workflow.py index c6b3000083..c0e70889a8 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,7 +1,7 @@ import json from collections.abc import Mapping, Sequence -from datetime import datetime, timezone -from enum import Enum +from datetime import UTC, datetime +from enum import Enum, StrEnum from typing import Any, Optional, Union import sqlalchemy as sa @@ -108,7 +108,7 @@ class Workflow(db.Model): ) updated_by: Mapped[Optional[str]] = mapped_column(StringUUID) updated_at: Mapped[datetime] = mapped_column( - sa.DateTime, nullable=False, default=datetime.now(tz=timezone.utc), server_onupdate=func.current_timestamp() + sa.DateTime, nullable=False, default=datetime.now(tz=UTC), server_onupdate=func.current_timestamp() ) _environment_variables: Mapped[str] = mapped_column( "environment_variables", db.Text, nullable=False, server_default="{}" @@ -238,7 +238,9 @@ class Workflow(db.Model): tenant_id = contexts.tenant_id.get() environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables) - results = [variable_factory.build_variable_from_mapping(v) for v in environment_variables_dict.values()] + results = [ + variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values() + ] # decrypt secret variables value decrypt_func = ( @@ -303,7 +305,7 @@ class Workflow(db.Model): self._conversation_variables = "{}" variables_dict: dict[str, Any] = json.loads(self._conversation_variables) - results = [variable_factory.build_variable_from_mapping(v) for v in variables_dict.values()] + results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()] return results @conversation_variables.setter @@ -314,7 +316,7 @@ class Workflow(db.Model): ) -class WorkflowRunStatus(Enum): +class WorkflowRunStatus(StrEnum): """ Workflow Run Status Enum """ @@ -393,13 +395,13 @@ class WorkflowRun(db.Model): version = db.Column(db.String(255), nullable=False) graph = db.Column(db.Text) inputs = db.Column(db.Text) - status = db.Column(db.String(255), nullable=False) - outputs: Mapped[str] = db.Column(db.Text) + status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped + outputs: Mapped[str] = mapped_column(sa.Text, default="{}") error = db.Column(db.Text) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) total_steps = db.Column(db.Integer, server_default=db.text("0")) - created_by_role = db.Column(db.String(255), nullable=False) + created_by_role = db.Column(db.String(255), nullable=False) # account, end_user created_by = db.Column(StringUUID, nullable=False) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) finished_at = db.Column(db.DateTime) @@ -793,4 +795,4 @@ class ConversationVariable(db.Model): def to_variable(self) -> Variable: mapping = json.loads(self.data) - return variable_factory.build_variable_from_mapping(mapping) + return variable_factory.build_conversation_variable_from_mapping(mapping) diff --git a/api/poetry.lock b/api/poetry.lock index 944b6c8316..dcd8982c37 100644 --- a/api/poetry.lock +++ b/api/poetry.lock @@ -114,7 +114,6 @@ files = [ [package.dependencies] aiohappyeyeballs = ">=2.3.0" aiosignal = ">=1.1.2" -async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} attrs = ">=17.3.0" frozenlist = ">=1.1.1" multidict = ">=4.5,<7.0" @@ -483,10 +482,8 @@ files = [ ] [package.dependencies] -exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} idna = ">=2.8" sniffio = ">=1.1" -typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} [package.extras] doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] @@ -519,9 +516,6 @@ files = [ {file = "asgiref-3.8.1.tar.gz", hash = "sha256:c343bd80a0bec947a9860adb4c432ffa7db769836c64238fc34bdc3fec84d590"}, ] -[package.dependencies] -typing-extensions = {version = ">=4", markers = "python_version < \"3.11\""} - [package.extras] tests = ["mypy (>=0.800)", "pytest", "pytest-asyncio"] @@ -1092,10 +1086,8 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "os_name == \"nt\""} -importlib-metadata = {version = ">=4.6", markers = "python_full_version < \"3.10.2\""} packaging = ">=19.1" pyproject_hooks = "*" -tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} [package.extras] docs = ["furo (>=2023.08.17)", "sphinx (>=7.0,<8.0)", "sphinx-argparse-cli (>=1.5)", "sphinx-autodoc-typehints (>=1.10)", "sphinx-issues (>=3.0.0)"] @@ -1388,36 +1380,40 @@ files = [ [[package]] name = "chroma-hnswlib" -version = "0.7.3" +version = "0.7.6" description = "Chromas fork of hnswlib" optional = false python-versions = "*" files = [ - {file = "chroma-hnswlib-0.7.3.tar.gz", hash = "sha256:b6137bedde49fffda6af93b0297fe00429fc61e5a072b1ed9377f909ed95a932"}, - {file = "chroma_hnswlib-0.7.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:59d6a7c6f863c67aeb23e79a64001d537060b6995c3eca9a06e349ff7b0998ca"}, - {file = "chroma_hnswlib-0.7.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d71a3f4f232f537b6152947006bd32bc1629a8686df22fd97777b70f416c127a"}, - {file = "chroma_hnswlib-0.7.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c92dc1ebe062188e53970ba13f6b07e0ae32e64c9770eb7f7ffa83f149d4210"}, - {file = "chroma_hnswlib-0.7.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49da700a6656fed8753f68d44b8cc8ae46efc99fc8a22a6d970dc1697f49b403"}, - {file = "chroma_hnswlib-0.7.3-cp310-cp310-win_amd64.whl", hash = "sha256:108bc4c293d819b56476d8f7865803cb03afd6ca128a2a04d678fffc139af029"}, - {file = "chroma_hnswlib-0.7.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:11e7ca93fb8192214ac2b9c0943641ac0daf8f9d4591bb7b73be808a83835667"}, - {file = "chroma_hnswlib-0.7.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6f552e4d23edc06cdeb553cdc757d2fe190cdeb10d43093d6a3319f8d4bf1c6b"}, - {file = "chroma_hnswlib-0.7.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f96f4d5699e486eb1fb95849fe35ab79ab0901265805be7e60f4eaa83ce263ec"}, - {file = "chroma_hnswlib-0.7.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:368e57fe9ebae05ee5844840fa588028a023d1182b0cfdb1d13f607c9ea05756"}, - {file = "chroma_hnswlib-0.7.3-cp311-cp311-win_amd64.whl", hash = "sha256:b7dca27b8896b494456db0fd705b689ac6b73af78e186eb6a42fea2de4f71c6f"}, - {file = "chroma_hnswlib-0.7.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:70f897dc6218afa1d99f43a9ad5eb82f392df31f57ff514ccf4eeadecd62f544"}, - {file = "chroma_hnswlib-0.7.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5aef10b4952708f5a1381c124a29aead0c356f8d7d6e0b520b778aaa62a356f4"}, - {file = "chroma_hnswlib-0.7.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ee2d8d1529fca3898d512079144ec3e28a81d9c17e15e0ea4665697a7923253"}, - {file = "chroma_hnswlib-0.7.3-cp37-cp37m-win_amd64.whl", hash = "sha256:a4021a70e898783cd6f26e00008b494c6249a7babe8774e90ce4766dd288c8ba"}, - {file = "chroma_hnswlib-0.7.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a8f61fa1d417fda848e3ba06c07671f14806a2585272b175ba47501b066fe6b1"}, - {file = "chroma_hnswlib-0.7.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d7563be58bc98e8f0866907368e22ae218d6060601b79c42f59af4eccbbd2e0a"}, - {file = "chroma_hnswlib-0.7.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51b8d411486ee70d7b66ec08cc8b9b6620116b650df9c19076d2d8b6ce2ae914"}, - {file = "chroma_hnswlib-0.7.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9d706782b628e4f43f1b8a81e9120ac486837fbd9bcb8ced70fe0d9b95c72d77"}, - {file = "chroma_hnswlib-0.7.3-cp38-cp38-win_amd64.whl", hash = "sha256:54f053dedc0e3ba657f05fec6e73dd541bc5db5b09aa8bc146466ffb734bdc86"}, - {file = "chroma_hnswlib-0.7.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e607c5a71c610a73167a517062d302c0827ccdd6e259af6e4869a5c1306ffb5d"}, - {file = "chroma_hnswlib-0.7.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c2358a795870156af6761890f9eb5ca8cade57eb10c5f046fe94dae1faa04b9e"}, - {file = "chroma_hnswlib-0.7.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7cea425df2e6b8a5e201fff0d922a1cc1d165b3cfe762b1408075723c8892218"}, - {file = "chroma_hnswlib-0.7.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:454df3dd3e97aa784fba7cf888ad191e0087eef0fd8c70daf28b753b3b591170"}, - {file = "chroma_hnswlib-0.7.3-cp39-cp39-win_amd64.whl", hash = "sha256:df587d15007ca701c6de0ee7d5585dd5e976b7edd2b30ac72bc376b3c3f85882"}, + {file = "chroma_hnswlib-0.7.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f35192fbbeadc8c0633f0a69c3d3e9f1a4eab3a46b65458bbcbcabdd9e895c36"}, + {file = "chroma_hnswlib-0.7.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6f007b608c96362b8f0c8b6b2ac94f67f83fcbabd857c378ae82007ec92f4d82"}, + {file = "chroma_hnswlib-0.7.6-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:456fd88fa0d14e6b385358515aef69fc89b3c2191706fd9aee62087b62aad09c"}, + {file = "chroma_hnswlib-0.7.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5dfaae825499c2beaa3b75a12d7ec713b64226df72a5c4097203e3ed532680da"}, + {file = "chroma_hnswlib-0.7.6-cp310-cp310-win_amd64.whl", hash = "sha256:2487201982241fb1581be26524145092c95902cb09fc2646ccfbc407de3328ec"}, + {file = "chroma_hnswlib-0.7.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:81181d54a2b1e4727369486a631f977ffc53c5533d26e3d366dda243fb0998ca"}, + {file = "chroma_hnswlib-0.7.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4b4ab4e11f1083dd0a11ee4f0e0b183ca9f0f2ed63ededba1935b13ce2b3606f"}, + {file = "chroma_hnswlib-0.7.6-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:53db45cd9173d95b4b0bdccb4dbff4c54a42b51420599c32267f3abbeb795170"}, + {file = "chroma_hnswlib-0.7.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5c093f07a010b499c00a15bc9376036ee4800d335360570b14f7fe92badcdcf9"}, + {file = "chroma_hnswlib-0.7.6-cp311-cp311-win_amd64.whl", hash = "sha256:0540b0ac96e47d0aa39e88ea4714358ae05d64bbe6bf33c52f316c664190a6a3"}, + {file = "chroma_hnswlib-0.7.6-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:e87e9b616c281bfbe748d01705817c71211613c3b063021f7ed5e47173556cb7"}, + {file = "chroma_hnswlib-0.7.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ec5ca25bc7b66d2ecbf14502b5729cde25f70945d22f2aaf523c2d747ea68912"}, + {file = "chroma_hnswlib-0.7.6-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:305ae491de9d5f3c51e8bd52d84fdf2545a4a2bc7af49765cda286b7bb30b1d4"}, + {file = "chroma_hnswlib-0.7.6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:822ede968d25a2c88823ca078a58f92c9b5c4142e38c7c8b4c48178894a0a3c5"}, + {file = "chroma_hnswlib-0.7.6-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2fe6ea949047beed19a94b33f41fe882a691e58b70c55fdaa90274ae78be046f"}, + {file = "chroma_hnswlib-0.7.6-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:feceff971e2a2728c9ddd862a9dd6eb9f638377ad98438876c9aeac96c9482f5"}, + {file = "chroma_hnswlib-0.7.6-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bb0633b60e00a2b92314d0bf5bbc0da3d3320be72c7e3f4a9b19f4609dc2b2ab"}, + {file = "chroma_hnswlib-0.7.6-cp37-cp37m-win_amd64.whl", hash = "sha256:a566abe32fab42291f766d667bdbfa234a7f457dcbd2ba19948b7a978c8ca624"}, + {file = "chroma_hnswlib-0.7.6-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6be47853d9a58dedcfa90fc846af202b071f028bbafe1d8711bf64fe5a7f6111"}, + {file = "chroma_hnswlib-0.7.6-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3a7af35bdd39a88bffa49f9bb4bf4f9040b684514a024435a1ef5cdff980579d"}, + {file = "chroma_hnswlib-0.7.6-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a53b1f1551f2b5ad94eb610207bde1bb476245fc5097a2bec2b476c653c58bde"}, + {file = "chroma_hnswlib-0.7.6-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3085402958dbdc9ff5626ae58d696948e715aef88c86d1e3f9285a88f1afd3bc"}, + {file = "chroma_hnswlib-0.7.6-cp38-cp38-win_amd64.whl", hash = "sha256:77326f658a15adfb806a16543f7db7c45f06fd787d699e643642d6bde8ed49c4"}, + {file = "chroma_hnswlib-0.7.6-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:93b056ab4e25adab861dfef21e1d2a2756b18be5bc9c292aa252fa12bb44e6ae"}, + {file = "chroma_hnswlib-0.7.6-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fe91f018b30452c16c811fd6c8ede01f84e5a9f3c23e0758775e57f1c3778871"}, + {file = "chroma_hnswlib-0.7.6-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e6c0e627476f0f4d9e153420d36042dd9c6c3671cfd1fe511c0253e38c2a1039"}, + {file = "chroma_hnswlib-0.7.6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e9796a4536b7de6c6d76a792ba03e08f5aaa53e97e052709568e50b4d20c04f"}, + {file = "chroma_hnswlib-0.7.6-cp39-cp39-win_amd64.whl", hash = "sha256:d30e2db08e7ffdcc415bd072883a322de5995eb6ec28a8f8c054103bbd3ec1e0"}, + {file = "chroma_hnswlib-0.7.6.tar.gz", hash = "sha256:4dce282543039681160259d29fcde6151cc9106c6461e0485f57cdccd83059b7"}, ] [package.dependencies] @@ -1425,26 +1421,26 @@ numpy = "*" [[package]] name = "chromadb" -version = "0.5.1" +version = "0.5.20" description = "Chroma." optional = false python-versions = ">=3.8" files = [ - {file = "chromadb-0.5.1-py3-none-any.whl", hash = "sha256:61f1f75a672b6edce7f1c8875c67e2aaaaf130dc1c1684431fbc42ad7240d01d"}, - {file = "chromadb-0.5.1.tar.gz", hash = "sha256:e2b2b6a34c2a949bedcaa42fa7775f40c7f6667848fc8094dcbf97fc0d30bee7"}, + {file = "chromadb-0.5.20-py3-none-any.whl", hash = "sha256:9550ba1b6dce911e35cac2568b301badf4b42f457b99a432bdeec2b6b9dd3680"}, + {file = "chromadb-0.5.20.tar.gz", hash = "sha256:19513a23b2d20059866216bfd80195d1d4a160ffba234b8899f5e80978160ca7"}, ] [package.dependencies] bcrypt = ">=4.0.1" build = ">=1.0.3" -chroma-hnswlib = "0.7.3" +chroma-hnswlib = "0.7.6" fastapi = ">=0.95.2" grpcio = ">=1.58.0" httpx = ">=0.27.0" importlib-resources = "*" kubernetes = ">=28.1.0" mmh3 = ">=4.0.1" -numpy = ">=1.22.5,<2.0.0" +numpy = ">=1.22.5" onnxruntime = ">=1.14.1" opentelemetry-api = ">=1.2.0" opentelemetry-exporter-otlp-proto-grpc = ">=1.2.0" @@ -1456,7 +1452,7 @@ posthog = ">=2.4.0" pydantic = ">=1.9" pypika = ">=0.48.9" PyYAML = ">=6.0.0" -requests = ">=2.28" +rich = ">=10.11.0" tenacity = ">=8.2.3" tokenizers = ">=0.13.2" tqdm = ">=4.65.0" @@ -2036,9 +2032,6 @@ files = [ {file = "dataclass_wizard-0.28.0-py2.py3-none-any.whl", hash = "sha256:996fa46475b9192a48a057c34f04597bc97be5bc2f163b99cb1de6f778ca1f7f"}, ] -[package.dependencies] -typing-extensions = {version = ">=4", markers = "python_version == \"3.9\" or python_version == \"3.10\""} - [package.extras] dev = ["Sphinx (==7.4.7)", "Sphinx (==8.1.3)", "bump2version (==1.0.1)", "coverage (>=6.2)", "dataclass-factory (==2.16)", "dataclass-wizard[toml]", "dataclasses-json (==0.6.7)", "flake8 (>=3)", "jsons (==1.6.3)", "pip (>=21.3.1)", "pytest (==8.3.3)", "pytest-cov (==6.0.0)", "pytest-mock (>=3.6.1)", "pytimeparse (==1.1.8)", "sphinx-issues (==5.0.0)", "tomli (>=2,<3)", "tomli (>=2,<3)", "tomli-w (>=1,<2)", "tox (==4.23.2)", "twine (==5.1.1)", "watchdog[watchmedo] (==6.0.0)", "wheel (==0.45.0)"] timedelta = ["pytimeparse (>=1.1.7)"] @@ -2409,20 +2402,6 @@ files = [ [package.extras] tests = ["pytest"] -[[package]] -name = "exceptiongroup" -version = "1.2.2" -description = "Backport of PEP 654 (exception groups)" -optional = false -python-versions = ">=3.7" -files = [ - {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, - {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, -] - -[package.extras] -test = ["pytest (>=6)"] - [[package]] name = "faker" version = "32.1.0" @@ -2613,21 +2592,21 @@ six = ">=1.10.0" [[package]] name = "flask" -version = "3.0.3" +version = "3.1.0" description = "A simple framework for building complex web applications." optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "flask-3.0.3-py3-none-any.whl", hash = "sha256:34e815dfaa43340d1d15a5c3a02b8476004037eb4840b34910c6e21679d288f3"}, - {file = "flask-3.0.3.tar.gz", hash = "sha256:ceb27b0af3823ea2737928a4d99d125a06175b8512c445cbd9a9ce200ef76842"}, + {file = "flask-3.1.0-py3-none-any.whl", hash = "sha256:d667207822eb83f1c4b50949b1623c8fc8d51f2341d65f72e1a1815397551136"}, + {file = "flask-3.1.0.tar.gz", hash = "sha256:5f873c5184c897c8d9d1b05df1e3d01b14910ce69607a117bd3277098a5836ac"}, ] [package.dependencies] -blinker = ">=1.6.2" +blinker = ">=1.9" click = ">=8.1.3" -itsdangerous = ">=2.1.2" +itsdangerous = ">=2.2" Jinja2 = ">=3.1.2" -Werkzeug = ">=3.0.0" +Werkzeug = ">=3.1" [package.extras] async = ["asgiref (>=3.2)"] @@ -2635,19 +2614,23 @@ dotenv = ["python-dotenv"] [[package]] name = "flask-compress" -version = "1.14" -description = "Compress responses in your Flask app with gzip, deflate or brotli." +version = "1.17" +description = "Compress responses in your Flask app with gzip, deflate, brotli or zstandard." optional = false -python-versions = "*" +python-versions = ">=3.9" files = [ - {file = "Flask-Compress-1.14.tar.gz", hash = "sha256:e46528f37b91857012be38e24e65db1a248662c3dc32ee7808b5986bf1d123ee"}, - {file = "Flask_Compress-1.14-py3-none-any.whl", hash = "sha256:b86c9808f0f38ea2246c9730972cf978f2cdf6a9a1a69102ba81e07891e6b26c"}, + {file = "Flask_Compress-1.17-py3-none-any.whl", hash = "sha256:415131f197c41109f08e8fdfc3a6628d83d81680fb5ecd0b3a97410e02397b20"}, + {file = "flask_compress-1.17.tar.gz", hash = "sha256:1ebb112b129ea7c9e7d6ee6d5cc0d64f226cbc50c4daddf1a58b9bd02253fbd8"}, ] [package.dependencies] brotli = {version = "*", markers = "platform_python_implementation != \"PyPy\""} brotlicffi = {version = "*", markers = "platform_python_implementation == \"PyPy\""} flask = "*" +zstandard = [ + {version = "*", markers = "platform_python_implementation != \"PyPy\""}, + {version = "*", extras = ["cffi"], markers = "platform_python_implementation == \"PyPy\""}, +] [[package]] name = "flask-cors" @@ -3210,14 +3193,8 @@ files = [ [package.dependencies] google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" -grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, - {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, -] -grpcio-status = [ - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, - {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, -] +grpcio = {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""} +grpcio-status = {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""} proto-plus = ">=1.22.3,<2.0.0dev" protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" requests = ">=2.18.0,<3.0.0.dev0" @@ -5550,9 +5527,6 @@ files = [ {file = "multidict-6.1.0.tar.gz", hash = "sha256:22ae2ebf9b0c69d206c003e2f6a914ea33f0a932d4aa16f236afc049d9958f4a"}, ] -[package.dependencies] -typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""} - [[package]] name = "multiprocess" version = "0.70.17" @@ -6410,7 +6384,6 @@ bottleneck = {version = ">=1.3.6", optional = true, markers = "extra == \"perfor numba = {version = ">=0.56.4", optional = true, markers = "extra == \"performance\""} numexpr = {version = ">=2.8.4", optional = true, markers = "extra == \"performance\""} numpy = [ - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] @@ -6694,7 +6667,6 @@ files = [ deprecation = ">=2.1.0,<3.0.0" httpx = {version = ">=0.26,<0.28", extras = ["http2"]} pydantic = ">=1.9,<3.0" -strenum = {version = ">=0.4.9,<0.5.0", markers = "python_version < \"3.11\""} [[package]] name = "posthog" @@ -7426,9 +7398,6 @@ files = [ {file = "pypdf-5.1.0.tar.gz", hash = "sha256:425a129abb1614183fd1aca6982f650b47f8026867c0ce7c4b9f281c443d2740"}, ] -[package.dependencies] -typing_extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} - [package.extras] crypto = ["cryptography"] cryptodome = ["PyCryptodome"] @@ -7517,11 +7486,9 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "sys_platform == \"win32\""} -exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" pluggy = ">=1.5,<2" -tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] @@ -7559,7 +7526,6 @@ files = [ [package.dependencies] pytest = ">=8.3.3" -tomli = {version = ">=2.0.1", markers = "python_version < \"3.11\""} [package.extras] testing = ["covdefaults (>=2.3)", "coverage (>=7.6.1)", "pytest-mock (>=3.14)"] @@ -8377,7 +8343,6 @@ files = [ [package.dependencies] markdown-it-py = ">=2.2.0" pygments = ">=2.13.0,<3.0.0" -typing-extensions = {version = ">=4.0.0,<5.0", markers = "python_version < \"3.11\""} [package.extras] jupyter = ["ipywidgets (>=7.5.1,<9)"] @@ -8497,29 +8462,29 @@ pyasn1 = ">=0.1.3" [[package]] name = "ruff" -version = "0.7.4" +version = "0.8.1" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.7.4-py3-none-linux_armv6l.whl", hash = "sha256:a4919925e7684a3f18e18243cd6bea7cfb8e968a6eaa8437971f681b7ec51478"}, - {file = "ruff-0.7.4-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:cfb365c135b830778dda8c04fb7d4280ed0b984e1aec27f574445231e20d6c63"}, - {file = "ruff-0.7.4-py3-none-macosx_11_0_arm64.whl", hash = "sha256:63a569b36bc66fbadec5beaa539dd81e0527cb258b94e29e0531ce41bacc1f20"}, - {file = "ruff-0.7.4-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d06218747d361d06fd2fdac734e7fa92df36df93035db3dc2ad7aa9852cb109"}, - {file = "ruff-0.7.4-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e0cea28d0944f74ebc33e9f934238f15c758841f9f5edd180b5315c203293452"}, - {file = "ruff-0.7.4-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:80094ecd4793c68b2571b128f91754d60f692d64bc0d7272ec9197fdd09bf9ea"}, - {file = "ruff-0.7.4-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:997512325c6620d1c4c2b15db49ef59543ef9cd0f4aa8065ec2ae5103cedc7e7"}, - {file = "ruff-0.7.4-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:00b4cf3a6b5fad6d1a66e7574d78956bbd09abfd6c8a997798f01f5da3d46a05"}, - {file = "ruff-0.7.4-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7dbdc7d8274e1422722933d1edddfdc65b4336abf0b16dfcb9dedd6e6a517d06"}, - {file = "ruff-0.7.4-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e92dfb5f00eaedb1501b2f906ccabfd67b2355bdf117fea9719fc99ac2145bc"}, - {file = "ruff-0.7.4-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:3bd726099f277d735dc38900b6a8d6cf070f80828877941983a57bca1cd92172"}, - {file = "ruff-0.7.4-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:2e32829c429dd081ee5ba39aef436603e5b22335c3d3fff013cd585806a6486a"}, - {file = "ruff-0.7.4-py3-none-musllinux_1_2_i686.whl", hash = "sha256:662a63b4971807623f6f90c1fb664613f67cc182dc4d991471c23c541fee62dd"}, - {file = "ruff-0.7.4-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:876f5e09eaae3eb76814c1d3b68879891d6fde4824c015d48e7a7da4cf066a3a"}, - {file = "ruff-0.7.4-py3-none-win32.whl", hash = "sha256:75c53f54904be42dd52a548728a5b572344b50d9b2873d13a3f8c5e3b91f5cac"}, - {file = "ruff-0.7.4-py3-none-win_amd64.whl", hash = "sha256:745775c7b39f914238ed1f1b0bebed0b9155a17cd8bc0b08d3c87e4703b990d6"}, - {file = "ruff-0.7.4-py3-none-win_arm64.whl", hash = "sha256:11bff065102c3ae9d3ea4dc9ecdfe5a5171349cdd0787c1fc64761212fc9cf1f"}, - {file = "ruff-0.7.4.tar.gz", hash = "sha256:cd12e35031f5af6b9b93715d8c4f40360070b2041f81273d0527683d5708fce2"}, + {file = "ruff-0.8.1-py3-none-linux_armv6l.whl", hash = "sha256:fae0805bd514066f20309f6742f6ee7904a773eb9e6c17c45d6b1600ca65c9b5"}, + {file = "ruff-0.8.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b8a4f7385c2285c30f34b200ca5511fcc865f17578383db154e098150ce0a087"}, + {file = "ruff-0.8.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:cd054486da0c53e41e0086e1730eb77d1f698154f910e0cd9e0d64274979a209"}, + {file = "ruff-0.8.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2029b8c22da147c50ae577e621a5bfbc5d1fed75d86af53643d7a7aee1d23871"}, + {file = "ruff-0.8.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2666520828dee7dfc7e47ee4ea0d928f40de72056d929a7c5292d95071d881d1"}, + {file = "ruff-0.8.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:333c57013ef8c97a53892aa56042831c372e0bb1785ab7026187b7abd0135ad5"}, + {file = "ruff-0.8.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:288326162804f34088ac007139488dcb43de590a5ccfec3166396530b58fb89d"}, + {file = "ruff-0.8.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b12c39b9448632284561cbf4191aa1b005882acbc81900ffa9f9f471c8ff7e26"}, + {file = "ruff-0.8.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:364e6674450cbac8e998f7b30639040c99d81dfb5bbc6dfad69bc7a8f916b3d1"}, + {file = "ruff-0.8.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b22346f845fec132aa39cd29acb94451d030c10874408dbf776af3aaeb53284c"}, + {file = "ruff-0.8.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b2f2f7a7e7648a2bfe6ead4e0a16745db956da0e3a231ad443d2a66a105c04fa"}, + {file = "ruff-0.8.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:adf314fc458374c25c5c4a4a9270c3e8a6a807b1bec018cfa2813d6546215540"}, + {file = "ruff-0.8.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a885d68342a231b5ba4d30b8c6e1b1ee3a65cf37e3d29b3c74069cdf1ee1e3c9"}, + {file = "ruff-0.8.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d2c16e3508c8cc73e96aa5127d0df8913d2290098f776416a4b157657bee44c5"}, + {file = "ruff-0.8.1-py3-none-win32.whl", hash = "sha256:93335cd7c0eaedb44882d75a7acb7df4b77cd7cd0d2255c93b28791716e81790"}, + {file = "ruff-0.8.1-py3-none-win_amd64.whl", hash = "sha256:2954cdbe8dfd8ab359d4a30cd971b589d335a44d444b6ca2cb3d1da21b75e4b6"}, + {file = "ruff-0.8.1-py3-none-win_arm64.whl", hash = "sha256:55873cc1a473e5ac129d15eccb3c008c096b94809d693fc7053f588b67822737"}, + {file = "ruff-0.8.1.tar.gz", hash = "sha256:3583db9a6450364ed5ca3f3b4225958b24f78178908d5c4bc0f46251ccca898f"}, ] [[package]] @@ -8771,11 +8736,6 @@ files = [ {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f60021ec1574e56632be2a36b946f8143bf4e5e6af4a06d85281adc22938e0dd"}, {file = "scikit_learn-1.5.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:394397841449853c2290a32050382edaec3da89e35b3e03d6cc966aebc6a8ae6"}, {file = "scikit_learn-1.5.2-cp312-cp312-win_amd64.whl", hash = "sha256:57cc1786cfd6bd118220a92ede80270132aa353647684efa385a74244a41e3b1"}, - {file = "scikit_learn-1.5.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e9a702e2de732bbb20d3bad29ebd77fc05a6b427dc49964300340e4c9328b3f5"}, - {file = "scikit_learn-1.5.2-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:b0768ad641981f5d3a198430a1d31c3e044ed2e8a6f22166b4d546a5116d7908"}, - {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:178ddd0a5cb0044464fc1bfc4cca5b1833bfc7bb022d70b05db8530da4bb3dd3"}, - {file = "scikit_learn-1.5.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7284ade780084d94505632241bf78c44ab3b6f1e8ccab3d2af58e0e950f9c12"}, - {file = "scikit_learn-1.5.2-cp313-cp313-win_amd64.whl", hash = "sha256:b7b0f9a0b1040830d38c39b91b3a44e1b643f4b36e36567b80b7c6bd2202a27f"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:757c7d514ddb00ae249832fe87100d9c73c6ea91423802872d9e74970a0e40b9"}, {file = "scikit_learn-1.5.2-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:52788f48b5d8bca5c0736c175fa6bdaab2ef00a8f536cda698db61bd89c551c1"}, {file = "scikit_learn-1.5.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:643964678f4b5fbdc95cbf8aec638acc7aa70f5f79ee2cdad1eec3df4ba6ead8"}, @@ -9214,22 +9174,6 @@ httpx = {version = ">=0.26,<0.28", extras = ["http2"]} python-dateutil = ">=2.8.2,<3.0.0" typing-extensions = ">=4.2.0,<5.0.0" -[[package]] -name = "strenum" -version = "0.4.15" -description = "An Enum that inherits from str." -optional = false -python-versions = "*" -files = [ - {file = "StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659"}, - {file = "StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff"}, -] - -[package.extras] -docs = ["myst-parser[linkify]", "sphinx", "sphinx-rtd-theme"] -release = ["twine"] -test = ["pylint", "pytest", "pytest-black", "pytest-cov", "pytest-pylint"] - [[package]] name = "strictyaml" version = "1.7.3" @@ -9636,17 +9580,6 @@ files = [ {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, ] -[[package]] -name = "tomli" -version = "2.1.0" -description = "A lil' TOML parser" -optional = false -python-versions = ">=3.8" -files = [ - {file = "tomli-2.1.0-py3-none-any.whl", hash = "sha256:a5c57c3d1c56f5ccdf89f6523458f60ef716e210fc47c4cfb188c5ba473e0391"}, - {file = "tomli-2.1.0.tar.gz", hash = "sha256:3f646cae2aec94e17d04973e4249548320197cfabdf130015d023de4b74d8ab8"}, -] - [[package]] name = "tos" version = "2.7.2" @@ -10067,7 +10000,6 @@ h11 = ">=0.8" httptools = {version = ">=0.5.0", optional = true, markers = "extra == \"standard\""} python-dotenv = {version = ">=0.13", optional = true, markers = "extra == \"standard\""} pyyaml = {version = ">=5.1", optional = true, markers = "extra == \"standard\""} -typing-extensions = {version = ">=4.0", markers = "python_version < \"3.11\""} uvloop = {version = ">=0.14.0,<0.15.0 || >0.15.0,<0.15.1 || >0.15.1", optional = true, markers = "(sys_platform != \"win32\" and sys_platform != \"cygwin\") and platform_python_implementation != \"PyPy\" and extra == \"standard\""} watchfiles = {version = ">=0.13", optional = true, markers = "extra == \"standard\""} websockets = {version = ">=10.4", optional = true, markers = "extra == \"standard\""} @@ -10487,13 +10419,13 @@ files = [ [[package]] name = "werkzeug" -version = "3.0.6" +version = "3.1.3" description = "The comprehensive WSGI web application library." optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ - {file = "werkzeug-3.0.6-py3-none-any.whl", hash = "sha256:1bc0c2310d2fbb07b1dd1105eba2f7af72f322e1e455f2f93c993bee8c8a5f17"}, - {file = "werkzeug-3.0.6.tar.gz", hash = "sha256:a8dd59d4de28ca70471a34cba79bed5f7ef2e036a76b3ab0835474246eb41f8d"}, + {file = "werkzeug-3.1.3-py3-none-any.whl", hash = "sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e"}, + {file = "werkzeug-3.1.3.tar.gz", hash = "sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746"}, ] [package.dependencies] @@ -11043,12 +10975,12 @@ files = [ ] [package.dependencies] -cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\""} +cffi = {version = ">=1.11", optional = true, markers = "platform_python_implementation == \"PyPy\" or extra == \"cffi\""} [package.extras] cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.0" -python-versions = ">=3.10,<3.13" -content-hash = "152b8e11ceffaa482fee6920b7991f52427aa1ffed75614e78ec1065dd5f6898" +python-versions = ">=3.11,<3.13" +content-hash = "b762e282fd140c87ae1b0be8d56ec0e1be6515ced28996f1ab0a23f3842120af" diff --git a/api/pyproject.toml b/api/pyproject.toml index f02c063182..1bdcf5a1a8 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,108 +1,12 @@ [project] -requires-python = ">=3.10,<3.13" +name = "dify-api" +requires-python = ">=3.11,<3.13" [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" -[tool.ruff] -exclude=[ - "migrations/*", -] -line-length = 120 - -[tool.ruff.lint] -preview = true -select = [ - "B", # flake8-bugbear rules - "C4", # flake8-comprehensions - "E", # pycodestyle E rules - "F", # pyflakes rules - "FURB", # refurb rules - "I", # isort rules - "N", # pep8-naming - "PT", # flake8-pytest-style rules - "PLC0208", # iteration-over-set - "PLC2801", # unnecessary-dunder-call - "PLC0414", # useless-import-alias - "PLR0402", # manual-from-import - "PLR1711", # useless-return - "PLR1714", # repeated-equality-comparison - "RUF013", # implicit-optional - "RUF019", # unnecessary-key-check - "RUF100", # unused-noqa - "RUF101", # redirected-noqa - "S506", # unsafe-yaml-load - "SIM", # flake8-simplify rules - "TRY400", # error-instead-of-exception - "TRY401", # verbose-log-message - "UP", # pyupgrade rules - "W191", # tab-indentation - "W605", # invalid-escape-sequence -] -ignore = [ - "E402", # module-import-not-at-top-of-file - "E711", # none-comparison - "E712", # true-false-comparison - "E721", # type-comparison - "E722", # bare-except - "E731", # lambda-assignment - "F821", # undefined-name - "F841", # unused-variable - "FURB113", # repeated-append - "FURB152", # math-constant - "UP007", # non-pep604-annotation - "UP032", # f-string - "B005", # strip-with-multi-characters - "B006", # mutable-argument-default - "B007", # unused-loop-control-variable - "B026", # star-arg-unpacking-after-keyword-arg - "B904", # raise-without-from-inside-except - "B905", # zip-without-explicit-strict - "N806", # non-lowercase-variable-in-function - "N815", # mixed-case-variable-in-class-scope - "PT011", # pytest-raises-too-broad - "SIM102", # collapsible-if - "SIM103", # needless-bool - "SIM105", # suppressible-exception - "SIM107", # return-in-try-except-finally - "SIM108", # if-else-block-instead-of-if-exp - "SIM113", # eumerate-for-loop - "SIM117", # multiple-with-statements - "SIM210", # if-expr-with-true-false - "SIM300", # yoda-conditions, -] - -[tool.ruff.lint.per-file-ignores] -"app.py" = [ -] -"__init__.py" = [ - "F401", # unused-import - "F811", # redefined-while-unused -] -"configs/*" = [ - "N802", # invalid-function-name -] -"libs/gmpy2_pkcs10aep_cipher.py" = [ - "N803", # invalid-argument-name -] -"tests/*" = [ - "F811", # redefined-while-unused - "F401", # unused-import -] - -[tool.ruff.lint.pyflakes] -extend-generics=[ - "_pytest.monkeypatch", - "tests.integration_tests", -] - -[tool.ruff.format] -exclude = [ -] - [tool.poetry] -name = "dify-api" package-mode = false ############################################################ @@ -124,11 +28,11 @@ chardet = "~5.1.0" cohere = "~5.2.4" dashscope = { version = "~1.17.0", extras = ["tokenizer"] } fal-client = "0.5.6" -flask = "~3.0.1" -flask-compress = "~1.14" +flask = "~3.1.0" +flask-compress = "~1.17" flask-cors = "~4.0.0" flask-login = "~0.6.3" -flask-migrate = "~4.0.5" +flask-migrate = "~4.0.7" flask-restful = "~0.3.10" flask-sqlalchemy = "~3.1.1" gevent = "~24.11.1" @@ -163,7 +67,7 @@ pydantic-settings = "~2.6.0" pydantic_extra_types = "~2.9.0" pyjwt = "~2.8.0" pypdfium2 = "~4.17.0" -python = ">=3.10,<3.13" +python = ">=3.11,<3.13" python-docx = "~1.1.0" python-dotenv = "1.0.0" pyyaml = "~6.0.1" @@ -184,7 +88,6 @@ unstructured = { version = "~0.16.1", extras = ["docx", "epub", "md", "msg", "pp validators = "0.21.0" volcengine-python-sdk = {extras = ["ark"], version = "~1.0.98"} websocket-client = "~1.7.0" -werkzeug = "~3.0.1" xinference-client = "0.15.2" yarl = "~1.9.4" youtube-transcript-api = "~0.6.2" @@ -242,7 +145,7 @@ tos = "~2.7.1" [tool.poetry.group.vdb.dependencies] alibabacloud_gpdb20160503 = "~3.8.0" alibabacloud_tea_openapi = "~0.3.9" -chromadb = "0.5.1" +chromadb = "0.5.20" clickhouse-connect = "~0.7.16" couchbase = "~4.3.0" elasticsearch = "8.14.0" @@ -282,4 +185,4 @@ pytest-mock = "~3.14.0" optional = true [tool.poetry.group.lint.dependencies] dotenv-linter = "~0.5.0" -ruff = "~0.7.3" +ruff = "~0.8.1" diff --git a/api/schedule/update_tidb_serverless_status_task.py b/api/schedule/update_tidb_serverless_status_task.py index 07eca3173b..b2d8746f9c 100644 --- a/api/schedule/update_tidb_serverless_status_task.py +++ b/api/schedule/update_tidb_serverless_status_task.py @@ -12,21 +12,18 @@ from models.dataset import TidbAuthBinding def update_tidb_serverless_status_task(): click.echo(click.style("Update tidb serverless status task.", fg="green")) start_at = time.perf_counter() - while True: - try: - # check the number of idle tidb serverless - tidb_serverless_list = TidbAuthBinding.query.filter( - TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING" - ).all() - if len(tidb_serverless_list) == 0: - break - # update tidb serverless status - iterations_per_thread = 20 - update_clusters(tidb_serverless_list) - - except Exception as e: - click.echo(click.style(f"Error: {e}", fg="red")) - break + try: + # check the number of idle tidb serverless + tidb_serverless_list = TidbAuthBinding.query.filter( + TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING" + ).all() + if len(tidb_serverless_list) == 0: + return + # update tidb serverless status + update_clusters(tidb_serverless_list) + + except Exception as e: + click.echo(click.style(f"Error: {e}", fg="red")) end_at = time.perf_counter() click.echo( diff --git a/api/services/account_service.py b/api/services/account_service.py index 3d7f9e7dfb..f0c6ac7ebd 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -4,7 +4,7 @@ import logging import random import secrets import uuid -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from hashlib import sha256 from typing import Any, Optional @@ -115,15 +115,15 @@ class AccountService: available_ta.current = True db.session.commit() - if datetime.now(timezone.utc).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10): - account.last_active_at = datetime.now(timezone.utc).replace(tzinfo=None) + if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10): + account.last_active_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return account @staticmethod def get_account_jwt_token(account: Account) -> str: - exp_dt = datetime.now(timezone.utc) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) + exp_dt = datetime.now(UTC) + timedelta(minutes=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES) exp = int(exp_dt.timestamp()) payload = { "user_id": account.id, @@ -160,7 +160,7 @@ class AccountService: if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() @@ -253,7 +253,7 @@ class AccountService: # If it exists, update the record account_integrate.open_id = open_id account_integrate.encrypted_token = "" # todo - account_integrate.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + account_integrate.updated_at = datetime.now(UTC).replace(tzinfo=None) else: # If it does not exist, create a new record account_integrate = AccountIntegrate( @@ -288,7 +288,7 @@ class AccountService: @staticmethod def update_login_info(account: Account, *, ip_address: str) -> None: """Update last login time and ip""" - account.last_login_at = datetime.now(timezone.utc).replace(tzinfo=None) + account.last_login_at = datetime.now(UTC).replace(tzinfo=None) account.last_login_ip = ip_address db.session.add(account) db.session.commit() @@ -573,7 +573,7 @@ class TenantService: return tenant @staticmethod - def switch_tenant(account: Account, tenant_id: Optional[int] = None) -> None: + def switch_tenant(account: Account, tenant_id: Optional[str] = None) -> None: """Switch the current workspace for the account""" # Ensure tenant_id is provided @@ -672,7 +672,7 @@ class TenantService: return db.session.query(func.count(Tenant.id)).scalar() @staticmethod - def check_member_permission(tenant: Tenant, operator: Account, member: Account, action: str) -> None: + def check_member_permission(tenant: Tenant, operator: Account, member: Account | None, action: str) -> None: """Check member permission""" perms = { "add": [TenantAccountRole.OWNER, TenantAccountRole.ADMIN], @@ -765,7 +765,7 @@ class RegisterService: ) account.last_login_ip = ip_address - account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.now(UTC).replace(tzinfo=None) TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True) @@ -805,7 +805,7 @@ class RegisterService: is_setup=is_setup, ) account.status = AccountStatus.ACTIVE.value if not status else status.value - account.initialized_at = datetime.now(timezone.utc).replace(tzinfo=None) + account.initialized_at = datetime.now(UTC).replace(tzinfo=None) if open_id is not None or provider is not None: AccountService.link_account_integrate(provider, open_id, account) diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py index 915d37ec03..f45c21cb18 100644 --- a/api/services/annotation_service.py +++ b/api/services/annotation_service.py @@ -429,7 +429,7 @@ class AppAnnotationService: raise NotFound("App annotation not found") annotation_setting.score_threshold = args["score_threshold"] annotation_setting.updated_user_id = current_user.id - annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(annotation_setting) db.session.commit() diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 02a3d0a9a1..8180c3b400 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -1,6 +1,6 @@ import logging import uuid -from enum import Enum +from enum import StrEnum from typing import Optional from uuid import uuid4 @@ -22,15 +22,15 @@ logger = logging.getLogger(__name__) IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" IMPORT_INFO_REDIS_EXPIRY = 180 # 3 minutes -CURRENT_DSL_VERSION = "0.1.3" +CURRENT_DSL_VERSION = "0.1.4" -class ImportMode(str, Enum): +class ImportMode(StrEnum): YAML_CONTENT = "yaml-content" YAML_URL = "yaml-url" -class ImportStatus(str, Enum): +class ImportStatus(StrEnum): COMPLETED = "completed" COMPLETED_WITH_WARNINGS = "completed-with-warnings" PENDING = "pending" @@ -113,6 +113,10 @@ class AppDslService: ) try: max_size = 10 * 1024 * 1024 # 10MB + # tricky way to handle url from github to github raw url + if yaml_url.startswith("https://github.com") and yaml_url.endswith((".yml", ".yaml")): + yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com") + yaml_url = yaml_url.replace("/blob/", "/") response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10)) response.raise_for_status() content = response.content @@ -383,11 +387,11 @@ class AppDslService: environment_variables_list = workflow_data.get("environment_variables", []) environment_variables = [ - variable_factory.build_variable_from_mapping(obj) for obj in environment_variables_list + variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list ] conversation_variables_list = workflow_data.get("conversation_variables", []) conversation_variables = [ - variable_factory.build_variable_from_mapping(obj) for obj in conversation_variables_list + variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list ] workflow_service = WorkflowService() diff --git a/api/services/app_generate_service.py b/api/services/app_generate_service.py index 83a9a16904..545de8190c 100644 --- a/api/services/app_generate_service.py +++ b/api/services/app_generate_service.py @@ -43,50 +43,66 @@ class AppGenerateService: request_id = rate_limit.enter(request_id) if app_model.mode == AppMode.COMPLETION.value: return rate_limit.generate( - CompletionAppGenerator().generate( - app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + generator=CompletionAppGenerator().generate( + app_model=app_model, + user=user, + args=args, + invoke_from=invoke_from, + streaming=streaming, ), - request_id, + request_id=request_id, ) elif app_model.mode == AppMode.AGENT_CHAT.value or app_model.is_agent: + generator = AgentChatAppGenerator().generate( + app_model=app_model, + user=user, + args=args, + invoke_from=invoke_from, + streaming=streaming, + ) return rate_limit.generate( - AgentChatAppGenerator().generate( - app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming - ), - request_id, + generator=generator, + request_id=request_id, ) elif app_model.mode == AppMode.CHAT.value: return rate_limit.generate( - ChatAppGenerator().generate( - app_model=app_model, user=user, args=args, invoke_from=invoke_from, stream=streaming + generator=ChatAppGenerator().generate( + app_model=app_model, + user=user, + args=args, + invoke_from=invoke_from, + streaming=streaming, ), - request_id, + request_id=request_id, ) elif app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, invoke_from) return rate_limit.generate( - AdvancedChatAppGenerator().generate( + generator=AdvancedChatAppGenerator().generate( app_model=app_model, workflow=workflow, user=user, args=args, invoke_from=invoke_from, - stream=streaming, + streaming=streaming, ), - request_id, + request_id=request_id, ) elif app_model.mode == AppMode.WORKFLOW.value: workflow = cls._get_workflow(app_model, invoke_from) + generator = WorkflowAppGenerator().generate( + app_model=app_model, + workflow=workflow, + user=user, + args=args, + invoke_from=invoke_from, + streaming=streaming, + call_depth=0, + workflow_thread_pool_id=None, + ) return rate_limit.generate( - WorkflowAppGenerator().generate( - app_model=app_model, - workflow=workflow, - user=user, - args=args, - invoke_from=invoke_from, - stream=streaming, - ), - request_id, + generator=generator, + request_id=request_id, ) else: raise ValueError(f"Invalid app mode {app_model.mode}") @@ -108,12 +124,17 @@ class AppGenerateService: if app_model.mode == AppMode.ADVANCED_CHAT.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return AdvancedChatAppGenerator().single_iteration_generate( - app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming + app_model=app_model, + workflow=workflow, + node_id=node_id, + user=user, + args=args, + streaming=streaming, ) elif app_model.mode == AppMode.WORKFLOW.value: workflow = cls._get_workflow(app_model, InvokeFrom.DEBUGGER) return WorkflowAppGenerator().single_iteration_generate( - app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, stream=streaming + app_model=app_model, workflow=workflow, node_id=node_id, user=user, args=args, streaming=streaming ) else: raise ValueError(f"Invalid app mode {app_model.mode}") diff --git a/api/services/app_service.py b/api/services/app_service.py index 26e3d004b6..8d8ba735ec 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,6 +1,6 @@ import json import logging -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import cast from flask_login import current_user @@ -223,7 +223,7 @@ class AppService: app.icon_background = args.get("icon_background") app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) app.updated_by = current_user.id - app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() if app.max_active_requests is not None: @@ -240,7 +240,7 @@ class AppService: """ app.name = name app.updated_by = current_user.id - app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return app @@ -256,7 +256,7 @@ class AppService: app.icon = icon app.icon_background = icon_background app.updated_by = current_user.id - app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return app @@ -273,7 +273,7 @@ class AppService: app.enable_site = enable_site app.updated_by = current_user.id - app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return app @@ -290,7 +290,7 @@ class AppService: app.enable_api = enable_api app.updated_by = current_user.id - app.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + app.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return app diff --git a/api/services/auth/auth_type.py b/api/services/auth/auth_type.py index 2d6e901447..2e1946841f 100644 --- a/api/services/auth/auth_type.py +++ b/api/services/auth/auth_type.py @@ -1,6 +1,6 @@ -from enum import Enum +from enum import StrEnum -class AuthType(str, Enum): +class AuthType(StrEnum): FIRECRAWL = "firecrawl" JINA = "jinareader" diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index f9e41988c0..8642972710 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,4 +1,5 @@ -from datetime import datetime, timezone +from collections.abc import Callable +from datetime import UTC, datetime from typing import Optional, Union from sqlalchemy import asc, desc, or_ @@ -74,14 +75,14 @@ class ConversationService: return InfiniteScrollPagination(data=conversations, limit=limit, has_more=has_more) @classmethod - def _get_sort_params(cls, sort_by: str) -> tuple[str, callable]: + def _get_sort_params(cls, sort_by: str): if sort_by.startswith("-"): return sort_by[1:], desc return sort_by, asc @classmethod def _build_filter_condition( - cls, sort_field: str, sort_direction: callable, reference_conversation: Conversation, is_next_page: bool = False + cls, sort_field: str, sort_direction: Callable, reference_conversation: Conversation, is_next_page: bool = False ): field_value = getattr(reference_conversation, sort_field) if (sort_direction == desc and not is_next_page) or (sort_direction == asc and is_next_page): @@ -104,7 +105,7 @@ class ConversationService: return cls.auto_generate_name(app_model, conversation) else: conversation.name = name - conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return conversation @@ -160,5 +161,5 @@ class ConversationService: conversation = cls.get_conversation(app_model, conversation_id, user) conversation.is_deleted = True - conversation.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 806dbdf8c5..a1014e8e0a 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -406,6 +406,9 @@ class DocumentService: ], "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, }, + "limits": { + "indexing_max_segmentation_tokens_length": dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH, + }, } DOCUMENT_METADATA_SCHEMA = { @@ -600,7 +603,7 @@ class DocumentService: # update document to be paused document.is_paused = True document.paused_by = current_user.id - document.paused_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.paused_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(document) db.session.commit() @@ -1072,7 +1075,7 @@ class DocumentService: document.parsing_completed_at = None document.cleaning_completed_at = None document.splitting_completed_at = None - document.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.created_from = created_from document.doc_form = document_data["doc_form"] db.session.add(document) @@ -1409,8 +1412,8 @@ class SegmentService: word_count=len(content), tokens=tokens, status="completed", - indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), created_by=current_user.id, ) if document.doc_form == "qa_model": @@ -1429,7 +1432,7 @@ class SegmentService: except Exception as e: logging.exception("create segment index failed") segment_document.enabled = False - segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment_document.status = "error" segment_document.error = str(e) db.session.commit() @@ -1481,8 +1484,8 @@ class SegmentService: word_count=len(content), tokens=tokens, status="completed", - indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), - completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), created_by=current_user.id, ) if document.doc_form == "qa_model": @@ -1508,7 +1511,7 @@ class SegmentService: logging.exception("create segment index failed") for segment_document in segment_data_list: segment_document.enabled = False - segment_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment_document.status = "error" segment_document.error = str(e) db.session.commit() @@ -1526,7 +1529,7 @@ class SegmentService: if segment.enabled != action: if not action: segment.enabled = action - segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.disabled_by = current_user.id db.session.add(segment) db.session.commit() @@ -1585,10 +1588,10 @@ class SegmentService: segment.word_count = len(content) segment.tokens = tokens segment.status = "completed" - segment.indexing_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) - segment.completed_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.indexing_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + segment.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.updated_by = current_user.id - segment.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.enabled = True segment.disabled_at = None segment.disabled_by = None @@ -1608,7 +1611,7 @@ class SegmentService: except Exception as e: logging.exception("update segment index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.status = "error" segment.error = str(e) db.session.commit() diff --git a/api/services/errors/__init__.py b/api/services/errors/__init__.py index bb5711145c..eb1f055708 100644 --- a/api/services/errors/__init__.py +++ b/api/services/errors/__init__.py @@ -14,16 +14,16 @@ from . import ( ) __all__ = [ - "base", - "conversation", - "message", - "index", - "app_model_config", "account", - "document", - "dataset", "app", - "completion", + "app_model_config", "audio", + "base", + "completion", + "conversation", + "dataset", + "document", "file", + "index", + "message", ] diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 98e5d9face..7e3cd87f1e 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,6 +1,6 @@ import json from copy import deepcopy -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Any, Optional, Union import httpx @@ -99,7 +99,7 @@ class ExternalDatasetService: external_knowledge_api.description = args.get("description", "") external_knowledge_api.settings = json.dumps(args.get("settings"), ensure_ascii=False) external_knowledge_api.updated_by = user_id - external_knowledge_api.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + external_knowledge_api.updated_at = datetime.now(UTC).replace(tzinfo=None) db.session.commit() return external_knowledge_api diff --git a/api/services/feature_service.py b/api/services/feature_service.py index d0b04628cf..6bd82a2757 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import StrEnum from pydantic import BaseModel, ConfigDict @@ -22,7 +22,7 @@ class LimitationModel(BaseModel): limit: int = 0 -class LicenseStatus(str, Enum): +class LicenseStatus(StrEnum): NONE = "none" INACTIVE = "inactive" ACTIVE = "active" @@ -171,8 +171,10 @@ class FeatureService: features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"] if "license" in enterprise_info: - if "status" in enterprise_info["license"]: - features.license.status = enterprise_info["license"]["status"] + license_info = enterprise_info["license"] - if "expired_at" in enterprise_info["license"]: - features.license.expired_at = enterprise_info["license"]["expired_at"] + if "status" in license_info: + features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) + + if "expired_at" in license_info: + features.license.expired_at = license_info["expired_at"] diff --git a/api/services/file_service.py b/api/services/file_service.py index 976111502c..b12b95ca13 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -77,7 +77,7 @@ class FileService: mime_type=mimetype, created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER), created_by=user.id, - created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=False, hash=hashlib.sha3_256(content).hexdigest(), source_url=source_url, @@ -123,10 +123,10 @@ class FileService: mime_type="text/plain", created_by=current_user.id, created_by_role=CreatedByRole.ACCOUNT, - created_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=True, used_by=current_user.id, - used_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + used_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), ) db.session.add(upload_file) diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index e7b9422cfe..b20bda8755 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -371,7 +371,7 @@ class ModelLoadBalancingService: load_balancing_config.name = name load_balancing_config.enabled = enabled - load_balancing_config.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + load_balancing_config.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() self._clear_credentials_cache(tenant_id, config_id) diff --git a/api/services/recommend_app/recommend_app_type.py b/api/services/recommend_app/recommend_app_type.py index 7ea93b3f64..e60e435b3a 100644 --- a/api/services/recommend_app/recommend_app_type.py +++ b/api/services/recommend_app/recommend_app_type.py @@ -1,7 +1,7 @@ -from enum import Enum +from enum import StrEnum -class RecommendAppType(str, Enum): +class RecommendAppType(StrEnum): REMOTE = "remote" BUILDIN = "builtin" DATABASE = "db" diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 7187d40517..37d7d0937c 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,7 +1,7 @@ import json import time from collections.abc import Sequence -from datetime import datetime, timezone +from datetime import UTC, datetime from typing import Optional from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager @@ -12,7 +12,7 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.nodes import NodeType from core.workflow.nodes.event import RunCompletedEvent -from core.workflow.nodes.node_mapping import node_type_classes_mapping +from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db @@ -115,7 +115,7 @@ class WorkflowService: workflow.graph = json.dumps(graph) workflow.features = json.dumps(features) workflow.updated_by = account.id - workflow.updated_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) workflow.environment_variables = environment_variables workflow.conversation_variables = conversation_variables @@ -148,7 +148,7 @@ class WorkflowService: tenant_id=app_model.tenant_id, app_id=app_model.id, type=draft_workflow.type, - version=str(datetime.now(timezone.utc).replace(tzinfo=None)), + version=str(datetime.now(UTC).replace(tzinfo=None)), graph=draft_workflow.graph, features=draft_workflow.features, created_by=account.id, @@ -176,7 +176,8 @@ class WorkflowService: """ # return default block config default_block_configs = [] - for node_type, node_class in node_type_classes_mapping.items(): + for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values(): + node_class = node_class_mapping[LATEST_VERSION] default_config = node_class.get_default_config() if default_config: default_block_configs.append(default_config) @@ -190,13 +191,13 @@ class WorkflowService: :param filters: filter by node config parameters. :return: """ - node_type_enum: NodeType = NodeType(node_type) + node_type_enum = NodeType(node_type) # return default block config - node_class = node_type_classes_mapping.get(node_type_enum) - if not node_class: + if node_type_enum not in NODE_TYPE_CLASSES_MAPPING: return None + node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION] default_config = node_class.get_default_config(filters=filters) if not default_config: return None @@ -257,18 +258,22 @@ class WorkflowService: workflow_node_execution.elapsed_time = time.perf_counter() - start_at workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value workflow_node_execution.created_by = account.id - workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) - workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) + workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) if run_succeeded and node_run_result: # create workflow node execution - workflow_node_execution.inputs = json.dumps(node_run_result.inputs) if node_run_result.inputs else None - workflow_node_execution.process_data = ( - json.dumps(node_run_result.process_data) if node_run_result.process_data else None - ) - workflow_node_execution.outputs = ( - json.dumps(jsonable_encoder(node_run_result.outputs)) if node_run_result.outputs else None + inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None + process_data = ( + WorkflowEntry.handle_special_values(node_run_result.process_data) + if node_run_result.process_data + else None ) + outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None + + workflow_node_execution.inputs = json.dumps(inputs) + workflow_node_execution.process_data = json.dumps(process_data) + workflow_node_execution.outputs = json.dumps(outputs) workflow_node_execution.execution_metadata = ( json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None ) @@ -303,10 +308,10 @@ class WorkflowService: new_app = workflow_converter.convert_to_workflow( app_model=app_model, account=account, - name=args.get("name"), - icon_type=args.get("icon_type"), - icon=args.get("icon"), - icon_background=args.get("icon_background"), + name=args.get("name", "Default Name"), + icon_type=args.get("icon_type", "emoji"), + icon=args.get("icon", "🤖"), + icon_background=args.get("icon_background", "#FFEAD5"), ) return new_app diff --git a/api/tasks/add_document_to_index_task.py b/api/tasks/add_document_to_index_task.py index b50876cc79..09be661216 100644 --- a/api/tasks/add_document_to_index_task.py +++ b/api/tasks/add_document_to_index_task.py @@ -18,9 +18,9 @@ from models.dataset import DocumentSegment def add_document_to_index_task(dataset_document_id: str): """ Async Add document to index - :param document_id: + :param dataset_document_id: - Usage: add_document_to_index.delay(document_id) + Usage: add_document_to_index.delay(dataset_document_id) """ logging.info(click.style("Start add document to index: {}".format(dataset_document_id), fg="green")) start_at = time.perf_counter() @@ -74,7 +74,7 @@ def add_document_to_index_task(dataset_document_id: str): except Exception as e: logging.exception("add document to index failed") dataset_document.enabled = False - dataset_document.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + dataset_document.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) dataset_document.status = "error" dataset_document.error = str(e) db.session.commit() diff --git a/api/tasks/annotation/enable_annotation_reply_task.py b/api/tasks/annotation/enable_annotation_reply_task.py index e819bf3635..0bdcd0eccd 100644 --- a/api/tasks/annotation/enable_annotation_reply_task.py +++ b/api/tasks/annotation/enable_annotation_reply_task.py @@ -52,7 +52,7 @@ def enable_annotation_reply_task( annotation_setting.score_threshold = score_threshold annotation_setting.collection_binding_id = dataset_collection_binding.id annotation_setting.updated_user_id = user_id - annotation_setting.updated_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + annotation_setting.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(annotation_setting) else: new_app_annotation_setting = AppAnnotationSetting( diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 5ee72c27fc..dcb7009e44 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -80,9 +80,9 @@ def batch_create_segment_to_index_task( word_count=len(content), tokens=tokens, created_by=user_id, - indexing_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + indexing_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), status="completed", - completed_at=datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + completed_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), ) if dataset_document.doc_form == "qa_model": segment_document.answer = segment["answer"] diff --git a/api/tasks/create_segment_to_index_task.py b/api/tasks/create_segment_to_index_task.py index 26375743b6..315b01f157 100644 --- a/api/tasks/create_segment_to_index_task.py +++ b/api/tasks/create_segment_to_index_task.py @@ -38,7 +38,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] # update segment status to indexing update_params = { DocumentSegment.status: "indexing", - DocumentSegment.indexing_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.commit() @@ -75,7 +75,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] # update segment to completed update_params = { DocumentSegment.status: "completed", - DocumentSegment.completed_at: datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None), + DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), } DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.commit() @@ -87,7 +87,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] except Exception as e: logging.exception("create segment to index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.status = "error" segment.error = str(e) db.session.commit() diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 6dd755ab03..1831691393 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -67,7 +67,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): # check the page is updated if last_edited_time != page_edited_time: document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() # delete all document segment and index diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index df1177d578..734dd2478a 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -50,7 +50,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.add(document) db.session.commit() return @@ -64,7 +64,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index cb38bc668d..1a52a6636b 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -30,7 +30,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): raise NotFound("Document not found") document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) db.session.commit() # delete all document segment and index diff --git a/api/tasks/enable_segment_to_index_task.py b/api/tasks/enable_segment_to_index_task.py index 1412ad9ec7..12639db939 100644 --- a/api/tasks/enable_segment_to_index_task.py +++ b/api/tasks/enable_segment_to_index_task.py @@ -71,7 +71,7 @@ def enable_segment_to_index_task(segment_id: str): except Exception as e: logging.exception("enable segment to index failed") segment.enabled = False - segment.disabled_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) + segment.disabled_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) segment.status = "error" segment.error = str(e) db.session.commit() diff --git a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py b/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py index 4f44d2ffd6..5dd4754e8e 100644 --- a/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py +++ b/api/tests/integration_tests/vdb/analyticdb/test_analyticdb.py @@ -1,4 +1,4 @@ -from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbConfig, AnalyticdbVector +from core.rag.datasource.vdb.analyticdb.analyticdb_vector import AnalyticdbVector from core.rag.datasource.vdb.analyticdb.analyticdb_vector_openapi import AnalyticdbVectorOpenAPIConfig from core.rag.datasource.vdb.analyticdb.analyticdb_vector_sql import AnalyticdbVectorBySqlConfig from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, setup_mock_redis diff --git a/api/tests/unit_tests/configs/test_dify_config.py b/api/tests/unit_tests/configs/test_dify_config.py index 3f639ccacc..0eb310a51a 100644 --- a/api/tests/unit_tests/configs/test_dify_config.py +++ b/api/tests/unit_tests/configs/test_dify_config.py @@ -71,7 +71,6 @@ def test_flask_configs(example_env_file): assert config["EDITION"] == "SELF_HOSTED" assert config["API_COMPRESSION_ENABLED"] is False assert config["SENTRY_TRACES_SAMPLE_RATE"] == 1.0 - assert config["TESTING"] == False # value from env file assert config["CONSOLE_API_URL"] == "https://example.com" diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index 621c995a4b..e09acc4c39 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -10,7 +10,6 @@ ABS_PATH = os.path.dirname(os.path.abspath(__file__)) PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir)) CACHED_APP = Flask(__name__) -CACHED_APP.config.update({"TESTING": True}) @pytest.fixture diff --git a/api/tests/unit_tests/core/app/segments/test_factory.py b/api/tests/unit_tests/core/app/segments/test_factory.py index 882a87239b..e6e289c12a 100644 --- a/api/tests/unit_tests/core/app/segments/test_factory.py +++ b/api/tests/unit_tests/core/app/segments/test_factory.py @@ -19,36 +19,36 @@ from factories import variable_factory def test_string_variable(): test_data = {"value_type": "string", "name": "test_text", "value": "Hello, World!"} - result = variable_factory.build_variable_from_mapping(test_data) + result = variable_factory.build_conversation_variable_from_mapping(test_data) assert isinstance(result, StringVariable) def test_integer_variable(): test_data = {"value_type": "number", "name": "test_int", "value": 42} - result = variable_factory.build_variable_from_mapping(test_data) + result = variable_factory.build_conversation_variable_from_mapping(test_data) assert isinstance(result, IntegerVariable) def test_float_variable(): test_data = {"value_type": "number", "name": "test_float", "value": 3.14} - result = variable_factory.build_variable_from_mapping(test_data) + result = variable_factory.build_conversation_variable_from_mapping(test_data) assert isinstance(result, FloatVariable) def test_secret_variable(): test_data = {"value_type": "secret", "name": "test_secret", "value": "secret_value"} - result = variable_factory.build_variable_from_mapping(test_data) + result = variable_factory.build_conversation_variable_from_mapping(test_data) assert isinstance(result, SecretVariable) def test_invalid_value_type(): test_data = {"value_type": "unknown", "name": "test_invalid", "value": "value"} with pytest.raises(VariableError): - variable_factory.build_variable_from_mapping(test_data) + variable_factory.build_conversation_variable_from_mapping(test_data) def test_build_a_blank_string(): - result = variable_factory.build_variable_from_mapping( + result = variable_factory.build_conversation_variable_from_mapping( { "value_type": "string", "name": "blank", @@ -80,7 +80,7 @@ def test_object_variable(): "key2": 2, }, } - variable = variable_factory.build_variable_from_mapping(mapping) + variable = variable_factory.build_conversation_variable_from_mapping(mapping) assert isinstance(variable, ObjectSegment) assert isinstance(variable.value["key1"], str) assert isinstance(variable.value["key2"], int) @@ -97,7 +97,7 @@ def test_array_string_variable(): "text", ], } - variable = variable_factory.build_variable_from_mapping(mapping) + variable = variable_factory.build_conversation_variable_from_mapping(mapping) assert isinstance(variable, ArrayStringVariable) assert isinstance(variable.value[0], str) assert isinstance(variable.value[1], str) @@ -114,7 +114,7 @@ def test_array_number_variable(): 2.0, ], } - variable = variable_factory.build_variable_from_mapping(mapping) + variable = variable_factory.build_conversation_variable_from_mapping(mapping) assert isinstance(variable, ArrayNumberVariable) assert isinstance(variable.value[0], int) assert isinstance(variable.value[1], float) @@ -137,7 +137,7 @@ def test_array_object_variable(): }, ], } - variable = variable_factory.build_variable_from_mapping(mapping) + variable = variable_factory.build_conversation_variable_from_mapping(mapping) assert isinstance(variable, ArrayObjectVariable) assert isinstance(variable.value[0], dict) assert isinstance(variable.value[1], dict) @@ -149,7 +149,7 @@ def test_array_object_variable(): def test_variable_cannot_large_than_200_kb(): with pytest.raises(VariableError): - variable_factory.build_variable_from_mapping( + variable_factory.build_conversation_variable_from_mapping( { "id": str(uuid4()), "value_type": "string", diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py index f6b3be8250..f6555cfdde 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer_stream_processor.py @@ -1,6 +1,6 @@ import uuid from collections.abc import Generator -from datetime import datetime, timezone +from datetime import UTC, datetime, timezone from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey @@ -29,7 +29,7 @@ def _recursive_process(graph: Graph, next_node_id: str) -> Generator[GraphEngine def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEvent, None, None]: - route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None)) + route_node_state = RouteNodeState(node_id=next_node_id, start_at=datetime.now(UTC).replace(tzinfo=None)) parallel_id = graph.node_parallel_mapping.get(next_node_id) parallel_start_node_id = None @@ -68,7 +68,7 @@ def _publish_events(graph: Graph, next_node_id: str) -> Generator[GraphEngineEve ) route_node_state.status = RouteNodeState.Status.SUCCESS - route_node_state.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) + route_node_state.finished_at = datetime.now(UTC).replace(tzinfo=None) yield NodeRunSucceededEvent( id=node_execution_id, node_id=next_node_id, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py similarity index 92% rename from api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py rename to api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py index 096ae0ea52..9793da129d 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_variable_assigner.py +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py @@ -10,7 +10,8 @@ from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState -from core.workflow.nodes.variable_assigner import VariableAssignerNode, WriteMode +from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode +from core.workflow.nodes.variable_assigner.v1.node_data import WriteMode from models.enums import UserFrom from models.workflow import WorkflowType @@ -84,6 +85,7 @@ def test_overwrite_string_variable(): config={ "id": "node_id", "data": { + "title": "test", "assigned_variable_selector": ["conversation", conversation_variable.name], "write_mode": WriteMode.OVER_WRITE.value, "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], @@ -91,7 +93,7 @@ def test_overwrite_string_variable(): }, ) - with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: + with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run: list(node.run()) mock_run.assert_called_once() @@ -166,6 +168,7 @@ def test_append_variable_to_array(): config={ "id": "node_id", "data": { + "title": "test", "assigned_variable_selector": ["conversation", conversation_variable.name], "write_mode": WriteMode.APPEND.value, "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], @@ -173,7 +176,7 @@ def test_append_variable_to_array(): }, ) - with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: + with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run: list(node.run()) mock_run.assert_called_once() @@ -237,6 +240,7 @@ def test_clear_array(): config={ "id": "node_id", "data": { + "title": "test", "assigned_variable_selector": ["conversation", conversation_variable.name], "write_mode": WriteMode.CLEAR.value, "input_variable_selector": [], @@ -244,7 +248,7 @@ def test_clear_array(): }, ) - with mock.patch("core.workflow.nodes.variable_assigner.node.update_conversation_variable") as mock_run: + with mock.patch("core.workflow.nodes.variable_assigner.common.helpers.update_conversation_variable") as mock_run: list(node.run()) mock_run.assert_called_once() diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py new file mode 100644 index 0000000000..16c1370018 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py @@ -0,0 +1,24 @@ +import pytest + +from core.variables import SegmentType +from core.workflow.nodes.variable_assigner.v2.enums import Operation +from core.workflow.nodes.variable_assigner.v2.helpers import is_input_value_valid + + +def test_is_input_value_valid_overwrite_array_string(): + # Valid cases + assert is_input_value_valid( + variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["hello", "world"] + ) + assert is_input_value_valid(variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[]) + + # Invalid cases + assert not is_input_value_valid( + variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value="not an array" + ) + assert not is_input_value_valid( + variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[1, 2, 3] + ) + assert not is_input_value_valid( + variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["valid", 123, "invalid"] + ) diff --git a/api/tests/unit_tests/models/test_conversation_variable.py b/api/tests/unit_tests/models/test_conversation_variable.py index b879afa3e7..5d84a2ec85 100644 --- a/api/tests/unit_tests/models/test_conversation_variable.py +++ b/api/tests/unit_tests/models/test_conversation_variable.py @@ -6,7 +6,7 @@ from models import ConversationVariable def test_from_variable_and_to_variable(): - variable = variable_factory.build_variable_from_mapping( + variable = variable_factory.build_conversation_variable_from_mapping( { "id": str(uuid4()), "name": "name", diff --git a/api/tests/unit_tests/models/test_workflow.py b/api/tests/unit_tests/models/test_workflow.py index 478fa8012b..fe56f18f1b 100644 --- a/api/tests/unit_tests/models/test_workflow.py +++ b/api/tests/unit_tests/models/test_workflow.py @@ -24,10 +24,18 @@ def test_environment_variables(): ) # Create some EnvironmentVariable instances - variable1 = StringVariable.model_validate({"name": "var1", "value": "value1", "id": str(uuid4())}) - variable2 = IntegerVariable.model_validate({"name": "var2", "value": 123, "id": str(uuid4())}) - variable3 = SecretVariable.model_validate({"name": "var3", "value": "secret", "id": str(uuid4())}) - variable4 = FloatVariable.model_validate({"name": "var4", "value": 3.14, "id": str(uuid4())}) + variable1 = StringVariable.model_validate( + {"name": "var1", "value": "value1", "id": str(uuid4()), "selector": ["env", "var1"]} + ) + variable2 = IntegerVariable.model_validate( + {"name": "var2", "value": 123, "id": str(uuid4()), "selector": ["env", "var2"]} + ) + variable3 = SecretVariable.model_validate( + {"name": "var3", "value": "secret", "id": str(uuid4()), "selector": ["env", "var3"]} + ) + variable4 = FloatVariable.model_validate( + {"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]} + ) with ( mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), @@ -58,10 +66,18 @@ def test_update_environment_variables(): ) # Create some EnvironmentVariable instances - variable1 = StringVariable.model_validate({"name": "var1", "value": "value1", "id": str(uuid4())}) - variable2 = IntegerVariable.model_validate({"name": "var2", "value": 123, "id": str(uuid4())}) - variable3 = SecretVariable.model_validate({"name": "var3", "value": "secret", "id": str(uuid4())}) - variable4 = FloatVariable.model_validate({"name": "var4", "value": 3.14, "id": str(uuid4())}) + variable1 = StringVariable.model_validate( + {"name": "var1", "value": "value1", "id": str(uuid4()), "selector": ["env", "var1"]} + ) + variable2 = IntegerVariable.model_validate( + {"name": "var2", "value": 123, "id": str(uuid4()), "selector": ["env", "var2"]} + ) + variable3 = SecretVariable.model_validate( + {"name": "var3", "value": "secret", "id": str(uuid4()), "selector": ["env", "var3"]} + ) + variable4 = FloatVariable.model_validate( + {"name": "var4", "value": 3.14, "id": str(uuid4()), "selector": ["env", "var4"]} + ) with ( mock.patch("core.helper.encrypter.encrypt_token", return_value="encrypted_token"), diff --git a/docker-legacy/docker-compose.chroma.yaml b/docker-legacy/docker-compose.chroma.yaml index a943d620c0..63354305de 100644 --- a/docker-legacy/docker-compose.chroma.yaml +++ b/docker-legacy/docker-compose.chroma.yaml @@ -1,7 +1,7 @@ services: # Chroma vector store. chroma: - image: ghcr.io/chroma-core/chroma:0.5.1 + image: ghcr.io/chroma-core/chroma:0.5.20 restart: always volumes: - ./volumes/chroma:/chroma/chroma diff --git a/docker-legacy/docker-compose.yaml b/docker-legacy/docker-compose.yaml index 7bf2cd4708..ea9f5fc493 100644 --- a/docker-legacy/docker-compose.yaml +++ b/docker-legacy/docker-compose.yaml @@ -2,7 +2,7 @@ version: '3' services: # API service api: - image: langgenius/dify-api:0.11.2 + image: langgenius/dify-api:0.13.0 restart: always environment: # Startup mode, 'api' starts the API server. @@ -210,7 +210,7 @@ services: SSRF_PROXY_HTTP_URL: 'http://ssrf_proxy:3128' SSRF_PROXY_HTTPS_URL: 'http://ssrf_proxy:3128' # Indexing configuration - INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: 1000 + INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: 4000 depends_on: - db - redis @@ -227,7 +227,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.11.2 + image: langgenius/dify-api:0.13.0 restart: always environment: CONSOLE_WEB_URL: '' @@ -397,7 +397,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.11.2 + image: langgenius/dify-web:0.13.0 restart: always environment: # The base URL of console application api server, refers to the Console base URL of WEB service if console domain is diff --git a/docker/.env.example b/docker/.env.example index 50dc56a5c9..719a025877 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -579,6 +579,7 @@ ETL_TYPE=dify # For example: http://unstructured:8000/general/v0/general UNSTRUCTURED_API_URL= UNSTRUCTURED_API_KEY= +SCARF_NO_ANALYTICS=true # ------------------------------ # Model Configuration @@ -682,7 +683,7 @@ SMTP_OPPORTUNISTIC_TLS=false # ------------------------------ # Maximum length of segmentation tokens for indexing -INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=1000 +INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH=4000 # Member invitation link valid time (hours), # Default: 72. diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index e316ec0c5a..d0ec9e5977 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -177,7 +177,7 @@ x-shared-env: &shared-api-worker-env ELASTICSEARCH_PASSWORD: ${ELASTICSEARCH_PASSWORD:-elastic} LINDORM_URL: ${LINDORM_URL:-http://lindorm:30070} LINDORM_USERNAME: ${LINDORM_USERNAME:-lindorm} - LINDORM_PASSWORD: ${LINDORM_USERNAME:-lindorm } + LINDORM_PASSWORD: ${LINDORM_PASSWORD:-lindorm } KIBANA_PORT: ${KIBANA_PORT:-5601} # AnalyticDB configuration ANALYTICDB_KEY_ID: ${ANALYTICDB_KEY_ID:-} @@ -247,7 +247,7 @@ x-shared-env: &shared-api-worker-env SMTP_OPPORTUNISTIC_TLS: ${SMTP_OPPORTUNISTIC_TLS:-false} RESEND_API_KEY: ${RESEND_API_KEY:-your-resend-api-key} RESEND_API_URL: ${RESEND_API_URL:-https://api.resend.com} - INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-1000} + INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: ${INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH:-4000} INVITE_EXPIRY_HOURS: ${INVITE_EXPIRY_HOURS:-72} RESET_PASSWORD_TOKEN_EXPIRY_MINUTES: ${RESET_PASSWORD_TOKEN_EXPIRY_MINUTES:-5} CODE_EXECUTION_ENDPOINT: ${CODE_EXECUTION_ENDPOINT:-http://sandbox:8194} @@ -287,11 +287,12 @@ x-shared-env: &shared-api-worker-env OCEANBASE_CLUSTER_NAME: ${OCEANBASE_CLUSTER_NAME:-difyai} OCEANBASE_MEMORY_LIMIT: ${OCEANBASE_MEMORY_LIMIT:-6G} CREATE_TIDB_SERVICE_JOB_ENABLED: ${CREATE_TIDB_SERVICE_JOB_ENABLED:-false} + RETRIEVAL_TOP_N: ${RETRIEVAL_TOP_N:-0} services: # API service api: - image: langgenius/dify-api:0.11.2 + image: langgenius/dify-api:0.13.0 restart: always environment: # Use the shared environment variables. @@ -311,7 +312,7 @@ services: # worker service # The Celery worker for processing the queue. worker: - image: langgenius/dify-api:0.11.2 + image: langgenius/dify-api:0.13.0 restart: always environment: # Use the shared environment variables. @@ -330,7 +331,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:0.11.2 + image: langgenius/dify-web:0.13.0 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -605,7 +606,7 @@ services: # Chroma vector database chroma: - image: ghcr.io/chroma-core/chroma:0.5.1 + image: ghcr.io/chroma-core/chroma:0.5.20 profiles: - chroma restart: always diff --git a/web/app/(commonLayout)/apps/page.tsx b/web/app/(commonLayout)/apps/page.tsx index 76985de34f..ab9852e462 100644 --- a/web/app/(commonLayout)/apps/page.tsx +++ b/web/app/(commonLayout)/apps/page.tsx @@ -1,23 +1,27 @@ +'use client' +import { useContextSelector } from 'use-context-selector' +import { useTranslation } from 'react-i18next' import style from '../list.module.css' import Apps from './Apps' import classNames from '@/utils/classnames' -import { getLocaleOnServer, useTranslation as translate } from '@/i18n/server' +import AppContext from '@/context/app-context' +import { LicenseStatus } from '@/types/feature' -const AppList = async () => { - const locale = getLocaleOnServer() - const { t } = await translate(locale, 'app') +const AppList = () => { + const { t } = useTranslation() + const systemFeatures = useContextSelector(AppContext, v => v.systemFeatures) return (
-
-

{t('join')}

-

{t('communityIntro')}

+ {systemFeatures.license.status === LicenseStatus.NONE &&
+

{t('app.join')}

+

{t('app.communityIntro')}

-
+
}
) } diff --git a/web/app/components/app/annotation/index.tsx b/web/app/components/app/annotation/index.tsx index 0783c3fa66..418079abe8 100644 --- a/web/app/components/app/annotation/index.tsx +++ b/web/app/components/app/annotation/index.tsx @@ -2,19 +2,17 @@ import type { FC } from 'react' import React, { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' -import { Pagination } from 'react-headless-pagination' import { useDebounce } from 'ahooks' -import { ArrowLeftIcon, ArrowRightIcon } from '@heroicons/react/24/outline' import Toast from '../../base/toast' import Filter from './filter' import type { QueryParam } from './filter' import List from './list' import EmptyElement from './empty-element' import HeaderOpts from './header-opts' -import s from './style.module.css' import { AnnotationEnableStatus, type AnnotationItem, type AnnotationItemBasic, JobStatus } from './type' import ViewAnnotationModal from './view-annotation-modal' import cn from '@/utils/classnames' +import Pagination from '@/app/components/base/pagination' import Switch from '@/app/components/base/switch' import { addAnnotation, delAnnotation, fetchAnnotationConfig as doFetchAnnotationConfig, editAnnotation, fetchAnnotationList, queryAnnotationJobStatus, updateAnnotationScore, updateAnnotationStatus } from '@/service/annotation' import Loading from '@/app/components/base/loading' @@ -69,9 +67,10 @@ const Annotation: FC = ({ const [queryParams, setQueryParams] = useState({}) const [currPage, setCurrPage] = React.useState(0) const debouncedQueryParams = useDebounce(queryParams, { wait: 500 }) + const [limit, setLimit] = React.useState(APP_PAGE_LIMIT) const query = { page: currPage + 1, - limit: APP_PAGE_LIMIT, + limit, keyword: debouncedQueryParams.keyword || '', } @@ -228,35 +227,12 @@ const Annotation: FC = ({ {/* Show Pagination only if the total is more than the limit */} {(total && total > APP_PAGE_LIMIT) ? - - - {t('appLog.table.pagination.previous')} - -
- -
- - {t('appLog.table.pagination.next')} - - -
+ current={currPage} + onChange={setCurrPage} + total={total} + limit={limit} + onLimitChange={setLimit} + /> : null} {isShowViewModal && ( diff --git a/web/app/components/app/annotation/style.module.css b/web/app/components/app/annotation/style.module.css deleted file mode 100644 index 24179c1ca1..0000000000 --- a/web/app/components/app/annotation/style.module.css +++ /dev/null @@ -1,3 +0,0 @@ -.pagination li { - list-style: none; -} \ No newline at end of file diff --git a/web/app/components/app/annotation/view-annotation-modal/index.tsx b/web/app/components/app/annotation/view-annotation-modal/index.tsx index daa8434ff7..0fb8bbc31e 100644 --- a/web/app/components/app/annotation/view-annotation-modal/index.tsx +++ b/web/app/components/app/annotation/view-annotation-modal/index.tsx @@ -2,13 +2,12 @@ import type { FC } from 'react' import React, { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' -import { Pagination } from 'react-headless-pagination' -import { ArrowLeftIcon, ArrowRightIcon } from '@heroicons/react/24/outline' import EditItem, { EditItemType } from '../edit-annotation-modal/edit-item' import type { AnnotationItem, HitHistoryItem } from '../type' import s from './style.module.css' import HitHistoryNoData from './hit-history-no-data' import cn from '@/utils/classnames' +import Pagination from '@/app/components/base/pagination' import Drawer from '@/app/components/base/drawer-plus' import { MessageCheckRemove } from '@/app/components/base/icons/src/vender/line/communication' import Confirm from '@/app/components/base/confirm' @@ -150,35 +149,10 @@ const ViewAnnotationModal: FC = ({ {(total && total > APP_PAGE_LIMIT) ? - - - {t('appLog.table.pagination.previous')} - -
- -
- - {t('appLog.table.pagination.next')} - - -
+ current={currPage} + onChange={setCurrPage} + total={total} + /> : null} diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx index 947abb853d..1144c323d1 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/chat-item.tsx @@ -29,6 +29,7 @@ import { useAppContext } from '@/context/app-context' import { ModelFeatureEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useFeatures } from '@/app/components/base/features/hooks' import type { InputForm } from '@/app/components/base/chat/chat/type' +import { getLastAnswer } from '@/app/components/base/chat/utils' type ChatItemProps = { modelAndParameter: ModelAndParameter @@ -101,7 +102,7 @@ const ChatItem: FC = ({ query: message, inputs, model_config: configData, - parent_message_id: chatListRef.current.at(-1)?.id || null, + parent_message_id: getLastAnswer(chatListRef.current)?.id || null, } if ((config.file_upload as any).enabled && files?.length && supportVision) diff --git a/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx b/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx index e652579cfc..6d16660e81 100644 --- a/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx +++ b/web/app/components/app/configuration/features/chat-group/opening-statement/index.tsx @@ -22,7 +22,7 @@ import { getNewVar } from '@/utils/var' import { varHighlightHTML } from '@/app/components/app/configuration/base/var-highlight' import Toast from '@/app/components/base/toast' -const MAX_QUESTION_NUM = 5 +const MAX_QUESTION_NUM = 10 export type IOpeningStatementProps = { value: string diff --git a/web/app/components/app/log-annotation/index.tsx b/web/app/components/app/log-annotation/index.tsx index c84d941143..3fa13019f9 100644 --- a/web/app/components/app/log-annotation/index.tsx +++ b/web/app/components/app/log-annotation/index.tsx @@ -52,7 +52,7 @@ const LogAnnotation: FC = ({ options={options} /> )} -
+
{pageType === PageType.log && appDetail.mode !== 'workflow' && ()} {pageType === PageType.annotation && ()} {pageType === PageType.log && appDetail.mode === 'workflow' && ()} diff --git a/web/app/components/app/log/index.tsx b/web/app/components/app/log/index.tsx index e076f587ea..592233facd 100644 --- a/web/app/components/app/log/index.tsx +++ b/web/app/components/app/log/index.tsx @@ -2,17 +2,15 @@ import type { FC, SVGProps } from 'react' import React, { useState } from 'react' import useSWR from 'swr' +import Link from 'next/link' import { usePathname } from 'next/navigation' -import { Pagination } from 'react-headless-pagination' import { useDebounce } from 'ahooks' import { omit } from 'lodash-es' import dayjs from 'dayjs' -import { ArrowLeftIcon, ArrowRightIcon } from '@heroicons/react/24/outline' import { Trans, useTranslation } from 'react-i18next' -import Link from 'next/link' import List from './list' import Filter, { TIME_PERIOD_MAPPING } from './filter' -import s from './style.module.css' +import Pagination from '@/app/components/base/pagination' import Loading from '@/app/components/base/loading' import { fetchChatConversations, fetchCompletionConversations } from '@/service/log' import { APP_PAGE_LIMIT } from '@/config' @@ -60,6 +58,7 @@ const Logs: FC = ({ appDetail }) => { sort_by: '-created_at', }) const [currPage, setCurrPage] = React.useState(0) + const [limit, setLimit] = React.useState(APP_PAGE_LIMIT) const debouncedQueryParams = useDebounce(queryParams, { wait: 500 }) // Get the app type first @@ -67,7 +66,7 @@ const Logs: FC = ({ appDetail }) => { const query = { page: currPage + 1, - limit: APP_PAGE_LIMIT, + limit, ...((debouncedQueryParams.period !== '9') ? { start: dayjs().subtract(TIME_PERIOD_MAPPING[debouncedQueryParams.period].value, 'day').startOf('day').format('YYYY-MM-DD HH:mm'), @@ -102,9 +101,9 @@ const Logs: FC = ({ appDetail }) => { const total = isChatMode ? chatConversations?.total : completionConversations?.total return ( -
-

{t('appLog.description')}

-
+
+

{t('appLog.description')}

+
{total === undefined ? @@ -115,35 +114,12 @@ const Logs: FC = ({ appDetail }) => { {/* Show Pagination only if the total is more than the limit */} {(total && total > APP_PAGE_LIMIT) ? - - - {t('appLog.table.pagination.previous')} - -
- -
- - {t('appLog.table.pagination.next')} - - -
+ current={currPage} + onChange={setCurrPage} + total={total} + limit={limit} + onLimitChange={setLimit} + /> : null}
diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 9a8cca4378..978e83737b 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -318,7 +318,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { const targetTone = TONE_LIST.find((item: any) => { let res = true validatedParams.forEach((param) => { - res = item.config?.[param] === detail.model_config?.configs?.completion_params?.[param] + res = item.config?.[param] === detail?.model_config.model?.completion_params?.[param] }) return res })?.name ?? 'custom' diff --git a/web/app/components/app/log/style.module.css b/web/app/components/app/log/style.module.css deleted file mode 100644 index adb32a39db..0000000000 --- a/web/app/components/app/log/style.module.css +++ /dev/null @@ -1,3 +0,0 @@ -.pagination li { - list-style: none; -} diff --git a/web/app/components/app/workflow-log/index.tsx b/web/app/components/app/workflow-log/index.tsx index 7a891f5895..453f2cd75a 100644 --- a/web/app/components/app/workflow-log/index.tsx +++ b/web/app/components/app/workflow-log/index.tsx @@ -3,14 +3,12 @@ import type { FC, SVGProps } from 'react' import React, { useState } from 'react' import useSWR from 'swr' import { usePathname } from 'next/navigation' -import { Pagination } from 'react-headless-pagination' import { useDebounce } from 'ahooks' -import { ArrowLeftIcon, ArrowRightIcon } from '@heroicons/react/24/outline' import { Trans, useTranslation } from 'react-i18next' import Link from 'next/link' import List from './list' import Filter from './filter' -import s from './style.module.css' +import Pagination from '@/app/components/base/pagination' import Loading from '@/app/components/base/loading' import { fetchWorkflowLogs } from '@/service/log' import { APP_PAGE_LIMIT } from '@/config' @@ -53,10 +51,11 @@ const Logs: FC = ({ appDetail }) => { const [queryParams, setQueryParams] = useState({ status: 'all' }) const [currPage, setCurrPage] = React.useState(0) const debouncedQueryParams = useDebounce(queryParams, { wait: 500 }) + const [limit, setLimit] = React.useState(APP_PAGE_LIMIT) const query = { page: currPage + 1, - limit: APP_PAGE_LIMIT, + limit, ...(debouncedQueryParams.status !== 'all' ? { status: debouncedQueryParams.status } : {}), ...(debouncedQueryParams.keyword ? { keyword: debouncedQueryParams.keyword } : {}), } @@ -77,7 +76,7 @@ const Logs: FC = ({ appDetail }) => {

{t('appLog.workflowTitle')}

{t('appLog.workflowSubtitle')}

-
+
{/* workflow log */} {total === undefined @@ -89,35 +88,12 @@ const Logs: FC = ({ appDetail }) => { {/* Show Pagination only if the total is more than the limit */} {(total && total > APP_PAGE_LIMIT) ? - - - {t('appLog.table.pagination.previous')} - -
- -
- - {t('appLog.table.pagination.next')} - - -
+ current={currPage} + onChange={setCurrPage} + total={total} + limit={limit} + onLimitChange={setLimit} + /> : null}
diff --git a/web/app/components/app/workflow-log/list.tsx b/web/app/components/app/workflow-log/list.tsx index e43d95d5ad..e3de4a957f 100644 --- a/web/app/components/app/workflow-log/list.tsx +++ b/web/app/components/app/workflow-log/list.tsx @@ -2,9 +2,7 @@ import type { FC } from 'react' import React, { useState } from 'react' import { useTranslation } from 'react-i18next' -// import s from './style.module.css' import DetailPanel from './detail' -import cn from '@/utils/classnames' import type { WorkflowAppLogDetail, WorkflowLogsResponse } from '@/models/log' import type { App } from '@/types/app' import Loading from '@/app/components/base/loading' @@ -12,6 +10,7 @@ import Drawer from '@/app/components/base/drawer' import Indicator from '@/app/components/header/indicator' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import useTimestamp from '@/hooks/use-timestamp' +import cn from '@/utils/classnames' type ILogs = { logs?: WorkflowLogsResponse diff --git a/web/app/components/app/workflow-log/style.module.css b/web/app/components/app/workflow-log/style.module.css deleted file mode 100644 index adb32a39db..0000000000 --- a/web/app/components/app/workflow-log/style.module.css +++ /dev/null @@ -1,3 +0,0 @@ -.pagination li { - list-style: none; -} diff --git a/web/app/components/base/badge.tsx b/web/app/components/base/badge.tsx index c3300a1e67..722fde3237 100644 --- a/web/app/components/base/badge.tsx +++ b/web/app/components/base/badge.tsx @@ -15,7 +15,7 @@ const Badge = ({ return (
-const Divider: FC = ({ type, className = '', style }) => { +const Divider: FC = ({ type, bgStyle, className = '', style }) => { return ( -
+
) } diff --git a/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx index 9f25d0fa11..41ed043e93 100644 --- a/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx +++ b/web/app/components/base/features/new-feature-panel/conversation-opener/modal.tsx @@ -22,7 +22,7 @@ type OpeningSettingModalProps = { onAutoAddPromptVariable?: (variable: PromptVariable[]) => void } -const MAX_QUESTION_NUM = 5 +const MAX_QUESTION_NUM = 10 const OpeningSettingModal = ({ data, diff --git a/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-image-item.tsx b/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-image-item.tsx index 4595280893..6e9ef489fc 100644 --- a/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-image-item.tsx +++ b/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-image-item.tsx @@ -84,7 +84,7 @@ const FileImageItem = ({ className='absolute bottom-0.5 right-0.5 flex items-center justify-center w-6 h-6 rounded-lg bg-components-actionbar-bg shadow-md' onClick={(e) => { e.stopPropagation() - downloadFile(url || '', name) + downloadFile(url || base64Url || '', name) }} > diff --git a/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx b/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx index fcf665643c..dcf4082780 100644 --- a/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx +++ b/web/app/components/base/file-uploader/file-uploader-in-chat-input/file-item.tsx @@ -80,7 +80,7 @@ const FileItem = ({ }
{ - showDownloadAction && ( + showDownloadAction && url && ( + {/* Cookie banner */} + ) diff --git a/web/app/components/base/input/index.tsx b/web/app/components/base/input/index.tsx index 191dbde452..bf8efdb65a 100644 --- a/web/app/components/base/input/index.tsx +++ b/web/app/components/base/input/index.tsx @@ -64,7 +64,9 @@ const Input = ({ destructive && 'bg-components-input-bg-destructive border-components-input-border-destructive text-components-input-text-filled hover:bg-components-input-bg-destructive hover:border-components-input-border-destructive focus:bg-components-input-bg-destructive focus:border-components-input-border-destructive', className, )} - placeholder={placeholder ?? (showLeftIcon ? t('common.operation.search') ?? '' : t('common.placeholder.input'))} + placeholder={placeholder ?? (showLeftIcon + ? (t('common.operation.search') || '') + : (t('common.placeholder.input') || ''))} value={value} onChange={onChange} disabled={disabled} diff --git a/web/app/components/base/list-empty/horizontal-line.tsx b/web/app/components/base/list-empty/horizontal-line.tsx new file mode 100644 index 0000000000..cb8edb8dee --- /dev/null +++ b/web/app/components/base/list-empty/horizontal-line.tsx @@ -0,0 +1,21 @@ +type HorizontalLineProps = { + className?: string +} +const HorizontalLine = ({ + className, +}: HorizontalLineProps) => { + return ( + + + + + + + + + + + ) +} + +export default HorizontalLine diff --git a/web/app/components/base/list-empty/index.tsx b/web/app/components/base/list-empty/index.tsx new file mode 100644 index 0000000000..e925878bc1 --- /dev/null +++ b/web/app/components/base/list-empty/index.tsx @@ -0,0 +1,35 @@ +import React from 'react' +import { Variable02 } from '../icons/src/vender/solid/development' +import VerticalLine from './vertical-line' +import HorizontalLine from './horizontal-line' + +type ListEmptyProps = { + title?: string + description?: React.ReactNode +} + +const ListEmpty = ({ + title, + description, +}: ListEmptyProps) => { + return ( +
+
+
+ + + + + +
+
+
+
{title}
+ {description} +
+
+ ) +} + +export default ListEmpty diff --git a/web/app/components/base/list-empty/vertical-line.tsx b/web/app/components/base/list-empty/vertical-line.tsx new file mode 100644 index 0000000000..63e57447bf --- /dev/null +++ b/web/app/components/base/list-empty/vertical-line.tsx @@ -0,0 +1,21 @@ +type VerticalLineProps = { + className?: string +} +const VerticalLine = ({ + className, +}: VerticalLineProps) => { + return ( + + + + + + + + + + + ) +} + +export default VerticalLine diff --git a/web/app/components/base/pagination/hook.ts b/web/app/components/base/pagination/hook.ts new file mode 100644 index 0000000000..6501d6f457 --- /dev/null +++ b/web/app/components/base/pagination/hook.ts @@ -0,0 +1,95 @@ +import React, { useCallback } from 'react' +import type { IPaginationProps, IUsePagination } from './type' + +const usePagination = ({ + currentPage, + setCurrentPage, + truncableText = '...', + truncableClassName = '', + totalPages, + edgePageCount, + middlePagesSiblingCount, +}: IPaginationProps): IUsePagination => { + const pages = Array(totalPages) + .fill(0) + .map((_, i) => i + 1) + + const hasPreviousPage = currentPage > 1 + const hasNextPage = currentPage < totalPages + + const isReachedToFirst = currentPage <= middlePagesSiblingCount + const isReachedToLast = currentPage + middlePagesSiblingCount >= totalPages + + const middlePages = React.useMemo(() => { + const middlePageCount = middlePagesSiblingCount * 2 + 1 + if (isReachedToFirst) + return pages.slice(0, middlePageCount) + + if (isReachedToLast) + return pages.slice(-middlePageCount) + + return pages.slice( + currentPage - middlePagesSiblingCount, + currentPage + middlePagesSiblingCount + 1, + ) + }, [currentPage, isReachedToFirst, isReachedToLast, middlePagesSiblingCount, pages]) + + const getAllPreviousPages = useCallback(() => { + return pages.slice(0, middlePages[0] - 1) + }, [middlePages, pages]) + + const previousPages = React.useMemo(() => { + if (isReachedToFirst || getAllPreviousPages().length < 1) + return [] + + return pages + .slice(0, edgePageCount) + .filter(p => !middlePages.includes(p)) + }, [edgePageCount, getAllPreviousPages, isReachedToFirst, middlePages, pages]) + + const getAllNextPages = React.useMemo(() => { + return pages.slice( + middlePages[middlePages.length - 1], + pages[pages.length], + ) + }, [pages, middlePages]) + + const nextPages = React.useMemo(() => { + if (isReachedToLast) + return [] + + if (getAllNextPages.length < 1) + return [] + + return pages + .slice(pages.length - edgePageCount, pages.length) + .filter(p => !middlePages.includes(p)) + }, [edgePageCount, getAllNextPages.length, isReachedToLast, middlePages, pages]) + + const isPreviousTruncable = React.useMemo(() => { + // Is truncable if first value of middlePage is larger than last value of previousPages + return middlePages[0] > previousPages[previousPages.length - 1] + 1 + }, [previousPages, middlePages]) + + const isNextTruncable = React.useMemo(() => { + // Is truncable if last value of middlePage is larger than first value of previousPages + return middlePages[middlePages.length - 1] + 1 < nextPages[0] + }, [nextPages, middlePages]) + + return { + currentPage, + setCurrentPage, + truncableText, + truncableClassName, + pages, + hasPreviousPage, + hasNextPage, + previousPages, + isPreviousTruncable, + middlePages, + isNextTruncable, + nextPages, + } +} + +export default usePagination diff --git a/web/app/components/base/pagination/index.tsx b/web/app/components/base/pagination/index.tsx index f8c5684b55..b64c712425 100644 --- a/web/app/components/base/pagination/index.tsx +++ b/web/app/components/base/pagination/index.tsx @@ -1,50 +1,165 @@ import type { FC } from 'react' import React from 'react' -import { Pagination } from 'react-headless-pagination' -import { ArrowLeftIcon, ArrowRightIcon } from '@heroicons/react/24/outline' import { useTranslation } from 'react-i18next' -import s from './style.module.css' +import { RiArrowLeftLine, RiArrowRightLine } from '@remixicon/react' +import { useDebounceFn } from 'ahooks' +import { Pagination } from './pagination' +import Button from '@/app/components/base/button' +import Input from '@/app/components/base/input' +import cn from '@/utils/classnames' type Props = { + className?: string current: number onChange: (cur: number) => void total: number limit?: number + onLimitChange?: (limit: number) => void } -const CustomizedPagination: FC = ({ current, onChange, total, limit = 10 }) => { +const CustomizedPagination: FC = ({ + className, + current, + onChange, + total, + limit = 10, + onLimitChange, +}) => { const { t } = useTranslation() const totalPages = Math.ceil(total / limit) + const inputRef = React.useRef(null) + const [showInput, setShowInput] = React.useState(false) + const [inputValue, setInputValue] = React.useState(current + 1) + const [showPerPageTip, setShowPerPageTip] = React.useState(false) + + const { run: handlePaging } = useDebounceFn((value: string) => { + if (parseInt(value) > totalPages) { + setInputValue(totalPages) + onChange(totalPages - 1) + setShowInput(false) + return + } + if (parseInt(value) < 1) { + setInputValue(1) + onChange(0) + setShowInput(false) + return + } + onChange(parseInt(value) - 1) + setInputValue(parseInt(value)) + setShowInput(false) + }, { wait: 500 }) + + const handleInputChange = (e: React.ChangeEvent) => { + const value = e.target.value + if (!value) + return setInputValue('') + if (isNaN(parseInt(value))) + return setInputValue('') + setInputValue(parseInt(value)) + handlePaging(value) + } + return ( - - - {t('appLog.table.pagination.previous')} - -
+
+
} + disabled={current === 0} + > + + + {!showInput && ( +
setShowInput(true)} + > +
{current + 1}
+
/
+
{totalPages}
+
+ )} + {showInput && ( + setShowInput(false)} + /> + )} +
} + disabled={current === totalPages - 1} + > + + +
+
- - {t('appLog.table.pagination.next')} - - + {onLimitChange && ( +
+
{showPerPageTip ? t('common.pagination.perPage') : ''}
+
setShowPerPageTip(true)} + onMouseLeave={() => setShowPerPageTip(false)} + > +
onLimitChange?.(10)} + >10
+
onLimitChange?.(25)} + >25
+
onLimitChange?.(50)} + >50
+
+
+ )} ) } diff --git a/web/app/components/base/pagination/pagination.tsx b/web/app/components/base/pagination/pagination.tsx new file mode 100644 index 0000000000..5898c4e924 --- /dev/null +++ b/web/app/components/base/pagination/pagination.tsx @@ -0,0 +1,189 @@ +import React from 'react' +import clsx from 'clsx' +import usePagination from './hook' +import type { + ButtonProps, + IPagination, + IPaginationProps, + PageButtonProps, +} from './type' + +const defaultState: IPagination = { + currentPage: 0, + setCurrentPage: () => {}, + truncableText: '...', + truncableClassName: '', + pages: [], + hasPreviousPage: false, + hasNextPage: false, + previousPages: [], + isPreviousTruncable: false, + middlePages: [], + isNextTruncable: false, + nextPages: [], +} + +const PaginationContext: React.Context = React.createContext(defaultState) + +export const PrevButton = ({ + className, + children, + dataTestId, + as =
diff --git a/web/app/components/develop/template/template.en.mdx b/web/app/components/develop/template/template.en.mdx index c923ea30db..f469076bf3 100755 --- a/web/app/components/develop/template/template.en.mdx +++ b/web/app/components/develop/template/template.en.mdx @@ -379,10 +379,107 @@ The text generation application offers non-session support and is ideal for tran --- + + + + Text to speech. + + ### Request Body + + + + For text messages generated by Dify, simply pass the generated message-id directly. The backend will use the message-id to look up the corresponding content and synthesize the voice information directly. If both message_id and text are provided simultaneously, the message_id is given priority. + + + Speech generated content。 + + + The user identifier, defined by the developer, must ensure uniqueness within the app. + + + + + + + + ```bash {{ title: 'cURL' }} + curl -o text-to-audio.mp3 -X POST '${props.appDetail.api_base_url}/text-to-audio' \ + --header 'Authorization: Bearer {api_key}' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "message_id": "5ad4cb98-f0c7-4085-b384-88c403be6290", + "text": "Hello Dify", + "user": "abc-123" + }' + ``` + + + + + ```json {{ title: 'headers' }} + { + "Content-Type": "audio/wav" + } + ``` + + + +--- + + + + + Used to get basic information about this application + ### Query + + + + User identifier, defined by the developer's rules, must be unique within the application. + + + ### Response + - `name` (string) application name + - `description` (string) application description + - `tags` (array[string]) application tags + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/info?user=abc-123' \ + -H 'Authorization: Bearer {api_key}' + ``` + + + ```json {{ title: 'Response' }} + { + "name": "My App", + "description": "This is my app.", + "tags": [ + "tag1", + "tag2" + ] + } + ``` + + + + +--- + @@ -497,56 +594,3 @@ The text generation application offers non-session support and is ideal for tran - ---- - - - - - Text to speech. - - ### Request Body - - - - For text messages generated by Dify, simply pass the generated message-id directly. The backend will use the message-id to look up the corresponding content and synthesize the voice information directly. If both message_id and text are provided simultaneously, the message_id is given priority. - - - Speech generated content。 - - - The user identifier, defined by the developer, must ensure uniqueness within the app. - - - - - - - - ```bash {{ title: 'cURL' }} - curl -o text-to-audio.mp3 -X POST '${props.appDetail.api_base_url}/text-to-audio' \ - --header 'Authorization: Bearer {api_key}' \ - --header 'Content-Type: application/json' \ - --data-raw '{ - "message_id": "5ad4cb98-f0c7-4085-b384-88c403be6290", - "text": "Hello Dify", - "user": "abc-123" - }' - ``` - - - - - ```json {{ title: 'headers' }} - { - "Content-Type": "audio/wav" - } - ``` - - - diff --git a/web/app/components/develop/template/template.ja.mdx b/web/app/components/develop/template/template.ja.mdx index a6ab109229..bd92bd7f36 100755 --- a/web/app/components/develop/template/template.ja.mdx +++ b/web/app/components/develop/template/template.ja.mdx @@ -375,13 +375,109 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from +--- + + + + + テキストを音声に変換します。 + + ### リクエストボディ + + + + Difyが生成したテキストメッセージの場合、生成されたmessage-idを直接渡すだけです。バックエンドはmessage-idを使用して対応するコンテンツを検索し、音声情報を直接合成します。message_idとtextの両方が同時に提供された場合、message_idが優先されます。 + + + 音声生成コンテンツ。 + + + 開発者が定義したユーザー識別子。アプリ内で一意性を確保する必要があります。 + + + + + + + + ```bash {{ title: 'cURL' }} + curl -o text-to-audio.mp3 -X POST '${props.appDetail.api_base_url}/text-to-audio' \ + --header 'Authorization: Bearer {api_key}' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "message_id": "5ad4cb98-f0c7-4085-b384-88c403be6290", + "text": "Hello Dify", + "user": "abc-123" + }' + ``` + + + + + ```json {{ title: 'headers' }} + { + "Content-Type": "audio/wav" + } + ``` + + + +--- + + + + + このアプリケーションの基本情報を取得するために使用されます + ### Query + + + + ユーザー識別子、開発者のルールによって定義され、アプリケーション内で一意でなければなりません。 + + + ### Response + - `name` (string) アプリケーションの名前 + - `description` (string) アプリケーションの説明 + - `tags` (array[string]) アプリケーションのタグ + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/info?user=abc-123' \ + -H 'Authorization: Bearer {api_key}' + ``` + + + ```json {{ title: 'Response' }} + { + "name": "My App", + "description": "This is my app.", + "tags": [ + "tag1", + "tag2" + ] + } + ``` + + + --- @@ -496,56 +592,3 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - ---- - - - - - テキストを音声に変換します。 - - ### リクエストボディ - - - - Difyが生成したテキストメッセージの場合、生成されたmessage-idを直接渡すだけです。バックエンドはmessage-idを使用して対応するコンテンツを検索し、音声情報を直接合成します。message_idとtextの両方が同時に提供された場合、message_idが優先されます。 - - - 音声生成コンテンツ。 - - - 開発者が定義したユーザー識別子。アプリ内で一意性を確保する必要があります。 - - - - - - - - ```bash {{ title: 'cURL' }} - curl -o text-to-audio.mp3 -X POST '${props.appDetail.api_base_url}/text-to-audio' \ - --header 'Authorization: Bearer {api_key}' \ - --header 'Content-Type: application/json' \ - --data-raw '{ - "message_id": "5ad4cb98-f0c7-4085-b384-88c403be6290", - "text": "Hello Dify", - "user": "abc-123" - }' - ``` - - - - - ```json {{ title: 'headers' }} - { - "Content-Type": "audio/wav" - } - ``` - - - diff --git a/web/app/components/develop/template/template.zh.mdx b/web/app/components/develop/template/template.zh.mdx index d193b91816..7b1bec3546 100755 --- a/web/app/components/develop/template/template.zh.mdx +++ b/web/app/components/develop/template/template.zh.mdx @@ -353,10 +353,108 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' --- + + + + 文字转语音。 + + ### Request Body + + + + Dify 生成的文本消息,那么直接传递生成的message-id 即可,后台会通过 message_id 查找相应的内容直接合成语音信息。如果同时传 message_id 和 text,优先使用 message_id。 + + + 语音生成内容。如果没有传 message-id的话,则会使用这个字段的内容 + + + 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 + + + + + + + + ```bash {{ title: 'cURL' }} + curl -o text-to-audio.mp3 -X POST '${props.appDetail.api_base_url}/text-to-audio' \ + --header 'Authorization: Bearer {api_key}' \ + --header 'Content-Type: application/json' \ + --data-raw '{ + "message_id": "5ad4cb98-f0c7-4085-b384-88c403be6290", + "text": "你好Dify", + "user": "abc-123", + "streaming": false + }' + ``` + + + + + ```json {{ title: 'headers' }} + { + "Content-Type": "audio/wav" + } + ``` + + + +--- + + + + + 用于获取应用的基本信息 + ### Query + + + + 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 + + + ### Response + - `name` (string) 应用名称 + - `description` (string) 应用描述 + - `tags` (array[string]) 应用标签 + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/info?user=abc-123' \ + -H 'Authorization: Bearer {api_key}' + ``` + + + ```json {{ title: 'Response' }} + { + "name": "My App", + "description": "This is my app.", + "tags": [ + "tag1", + "tag2" + ] + } + ``` + + + + +--- + @@ -461,57 +559,3 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - ---- - - - - - 文字转语音。 - - ### Request Body - - - - Dify 生成的文本消息,那么直接传递生成的message-id 即可,后台会通过 message_id 查找相应的内容直接合成语音信息。如果同时传 message_id 和 text,优先使用 message_id。 - - - 语音生成内容。如果没有传 message-id的话,则会使用这个字段的内容 - - - 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 - - - - - - - - ```bash {{ title: 'cURL' }} - curl -o text-to-audio.mp3 -X POST '${props.appDetail.api_base_url}/text-to-audio' \ - --header 'Authorization: Bearer {api_key}' \ - --header 'Content-Type: application/json' \ - --data-raw '{ - "message_id": "5ad4cb98-f0c7-4085-b384-88c403be6290", - "text": "你好Dify", - "user": "abc-123", - "streaming": false - }' - ``` - - - - - ```json {{ title: 'headers' }} - { - "Content-Type": "audio/wav" - } - ``` - - - diff --git a/web/app/components/develop/template/template_advanced_chat.en.mdx b/web/app/components/develop/template/template_advanced_chat.en.mdx index 7d64caa769..1d12a045ea 100644 --- a/web/app/components/develop/template/template_advanced_chat.en.mdx +++ b/web/app/components/develop/template/template_advanced_chat.en.mdx @@ -161,7 +161,7 @@ Chat applications support session persistence, allowing previous chat history to - `title` (string) name of node - `index` (int) Execution sequence number, used to display Tracing Node sequence - `predecessor_node_id` (string) optional Prefix node ID, used for canvas display execution path - - `inputs` (array[object]) Contents of all preceding node variables used in the node + - `inputs` (object) Contents of all preceding node variables used in the node - `created_at` (timestamp) timestamp of start, e.g., 1705395332 - `event: node_finished` node execution ends, success or failure in different states in the same event - `task_id` (string) Task ID, used for request tracking and the below Stop Generate API @@ -174,7 +174,7 @@ Chat applications support session persistence, allowing previous chat history to - `title` (string) name of node - `index` (int) Execution sequence number, used to display Tracing Node sequence - `predecessor_node_id` (string) optional Prefix node ID, used for canvas display execution path - - `inputs` (array[object]) Contents of all preceding node variables used in the node + - `inputs` (object) Contents of all preceding node variables used in the node - `process_data` (json) Optional node process data - `outputs` (json) Optional content of output - `status` (string) status of execution, `running` / `succeeded` / `failed` / `stopped` @@ -564,7 +564,7 @@ Chat applications support session persistence, allowing previous chat history to - `data` (array[object]) Message list - `id` (string) Message ID - `conversation_id` (string) Conversation ID - - `inputs` (array[object]) User input parameters. + - `inputs` (object) User input parameters. - `query` (string) User input / question content. - `message_files` (array[object]) Message files - `id` (string) ID @@ -648,16 +648,13 @@ Chat applications support session persistence, allowing previous chat history to Should be uniquely defined by the developer within the application. - The ID of the last record on the current page, default is null. + (Optional) The ID of the last record on the current page, default is null. - How many records to return in one request, default is the most recent 20 entries. - - - Return only pinned conversations as `true`, only non-pinned as `false` + (Optional) How many records to return in one request, default is the most recent 20 entries. Maximum 100, minimum 1. - Sorting Field (Optional), Default: -updated_at (sorted in descending order by update time) + (Optional) Sorting Field, Default: -updated_at (sorted in descending order by update time) - Available Values: created_at, -created_at, updated_at, -updated_at - The symbol before the field represents the order or reverse, "-" represents reverse order. @@ -667,9 +664,11 @@ Chat applications support session persistence, allowing previous chat history to - `data` (array[object]) List of conversations - `id` (string) Conversation ID - `name` (string) Conversation name, by default, is generated by LLM. - - `inputs` (array[object]) User input parameters. + - `inputs` (object) User input parameters. + - `status` (string) Conversation status - `introduction` (string) Introduction - `created_at` (timestamp) Creation timestamp, e.g., 1705395332 + - `updated_at` (timestamp) Update timestamp, e.g., 1705395332 - `has_more` (bool) - `limit` (int) Number of entries returned, if input exceeds system limit, system limit number is returned @@ -699,7 +698,8 @@ Chat applications support session persistence, allowing previous chat history to "myName": "Lucy" }, "status": "normal", - "created_at": 1679667915 + "created_at": 1679667915, + "updated_at": 1679667915 }, { "id": "hSIhXBhNe8X1d8Et" @@ -781,10 +781,10 @@ Chat applications support session persistence, allowing previous chat history to - The name of the conversation. This parameter can be omitted if `auto_generate` is set to `true`. + (Optional) The name of the conversation. This parameter can be omitted if `auto_generate` is set to `true`. - Automatically generate the title, default is `false` + (Optional) Automatically generate the title, default is `false` The user identifier, defined by the developer, must ensure uniqueness within the application. @@ -794,13 +794,15 @@ Chat applications support session persistence, allowing previous chat history to ### Response - `id` (string) Conversation ID - `name` (string) Conversation name - - `inputs` array[object] User input parameters. + - `inputs` (object) User input parameters + - `status` (string) Conversation status - `introduction` (string) Introduction - `created_at` (timestamp) Creation timestamp, e.g., 1705395332 + - `updated_at` (timestamp) Update timestamp, e.g., 1705395332 - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/conversations/{conversation_id}/name' \ @@ -808,6 +810,7 @@ Chat applications support session persistence, allowing previous chat history to --header 'Authorization: Bearer {api_key}' \ --data-raw '{ "name": "", + "auto_generate": true, "user": "abc-123" }' ``` @@ -820,8 +823,10 @@ Chat applications support session persistence, allowing previous chat history to "id": "cd78daf6-f9e4-4463-9ff2-54257230a0ce", "name": "Chat vs AI", "inputs": {}, + "status": "normal", "introduction": "", - "created_at": 1705569238 + "created_at": 1705569238, + "updated_at": 1705569238 } ``` @@ -931,13 +936,57 @@ Chat applications support session persistence, allowing previous chat history to +--- + + + + + Used to get basic information about this application + ### Query + + + + User identifier, defined by the developer's rules, must be unique within the application. + + + ### Response + - `name` (string) application name + - `description` (string) application description + - `tags` (array[string]) application tags + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/info?user=abc-123' \ + -H 'Authorization: Bearer {api_key}' + ``` + + + ```json {{ title: 'Response' }} + { + "name": "My App", + "description": "This is my app.", + "tags": [ + "tag1", + "tag2" + ] + } + ``` + + + --- @@ -1091,14 +1140,14 @@ Chat applications support session persistence, allowing previous chat history to ```json {{ title: 'Response' }} { - "tool_icons": { + "tool_icons": { "dalle2": "https://cloud.dify.ai/console/api/workspaces/current/tool-provider/builtin/dalle/icon", "api_tool": { - "background": "#252525", - "content": "\ud83d\ude01" + "background": "#252525", + "content": "\ud83d\ude01" } + } } - } ``` diff --git a/web/app/components/develop/template/template_advanced_chat.ja.mdx b/web/app/components/develop/template/template_advanced_chat.ja.mdx index b4c252e19a..2fc17d1ca9 100644 --- a/web/app/components/develop/template/template_advanced_chat.ja.mdx +++ b/web/app/components/develop/template/template_advanced_chat.ja.mdx @@ -161,7 +161,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - `title` (string) ノードの名前 - `index` (int) 実行シーケンス番号、トレースノードシーケンスを表示するために使用 - `predecessor_node_id` (string) オプションのプレフィックスノードID、キャンバス表示実行パスに使用 - - `inputs` (array[object]) ノードで使用されるすべての前のノード変数の内容 + - `inputs` (object) ノードで使用されるすべての前のノード変数の内容 - `created_at` (timestamp) 開始のタイムスタンプ、例:1705395332 - `event: node_finished` ノード実行が終了、成功または失敗は同じイベント内で異なる状態で示されます - `task_id` (string) タスクID、リクエスト追跡と以下のStop Generate APIに使用 @@ -174,7 +174,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - `title` (string) ノードの名前 - `index` (int) 実行シーケンス番号、トレースノードシーケンスを表示するために使用 - `predecessor_node_id` (string) オプションのプレフィックスノードID、キャンバス表示実行パスに使用 - - `inputs` (array[object]) ノードで使用されるすべての前のノード変数の内容 + - `inputs` (object) ノードで使用されるすべての前のノード変数の内容 - `process_data` (json) オプションのノードプロセスデータ - `outputs` (json) オプションの出力内容 - `status` (string) 実行の状態、`running` / `succeeded` / `failed` / `stopped` @@ -564,7 +564,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - `data` (array[object]) メッセージリスト - `id` (string) メッセージID - `conversation_id` (string) 会話ID - - `inputs` (array[object]) ユーザー入力パラメータ。 + - `inputs` (object) ユーザー入力パラメータ。 - `query` (string) ユーザー入力/質問内容。 - `message_files` (array[object]) メッセージファイル - `id` (string) ID @@ -648,16 +648,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from アプリケーション内で開発者によって一意に定義されるべきです。 - 現在のページの最後の記録のID、デフォルトはnullです。 + (Optional)現在のページの最後の記録のID、デフォルトはnullです。 - 1回のリクエストで返す記録の数、デフォルトは最新の20件です。 - - - ピン留めされた会話のみを`true`として返し、非ピン留めを`false`として返します + (Optional)1回のリクエストで返す記録の数、デフォルトは最新の20件です。最大100、最小1。 - ソートフィールド(オプション)、デフォルト:-updated_at(更新時間で降順にソート) + (Optional)ソートフィールド、デフォルト:-updated_at(更新時間で降順にソート) - 利用可能な値:created_at, -created_at, updated_at, -updated_at - フィールドの前の記号は順序または逆順を表し、"-"は逆順を表します。 @@ -667,9 +664,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - `data` (array[object]) 会話のリスト - `id` (string) 会話ID - `name` (string) 会話名、デフォルトではLLMによって生成されます。 - - `inputs` (array[object]) ユーザー入力パラメータ。 + - `inputs` (object) ユーザー入力パラメータ。 - `introduction` (string) 紹介 - `created_at` (timestamp) 作成タイムスタンプ、例:1705395332 + - `updated_at` (timestamp) 更新タイムスタンプ、例:1705395332 - `has_more` (bool) - `limit` (int) 返されたエントリ数、入力がシステム制限を超える場合、システム制限数が返されます @@ -699,7 +697,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from "myName": "Lucy" }, "status": "normal", - "created_at": 1679667915 + "created_at": 1679667915, + "updated_at": 1679667915 }, { "id": "hSIhXBhNe8X1d8Et" @@ -781,10 +780,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - 会話の名前。`auto_generate`が`true`に設定されている場合、このパラメータは省略できます。 + (Optional)会話の名前。`auto_generate`が`true`に設定されている場合、このパラメータは省略できます。 - タイトルを自動生成、デフォルトは`false` + (Optional)タイトルを自動生成、デフォルトは`false` ユーザー識別子、開発者によって定義され、アプリケーション内で一意であることを保証しなければなりません。 @@ -794,13 +793,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### 応答 - `id` (string) 会話ID - `name` (string) 会話名 - - `inputs` array[object] ユーザー入力パラメータ。 + - `inputs` (object) ユーザー入力パラメータ + - `status` (string) 会話状態 - `introduction` (string) 紹介 - `created_at` (timestamp) 作成タイムスタンプ、例:1705395332 + - `updated_at` (timestamp) 更新タイムスタンプ、例:1705395332 - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/conversations/{conversation_id}/name' \ @@ -808,6 +809,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from --header 'Authorization: Bearer {api_key}' \ --data-raw '{ "name": "", + "auto_generate": true, "user": "abc-123" }' ``` @@ -820,8 +822,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from "id": "cd78daf6-f9e4-4463-9ff2-54257230a0ce", "name": "チャット vs AI", "inputs": {}, + "status": "normal", "introduction": "", - "created_at": 1705569238 + "created_at": 1705569238, + "updated_at": 1705569238 } ``` @@ -931,13 +935,57 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from +--- + + + + + このアプリケーションの基本情報を取得するために使用されます + ### Query + + + + ユーザー識別子、開発者のルールによって定義され、アプリケーション内で一意でなければなりません。 + + + ### Response + - `name` (string) アプリケーションの名前 + - `description` (string) アプリケーションの説明 + - `tags` (array[string]) アプリケーションのタグ + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/info?user=abc-123' \ + -H 'Authorization: Bearer {api_key}' + ``` + + + ```json {{ title: 'Response' }} + { + "name": "My App", + "description": "This is my app.", + "tags": [ + "tag1", + "tag2" + ] + } + ``` + + + --- @@ -1057,7 +1105,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from @@ -1091,14 +1139,14 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```json {{ title: '応答' }} { - "tool_icons": { + "tool_icons": { "dalle2": "https://cloud.dify.ai/console/api/workspaces/current/tool-provider/builtin/dalle/icon", "api_tool": { - "background": "#252525", - "content": "\ud83d\ude01" + "background": "#252525", + "content": "\ud83d\ude01" } + } } - } ``` diff --git a/web/app/components/develop/template/template_advanced_chat.zh.mdx b/web/app/components/develop/template/template_advanced_chat.zh.mdx index f3ddd6933c..734e52ae58 100755 --- a/web/app/components/develop/template/template_advanced_chat.zh.mdx +++ b/web/app/components/develop/template/template_advanced_chat.zh.mdx @@ -162,7 +162,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - `title` (string) 节点名称 - `index` (int) 执行序号,用于展示 Tracing Node 顺序 - `predecessor_node_id` (string) 前置节点 ID,用于画布展示执行路径 - - `inputs` (array[object]) 节点中所有使用到的前置节点变量内容 + - `inputs` (object) 节点中所有使用到的前置节点变量内容 - `created_at` (timestamp) 开始时间 - `event: node_finished` node 执行结束,成功失败同一事件中不同状态 - `task_id` (string) 任务 ID,用于请求跟踪和下方的停止响应接口 @@ -173,7 +173,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - `node_id` (string) 节点 ID - `index` (int) 执行序号,用于展示 Tracing Node 顺序 - `predecessor_node_id` (string) optional 前置节点 ID,用于画布展示执行路径 - - `inputs` (array[object]) 节点中所有使用到的前置节点变量内容 + - `inputs` (object) 节点中所有使用到的前置节点变量内容 - `process_data` (json) Optional 节点过程数据 - `outputs` (json) Optional 输出内容 - `status` (string) 执行状态 `running` / `succeeded` / `failed` / `stopped` @@ -570,7 +570,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - `data` (array[object]) 消息列表 - `id` (string) 消息 ID - `conversation_id` (string) 会话 ID - - `inputs` (array[object]) 用户输入参数。 + - `inputs` (object) 用户输入参数。 - `query` (string) 用户输入 / 提问内容。 - `message_files` (array[object]) 消息文件 - `id` (string) ID @@ -683,16 +683,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 - 当前页最后面一条记录的 ID,默认 null + (选填)当前页最后面一条记录的 ID,默认 null - 一次请求返回多少条记录 - - - 只返回置顶 true,只返回非置顶 false + (选填)一次请求返回多少条记录,默认 20 条,最大 100 条,最小 1 条。 - 排序字段(选题),默认 -updated_at(按更新时间倒序排列) + (选填)排序字段,默认 -updated_at(按更新时间倒序排列) - 可选值:created_at, -created_at, updated_at, -updated_at - 字段前面的符号代表顺序或倒序,-代表倒序 @@ -702,9 +699,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - `data` (array[object]) 会话列表 - `id` (string) 会话 ID - `name` (string) 会话名称,默认由大语言模型生成。 - - `inputs` (array[object]) 用户输入参数。 + - `inputs` (object) 用户输入参数。 + - `status` (string) 会话状态 - `introduction` (string) 开场白 - `created_at` (timestamp) 创建时间 + - `updated_at` (timestamp) 更新时间 - `has_more` (bool) - `limit` (int) 返回条数,若传入超过系统限制,返回系统限制数量 @@ -734,7 +733,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' "myName": "Lucy" }, "status": "normal", - "created_at": 1679667915 + "created_at": 1679667915, + "updated_at": 1679667915 }, { "id": "hSIhXBhNe8X1d8Et" @@ -817,10 +817,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - 名称,若 `auto_generate` 为 `true` 时,该参数可不传。 + (选填)名称,若 `auto_generate` 为 `true` 时,该参数可不传。 - - 自动生成标题,默认 false。 + + (选填)自动生成标题,默认 false。 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 @@ -830,13 +830,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' ### Response - `id` (string) 会话 ID - `name` (string) 会话名称 - - `inputs` array[object] 用户输入参数。 + - `inputs` (object) 用户输入参数 + - `status` (string) 会话状态 - `introduction` (string) 开场白 - `created_at` (timestamp) 创建时间 + - `updated_at` (timestamp) 更新时间 - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/conversations/{conversation_id}/name' \ @@ -844,6 +846,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' --header 'Content-Type: application/json' \ --data-raw '{ "name": "", + "auto_generate": true, "user": "abc-123" }' ``` @@ -853,7 +856,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' ```json {{ title: 'Response' }} { - "result": "success" + "id": "34d511d5-56de-4f16-a997-57b379508443", + "name": "hello", + "inputs": {}, + "status": "normal", + "introduction": "", + "created_at": 1732731141, + "updated_at": 1732734510 } ``` @@ -960,13 +969,57 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' +--- + + + + + 用于获取应用的基本信息 + ### Query + + + + 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 + + + ### Response + - `name` (string) 应用名称 + - `description` (string) 应用描述 + - `tags` (array[string]) 应用标签 + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/info?user=abc-123' \ + -H 'Authorization: Bearer {api_key}' + ``` + + + ```json {{ title: 'Response' }} + { + "name": "My App", + "description": "This is my app.", + "tags": [ + "tag1", + "tag2" + ] + } + ``` + + + --- diff --git a/web/app/components/develop/template/template_chat.en.mdx b/web/app/components/develop/template/template_chat.en.mdx index ac8ee9d657..4e873b3294 100644 --- a/web/app/components/develop/template/template_chat.en.mdx +++ b/web/app/components/develop/template/template_chat.en.mdx @@ -528,7 +528,7 @@ Chat applications support session persistence, allowing previous chat history to - `data` (array[object]) Message list - `id` (string) Message ID - `conversation_id` (string) Conversation ID - - `inputs` (array[object]) User input parameters. + - `inputs` (object) User input parameters. - `query` (string) User input / question content. - `message_files` (array[object]) Message files - `id` (string) ID @@ -682,16 +682,13 @@ Chat applications support session persistence, allowing previous chat history to Should be uniquely defined by the developer within the application. - The ID of the last record on the current page, default is null. + (Optional) The ID of the last record on the current page, default is null. - How many records to return in one request, default is the most recent 20 entries. - - - Return only pinned conversations as `true`, only non-pinned as `false` + (Optional) How many records to return in one request, default is the most recent 20 entries. Maximum 100, minimum 1. - Sorting Field (Optional), Default: -updated_at (sorted in descending order by update time) + (Optional) Sorting Field, Default: -updated_at (sorted in descending order by update time) - Available Values: created_at, -created_at, updated_at, -updated_at - The symbol before the field represents the order or reverse, "-" represents reverse order. @@ -701,9 +698,11 @@ Chat applications support session persistence, allowing previous chat history to - `data` (array[object]) List of conversations - `id` (string) Conversation ID - `name` (string) Conversation name, by default, is a snippet of the first question asked by the user in the conversation. - - `inputs` (array[object]) User input parameters. + - `inputs` (object) User input parameters. + - `status` (string) Conversation status - `introduction` (string) Introduction - `created_at` (timestamp) Creation timestamp, e.g., 1705395332 + - `updated_at` (timestamp) Update timestamp, e.g., 1705395332 - `has_more` (bool) - `limit` (int) Number of entries returned, if input exceeds system limit, system limit number is returned @@ -733,7 +732,8 @@ Chat applications support session persistence, allowing previous chat history to "myName": "Lucy" }, "status": "normal", - "created_at": 1679667915 + "created_at": 1679667915, + "updated_at": 1679667915 }, { "id": "hSIhXBhNe8X1d8Et" @@ -815,10 +815,10 @@ Chat applications support session persistence, allowing previous chat history to - The name of the conversation. This parameter can be omitted if `auto_generate` is set to `true`. + (Optional) The name of the conversation. This parameter can be omitted if `auto_generate` is set to `true`. - Automatically generate the title, default is `false` + (Optional) Automatically generate the title, default is `false` The user identifier, defined by the developer, must ensure uniqueness within the application. @@ -828,13 +828,15 @@ Chat applications support session persistence, allowing previous chat history to ### Response - `id` (string) Conversation ID - `name` (string) Conversation name - - `inputs` array[object] User input parameters. + - `inputs` (object) User input parameters + - `status` (string) Conversation status - `introduction` (string) Introduction - `created_at` (timestamp) Creation timestamp, e.g., 1705395332 + - `updated_at` (timestamp) Update timestamp, e.g., 1705395332 - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/conversations/{conversation_id}/name' \ @@ -842,6 +844,7 @@ Chat applications support session persistence, allowing previous chat history to --header 'Authorization: Bearer {api_key}' \ --data-raw '{ "name": "", + "auto_generate": true, "user": "abc-123" }' ``` @@ -854,8 +857,10 @@ Chat applications support session persistence, allowing previous chat history to "id": "cd78daf6-f9e4-4463-9ff2-54257230a0ce", "name": "Chat vs AI", "inputs": {}, + "status": "normal", "introduction": "", - "created_at": 1705569238 + "created_at": 1705569238, + "updated_at": 1705569238 } ``` @@ -960,13 +965,57 @@ Chat applications support session persistence, allowing previous chat history to +--- + + + + + Used to get basic information about this application + ### Query + + + + User identifier, defined by the developer's rules, must be unique within the application. + + + ### Response + - `name` (string) application name + - `description` (string) application description + - `tags` (array[string]) application tags + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/info?user=abc-123' \ + -H 'Authorization: Bearer {api_key}' + ``` + + + ```json {{ title: 'Response' }} + { + "name": "My App", + "description": "This is my app.", + "tags": [ + "tag1", + "tag2" + ] + } + ``` + + + --- @@ -1120,14 +1169,14 @@ Chat applications support session persistence, allowing previous chat history to ```json {{ title: 'Response' }} { - "tool_icons": { + "tool_icons": { "dalle2": "https://cloud.dify.ai/console/api/workspaces/current/tool-provider/builtin/dalle/icon", "api_tool": { - "background": "#252525", - "content": "\ud83d\ude01" + "background": "#252525", + "content": "\ud83d\ude01" } + } } - } ``` diff --git a/web/app/components/develop/template/template_chat.ja.mdx b/web/app/components/develop/template/template_chat.ja.mdx index a962177f0e..b8914a4749 100644 --- a/web/app/components/develop/template/template_chat.ja.mdx +++ b/web/app/components/develop/template/template_chat.ja.mdx @@ -528,7 +528,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - `data` (array[object]) メッセージリスト - `id` (string) メッセージID - `conversation_id` (string) 会話ID - - `inputs` (array[object]) ユーザー入力パラメータ。 + - `inputs` (object) ユーザー入力パラメータ。 - `query` (string) ユーザー入力/質問内容。 - `message_files` (array[object]) メッセージファイル - `id` (string) ID @@ -682,16 +682,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from アプリケーション内で開発者によって一意に定義される必要があります。 - 現在のページの最後のレコードのID、デフォルトはnullです。 + (Optional)現在のページの最後のレコードのID、デフォルトはnullです。 - 1回のリクエストで返すレコードの数、デフォルトは最新の20件です。 - - - ピン留めされた会話のみを`true`として返し、ピン留めされていないもののみを`false`として返します + (Optional)1回のリクエストで返すレコードの数、デフォルトは最新の20件です。最大100、最小1。 - ソートフィールド(オプション)、デフォルト:-updated_at(更新時間で降順にソート) + (Optional)ソートフィールド、デフォルト:-updated_at(更新時間で降順にソート) - 利用可能な値:created_at, -created_at, updated_at, -updated_at - フィールドの前の記号は順序または逆順を表し、"-"は逆順を表します。 @@ -701,9 +698,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - `data` (array[object]) 会話のリスト - `id` (string) 会話ID - `name` (string) 会話名、デフォルトでは、ユーザーが会話で最初に尋ねた質問のスニペットです。 - - `inputs` (array[object]) ユーザー入力パラメータ。 + - `inputs` (object) ユーザー入力パラメータ。 - `introduction` (string) 紹介 - `created_at` (timestamp) 作成タイムスタンプ、例:1705395332 + - `updated_at` (timestamp) 更新タイムスタンプ、例:1705395332 - `has_more` (bool) - `limit` (int) 返されたエントリの数、入力がシステム制限を超える場合、システム制限の数を返します @@ -733,7 +731,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from "myName": "Lucy" }, "status": "normal", - "created_at": 1679667915 + "created_at": 1679667915, + "updated_at": 1679667915 }, { "id": "hSIhXBhNe8X1d8Et" @@ -815,10 +814,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - 会話の名前。このパラメータは、`auto_generate`が`true`に設定されている場合、省略できます。 + (Optional)会話の名前。このパラメータは、`auto_generate`が`true`に設定されている場合、省略できます。 - タイトルを自動生成します。デフォルトは`false`です。 + (Optional)タイトルを自動生成します。デフォルトは`false`です。 ユーザー識別子、開発者によって定義され、アプリケーション内で一意である必要があります。 @@ -828,13 +827,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ### 応答 - `id` (string) 会話ID - `name` (string) 会話名 - - `inputs` array[object] ユーザー入力パラメータ。 + - `inputs` (object) ユーザー入力パラメータ + - `status` (string) 会話状態 - `introduction` (string) 紹介 - `created_at` (timestamp) 作成タイムスタンプ、例:1705395332 + - `updated_at` (timestamp) 更新タイムスタンプ、例:1705395332 - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/conversations/{conversation_id}/name' \ @@ -842,6 +843,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from --header 'Authorization: Bearer {api_key}' \ --data-raw '{ "name": "", + "auto_generate": true, "user": "abc-123" }' ``` @@ -855,7 +857,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from "name": "Chat vs AI", "inputs": {}, "introduction": "", - "created_at": 1705569238 + "created_at": 1705569238, + "updated_at": 1705569238 } ``` @@ -960,13 +963,57 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from +--- + + + + + このアプリケーションの基本情報を取得するために使用されます + ### Query + + + + ユーザー識別子、開発者のルールによって定義され、アプリケーション内で一意でなければなりません。 + + + ### Response + - `name` (string) アプリケーションの名前 + - `description` (string) アプリケーションの説明 + - `tags` (array[string]) アプリケーションのタグ + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/info?user=abc-123' \ + -H 'Authorization: Bearer {api_key}' + ``` + + + ```json {{ title: 'Response' }} + { + "name": "My App", + "description": "This is my app.", + "tags": [ + "tag1", + "tag2" + ] + } + ``` + + + --- @@ -1086,7 +1133,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from @@ -1120,14 +1167,14 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ```json {{ title: '応答' }} { - "tool_icons": { + "tool_icons": { "dalle2": "https://cloud.dify.ai/console/api/workspaces/current/tool-provider/builtin/dalle/icon", "api_tool": { - "background": "#252525", - "content": "\ud83d\ude01" + "background": "#252525", + "content": "\ud83d\ude01" } + } } - } ``` diff --git a/web/app/components/develop/template/template_chat.zh.mdx b/web/app/components/develop/template/template_chat.zh.mdx index c786d56980..70242623b7 100644 --- a/web/app/components/develop/template/template_chat.zh.mdx +++ b/web/app/components/develop/template/template_chat.zh.mdx @@ -543,7 +543,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - `data` (array[object]) 消息列表 - `id` (string) 消息 ID - `conversation_id` (string) 会话 ID - - `inputs` (array[object]) 用户输入参数。 + - `inputs` (object) 用户输入参数。 - `query` (string) 用户输入 / 提问内容。 - `message_files` (array[object]) 消息文件 - `id` (string) ID @@ -697,16 +697,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 - 当前页最后面一条记录的 ID,默认 null + (选填)当前页最后面一条记录的 ID,默认 null - 一次请求返回多少条记录 - - - 只返回置顶 true,只返回非置顶 false + (选填)一次请求返回多少条记录,默认 20 条,最大 100 条,最小 1 条。 - 排序字段(选题),默认 -updated_at(按更新时间倒序排列) + (选填)排序字段,默认 -updated_at(按更新时间倒序排列) - 可选值:created_at, -created_at, updated_at, -updated_at - 字段前面的符号代表顺序或倒序,-代表倒序 @@ -716,9 +713,11 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - `data` (array[object]) 会话列表 - `id` (string) 会话 ID - `name` (string) 会话名称,默认为会话中用户最开始问题的截取。 - - `inputs` (array[object]) 用户输入参数。 + - `inputs` (object) 用户输入参数。 + - `status` (string) 会话状态 - `introduction` (string) 开场白 - `created_at` (timestamp) 创建时间 + - `updated_at` (timestamp) 更新时间 - `has_more` (bool) - `limit` (int) 返回条数,若传入超过系统限制,返回系统限制数量 @@ -748,7 +747,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' "myName": "Lucy" }, "status": "normal", - "created_at": 1679667915 + "created_at": 1679667915, + "updated_at": 1679667915 }, { "id": "hSIhXBhNe8X1d8Et" @@ -831,10 +831,10 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' - 名称,若 `auto_generate` 为 `true` 时,该参数可不传。 + (选填)名称,若 `auto_generate` 为 `true` 时,该参数可不传。 - - 自动生成标题,默认 false。 + + (选填)自动生成标题,默认 false。 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 @@ -844,13 +844,15 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' ### Response - `id` (string) 会话 ID - `name` (string) 会话名称 - - `inputs` array[object] 用户输入参数。 + - `inputs` (object) 用户输入参数 + - `status` (string) 会话状态 - `introduction` (string) 开场白 - `created_at` (timestamp) 创建时间 + - `updated_at` (timestamp) 更新时间 - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/conversations/{conversation_id}/name' \ @@ -858,6 +860,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' --header 'Content-Type: application/json' \ --data-raw '{ "name": "", + "auto_generate": true, "user": "abc-123" }' ``` @@ -867,7 +870,13 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' ```json {{ title: 'Response' }} { - "result": "success" + "id": "34d511d5-56de-4f16-a997-57b379508443", + "name": "hello", + "inputs": {}, + "status": "normal", + "introduction": "", + "created_at": 1732731141, + "updated_at": 1732734510 } ``` @@ -969,13 +978,57 @@ import { Row, Col, Properties, Property, Heading, SubProperty } from '../md.tsx' +--- + + + + + 用于获取应用的基本信息 + ### Query + + + + 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 + + + ### Response + - `name` (string) 应用名称 + - `description` (string) 应用描述 + - `tags` (array[string]) 应用标签 + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/info?user=abc-123' \ + -H 'Authorization: Bearer {api_key}' + ``` + + + ```json {{ title: 'Response' }} + { + "name": "My App", + "description": "This is my app.", + "tags": [ + "tag1", + "tag2" + ] + } + ``` + + + --- diff --git a/web/app/components/develop/template/template_workflow.en.mdx b/web/app/components/develop/template/template_workflow.en.mdx index be2ef54743..97519611aa 100644 --- a/web/app/components/develop/template/template_workflow.en.mdx +++ b/web/app/components/develop/template/template_workflow.en.mdx @@ -54,7 +54,7 @@ Workflow applications offers non-session support and is ideal for translation, a User identifier, used to define the identity of the end-user for retrieval and statistics. Should be uniquely defined by the developer within the application. - `files` (array[object]) Optional - File list, suitable for inputting files combined with text understanding and answering questions, available only when the model supports Vision capability. + File list, suitable for inputting files combined with text understanding and answering questions, available only when the model supports file parsing and understanding capability. - `type` (string) Supported type: - `document` ('TXT', 'MD', 'MARKDOWN', 'PDF', 'HTML', 'XLSX', 'XLS', 'DOCX', 'CSV', 'EML', 'MSG', 'PPTX', 'PPT', 'XML', 'EPUB') - `image` ('JPG', 'JPEG', 'PNG', 'GIF', 'WEBP', 'SVG') @@ -113,7 +113,7 @@ Workflow applications offers non-session support and is ideal for translation, a - `title` (string) name of node - `index` (int) Execution sequence number, used to display Tracing Node sequence - `predecessor_node_id` (string) optional Prefix node ID, used for canvas display execution path - - `inputs` (array[object]) Contents of all preceding node variables used in the node + - `inputs` (object) Contents of all preceding node variables used in the node - `created_at` (timestamp) timestamp of start, e.g., 1705395332 - `event: node_finished` node execution ends, success or failure in different states in the same event - `task_id` (string) Task ID, used for request tracking and the below Stop Generate API @@ -126,7 +126,7 @@ Workflow applications offers non-session support and is ideal for translation, a - `title` (string) name of node - `index` (int) Execution sequence number, used to display Tracing Node sequence - `predecessor_node_id` (string) optional Prefix node ID, used for canvas display execution path - - `inputs` (array[object]) Contents of all preceding node variables used in the node + - `inputs` (object) Contents of all preceding node variables used in the node - `process_data` (json) Optional node process data - `outputs` (json) Optional content of output - `status` (string) status of execution, `running` / `succeeded` / `failed` / `stopped` @@ -188,6 +188,19 @@ Workflow applications offers non-session support and is ideal for translation, a }' ``` + + + ```json {{ title: 'File variable example' }} + { + "inputs": { + "{variable_name}": { + "transfer_method": "local_file", + "upload_file_id": "{upload_file_id}", + "type": "{document_type}" + } + } + } + ``` ### Blocking Mode @@ -223,7 +236,88 @@ Workflow applications offers non-session support and is ideal for translation, a data: {"event": "tts_message_end", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": ""} ``` + + ```json {{ title: 'File upload sample code' }} + { + import requests + import json + + def upload_file(file_path, user): + upload_url = "https://api.dify.ai/v1/files/upload" + headers = { + "Authorization": "Bearer app-xxxxxxxx", + } + + try: + print("Upload file...") + with open(file_path, 'rb') as file: + files = { + 'file': (file_path, file, 'text/plain') # Make sure the file is uploaded with the appropriate MIME type + } + data = { + "user": user, + "type": "TXT" # Set the file type to TXT + } + + response = requests.post(upload_url, headers=headers, files=files, data=data) + if response.status_code == 201: # 201 means creation is successful + print("File uploaded successfully") + return response.json().get("id") # Get the uploaded file ID + else: + print(f"File upload failed, status code: {response.status_code}") + return None + except Exception as e: + print(f"Error occurred: {str(e)}") + return None + + def run_workflow(file_id, user, response_mode="blocking"): + workflow_url = "https://api.dify.ai/v1/workflows/run" + headers = { + "Authorization": "Bearer app-xxxxxxxxx", + "Content-Type": "application/json" + } + data = { + "inputs": { + "orig_mail": { + "transfer_method": "local_file", + "upload_file_id": file_id, + "type": "document" + } + }, + "response_mode": response_mode, + "user": user + } + + try: + print("Run Workflow...") + response = requests.post(workflow_url, headers=headers, json=data) + if response.status_code == 200: + print("Workflow execution successful") + return response.json() + else: + print(f"Workflow execution failed, status code: {response.status_code}") + return {"status": "error", "message": f"Failed to execute workflow, status code: {response.status_code}"} + except Exception as e: + print(f"Error occurred: {str(e)}") + return {"status": "error", "message": str(e)} + + # Usage Examples + file_path = "{your_file_path}" + user = "difyuser" + + # Upload files + file_id = upload_file(file_path, user) + if file_id: + # The file was uploaded successfully, and the workflow continues to run + result = run_workflow(file_id, user) + print(result) + else: + print("File upload failed and workflow cannot be executed") + + } + ``` + @@ -404,104 +498,6 @@ Workflow applications offers non-session support and is ideal for translation, a --- - - - - Used at the start of entering the page to obtain information such as features, input parameter names, types, and default values. - - ### Query - - - - User identifier, defined by the developer's rules, must be unique within the application. - - - - ### Response - - `user_input_form` (array[object]) User input form configuration - - `text-input` (object) Text input control - - `label` (string) Variable display label name - - `variable` (string) Variable ID - - `required` (bool) Whether it is required - - `default` (string) Default value - - `paragraph` (object) Paragraph text input control - - `label` (string) Variable display label name - - `variable` (string) Variable ID - - `required` (bool) Whether it is required - - `default` (string) Default value - - `select` (object) Dropdown control - - `label` (string) Variable display label name - - `variable` (string) Variable ID - - `required` (bool) Whether it is required - - `default` (string) Default value - - `options` (array[string]) Option values - - `file_upload` (object) File upload configuration - - `image` (object) Image settings - Currently only supports image types: `png`, `jpg`, `jpeg`, `webp`, `gif` - - `enabled` (bool) Whether it is enabled - - `number_limits` (int) Image number limit, default is 3 - - `transfer_methods` (array[string]) List of transfer methods, remote_url, local_file, must choose one - - `system_parameters` (object) System parameters - - `file_size_limit` (int) Document upload size limit (MB) - - `image_file_size_limit` (int) Image file upload size limit (MB) - - `audio_file_size_limit` (int) Audio file upload size limit (MB) - - `video_file_size_limit` (int) Video file upload size limit (MB) - - - - - - - ```bash {{ title: 'cURL' }} - curl -X GET '${props.appDetail.api_base_url}/parameters?user=abc-123' \ - --header 'Authorization: Bearer {api_key}' - ``` - - - - - ```json {{ title: 'Response' }} - { - "user_input_form": [ - { - "paragraph": { - "label": "Query", - "variable": "query", - "required": true, - "default": "" - } - } - ], - "file_upload": { - "image": { - "enabled": false, - "number_limits": 3, - "detail": "high", - "transfer_methods": [ - "remote_url", - "local_file" - ] - } - }, - "system_parameters": { - "file_size_limit": 15, - "image_file_size_limit": 10, - "audio_file_size_limit": 50, - "video_file_size_limit": 100 - } - } - ``` - - - - ---- - +--- + + + + + Used to get basic information about this application + ### Query + + + + User identifier, defined by the developer's rules, must be unique within the application. + + + ### Response + - `name` (string) application name + - `description` (string) application description + - `tags` (array[string]) application tags + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/info?user=abc-123' \ + -H 'Authorization: Bearer {api_key}' + ``` + + + ```json {{ title: 'Response' }} + { + "name": "My App", + "description": "This is my app.", + "tags": [ + "tag1", + "tag2" + ] + } + ``` + + + + +--- + + + + + Used at the start of entering the page to obtain information such as features, input parameter names, types, and default values. + + ### Query + + + + User identifier, defined by the developer's rules, must be unique within the application. + + + + ### Response + - `user_input_form` (array[object]) User input form configuration + - `text-input` (object) Text input control + - `label` (string) Variable display label name + - `variable` (string) Variable ID + - `required` (bool) Whether it is required + - `default` (string) Default value + - `paragraph` (object) Paragraph text input control + - `label` (string) Variable display label name + - `variable` (string) Variable ID + - `required` (bool) Whether it is required + - `default` (string) Default value + - `select` (object) Dropdown control + - `label` (string) Variable display label name + - `variable` (string) Variable ID + - `required` (bool) Whether it is required + - `default` (string) Default value + - `options` (array[string]) Option values + - `file_upload` (object) File upload configuration + - `image` (object) Image settings + Currently only supports image types: `png`, `jpg`, `jpeg`, `webp`, `gif` + - `enabled` (bool) Whether it is enabled + - `number_limits` (int) Image number limit, default is 3 + - `transfer_methods` (array[string]) List of transfer methods, remote_url, local_file, must choose one + - `system_parameters` (object) System parameters + - `file_size_limit` (int) Document upload size limit (MB) + - `image_file_size_limit` (int) Image file upload size limit (MB) + - `audio_file_size_limit` (int) Audio file upload size limit (MB) + - `video_file_size_limit` (int) Video file upload size limit (MB) + + + + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/parameters?user=abc-123' \ + --header 'Authorization: Bearer {api_key}' + ``` + + + + + ```json {{ title: 'Response' }} + { + "user_input_form": [ + { + "paragraph": { + "label": "Query", + "variable": "query", + "required": true, + "default": "" + } + } + ], + "file_upload": { + "image": { + "enabled": false, + "number_limits": 3, + "detail": "high", + "transfer_methods": [ + "remote_url", + "local_file" + ] + } + }, + "system_parameters": { + "file_size_limit": 15, + "image_file_size_limit": 10, + "audio_file_size_limit": 50, + "video_file_size_limit": 100 + } + } + ``` + + + diff --git a/web/app/components/develop/template/template_workflow.ja.mdx b/web/app/components/develop/template/template_workflow.ja.mdx index ad669430f2..56eaeda2d7 100644 --- a/web/app/components/develop/template/template_workflow.ja.mdx +++ b/web/app/components/develop/template/template_workflow.ja.mdx @@ -54,7 +54,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from ユーザー識別子、エンドユーザーのアイデンティティを定義するために使用されます。 アプリケーション内で開発者によって一意に定義される必要があります。 - `files` (array[object]) オプション - ファイルリスト、テキストの理解と質問への回答を組み合わせたファイルの入力に適しており、モデルがビジョン機能をサポートしている場合にのみ利用可能です。 + ファイルリストは、テキスト理解と質問への回答を組み合わせたファイルの入力に適しています。モデルがファイルの解析と理解機能をサポートしている場合にのみ使用できます。 - `type` (string) サポートされているタイプ: - `document` ('TXT', 'MD', 'MARKDOWN', 'PDF', 'HTML', 'XLSX', 'XLS', 'DOCX', 'CSV', 'EML', 'MSG', 'PPTX', 'PPT', 'XML', 'EPUB') - `image` ('JPG', 'JPEG', 'PNG', 'GIF', 'WEBP', 'SVG') @@ -113,7 +113,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - `title` (string) ノードの名前 - `index` (int) 実行シーケンス番号、トレースノードシーケンスを表示するために使用 - `predecessor_node_id` (string) オプションのプレフィックスノードID、キャンバス表示実行パスに使用 - - `inputs` (array[object]) ノードで使用されるすべての前のノード変数の内容 + - `inputs` (object) ノードで使用されるすべての前のノード変数の内容 - `created_at` (timestamp) 開始のタイムスタンプ、例:1705395332 - `event: node_finished` ノード実行終了、同じイベントで異なる状態で成功または失敗 - `task_id` (string) タスクID、リクエスト追跡と以下のStop Generate APIに使用 @@ -126,7 +126,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from - `title` (string) ノードの名前 - `index` (int) 実行シーケンス番号、トレースノードシーケンスを表示するために使用 - `predecessor_node_id` (string) オプションのプレフィックスノードID、キャンバス表示実行パスに使用 - - `inputs` (array[object]) ノードで使用されるすべての前のノード変数の内容 + - `inputs` (object) ノードで使用されるすべての前のノード変数の内容 - `process_data` (json) オプションのノードプロセスデータ - `outputs` (json) オプションの出力内容 - `status` (string) 実行のステータス、`running` / `succeeded` / `failed` / `stopped` @@ -188,6 +188,19 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from }' ``` + + + ```json {{ title: 'ファイル変数の例' }} + { + "inputs": { + "{variable_name}": { + "transfer_method": "local_file", + "upload_file_id": "{upload_file_id}", + "type": "{document_type}" + } + } + } + ``` ### ブロッキングモード @@ -223,7 +236,88 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from data: {"event": "tts_message_end", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": ""} ``` + + ```json {{ title: 'ファイルアップロードのサンプルコード' }} + { + import requests + import json + + def upload_file(file_path, user): + upload_url = "https://api.dify.ai/v1/files/upload" + headers = { + "Authorization": "Bearer app-xxxxxxxx", + } + + try: + print("ファイルをアップロードしています...") + with open(file_path, 'rb') as file: + files = { + 'file': (file_path, file, 'text/plain') # ファイルが適切な MIME タイプでアップロードされていることを確認してください + } + data = { + "user": user, + "type": "TXT" # ファイルタイプをTXTに設定します + } + + response = requests.post(upload_url, headers=headers, files=files, data=data) + if response.status_code == 201: # 201 は作成が成功したことを意味します + print("ファイルが正常にアップロードされました") + return response.json().get("id") # アップロードされたファイルIDを取得する + else: + print(f"ファイルのアップロードに失敗しました。ステータス コード: {response.status_code}") + return None + except Exception as e: + print(f"エラーが発生しました: {str(e)}") + return None + + def run_workflow(file_id, user, response_mode="blocking"): + workflow_url = "https://api.dify.ai/v1/workflows/run" + headers = { + "Authorization": "Bearer app-xxxxxxxxx", + "Content-Type": "application/json" + } + data = { + "inputs": { + "orig_mail": { + "transfer_method": "local_file", + "upload_file_id": file_id, + "type": "document" + } + }, + "response_mode": response_mode, + "user": user + } + + try: + print("ワークフローを実行...") + response = requests.post(workflow_url, headers=headers, json=data) + if response.status_code == 200: + print("ワークフローが正常に実行されました") + return response.json() + else: + print(f"ワークフローの実行がステータス コードで失敗しました: {response.status_code}") + return {"status": "error", "message": f"Failed to execute workflow, status code: {response.status_code}"} + except Exception as e: + print(f"エラーが発生しました: {str(e)}") + return {"status": "error", "message": str(e)} + + # 使用例 + file_path = "{your_file_path}" + user = "difyuser" + + # ファイルをアップロードする + file_id = upload_file(file_path, user) + if file_id: + # ファイルは正常にアップロードされました。ワークフローの実行を続行します + result = run_workflow(file_id, user) + print(result) + else: + print("ファイルのアップロードに失敗し、ワークフローを実行できません") + + } + ``` + @@ -404,104 +498,6 @@ import { Row, Col, Properties, Property, Heading, SubProperty, Paragraph } from --- - - - - ページに入る際に、機能、入力パラメータ名、タイプ、デフォルト値などの情報を取得するために使用されます。 - - ### クエリ - - - - ユーザー識別子、開発者のルールで定義され、アプリケーション内で一意でなければなりません。 - - - - ### 応答 - - `user_input_form` (array[object]) ユーザー入力フォームの設定 - - `text-input` (object) テキスト入力コントロール - - `label` (string) 変数表示ラベル名 - - `variable` (string) 変数ID - - `required` (bool) 必須かどうか - - `default` (string) デフォルト値 - - `paragraph` (object) 段落テキスト入力コントロール - - `label` (string) 変数表示ラベル名 - - `variable` (string) 変数ID - - `required` (bool) 必須かどうか - - `default` (string) デフォルト値 - - `select` (object) ドロップダウンコントロール - - `label` (string) 変数表示ラベル名 - - `variable` (string) 変数ID - - `required` (bool) 必須かどうか - - `default` (string) デフォルト値 - - `options` (array[string]) オプション値 - - `file_upload` (object) ファイルアップロード設定 - - `image` (object) 画像設定 - 現在サポートされている画像タイプのみ:`png`, `jpg`, `jpeg`, `webp`, `gif` - - `enabled` (bool) 有効かどうか - - `number_limits` (int) 画像数の制限、デフォルトは3 - - `transfer_methods` (array[string]) 転送方法のリスト、remote_url, local_file、いずれかを選択する必要があります - - `system_parameters` (object) システムパラメータ - - `file_size_limit` (int) ドキュメントアップロードサイズ制限(MB) - - `image_file_size_limit` (int) 画像ファイルアップロードサイズ制限(MB) - - `audio_file_size_limit` (int) オーディオファイルアップロードサイズ制限(MB) - - `video_file_size_limit` (int) ビデオファイルアップロードサイズ制限(MB) - - - - - - - ```bash {{ title: 'cURL' }} - curl -X GET '${props.appDetail.api_base_url}/parameters?user=abc-123' \ - --header 'Authorization: Bearer {api_key}' - ``` - - - - - ```json {{ title: '応答' }} - { - "user_input_form": [ - { - "paragraph": { - "label": "Query", - "variable": "query", - "required": true, - "default": "" - } - } - ], - "file_upload": { - "image": { - "enabled": false, - "number_limits": 3, - "detail": "high", - "transfer_methods": [ - "remote_url", - "local_file" - ] - } - }, - "system_parameters": { - "file_size_limit": 15, - "image_file_size_limit": 10, - "audio_file_size_limit": 50, - "video_file_size_limit": 100 - } - } - ``` - - - - ---- - +--- + + + + + このアプリケーションの基本情報を取得するために使用されます + ### Query + + + + ユーザー識別子、開発者のルールによって定義され、アプリケーション内で一意でなければなりません。 + + + ### Response + - `name` (string) アプリケーションの名前 + - `description` (string) アプリケーションの説明 + - `tags` (array[string]) アプリケーションのタグ + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/info?user=abc-123' \ + -H 'Authorization: Bearer {api_key}' + ``` + + + ```json {{ title: 'Response' }} + { + "name": "My App", + "description": "This is my app.", + "tags": [ + "tag1", + "tag2" + ] + } + ``` + + + + +--- + + + + + ページに入る際に、機能、入力パラメータ名、タイプ、デフォルト値などの情報を取得するために使用されます。 + + ### クエリ + + + + ユーザー識別子、開発者のルールで定義され、アプリケーション内で一意でなければなりません。 + + + + ### 応答 + - `user_input_form` (array[object]) ユーザー入力フォームの設定 + - `text-input` (object) テキスト入力コントロール + - `label` (string) 変数表示ラベル名 + - `variable` (string) 変数ID + - `required` (bool) 必須かどうか + - `default` (string) デフォルト値 + - `paragraph` (object) 段落テキスト入力コントロール + - `label` (string) 変数表示ラベル名 + - `variable` (string) 変数ID + - `required` (bool) 必須かどうか + - `default` (string) デフォルト値 + - `select` (object) ドロップダウンコントロール + - `label` (string) 変数表示ラベル名 + - `variable` (string) 変数ID + - `required` (bool) 必須かどうか + - `default` (string) デフォルト値 + - `options` (array[string]) オプション値 + - `file_upload` (object) ファイルアップロード設定 + - `image` (object) 画像設定 + 現在サポートされている画像タイプのみ:`png`, `jpg`, `jpeg`, `webp`, `gif` + - `enabled` (bool) 有効かどうか + - `number_limits` (int) 画像数の制限、デフォルトは3 + - `transfer_methods` (array[string]) 転送方法のリスト、remote_url, local_file、いずれかを選択する必要があります + - `system_parameters` (object) システムパラメータ + - `file_size_limit` (int) ドキュメントアップロードサイズ制限(MB) + - `image_file_size_limit` (int) 画像ファイルアップロードサイズ制限(MB) + - `audio_file_size_limit` (int) オーディオファイルアップロードサイズ制限(MB) + - `video_file_size_limit` (int) ビデオファイルアップロードサイズ制限(MB) + + + + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/parameters?user=abc-123' \ + --header 'Authorization: Bearer {api_key}' + ``` + + + + + ```json {{ title: '応答' }} + { + "user_input_form": [ + { + "paragraph": { + "label": "Query", + "variable": "query", + "required": true, + "default": "" + } + } + ], + "file_upload": { + "image": { + "enabled": false, + "number_limits": 3, + "detail": "high", + "transfer_methods": [ + "remote_url", + "local_file" + ] + } + }, + "system_parameters": { + "file_size_limit": 15, + "image_file_size_limit": 10, + "audio_file_size_limit": 50, + "video_file_size_limit": 100 + } + } + ``` + + + diff --git a/web/app/components/develop/template/template_workflow.zh.mdx b/web/app/components/develop/template/template_workflow.zh.mdx index 5b6fdb1381..cfebb0e319 100644 --- a/web/app/components/develop/template/template_workflow.zh.mdx +++ b/web/app/components/develop/template/template_workflow.zh.mdx @@ -52,7 +52,7 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 用户标识,用于定义终端用户的身份,方便检索、统计。 由开发者定义规则,需保证用户标识在应用内唯一。 - `files` (array[object]) Optional - 文件列表,适用于传入文件结合文本理解并回答问题,仅当模型支持 Vision 能力时可用。 + 文件列表,适用于传入文件结合文本理解并回答问题,仅当模型支持该类型文件解析能力时可用。 - `type` (string) 支持类型: - `document` 具体类型包含:'TXT', 'MD', 'MARKDOWN', 'PDF', 'HTML', 'XLSX', 'XLS', 'DOCX', 'CSV', 'EML', 'MSG', 'PPTX', 'PPT', 'XML', 'EPUB' - `image` 具体类型包含:'JPG', 'JPEG', 'PNG', 'GIF', 'WEBP', 'SVG' @@ -111,7 +111,7 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 - `title` (string) 节点名称 - `index` (int) 执行序号,用于展示 Tracing Node 顺序 - `predecessor_node_id` (string) 前置节点 ID,用于画布展示执行路径 - - `inputs` (array[object]) 节点中所有使用到的前置节点变量内容 + - `inputs` (object) 节点中所有使用到的前置节点变量内容 - `created_at` (timestamp) 开始时间 - `event: node_finished` node 执行结束,成功失败同一事件中不同状态 - `task_id` (string) 任务 ID,用于请求跟踪和下方的停止响应接口 @@ -122,7 +122,7 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 - `node_id` (string) 节点 ID - `index` (int) 执行序号,用于展示 Tracing Node 顺序 - `predecessor_node_id` (string) optional 前置节点 ID,用于画布展示执行路径 - - `inputs` (array[object]) 节点中所有使用到的前置节点变量内容 + - `inputs` (object) 节点中所有使用到的前置节点变量内容 - `process_data` (json) Optional 节点过程数据 - `outputs` (json) Optional 输出内容 - `status` (string) 执行状态 `running` / `succeeded` / `failed` / `stopped` @@ -171,8 +171,7 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 - - + ```bash {{ title: 'cURL' }} curl -X POST '${props.appDetail.api_base_url}/workflows/run' \ --header 'Authorization: Bearer {api_key}' \ @@ -183,7 +182,19 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 "user": "abc-123" }' ``` - + + + ```json {{ title: 'File variable example' }} + { + "inputs": { + "{variable_name}": { + "transfer_method": "local_file", + "upload_file_id": "{upload_file_id}", + "type": "{document_type}" + } + } + } + ``` ### Blocking Mode @@ -219,7 +230,88 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 data: {"event": "tts_message_end", "conversation_id": "23dd85f3-1a41-4ea0-b7a9-062734ccfaf9", "message_id": "a8bdc41c-13b2-4c18-bfd9-054b9803038c", "created_at": 1721205487, "task_id": "3bf8a0bb-e73b-4690-9e66-4e429bad8ee7", "audio": ""} ``` + + ```json {{ title: 'File upload sample code' }} + { + import requests + import json + + def upload_file(file_path, user): + upload_url = "https://api.dify.ai/v1/files/upload" + headers = { + "Authorization": "Bearer app-xxxxxxxx", + } + + try: + print("上传文件中...") + with open(file_path, 'rb') as file: + files = { + 'file': (file_path, file, 'text/plain') # 确保文件以适当的MIME类型上传 + } + data = { + "user": user, + "type": "TXT" # 设置文件类型为TXT + } + + response = requests.post(upload_url, headers=headers, files=files, data=data) + if response.status_code == 201: # 201 表示创建成功 + print("文件上传成功") + return response.json().get("id") # 获取上传的文件 ID + else: + print(f"文件上传失败,状态码: {response.status_code}") + return None + except Exception as e: + print(f"发生错误: {str(e)}") + return None + + def run_workflow(file_id, user, response_mode="blocking"): + workflow_url = "https://api.dify.ai/v1/workflows/run" + headers = { + "Authorization": "Bearer app-xxxxxxxxx", + "Content-Type": "application/json" + } + + data = { + "inputs": { + "orig_mail": { + "transfer_method": "local_file", + "upload_file_id": file_id, + "type": "document" + } + }, + "response_mode": response_mode, + "user": user + } + try: + print("运行工作流...") + response = requests.post(workflow_url, headers=headers, json=data) + if response.status_code == 200: + print("工作流执行成功") + return response.json() + else: + print(f"工作流执行失败,状态码: {response.status_code}") + return {"status": "error", "message": f"Failed to execute workflow, status code: {response.status_code}"} + except Exception as e: + print(f"发生错误: {str(e)}") + return {"status": "error", "message": str(e)} + + # 使用示例 + file_path = "{your_file_path}" + user = "difyuser" + + # 上传文件 + file_id = upload_file(file_path, user) + if file_id: + # 文件上传成功,继续运行工作流 + result = run_workflow(file_id, user) + print(result) + else: + print("文件上传失败,无法执行工作流") + + } + ``` + @@ -398,104 +490,6 @@ Workflow 应用无会话支持,适合用于翻译/文章写作/总结 AI 等 --- - - - - 用于进入页面一开始,获取功能开关、输入参数名称、类型及默认值等使用。 - - ### Query - - - - 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 - - - - ### Response - - `user_input_form` (array[object]) 用户输入表单配置 - - `text-input` (object) 文本输入控件 - - `label` (string) 控件展示标签名 - - `variable` (string) 控件 ID - - `required` (bool) 是否必填 - - `default` (string) 默认值 - - `paragraph` (object) 段落文本输入控件 - - `label` (string) 控件展示标签名 - - `variable` (string) 控件 ID - - `required` (bool) 是否必填 - - `default` (string) 默认值 - - `select` (object) 下拉控件 - - `label` (string) 控件展示标签名 - - `variable` (string) 控件 ID - - `required` (bool) 是否必填 - - `default` (string) 默认值 - - `options` (array[string]) 选项值 - - `file_upload` (object) 文件上传配置 - - `image` (object) 图片设置 - 当前仅支持图片类型:`png`, `jpg`, `jpeg`, `webp`, `gif` - - `enabled` (bool) 是否开启 - - `number_limits` (int) 图片数量限制,默认 3 - - `transfer_methods` (array[string]) 传递方式列表,remote_url , local_file,必选一个 - - `system_parameters` (object) 系统参数 - - `file_size_limit` (int) 文档上传大小限制 (MB) - - `image_file_size_limit` (int) 图片文件上传大小限制(MB) - - `audio_file_size_limit` (int) 音频文件上传大小限制 (MB) - - `video_file_size_limit` (int) 视频文件上传大小限制 (MB) - - - - - - - ```bash {{ title: 'cURL' }} - curl -X GET '${props.appDetail.api_base_url}/parameters?user=abc-123' \ - --header 'Authorization: Bearer {api_key}' - ``` - - - - - ```json {{ title: 'Response' }} - { - "user_input_form": [ - { - "paragraph": { - "label": "Query", - "variable": "query", - "required": true, - "default": "" - } - } - ], - "file_upload": { - "image": { - "enabled": false, - "number_limits": 3, - "detail": "high", - "transfer_methods": [ - "remote_url", - "local_file" - ] - } - }, - "system_parameters": { - "file_size_limit": 15, - "image_file_size_limit": 10, - "audio_file_size_limit": 50, - "video_file_size_limit": 100 - } - } - ``` - - - - ---- - +--- + + + + + 用于获取应用的基本信息 + ### Query + + + + 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 + + + ### Response + - `name` (string) 应用名称 + - `description` (string) 应用描述 + - `tags` (array[string]) 应用标签 + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/info?user=abc-123' \ + -H 'Authorization: Bearer {api_key}' + ``` + + + ```json {{ title: 'Response' }} + { + "name": "My App", + "description": "This is my app.", + "tags": [ + "tag1", + "tag2" + ] + } + ``` + + + + +--- + + + + + 用于进入页面一开始,获取功能开关、输入参数名称、类型及默认值等使用。 + + ### Query + + + + 用户标识,由开发者定义规则,需保证用户标识在应用内唯一。 + + + + ### Response + - `user_input_form` (array[object]) 用户输入表单配置 + - `text-input` (object) 文本输入控件 + - `label` (string) 控件展示标签名 + - `variable` (string) 控件 ID + - `required` (bool) 是否必填 + - `default` (string) 默认值 + - `paragraph` (object) 段落文本输入控件 + - `label` (string) 控件展示标签名 + - `variable` (string) 控件 ID + - `required` (bool) 是否必填 + - `default` (string) 默认值 + - `select` (object) 下拉控件 + - `label` (string) 控件展示标签名 + - `variable` (string) 控件 ID + - `required` (bool) 是否必填 + - `default` (string) 默认值 + - `options` (array[string]) 选项值 + - `file_upload` (object) 文件上传配置 + - `image` (object) 图片设置 + 当前仅支持图片类型:`png`, `jpg`, `jpeg`, `webp`, `gif` + - `enabled` (bool) 是否开启 + - `number_limits` (int) 图片数量限制,默认 3 + - `transfer_methods` (array[string]) 传递方式列表,remote_url , local_file,必选一个 + - `system_parameters` (object) 系统参数 + - `file_size_limit` (int) 文档上传大小限制 (MB) + - `image_file_size_limit` (int) 图片文件上传大小限制(MB) + - `audio_file_size_limit` (int) 音频文件上传大小限制 (MB) + - `video_file_size_limit` (int) 视频文件上传大小限制 (MB) + + + + + + + ```bash {{ title: 'cURL' }} + curl -X GET '${props.appDetail.api_base_url}/parameters?user=abc-123' \ + --header 'Authorization: Bearer {api_key}' + ``` + + + + + ```json {{ title: 'Response' }} + { + "user_input_form": [ + { + "paragraph": { + "label": "Query", + "variable": "query", + "required": true, + "default": "" + } + } + ], + "file_upload": { + "image": { + "enabled": false, + "number_limits": 3, + "detail": "high", + "transfer_methods": [ + "remote_url", + "local_file" + ] + } + }, + "system_parameters": { + "file_size_limit": 15, + "image_file_size_limit": 10, + "audio_file_size_limit": 50, + "video_file_size_limit": 100 + } + } + ``` + + + diff --git a/web/app/components/explore/app-list/index.tsx b/web/app/components/explore/app-list/index.tsx index 6186d8164d..b8e7939328 100644 --- a/web/app/components/explore/app-list/index.tsx +++ b/web/app/components/explore/app-list/index.tsx @@ -14,7 +14,7 @@ import type { App } from '@/models/explore' import Category from '@/app/components/explore/category' import AppCard from '@/app/components/explore/app-card' import { fetchAppDetail, fetchAppList } from '@/service/explore' -import { importApp } from '@/service/apps' +import { importDSL } from '@/service/apps' import { useTabSearchParams } from '@/hooks/use-tab-searchparams' import CreateAppModal from '@/app/components/explore/create-app-modal' import AppTypeSelector from '@/app/components/app/type-selector' @@ -24,6 +24,7 @@ import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { useAppContext } from '@/context/app-context' import { getRedirection } from '@/utils/app-redirection' import Input from '@/app/components/base/input' +import { DSLImportMode } from '@/models/app' type AppsProps = { pageType?: PageType @@ -127,8 +128,9 @@ const Apps = ({ currApp?.app.id as string, ) try { - const app = await importApp({ - data: export_data, + const app = await importDSL({ + mode: DSLImportMode.YAML_CONTENT, + yaml_content: export_data, name, icon_type, icon, @@ -143,7 +145,7 @@ const Apps = ({ if (onSuccess) onSuccess() localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') - getRedirection(isCurrentWorkspaceEditor, app, push) + getRedirection(isCurrentWorkspaceEditor, { id: app.app_id }, push) } catch (e) { Toast.notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) diff --git a/web/app/components/header/account-setting/model-provider-page/model-selector/popup.tsx b/web/app/components/header/account-setting/model-provider-page/model-selector/popup.tsx index 1a910aba08..1e43439d15 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-selector/popup.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-selector/popup.tsx @@ -25,22 +25,18 @@ const Popup: FC = ({ const language = useLanguage() const [searchText, setSearchText] = useState('') - const filteredModelList = modelList.filter( - model => model.models.filter( - (modelItem) => { - if (modelItem.label[language] !== undefined) - return modelItem.label[language].toLowerCase().includes(searchText.toLowerCase()) + const filteredModelList = modelList.map((model) => { + const filteredModels = model.models.filter((modelItem) => { + if (modelItem.label[language] !== undefined) + return modelItem.label[language].toLowerCase().includes(searchText.toLowerCase()) - let found = false - Object.keys(modelItem.label).forEach((key) => { - if (modelItem.label[key].toLowerCase().includes(searchText.toLowerCase())) - found = true - }) + return Object.values(modelItem.label).some(label => + label.toLowerCase().includes(searchText.toLowerCase()), + ) + }) - return found - }, - ).length, - ) + return { ...model, models: filteredModels } + }).filter(model => model.models.length > 0) return (
diff --git a/web/app/components/header/index.tsx b/web/app/components/header/index.tsx index 3757d552df..1d7349ccd0 100644 --- a/web/app/components/header/index.tsx +++ b/web/app/components/header/index.tsx @@ -4,6 +4,7 @@ import Link from 'next/link' import { useBoolean } from 'ahooks' import { useSelectedLayoutSegment } from 'next/navigation' import { Bars3Icon } from '@heroicons/react/20/solid' +import { useContextSelector } from 'use-context-selector' import HeaderBillingBtn from '../billing/header-billing-btn' import AccountDropdown from './account-dropdown' import AppNav from './app-nav' @@ -14,11 +15,12 @@ import ToolsNav from './tools-nav' import GithubStar from './github-star' import LicenseNav from './license-env' import { WorkspaceProvider } from '@/context/workspace-context' -import { useAppContext } from '@/context/app-context' +import AppContext, { useAppContext } from '@/context/app-context' import LogoSite from '@/app/components/base/logo/logo-site' import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints' import { useProviderContext } from '@/context/provider-context' import { useModalContext } from '@/context/modal-context' +import { LicenseStatus } from '@/types/feature' const navClassName = ` flex items-center relative mr-0 sm:mr-3 px-3 h-8 rounded-xl @@ -28,7 +30,7 @@ const navClassName = ` const Header = () => { const { isCurrentWorkspaceEditor, isCurrentWorkspaceDatasetOperator } = useAppContext() - + const systemFeatures = useContextSelector(AppContext, v => v.systemFeatures) const selectedSegment = useSelectedLayoutSegment() const media = useBreakpoints() const isMobile = media === MediaType.mobile @@ -60,7 +62,7 @@ const Header = () => { - + {systemFeatures.license.status === LicenseStatus.NONE && } }
{isMobile && ( @@ -68,7 +70,7 @@ const Header = () => { - + {systemFeatures.license.status === LicenseStatus.NONE && }
)} {!isMobile && ( diff --git a/web/app/components/workflow/block-icon.tsx b/web/app/components/workflow/block-icon.tsx index b115a7b3c3..1001e981c5 100644 --- a/web/app/components/workflow/block-icon.tsx +++ b/web/app/components/workflow/block-icon.tsx @@ -48,6 +48,7 @@ const getIcon = (type: BlockEnum, className: string) => { [BlockEnum.VariableAggregator]: , [BlockEnum.Assigner]: , [BlockEnum.Tool]: , + [BlockEnum.IterationStart]: , [BlockEnum.Iteration]: , [BlockEnum.ParameterExtractor]: , [BlockEnum.DocExtractor]: , diff --git a/web/app/components/workflow/hooks/use-workflow-run.ts b/web/app/components/workflow/hooks/use-workflow-run.ts index eab3535505..24b20b5274 100644 --- a/web/app/components/workflow/hooks/use-workflow-run.ts +++ b/web/app/components/workflow/hooks/use-workflow-run.ts @@ -271,13 +271,18 @@ export const useWorkflowRun = () => { } as any) } else { - if (!iterParallelLogMap.has(data.parallel_run_id)) - iterParallelLogMap.set(data.parallel_run_id, [{ ...data, status: NodeRunningStatus.Running } as any]) + const nodeId = iterations?.node_id as string + if (!iterParallelLogMap.has(nodeId as string)) + iterParallelLogMap.set(iterations?.node_id as string, new Map()) + + const currentIterLogMap = iterParallelLogMap.get(nodeId)! + if (!currentIterLogMap.has(data.parallel_run_id)) + currentIterLogMap.set(data.parallel_run_id, [{ ...data, status: NodeRunningStatus.Running } as any]) else - iterParallelLogMap.get(data.parallel_run_id)!.push({ ...data, status: NodeRunningStatus.Running } as any) + currentIterLogMap.get(data.parallel_run_id)!.push({ ...data, status: NodeRunningStatus.Running } as any) setIterParallelLogMap(iterParallelLogMap) if (iterations) - iterations.details = Array.from(iterParallelLogMap.values()) + iterations.details = Array.from(currentIterLogMap.values()) } })) } @@ -373,7 +378,7 @@ export const useWorkflowRun = () => { if (iterations && iterations.details) { const iterRunID = data.execution_metadata?.parallel_mode_run_id - const currIteration = iterParallelLogMap.get(iterRunID) + const currIteration = iterParallelLogMap.get(iterations.node_id)?.get(iterRunID) const nodeIndex = currIteration?.findIndex(node => node.node_id === data.node_id && ( node?.parallel_run_id === data.execution_metadata?.parallel_mode_run_id), @@ -392,7 +397,9 @@ export const useWorkflowRun = () => { } } setIterParallelLogMap(iterParallelLogMap) - iterations.details = Array.from(iterParallelLogMap.values()) + const iterLogMap = iterParallelLogMap.get(iterations.node_id) + if (iterLogMap) + iterations.details = Array.from(iterLogMap.values()) } })) } diff --git a/web/app/components/workflow/hooks/use-workflow-template.ts b/web/app/components/workflow/hooks/use-workflow-template.ts index e36f0b61f9..c2dc956b63 100644 --- a/web/app/components/workflow/hooks/use-workflow-template.ts +++ b/web/app/components/workflow/hooks/use-workflow-template.ts @@ -22,6 +22,7 @@ export const useWorkflowTemplate = () => { ...nodesInitialData.llm, memory: { window: { enabled: false, size: 10 }, + query_prompt_template: '{{#sys.query#}}', }, selected: true, }, diff --git a/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx b/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx index 6a3da3cf24..79d9c5b4dd 100644 --- a/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/before-run-form/index.tsx @@ -16,6 +16,7 @@ import { InputVarType, NodeRunningStatus } from '@/app/components/workflow/types import ResultPanel from '@/app/components/workflow/run/result-panel' import Toast from '@/app/components/base/toast' import { TransferMethod } from '@/types/app' +import { getProcessedFiles } from '@/app/components/base/file-uploader/utils' const i18nPrefix = 'workflow.singleRun' @@ -39,6 +40,11 @@ function formatValue(value: string | any, type: InputVarType) { return JSON.parse(item) }) } + if (type === InputVarType.multiFiles) + return getProcessedFiles(value) + + if (type === InputVarType.singleFile) + return getProcessedFiles([value])[0] return value } diff --git a/web/app/components/workflow/nodes/_base/components/editor/code-editor/index.tsx b/web/app/components/workflow/nodes/_base/components/editor/code-editor/index.tsx index 28d07936d3..2d75679b08 100644 --- a/web/app/components/workflow/nodes/_base/components/editor/code-editor/index.tsx +++ b/web/app/components/workflow/nodes/_base/components/editor/code-editor/index.tsx @@ -33,6 +33,7 @@ export type Props = { showFileList?: boolean onGenerated?: (value: string) => void showCodeGenerator?: boolean + className?: string } export const languageMap = { @@ -67,6 +68,7 @@ const CodeEditor: FC = ({ showFileList, onGenerated, showCodeGenerator = false, + className, }) => { const [isFocus, setIsFocus] = React.useState(false) const [isMounted, setIsMounted] = React.useState(false) @@ -187,7 +189,7 @@ const CodeEditor: FC = ({ ) return ( -
+
{noWrapper ?
= ({ triggerClassName='w-4 h-4 ml-1' /> )} -
{operations &&
{operations}
} diff --git a/web/app/components/workflow/nodes/_base/components/list-no-data-placeholder.tsx b/web/app/components/workflow/nodes/_base/components/list-no-data-placeholder.tsx index 4ec9d27f50..bf592deaec 100644 --- a/web/app/components/workflow/nodes/_base/components/list-no-data-placeholder.tsx +++ b/web/app/components/workflow/nodes/_base/components/list-no-data-placeholder.tsx @@ -10,7 +10,7 @@ const ListNoDataPlaceholder: FC = ({ children, }) => { return ( -
+
{children}
) diff --git a/web/app/components/workflow/nodes/_base/components/memory-config.tsx b/web/app/components/workflow/nodes/_base/components/memory-config.tsx index c108608739..476f5b738c 100644 --- a/web/app/components/workflow/nodes/_base/components/memory-config.tsx +++ b/web/app/components/workflow/nodes/_base/components/memory-config.tsx @@ -53,7 +53,7 @@ type Props = { const MEMORY_DEFAULT: Memory = { window: { enabled: false, size: WINDOW_SIZE_DEFAULT }, - query_prompt_template: '', + query_prompt_template: '{{#sys.query#}}', } const MemoryConfig: FC = ({ diff --git a/web/app/components/workflow/nodes/_base/components/variable/assigned-var-reference-popup.tsx b/web/app/components/workflow/nodes/_base/components/variable/assigned-var-reference-popup.tsx new file mode 100644 index 0000000000..9ad5ad4a5a --- /dev/null +++ b/web/app/components/workflow/nodes/_base/components/variable/assigned-var-reference-popup.tsx @@ -0,0 +1,39 @@ +'use client' +import type { FC } from 'react' +import React from 'react' +import { useTranslation } from 'react-i18next' +import VarReferenceVars from './var-reference-vars' +import type { NodeOutPutVar, ValueSelector, Var } from '@/app/components/workflow/types' +import ListEmpty from '@/app/components/base/list-empty' + +type Props = { + vars: NodeOutPutVar[] + onChange: (value: ValueSelector, varDetail: Var) => void + itemWidth?: number +} +const AssignedVarReferencePopup: FC = ({ + vars, + onChange, + itemWidth, +}) => { + const { t } = useTranslation() + // max-h-[300px] overflow-y-auto todo: use portal to handle long list + return ( +
+ {(!vars || vars.length === 0) + ? + : + } +
+ ) +} +export default React.memo(AssignedVarReferencePopup) diff --git a/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx b/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx index 0c553a2738..e4d354a615 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx @@ -60,6 +60,9 @@ type Props = { onRemove?: () => void typePlaceHolder?: string isSupportFileVar?: boolean + placeholder?: string + minWidth?: number + popupFor?: 'assigned' | 'toAssigned' } const VarReferencePicker: FC = ({ @@ -83,6 +86,9 @@ const VarReferencePicker: FC = ({ onRemove, typePlaceHolder, isSupportFileVar = true, + placeholder, + minWidth, + popupFor, }) => { const { t } = useTranslation() const store = useStoreApi() @@ -261,7 +267,7 @@ const VarReferencePicker: FC = ({ { }}>
) - : (
+ : (
{isSupportConstantValue ?
{ e.stopPropagation() @@ -285,7 +291,7 @@ const VarReferencePicker: FC = ({ />
: (!hasValue &&
- +
)} {isConstant ? ( @@ -329,17 +335,17 @@ const VarReferencePicker: FC = ({ {!hasValue && } {isEnv && } {isChatVar && } -
{varName}
-
{type}
{!isValidVar && } ) - :
{t('workflow.common.setVarValuePlaceholder')}
} + :
{placeholder ?? t('workflow.common.setVarValuePlaceholder')}
}
@@ -378,12 +384,13 @@ const VarReferencePicker: FC = ({ + }} className='mt-1'> {!isConstant && ( )} diff --git a/web/app/components/workflow/nodes/_base/components/variable/var-reference-popup.tsx b/web/app/components/workflow/nodes/_base/components/variable/var-reference-popup.tsx index cd03da1556..d9a4d2c946 100644 --- a/web/app/components/workflow/nodes/_base/components/variable/var-reference-popup.tsx +++ b/web/app/components/workflow/nodes/_base/components/variable/var-reference-popup.tsx @@ -1,33 +1,64 @@ 'use client' import type { FC } from 'react' import React from 'react' +import { useTranslation } from 'react-i18next' +import { useContext } from 'use-context-selector' import VarReferenceVars from './var-reference-vars' import type { NodeOutPutVar, ValueSelector, Var } from '@/app/components/workflow/types' +import ListEmpty from '@/app/components/base/list-empty' +import { LanguagesSupported } from '@/i18n/language' +import I18n from '@/context/i18n' type Props = { vars: NodeOutPutVar[] + popupFor?: 'assigned' | 'toAssigned' onChange: (value: ValueSelector, varDetail: Var) => void itemWidth?: number isSupportFileVar?: boolean } const VarReferencePopup: FC = ({ vars, + popupFor, onChange, itemWidth, isSupportFileVar = true, }) => { + const { t } = useTranslation() + const { locale } = useContext(I18n) // max-h-[300px] overflow-y-auto todo: use portal to handle long list return (
- + {((!vars || vars.length === 0) && popupFor) + ? (popupFor === 'toAssigned' + ? ( + + {t('workflow.variableReference.noVarsForOperation')} +
} + /> + ) + : ( + + {t('workflow.variableReference.assignedVarsDescription')} +
{t('workflow.variableReference.conversationVars')} +
} + /> + )) + : + }
) } diff --git a/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts b/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts index c500f0c8cf..6791a2f746 100644 --- a/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts +++ b/web/app/components/workflow/nodes/_base/hooks/use-one-step-run.ts @@ -24,6 +24,7 @@ import QuestionClassifyDefault from '@/app/components/workflow/nodes/question-cl import HTTPDefault from '@/app/components/workflow/nodes/http/default' import ToolDefault from '@/app/components/workflow/nodes/tool/default' import VariableAssigner from '@/app/components/workflow/nodes/variable-assigner/default' +import Assigner from '@/app/components/workflow/nodes/assigner/default' import ParameterExtractorDefault from '@/app/components/workflow/nodes/parameter-extractor/default' import IterationDefault from '@/app/components/workflow/nodes/iteration/default' import { ssePost } from '@/service/base' @@ -39,6 +40,7 @@ const { checkValid: checkQuestionClassifyValid } = QuestionClassifyDefault const { checkValid: checkHttpValid } = HTTPDefault const { checkValid: checkToolValid } = ToolDefault const { checkValid: checkVariableAssignerValid } = VariableAssigner +const { checkValid: checkAssignerValid } = Assigner const { checkValid: checkParameterExtractorValid } = ParameterExtractorDefault const { checkValid: checkIterationValid } = IterationDefault @@ -51,7 +53,7 @@ const checkValidFns: Record = { [BlockEnum.QuestionClassifier]: checkQuestionClassifyValid, [BlockEnum.HttpRequest]: checkHttpValid, [BlockEnum.Tool]: checkToolValid, - [BlockEnum.VariableAssigner]: checkVariableAssignerValid, + [BlockEnum.VariableAssigner]: checkAssignerValid, [BlockEnum.VariableAggregator]: checkVariableAssignerValid, [BlockEnum.ParameterExtractor]: checkParameterExtractorValid, [BlockEnum.Iteration]: checkIterationValid, diff --git a/web/app/components/workflow/nodes/assigner/components/operation-selector.tsx b/web/app/components/workflow/nodes/assigner/components/operation-selector.tsx new file mode 100644 index 0000000000..8542bb4829 --- /dev/null +++ b/web/app/components/workflow/nodes/assigner/components/operation-selector.tsx @@ -0,0 +1,128 @@ +import type { FC } from 'react' +import { useState } from 'react' +import { + RiArrowDownSLine, + RiCheckLine, +} from '@remixicon/react' +import classNames from 'classnames' +import { useTranslation } from 'react-i18next' +import type { WriteMode } from '../types' +import { getOperationItems } from '../utils' +import { + PortalToFollowElem, + PortalToFollowElemContent, + PortalToFollowElemTrigger, +} from '@/app/components/base/portal-to-follow-elem' +import type { VarType } from '@/app/components/workflow/types' +import Divider from '@/app/components/base/divider' + +type Item = { + value: string | number + name: string +} + +type OperationSelectorProps = { + value: string | number + onSelect: (value: Item) => void + placeholder?: string + disabled?: boolean + className?: string + popupClassName?: string + assignedVarType?: VarType + writeModeTypes?: WriteMode[] + writeModeTypesArr?: WriteMode[] + writeModeTypesNum?: WriteMode[] +} + +const i18nPrefix = 'workflow.nodes.assigner' + +const OperationSelector: FC = ({ + value, + onSelect, + disabled = false, + className, + popupClassName, + assignedVarType, + writeModeTypes, + writeModeTypesArr, + writeModeTypesNum, +}) => { + const { t } = useTranslation() + const [open, setOpen] = useState(false) + + const items = getOperationItems(assignedVarType, writeModeTypes, writeModeTypesArr, writeModeTypesNum) + + const selectedItem = items.find(item => item.value === value) + + return ( + + !disabled && setOpen(v => !v)} + > +
+
+ + {selectedItem?.name ? t(`${i18nPrefix}.operations.${selectedItem?.name}`) : t(`${i18nPrefix}.operations.title`)} + +
+ +
+
+ + +
+
+
+
{t(`${i18nPrefix}.operations.title`)}
+
+ {items.map(item => ( + item.value === 'divider' + ? ( + + ) + : ( +
{ + onSelect(item) + setOpen(false) + }} + > +
+ {t(`${i18nPrefix}.operations.${item.name}`)} +
+ {item.value === value && ( +
+ +
+ )} +
+ ) + ))} +
+
+
+
+ ) +} + +export default OperationSelector diff --git a/web/app/components/workflow/nodes/assigner/components/var-list/index.tsx b/web/app/components/workflow/nodes/assigner/components/var-list/index.tsx new file mode 100644 index 0000000000..42ee9845dd --- /dev/null +++ b/web/app/components/workflow/nodes/assigner/components/var-list/index.tsx @@ -0,0 +1,227 @@ +'use client' +import type { FC } from 'react' +import { useTranslation } from 'react-i18next' +import React, { useCallback } from 'react' +import produce from 'immer' +import { RiDeleteBinLine } from '@remixicon/react' +import OperationSelector from '../operation-selector' +import { AssignerNodeInputType, WriteMode } from '../../types' +import type { AssignerNodeOperation } from '../../types' +import ListNoDataPlaceholder from '@/app/components/workflow/nodes/_base/components/list-no-data-placeholder' +import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker' +import type { ValueSelector, Var, VarType } from '@/app/components/workflow/types' +import { CodeLanguage } from '@/app/components/workflow/nodes/code/types' +import ActionButton from '@/app/components/base/action-button' +import Input from '@/app/components/base/input' +import Textarea from '@/app/components/base/textarea' +import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor' + +type Props = { + readonly: boolean + nodeId: string + list: AssignerNodeOperation[] + onChange: (list: AssignerNodeOperation[], value?: ValueSelector) => void + onOpen?: (index: number) => void + filterVar?: (payload: Var, valueSelector: ValueSelector) => boolean + filterToAssignedVar?: (payload: Var, assignedVarType: VarType, write_mode: WriteMode) => boolean + getAssignedVarType?: (valueSelector: ValueSelector) => VarType + getToAssignedVarType?: (assignedVarType: VarType, write_mode: WriteMode) => VarType + writeModeTypes?: WriteMode[] + writeModeTypesArr?: WriteMode[] + writeModeTypesNum?: WriteMode[] +} + +const VarList: FC = ({ + readonly, + nodeId, + list, + onChange, + onOpen = () => { }, + filterVar, + filterToAssignedVar, + getAssignedVarType, + getToAssignedVarType, + writeModeTypes, + writeModeTypesArr, + writeModeTypesNum, +}) => { + const { t } = useTranslation() + const handleAssignedVarChange = useCallback((index: number) => { + return (value: ValueSelector | string) => { + const newList = produce(list, (draft) => { + draft[index].variable_selector = value as ValueSelector + draft[index].operation = WriteMode.overwrite + draft[index].value = undefined + }) + onChange(newList, value as ValueSelector) + } + }, [list, onChange]) + + const handleOperationChange = useCallback((index: number) => { + return (item: { value: string | number }) => { + const newList = produce(list, (draft) => { + draft[index].operation = item.value as WriteMode + draft[index].value = '' // Clear value when operation changes + if (item.value === WriteMode.set || item.value === WriteMode.increment || item.value === WriteMode.decrement + || item.value === WriteMode.multiply || item.value === WriteMode.divide) + draft[index].input_type = AssignerNodeInputType.constant + else + draft[index].input_type = AssignerNodeInputType.variable + }) + onChange(newList) + } + }, [list, onChange]) + + const handleToAssignedVarChange = useCallback((index: number) => { + return (value: ValueSelector | string | number) => { + const newList = produce(list, (draft) => { + draft[index].value = value as ValueSelector + }) + onChange(newList, value as ValueSelector) + } + }, [list, onChange]) + + const handleVarRemove = useCallback((index: number) => { + return () => { + const newList = produce(list, (draft) => { + draft.splice(index, 1) + }) + onChange(newList) + } + }, [list, onChange]) + + const handleOpen = useCallback((index: number) => { + return () => onOpen(index) + }, [onOpen]) + + const handleFilterToAssignedVar = useCallback((index: number) => { + return (payload: Var, valueSelector: ValueSelector) => { + const item = list[index] + const assignedVarType = item.variable_selector ? getAssignedVarType?.(item.variable_selector) : undefined + + if (!filterToAssignedVar || !item.variable_selector || !assignedVarType || !item.operation) + return true + + return filterToAssignedVar( + payload, + assignedVarType, + item.operation, + ) + } + }, [list, filterToAssignedVar, getAssignedVarType]) + + if (list.length === 0) { + return ( + + {t('workflow.nodes.assigner.noVarTip')} + + ) + } + + return ( +
+ {list.map((item, index) => { + const assignedVarType = item.variable_selector ? getAssignedVarType?.(item.variable_selector) : undefined + const toAssignedVarType = (assignedVarType && item.operation && getToAssignedVarType) + ? getToAssignedVarType(assignedVarType, item.operation) + : undefined + + return ( +
+
+
+ + +
+ {item.operation !== WriteMode.clear && item.operation !== WriteMode.set + && !writeModeTypesNum?.includes(item.operation) + && ( + + ) + } + {item.operation === WriteMode.set && assignedVarType && ( + <> + {assignedVarType === 'number' && ( + handleToAssignedVarChange(index)(Number(e.target.value))} + className='w-full' + /> + )} + {assignedVarType === 'string' && ( +