diff --git a/.devcontainer/README.md b/.devcontainer/README.md
index df12a3c2d6..2b18630a21 100644
--- a/.devcontainer/README.md
+++ b/.devcontainer/README.md
@@ -34,4 +34,4 @@ if you see such error message when you open this project in codespaces:

a simple workaround is change `/signin` endpoint into another one, then login with GitHub account and close the tab, then change it back to `/signin` endpoint. Then all things will be fine.
-The reason is `signin` endpoint is not allowed in codespaces, details can be found [here](https://github.com/orgs/community/discussions/5204)
\ No newline at end of file
+The reason is `signin` endpoint is not allowed in codespaces, details can be found [here](https://github.com/orgs/community/discussions/5204)
diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
index 339ad60ce0..8246544061 100644
--- a/.devcontainer/devcontainer.json
+++ b/.devcontainer/devcontainer.json
@@ -2,7 +2,7 @@
// README at: https://github.com/devcontainers/templates/tree/main/src/anaconda
{
"name": "Python 3.12",
- "build": {
+ "build": {
"context": "..",
"dockerfile": "Dockerfile"
},
diff --git a/.devcontainer/noop.txt b/.devcontainer/noop.txt
index dde8dc3c10..49de88dbd4 100644
--- a/.devcontainer/noop.txt
+++ b/.devcontainer/noop.txt
@@ -1,3 +1,3 @@
This file copied into the container along with environment.yml* from the parent
-folder. This file is included to prevents the Dockerfile COPY instruction from
-failing if no environment.yml is found.
\ No newline at end of file
+folder. This file is included to prevents the Dockerfile COPY instruction from
+failing if no environment.yml is found.
diff --git a/web/.editorconfig b/.editorconfig
similarity index 51%
rename from web/.editorconfig
rename to .editorconfig
index e1d3f0b992..374da0b5d2 100644
--- a/web/.editorconfig
+++ b/.editorconfig
@@ -5,18 +5,35 @@ root = true
# Unix-style newlines with a newline ending every file
[*]
+charset = utf-8
end_of_line = lf
insert_final_newline = true
+trim_trailing_whitespace = true
+
+[*.py]
+indent_size = 4
+indent_style = space
+
+[*.{yml,yaml}]
+indent_style = space
+indent_size = 2
+
+[*.toml]
+indent_size = 4
+indent_style = space
+
+# Markdown and MDX are whitespace sensitive languages.
+# Do not remove trailing spaces.
+[*.{md,mdx}]
+trim_trailing_whitespace = false
# Matches multiple files with brace expansion notation
# Set default charset
[*.{js,tsx}]
-charset = utf-8
indent_style = space
indent_size = 2
-
-# Matches the exact files either package.json or .travis.yml
-[{package.json,.travis.yml}]
+# Matches the exact files package.json
+[package.json]
indent_style = space
indent_size = 2
diff --git a/.gitattributes b/.gitattributes
index a10da53408..a32a39f65c 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -1,5 +1,5 @@
# Ensure that .sh scripts use LF as line separator, even if they are checked out
-# to Windows(NTFS) file-system, by a user of Docker for Windows.
+# to Windows(NTFS) file-system, by a user of Docker for Windows.
# These .sh scripts will be run from the Container after `docker compose up -d`.
# If they appear to be CRLF style, Dash from the Container will fail to execute
# them.
diff --git a/.github/linters/editorconfig-checker.json b/.github/linters/editorconfig-checker.json
new file mode 100644
index 0000000000..ce6e9ae341
--- /dev/null
+++ b/.github/linters/editorconfig-checker.json
@@ -0,0 +1,22 @@
+{
+ "Verbose": false,
+ "Debug": false,
+ "IgnoreDefaults": false,
+ "SpacesAfterTabs": false,
+ "NoColor": false,
+ "Exclude": [
+ "^web/public/vs/",
+ "^web/public/pdf.worker.min.mjs$",
+ "web/app/components/base/icons/src/vender/"
+ ],
+ "AllowedContentTypes": [],
+ "PassedFiles": [],
+ "Disable": {
+ "EndOfLine": false,
+ "Indentation": false,
+ "IndentSize": true,
+ "InsertFinalNewline": false,
+ "TrimTrailingWhitespace": false,
+ "MaxLineLength": false
+ }
+}
diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml
index 02583cda06..f08befefb8 100644
--- a/.github/workflows/api-tests.yml
+++ b/.github/workflows/api-tests.yml
@@ -88,3 +88,6 @@ jobs:
- name: Run Workflow
run: uv run --project api bash dev/pytest/pytest_workflow.sh
+
+ - name: Run Tool
+ run: uv run --project api bash dev/pytest/pytest_tools.sh
diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml
index 98e5fd5150..30c0ff000d 100644
--- a/.github/workflows/style.yml
+++ b/.github/workflows/style.yml
@@ -9,6 +9,12 @@ concurrency:
group: style-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
+permissions:
+ checks: write
+ statuses: write
+ contents: read
+
+
jobs:
python-style:
name: Python Style
@@ -43,8 +49,8 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: |
uv run --directory api ruff --version
- uv run --directory api ruff check ./
- uv run --directory api ruff format --check ./
+ uv run --directory api ruff check --diff ./
+ uv run --directory api ruff format --check --diff ./
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'
@@ -163,3 +169,14 @@ jobs:
VALIDATE_DOCKERFILE_HADOLINT: true
VALIDATE_XML: true
VALIDATE_YAML: true
+
+ - name: EditorConfig checks
+ uses: super-linter/super-linter/slim@v7
+ env:
+ DEFAULT_BRANCH: main
+ GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
+ IGNORE_GENERATED_FILES: true
+ IGNORE_GITIGNORED_FILES: true
+ # EditorConfig validation
+ VALIDATE_EDITORCONFIG: true
+ EDITORCONFIG_FILE_NAME: editorconfig-checker.json
diff --git a/CONTRIBUTING_CN.md b/CONTRIBUTING_CN.md
index 0478d2e1fa..69ae7071bb 100644
--- a/CONTRIBUTING_CN.md
+++ b/CONTRIBUTING_CN.md
@@ -6,7 +6,7 @@
本指南和 Dify 一样在不断完善中。如果有任何滞后于项目实际情况的地方,恳请谅解,我们也欢迎任何改进建议。
-关于许可证,请花一分钟阅读我们简短的[许可和贡献者协议](./LICENSE)。社区同时也遵循[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
+关于许可证,请花一分钟阅读我们简短的[许可和贡献者协议](./LICENSE)。同时也请遵循社区[行为准则](https://github.com/langgenius/.github/blob/main/CODE_OF_CONDUCT.md)。
## 开始之前
diff --git a/CONTRIBUTING_ES.md b/CONTRIBUTING_ES.md
index 261aa0fda1..98cbb5b457 100644
--- a/CONTRIBUTING_ES.md
+++ b/CONTRIBUTING_ES.md
@@ -90,4 +90,4 @@ Recomendamos revisar este documento cuidadosamente antes de proceder con la conf
No dudes en contactarnos si encuentras algún problema durante el proceso de configuración.
## Obteniendo Ayuda
-Si alguna vez te quedas atascado o tienes una pregunta urgente mientras contribuyes, simplemente envíanos tus consultas a través del issue relacionado de GitHub, o únete a nuestro [Discord](https://discord.gg/8Tpq4AcN9c) para una charla rápida.
\ No newline at end of file
+Si alguna vez te quedas atascado o tienes una pregunta urgente mientras contribuyes, simplemente envíanos tus consultas a través del issue relacionado de GitHub, o únete a nuestro [Discord](https://discord.gg/8Tpq4AcN9c) para una charla rápida.
diff --git a/CONTRIBUTING_FR.md b/CONTRIBUTING_FR.md
index c3418f86cc..fc8410dfd6 100644
--- a/CONTRIBUTING_FR.md
+++ b/CONTRIBUTING_FR.md
@@ -90,4 +90,4 @@ Nous recommandons de revoir attentivement ce document avant de procéder à la c
N'hésitez pas à nous contacter si vous rencontrez des problèmes pendant le processus de configuration.
## Obtenir de l'aide
-Si jamais vous êtes bloqué ou avez une question urgente en contribuant, envoyez-nous simplement vos questions via le problème GitHub concerné, ou rejoignez notre [Discord](https://discord.gg/8Tpq4AcN9c) pour une discussion rapide.
\ No newline at end of file
+Si jamais vous êtes bloqué ou avez une question urgente en contribuant, envoyez-nous simplement vos questions via le problème GitHub concerné, ou rejoignez notre [Discord](https://discord.gg/8Tpq4AcN9c) pour une discussion rapide.
diff --git a/CONTRIBUTING_KR.md b/CONTRIBUTING_KR.md
index fcf44d495a..78d3f38c47 100644
--- a/CONTRIBUTING_KR.md
+++ b/CONTRIBUTING_KR.md
@@ -90,4 +90,4 @@ PR 설명에 기존 이슈를 연결하거나 새 이슈를 여는 것을 잊지
설정 과정에서 문제가 발생하면 언제든지 연락해 주세요.
## 도움 받기
-기여하는 동안 막히거나 긴급한 질문이 있으면, 관련 GitHub 이슈를 통해 질문을 보내거나, 빠른 대화를 위해 우리의 [Discord](https://discord.gg/8Tpq4AcN9c)에 참여하세요.
\ No newline at end of file
+기여하는 동안 막히거나 긴급한 질문이 있으면, 관련 GitHub 이슈를 통해 질문을 보내거나, 빠른 대화를 위해 우리의 [Discord](https://discord.gg/8Tpq4AcN9c)에 참여하세요.
diff --git a/CONTRIBUTING_PT.md b/CONTRIBUTING_PT.md
index bba76c17ee..7347fd7f9c 100644
--- a/CONTRIBUTING_PT.md
+++ b/CONTRIBUTING_PT.md
@@ -90,4 +90,4 @@ Recomendamos revisar este documento cuidadosamente antes de prosseguir com a con
Sinta-se à vontade para entrar em contato se encontrar quaisquer problemas durante o processo de configuração.
## Obtendo Ajuda
-Se você ficar preso ou tiver uma dúvida urgente enquanto contribui, simplesmente envie suas perguntas através do problema relacionado no GitHub, ou entre no nosso [Discord](https://discord.gg/8Tpq4AcN9c) para uma conversa rápida.
\ No newline at end of file
+Se você ficar preso ou tiver uma dúvida urgente enquanto contribui, simplesmente envie suas perguntas através do problema relacionado no GitHub, ou entre no nosso [Discord](https://discord.gg/8Tpq4AcN9c) para uma conversa rápida.
diff --git a/CONTRIBUTING_TR.md b/CONTRIBUTING_TR.md
index 4e216d22a4..681f05689b 100644
--- a/CONTRIBUTING_TR.md
+++ b/CONTRIBUTING_TR.md
@@ -90,4 +90,4 @@ Kuruluma geçmeden önce bu belgeyi dikkatlice incelemenizi öneririz, çünkü
Kurulum süreci sırasında herhangi bir sorunla karşılaşırsanız bizimle iletişime geçmekten çekinmeyin.
## Yardım Almak
-Katkıda bulunurken takılırsanız veya yanıcı bir sorunuz olursa, sorularınızı ilgili GitHub sorunu aracılığıyla bize gönderin veya hızlı bir sohbet için [Discord'umuza](https://discord.gg/8Tpq4AcN9c) katılın.
\ No newline at end of file
+Katkıda bulunurken takılırsanız veya yanıcı bir sorunuz olursa, sorularınızı ilgili GitHub sorunu aracılığıyla bize gönderin veya hızlı bir sohbet için [Discord'umuza](https://discord.gg/8Tpq4AcN9c) katılın.
diff --git a/README.md b/README.md
index 87ebc9bafc..efb37d6083 100644
--- a/README.md
+++ b/README.md
@@ -8,7 +8,7 @@
Dify Cloud ·
Self-hosting ·
Documentation ·
- Enterprise inquiry
+ Dify edition overview
@@ -54,7 +54,7 @@
-Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features and more, letting you quickly go from prototype to production.
+Dify is an open-source LLM app development platform. Its intuitive interface combines agentic AI workflow, RAG pipeline, agent capabilities, model management, observability features, and more, allowing you to quickly move from prototype to production.
## Quick start
@@ -188,7 +188,7 @@ All of Dify's offerings come with corresponding APIs, so you could effortlessly
- **Dify for enterprise / organizations**
We provide additional enterprise-centric features. [Log your questions for us through this chatbot](https://udify.app/chat/22L1zSxg6yW1cWQg) or [send us an email](mailto:business@dify.ai?subject=[GitHub]Business%20License%20Inquiry) to discuss enterprise needs.
- > For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one-click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
+ > For startups and small businesses using AWS, check out [Dify Premium on AWS Marketplace](https://aws.amazon.com/marketplace/pp/prodview-t22mebxzwjhu6) and deploy it to your own AWS VPC with one click. It's an affordable AMI offering with the option to create apps with custom logo and branding.
## Staying ahead
@@ -233,7 +233,7 @@ Deploy Dify to AWS with [CDK](https://aws.amazon.com/cdk/)
For those who'd like to contribute code, see our [Contribution Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
At the same time, please consider supporting Dify by sharing it on social media and at events and conferences.
-> We are looking for contributors to help with translating Dify to languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c).
+> We are looking for contributors to help translate Dify into languages other than Mandarin or English. If you are interested in helping, please see the [i18n README](https://github.com/langgenius/dify/blob/main/web/i18n/README.md) for more information, and leave us a comment in the `global-users` channel of our [Discord Community Server](https://discord.gg/8Tpq4AcN9c).
## Community & contact
diff --git a/README_AR.md b/README_AR.md
index e58f59da5d..4f93802fda 100644
--- a/README_AR.md
+++ b/README_AR.md
@@ -4,7 +4,7 @@
Dify Cloud ·
الاستضافة الذاتية ·
التوثيق ·
- استفسار الشركات (للإنجليزية فقط)
+ نظرة عامة على منتجات Dify
diff --git a/README_BN.md b/README_BN.md
index 3ebc81af5d..7599fae9ff 100644
--- a/README_BN.md
+++ b/README_BN.md
@@ -8,7 +8,7 @@
ডিফাই ক্লাউড ·
সেল্ফ-হোস্টিং ·
ডকুমেন্টেশন ·
- ব্যাবসায়িক অনুসন্ধান
+ Dify পণ্যের রূপভেদ
diff --git a/README_CN.md b/README_CN.md
index 6d3c601100..973629f459 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -4,7 +4,7 @@
Dify 云服务 ·
自托管 ·
文档 ·
- (需用英文)常见问题解答 / 联系团队
+ Dify 产品形态总览
diff --git a/README_DE.md b/README_DE.md
index b3b9bf3221..738c0e3b67 100644
--- a/README_DE.md
+++ b/README_DE.md
@@ -8,7 +8,7 @@
Dify Cloud ·
Selbstgehostetes ·
Dokumentation ·
- Anfrage an Unternehmen
+ Überblick über die Dify-Produkte
diff --git a/README_ES.md b/README_ES.md
index d14afdd2eb..212268b73d 100644
--- a/README_ES.md
+++ b/README_ES.md
@@ -4,7 +4,7 @@
Dify Cloud ·
Auto-alojamiento ·
Documentación ·
- Consultas empresariales (en inglés)
+ Resumen de las ediciones de Dify
diff --git a/README_FR.md b/README_FR.md
index 031196303e..89eea7d058 100644
--- a/README_FR.md
+++ b/README_FR.md
@@ -4,7 +4,7 @@
Dify Cloud ·
Auto-hébergement ·
Documentation ·
- Demande d’entreprise (en anglais seulement)
+ Présentation des différentes offres Dify
diff --git a/README_JA.md b/README_JA.md
index 3b7a6f50db..adca219753 100644
--- a/README_JA.md
+++ b/README_JA.md
@@ -4,7 +4,7 @@
Dify Cloud ·
セルフホスティング ·
ドキュメント ·
- 企業のお問い合わせ(英語のみ)
+ Difyの各種エディションについて
diff --git a/README_KL.md b/README_KL.md
index ccadb77274..17e6c9d509 100644
--- a/README_KL.md
+++ b/README_KL.md
@@ -4,7 +4,7 @@
Dify Cloud ·
Self-hosting ·
Documentation ·
- Commercial enquiries
+ Dify product editions
diff --git a/README_KR.md b/README_KR.md
index c1a98f8b68..d44723f9b6 100644
--- a/README_KR.md
+++ b/README_KR.md
@@ -4,7 +4,7 @@
Dify 클라우드 ·
셀프-호스팅 ·
문서 ·
- 기업 문의 (영어만 가능)
+ Dify 제품 에디션 안내
diff --git a/README_PT.md b/README_PT.md
index 5b3c782645..9dc2207279 100644
--- a/README_PT.md
+++ b/README_PT.md
@@ -8,7 +8,7 @@
Dify Cloud ·
Auto-hospedagem ·
Documentação ·
- Consultas empresariais
+ Visão geral das edições do Dify
diff --git a/README_SI.md b/README_SI.md
index 7c0867c776..9a38b558b4 100644
--- a/README_SI.md
+++ b/README_SI.md
@@ -1,259 +1,259 @@
-
-
-
- 📌 Predstavljamo nalaganje datotek Dify Workflow: znova ustvarite Google NotebookLM Podcast
-
-
-
- Dify Cloud ·
- Samostojno gostovanje ·
- Dokumentacija ·
- Povpraševanje za podjetja
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-Dify je odprtokodna platforma za razvoj aplikacij LLM. Njegov intuitivni vmesnik združuje agentski potek dela z umetno inteligenco, cevovod RAG, zmogljivosti agentov, upravljanje modelov, funkcije opazovanja in več, kar vam omogoča hiter prehod od prototipa do proizvodnje.
-
-## Hitri začetek
-> Preden namestite Dify, se prepričajte, da vaša naprava izpolnjuje naslednje minimalne sistemske zahteve:
->
->- CPU >= 2 Core
->- RAM >= 4 GiB
-
-
-
-Najlažji način za zagon strežnika Dify je prek docker compose . Preden zaženete Dify z naslednjimi ukazi, se prepričajte, da sta Docker in Docker Compose nameščena na vašem računalniku:
-
-```bash
-cd dify
-cd docker
-cp .env.example .env
-docker compose up -d
-```
-
-Po zagonu lahko dostopate do nadzorne plošče Dify v brskalniku na [http://localhost/install](http://localhost/install) in začnete postopek inicializacije.
-
-#### Iskanje pomoči
-Prosimo, glejte naša pogosta vprašanja [FAQ](https://docs.dify.ai/getting-started/install-self-hosted/faqs) če naletite na težave pri nastavitvi Dify. Če imate še vedno težave, se obrnite na [skupnost ali nas](#community--contact).
-
-> Če želite prispevati k Difyju ali narediti dodaten razvoj, glejte naš vodnik za [uvajanje iz izvorne kode](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)
-
-## Ključne značilnosti
-**1. Potek dela**:
- Zgradite in preizkusite zmogljive poteke dela AI na vizualnem platnu, pri čemer izkoristite vse naslednje funkcije in več.
-
-
- https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
-
-
-
-**2. Celovita podpora za modele**:
- Brezhibna integracija s stotinami lastniških/odprtokodnih LLM-jev ducatov ponudnikov sklepanja in samostojnih rešitev, ki pokrivajo GPT, Mistral, Llama3 in vse modele, združljive z API-jem OpenAI. Celoten seznam podprtih ponudnikov modelov najdete [tukaj](https://docs.dify.ai/getting-started/readme/model-providers).
-
-
-
-
-**3. Prompt IDE**:
- intuitivni vmesnik za ustvarjanje pozivov, primerjavo zmogljivosti modela in dodajanje dodatnih funkcij, kot je pretvorba besedila v govor, aplikaciji, ki temelji na klepetu.
-
-**4. RAG Pipeline**:
- E Obsežne zmogljivosti RAG, ki pokrivajo vse od vnosa dokumenta do priklica, s podporo za ekstrakcijo besedila iz datotek PDF, PPT in drugih običajnih formatov dokumentov.
-
-**5. Agent capabilities**:
- definirate lahko agente, ki temeljijo na klicanju funkcij LLM ali ReAct, in dodate vnaprej izdelana orodja ali orodja po meri za agenta. Dify ponuja več kot 50 vgrajenih orodij za agente AI, kot so Google Search, DALL·E, Stable Diffusion in WolframAlpha.
-
-**6. LLMOps**:
- Spremljajte in analizirajte dnevnike aplikacij in učinkovitost skozi čas. Pozive, nabore podatkov in modele lahko nenehno izboljšujete na podlagi proizvodnih podatkov in opomb.
-
-**7. Backend-as-a-Service**:
- AVse ponudbe Difyja so opremljene z ustreznimi API-ji, tako da lahko Dify brez težav integrirate v svojo poslovno logiko.
-
-## Primerjava Funkcij
-
-
-
- Funkcija
- Dify.AI
- LangChain
- Flowise
- OpenAI Assistants API
-
-
- Programski pristop
- API + usmerjeno v aplikacije
- Python koda
- Usmerjeno v aplikacije
- Usmerjeno v API
-
-
- Podprti LLM-ji
- Bogata izbira
- Bogata izbira
- Bogata izbira
- Samo OpenAI
-
-
- RAG pogon
- ✅
- ✅
- ✅
- ✅
-
-
- Agent
- ✅
- ✅
- ❌
- ✅
-
-
- Potek dela
- ✅
- ❌
- ✅
- ❌
-
-
- Spremljanje
- ✅
- ✅
- ❌
- ❌
-
-
- Funkcija za podjetja (SSO/nadzor dostopa)
- ✅
- ❌
- ❌
- ❌
-
-
- Lokalna namestitev
- ✅
- ✅
- ✅
- ❌
-
-
-
-## Uporaba Dify
-
-- **Cloud **
-Gostimo storitev Dify Cloud za vsakogar, ki jo lahko preizkusite brez nastavitev. Zagotavlja vse zmožnosti različice za samostojno namestitev in vključuje 200 brezplačnih klicev GPT-4 v načrtu peskovnika.
-
-- **Self-hosting Dify Community Edition**
-Hitro zaženite Dify v svojem okolju s tem [začetnim vodnikom](#quick-start) . Za dodatne reference in podrobnejša navodila uporabite našo [dokumentacijo](https://docs.dify.ai) .
-
-
-- **Dify za podjetja/organizacije**
-Ponujamo dodatne funkcije, osredotočene na podjetja. Zabeležite svoja vprašanja prek tega klepetalnega robota ali nam pošljite e-pošto, da se pogovorimo o potrebah podjetja.
- > Za novoustanovljena podjetja in mala podjetja, ki uporabljajo AWS, si oglejte Dify Premium na AWS Marketplace in ga z enim klikom uvedite v svoj AWS VPC. To je cenovno ugodna ponudba AMI z možnostjo ustvarjanja aplikacij z logotipom in blagovno znamko po meri.
-
-
-## Staying ahead
-
-Star Dify on GitHub and be instantly notified of new releases.
-
-
-
-
-## Napredne nastavitve
-
-Če morate prilagoditi konfiguracijo, si oglejte komentarje v naši datoteki .env.example in posodobite ustrezne vrednosti v svoji .env datoteki. Poleg tega boste morda morali prilagoditi docker-compose.yamlsamo datoteko, na primer spremeniti različice slike, preslikave vrat ali namestitve nosilca, glede na vaše specifično okolje in zahteve za uvajanje. Po kakršnih koli spremembah ponovno zaženite docker-compose up -d. Celoten seznam razpoložljivih spremenljivk okolja najdete tukaj .
-
-Če želite konfigurirati visoko razpoložljivo nastavitev, so na voljo Helm Charts in datoteke YAML, ki jih prispeva skupnost, ki omogočajo uvedbo Difyja v Kubernetes.
-
-- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
-- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
-- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
-- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
-
-#### Uporaba Terraform za uvajanje
-
-namestite Dify v Cloud Platform z enim klikom z uporabo [terraform](https://www.terraform.io/)
-
-##### Azure Global
-- [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform)
-
-##### Google Cloud
-- [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform)
-
-#### Uporaba AWS CDK za uvajanje
-
-Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/)
-
-##### AWS
-- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
-
-## Prispevam
-
-Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkrati vas prosimo, da podprete Dify tako, da ga delite na družbenih medijih ter na dogodkih in konferencah.
-
-
-
-> Iščemo sodelavce za pomoč pri prevajanju Difyja v jezike, ki niso mandarinščina ali angleščina. Če želite pomagati, si oglejte i18n README za več informacij in nam pustite komentar v global-userskanalu našega strežnika skupnosti Discord .
-
-## Skupnost in stik
-
-* [Github Discussion](https://github.com/langgenius/dify/discussions). Najboljše za: izmenjavo povratnih informacij in postavljanje vprašanj.
-* [GitHub Issues](https://github.com/langgenius/dify/issues). Najboljše za: hrošče, na katere naletite pri uporabi Dify.AI, in predloge funkcij. Oglejte si naš [vodnik za prispevke](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
-* [Discord](https://discord.gg/FngNHpbcY7). Najboljše za: deljenje vaših aplikacij in druženje s skupnostjo.
-* [X(Twitter)](https://twitter.com/dify_ai). Najboljše za: deljenje vaših aplikacij in druženje s skupnostjo.
-
-**Contributors**
-
-
-
-
-
-## Star history
-
-[](https://star-history.com/#langgenius/dify&Date)
-
-
-## Varnostno razkritje
-
-Zaradi zaščite vaše zasebnosti se izogibajte objavljanju varnostnih vprašanj na GitHub. Namesto tega pošljite vprašanja na security@dify.ai in zagotovili vam bomo podrobnejši odgovor.
-
-## Licenca
-
-To skladišče je na voljo pod [odprtokodno licenco Dify](LICENSE) , ki je v bistvu Apache 2.0 z nekaj dodatnimi omejitvami.
+
+
+
+ 📌 Predstavljamo nalaganje datotek Dify Workflow: znova ustvarite Google NotebookLM Podcast
+
+
+
+ Dify Cloud ·
+ Samostojno gostovanje ·
+ Dokumentacija ·
+ Pregled ponudb izdelkov Dify
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+Dify je odprtokodna platforma za razvoj aplikacij LLM. Njegov intuitivni vmesnik združuje agentski potek dela z umetno inteligenco, cevovod RAG, zmogljivosti agentov, upravljanje modelov, funkcije opazovanja in več, kar vam omogoča hiter prehod od prototipa do proizvodnje.
+
+## Hitri začetek
+> Preden namestite Dify, se prepričajte, da vaša naprava izpolnjuje naslednje minimalne sistemske zahteve:
+>
+>- CPU >= 2 Core
+>- RAM >= 4 GiB
+
+
+
+Najlažji način za zagon strežnika Dify je prek docker compose . Preden zaženete Dify z naslednjimi ukazi, se prepričajte, da sta Docker in Docker Compose nameščena na vašem računalniku:
+
+```bash
+cd dify
+cd docker
+cp .env.example .env
+docker compose up -d
+```
+
+Po zagonu lahko dostopate do nadzorne plošče Dify v brskalniku na [http://localhost/install](http://localhost/install) in začnete postopek inicializacije.
+
+#### Iskanje pomoči
+Prosimo, glejte naša pogosta vprašanja [FAQ](https://docs.dify.ai/getting-started/install-self-hosted/faqs) če naletite na težave pri nastavitvi Dify. Če imate še vedno težave, se obrnite na [skupnost ali nas](#community--contact).
+
+> Če želite prispevati k Difyju ali narediti dodaten razvoj, glejte naš vodnik za [uvajanje iz izvorne kode](https://docs.dify.ai/getting-started/install-self-hosted/local-source-code)
+
+## Ključne značilnosti
+**1. Potek dela**:
+ Zgradite in preizkusite zmogljive poteke dela AI na vizualnem platnu, pri čemer izkoristite vse naslednje funkcije in več.
+
+
+ https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa
+
+
+
+**2. Celovita podpora za modele**:
+ Brezhibna integracija s stotinami lastniških/odprtokodnih LLM-jev ducatov ponudnikov sklepanja in samostojnih rešitev, ki pokrivajo GPT, Mistral, Llama3 in vse modele, združljive z API-jem OpenAI. Celoten seznam podprtih ponudnikov modelov najdete [tukaj](https://docs.dify.ai/getting-started/readme/model-providers).
+
+
+
+
+**3. Prompt IDE**:
+ intuitivni vmesnik za ustvarjanje pozivov, primerjavo zmogljivosti modela in dodajanje dodatnih funkcij, kot je pretvorba besedila v govor, aplikaciji, ki temelji na klepetu.
+
+**4. RAG Pipeline**:
+ E Obsežne zmogljivosti RAG, ki pokrivajo vse od vnosa dokumenta do priklica, s podporo za ekstrakcijo besedila iz datotek PDF, PPT in drugih običajnih formatov dokumentov.
+
+**5. Agent capabilities**:
+ definirate lahko agente, ki temeljijo na klicanju funkcij LLM ali ReAct, in dodate vnaprej izdelana orodja ali orodja po meri za agenta. Dify ponuja več kot 50 vgrajenih orodij za agente AI, kot so Google Search, DALL·E, Stable Diffusion in WolframAlpha.
+
+**6. LLMOps**:
+ Spremljajte in analizirajte dnevnike aplikacij in učinkovitost skozi čas. Pozive, nabore podatkov in modele lahko nenehno izboljšujete na podlagi proizvodnih podatkov in opomb.
+
+**7. Backend-as-a-Service**:
+ AVse ponudbe Difyja so opremljene z ustreznimi API-ji, tako da lahko Dify brez težav integrirate v svojo poslovno logiko.
+
+## Primerjava Funkcij
+
+
+
+ Funkcija
+ Dify.AI
+ LangChain
+ Flowise
+ OpenAI Assistants API
+
+
+ Programski pristop
+ API + usmerjeno v aplikacije
+ Python koda
+ Usmerjeno v aplikacije
+ Usmerjeno v API
+
+
+ Podprti LLM-ji
+ Bogata izbira
+ Bogata izbira
+ Bogata izbira
+ Samo OpenAI
+
+
+ RAG pogon
+ ✅
+ ✅
+ ✅
+ ✅
+
+
+ Agent
+ ✅
+ ✅
+ ❌
+ ✅
+
+
+ Potek dela
+ ✅
+ ❌
+ ✅
+ ❌
+
+
+ Spremljanje
+ ✅
+ ✅
+ ❌
+ ❌
+
+
+ Funkcija za podjetja (SSO/nadzor dostopa)
+ ✅
+ ❌
+ ❌
+ ❌
+
+
+ Lokalna namestitev
+ ✅
+ ✅
+ ✅
+ ❌
+
+
+
+## Uporaba Dify
+
+- **Cloud **
+Gostimo storitev Dify Cloud za vsakogar, ki jo lahko preizkusite brez nastavitev. Zagotavlja vse zmožnosti različice za samostojno namestitev in vključuje 200 brezplačnih klicev GPT-4 v načrtu peskovnika.
+
+- **Self-hosting Dify Community Edition**
+Hitro zaženite Dify v svojem okolju s tem [začetnim vodnikom](#quick-start) . Za dodatne reference in podrobnejša navodila uporabite našo [dokumentacijo](https://docs.dify.ai) .
+
+
+- **Dify za podjetja/organizacije**
+Ponujamo dodatne funkcije, osredotočene na podjetja. Zabeležite svoja vprašanja prek tega klepetalnega robota ali nam pošljite e-pošto, da se pogovorimo o potrebah podjetja.
+ > Za novoustanovljena podjetja in mala podjetja, ki uporabljajo AWS, si oglejte Dify Premium na AWS Marketplace in ga z enim klikom uvedite v svoj AWS VPC. To je cenovno ugodna ponudba AMI z možnostjo ustvarjanja aplikacij z logotipom in blagovno znamko po meri.
+
+
+## Staying ahead
+
+Star Dify on GitHub and be instantly notified of new releases.
+
+
+
+
+## Napredne nastavitve
+
+Če morate prilagoditi konfiguracijo, si oglejte komentarje v naši datoteki .env.example in posodobite ustrezne vrednosti v svoji .env datoteki. Poleg tega boste morda morali prilagoditi docker-compose.yamlsamo datoteko, na primer spremeniti različice slike, preslikave vrat ali namestitve nosilca, glede na vaše specifično okolje in zahteve za uvajanje. Po kakršnih koli spremembah ponovno zaženite docker-compose up -d. Celoten seznam razpoložljivih spremenljivk okolja najdete tukaj .
+
+Če želite konfigurirati visoko razpoložljivo nastavitev, so na voljo Helm Charts in datoteke YAML, ki jih prispeva skupnost, ki omogočajo uvedbo Difyja v Kubernetes.
+
+- [Helm Chart by @LeoQuote](https://github.com/douban/charts/tree/master/charts/dify)
+- [Helm Chart by @BorisPolonsky](https://github.com/BorisPolonsky/dify-helm)
+- [YAML file by @Winson-030](https://github.com/Winson-030/dify-kubernetes)
+- [YAML file by @wyy-holding](https://github.com/wyy-holding/dify-k8s)
+
+#### Uporaba Terraform za uvajanje
+
+namestite Dify v Cloud Platform z enim klikom z uporabo [terraform](https://www.terraform.io/)
+
+##### Azure Global
+- [Azure Terraform by @nikawang](https://github.com/nikawang/dify-azure-terraform)
+
+##### Google Cloud
+- [Google Cloud Terraform by @sotazum](https://github.com/DeNA/dify-google-cloud-terraform)
+
+#### Uporaba AWS CDK za uvajanje
+
+Uvedite Dify v AWS z uporabo [CDK](https://aws.amazon.com/cdk/)
+
+##### AWS
+- [AWS CDK by @KevinZhao](https://github.com/aws-samples/solution-for-deploying-dify-on-aws)
+
+## Prispevam
+
+Za tiste, ki bi radi prispevali kodo, si oglejte naš vodnik za prispevke . Hkrati vas prosimo, da podprete Dify tako, da ga delite na družbenih medijih ter na dogodkih in konferencah.
+
+
+
+> Iščemo sodelavce za pomoč pri prevajanju Difyja v jezike, ki niso mandarinščina ali angleščina. Če želite pomagati, si oglejte i18n README za več informacij in nam pustite komentar v global-userskanalu našega strežnika skupnosti Discord .
+
+## Skupnost in stik
+
+* [Github Discussion](https://github.com/langgenius/dify/discussions). Najboljše za: izmenjavo povratnih informacij in postavljanje vprašanj.
+* [GitHub Issues](https://github.com/langgenius/dify/issues). Najboljše za: hrošče, na katere naletite pri uporabi Dify.AI, in predloge funkcij. Oglejte si naš [vodnik za prispevke](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md).
+* [Discord](https://discord.gg/FngNHpbcY7). Najboljše za: deljenje vaših aplikacij in druženje s skupnostjo.
+* [X(Twitter)](https://twitter.com/dify_ai). Najboljše za: deljenje vaših aplikacij in druženje s skupnostjo.
+
+**Contributors**
+
+
+
+
+
+## Star history
+
+[](https://star-history.com/#langgenius/dify&Date)
+
+
+## Varnostno razkritje
+
+Zaradi zaščite vaše zasebnosti se izogibajte objavljanju varnostnih vprašanj na GitHub. Namesto tega pošljite vprašanja na security@dify.ai in zagotovili vam bomo podrobnejši odgovor.
+
+## Licenca
+
+To skladišče je na voljo pod [odprtokodno licenco Dify](LICENSE) , ki je v bistvu Apache 2.0 z nekaj dodatnimi omejitvami.
diff --git a/README_TR.md b/README_TR.md
index f8890b00ef..ab2853a019 100644
--- a/README_TR.md
+++ b/README_TR.md
@@ -4,7 +4,7 @@
Dify Bulut ·
Kendi Sunucunuzda Barındırma ·
Dokümantasyon ·
- Yalnızca İngilizce: Kurumsal Sorgulama
+ Dify ürün seçeneklerine genel bakış
diff --git a/README_TW.md b/README_TW.md
index 260f1e80ac..8263a22b64 100644
--- a/README_TW.md
+++ b/README_TW.md
@@ -8,7 +8,7 @@
Dify 雲端服務 ·
自行託管 ·
說明文件 ·
- 企業諮詢
+ 產品方案概覽
diff --git a/README_VI.md b/README_VI.md
index 15d2d5ae80..852ed7aaa0 100644
--- a/README_VI.md
+++ b/README_VI.md
@@ -4,7 +4,7 @@
Dify Cloud ·
Tự triển khai ·
Tài liệu ·
- Yêu cầu doanh nghiệp
+ Tổng quan các lựa chọn sản phẩm Dify
diff --git a/api/.dockerignore b/api/.dockerignore
index 447edcda08..a0ce59d221 100644
--- a/api/.dockerignore
+++ b/api/.dockerignore
@@ -16,4 +16,4 @@ logs
.ruff_cache
# venv
-.venv
\ No newline at end of file
+.venv
diff --git a/api/.env.example b/api/.env.example
index 01ddb4adfd..2cc6410cdd 100644
--- a/api/.env.example
+++ b/api/.env.example
@@ -297,6 +297,7 @@ LINDORM_URL=http://ld-*******************-proxy-search-pub.lindorm.aliyuncs.com:
LINDORM_USERNAME=admin
LINDORM_PASSWORD=admin
USING_UGC_INDEX=False
+LINDORM_QUERY_TIMEOUT=1
# OceanBase Vector configuration
OCEANBASE_VECTOR_HOST=127.0.0.1
@@ -482,4 +483,7 @@ OTEL_MAX_QUEUE_SIZE=2048
OTEL_MAX_EXPORT_BATCH_SIZE=512
OTEL_METRIC_EXPORT_INTERVAL=60000
OTEL_BATCH_EXPORT_TIMEOUT=10000
-OTEL_METRIC_EXPORT_TIMEOUT=30000
\ No newline at end of file
+OTEL_METRIC_EXPORT_TIMEOUT=30000
+
+# Prevent Clickjacking
+ALLOW_EMBED=false
diff --git a/api/README.md b/api/README.md
index c542f11b16..9308d5dc44 100644
--- a/api/README.md
+++ b/api/README.md
@@ -90,3 +90,4 @@
```bash
uv run -P api bash dev/pytest/pytest_all_tests.sh
```
+
diff --git a/api/app.py b/api/app.py
index 9830a80904..4f393f6c20 100644
--- a/api/app.py
+++ b/api/app.py
@@ -18,7 +18,7 @@ else:
# so we need to disable gevent in debug mode.
# If you are using debugpy and set GEVENT_SUPPORT=True, you can debug with gevent.
if (flask_debug := os.environ.get("FLASK_DEBUG", "0")) and flask_debug.lower() in {"false", "0", "no"}:
- from gevent import monkey # type: ignore
+ from gevent import monkey
# gevent
monkey.patch_all()
diff --git a/api/app_factory.py b/api/app_factory.py
index 586f2ded9e..1c886ac5c7 100644
--- a/api/app_factory.py
+++ b/api/app_factory.py
@@ -54,7 +54,6 @@ def initialize_extensions(app: DifyApp):
ext_otel,
ext_proxy_fix,
ext_redis,
- ext_repositories,
ext_sentry,
ext_set_secretkey,
ext_storage,
@@ -75,7 +74,6 @@ def initialize_extensions(app: DifyApp):
ext_migrate,
ext_redis,
ext_storage,
- ext_repositories,
ext_celery,
ext_login,
ext_mail,
diff --git a/api/commands.py b/api/commands.py
index e70d6e0b49..dc31dc0d80 100644
--- a/api/commands.py
+++ b/api/commands.py
@@ -17,6 +17,7 @@ from core.rag.models.document import Document
from events.app_event import app_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from extensions.ext_storage import storage
from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair
@@ -271,6 +272,7 @@ def migrate_knowledge_vector_database():
upper_collection_vector_types = {
VectorType.MILVUS,
VectorType.PGVECTOR,
+ VectorType.VASTBASE,
VectorType.RELYT,
VectorType.WEAVIATE,
VectorType.ORACLE,
@@ -442,13 +444,13 @@ def convert_to_agent_apps():
WHERE a.mode = 'chat'
AND am.agent_mode is not null
AND (
- am.agent_mode like '%"strategy": "function_call"%'
+ am.agent_mode like '%"strategy": "function_call"%'
OR am.agent_mode like '%"strategy": "react"%'
- )
+ )
AND (
- am.agent_mode like '{"enabled": true%'
+ am.agent_mode like '{"enabled": true%'
OR am.agent_mode like '{"max_iteration": %'
- ) ORDER BY a.created_at DESC LIMIT 1000
+ ) ORDER BY a.created_at DESC LIMIT 1000
"""
with db.engine.begin() as conn:
@@ -666,7 +668,7 @@ def upgrade_db():
click.echo(click.style("Starting database migration.", fg="green"))
# run db migration
- import flask_migrate # type: ignore
+ import flask_migrate
flask_migrate.upgrade()
@@ -814,3 +816,331 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
ClearFreePlanTenantExpiredLogs.process(days, batch, tenant_ids)
click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green"))
+
+
+@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
+@click.command("clear-orphaned-file-records", help="Clear orphaned file records.")
+def clear_orphaned_file_records(force: bool):
+ """
+ Clear orphaned file records in the database.
+ """
+
+ # define tables and columns to process
+ files_tables = [
+ {"table": "upload_files", "id_column": "id", "key_column": "key"},
+ {"table": "tool_files", "id_column": "id", "key_column": "file_key"},
+ ]
+ ids_tables = [
+ {"type": "uuid", "table": "message_files", "column": "upload_file_id"},
+ {"type": "text", "table": "documents", "column": "data_source_info"},
+ {"type": "text", "table": "document_segments", "column": "content"},
+ {"type": "text", "table": "messages", "column": "answer"},
+ {"type": "text", "table": "workflow_node_executions", "column": "inputs"},
+ {"type": "text", "table": "workflow_node_executions", "column": "process_data"},
+ {"type": "text", "table": "workflow_node_executions", "column": "outputs"},
+ {"type": "text", "table": "conversations", "column": "introduction"},
+ {"type": "text", "table": "conversations", "column": "system_instruction"},
+ {"type": "json", "table": "messages", "column": "inputs"},
+ {"type": "json", "table": "messages", "column": "message"},
+ ]
+
+ # notify user and ask for confirmation
+ click.echo(
+ click.style(
+ "This command will first find and delete orphaned file records from the message_files table,", fg="yellow"
+ )
+ )
+ click.echo(
+ click.style(
+ "and then it will find and delete orphaned file records in the following tables:",
+ fg="yellow",
+ )
+ )
+ for files_table in files_tables:
+ click.echo(click.style(f"- {files_table['table']}", fg="yellow"))
+ click.echo(
+ click.style("The following tables and columns will be scanned to find orphaned file records:", fg="yellow")
+ )
+ for ids_table in ids_tables:
+ click.echo(click.style(f"- {ids_table['table']} ({ids_table['column']})", fg="yellow"))
+ click.echo("")
+
+ click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red"))
+ click.echo(
+ click.style(
+ (
+ "Since not all patterns have been fully tested, "
+ "please note that this command may delete unintended file records."
+ ),
+ fg="yellow",
+ )
+ )
+ click.echo(
+ click.style("This cannot be undone. Please make sure to back up your database before proceeding.", fg="yellow")
+ )
+ click.echo(
+ click.style(
+ (
+ "It is also recommended to run this during the maintenance window, "
+ "as this may cause high load on your instance."
+ ),
+ fg="yellow",
+ )
+ )
+ if not force:
+ click.confirm("Do you want to proceed?", abort=True)
+
+ # start the cleanup process
+ click.echo(click.style("Starting orphaned file records cleanup.", fg="white"))
+
+ # clean up the orphaned records in the message_files table where message_id doesn't exist in messages table
+ try:
+ click.echo(
+ click.style("- Listing message_files records where message_id doesn't exist in messages table", fg="white")
+ )
+ query = (
+ "SELECT mf.id, mf.message_id "
+ "FROM message_files mf LEFT JOIN messages m ON mf.message_id = m.id "
+ "WHERE m.id IS NULL"
+ )
+ orphaned_message_files = []
+ with db.engine.begin() as conn:
+ rs = conn.execute(db.text(query))
+ for i in rs:
+ orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])})
+
+ if orphaned_message_files:
+ click.echo(click.style(f"Found {len(orphaned_message_files)} orphaned message_files records:", fg="white"))
+ for record in orphaned_message_files:
+ click.echo(click.style(f" - id: {record['id']}, message_id: {record['message_id']}", fg="black"))
+
+ if not force:
+ click.confirm(
+ (
+ f"Do you want to proceed "
+ f"to delete all {len(orphaned_message_files)} orphaned message_files records?"
+ ),
+ abort=True,
+ )
+
+ click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
+ query = "DELETE FROM message_files WHERE id IN :ids"
+ with db.engine.begin() as conn:
+ conn.execute(db.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
+ click.echo(
+ click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
+ )
+ else:
+ click.echo(click.style("No orphaned message_files records found. There is nothing to delete.", fg="green"))
+ except Exception as e:
+ click.echo(click.style(f"Error deleting orphaned message_files records: {str(e)}", fg="red"))
+
+ # clean up the orphaned records in the rest of the *_files tables
+ try:
+ # fetch file id and keys from each table
+ all_files_in_tables = []
+ for files_table in files_tables:
+ click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white"))
+ query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}"
+ with db.engine.begin() as conn:
+ rs = conn.execute(db.text(query))
+ for i in rs:
+ all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]})
+ click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
+
+ # fetch referred table and columns
+ guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
+ all_ids_in_tables = []
+ for ids_table in ids_tables:
+ query = ""
+ if ids_table["type"] == "uuid":
+ click.echo(
+ click.style(
+ f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", fg="white"
+ )
+ )
+ query = (
+ f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
+ )
+ with db.engine.begin() as conn:
+ rs = conn.execute(db.text(query))
+ for i in rs:
+ all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
+ elif ids_table["type"] == "text":
+ click.echo(
+ click.style(
+ f"- Listing file-id-like strings in column {ids_table['column']} in table {ids_table['table']}",
+ fg="white",
+ )
+ )
+ query = (
+ f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
+ f"FROM {ids_table['table']}"
+ )
+ with db.engine.begin() as conn:
+ rs = conn.execute(db.text(query))
+ for i in rs:
+ for j in i[0]:
+ all_ids_in_tables.append({"table": ids_table["table"], "id": j})
+ elif ids_table["type"] == "json":
+ click.echo(
+ click.style(
+ (
+ f"- Listing file-id-like JSON string in column {ids_table['column']} "
+ f"in table {ids_table['table']}"
+ ),
+ fg="white",
+ )
+ )
+ query = (
+ f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
+ f"FROM {ids_table['table']}"
+ )
+ with db.engine.begin() as conn:
+ rs = conn.execute(db.text(query))
+ for i in rs:
+ for j in i[0]:
+ all_ids_in_tables.append({"table": ids_table["table"], "id": j})
+ click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white"))
+
+ except Exception as e:
+ click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
+ return
+
+ # find orphaned files
+ all_files = [file["id"] for file in all_files_in_tables]
+ all_ids = [file["id"] for file in all_ids_in_tables]
+ orphaned_files = list(set(all_files) - set(all_ids))
+ if not orphaned_files:
+ click.echo(click.style("No orphaned file records found. There is nothing to delete.", fg="green"))
+ return
+ click.echo(click.style(f"Found {len(orphaned_files)} orphaned file records.", fg="white"))
+ for file in orphaned_files:
+ click.echo(click.style(f"- orphaned file id: {file}", fg="black"))
+ if not force:
+ click.confirm(f"Do you want to proceed to delete all {len(orphaned_files)} orphaned file records?", abort=True)
+
+ # delete orphaned records for each file
+ try:
+ for files_table in files_tables:
+ click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white"))
+ query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids"
+ with db.engine.begin() as conn:
+ conn.execute(db.text(query), {"ids": tuple(orphaned_files)})
+ except Exception as e:
+ click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red"))
+ return
+ click.echo(click.style(f"Removed {len(orphaned_files)} orphaned file records.", fg="green"))
+
+
+@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
+@click.command("remove-orphaned-files-on-storage", help="Remove orphaned files on the storage.")
+def remove_orphaned_files_on_storage(force: bool):
+ """
+ Remove orphaned files on the storage.
+ """
+
+ # define tables and columns to process
+ files_tables = [
+ {"table": "upload_files", "key_column": "key"},
+ {"table": "tool_files", "key_column": "file_key"},
+ ]
+ storage_paths = ["image_files", "tools", "upload_files"]
+
+ # notify user and ask for confirmation
+ click.echo(click.style("This command will find and remove orphaned files on the storage,", fg="yellow"))
+ click.echo(
+ click.style("by comparing the files on the storage with the records in the following tables:", fg="yellow")
+ )
+ for files_table in files_tables:
+ click.echo(click.style(f"- {files_table['table']}", fg="yellow"))
+ click.echo(click.style("The following paths on the storage will be scanned to find orphaned files:", fg="yellow"))
+ for storage_path in storage_paths:
+ click.echo(click.style(f"- {storage_path}", fg="yellow"))
+ click.echo("")
+
+ click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red"))
+ click.echo(
+ click.style(
+ "Currently, this command will work only for opendal based storage (STORAGE_TYPE=opendal).", fg="yellow"
+ )
+ )
+ click.echo(
+ click.style(
+ "Since not all patterns have been fully tested, please note that this command may delete unintended files.",
+ fg="yellow",
+ )
+ )
+ click.echo(
+ click.style("This cannot be undone. Please make sure to back up your storage before proceeding.", fg="yellow")
+ )
+ click.echo(
+ click.style(
+ (
+ "It is also recommended to run this during the maintenance window, "
+ "as this may cause high load on your instance."
+ ),
+ fg="yellow",
+ )
+ )
+ if not force:
+ click.confirm("Do you want to proceed?", abort=True)
+
+ # start the cleanup process
+ click.echo(click.style("Starting orphaned files cleanup.", fg="white"))
+
+ # fetch file id and keys from each table
+ all_files_in_tables = []
+ try:
+ for files_table in files_tables:
+ click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white"))
+ query = f"SELECT {files_table['key_column']} FROM {files_table['table']}"
+ with db.engine.begin() as conn:
+ rs = conn.execute(db.text(query))
+ for i in rs:
+ all_files_in_tables.append(str(i[0]))
+ click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white"))
+ except Exception as e:
+ click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red"))
+
+ all_files_on_storage = []
+ for storage_path in storage_paths:
+ try:
+ click.echo(click.style(f"- Scanning files on storage path {storage_path}", fg="white"))
+ files = storage.scan(path=storage_path, files=True, directories=False)
+ all_files_on_storage.extend(files)
+ except FileNotFoundError as e:
+ click.echo(click.style(f" -> Skipping path {storage_path} as it does not exist.", fg="yellow"))
+ continue
+ except Exception as e:
+ click.echo(click.style(f" -> Error scanning files on storage path {storage_path}: {str(e)}", fg="red"))
+ continue
+ click.echo(click.style(f"Found {len(all_files_on_storage)} files on storage.", fg="white"))
+
+ # find orphaned files
+ orphaned_files = list(set(all_files_on_storage) - set(all_files_in_tables))
+ if not orphaned_files:
+ click.echo(click.style("No orphaned files found. There is nothing to remove.", fg="green"))
+ return
+ click.echo(click.style(f"Found {len(orphaned_files)} orphaned files.", fg="white"))
+ for file in orphaned_files:
+ click.echo(click.style(f"- orphaned file: {file}", fg="black"))
+ if not force:
+ click.confirm(f"Do you want to proceed to remove all {len(orphaned_files)} orphaned files?", abort=True)
+
+ # delete orphaned files
+ removed_files = 0
+ error_files = 0
+ for file in orphaned_files:
+ try:
+ storage.delete(file)
+ removed_files += 1
+ click.echo(click.style(f"- Removing orphaned file: {file}", fg="white"))
+ except Exception as e:
+ error_files += 1
+ click.echo(click.style(f"- Error deleting orphaned file {file}: {str(e)}", fg="red"))
+ continue
+ if error_files == 0:
+ click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green"))
+ else:
+ click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow"))
diff --git a/api/configs/app_config.py b/api/configs/app_config.py
index cb0adb751c..3a3ad35ee7 100644
--- a/api/configs/app_config.py
+++ b/api/configs/app_config.py
@@ -13,6 +13,7 @@ from .observability import ObservabilityConfig
from .packaging import PackagingInfo
from .remote_settings_sources import RemoteSettingsSource, RemoteSettingsSourceConfig, RemoteSettingsSourceName
from .remote_settings_sources.apollo import ApolloSettingsSource
+from .remote_settings_sources.nacos import NacosSettingsSource
logger = logging.getLogger(__name__)
@@ -34,6 +35,8 @@ class RemoteSettingsSourceFactory(PydanticBaseSettingsSource):
match remote_source_name:
case RemoteSettingsSourceName.APOLLO:
remote_source = ApolloSettingsSource(current_state)
+ case RemoteSettingsSourceName.NACOS:
+ remote_source = NacosSettingsSource(current_state)
case _:
logger.warning(f"Unsupported remote source: {remote_source_name}")
return {}
diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py
index f498dccbbc..4890b5f746 100644
--- a/api/configs/feature/__init__.py
+++ b/api/configs/feature/__init__.py
@@ -398,6 +398,11 @@ class InnerAPIConfig(BaseSettings):
default=False,
)
+ INNER_API_KEY: Optional[str] = Field(
+ description="API key for accessing the internal API",
+ default=None,
+ )
+
class LoggingConfig(BaseSettings):
"""
diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py
index 15dfe0063b..d285515998 100644
--- a/api/configs/middleware/__init__.py
+++ b/api/configs/middleware/__init__.py
@@ -22,6 +22,7 @@ from .vdb.baidu_vector_config import BaiduVectorDBConfig
from .vdb.chroma_config import ChromaConfig
from .vdb.couchbase_config import CouchbaseConfig
from .vdb.elasticsearch_config import ElasticsearchConfig
+from .vdb.huawei_cloud_config import HuaweiCloudConfig
from .vdb.lindorm_config import LindormConfig
from .vdb.milvus_config import MilvusConfig
from .vdb.myscale_config import MyScaleConfig
@@ -38,6 +39,7 @@ from .vdb.tencent_vector_config import TencentVectorDBConfig
from .vdb.tidb_on_qdrant_config import TidbOnQdrantConfig
from .vdb.tidb_vector_config import TiDBVectorConfig
from .vdb.upstash_config import UpstashConfig
+from .vdb.vastbase_vector_config import VastbaseVectorConfig
from .vdb.vikingdb_config import VikingDBConfig
from .vdb.weaviate_config import WeaviateConfig
@@ -263,11 +265,13 @@ class MiddlewareConfig(
VectorStoreConfig,
AnalyticdbConfig,
ChromaConfig,
+ HuaweiCloudConfig,
MilvusConfig,
MyScaleConfig,
OpenSearchConfig,
OracleConfig,
PGVectorConfig,
+ VastbaseVectorConfig,
PGVectoRSConfig,
QdrantConfig,
RelytConfig,
diff --git a/api/configs/middleware/vdb/huawei_cloud_config.py b/api/configs/middleware/vdb/huawei_cloud_config.py
new file mode 100644
index 0000000000..2290c60499
--- /dev/null
+++ b/api/configs/middleware/vdb/huawei_cloud_config.py
@@ -0,0 +1,25 @@
+from typing import Optional
+
+from pydantic import Field
+from pydantic_settings import BaseSettings
+
+
+class HuaweiCloudConfig(BaseSettings):
+ """
+ Configuration settings for Huawei cloud search service
+ """
+
+ HUAWEI_CLOUD_HOSTS: Optional[str] = Field(
+ description="Hostname or IP address of the Huawei cloud search service instance",
+ default=None,
+ )
+
+ HUAWEI_CLOUD_USER: Optional[str] = Field(
+ description="Username for authenticating with Huawei cloud search service",
+ default=None,
+ )
+
+ HUAWEI_CLOUD_PASSWORD: Optional[str] = Field(
+ description="Password for authenticating with Huawei cloud search service",
+ default=None,
+ )
diff --git a/api/configs/middleware/vdb/lindorm_config.py b/api/configs/middleware/vdb/lindorm_config.py
index 95e1d1cfca..e80e3f4a35 100644
--- a/api/configs/middleware/vdb/lindorm_config.py
+++ b/api/configs/middleware/vdb/lindorm_config.py
@@ -32,3 +32,4 @@ class LindormConfig(BaseSettings):
description="Using UGC index will store the same type of Index in a single index but can retrieve separately.",
default=False,
)
+ LINDORM_QUERY_TIMEOUT: Optional[float] = Field(description="The lindorm search request timeout (s)", default=2.0)
diff --git a/api/configs/middleware/vdb/opensearch_config.py b/api/configs/middleware/vdb/opensearch_config.py
index 81dde4c04d..96f478e9a6 100644
--- a/api/configs/middleware/vdb/opensearch_config.py
+++ b/api/configs/middleware/vdb/opensearch_config.py
@@ -1,4 +1,5 @@
-from typing import Optional
+import enum
+from typing import Literal, Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
@@ -9,6 +10,14 @@ class OpenSearchConfig(BaseSettings):
Configuration settings for OpenSearch
"""
+ class AuthMethod(enum.StrEnum):
+ """
+ Authentication method for OpenSearch
+ """
+
+ BASIC = "basic"
+ AWS_MANAGED_IAM = "aws_managed_iam"
+
OPENSEARCH_HOST: Optional[str] = Field(
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
default=None,
@@ -19,6 +28,16 @@ class OpenSearchConfig(BaseSettings):
default=9200,
)
+ OPENSEARCH_SECURE: bool = Field(
+ description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
+ default=False,
+ )
+
+ OPENSEARCH_AUTH_METHOD: AuthMethod = Field(
+ description="Authentication method for OpenSearch connection (default is 'basic')",
+ default=AuthMethod.BASIC,
+ )
+
OPENSEARCH_USER: Optional[str] = Field(
description="Username for authenticating with OpenSearch",
default=None,
@@ -29,7 +48,11 @@ class OpenSearchConfig(BaseSettings):
default=None,
)
- OPENSEARCH_SECURE: bool = Field(
- description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
- default=False,
+ OPENSEARCH_AWS_REGION: Optional[str] = Field(
+ description="AWS region for OpenSearch (e.g. 'us-west-2')",
+ default=None,
+ )
+
+ OPENSEARCH_AWS_SERVICE: Optional[Literal["es", "aoss"]] = Field(
+ description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)", default=None
)
diff --git a/api/configs/middleware/vdb/vastbase_vector_config.py b/api/configs/middleware/vdb/vastbase_vector_config.py
new file mode 100644
index 0000000000..816d6df90a
--- /dev/null
+++ b/api/configs/middleware/vdb/vastbase_vector_config.py
@@ -0,0 +1,45 @@
+from typing import Optional
+
+from pydantic import Field, PositiveInt
+from pydantic_settings import BaseSettings
+
+
+class VastbaseVectorConfig(BaseSettings):
+ """
+ Configuration settings for Vector (Vastbase with vector extension)
+ """
+
+ VASTBASE_HOST: Optional[str] = Field(
+ description="Hostname or IP address of the Vastbase server with Vector extension (e.g., 'localhost')",
+ default=None,
+ )
+
+ VASTBASE_PORT: PositiveInt = Field(
+ description="Port number on which the Vastbase server is listening (default is 5432)",
+ default=5432,
+ )
+
+ VASTBASE_USER: Optional[str] = Field(
+ description="Username for authenticating with the Vastbase database",
+ default=None,
+ )
+
+ VASTBASE_PASSWORD: Optional[str] = Field(
+ description="Password for authenticating with the Vastbase database",
+ default=None,
+ )
+
+ VASTBASE_DATABASE: Optional[str] = Field(
+ description="Name of the Vastbase database to connect to",
+ default=None,
+ )
+
+ VASTBASE_MIN_CONNECTION: PositiveInt = Field(
+ description="Min connection of the Vastbase database",
+ default=1,
+ )
+
+ VASTBASE_MAX_CONNECTION: PositiveInt = Field(
+ description="Max connection of the Vastbase database",
+ default=5,
+ )
diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py
index c7aedc5b8a..c7960e1356 100644
--- a/api/configs/packaging/__init__.py
+++ b/api/configs/packaging/__init__.py
@@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings):
CURRENT_VERSION: str = Field(
description="Dify version",
- default="1.2.0",
+ default="1.3.1",
)
COMMIT_SHA: str = Field(
diff --git a/api/configs/remote_settings_sources/enums.py b/api/configs/remote_settings_sources/enums.py
index 3081f2950f..dd998cac64 100644
--- a/api/configs/remote_settings_sources/enums.py
+++ b/api/configs/remote_settings_sources/enums.py
@@ -3,3 +3,4 @@ from enum import StrEnum
class RemoteSettingsSourceName(StrEnum):
APOLLO = "apollo"
+ NACOS = "nacos"
diff --git a/api/configs/remote_settings_sources/nacos/__init__.py b/api/configs/remote_settings_sources/nacos/__init__.py
new file mode 100644
index 0000000000..b1ce8e87bc
--- /dev/null
+++ b/api/configs/remote_settings_sources/nacos/__init__.py
@@ -0,0 +1,52 @@
+import logging
+import os
+from collections.abc import Mapping
+from typing import Any
+
+from pydantic.fields import FieldInfo
+
+from .http_request import NacosHttpClient
+
+logger = logging.getLogger(__name__)
+
+from configs.remote_settings_sources.base import RemoteSettingsSource
+
+from .utils import _parse_config
+
+
+class NacosSettingsSource(RemoteSettingsSource):
+ def __init__(self, configs: Mapping[str, Any]):
+ self.configs = configs
+ self.remote_configs: dict[str, Any] = {}
+ self.async_init()
+
+ def async_init(self):
+ data_id = os.getenv("DIFY_ENV_NACOS_DATA_ID", "dify-api-env.properties")
+ group = os.getenv("DIFY_ENV_NACOS_GROUP", "nacos-dify")
+ tenant = os.getenv("DIFY_ENV_NACOS_NAMESPACE", "")
+
+ params = {"dataId": data_id, "group": group, "tenant": tenant}
+ try:
+ content = NacosHttpClient().http_request("/nacos/v1/cs/configs", method="GET", headers={}, params=params)
+ self.remote_configs = self._parse_config(content)
+ except Exception as e:
+ logger.exception("[get-access-token] exception occurred")
+ raise
+
+ def _parse_config(self, content: str) -> dict:
+ if not content:
+ return {}
+ try:
+ return _parse_config(self, content)
+ except Exception as e:
+ raise RuntimeError(f"Failed to parse config: {e}")
+
+ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
+ if not isinstance(self.remote_configs, dict):
+ raise ValueError(f"remote configs is not dict, but {type(self.remote_configs)}")
+
+ field_value = self.remote_configs.get(field_name)
+ if field_value is None:
+ return None, field_name, False
+
+ return field_value, field_name, False
diff --git a/api/configs/remote_settings_sources/nacos/http_request.py b/api/configs/remote_settings_sources/nacos/http_request.py
new file mode 100644
index 0000000000..2785bd955b
--- /dev/null
+++ b/api/configs/remote_settings_sources/nacos/http_request.py
@@ -0,0 +1,83 @@
+import base64
+import hashlib
+import hmac
+import logging
+import os
+import time
+
+import requests
+
+logger = logging.getLogger(__name__)
+
+
+class NacosHttpClient:
+ def __init__(self):
+ self.username = os.getenv("DIFY_ENV_NACOS_USERNAME")
+ self.password = os.getenv("DIFY_ENV_NACOS_PASSWORD")
+ self.ak = os.getenv("DIFY_ENV_NACOS_ACCESS_KEY")
+ self.sk = os.getenv("DIFY_ENV_NACOS_SECRET_KEY")
+ self.server = os.getenv("DIFY_ENV_NACOS_SERVER_ADDR", "localhost:8848")
+ self.token = None
+ self.token_ttl = 18000
+ self.token_expire_time: float = 0
+
+ def http_request(self, url, method="GET", headers=None, params=None):
+ try:
+ self._inject_auth_info(headers, params)
+ response = requests.request(method, url="http://" + self.server + url, headers=headers, params=params)
+ response.raise_for_status()
+ return response.text
+ except requests.exceptions.RequestException as e:
+ return f"Request to Nacos failed: {e}"
+
+ def _inject_auth_info(self, headers, params, module="config"):
+ headers.update({"User-Agent": "Nacos-Http-Client-In-Dify:v0.0.1"})
+
+ if module == "login":
+ return
+
+ ts = str(int(time.time() * 1000))
+
+ if self.ak and self.sk:
+ sign_str = self.get_sign_str(params["group"], params["tenant"], ts)
+ headers["Spas-AccessKey"] = self.ak
+ headers["Spas-Signature"] = self.__do_sign(sign_str, self.sk)
+ headers["timeStamp"] = ts
+ if self.username and self.password:
+ self.get_access_token(force_refresh=False)
+ params["accessToken"] = self.token
+
+ def __do_sign(self, sign_str, sk):
+ return (
+ base64.encodebytes(hmac.new(sk.encode(), sign_str.encode(), digestmod=hashlib.sha1).digest())
+ .decode()
+ .strip()
+ )
+
+ def get_sign_str(self, group, tenant, ts):
+ sign_str = ""
+ if tenant:
+ sign_str = tenant + "+"
+ if group:
+ sign_str = sign_str + group + "+"
+ if sign_str:
+ sign_str += ts
+ return sign_str
+
+ def get_access_token(self, force_refresh=False):
+ current_time = time.time()
+ if self.token and not force_refresh and self.token_expire_time > current_time:
+ return self.token
+
+ params = {"username": self.username, "password": self.password}
+ url = "http://" + self.server + "/nacos/v1/auth/login"
+ try:
+ resp = requests.request("POST", url, headers=None, params=params)
+ resp.raise_for_status()
+ response_data = resp.json()
+ self.token = response_data.get("accessToken")
+ self.token_ttl = response_data.get("tokenTtl", 18000)
+ self.token_expire_time = current_time + self.token_ttl - 10
+ except Exception as e:
+ logger.exception("[get-access-token] exception occur")
+ raise
diff --git a/api/configs/remote_settings_sources/nacos/utils.py b/api/configs/remote_settings_sources/nacos/utils.py
new file mode 100644
index 0000000000..f3372563b1
--- /dev/null
+++ b/api/configs/remote_settings_sources/nacos/utils.py
@@ -0,0 +1,31 @@
+def _parse_config(self, content: str) -> dict[str, str]:
+ config: dict[str, str] = {}
+ if not content:
+ return config
+
+ for line in content.splitlines():
+ cleaned_line = line.strip()
+ if not cleaned_line or cleaned_line.startswith(("#", "!")):
+ continue
+
+ separator_index = -1
+ for i, c in enumerate(cleaned_line):
+ if c in ("=", ":") and (i == 0 or cleaned_line[i - 1] != "\\"):
+ separator_index = i
+ break
+
+ if separator_index == -1:
+ continue
+
+ key = cleaned_line[:separator_index].strip()
+ raw_value = cleaned_line[separator_index + 1 :].strip()
+
+ try:
+ decoded_value = bytes(raw_value, "utf-8").decode("unicode_escape")
+ decoded_value = decoded_value.replace(r"\=", "=").replace(r"\:", ":")
+ except UnicodeDecodeError:
+ decoded_value = raw_value
+
+ config[key] = decoded_value
+
+ return config
diff --git a/api/constants/__init__.py b/api/constants/__init__.py
index 9162357466..a84de0a451 100644
--- a/api/constants/__init__.py
+++ b/api/constants/__init__.py
@@ -16,11 +16,25 @@ AUDIO_EXTENSIONS.extend([ext.upper() for ext in AUDIO_EXTENSIONS])
if dify_config.ETL_TYPE == "Unstructured":
- DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls"]
+ DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "vtt", "properties"]
DOCUMENT_EXTENSIONS.extend(("doc", "docx", "csv", "eml", "msg", "pptx", "xml", "epub"))
if dify_config.UNSTRUCTURED_API_URL:
DOCUMENT_EXTENSIONS.append("ppt")
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
else:
- DOCUMENT_EXTENSIONS = ["txt", "markdown", "md", "mdx", "pdf", "html", "htm", "xlsx", "xls", "docx", "csv"]
+ DOCUMENT_EXTENSIONS = [
+ "txt",
+ "markdown",
+ "md",
+ "mdx",
+ "pdf",
+ "html",
+ "htm",
+ "xlsx",
+ "xls",
+ "docx",
+ "csv",
+ "vtt",
+ "properties",
+ ]
DOCUMENT_EXTENSIONS.extend([ext.upper() for ext in DOCUMENT_EXTENSIONS])
diff --git a/api/constants/mimetypes.py b/api/constants/mimetypes.py
new file mode 100644
index 0000000000..38988cdd24
--- /dev/null
+++ b/api/constants/mimetypes.py
@@ -0,0 +1,7 @@
+# The two constants below should keep in sync.
+# Default content type for files which have no explicit content type.
+
+DEFAULT_MIME_TYPE = "application/octet-stream"
+# Default file extension for files which have no explicit content type, should
+# correspond to the `DEFAULT_MIME_TYPE` above.
+DEFAULT_EXTENSION = ".bin"
diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py
index b1ebc444a5..79869916ed 100644
--- a/api/controllers/common/fields.py
+++ b/api/controllers/common/fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
parameters__system_parameters = {
"image_file_size_limit": fields.Integer,
diff --git a/api/controllers/console/admin.py b/api/controllers/console/admin.py
index 6e3273f5d4..8cb7ad9f5b 100644
--- a/api/controllers/console/admin.py
+++ b/api/controllers/console/admin.py
@@ -1,7 +1,7 @@
from functools import wraps
from flask import request
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound, Unauthorized
diff --git a/api/controllers/console/apikey.py b/api/controllers/console/apikey.py
index eb42507c63..47c93a15c6 100644
--- a/api/controllers/console/apikey.py
+++ b/api/controllers/console/apikey.py
@@ -1,7 +1,7 @@
from typing import Any
-import flask_restful # type: ignore
-from flask_login import current_user # type: ignore
+import flask_restful
+from flask_login import current_user
from flask_restful import Resource, fields, marshal_with
from sqlalchemy import select
from sqlalchemy.orm import Session
diff --git a/api/controllers/console/app/advanced_prompt_template.py b/api/controllers/console/app/advanced_prompt_template.py
index 8d0c5b84af..c228743fa5 100644
--- a/api/controllers/console/app/advanced_prompt_template.py
+++ b/api/controllers/console/app/advanced_prompt_template.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
diff --git a/api/controllers/console/app/agent.py b/api/controllers/console/app/agent.py
index 920cae0d85..d433415894 100644
--- a/api/controllers/console/app/agent.py
+++ b/api/controllers/console/app/agent.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py
index fcd8ed1882..91058767eb 100644
--- a/api/controllers/console/app/annotation.py
+++ b/api/controllers/console/app/annotation.py
@@ -1,6 +1,6 @@
from flask import request
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
@@ -186,7 +186,7 @@ class AnnotationUpdateDeleteApi(Resource):
app_id = str(app_id)
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_id, annotation_id)
- return {"result": "success"}, 200
+ return {"result": "success"}, 204
class AnnotationBatchImportApi(Resource):
diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py
index 3e908b76a7..f97209c369 100644
--- a/api/controllers/console/app/app.py
+++ b/api/controllers/console/app/app.py
@@ -1,8 +1,8 @@
import uuid
from typing import cast
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, inputs, marshal, marshal_with, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, abort
diff --git a/api/controllers/console/app/app_import.py b/api/controllers/console/app/app_import.py
index a159d4c5c4..5dc6515ce0 100644
--- a/api/controllers/console/app/app_import.py
+++ b/api/controllers/console/app/app_import.py
@@ -1,7 +1,7 @@
from typing import cast
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, marshal_with, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py
index 12d9157dda..5f2def8d8e 100644
--- a/api/controllers/console/app/audio.py
+++ b/api/controllers/console/app/audio.py
@@ -1,7 +1,7 @@
import logging
from flask import request
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError
import services
@@ -80,8 +80,6 @@ class ChatMessageTextApi(Resource):
@account_initialization_required
@get_app_model
def post(self, app_model: App):
- from werkzeug.exceptions import InternalServerError
-
try:
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, location="json")
diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py
index c9820f70f7..732f5b799a 100644
--- a/api/controllers/console/app/completion.py
+++ b/api/controllers/console/app/completion.py
@@ -1,7 +1,7 @@
import logging
-import flask_login # type: ignore
-from flask_restful import Resource, reqparse # type: ignore
+import flask_login
+from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py
index 8827f129d9..70d6216497 100644
--- a/api/controllers/console/app/conversation.py
+++ b/api/controllers/console/app/conversation.py
@@ -1,9 +1,9 @@
from datetime import UTC, datetime
import pytz # pip install pytz
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, marshal_with, reqparse # type: ignore
-from flask_restful.inputs import int_range # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, marshal_with, reqparse
+from flask_restful.inputs import int_range
from sqlalchemy import func, or_
from sqlalchemy.orm import joinedload
from werkzeug.exceptions import Forbidden, NotFound
diff --git a/api/controllers/console/app/conversation_variables.py b/api/controllers/console/app/conversation_variables.py
index c0a20b7160..d49f433ba1 100644
--- a/api/controllers/console/app/conversation_variables.py
+++ b/api/controllers/console/app/conversation_variables.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, marshal_with, reqparse # type: ignore
+from flask_restful import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py
index 4046417076..790369c052 100644
--- a/api/controllers/console/app/generator.py
+++ b/api/controllers/console/app/generator.py
@@ -1,7 +1,7 @@
import os
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.app.error import (
diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py
index b5828b6b4b..b7a4c31a15 100644
--- a/api/controllers/console/app/message.py
+++ b/api/controllers/console/app/message.py
@@ -1,8 +1,8 @@
import logging
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
-from flask_restful.inputs import int_range # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, fields, marshal_with, reqparse
+from flask_restful.inputs import int_range
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
from controllers.console import api
diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py
index 8ecc8a9db5..f30e3e893c 100644
--- a/api/controllers/console/app/model_config.py
+++ b/api/controllers/console/app/model_config.py
@@ -2,8 +2,8 @@ import json
from typing import cast
from flask import request
-from flask_login import current_user # type: ignore
-from flask_restful import Resource # type: ignore
+from flask_login import current_user
+from flask_restful import Resource
from controllers.console import api
from controllers.console.app.wraps import get_app_model
diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py
index dd25af8ebf..978c02412c 100644
--- a/api/controllers/console/app/ops_trace.py
+++ b/api/controllers/console/app/ops_trace.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from werkzeug.exceptions import BadRequest
from controllers.console import api
@@ -84,7 +84,7 @@ class TraceAppConfigApi(Resource):
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"])
if not result:
raise TracingConfigNotExist()
- return {"result": "success"}
+ return {"result": "success"}, 204
except Exception as e:
raise BadRequest(str(e))
diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py
index f15f9d4dae..3c3a359eeb 100644
--- a/api/controllers/console/app/site.py
+++ b/api/controllers/console/app/site.py
@@ -1,7 +1,7 @@
from datetime import UTC, datetime
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, marshal_with, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from constants.languages import supported_language
diff --git a/api/controllers/console/app/statistic.py b/api/controllers/console/app/statistic.py
index a37d26b989..86aed77412 100644
--- a/api/controllers/console/app/statistic.py
+++ b/api/controllers/console/app/statistic.py
@@ -3,8 +3,8 @@ from decimal import Decimal
import pytz
from flask import jsonify
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py
index 2e077d2095..0c13adce9b 100644
--- a/api/controllers/console/app/workflow.py
+++ b/api/controllers/console/app/workflow.py
@@ -3,7 +3,7 @@ import logging
from typing import cast
from flask import abort, request
-from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore
+from flask_restful import Resource, inputs, marshal_with, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py
index d863747995..c475aea9fc 100644
--- a/api/controllers/console/app/workflow_app_log.py
+++ b/api/controllers/console/app/workflow_app_log.py
@@ -1,6 +1,6 @@
from dateutil.parser import isoparse
-from flask_restful import Resource, marshal_with, reqparse # type: ignore
-from flask_restful.inputs import int_range # type: ignore
+from flask_restful import Resource, marshal_with, reqparse
+from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from controllers.console import api
diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py
index 25a99c1e15..08ab61bbb9 100644
--- a/api/controllers/console/app/workflow_run.py
+++ b/api/controllers/console/app/workflow_run.py
@@ -1,5 +1,5 @@
-from flask_restful import Resource, marshal_with, reqparse # type: ignore
-from flask_restful.inputs import int_range # type: ignore
+from flask_restful import Resource, marshal_with, reqparse
+from flask_restful.inputs import int_range
from controllers.console import api
from controllers.console.app.wraps import get_app_model
diff --git a/api/controllers/console/app/workflow_statistic.py b/api/controllers/console/app/workflow_statistic.py
index 097bf7d188..6c7c73707b 100644
--- a/api/controllers/console/app/workflow_statistic.py
+++ b/api/controllers/console/app/workflow_statistic.py
@@ -3,8 +3,8 @@ from decimal import Decimal
import pytz
from flask import jsonify
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.app.wraps import get_app_model
diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py
index c56f551d49..1795563ff7 100644
--- a/api/controllers/console/auth/activate.py
+++ b/api/controllers/console/auth/activate.py
@@ -1,7 +1,7 @@
import datetime
from flask import request
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from constants.languages import supported_language
from controllers.console import api
diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py
index ea00c2b8c2..b8c3c8f012 100644
--- a/api/controllers/console/auth/data_source_bearer_auth.py
+++ b/api/controllers/console/auth/data_source_bearer_auth.py
@@ -1,5 +1,5 @@
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
@@ -65,7 +65,7 @@ class ApiKeyAuthDataSourceBindingDelete(Resource):
ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id)
- return {"result": "success"}, 200
+ return {"result": "success"}, 204
api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source")
diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py
index b4bd80fe2f..1049f864c3 100644
--- a/api/controllers/console/auth/data_source_oauth.py
+++ b/api/controllers/console/auth/data_source_oauth.py
@@ -2,8 +2,8 @@ import logging
import requests
from flask import current_app, redirect, request
-from flask_login import current_user # type: ignore
-from flask_restful import Resource # type: ignore
+from flask_login import current_user
+from flask_restful import Resource
from werkzeug.exceptions import Forbidden
from configs import dify_config
diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py
index d4a33645ab..d73d8ce701 100644
--- a/api/controllers/console/auth/forgot_password.py
+++ b/api/controllers/console/auth/forgot_password.py
@@ -2,7 +2,7 @@ import base64
import secrets
from flask import request
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py
index 16c1dcc441..27864bab3d 100644
--- a/api/controllers/console/auth/login.py
+++ b/api/controllers/console/auth/login.py
@@ -1,8 +1,8 @@
from typing import cast
-import flask_login # type: ignore
+import flask_login
from flask import request
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
import services
from configs import dify_config
diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py
index 33bafbf463..f5284cc43b 100644
--- a/api/controllers/console/auth/oauth.py
+++ b/api/controllers/console/auth/oauth.py
@@ -4,7 +4,7 @@ from typing import Optional
import requests
from flask import current_app, redirect, request
-from flask_restful import Resource # type: ignore
+from flask_restful import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
diff --git a/api/controllers/console/billing/billing.py b/api/controllers/console/billing/billing.py
index fd7b7bd8cb..4b0c82ae6c 100644
--- a/api/controllers/console/billing/billing.py
+++ b/api/controllers/console/billing/billing.py
@@ -1,5 +1,5 @@
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
diff --git a/api/controllers/console/billing/compliance.py b/api/controllers/console/billing/compliance.py
index 6d5d668709..9679632ac7 100644
--- a/api/controllers/console/billing/compliance.py
+++ b/api/controllers/console/billing/compliance.py
@@ -1,6 +1,6 @@
from flask import request
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, reqparse
from libs.helper import extract_remote_ip
from libs.login import login_required
diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py
index 70bfb217eb..7b0d9373cf 100644
--- a/api/controllers/console/datasets/data_source.py
+++ b/api/controllers/console/datasets/data_source.py
@@ -2,8 +2,8 @@ import datetime
import json
from flask import request
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, marshal_with, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, marshal_with, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py
index 4644ac6299..571a395780 100644
--- a/api/controllers/console/datasets/datasets.py
+++ b/api/controllers/console/datasets/datasets.py
@@ -1,7 +1,7 @@
-import flask_restful # type: ignore
+import flask_restful
from flask import request
-from flask_login import current_user # type: ignore # type: ignore
-from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, NotFound
import services
@@ -657,6 +657,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.ELASTICSEARCH
| VectorType.ELASTICSEARCH_JA
| VectorType.PGVECTOR
+ | VectorType.VASTBASE
| VectorType.TIDB_ON_QDRANT
| VectorType.LINDORM
| VectorType.COUCHBASE
@@ -664,6 +665,7 @@ class DatasetRetrievalSettingApi(Resource):
| VectorType.OPENGAUSS
| VectorType.OCEANBASE
| VectorType.TABLESTORE
+ | VectorType.HUAWEI_CLOUD
| VectorType.TENCENT
):
return {
@@ -705,11 +707,13 @@ class DatasetRetrievalSettingMockApi(Resource):
| VectorType.ELASTICSEARCH_JA
| VectorType.COUCHBASE
| VectorType.PGVECTOR
+ | VectorType.VASTBASE
| VectorType.LINDORM
| VectorType.OPENGAUSS
| VectorType.OCEANBASE
| VectorType.TABLESTORE
| VectorType.TENCENT
+ | VectorType.HUAWEI_CLOUD
):
return {
"retrieval_method": [
diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py
index 0b40312368..68601adfed 100644
--- a/api/controllers/console/datasets/datasets_document.py
+++ b/api/controllers/console/datasets/datasets_document.py
@@ -4,8 +4,8 @@ from datetime import UTC, datetime
from typing import cast
from flask import request
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, fields, marshal, marshal_with, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc
from werkzeug.exceptions import Forbidden, NotFound
@@ -40,7 +40,7 @@ from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError
-from core.plugin.manager.exc import PluginDaemonClientSideError
+from core.plugin.impl.exc import PluginDaemonClientSideError
from core.rag.extractor.entity.extract_setting import ExtractSetting
from extensions.ext_database import db
from extensions.ext_redis import redis_client
diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py
index 696aaa94db..a145038672 100644
--- a/api/controllers/console/datasets/datasets_segments.py
+++ b/api/controllers/console/datasets/datasets_segments.py
@@ -2,8 +2,8 @@ import uuid
import pandas as pd
from flask import request
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, marshal, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound
import services
@@ -131,7 +131,7 @@ class DatasetDocumentSegmentListApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segments(segment_ids, document, dataset)
- return {"result": "success"}, 200
+ return {"result": "success"}, 204
class DatasetDocumentSegmentApi(Resource):
@@ -333,7 +333,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segment(segment, document, dataset)
- return {"result": "success"}, 200
+ return {"result": "success"}, 204
class DatasetDocumentSegmentBatchImportApi(Resource):
@@ -590,7 +590,7 @@ class ChildChunkUpdateApi(Resource):
SegmentService.delete_child_chunk(child_chunk, dataset)
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
- return {"result": "success"}, 200
+ return {"result": "success"}, 204
@setup_required
@login_required
diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py
index 2c031172bf..cf9081e154 100644
--- a/api/controllers/console/datasets/external.py
+++ b/api/controllers/console/datasets/external.py
@@ -1,6 +1,6 @@
from flask import request
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, marshal, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
@@ -135,7 +135,7 @@ class ExternalApiTemplateApi(Resource):
raise Forbidden()
ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id)
- return {"result": "success"}, 200
+ return {"result": "success"}, 204
class ExternalApiUseCheckApi(Resource):
@@ -209,6 +209,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
+ parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
args = parser.parse_args()
HitTestingService.hit_testing_args_check(args)
@@ -219,6 +220,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
query=args["query"],
account=current_user,
external_retrieval_model=args["external_retrieval_model"],
+ metadata_filtering_conditions=args["metadata_filtering_conditions"],
)
return response
diff --git a/api/controllers/console/datasets/hit_testing.py b/api/controllers/console/datasets/hit_testing.py
index d344e9d126..fba5d4c0f3 100644
--- a/api/controllers/console/datasets/hit_testing.py
+++ b/api/controllers/console/datasets/hit_testing.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource # type: ignore
+from flask_restful import Resource
from controllers.console import api
from controllers.console.datasets.hit_testing_base import DatasetsHitTestingBase
diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py
index bd944602c1..3b4c076863 100644
--- a/api/controllers/console/datasets/hit_testing_base.py
+++ b/api/controllers/console/datasets/hit_testing_base.py
@@ -1,7 +1,7 @@
import logging
-from flask_login import current_user # type: ignore
-from flask_restful import marshal, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import marshal, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services.dataset_service
diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py
index fc9711169f..b1a83aa371 100644
--- a/api/controllers/console/datasets/metadata.py
+++ b/api/controllers/console/datasets/metadata.py
@@ -1,5 +1,5 @@
-from flask_login import current_user # type: ignore # type: ignore
-from flask_restful import Resource, marshal_with, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound
from controllers.console import api
@@ -82,7 +82,7 @@ class DatasetMetadataApi(Resource):
DatasetService.check_dataset_permission(dataset, current_user)
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
- return 200
+ return {"result": "success"}, 204
class DatasetMetadataBuiltInFieldApi(Resource):
diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py
index 33c926b4c9..4200a51709 100644
--- a/api/controllers/console/datasets/website.py
+++ b/api/controllers/console/datasets/website.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.datasets.error import WebsiteCrawlError
diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py
index c7f9fec326..54bc590677 100644
--- a/api/controllers/console/explore/audio.py
+++ b/api/controllers/console/explore/audio.py
@@ -66,7 +66,7 @@ class ChatAudioApi(InstalledAppResource):
class ChatTextApi(InstalledAppResource):
def post(self, installed_app):
- from flask_restful import reqparse # type: ignore
+ from flask_restful import reqparse
app_model = installed_app.app
try:
diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py
index e693a5a71b..4367da1162 100644
--- a/api/controllers/console/explore/completion.py
+++ b/api/controllers/console/explore/completion.py
@@ -1,8 +1,8 @@
import logging
from datetime import UTC, datetime
-from flask_login import current_user # type: ignore
-from flask_restful import reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
diff --git a/api/controllers/console/explore/conversation.py b/api/controllers/console/explore/conversation.py
index 600e78e09e..d7c161cc6d 100644
--- a/api/controllers/console/explore/conversation.py
+++ b/api/controllers/console/explore/conversation.py
@@ -1,6 +1,6 @@
-from flask_login import current_user # type: ignore
-from flask_restful import marshal_with, reqparse # type: ignore
-from flask_restful.inputs import int_range # type: ignore
+from flask_login import current_user
+from flask_restful import marshal_with, reqparse
+from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py
index 86550b2bdf..9336c35a0d 100644
--- a/api/controllers/console/explore/installed_app.py
+++ b/api/controllers/console/explore/installed_app.py
@@ -2,8 +2,8 @@ from datetime import UTC, datetime
from typing import Any
from flask import request
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, inputs, marshal_with, reqparse
from sqlalchemy import and_
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
@@ -113,7 +113,7 @@ class InstalledAppApi(InstalledAppResource):
db.session.delete(installed_app)
db.session.commit()
- return {"result": "success", "message": "App uninstalled successfully"}
+ return {"result": "success", "message": "App uninstalled successfully"}, 204
def patch(self, installed_app):
parser = reqparse.RequestParser()
diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py
index ff12959a65..822777604a 100644
--- a/api/controllers/console/explore/message.py
+++ b/api/controllers/console/explore/message.py
@@ -1,8 +1,8 @@
import logging
-from flask_login import current_user # type: ignore
-from flask_restful import marshal_with, reqparse # type: ignore
-from flask_restful.inputs import int_range # type: ignore
+from flask_login import current_user
+from flask_restful import marshal_with, reqparse
+from flask_restful.inputs import int_range
from werkzeug.exceptions import InternalServerError, NotFound
import services
diff --git a/api/controllers/console/explore/parameter.py b/api/controllers/console/explore/parameter.py
index bf9f0d6b28..a1280d91d1 100644
--- a/api/controllers/console/explore/parameter.py
+++ b/api/controllers/console/explore/parameter.py
@@ -1,4 +1,4 @@
-from flask_restful import marshal_with # type: ignore
+from flask_restful import marshal_with
from controllers.common import fields
from controllers.console import api
diff --git a/api/controllers/console/explore/recommended_app.py b/api/controllers/console/explore/recommended_app.py
index be6b1f5d21..ce85f495aa 100644
--- a/api/controllers/console/explore/recommended_app.py
+++ b/api/controllers/console/explore/recommended_app.py
@@ -1,5 +1,5 @@
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, fields, marshal_with, reqparse
from constants.languages import languages
from controllers.console import api
diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py
index 9f0c496645..339e7007a0 100644
--- a/api/controllers/console/explore/saved_message.py
+++ b/api/controllers/console/explore/saved_message.py
@@ -1,6 +1,6 @@
-from flask_login import current_user # type: ignore
-from flask_restful import fields, marshal_with, reqparse # type: ignore
-from flask_restful.inputs import int_range # type: ignore
+from flask_login import current_user
+from flask_restful import fields, marshal_with, reqparse
+from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
from controllers.console import api
@@ -72,7 +72,7 @@ class SavedMessageApi(InstalledAppResource):
SavedMessageService.delete(app_model, current_user, message_id)
- return {"result": "success"}
+ return {"result": "success"}, 204
api.add_resource(
diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py
index a2653a94f6..3f625e6609 100644
--- a/api/controllers/console/explore/workflow.py
+++ b/api/controllers/console/explore/workflow.py
@@ -1,6 +1,6 @@
import logging
-from flask_restful import reqparse # type: ignore
+from flask_restful import reqparse
from werkzeug.exceptions import InternalServerError
from controllers.console.app.error import (
diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py
index b7ba81fba2..49ea81a8a0 100644
--- a/api/controllers/console/explore/wraps.py
+++ b/api/controllers/console/explore/wraps.py
@@ -1,7 +1,7 @@
from functools import wraps
-from flask_login import current_user # type: ignore
-from flask_restful import Resource # type: ignore
+from flask_login import current_user
+from flask_restful import Resource
from werkzeug.exceptions import NotFound
from controllers.console.wraps import account_initialization_required
diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py
index ed6cedb220..07a241ef86 100644
--- a/api/controllers/console/extension.py
+++ b/api/controllers/console/extension.py
@@ -1,5 +1,5 @@
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, marshal_with, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, marshal_with, reqparse
from constants import HIDDEN_VALUE
from controllers.console import api
@@ -99,7 +99,7 @@ class APIBasedExtensionDetailAPI(Resource):
APIBasedExtensionService.delete(extension_data_from_db)
- return {"result": "success"}
+ return {"result": "success"}, 204
api.add_resource(CodeBasedExtensionAPI, "/code-based-extension")
diff --git a/api/controllers/console/feature.py b/api/controllers/console/feature.py
index da1171412f..70ab4ff865 100644
--- a/api/controllers/console/feature.py
+++ b/api/controllers/console/feature.py
@@ -1,5 +1,5 @@
-from flask_login import current_user # type: ignore
-from flask_restful import Resource # type: ignore
+from flask_login import current_user
+from flask_restful import Resource
from libs.login import login_required
from services.feature_service import FeatureService
diff --git a/api/controllers/console/files.py b/api/controllers/console/files.py
index 8cf754bbd6..66b6214f82 100644
--- a/api/controllers/console/files.py
+++ b/api/controllers/console/files.py
@@ -1,8 +1,8 @@
from typing import Literal
from flask import request
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, marshal_with # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, marshal_with
from werkzeug.exceptions import Forbidden
import services
diff --git a/api/controllers/console/init_validate.py b/api/controllers/console/init_validate.py
index cfed5fe7a4..b19e331d2e 100644
--- a/api/controllers/console/init_validate.py
+++ b/api/controllers/console/init_validate.py
@@ -1,7 +1,7 @@
import os
from flask import session
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session
diff --git a/api/controllers/console/ping.py b/api/controllers/console/ping.py
index 2a116112a3..cd28cc946e 100644
--- a/api/controllers/console/ping.py
+++ b/api/controllers/console/ping.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource # type: ignore
+from flask_restful import Resource
from controllers.console import api
diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py
index 30afc930a8..b8cf019e4f 100644
--- a/api/controllers/console/remote_files.py
+++ b/api/controllers/console/remote_files.py
@@ -2,8 +2,8 @@ import urllib.parse
from typing import cast
import httpx
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, marshal_with, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, marshal_with, reqparse
import services
from controllers.common import helpers
diff --git a/api/controllers/console/setup.py b/api/controllers/console/setup.py
index 3b47f8f12f..e1f19a87a3 100644
--- a/api/controllers/console/setup.py
+++ b/api/controllers/console/setup.py
@@ -1,5 +1,5 @@
from flask import request
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from configs import dify_config
from libs.helper import StrLen, email, extract_remote_ip
diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py
index da83f64019..cb5dedca21 100644
--- a/api/controllers/console/tag/tags.py
+++ b/api/controllers/console/tag/tags.py
@@ -1,6 +1,6 @@
from flask import request
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, marshal_with, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
@@ -86,7 +86,7 @@ class TagUpdateDeleteApi(Resource):
TagService.delete_tag(tag_id)
- return 200
+ return 204
class TagBindingCreateApi(Resource):
diff --git a/api/controllers/console/version.py b/api/controllers/console/version.py
index 7773c99944..7dea8e554e 100644
--- a/api/controllers/console/version.py
+++ b/api/controllers/console/version.py
@@ -2,7 +2,7 @@ import json
import logging
import requests
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from packaging import version
from configs import dify_config
diff --git a/api/controllers/console/workspace/__init__.py b/api/controllers/console/workspace/__init__.py
index 7af2b44a4a..072e904caf 100644
--- a/api/controllers/console/workspace/__init__.py
+++ b/api/controllers/console/workspace/__init__.py
@@ -1,6 +1,6 @@
from functools import wraps
-from flask_login import current_user # type: ignore
+from flask_login import current_user
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py
index e9c25e6c5b..a9dbf44456 100644
--- a/api/controllers/console/workspace/account.py
+++ b/api/controllers/console/workspace/account.py
@@ -2,8 +2,8 @@ import datetime
import pytz
from flask import request
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, fields, marshal_with, reqparse
from configs import dify_config
from constants.languages import supported_language
diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py
index a41d6c501c..88c37767e3 100644
--- a/api/controllers/console/workspace/agent_providers.py
+++ b/api/controllers/console/workspace/agent_providers.py
@@ -1,5 +1,5 @@
-from flask_login import current_user # type: ignore
-from flask_restful import Resource # type: ignore
+from flask_login import current_user
+from flask_restful import Resource
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py
index a5bd2a4bcf..eb53dcb16e 100644
--- a/api/controllers/console/workspace/endpoint.py
+++ b/api/controllers/console/workspace/endpoint.py
@@ -1,10 +1,11 @@
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
+from core.plugin.impl.exc import PluginPermissionDeniedError
from libs.login import login_required
from services.plugin.endpoint_service import EndpointService
@@ -28,15 +29,18 @@ class EndpointCreateApi(Resource):
settings = args["settings"]
name = args["name"]
- return {
- "success": EndpointService.create_endpoint(
- tenant_id=user.current_tenant_id,
- user_id=user.id,
- plugin_unique_identifier=plugin_unique_identifier,
- name=name,
- settings=settings,
- )
- }
+ try:
+ return {
+ "success": EndpointService.create_endpoint(
+ tenant_id=user.current_tenant_id,
+ user_id=user.id,
+ plugin_unique_identifier=plugin_unique_identifier,
+ name=name,
+ settings=settings,
+ )
+ }
+ except PluginPermissionDeniedError as e:
+ raise ValueError(e.description) from e
class EndpointListApi(Resource):
diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py
index 6e1d87cb12..ba74e2c074 100644
--- a/api/controllers/console/workspace/load_balancing_config.py
+++ b/api/controllers/console/workspace/load_balancing_config.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py
index a2b41c1d38..b9918b0d32 100644
--- a/api/controllers/console/workspace/members.py
+++ b/api/controllers/console/workspace/members.py
@@ -1,7 +1,7 @@
from urllib import parse
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, abort, marshal_with, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, abort, marshal_with, reqparse
import services
from configs import dify_config
diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py
index d7d1cc8d00..ff0fcbda6e 100644
--- a/api/controllers/console/workspace/model_providers.py
+++ b/api/controllers/console/workspace/model_providers.py
@@ -1,8 +1,8 @@
import io
from flask import send_file
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py
index 8b72a1ea3d..37d0f6c764 100644
--- a/api/controllers/console/workspace/models.py
+++ b/api/controllers/console/workspace/models.py
@@ -1,7 +1,7 @@
import logging
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from controllers.console import api
diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py
index e9c1884c60..fda5a7d3bb 100644
--- a/api/controllers/console/workspace/plugin.py
+++ b/api/controllers/console/workspace/plugin.py
@@ -1,8 +1,8 @@
import io
from flask import request, send_file
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden
from configs import dify_config
@@ -10,7 +10,7 @@ from controllers.console import api
from controllers.console.workspace import plugin_permission_required
from controllers.console.wraps import account_initialization_required, setup_required
from core.model_runtime.utils.encoders import jsonable_encoder
-from core.plugin.manager.exc import PluginDaemonClientSideError
+from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import login_required
from models.account import TenantPluginPermission
from services.plugin.plugin_permission_service import PluginPermissionService
diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py
index 39ab454922..2b1379bfb2 100644
--- a/api/controllers/console/workspace/tool_providers.py
+++ b/api/controllers/console/workspace/tool_providers.py
@@ -1,8 +1,8 @@
import io
from flask import send_file
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, reqparse
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden
diff --git a/api/controllers/console/workspace/workspace.py b/api/controllers/console/workspace/workspace.py
index 332ed00222..71e6f9178f 100644
--- a/api/controllers/console/workspace/workspace.py
+++ b/api/controllers/console/workspace/workspace.py
@@ -1,8 +1,8 @@
import logging
from flask import request
-from flask_login import current_user # type: ignore
-from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
from werkzeug.exceptions import Unauthorized
import services
diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py
index e5e8038ad7..360cbd9246 100644
--- a/api/controllers/console/wraps.py
+++ b/api/controllers/console/wraps.py
@@ -4,12 +4,13 @@ import time
from functools import wraps
from flask import abort, request
-from flask_login import current_user # type: ignore
+from flask_login import current_user
from configs import dify_config
from controllers.console.workspace.error import AccountNotInitializedError
from extensions.ext_database import db
from extensions.ext_redis import redis_client
+from models.account import AccountStatus
from models.dataset import RateLimitLog
from models.model import DifySetup
from services.feature_service import FeatureService, LicenseStatus
@@ -24,7 +25,7 @@ def account_initialization_required(view):
# check account initialization
account = current_user
- if account.status == "uninitialized":
+ if account.status == AccountStatus.UNINITIALIZED:
raise AccountNotInitializedError()
return view(*args, **kwargs)
diff --git a/api/controllers/files/image_preview.py b/api/controllers/files/image_preview.py
index 5adfe16a79..46c19e1fbb 100644
--- a/api/controllers/files/image_preview.py
+++ b/api/controllers/files/image_preview.py
@@ -1,7 +1,7 @@
from urllib.parse import quote
from flask import Response, request
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from werkzeug.exceptions import NotFound
import services
@@ -70,12 +70,26 @@ class FilePreviewApi(Resource):
direct_passthrough=True,
headers={},
)
+ # add Accept-Ranges header for audio/video files
+ if upload_file.mime_type in [
+ "audio/mpeg",
+ "audio/wav",
+ "audio/mp4",
+ "audio/ogg",
+ "audio/flac",
+ "audio/aac",
+ "video/mp4",
+ "video/webm",
+ "video/quicktime",
+ "audio/x-m4a",
+ ]:
+ response.headers["Accept-Ranges"] = "bytes"
if upload_file.size > 0:
response.headers["Content-Length"] = str(upload_file.size)
if args["as_attachment"]:
encoded_filename = quote(upload_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
- response.headers["Content-Type"] = "application/octet-stream"
+ response.headers["Content-Type"] = "application/octet-stream"
return response
diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py
index cfcce81247..1c3430ef4f 100644
--- a/api/controllers/files/tool_files.py
+++ b/api/controllers/files/tool_files.py
@@ -1,10 +1,14 @@
+from urllib.parse import quote
+
from flask import Response
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden, NotFound
from controllers.files import api
from controllers.files.error import UnsupportedFileTypeError
+from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager
+from models import db as global_db
class ToolFilePreviewApi(Resource):
@@ -19,17 +23,14 @@ class ToolFilePreviewApi(Resource):
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
args = parser.parse_args()
-
- if not ToolFileManager.verify_file(
- file_id=file_id,
- timestamp=args["timestamp"],
- nonce=args["nonce"],
- sign=args["sign"],
+ if not verify_tool_file_signature(
+ file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], sign=args["sign"]
):
raise Forbidden("Invalid request.")
try:
- stream, tool_file = ToolFileManager.get_file_generator_by_tool_file_id(
+ tool_file_manager = ToolFileManager(engine=global_db.engine)
+ stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id(
file_id,
)
@@ -47,7 +48,8 @@ class ToolFilePreviewApi(Resource):
if tool_file.size > 0:
response.headers["Content-Length"] = str(tool_file.size)
if args["as_attachment"]:
- response.headers["Content-Disposition"] = f"attachment; filename={tool_file.name}"
+ encoded_filename = quote(tool_file.name)
+ response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
return response
diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py
index ca5ea54435..6641632169 100644
--- a/api/controllers/files/upload.py
+++ b/api/controllers/files/upload.py
@@ -1,5 +1,7 @@
+from mimetypes import guess_extension
+
from flask import request
-from flask_restful import Resource, marshal_with # type: ignore
+from flask_restful import Resource, marshal_with
from werkzeug.exceptions import Forbidden
import services
@@ -9,8 +11,8 @@ from controllers.files.error import UnsupportedFileTypeError
from controllers.inner_api.plugin.wraps import get_user
from controllers.service_api.app.error import FileTooLargeError
from core.file.helpers import verify_plugin_file_signature
+from core.tools.tool_file_manager import ToolFileManager
from fields.file_fields import file_fields
-from services.file_service import FileService
class PluginUploadFileApi(Resource):
@@ -51,19 +53,26 @@ class PluginUploadFileApi(Resource):
raise Forbidden("Invalid request.")
try:
- upload_file = FileService.upload_file(
- filename=filename,
- content=file.read(),
+ tool_file = ToolFileManager().create_file_by_raw(
+ user_id=user.id,
+ tenant_id=tenant_id,
+ file_binary=file.read(),
mimetype=mimetype,
- user=user,
- source=None,
+ filename=filename,
+ conversation_id=None,
)
+
+ extension = guess_extension(tool_file.mimetype) or ".bin"
+ preview_url = ToolFileManager.sign_file(tool_file_id=tool_file.id, extension=extension)
+ tool_file.mime_type = mimetype
+ tool_file.extension = extension
+ tool_file.preview_url = preview_url
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
- return upload_file, 201
+ return tool_file, 201
api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")
diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py
index 061ad62a4a..f3a1bd8fa5 100644
--- a/api/controllers/inner_api/plugin/plugin.py
+++ b/api/controllers/inner_api/plugin/plugin.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource # type: ignore
+from flask_restful import Resource
from controllers.console.wraps import setup_required
from controllers.inner_api import api
diff --git a/api/controllers/inner_api/plugin/wraps.py b/api/controllers/inner_api/plugin/wraps.py
index c31f9d22ed..709bba3f30 100644
--- a/api/controllers/inner_api/plugin/wraps.py
+++ b/api/controllers/inner_api/plugin/wraps.py
@@ -3,7 +3,7 @@ from functools import wraps
from typing import Optional
from flask import request
-from flask_restful import reqparse # type: ignore
+from flask_restful import reqparse
from pydantic import BaseModel
from sqlalchemy.orm import Session
diff --git a/api/controllers/inner_api/workspace/workspace.py b/api/controllers/inner_api/workspace/workspace.py
index 9dfa5d23c3..a2fc2d4675 100644
--- a/api/controllers/inner_api/workspace/workspace.py
+++ b/api/controllers/inner_api/workspace/workspace.py
@@ -1,6 +1,6 @@
import json
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from controllers.console.wraps import setup_required
from controllers.inner_api import api
diff --git a/api/controllers/inner_api/wraps.py b/api/controllers/inner_api/wraps.py
index 86d3ad3dc5..f3a9312dd0 100644
--- a/api/controllers/inner_api/wraps.py
+++ b/api/controllers/inner_api/wraps.py
@@ -18,7 +18,7 @@ def enterprise_inner_api_only(view):
# get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get("X-Inner-Api-Key")
- if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
+ if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
abort(401)
return view(*args, **kwargs)
diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py
index cffa3665b1..bd1a23b723 100644
--- a/api/controllers/service_api/app/annotation.py
+++ b/api/controllers/service_api/app/annotation.py
@@ -1,5 +1,5 @@
from flask import request
-from flask_restful import Resource, marshal, marshal_with, reqparse # type: ignore
+from flask_restful import Resource, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden
from controllers.service_api import api
@@ -79,7 +79,7 @@ class AnnotationListApi(Resource):
class AnnotationUpdateDeleteApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.JSON))
@marshal_with(annotation_fields)
- def post(self, app_model: App, end_user: EndUser, annotation_id):
+ def put(self, app_model: App, end_user: EndUser, annotation_id):
if not current_user.is_editor:
raise Forbidden()
@@ -98,7 +98,7 @@ class AnnotationUpdateDeleteApi(Resource):
annotation_id = str(annotation_id)
AppAnnotationService.delete_app_annotation(app_model.id, annotation_id)
- return {"result": "success"}, 200
+ return {"result": "success"}, 204
api.add_resource(AnnotationReplyActionApi, "/apps/annotation-reply/")
diff --git a/api/controllers/service_api/app/app.py b/api/controllers/service_api/app/app.py
index 7131e8a310..2c03aba33d 100644
--- a/api/controllers/service_api/app/app.py
+++ b/api/controllers/service_api/app/app.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource, marshal_with # type: ignore
+from flask_restful import Resource, marshal_with
from controllers.common import fields
from controllers.service_api import api
@@ -47,7 +47,7 @@ class AppInfoApi(Resource):
def get(self, app_model: App):
"""Get app information"""
tags = [tag.name for tag in app_model.tags]
- return {"name": app_model.name, "description": app_model.description, "tags": tags}
+ return {"name": app_model.name, "description": app_model.description, "tags": tags, "mode": app_model.mode}
api.add_resource(AppParameterApi, "/parameters")
diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py
index e6bcc0bfd2..2682c2e7f1 100644
--- a/api/controllers/service_api/app/audio.py
+++ b/api/controllers/service_api/app/audio.py
@@ -1,7 +1,7 @@
import logging
from flask import request
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError
import services
diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py
index 38a65b7a90..1d9890199d 100644
--- a/api/controllers/service_api/app/completion.py
+++ b/api/controllers/service_api/app/completion.py
@@ -1,6 +1,6 @@
import logging
-from flask_restful import Resource, reqparse # type: ignore
+from flask_restful import Resource, reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py
index 334f2c5620..36a7905572 100644
--- a/api/controllers/service_api/app/conversation.py
+++ b/api/controllers/service_api/app/conversation.py
@@ -1,5 +1,5 @@
-from flask_restful import Resource, marshal_with, reqparse # type: ignore
-from flask_restful.inputs import int_range # type: ignore
+from flask_restful import Resource, marshal_with, reqparse
+from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
@@ -14,6 +14,9 @@ from fields.conversation_fields import (
conversation_infinite_scroll_pagination_fields,
simple_conversation_fields,
)
+from fields.conversation_variable_fields import (
+ conversation_variable_infinite_scroll_pagination_fields,
+)
from libs.helper import uuid_value
from models.model import App, AppMode, EndUser
from services.conversation_service import ConversationService
@@ -69,7 +72,7 @@ class ConversationDetailApi(Resource):
ConversationService.delete(app_model, conversation_id, end_user)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
- return {"result": "success"}, 200
+ return {"result": "success"}, 204
class ConversationRenameApi(Resource):
@@ -93,6 +96,31 @@ class ConversationRenameApi(Resource):
raise NotFound("Conversation Not Exists.")
+class ConversationVariablesApi(Resource):
+ @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY))
+ @marshal_with(conversation_variable_infinite_scroll_pagination_fields)
+ def get(self, app_model: App, end_user: EndUser, c_id):
+ # conversational variable only for chat app
+ app_mode = AppMode.value_of(app_model.mode)
+ if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
+ raise NotChatAppError()
+
+ conversation_id = str(c_id)
+
+ parser = reqparse.RequestParser()
+ parser.add_argument("last_id", type=uuid_value, location="args")
+ parser.add_argument("limit", type=int_range(1, 100), required=False, default=20, location="args")
+ args = parser.parse_args()
+
+ try:
+ return ConversationService.get_conversational_variable(
+ app_model, conversation_id, end_user, args["limit"], args["last_id"]
+ )
+ except services.errors.conversation.ConversationNotExistsError:
+ raise NotFound("Conversation Not Exists.")
+
+
api.add_resource(ConversationRenameApi, "/conversations//name", endpoint="conversation_name")
api.add_resource(ConversationApi, "/conversations")
api.add_resource(ConversationDetailApi, "/conversations/", endpoint="conversation_detail")
+api.add_resource(ConversationVariablesApi, "/conversations//variables", endpoint="conversation_variables")
diff --git a/api/controllers/service_api/app/file.py b/api/controllers/service_api/app/file.py
index 27b21b9f50..b0fd8e65ef 100644
--- a/api/controllers/service_api/app/file.py
+++ b/api/controllers/service_api/app/file.py
@@ -1,5 +1,5 @@
from flask import request
-from flask_restful import Resource, marshal_with # type: ignore
+from flask_restful import Resource, marshal_with
import services
from controllers.common.errors import FilenameNotExistsError
diff --git a/api/controllers/service_api/app/message.py b/api/controllers/service_api/app/message.py
index 95e538f4c7..1b148a9756 100644
--- a/api/controllers/service_api/app/message.py
+++ b/api/controllers/service_api/app/message.py
@@ -1,8 +1,8 @@
import json
import logging
-from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
-from flask_restful.inputs import int_range # type: ignore
+from flask_restful import Resource, fields, marshal_with, reqparse
+from flask_restful.inputs import int_range
from werkzeug.exceptions import BadRequest, InternalServerError, NotFound
import services
diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py
index 8b10a028f3..e9bb2b046a 100644
--- a/api/controllers/service_api/app/workflow.py
+++ b/api/controllers/service_api/app/workflow.py
@@ -1,8 +1,8 @@
import logging
from dateutil.parser import isoparse
-from flask_restful import Resource, fields, marshal_with, reqparse # type: ignore
-from flask_restful.inputs import int_range # type: ignore
+from flask_restful import Resource, fields, marshal_with, reqparse
+from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import InternalServerError
@@ -59,7 +59,7 @@ class WorkflowRunDetailApi(Resource):
Get a workflow task running detail
"""
app_mode = AppMode.value_of(app_model.mode)
- if app_mode != AppMode.WORKFLOW:
+ if app_mode not in [AppMode.WORKFLOW, AppMode.ADVANCED_CHAT]:
raise NotWorkflowAppError()
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first()
diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py
index e1e6f3168f..ee190245d5 100644
--- a/api/controllers/service_api/dataset/dataset.py
+++ b/api/controllers/service_api/dataset/dataset.py
@@ -1,5 +1,5 @@
from flask import request
-from flask_restful import marshal, reqparse # type: ignore
+from flask_restful import marshal, reqparse
from werkzeug.exceptions import Forbidden, NotFound
import services.dataset_service
diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py
index eec6afc9ef..33eda37014 100644
--- a/api/controllers/service_api/dataset/document.py
+++ b/api/controllers/service_api/dataset/document.py
@@ -1,7 +1,7 @@
import json
from flask import request
-from flask_restful import marshal, reqparse # type: ignore
+from flask_restful import marshal, reqparse
from sqlalchemy import desc
from werkzeug.exceptions import NotFound
@@ -323,7 +323,7 @@ class DocumentDeleteApi(DatasetApiResource):
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
- return {"result": "success"}, 200
+ return {"result": "success"}, 204
class DocumentListApi(DatasetApiResource):
diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py
index 298c8a8df8..35582feea0 100644
--- a/api/controllers/service_api/dataset/metadata.py
+++ b/api/controllers/service_api/dataset/metadata.py
@@ -1,5 +1,5 @@
-from flask_login import current_user # type: ignore # type: ignore
-from flask_restful import marshal, reqparse # type: ignore
+from flask_login import current_user # type: ignore
+from flask_restful import marshal, reqparse
from werkzeug.exceptions import NotFound
from controllers.service_api import api
@@ -63,7 +63,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
DatasetService.check_dataset_permission(dataset, current_user)
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
- return 200
+ return 204
class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource):
diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py
index 2a79e15cc5..fb3ca1e15f 100644
--- a/api/controllers/service_api/dataset/segment.py
+++ b/api/controllers/service_api/dataset/segment.py
@@ -1,6 +1,6 @@
from flask import request
-from flask_login import current_user # type: ignore
-from flask_restful import marshal, reqparse # type: ignore
+from flask_login import current_user
+from flask_restful import marshal, reqparse
from werkzeug.exceptions import NotFound
from controllers.service_api import api
@@ -159,7 +159,7 @@ class DatasetSegmentApi(DatasetApiResource):
if not segment:
raise NotFound("Segment not found.")
SegmentService.delete_segment(segment, document, dataset)
- return {"result": "success"}, 200
+ return {"result": "success"}, 204
@cloud_edition_billing_resource_check("vector_space", "dataset")
def post(self, tenant_id, dataset_id, document_id, segment_id):
@@ -344,7 +344,7 @@ class DatasetChildChunkApi(DatasetApiResource):
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
- return {"result": "success"}, 200
+ return {"result": "success"}, 204
@cloud_edition_billing_resource_check("vector_space", "dataset")
@cloud_edition_billing_knowledge_limit_check("add_segment", "dataset")
diff --git a/api/controllers/service_api/index.py b/api/controllers/service_api/index.py
index 75d9141a6d..d24c4597e2 100644
--- a/api/controllers/service_api/index.py
+++ b/api/controllers/service_api/index.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource # type: ignore
+from flask_restful import Resource
from configs import dify_config
from controllers.service_api import api
diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py
index 373f8019f9..3f18474674 100644
--- a/api/controllers/service_api/workspace/models.py
+++ b/api/controllers/service_api/workspace/models.py
@@ -1,5 +1,5 @@
-from flask_login import current_user # type: ignore
-from flask_restful import Resource # type: ignore
+from flask_login import current_user
+from flask_restful import Resource
from controllers.service_api import api
from controllers.service_api.wraps import validate_dataset_token
diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py
index 7facb03358..cd35ceac1d 100644
--- a/api/controllers/service_api/wraps.py
+++ b/api/controllers/service_api/wraps.py
@@ -7,7 +7,7 @@ from typing import Optional
from flask import current_app, request
from flask_login import user_logged_in # type: ignore
-from flask_restful import Resource # type: ignore
+from flask_restful import Resource
from pydantic import BaseModel
from sqlalchemy import select, update
from sqlalchemy.orm import Session
diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py
index a84b846112..c9a37af5ed 100644
--- a/api/controllers/web/app.py
+++ b/api/controllers/web/app.py
@@ -1,4 +1,4 @@
-from flask_restful import marshal_with # type: ignore
+from flask_restful import marshal_with
from controllers.common import fields
from controllers.web import api
diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py
index 97d980d07c..06d9ad7564 100644
--- a/api/controllers/web/audio.py
+++ b/api/controllers/web/audio.py
@@ -65,7 +65,7 @@ class AudioApi(WebApiResource):
class TextApi(WebApiResource):
def post(self, app_model: App, end_user):
- from flask_restful import reqparse # type: ignore
+ from flask_restful import reqparse
try:
parser = reqparse.RequestParser()
diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py
index 9677401490..fd3b9aa804 100644
--- a/api/controllers/web/completion.py
+++ b/api/controllers/web/completion.py
@@ -1,6 +1,6 @@
import logging
-from flask_restful import reqparse # type: ignore
+from flask_restful import reqparse
from werkzeug.exceptions import InternalServerError, NotFound
import services
diff --git a/api/controllers/web/conversation.py b/api/controllers/web/conversation.py
index 419247ea14..98cea3974f 100644
--- a/api/controllers/web/conversation.py
+++ b/api/controllers/web/conversation.py
@@ -1,5 +1,5 @@
-from flask_restful import marshal_with, reqparse # type: ignore
-from flask_restful.inputs import int_range # type: ignore
+from flask_restful import marshal_with, reqparse
+from flask_restful.inputs import int_range
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
diff --git a/api/controllers/web/feature.py b/api/controllers/web/feature.py
index ce841a8814..0563ed2238 100644
--- a/api/controllers/web/feature.py
+++ b/api/controllers/web/feature.py
@@ -1,4 +1,4 @@
-from flask_restful import Resource # type: ignore
+from flask_restful import Resource
from controllers.web import api
from services.feature_service import FeatureService
diff --git a/api/controllers/web/files.py b/api/controllers/web/files.py
index 1d4474015a..df06a73a85 100644
--- a/api/controllers/web/files.py
+++ b/api/controllers/web/files.py
@@ -1,5 +1,5 @@
from flask import request
-from flask_restful import marshal_with # type: ignore
+from flask_restful import marshal_with
import services
from controllers.common.errors import FilenameNotExistsError
diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py
index 17e9a3990f..f2e1873601 100644
--- a/api/controllers/web/message.py
+++ b/api/controllers/web/message.py
@@ -1,7 +1,7 @@
import logging
-from flask_restful import fields, marshal_with, reqparse # type: ignore
-from flask_restful.inputs import int_range # type: ignore
+from flask_restful import fields, marshal_with, reqparse
+from flask_restful.inputs import int_range
from werkzeug.exceptions import InternalServerError, NotFound
import services
diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py
index e30998c803..267dac223d 100644
--- a/api/controllers/web/passport.py
+++ b/api/controllers/web/passport.py
@@ -1,7 +1,7 @@
import uuid
from flask import request
-from flask_restful import Resource # type: ignore
+from flask_restful import Resource
from werkzeug.exceptions import NotFound, Unauthorized
from controllers.web import api
diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py
index d559ab8e07..ae68df6bdc 100644
--- a/api/controllers/web/remote_files.py
+++ b/api/controllers/web/remote_files.py
@@ -1,7 +1,7 @@
import urllib.parse
import httpx
-from flask_restful import marshal_with, reqparse # type: ignore
+from flask_restful import marshal_with, reqparse
import services
from controllers.common import helpers
diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py
index 6a9b818907..d7188ef0b3 100644
--- a/api/controllers/web/saved_message.py
+++ b/api/controllers/web/saved_message.py
@@ -1,5 +1,5 @@
-from flask_restful import fields, marshal_with, reqparse # type: ignore
-from flask_restful.inputs import int_range # type: ignore
+from flask_restful import fields, marshal_with, reqparse
+from flask_restful.inputs import int_range
from werkzeug.exceptions import NotFound
from controllers.web import api
@@ -67,7 +67,7 @@ class SavedMessageApi(WebApiResource):
SavedMessageService.delete(app_model, end_user, message_id)
- return {"result": "success"}
+ return {"result": "success"}, 204
api.add_resource(SavedMessageListApi, "/saved-messages")
diff --git a/api/controllers/web/site.py b/api/controllers/web/site.py
index e68dc7aa4a..0564b15ea3 100644
--- a/api/controllers/web/site.py
+++ b/api/controllers/web/site.py
@@ -1,4 +1,4 @@
-from flask_restful import fields, marshal_with # type: ignore
+from flask_restful import fields, marshal_with
from werkzeug.exceptions import Forbidden
from configs import dify_config
diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py
index d2e183be78..590fd3f2c7 100644
--- a/api/controllers/web/workflow.py
+++ b/api/controllers/web/workflow.py
@@ -1,6 +1,6 @@
import logging
-from flask_restful import reqparse # type: ignore
+from flask_restful import reqparse
from werkzeug.exceptions import InternalServerError
from controllers.web import api
diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py
index 1b4d263bee..c327c3df18 100644
--- a/api/controllers/web/wraps.py
+++ b/api/controllers/web/wraps.py
@@ -1,7 +1,7 @@
from functools import wraps
from flask import request
-from flask_restful import Resource # type: ignore
+from flask_restful import Resource
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from controllers.web.error import WebSSOAuthRequiredError
diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py
index 48c92ea2db..6998e4d29a 100644
--- a/api/core/agent/base_agent_runner.py
+++ b/api/core/agent/base_agent_runner.py
@@ -21,14 +21,13 @@ from core.model_runtime.entities import (
AssistantPromptMessage,
LLMUsage,
PromptMessage,
- PromptMessageContent,
PromptMessageTool,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
-from core.model_runtime.entities.message_entities import ImagePromptMessageContent
+from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.utils.extract_thread_messages import extract_thread_messages
@@ -92,6 +91,8 @@ class BaseAgentRunner(AppRunner):
return_resource=app_config.additional_features.show_retrieve_source,
invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback,
+ user_id=user_id,
+ inputs=cast(dict, application_generate_entity.inputs),
)
# get how many agent thoughts have been created
self.agent_thought_count = (
@@ -501,7 +502,7 @@ class BaseAgentRunner(AppRunner):
)
if not file_objs:
return UserPromptMessage(content=message.query)
- prompt_message_contents: list[PromptMessageContent] = []
+ prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file in file_objs:
prompt_message_contents.append(
diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py
index 7d407a4976..5ff89bdacb 100644
--- a/api/core/agent/cot_chat_agent_runner.py
+++ b/api/core/agent/cot_chat_agent_runner.py
@@ -5,12 +5,11 @@ from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
- PromptMessageContent,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
-from core.model_runtime.entities.message_entities import ImagePromptMessageContent
+from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.utils.encoders import jsonable_encoder
@@ -40,7 +39,7 @@ class CotChatAgentRunner(CotAgentRunner):
Organize user query
"""
if self.files:
- prompt_message_contents: list[PromptMessageContent] = []
+ prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
# get image detail config
diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py
index f45fa5c66e..a1110e7709 100644
--- a/api/core/agent/fc_agent_runner.py
+++ b/api/core/agent/fc_agent_runner.py
@@ -15,14 +15,13 @@ from core.model_runtime.entities import (
LLMResultChunkDelta,
LLMUsage,
PromptMessage,
- PromptMessageContent,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
-from core.model_runtime.entities.message_entities import ImagePromptMessageContent
+from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
@@ -395,7 +394,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
Organize user query
"""
if self.files:
- prompt_message_contents: list[PromptMessageContent] = []
+ prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
# get image detail config
diff --git a/api/core/agent/prompt/template.py b/api/core/agent/prompt/template.py
index ef64fd29fc..f5ba2119f4 100644
--- a/api/core/agent/prompt/template.py
+++ b/api/core/agent/prompt/template.py
@@ -1,4 +1,4 @@
-ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible.
+ENGLISH_REACT_COMPLETION_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible.
{{instruction}}
@@ -47,7 +47,7 @@ Thought:""" # noqa: E501
ENGLISH_REACT_COMPLETION_AGENT_SCRATCHPAD_TEMPLATES = """Observation: {{observation}}
Thought:"""
-ENGLISH_REACT_CHAT_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible.
+ENGLISH_REACT_CHAT_PROMPT_TEMPLATES = """Respond to the human as helpfully and accurately as possible.
{{instruction}}
diff --git a/api/core/agent/strategy/plugin.py b/api/core/agent/strategy/plugin.py
index a4b25f46e6..79b074cf95 100644
--- a/api/core/agent/strategy/plugin.py
+++ b/api/core/agent/strategy/plugin.py
@@ -4,7 +4,7 @@ from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
from core.agent.strategy.base import BaseAgentStrategy
-from core.plugin.manager.agent import PluginAgentManager
+from core.plugin.impl.agent import PluginAgentClient
from core.plugin.utils.converter import convert_parameters_to_plugin_format
@@ -42,7 +42,7 @@ class PluginAgentStrategy(BaseAgentStrategy):
"""
Invoke the agent strategy.
"""
- manager = PluginAgentManager()
+ manager = PluginAgentClient()
initialized_params = self.initialize_parameters(params)
params = convert_parameters_to_plugin_format(initialized_params)
diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py
index ef582d28e0..4b0e64130b 100644
--- a/api/core/app/apps/advanced_chat/app_generator.py
+++ b/api/core/app/apps/advanced_chat/app_generator.py
@@ -7,6 +7,7 @@ from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
+from sqlalchemy.orm import sessionmaker
import contexts
from configs import dify_config
@@ -24,6 +25,8 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
+from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from extensions.ext_database import db
from factories import file_factory
from models.account import Account
@@ -158,11 +161,20 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
+ # Create workflow node execution repository
+ session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
+ workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ session_factory=session_factory,
+ tenant_id=application_generate_entity.app_config.tenant_id,
+ app_id=application_generate_entity.app_config.app_id,
+ )
+
return self._generate(
workflow=workflow,
user=user,
invoke_from=invoke_from,
application_generate_entity=application_generate_entity,
+ workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
stream=streaming,
)
@@ -215,11 +227,20 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
+ # Create workflow node execution repository
+ session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
+ workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ session_factory=session_factory,
+ tenant_id=application_generate_entity.app_config.tenant_id,
+ app_id=application_generate_entity.app_config.app_id,
+ )
+
return self._generate(
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
+ workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None,
stream=streaming,
)
@@ -270,11 +291,20 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
+ # Create workflow node execution repository
+ session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
+ workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ session_factory=session_factory,
+ tenant_id=application_generate_entity.app_config.tenant_id,
+ app_id=application_generate_entity.app_config.app_id,
+ )
+
return self._generate(
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
+ workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None,
stream=streaming,
)
@@ -286,6 +316,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
invoke_from: InvokeFrom,
application_generate_entity: AdvancedChatAppGenerateEntity,
+ workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Optional[Conversation] = None,
stream: bool = True,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
@@ -296,6 +327,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param user: account or end user
:param invoke_from: invoke from source
:param application_generate_entity: application generate entity
+ :param workflow_node_execution_repository: repository for workflow node execution
:param conversation: conversation
:param stream: is stream
"""
@@ -348,6 +380,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation,
message=message,
user=user,
+ workflow_node_execution_repository=workflow_node_execution_repository,
stream=stream,
)
@@ -419,6 +452,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation: Conversation,
message: Message,
user: Union[Account, EndUser],
+ workflow_node_execution_repository: WorkflowNodeExecutionRepository,
stream: bool = False,
) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
@@ -430,6 +464,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param message: message
:param user: account or end user
:param stream: is stream
+ :param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
@@ -442,6 +477,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user=user,
stream=stream,
dialogue_count=self._dialogue_count,
+ workflow_node_execution_repository=workflow_node_execution_repository,
)
try:
diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
index 3bf6c330db..f71c49d112 100644
--- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py
+++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py
@@ -9,7 +9,6 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
-from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
@@ -58,13 +57,15 @@ from core.app.entities.task_entities import (
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
-from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
+from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes import NodeType
+from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from events.message_event import message_was_created
from extensions.ext_database import db
from models import Conversation, EndUser, Message, MessageFile
@@ -93,6 +94,7 @@ class AdvancedChatAppGenerateTaskPipeline:
user: Union[Account, EndUser],
stream: bool,
dialogue_count: int,
+ workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
@@ -111,7 +113,7 @@ class AdvancedChatAppGenerateTaskPipeline:
else:
raise NotImplementedError(f"User type not supported: {type(user)}")
- self._workflow_cycle_manager = WorkflowCycleManage(
+ self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.QUERY: message.query,
@@ -123,6 +125,7 @@ class AdvancedChatAppGenerateTaskPipeline:
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
+ workflow_node_execution_repository=workflow_node_execution_repository,
)
self._task_state = WorkflowTaskState()
@@ -684,7 +687,9 @@ class AdvancedChatAppGenerateTaskPipeline:
)
elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation
- yield self._message_cycle_manager._message_replace_to_stream_response(answer=event.text)
+ yield self._message_cycle_manager._message_replace_to_stream_response(
+ answer=event.text, reason=event.reason
+ )
elif isinstance(event, QueueAdvancedChatMessageEndEvent):
if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.")
@@ -695,7 +700,8 @@ class AdvancedChatAppGenerateTaskPipeline:
if output_moderation_answer:
self._task_state.answer = output_moderation_answer
yield self._message_cycle_manager._message_replace_to_stream_response(
- answer=output_moderation_answer
+ answer=output_moderation_answer,
+ reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
)
# Save message
diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py
index b1f527c0f2..995082b79d 100644
--- a/api/core/app/apps/message_based_app_generator.py
+++ b/api/core/app/apps/message_based_app_generator.py
@@ -153,6 +153,8 @@ class MessageBasedAppGenerator(BaseAppGenerator):
query = application_generate_entity.query or "New conversation"
else:
query = next(iter(application_generate_entity.inputs.values()), "New conversation")
+ if isinstance(query, int):
+ query = str(query)
query = query or "New conversation"
conversation_name = (query[:20] + "…") if len(query) > 20 else query
diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py
index 08986b16f0..1d67671974 100644
--- a/api/core/app/apps/workflow/app_generator.py
+++ b/api/core/app/apps/workflow/app_generator.py
@@ -7,6 +7,7 @@ from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
+from sqlalchemy.orm import sessionmaker
import contexts
from configs import dify_config
@@ -17,11 +18,13 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager
from core.app.apps.workflow.app_runner import WorkflowAppRunner
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
-from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
+from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from extensions.ext_database import db
from factories import file_factory
from models import Account, App, EndUser, Workflow
@@ -133,12 +136,21 @@ class WorkflowAppGenerator(BaseAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
+ # Create workflow node execution repository
+ session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
+ workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ session_factory=session_factory,
+ tenant_id=application_generate_entity.app_config.tenant_id,
+ app_id=application_generate_entity.app_config.app_id,
+ )
+
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=invoke_from,
+ workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
@@ -151,6 +163,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
invoke_from: InvokeFrom,
+ workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
@@ -162,6 +175,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param user: account or end user
:param application_generate_entity: application generate entity
:param invoke_from: invoke from source
+ :param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
:param workflow_thread_pool_id: workflow thread pool id
"""
@@ -193,6 +207,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow=workflow,
queue_manager=queue_manager,
user=user,
+ workflow_node_execution_repository=workflow_node_execution_repository,
stream=streaming,
)
@@ -245,12 +260,21 @@ class WorkflowAppGenerator(BaseAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
+ # Create workflow node execution repository
+ session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
+ workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ session_factory=session_factory,
+ tenant_id=application_generate_entity.app_config.tenant_id,
+ app_id=application_generate_entity.app_config.app_id,
+ )
+
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
+ workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
)
@@ -299,12 +323,21 @@ class WorkflowAppGenerator(BaseAppGenerator):
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
+ # Create workflow node execution repository
+ session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
+ workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ session_factory=session_factory,
+ tenant_id=application_generate_entity.app_config.tenant_id,
+ app_id=application_generate_entity.app_config.app_id,
+ )
+
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
invoke_from=InvokeFrom.DEBUGGER,
application_generate_entity=application_generate_entity,
+ workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
)
@@ -361,6 +394,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow: Workflow,
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
+ workflow_node_execution_repository: WorkflowNodeExecutionRepository,
stream: bool = False,
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
"""
@@ -370,6 +404,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param queue_manager: queue manager
:param user: account or end user
:param stream: is stream
+ :param workflow_node_execution_repository: optional repository for workflow node execution
:return:
"""
# init generate task pipeline
@@ -379,6 +414,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
queue_manager=queue_manager,
user=user,
stream=stream,
+ workflow_node_execution_repository=workflow_node_execution_repository,
)
try:
diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py
index 3702326406..7228020e9b 100644
--- a/api/core/app/entities/queue_entities.py
+++ b/api/core/app/entities/queue_entities.py
@@ -264,8 +264,16 @@ class QueueMessageReplaceEvent(AppQueueEvent):
QueueMessageReplaceEvent entity
"""
+ class MessageReplaceReason(StrEnum):
+ """
+ Reason for message replace event
+ """
+
+ OUTPUT_MODERATION = "output_moderation"
+
event: QueueEvent = QueueEvent.MESSAGE_REPLACE
text: str
+ reason: str
class QueueRetrieverResourcesEvent(AppQueueEvent):
diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py
index f23ee1b9fd..817699bd20 100644
--- a/api/core/app/entities/task_entities.py
+++ b/api/core/app/entities/task_entities.py
@@ -148,6 +148,7 @@ class MessageReplaceStreamResponse(StreamResponse):
event: StreamEvent = StreamEvent.MESSAGE_REPLACE
answer: str
+ reason: str
class AgentThoughtStreamResponse(StreamResponse):
diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py
index a2e06d4e1f..5331c0cc94 100644
--- a/api/core/app/task_pipeline/based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py
@@ -126,12 +126,12 @@ class BasedGenerateTaskPipeline:
if self._output_moderation_handler:
self._output_moderation_handler.stop_thread()
- completion = self._output_moderation_handler.moderation_completion(
+ completion, flagged = self._output_moderation_handler.moderation_completion(
completion=completion, public_event=False
)
self._output_moderation_handler = None
-
- return completion
+ if flagged:
+ return completion
return None
diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
index 8c9c26d36e..a98a42f5df 100644
--- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
+++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
@@ -9,7 +9,6 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
-from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
@@ -45,6 +44,7 @@ from core.app.entities.task_entities import (
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manage import MessageCycleManage
+from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
diff --git a/api/core/app/task_pipeline/message_cycle_manage.py b/api/core/app/task_pipeline/message_cycle_manage.py
index 6223b33b67..a6d826f08b 100644
--- a/api/core/app/task_pipeline/message_cycle_manage.py
+++ b/api/core/app/task_pipeline/message_cycle_manage.py
@@ -24,7 +24,7 @@ from core.app.entities.task_entities import (
WorkflowTaskState,
)
from core.llm_generator.llm_generator import LLMGenerator
-from core.tools.tool_file_manager import ToolFileManager
+from core.tools.signature import sign_tool_file
from extensions.ext_database import db
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
from services.annotation_service import AppAnnotationService
@@ -154,7 +154,7 @@ class MessageCycleManage:
if message_file.url.startswith("http"):
url = message_file.url
else:
- url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension)
+ url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
return MessageFileStreamResponse(
task_id=self._application_generate_entity.task_id,
@@ -182,10 +182,12 @@ class MessageCycleManage:
from_variable_selector=from_variable_selector,
)
- def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
+ def _message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
"""
Message replace to stream response.
:param answer: answer
:return:
"""
- return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer)
+ return MessageReplaceStreamResponse(
+ task_id=self._application_generate_entity.task_id, answer=answer, reason=reason
+ )
diff --git a/api/core/base/__init__.py b/api/core/base/__init__.py
new file mode 100644
index 0000000000..3f4bd3b771
--- /dev/null
+++ b/api/core/base/__init__.py
@@ -0,0 +1 @@
+# Core base package
diff --git a/api/core/base/tts/__init__.py b/api/core/base/tts/__init__.py
new file mode 100644
index 0000000000..37b6eeebb0
--- /dev/null
+++ b/api/core/base/tts/__init__.py
@@ -0,0 +1,6 @@
+from core.base.tts.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
+
+__all__ = [
+ "AppGeneratorTTSPublisher",
+ "AudioTrunk",
+]
diff --git a/api/core/app/apps/advanced_chat/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py
similarity index 100%
rename from api/core/app/apps/advanced_chat/app_generator_tts_publisher.py
rename to api/core/base/tts/app_generator_tts_publisher.py
diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py
index b3affc91a6..86887c9b4a 100644
--- a/api/core/entities/provider_configuration.py
+++ b/api/core/entities/provider_configuration.py
@@ -798,7 +798,25 @@ class ProviderConfiguration(BaseModel):
provider_models = [m for m in provider_models if m.status == ModelStatus.ACTIVE]
# resort provider_models
- return sorted(provider_models, key=lambda x: x.model_type.value)
+ # Optimize sorting logic: first sort by provider.position order, then by model_type.value
+ # Get the position list for model types (retrieve only once for better performance)
+ model_type_positions = {}
+ if hasattr(self.provider, "position") and self.provider.position:
+ model_type_positions = self.provider.position
+
+ def get_sort_key(model: ModelWithProviderEntity):
+ # Get the position list for the current model type
+ positions = model_type_positions.get(model.model_type.value, [])
+
+ # If the model name is in the position list, use its index for sorting
+ # Otherwise use a large value (list length) to place undefined models at the end
+ position_index = positions.index(model.model) if model.model in positions else len(positions)
+
+ # Return composite sort key: (model_type value, model position index)
+ return (model.model_type.value, position_index)
+
+ # Sort using the composite sort key
+ return sorted(provider_models, key=get_sort_key)
def _get_system_provider_models(
self,
diff --git a/api/core/external_data_tool/api/__builtin__ b/api/core/external_data_tool/api/__builtin__
index 56a6051ca2..d00491fd7e 100644
--- a/api/core/external_data_tool/api/__builtin__
+++ b/api/core/external_data_tool/api/__builtin__
@@ -1 +1 @@
-1
\ No newline at end of file
+1
diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py
index 4ebe997ac5..ada19ef8ce 100644
--- a/api/core/file/file_manager.py
+++ b/api/core/file/file_manager.py
@@ -7,15 +7,15 @@ from core.model_runtime.entities import (
AudioPromptMessageContent,
DocumentPromptMessageContent,
ImagePromptMessageContent,
- MultiModalPromptMessageContent,
VideoPromptMessageContent,
)
+from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
+from core.tools.signature import sign_tool_file
from extensions.ext_storage import storage
from . import helpers
from .enums import FileAttribute
from .models import File, FileTransferMethod, FileType
-from .tool_file_parser import ToolFileParser
def get_attr(*, file: File, attr: FileAttribute):
@@ -43,7 +43,7 @@ def to_prompt_message_content(
/,
*,
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
-) -> MultiModalPromptMessageContent:
+) -> PromptMessageContentUnionTypes:
if f.extension is None:
raise ValueError("Missing file extension")
if f.mime_type is None:
@@ -58,7 +58,7 @@ def to_prompt_message_content(
if f.type == FileType.IMAGE:
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
- prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = {
+ prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
FileType.IMAGE: ImagePromptMessageContent,
FileType.AUDIO: AudioPromptMessageContent,
FileType.VIDEO: VideoPromptMessageContent,
@@ -130,6 +130,6 @@ def _to_url(f: File, /):
# add sign url
if f.related_id is None or f.extension is None:
raise ValueError("Missing file related_id or extension")
- return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
+ return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
diff --git a/api/core/file/models.py b/api/core/file/models.py
index f5db6c2d74..aa3b5f629c 100644
--- a/api/core/file/models.py
+++ b/api/core/file/models.py
@@ -4,11 +4,11 @@ from typing import Any, Optional
from pydantic import BaseModel, Field, model_validator
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
+from core.tools.signature import sign_tool_file
from . import helpers
from .constants import FILE_MODEL_IDENTITY
from .enums import FileTransferMethod, FileType
-from .tool_file_parser import ToolFileParser
class ImageConfig(BaseModel):
@@ -34,13 +34,21 @@ class FileUploadConfig(BaseModel):
class File(BaseModel):
+ # NOTE: dify_model_identity is a special identifier used to distinguish between
+ # new and old data formats during serialization and deserialization.
dify_model_identity: str = FILE_MODEL_IDENTITY
id: Optional[str] = None # message file id
tenant_id: str
type: FileType
transfer_method: FileTransferMethod
+ # If `transfer_method` is `FileTransferMethod.remote_url`, the
+ # `remote_url` attribute must not be `None`.
remote_url: Optional[str] = None # remote url
+ # If `transfer_method` is `FileTransferMethod.local_file` or
+ # `FileTransferMethod.tool_file`, the `related_id` attribute must not be `None`.
+ #
+ # It should be set to `ToolFile.id` when `transfer_method` is `tool_file`.
related_id: Optional[str] = None
filename: Optional[str] = None
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
@@ -110,9 +118,7 @@ class File(BaseModel):
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
assert self.related_id is not None
assert self.extension is not None
- return ToolFileParser.get_tool_file_manager().sign_file(
- tool_file_id=self.related_id, extension=self.extension
- )
+ return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
def to_plugin_parameter(self) -> dict[str, Any]:
return {
diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py
index 6fa101cf36..656c9d48ed 100644
--- a/api/core/file/tool_file_parser.py
+++ b/api/core/file/tool_file_parser.py
@@ -1,12 +1,19 @@
-from typing import TYPE_CHECKING, Any, cast
+from collections.abc import Callable
+from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.tools.tool_file_manager import ToolFileManager
-tool_file_manager: dict[str, Any] = {"manager": None}
+_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None
class ToolFileParser:
@staticmethod
def get_tool_file_manager() -> "ToolFileManager":
- return cast("ToolFileManager", tool_file_manager["manager"])
+ assert _tool_file_manager_factory is not None
+ return _tool_file_manager_factory()
+
+
+def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None:
+ global _tool_file_manager_factory
+ _tool_file_manager_factory = factory
diff --git a/api/core/helper/code_executor/javascript/javascript_transformer.py b/api/core/helper/code_executor/javascript/javascript_transformer.py
index d67a0903aa..62489cdf29 100644
--- a/api/core/helper/code_executor/javascript/javascript_transformer.py
+++ b/api/core/helper/code_executor/javascript/javascript_transformer.py
@@ -10,13 +10,13 @@ class NodeJsTemplateTransformer(TemplateTransformer):
f"""
// declare main function
{cls._code_placeholder}
-
+
// decode and prepare input object
var inputs_obj = JSON.parse(Buffer.from('{cls._inputs_placeholder}', 'base64').toString('utf-8'))
-
+
// execute main function
var output_obj = main(inputs_obj)
-
+
// convert output to json and print
var output_json = JSON.stringify(output_obj)
var result = `<>${{output_json}}<>`
diff --git a/api/core/helper/code_executor/jinja2/jinja2_transformer.py b/api/core/helper/code_executor/jinja2/jinja2_transformer.py
index 63d58edbc7..54c78cdf92 100644
--- a/api/core/helper/code_executor/jinja2/jinja2_transformer.py
+++ b/api/core/helper/code_executor/jinja2/jinja2_transformer.py
@@ -21,20 +21,20 @@ class Jinja2TemplateTransformer(TemplateTransformer):
import jinja2
template = jinja2.Template('''{cls._code_placeholder}''')
return template.render(**inputs)
-
+
import json
from base64 import b64decode
-
+
# decode and prepare input dict
inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))
-
+
# execute main function
output = main(**inputs_obj)
-
+
# convert output and print
result = f'''<>{{output}}<>'''
print(result)
-
+
""")
return runner_script
@@ -43,15 +43,15 @@ class Jinja2TemplateTransformer(TemplateTransformer):
preload_script = dedent("""
import jinja2
from base64 import b64decode
-
+
def _jinja2_preload_():
# prepare jinja2 environment, load template and render before to avoid sandbox issue
template = jinja2.Template('{{s}}')
template.render(s='a')
-
+
if __name__ == '__main__':
_jinja2_preload_()
-
+
""")
return preload_script
diff --git a/api/core/helper/code_executor/python3/python3_transformer.py b/api/core/helper/code_executor/python3/python3_transformer.py
index 75a5a44d08..836fd273ae 100644
--- a/api/core/helper/code_executor/python3/python3_transformer.py
+++ b/api/core/helper/code_executor/python3/python3_transformer.py
@@ -9,16 +9,16 @@ class Python3TemplateTransformer(TemplateTransformer):
runner_script = dedent(f"""
# declare main function
{cls._code_placeholder}
-
+
import json
from base64 import b64decode
-
+
# decode and prepare input dict
inputs_obj = json.loads(b64decode('{cls._inputs_placeholder}').decode('utf-8'))
-
+
# execute main function
output_obj = main(**inputs_obj)
-
+
# convert output to json and print
output_json = json.dumps(output_obj, indent=4)
result = f'''<>{{output_json}}<>'''
diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py
index a75a4c22d1..81bf59b2b6 100644
--- a/api/core/indexing_runner.py
+++ b/api/core/indexing_runner.py
@@ -9,7 +9,7 @@ import uuid
from typing import Any, Optional, cast
from flask import current_app
-from flask_login import current_user # type: ignore
+from flask_login import current_user
from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config
diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py
index d5d2ca60fa..e5dbc30689 100644
--- a/api/core/llm_generator/llm_generator.py
+++ b/api/core/llm_generator/llm_generator.py
@@ -3,6 +3,8 @@ import logging
import re
from typing import Optional, cast
+import json_repair
+
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import (
@@ -366,7 +368,20 @@ class LLMGenerator:
),
)
- generated_json_schema = cast(str, response.message.content)
+ raw_content = response.message.content
+
+ if not isinstance(raw_content, str):
+ raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}")
+
+ try:
+ parsed_content = json.loads(raw_content)
+ except json.JSONDecodeError:
+ parsed_content = json_repair.loads(raw_content)
+
+ if not isinstance(parsed_content, dict | list):
+ raise ValueError(f"Failed to parse structured output from llm: {raw_content}")
+
+ generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False)
return {"output": generated_json_schema, "error": ""}
except InvokeError as e:
diff --git a/api/core/llm_generator/prompts.py b/api/core/llm_generator/prompts.py
index 82d22d7f89..34ea3aec26 100644
--- a/api/core/llm_generator/prompts.py
+++ b/api/core/llm_generator/prompts.py
@@ -1,7 +1,7 @@
# Written by YORKI MINAKO🤡, Edited by Xiaoyi
-CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is.
-Notice: the language type user use could be diverse, which can be English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.
-MAKE SURE your output is the SAME language as the user's input!
+CONVERSATION_TITLE_PROMPT = """You need to decompose the user's input into "subject" and "intention" in order to accurately figure out what the user's input language actually is.
+Notice: the language type user uses could be diverse, which can be English, Chinese, Italian, Español, Arabic, Japanese, French, and etc.
+ENSURE your output is in the SAME language as the user's input!
Your output is restricted only to: (Input language) Intention + Subject(short as possible)
Your output MUST be a valid JSON.
@@ -19,7 +19,7 @@ User Input: hi, yesterday i had some burgers.
example 2:
User Input: hello
{
- "Language Type": "The user's input is written in pure English",
+ "Language Type": "The user's input is pure English",
"Your Reasoning": "The language of my output must be pure English.",
"Your Output": "Greeting myself☺️"
}
@@ -46,7 +46,7 @@ example 5:
User Input: why小红的年龄is老than小明?
{
"Language Type": "The user's input is English-Chinese mixed",
- "Your Reasoning": "The English parts are subjective particles, the main intention is written in Chinese, besides, Chinese occupies a greater \"actual meaning\" than English, so the language of my output must be using Chinese.",
+ "Your Reasoning": "The English parts are filler words, the main intention is written in Chinese, besides, Chinese occupies a greater \"actual meaning\" than English, so the language of my output must be using Chinese.",
"Your Output": "询问小红和小明的年龄"
}
@@ -58,7 +58,7 @@ User Input: yo, 你今天咋样?
"Your Output": "查询今日我的状态☺️"
}
-User Input:
+User Input:
""" # noqa: E501
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE = (
@@ -114,6 +114,13 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
"4. The returned object should contain at least one key-value pair.\n\n"
"5. The returned object should always be in the format: {result: ...}\n\n"
"Example:\n"
+ "/**\n"
+ " * Multiplies two numbers together.\n"
+ " *\n"
+ " * @param {number} arg1 - The first number to multiply.\n"
+ " * @param {number} arg2 - The second number to multiply.\n"
+ " * @returns {{ result: number }} The result of the multiplication.\n"
+ " */\n"
"function main(arg1, arg2) {\n"
" return {\n"
" result: arg1 * arg2\n"
@@ -130,7 +137,7 @@ JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE = (
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, "
- "and keeping each question under 20 characters.\n"
+ "and keep each question under 20 characters.\n"
"MAKE SURE your output is the SAME language as the Assistant's latest response. "
"The output must be an array in JSON format following the specified schema:\n"
'["question1","question2","question3"]\n'
@@ -156,11 +163,11 @@ Here is a task description for which I would like you to create a high-quality p
{{TASK_DESCRIPTION}}
Based on task description, please create a well-structured prompt template that another AI could use to consistently complete the task. The prompt template should include:
-- Do not include or section and variables in the prompt, assume user will add them at their own will.
-- Clear instructions for the AI that will be using this prompt, demarcated with tags. The instructions should provide step-by-step directions on how to complete the task using the input variables. Also Specifies in the instructions that the output should not contain any xml tag.
-- Relevant examples if needed to clarify the task further, demarcated with tags. Do not include variables in the prompt. Give three pairs of input and output examples.
-- Include other relevant sections demarcated with appropriate XML tags like , .
-- Use the same language as task description.
+- Do not include or section and variables in the prompt, assume user will add them at their own will.
+- Clear instructions for the AI that will be using this prompt, demarcated with tags. The instructions should provide step-by-step directions on how to complete the task using the input variables. Also Specifies in the instructions that the output should not contain any xml tag.
+- Relevant examples if needed to clarify the task further, demarcated with tags. Do not include variables in the prompt. Give three pairs of input and output examples.
+- Include other relevant sections demarcated with appropriate XML tags like , .
+- Use the same language as task description.
- Output in ``` xml ``` and start with
Please generate the full prompt template with at least 300 words and output only the prompt template.
""" # noqa: E501
@@ -171,28 +178,28 @@ Here is a task description for which I would like you to create a high-quality p
{{TASK_DESCRIPTION}}
Based on task description, please create a well-structured prompt template that another AI could use to consistently complete the task. The prompt template should include:
-- Descriptive variable names surrounded by {{ }} (two curly brackets) to indicate where the actual values will be substituted in. Choose variable names that clearly indicate the type of value expected. Variable names have to be composed of number, english alphabets and underline and nothing else.
-- Clear instructions for the AI that will be using this prompt, demarcated with tags. The instructions should provide step-by-step directions on how to complete the task using the input variables. Also Specifies in the instructions that the output should not contain any xml tag.
-- Relevant examples if needed to clarify the task further, demarcated with tags. Do not use curly brackets any other than in section.
+- Descriptive variable names surrounded by {{ }} (two curly brackets) to indicate where the actual values will be substituted in. Choose variable names that clearly indicate the type of value expected. Variable names have to be composed of number, english alphabets and underline and nothing else.
+- Clear instructions for the AI that will be using this prompt, demarcated with tags. The instructions should provide step-by-step directions on how to complete the task using the input variables. Also Specifies in the instructions that the output should not contain any xml tag.
+- Relevant examples if needed to clarify the task further, demarcated with tags. Do not use curly brackets any other than in section.
- Any other relevant sections demarcated with appropriate XML tags like , , etc.
-- Use the same language as task description.
+- Use the same language as task description.
- Output in ``` xml ``` and start with
Please generate the full prompt template and output only the prompt template.
""" # noqa: E501
RULE_CONFIG_PARAMETER_GENERATE_TEMPLATE = """
-I need to extract the following information from the input text. The tag specifies the 'type', 'description' and 'required' of the information to be extracted.
+I need to extract the following information from the input text. The tag specifies the 'type', 'description' and 'required' of the information to be extracted.
-variables name bounded two double curly brackets. Variable name has to be composed of number, english alphabets and underline and nothing else.
+variables name bounded two double curly brackets. Variable name has to be composed of number, english alphabets and underline and nothing else.
Step 1: Carefully read the input and understand the structure of the expected output.
-Step 2: Extract relevant parameters from the provided text based on the name and description of object.
+Step 2: Extract relevant parameters from the provided text based on the name and description of object.
Step 3: Structure the extracted parameters to JSON object as specified in .
-Step 4: Ensure that the list of variable_names is properly formatted and valid. The output should not contain any XML tags. Output an empty list if there is no valid variable name in input text.
+Step 4: Ensure that the list of variable_names is properly formatted and valid. The output should not contain any XML tags. Output an empty list if there is no valid variable name in input text.
### Structure
-Here is the structure of the expected output, I should always follow the output structure.
+Here is the structure of the expected output, I should always follow the output structure.
["variable_name_1", "variable_name_2"]
### Input Text
@@ -207,13 +214,13 @@ I should always output a valid list. Output nothing other than the list of varia
RULE_CONFIG_STATEMENT_GENERATE_TEMPLATE = """
-Step 1: Identify the purpose of the chatbot from the variable {{TASK_DESCRIPTION}} and infer chatbot's tone (e.g., friendly, professional, etc.) to add personality traits.
+Step 1: Identify the purpose of the chatbot from the variable {{TASK_DESCRIPTION}} and infer chatbot's tone (e.g., friendly, professional, etc.) to add personality traits.
Step 2: Create a coherent and engaging opening statement.
Step 3: Ensure the output is welcoming and clearly explains what the chatbot is designed to do. Do not include any XML tags in the output.
-Please use the same language as the user's input language. If user uses chinese then generate opening statement in chinese, if user uses english then generate opening statement in english.
-Example Input:
+Please use the same language as the user's input language. If user uses chinese then generate opening statement in chinese, if user uses english then generate opening statement in english.
+Example Input:
Provide customer support for an e-commerce website
-Example Output:
+Example Output:
Welcome! I'm here to assist you with any questions or issues you might have with your shopping experience. Whether you're looking for product information, need help with your order, or have any other inquiries, feel free to ask. I'm friendly, helpful, and ready to support you in any way I can.
Here is the task description: {{INPUT_TEXT}}
@@ -269,15 +276,15 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc
{
"type": "object",
"properties": {
- "email": {
+ "email": {
"type": "string",
"format": "email"
},
- "password": {
+ "password": {
"type": "string",
"minLength": 8
},
- "age": {
+ "age": {
"type": "integer",
"minimum": 18
}
@@ -291,32 +298,30 @@ Your task is to convert simple user descriptions into properly formatted JSON Sc
{
"type": "object",
"properties": {
- "properties": {
- "songs": {
- "type": "array",
- "items": {
- "type": "object",
- "properties": {
- "name": {
- "type": "string"
- },
- "id": {
- "type": "string"
- },
- "duration": {
- "type": "string"
- },
- "aritst": {
- "type": "string"
- }
+ "songs": {
+ "type": "array",
+ "items": {
+ "type": "object",
+ "properties": {
+ "name": {
+ "type": "string"
+ },
+ "id": {
+ "type": "string"
+ },
+ "duration": {
+ "type": "string"
},
- "required": [
- "name",
- "id",
- "duration",
- "aritst"
- ]
- }
+ "aritst": {
+ "type": "string"
+ }
+ },
+ "required": [
+ "name",
+ "id",
+ "duration",
+ "aritst"
+ ]
}
}
},
diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py
index 3c90dd22a2..2254b3d4d5 100644
--- a/api/core/memory/token_buffer_memory.py
+++ b/api/core/memory/token_buffer_memory.py
@@ -8,11 +8,11 @@ from core.model_runtime.entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
- PromptMessageContent,
PromptMessageRole,
TextPromptMessageContent,
UserPromptMessage,
)
+from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from factories import file_factory
@@ -100,7 +100,7 @@ class TokenBufferMemory:
if not file_objs:
prompt_messages.append(UserPromptMessage(content=message.query))
else:
- prompt_message_contents: list[PromptMessageContent] = []
+ prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
for file in file_objs:
prompt_message = file_manager.to_prompt_message_content(
diff --git a/api/core/model_manager.py b/api/core/model_manager.py
index 0845ef206e..995a30d44c 100644
--- a/api/core/model_manager.py
+++ b/api/core/model_manager.py
@@ -101,7 +101,7 @@ class ModelInstance:
@overload
def invoke_llm(
self,
- prompt_messages: list[PromptMessage],
+ prompt_messages: Sequence[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,
diff --git a/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md b/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md
index 2d71e99fce..d845c4bd09 100644
--- a/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md
+++ b/api/core/model_runtime/docs/en_US/customizable_model_scale_out.md
@@ -307,4 +307,4 @@ Runtime Errors:
"""
```
-For interface method details, see: [Interfaces](./interfaces.md). For specific implementations, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
\ No newline at end of file
+For interface method details, see: [Interfaces](./interfaces.md). For specific implementations, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
diff --git a/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md b/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md
index 3e16257452..a770ed157b 100644
--- a/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md
+++ b/api/core/model_runtime/docs/en_US/predefined_model_scale_out.md
@@ -170,4 +170,4 @@ Runtime Errors:
"""
```
-For interface method explanations, see: [Interfaces](./interfaces.md). For detailed implementation, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
\ No newline at end of file
+For interface method explanations, see: [Interfaces](./interfaces.md). For detailed implementation, refer to: [llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py).
diff --git a/api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md b/api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md
index 88ec6861fe..7d30655469 100644
--- a/api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md
+++ b/api/core/model_runtime/docs/zh_Hans/customizable_model_scale_out.md
@@ -294,4 +294,4 @@ provider_credential_schema:
"""
```
-接口方法说明见:[Interfaces](./interfaces.md),具体实现可参考:[llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py)。
\ No newline at end of file
+接口方法说明见:[Interfaces](./interfaces.md),具体实现可参考:[llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py)。
diff --git a/api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md b/api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md
index b33dc7c94b..80e7982e9f 100644
--- a/api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md
+++ b/api/core/model_runtime/docs/zh_Hans/predefined_model_scale_out.md
@@ -169,4 +169,4 @@ pricing: # 价格信息
"""
```
-接口方法说明见:[Interfaces](./interfaces.md),具体实现可参考:[llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py)。
\ No newline at end of file
+接口方法说明见:[Interfaces](./interfaces.md),具体实现可参考:[llm.py](https://github.com/langgenius/dify-runtime/blob/main/lib/model_providers/anthropic/llm/llm.py)。
diff --git a/api/core/model_runtime/entities/message_entities.py b/api/core/model_runtime/entities/message_entities.py
index 977678b893..9d010ae28d 100644
--- a/api/core/model_runtime/entities/message_entities.py
+++ b/api/core/model_runtime/entities/message_entities.py
@@ -1,8 +1,9 @@
-from collections.abc import Sequence
+from abc import ABC
+from collections.abc import Mapping, Sequence
from enum import Enum, StrEnum
-from typing import Optional
+from typing import Annotated, Any, Literal, Optional, Union
-from pydantic import BaseModel, Field, field_validator
+from pydantic import BaseModel, Field, field_serializer, field_validator
class PromptMessageRole(Enum):
@@ -60,7 +61,7 @@ class PromptMessageContentType(StrEnum):
DOCUMENT = "document"
-class PromptMessageContent(BaseModel):
+class PromptMessageContent(ABC, BaseModel):
"""
Model class for prompt message content.
"""
@@ -73,7 +74,7 @@ class TextPromptMessageContent(PromptMessageContent):
Model class for text prompt message content.
"""
- type: PromptMessageContentType = PromptMessageContentType.TEXT
+ type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT
data: str
@@ -82,7 +83,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
Model class for multi-modal prompt message content.
"""
- type: PromptMessageContentType
format: str = Field(default=..., description="the format of multi-modal file")
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
url: str = Field(default="", description="the url of multi-modal file")
@@ -94,11 +94,11 @@ class MultiModalPromptMessageContent(PromptMessageContent):
class VideoPromptMessageContent(MultiModalPromptMessageContent):
- type: PromptMessageContentType = PromptMessageContentType.VIDEO
+ type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO
class AudioPromptMessageContent(MultiModalPromptMessageContent):
- type: PromptMessageContentType = PromptMessageContentType.AUDIO
+ type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO
class ImagePromptMessageContent(MultiModalPromptMessageContent):
@@ -110,21 +110,42 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
LOW = "low"
HIGH = "high"
- type: PromptMessageContentType = PromptMessageContentType.IMAGE
+ type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
- type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
+ type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT
+
+
+PromptMessageContentUnionTypes = Annotated[
+ Union[
+ TextPromptMessageContent,
+ ImagePromptMessageContent,
+ DocumentPromptMessageContent,
+ AudioPromptMessageContent,
+ VideoPromptMessageContent,
+ ],
+ Field(discriminator="type"),
+]
+
+CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = {
+ PromptMessageContentType.TEXT: TextPromptMessageContent,
+ PromptMessageContentType.IMAGE: ImagePromptMessageContent,
+ PromptMessageContentType.AUDIO: AudioPromptMessageContent,
+ PromptMessageContentType.VIDEO: VideoPromptMessageContent,
+ PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent,
+}
-class PromptMessage(BaseModel):
+
+class PromptMessage(ABC, BaseModel):
"""
Model class for prompt message.
"""
role: PromptMessageRole
- content: Optional[str | Sequence[PromptMessageContent]] = None
+ content: Optional[str | list[PromptMessageContentUnionTypes]] = None
name: Optional[str] = None
def is_empty(self) -> bool:
@@ -135,6 +156,33 @@ class PromptMessage(BaseModel):
"""
return not self.content
+ @field_validator("content", mode="before")
+ @classmethod
+ def validate_content(cls, v):
+ if isinstance(v, list):
+ prompts = []
+ for prompt in v:
+ if isinstance(prompt, PromptMessageContent):
+ if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent):
+ prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump())
+ elif isinstance(prompt, dict):
+ prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt)
+ else:
+ raise ValueError(f"invalid prompt message {prompt}")
+ prompts.append(prompt)
+ return prompts
+ return v
+
+ @field_serializer("content")
+ def serialize_content(
+ self, content: Optional[Union[str, Sequence[PromptMessageContent]]]
+ ) -> Optional[str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent]]:
+ if content is None or isinstance(content, str):
+ return content
+ if isinstance(content, list):
+ return [item.model_dump() if hasattr(item, "model_dump") else item for item in content]
+ return content
+
class UserPromptMessage(PromptMessage):
"""
diff --git a/api/core/model_runtime/entities/provider_entities.py b/api/core/model_runtime/entities/provider_entities.py
index 85321bed94..d0f9ee13e5 100644
--- a/api/core/model_runtime/entities/provider_entities.py
+++ b/api/core/model_runtime/entities/provider_entities.py
@@ -134,6 +134,9 @@ class ProviderEntity(BaseModel):
# pydantic configs
model_config = ConfigDict(protected_namespaces=())
+ # position from plugin _position.yaml
+ position: Optional[dict[str, list[str]]] = {}
+
@field_validator("models", mode="before")
@classmethod
def validate_models(cls, v):
diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py
index bd05590018..7d5ce1e47e 100644
--- a/api/core/model_runtime/model_providers/__base/ai_model.py
+++ b/api/core/model_runtime/model_providers/__base/ai_model.py
@@ -24,9 +24,8 @@ from core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
-from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
-from core.plugin.manager.model import PluginModelManager
+from core.plugin.impl.model import PluginModelClient
class AIModel(BaseModel):
@@ -141,7 +140,7 @@ class AIModel(BaseModel):
:param credentials: model credentials
:return: model schema
"""
- plugin_model_manager = PluginModelManager()
+ plugin_model_manager = PluginModelClient()
cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}"
# sort credentials
sorted_credentials = sorted(credentials.items()) if credentials else []
@@ -253,15 +252,3 @@ class AIModel(BaseModel):
raise Exception(f"Invalid model parameter rule name {name}")
return default_parameter_rule
-
- def _get_num_tokens_by_gpt2(self, text: str) -> int:
- """
- Get number of tokens for given prompt messages by gpt2
- Some provider models do not provide an interface for obtaining the number of tokens.
- Here, the gpt2 tokenizer is used to calculate the number of tokens.
- This method can be executed offline, and the gpt2 tokenizer has been cached in the project.
-
- :param text: plain text of prompt. You need to convert the original message to plain text
- :return: number of tokens
- """
- return GPT2Tokenizer.get_num_tokens(text)
diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py
index 1b799131e7..e2cc576f83 100644
--- a/api/core/model_runtime/model_providers/__base/large_language_model.py
+++ b/api/core/model_runtime/model_providers/__base/large_language_model.py
@@ -13,14 +13,16 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
+ PromptMessageContentUnionTypes,
PromptMessageTool,
+ TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import (
ModelType,
PriceType,
)
from core.model_runtime.model_providers.__base.ai_model import AIModel
-from core.plugin.manager.model import PluginModelManager
+from core.plugin.impl.model import PluginModelClient
logger = logging.getLogger(__name__)
@@ -140,7 +142,7 @@ class LargeLanguageModel(AIModel):
result: Union[LLMResult, Generator[LLMResultChunk, None, None]]
try:
- plugin_model_manager = PluginModelManager()
+ plugin_model_manager = PluginModelClient()
result = plugin_model_manager.invoke_llm(
tenant_id=self.tenant_id,
user_id=user or "unknown",
@@ -237,7 +239,7 @@ class LargeLanguageModel(AIModel):
def _invoke_result_generator(
self,
model: str,
- result: Generator,
+ result: Generator[LLMResultChunk, None, None],
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
@@ -254,11 +256,21 @@ class LargeLanguageModel(AIModel):
:return: result generator
"""
callbacks = callbacks or []
- assistant_message = AssistantPromptMessage(content="")
+ message_content: list[PromptMessageContentUnionTypes] = []
usage = None
system_fingerprint = None
real_model = model
+ def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None):
+ if not content:
+ return
+ if isinstance(content, list):
+ message_content.extend(content)
+ return
+ if isinstance(content, str):
+ message_content.append(TextPromptMessageContent(data=content))
+ return
+
try:
for chunk in result:
# Following https://github.com/langgenius/dify/issues/17799,
@@ -280,7 +292,8 @@ class LargeLanguageModel(AIModel):
callbacks=callbacks,
)
- assistant_message.content += chunk.delta.message.content
+ _update_message_content(chunk.delta.message.content)
+
real_model = chunk.model
if chunk.delta.usage:
usage = chunk.delta.usage
@@ -290,6 +303,7 @@ class LargeLanguageModel(AIModel):
except Exception as e:
raise self._transform_invoke_error(e)
+ assistant_message = AssistantPromptMessage(content=message_content)
self._trigger_after_invoke_callbacks(
model=model,
result=LLMResult(
@@ -326,7 +340,7 @@ class LargeLanguageModel(AIModel):
:return:
"""
if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED:
- plugin_model_manager = PluginModelManager()
+ plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_llm_num_tokens(
tenant_id=self.tenant_id,
user_id="unknown",
diff --git a/api/core/model_runtime/model_providers/__base/moderation_model.py b/api/core/model_runtime/model_providers/__base/moderation_model.py
index f98d7572c7..19dc1d599a 100644
--- a/api/core/model_runtime/model_providers/__base/moderation_model.py
+++ b/api/core/model_runtime/model_providers/__base/moderation_model.py
@@ -5,7 +5,7 @@ from pydantic import ConfigDict
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
-from core.plugin.manager.model import PluginModelManager
+from core.plugin.impl.model import PluginModelClient
class ModerationModel(AIModel):
@@ -31,7 +31,7 @@ class ModerationModel(AIModel):
self.started_at = time.perf_counter()
try:
- plugin_model_manager = PluginModelManager()
+ plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_moderation(
tenant_id=self.tenant_id,
user_id=user or "unknown",
diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/core/model_runtime/model_providers/__base/rerank_model.py
index e905cb18d4..569e756a3b 100644
--- a/api/core/model_runtime/model_providers/__base/rerank_model.py
+++ b/api/core/model_runtime/model_providers/__base/rerank_model.py
@@ -3,7 +3,7 @@ from typing import Optional
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.model_providers.__base.ai_model import AIModel
-from core.plugin.manager.model import PluginModelManager
+from core.plugin.impl.model import PluginModelClient
class RerankModel(AIModel):
@@ -36,7 +36,7 @@ class RerankModel(AIModel):
:return: rerank result
"""
try:
- plugin_model_manager = PluginModelManager()
+ plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_rerank(
tenant_id=self.tenant_id,
user_id=user or "unknown",
diff --git a/api/core/model_runtime/model_providers/__base/speech2text_model.py b/api/core/model_runtime/model_providers/__base/speech2text_model.py
index 97ff322f09..c69f65b681 100644
--- a/api/core/model_runtime/model_providers/__base/speech2text_model.py
+++ b/api/core/model_runtime/model_providers/__base/speech2text_model.py
@@ -4,7 +4,7 @@ from pydantic import ConfigDict
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
-from core.plugin.manager.model import PluginModelManager
+from core.plugin.impl.model import PluginModelClient
class Speech2TextModel(AIModel):
@@ -28,7 +28,7 @@ class Speech2TextModel(AIModel):
:return: text for given audio file
"""
try:
- plugin_model_manager = PluginModelManager()
+ plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_speech_to_text(
tenant_id=self.tenant_id,
user_id=user or "unknown",
diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py
index c4c1f92177..f7bba0eba1 100644
--- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py
+++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py
@@ -6,7 +6,7 @@ from core.entities.embedding_type import EmbeddingInputType
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.__base.ai_model import AIModel
-from core.plugin.manager.model import PluginModelManager
+from core.plugin.impl.model import PluginModelClient
class TextEmbeddingModel(AIModel):
@@ -38,7 +38,7 @@ class TextEmbeddingModel(AIModel):
:return: embeddings result
"""
try:
- plugin_model_manager = PluginModelManager()
+ plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_text_embedding(
tenant_id=self.tenant_id,
user_id=user or "unknown",
@@ -61,7 +61,7 @@ class TextEmbeddingModel(AIModel):
:param texts: texts to embed
:return:
"""
- plugin_model_manager = PluginModelManager()
+ plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_text_embedding_num_tokens(
tenant_id=self.tenant_id,
user_id="unknown",
diff --git a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py
index 2f6f4fbbef..b7db0b78bc 100644
--- a/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py
+++ b/api/core/model_runtime/model_providers/__base/tokenizers/gpt2_tokenzier.py
@@ -30,6 +30,8 @@ class GPT2Tokenizer:
@staticmethod
def get_encoder() -> Any:
global _tokenizer, _lock
+ if _tokenizer is not None:
+ return _tokenizer
with _lock:
if _tokenizer is None:
# Try to use tiktoken to get the tokenizer because it is faster
diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py
index 1f248d11ac..d51831900c 100644
--- a/api/core/model_runtime/model_providers/__base/tts_model.py
+++ b/api/core/model_runtime/model_providers/__base/tts_model.py
@@ -6,7 +6,7 @@ from pydantic import ConfigDict
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel
-from core.plugin.manager.model import PluginModelManager
+from core.plugin.impl.model import PluginModelClient
logger = logging.getLogger(__name__)
@@ -42,7 +42,7 @@ class TTSModel(AIModel):
:return: translated audio file
"""
try:
- plugin_model_manager = PluginModelManager()
+ plugin_model_manager = PluginModelClient()
return plugin_model_manager.invoke_tts(
tenant_id=self.tenant_id,
user_id=user or "unknown",
@@ -65,7 +65,7 @@ class TTSModel(AIModel):
:param credentials: The credentials required to access the TTS model.
:return: A list of voices supported by the TTS model.
"""
- plugin_model_manager = PluginModelManager()
+ plugin_model_manager = PluginModelClient()
return plugin_model_manager.get_tts_model_voices(
tenant_id=self.tenant_id,
user_id="unknown",
diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py
index d2fd4916a4..ad46f64ec3 100644
--- a/api/core/model_runtime/model_providers/model_provider_factory.py
+++ b/api/core/model_runtime/model_providers/model_provider_factory.py
@@ -22,8 +22,8 @@ from core.model_runtime.schema_validators.model_credential_schema_validator impo
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
from core.plugin.entities.plugin import ModelProviderID
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
-from core.plugin.manager.asset import PluginAssetManager
-from core.plugin.manager.model import PluginModelManager
+from core.plugin.impl.asset import PluginAssetManager
+from core.plugin.impl.model import PluginModelClient
logger = logging.getLogger(__name__)
@@ -40,7 +40,7 @@ class ModelProviderFactory:
self.provider_position_map = {}
self.tenant_id = tenant_id
- self.plugin_model_manager = PluginModelManager()
+ self.plugin_model_manager = PluginModelClient()
if not self.provider_position_map:
# get the path of current classes
diff --git a/api/core/moderation/api/__builtin__ b/api/core/moderation/api/__builtin__
index e440e5c842..00750edc07 100644
--- a/api/core/moderation/api/__builtin__
+++ b/api/core/moderation/api/__builtin__
@@ -1 +1 @@
-3
\ No newline at end of file
+3
diff --git a/api/core/moderation/keywords/__builtin__ b/api/core/moderation/keywords/__builtin__
index d8263ee986..0cfbf08886 100644
--- a/api/core/moderation/keywords/__builtin__
+++ b/api/core/moderation/keywords/__builtin__
@@ -1 +1 @@
-2
\ No newline at end of file
+2
diff --git a/api/core/moderation/openai_moderation/__builtin__ b/api/core/moderation/openai_moderation/__builtin__
index 56a6051ca2..d00491fd7e 100644
--- a/api/core/moderation/openai_moderation/__builtin__
+++ b/api/core/moderation/openai_moderation/__builtin__
@@ -1 +1 @@
-1
\ No newline at end of file
+1
diff --git a/api/core/moderation/output_moderation.py b/api/core/moderation/output_moderation.py
index e595be126c..2ec315417f 100644
--- a/api/core/moderation/output_moderation.py
+++ b/api/core/moderation/output_moderation.py
@@ -46,14 +46,14 @@ class OutputModeration(BaseModel):
if not self.thread:
self.thread = self.start_thread()
- def moderation_completion(self, completion: str, public_event: bool = False) -> str:
+ def moderation_completion(self, completion: str, public_event: bool = False) -> tuple[str, bool]:
self.buffer = completion
self.is_final_chunk = True
result = self.moderation(tenant_id=self.tenant_id, app_id=self.app_id, moderation_buffer=completion)
if not result or not result.flagged:
- return completion
+ return completion, False
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
@@ -61,9 +61,14 @@ class OutputModeration(BaseModel):
final_output = result.text
if public_event:
- self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE)
+ self.queue_manager.publish(
+ QueueMessageReplaceEvent(
+ text=final_output, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION
+ ),
+ PublishFrom.TASK_PIPELINE,
+ )
- return final_output
+ return final_output, True
def start_thread(self) -> threading.Thread:
buffer_size = dify_config.MODERATION_BUFFER_SIZE
@@ -112,7 +117,12 @@ class OutputModeration(BaseModel):
# trigger replace event
if self.thread_running:
- self.queue_manager.publish(QueueMessageReplaceEvent(text=final_output), PublishFrom.TASK_PIPELINE)
+ self.queue_manager.publish(
+ QueueMessageReplaceEvent(
+ text=final_output, reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION
+ ),
+ PublishFrom.TASK_PIPELINE,
+ )
if result.action == ModerationAction.DIRECT_OUTPUT:
break
diff --git a/api/core/ops/entities/config_entity.py b/api/core/ops/entities/config_entity.py
index b484242b61..874b2800b2 100644
--- a/api/core/ops/entities/config_entity.py
+++ b/api/core/ops/entities/config_entity.py
@@ -7,6 +7,7 @@ class TracingProviderEnum(Enum):
LANGFUSE = "langfuse"
LANGSMITH = "langsmith"
OPIK = "opik"
+ WEAVE = "weave"
class BaseTracingConfig(BaseModel):
@@ -88,5 +89,26 @@ class OpikConfig(BaseTracingConfig):
return v
+class WeaveConfig(BaseTracingConfig):
+ """
+ Model class for Weave tracing config.
+ """
+
+ api_key: str
+ entity: str | None = None
+ project: str
+ endpoint: str = "https://trace.wandb.ai"
+
+ @field_validator("endpoint")
+ @classmethod
+ def set_value(cls, v, info: ValidationInfo):
+ if v is None or v == "":
+ v = "https://trace.wandb.ai"
+ if not v.startswith("https://"):
+ raise ValueError("endpoint must start with https://")
+
+ return v
+
+
OPS_FILE_PATH = "ops_trace/"
OPS_TRACE_FAILED_KEY = "FAILED_OPS_TRACE"
diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py
index fa78b7b8e9..c74617e558 100644
--- a/api/core/ops/langfuse_trace/langfuse_trace.py
+++ b/api/core/ops/langfuse_trace/langfuse_trace.py
@@ -29,7 +29,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
UnitEnum,
)
from core.ops.utils import filter_none_values
-from core.repository.repository_factory import RepositoryFactory
+from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from extensions.ext_database import db
from models.model import EndUser
@@ -113,8 +113,8 @@ class LangFuseDataTrace(BaseTraceInstance):
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
- workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
- params={"tenant_id": trace_info.tenant_id, "session_factory": session_factory},
+ workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ session_factory=session_factory, tenant_id=trace_info.tenant_id
)
# Get all executions for this workflow run
diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py
index 85a0eafdc1..d1e16d3152 100644
--- a/api/core/ops/langsmith_trace/langsmith_trace.py
+++ b/api/core/ops/langsmith_trace/langsmith_trace.py
@@ -28,7 +28,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunUpdateModel,
)
from core.ops.utils import filter_none_values, generate_dotted_order
-from core.repository.repository_factory import RepositoryFactory
+from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from extensions.ext_database import db
from models.model import EndUser, MessageFile
@@ -137,12 +137,8 @@ class LangSmithDataTrace(BaseTraceInstance):
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
- workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
- params={
- "tenant_id": trace_info.tenant_id,
- "app_id": trace_info.metadata.get("app_id"),
- "session_factory": session_factory,
- },
+ workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id")
)
# Get all executions for this workflow run
diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py
index 923b9a24ed..1484041447 100644
--- a/api/core/ops/opik_trace/opik_trace.py
+++ b/api/core/ops/opik_trace/opik_trace.py
@@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import (
TraceTaskName,
WorkflowTraceInfo,
)
-from core.repository.repository_factory import RepositoryFactory
+from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from extensions.ext_database import db
from models.model import EndUser, MessageFile
@@ -150,12 +150,8 @@ class OpikDataTrace(BaseTraceInstance):
# through workflow_run_id get all_nodes_execution using repository
session_factory = sessionmaker(bind=db.engine)
- workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
- params={
- "tenant_id": trace_info.tenant_id,
- "app_id": trace_info.metadata.get("app_id"),
- "session_factory": session_factory,
- },
+ workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id")
)
# Get all executions for this workflow run
diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py
index 6fc02393fe..2c68055f87 100644
--- a/api/core/ops/ops_trace_manager.py
+++ b/api/core/ops/ops_trace_manager.py
@@ -20,6 +20,7 @@ from core.ops.entities.config_entity import (
LangSmithConfig,
OpikConfig,
TracingProviderEnum,
+ WeaveConfig,
)
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
@@ -34,7 +35,9 @@ from core.ops.entities.trace_entity import (
)
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
+from core.ops.opik_trace.opik_trace import OpikDataTrace
from core.ops.utils import get_message_data
+from core.ops.weave_trace.weave_trace import WeaveDataTrace
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
@@ -43,8 +46,6 @@ from tasks.ops_trace_task import process_trace_tasks
def build_opik_trace_instance(config: OpikConfig):
- from core.ops.opik_trace.opik_trace import OpikDataTrace
-
return OpikDataTrace(config)
@@ -67,6 +68,12 @@ provider_config_map: dict[str, dict[str, Any]] = {
"other_keys": ["project", "url", "workspace"],
"trace_instance": lambda config: build_opik_trace_instance(config),
},
+ TracingProviderEnum.WEAVE.value: {
+ "config_class": WeaveConfig,
+ "secret_keys": ["api_key"],
+ "other_keys": ["project", "entity", "endpoint"],
+ "trace_instance": WeaveDataTrace,
+ },
}
diff --git a/api/core/ops/weave_trace/__init__.py b/api/core/ops/weave_trace/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/ops/weave_trace/entities/__init__.py b/api/core/ops/weave_trace/entities/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/core/ops/weave_trace/entities/weave_trace_entity.py
new file mode 100644
index 0000000000..e423f5ccbb
--- /dev/null
+++ b/api/core/ops/weave_trace/entities/weave_trace_entity.py
@@ -0,0 +1,97 @@
+from typing import Any, Optional, Union
+
+from pydantic import BaseModel, Field, field_validator
+from pydantic_core.core_schema import ValidationInfo
+
+from core.ops.utils import replace_text_with_content
+
+
+class WeaveTokenUsage(BaseModel):
+ input_tokens: Optional[int] = None
+ output_tokens: Optional[int] = None
+ total_tokens: Optional[int] = None
+
+
+class WeaveMultiModel(BaseModel):
+ file_list: Optional[list[str]] = Field(None, description="List of files")
+
+
+class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel):
+ id: str = Field(..., description="ID of the trace")
+ op: str = Field(..., description="Name of the operation")
+ inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace")
+ outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace")
+ attributes: Optional[Union[str, dict[str, Any], list, None]] = Field(
+ None, description="Metadata and attributes associated with trace"
+ )
+ exception: Optional[str] = Field(None, description="Exception message of the trace")
+
+ @field_validator("inputs", "outputs")
+ @classmethod
+ def ensure_dict(cls, v, info: ValidationInfo):
+ field_name = info.field_name
+ values = info.data
+ if v == {} or v is None:
+ return v
+ usage_metadata = {
+ "input_tokens": values.get("input_tokens", 0),
+ "output_tokens": values.get("output_tokens", 0),
+ "total_tokens": values.get("total_tokens", 0),
+ }
+ file_list = values.get("file_list", [])
+ if isinstance(v, str):
+ if field_name == "inputs":
+ return {
+ "messages": {
+ "role": "user",
+ "content": v,
+ "usage_metadata": usage_metadata,
+ "file_list": file_list,
+ },
+ }
+ elif field_name == "outputs":
+ return {
+ "choices": {
+ "role": "ai",
+ "content": v,
+ "usage_metadata": usage_metadata,
+ "file_list": file_list,
+ },
+ }
+ elif isinstance(v, list):
+ data = {}
+ if len(v) > 0 and isinstance(v[0], dict):
+ # rename text to content
+ v = replace_text_with_content(data=v)
+ if field_name == "inputs":
+ data = {
+ "messages": [
+ dict(msg, **{"usage_metadata": usage_metadata, "file_list": file_list}) for msg in v
+ ]
+ if isinstance(v, list)
+ else v,
+ }
+ elif field_name == "outputs":
+ data = {
+ "choices": {
+ "role": "ai",
+ "content": v,
+ "usage_metadata": usage_metadata,
+ "file_list": file_list,
+ },
+ }
+ return data
+ else:
+ return {
+ "choices": {
+ "role": "ai" if field_name == "outputs" else "user",
+ "content": str(v),
+ "usage_metadata": usage_metadata,
+ "file_list": file_list,
+ },
+ }
+ if isinstance(v, dict):
+ v["usage_metadata"] = usage_metadata
+ v["file_list"] = file_list
+ return v
+ return v
diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py
new file mode 100644
index 0000000000..49594cb0f1
--- /dev/null
+++ b/api/core/ops/weave_trace/weave_trace.py
@@ -0,0 +1,420 @@
+import json
+import logging
+import os
+import uuid
+from datetime import datetime, timedelta
+from typing import Any, Optional, cast
+
+import wandb
+import weave
+
+from core.ops.base_trace_instance import BaseTraceInstance
+from core.ops.entities.config_entity import WeaveConfig
+from core.ops.entities.trace_entity import (
+ BaseTraceInfo,
+ DatasetRetrievalTraceInfo,
+ GenerateNameTraceInfo,
+ MessageTraceInfo,
+ ModerationTraceInfo,
+ SuggestedQuestionTraceInfo,
+ ToolTraceInfo,
+ TraceTaskName,
+ WorkflowTraceInfo,
+)
+from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
+from extensions.ext_database import db
+from models.model import EndUser, MessageFile
+from models.workflow import WorkflowNodeExecution
+
+logger = logging.getLogger(__name__)
+
+
+class WeaveDataTrace(BaseTraceInstance):
+ def __init__(
+ self,
+ weave_config: WeaveConfig,
+ ):
+ super().__init__(weave_config)
+ self.weave_api_key = weave_config.api_key
+ self.project_name = weave_config.project
+ self.entity = weave_config.entity
+
+ # Login with API key first
+ login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
+ if not login_status:
+ logger.error("Failed to login to Weights & Biases with the provided API key")
+ raise ValueError("Weave login failed")
+
+ # Then initialize weave client
+ self.weave_client = weave.init(
+ project_name=(f"{self.entity}/{self.project_name}" if self.entity else self.project_name)
+ )
+ self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
+ self.calls: dict[str, Any] = {}
+
+ def get_project_url(
+ self,
+ ):
+ try:
+ project_url = f"https://wandb.ai/{self.weave_client._project_id()}"
+ return project_url
+ except Exception as e:
+ logger.debug(f"Weave get run url failed: {str(e)}")
+ raise ValueError(f"Weave get run url failed: {str(e)}")
+
+ def trace(self, trace_info: BaseTraceInfo):
+ logger.debug(f"Trace info: {trace_info}")
+ if isinstance(trace_info, WorkflowTraceInfo):
+ self.workflow_trace(trace_info)
+ if isinstance(trace_info, MessageTraceInfo):
+ self.message_trace(trace_info)
+ if isinstance(trace_info, ModerationTraceInfo):
+ self.moderation_trace(trace_info)
+ if isinstance(trace_info, SuggestedQuestionTraceInfo):
+ self.suggested_question_trace(trace_info)
+ if isinstance(trace_info, DatasetRetrievalTraceInfo):
+ self.dataset_retrieval_trace(trace_info)
+ if isinstance(trace_info, ToolTraceInfo):
+ self.tool_trace(trace_info)
+ if isinstance(trace_info, GenerateNameTraceInfo):
+ self.generate_name_trace(trace_info)
+
+ def workflow_trace(self, trace_info: WorkflowTraceInfo):
+ trace_id = trace_info.message_id or trace_info.workflow_run_id
+ if trace_info.start_time is None:
+ trace_info.start_time = datetime.now()
+
+ if trace_info.message_id:
+ message_attributes = trace_info.metadata
+ message_attributes["workflow_app_log_id"] = trace_info.workflow_app_log_id
+
+ message_attributes["message_id"] = trace_info.message_id
+ message_attributes["workflow_run_id"] = trace_info.workflow_run_id
+ message_attributes["trace_id"] = trace_id
+ message_attributes["start_time"] = trace_info.start_time
+ message_attributes["end_time"] = trace_info.end_time
+ message_attributes["tags"] = ["message", "workflow"]
+
+ message_run = WeaveTraceModel(
+ id=trace_info.message_id,
+ op=str(TraceTaskName.MESSAGE_TRACE.value),
+ inputs=dict(trace_info.workflow_run_inputs),
+ outputs=dict(trace_info.workflow_run_outputs),
+ total_tokens=trace_info.total_tokens,
+ attributes=message_attributes,
+ exception=trace_info.error,
+ file_list=[],
+ )
+ self.start_call(message_run, parent_run_id=trace_info.workflow_run_id)
+ self.finish_call(message_run)
+
+ workflow_attributes = trace_info.metadata
+ workflow_attributes["workflow_run_id"] = trace_info.workflow_run_id
+ workflow_attributes["trace_id"] = trace_id
+ workflow_attributes["start_time"] = trace_info.start_time
+ workflow_attributes["end_time"] = trace_info.end_time
+ workflow_attributes["tags"] = ["workflow"]
+
+ workflow_run = WeaveTraceModel(
+ file_list=trace_info.file_list,
+ total_tokens=trace_info.total_tokens,
+ id=trace_info.workflow_run_id,
+ op=str(TraceTaskName.WORKFLOW_TRACE.value),
+ inputs=dict(trace_info.workflow_run_inputs),
+ outputs=dict(trace_info.workflow_run_outputs),
+ attributes=workflow_attributes,
+ exception=trace_info.error,
+ )
+
+ self.start_call(workflow_run, parent_run_id=trace_info.message_id)
+
+ # through workflow_run_id get all_nodes_execution
+ workflow_nodes_execution_id_records = (
+ db.session.query(WorkflowNodeExecution.id)
+ .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id)
+ .all()
+ )
+
+ for node_execution_id_record in workflow_nodes_execution_id_records:
+ node_execution = (
+ db.session.query(
+ WorkflowNodeExecution.id,
+ WorkflowNodeExecution.tenant_id,
+ WorkflowNodeExecution.app_id,
+ WorkflowNodeExecution.title,
+ WorkflowNodeExecution.node_type,
+ WorkflowNodeExecution.status,
+ WorkflowNodeExecution.inputs,
+ WorkflowNodeExecution.outputs,
+ WorkflowNodeExecution.created_at,
+ WorkflowNodeExecution.elapsed_time,
+ WorkflowNodeExecution.process_data,
+ WorkflowNodeExecution.execution_metadata,
+ )
+ .filter(WorkflowNodeExecution.id == node_execution_id_record.id)
+ .first()
+ )
+
+ if not node_execution:
+ continue
+
+ node_execution_id = node_execution.id
+ tenant_id = node_execution.tenant_id
+ app_id = node_execution.app_id
+ node_name = node_execution.title
+ node_type = node_execution.node_type
+ status = node_execution.status
+ if node_type == "llm":
+ inputs = (
+ json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {}
+ )
+ else:
+ inputs = json.loads(node_execution.inputs) if node_execution.inputs else {}
+ outputs = json.loads(node_execution.outputs) if node_execution.outputs else {}
+ created_at = node_execution.created_at or datetime.now()
+ elapsed_time = node_execution.elapsed_time
+ finished_at = created_at + timedelta(seconds=elapsed_time)
+
+ execution_metadata = (
+ json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {}
+ )
+ node_total_tokens = execution_metadata.get("total_tokens", 0)
+ attributes = execution_metadata.copy()
+ attributes.update(
+ {
+ "workflow_run_id": trace_info.workflow_run_id,
+ "node_execution_id": node_execution_id,
+ "tenant_id": tenant_id,
+ "app_id": app_id,
+ "app_name": node_name,
+ "node_type": node_type,
+ "status": status,
+ }
+ )
+
+ process_data = json.loads(node_execution.process_data) if node_execution.process_data else {}
+ if process_data and process_data.get("model_mode") == "chat":
+ attributes.update(
+ {
+ "ls_provider": process_data.get("model_provider", ""),
+ "ls_model_name": process_data.get("model_name", ""),
+ }
+ )
+ attributes["tags"] = ["node_execution"]
+ attributes["start_time"] = created_at
+ attributes["end_time"] = finished_at
+ attributes["elapsed_time"] = elapsed_time
+ attributes["workflow_run_id"] = trace_info.workflow_run_id
+ attributes["trace_id"] = trace_id
+ node_run = WeaveTraceModel(
+ total_tokens=node_total_tokens,
+ op=node_type,
+ inputs=inputs,
+ outputs=outputs,
+ file_list=trace_info.file_list,
+ attributes=attributes,
+ id=node_execution_id,
+ exception=None,
+ )
+
+ self.start_call(node_run, parent_run_id=trace_info.workflow_run_id)
+ self.finish_call(node_run)
+
+ self.finish_call(workflow_run)
+
+ def message_trace(self, trace_info: MessageTraceInfo):
+ # get message file data
+ file_list = cast(list[str], trace_info.file_list) or []
+ message_file_data: Optional[MessageFile] = trace_info.message_file_data
+ file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else ""
+ file_list.append(file_url)
+ attributes = trace_info.metadata
+ message_data = trace_info.message_data
+ if message_data is None:
+ return
+ message_id = message_data.id
+
+ user_id = message_data.from_account_id
+ attributes["user_id"] = user_id
+
+ if message_data.from_end_user_id:
+ end_user_data: Optional[EndUser] = (
+ db.session.query(EndUser).filter(EndUser.id == message_data.from_end_user_id).first()
+ )
+ if end_user_data is not None:
+ end_user_id = end_user_data.session_id
+ attributes["end_user_id"] = end_user_id
+
+ attributes["message_id"] = message_id
+ attributes["start_time"] = trace_info.start_time
+ attributes["end_time"] = trace_info.end_time
+ attributes["tags"] = ["message", str(trace_info.conversation_mode)]
+ message_run = WeaveTraceModel(
+ id=message_id,
+ op=str(TraceTaskName.MESSAGE_TRACE.value),
+ input_tokens=trace_info.message_tokens,
+ output_tokens=trace_info.answer_tokens,
+ total_tokens=trace_info.total_tokens,
+ inputs=trace_info.inputs,
+ outputs=trace_info.outputs,
+ exception=trace_info.error,
+ file_list=file_list,
+ attributes=attributes,
+ )
+ self.start_call(message_run)
+
+ # create llm run parented to message run
+ llm_run = WeaveTraceModel(
+ id=str(uuid.uuid4()),
+ input_tokens=trace_info.message_tokens,
+ output_tokens=trace_info.answer_tokens,
+ total_tokens=trace_info.total_tokens,
+ op="llm",
+ inputs=trace_info.inputs,
+ outputs=trace_info.outputs,
+ attributes=attributes,
+ file_list=[],
+ exception=None,
+ )
+ self.start_call(
+ llm_run,
+ parent_run_id=message_id,
+ )
+ self.finish_call(llm_run)
+ self.finish_call(message_run)
+
+ def moderation_trace(self, trace_info: ModerationTraceInfo):
+ if trace_info.message_data is None:
+ return
+
+ attributes = trace_info.metadata
+ attributes["tags"] = ["moderation"]
+ attributes["message_id"] = trace_info.message_id
+ attributes["start_time"] = trace_info.start_time or trace_info.message_data.created_at
+ attributes["end_time"] = trace_info.end_time or trace_info.message_data.updated_at
+
+ moderation_run = WeaveTraceModel(
+ id=str(uuid.uuid4()),
+ op=str(TraceTaskName.MODERATION_TRACE.value),
+ inputs=trace_info.inputs,
+ outputs={
+ "action": trace_info.action,
+ "flagged": trace_info.flagged,
+ "preset_response": trace_info.preset_response,
+ "inputs": trace_info.inputs,
+ },
+ attributes=attributes,
+ exception=getattr(trace_info, "error", None),
+ file_list=[],
+ )
+ self.start_call(moderation_run, parent_run_id=trace_info.message_id)
+ self.finish_call(moderation_run)
+
+ def suggested_question_trace(self, trace_info: SuggestedQuestionTraceInfo):
+ message_data = trace_info.message_data
+ if message_data is None:
+ return
+ attributes = trace_info.metadata
+ attributes["message_id"] = trace_info.message_id
+ attributes["tags"] = ["suggested_question"]
+ attributes["start_time"] = (trace_info.start_time or message_data.created_at,)
+ attributes["end_time"] = (trace_info.end_time or message_data.updated_at,)
+
+ suggested_question_run = WeaveTraceModel(
+ id=str(uuid.uuid4()),
+ op=str(TraceTaskName.SUGGESTED_QUESTION_TRACE.value),
+ inputs=trace_info.inputs,
+ outputs=trace_info.suggested_question,
+ attributes=attributes,
+ exception=trace_info.error,
+ file_list=[],
+ )
+
+ self.start_call(suggested_question_run, parent_run_id=trace_info.message_id)
+ self.finish_call(suggested_question_run)
+
+ def dataset_retrieval_trace(self, trace_info: DatasetRetrievalTraceInfo):
+ if trace_info.message_data is None:
+ return
+ attributes = trace_info.metadata
+ attributes["message_id"] = trace_info.message_id
+ attributes["tags"] = ["dataset_retrieval"]
+ attributes["start_time"] = (trace_info.start_time or trace_info.message_data.created_at,)
+ attributes["end_time"] = (trace_info.end_time or trace_info.message_data.updated_at,)
+
+ dataset_retrieval_run = WeaveTraceModel(
+ id=str(uuid.uuid4()),
+ op=str(TraceTaskName.DATASET_RETRIEVAL_TRACE.value),
+ inputs=trace_info.inputs,
+ outputs={"documents": trace_info.documents},
+ attributes=attributes,
+ exception=getattr(trace_info, "error", None),
+ file_list=[],
+ )
+
+ self.start_call(dataset_retrieval_run, parent_run_id=trace_info.message_id)
+ self.finish_call(dataset_retrieval_run)
+
+ def tool_trace(self, trace_info: ToolTraceInfo):
+ attributes = trace_info.metadata
+ attributes["tags"] = ["tool", trace_info.tool_name]
+ attributes["start_time"] = trace_info.start_time
+ attributes["end_time"] = trace_info.end_time
+
+ tool_run = WeaveTraceModel(
+ id=str(uuid.uuid4()),
+ op=trace_info.tool_name,
+ inputs=trace_info.tool_inputs,
+ outputs=trace_info.tool_outputs,
+ file_list=[cast(str, trace_info.file_url)] if trace_info.file_url else [],
+ attributes=attributes,
+ exception=trace_info.error,
+ )
+ message_id = trace_info.message_id or getattr(trace_info, "conversation_id", None)
+ message_id = message_id or None
+ self.start_call(tool_run, parent_run_id=message_id)
+ self.finish_call(tool_run)
+
+ def generate_name_trace(self, trace_info: GenerateNameTraceInfo):
+ attributes = trace_info.metadata
+ attributes["tags"] = ["generate_name"]
+ attributes["start_time"] = trace_info.start_time
+ attributes["end_time"] = trace_info.end_time
+
+ name_run = WeaveTraceModel(
+ id=str(uuid.uuid4()),
+ op=str(TraceTaskName.GENERATE_NAME_TRACE.value),
+ inputs=trace_info.inputs,
+ outputs=trace_info.outputs,
+ attributes=attributes,
+ exception=getattr(trace_info, "error", None),
+ file_list=[],
+ )
+
+ self.start_call(name_run)
+ self.finish_call(name_run)
+
+ def api_check(self):
+ try:
+ login_status = wandb.login(key=self.weave_api_key, verify=True, relogin=True)
+ if not login_status:
+ raise ValueError("Weave login failed")
+ else:
+ print("Weave login successful")
+ return True
+ except Exception as e:
+ logger.debug(f"Weave API check failed: {str(e)}")
+ raise ValueError(f"Weave API check failed: {str(e)}")
+
+ def start_call(self, run_data: WeaveTraceModel, parent_run_id: Optional[str] = None):
+ call = self.weave_client.create_call(op=run_data.op, inputs=run_data.inputs, attributes=run_data.attributes)
+ self.calls[run_data.id] = call
+ if parent_run_id:
+ self.calls[run_data.id].parent_id = parent_run_id
+
+ def finish_call(self, run_data: WeaveTraceModel):
+ call = self.calls.get(run_data.id)
+ if call:
+ self.weave_client.finish_call(call=call, output=run_data.outputs, exception=run_data.exception)
+ else:
+ raise ValueError(f"Call with id {run_data.id} not found")
diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py
index 484f52e33c..4e43561a15 100644
--- a/api/core/plugin/backwards_invocation/app.py
+++ b/api/core/plugin/backwards_invocation/app.py
@@ -72,7 +72,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
raise ValueError("missing query")
return cls.invoke_chat_app(app, user, conversation_id, query, stream, inputs, files)
- elif app.mode == AppMode.WORKFLOW.value:
+ elif app.mode == AppMode.WORKFLOW:
return cls.invoke_workflow_app(app, user, stream, inputs, files)
elif app.mode == AppMode.COMPLETION:
return cls.invoke_completion_app(app, user, stream, inputs, files)
diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py
index 490a475c16..5ec9620f22 100644
--- a/api/core/plugin/backwards_invocation/model.py
+++ b/api/core/plugin/backwards_invocation/model.py
@@ -239,8 +239,8 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
content = payload.text
SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
-and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
-retain the original meaning and keep the key points.
+and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
+retain the original meaning and keep the key points.
however, the text you got is too long, what you got is possible a part of the text.
Please summarize the text you got.
diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py
index 1588cbc3c7..2bea07bea0 100644
--- a/api/core/plugin/entities/plugin_daemon.py
+++ b/api/core/plugin/entities/plugin_daemon.py
@@ -1,6 +1,7 @@
+from collections.abc import Mapping
from datetime import datetime
from enum import StrEnum
-from typing import Generic, Optional, TypeVar
+from typing import Any, Generic, Optional, TypeVar
from pydantic import BaseModel, ConfigDict, Field
@@ -158,3 +159,11 @@ class PluginInstallTaskStartResponse(BaseModel):
class PluginUploadResponse(BaseModel):
unique_identifier: str = Field(description="The unique identifier of the plugin.")
manifest: PluginDeclaration
+
+
+class PluginOAuthAuthorizationUrlResponse(BaseModel):
+ authorization_url: str = Field(description="The URL of the authorization.")
+
+
+class PluginOAuthCredentialsResponse(BaseModel):
+ credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")
diff --git a/api/core/plugin/manager/agent.py b/api/core/plugin/impl/agent.py
similarity index 97%
rename from api/core/plugin/manager/agent.py
rename to api/core/plugin/impl/agent.py
index 50172f12f2..66b77c7489 100644
--- a/api/core/plugin/manager/agent.py
+++ b/api/core/plugin/impl/agent.py
@@ -6,10 +6,10 @@ from core.plugin.entities.plugin import GenericProviderID
from core.plugin.entities.plugin_daemon import (
PluginAgentProviderEntity,
)
-from core.plugin.manager.base import BasePluginManager
+from core.plugin.impl.base import BasePluginClient
-class PluginAgentManager(BasePluginManager):
+class PluginAgentClient(BasePluginClient):
def fetch_agent_strategy_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]:
"""
Fetch agent providers for the given tenant.
diff --git a/api/core/plugin/manager/asset.py b/api/core/plugin/impl/asset.py
similarity index 76%
rename from api/core/plugin/manager/asset.py
rename to api/core/plugin/impl/asset.py
index 17755d3561..b9bfe2d2cf 100644
--- a/api/core/plugin/manager/asset.py
+++ b/api/core/plugin/impl/asset.py
@@ -1,7 +1,7 @@
-from core.plugin.manager.base import BasePluginManager
+from core.plugin.impl.base import BasePluginClient
-class PluginAssetManager(BasePluginManager):
+class PluginAssetManager(BasePluginClient):
def fetch_asset(self, tenant_id: str, id: str) -> bytes:
"""
Fetch an asset by id.
diff --git a/api/core/plugin/manager/base.py b/api/core/plugin/impl/base.py
similarity index 99%
rename from api/core/plugin/manager/base.py
rename to api/core/plugin/impl/base.py
index d8d7b3e860..4f1d808a3e 100644
--- a/api/core/plugin/manager/base.py
+++ b/api/core/plugin/impl/base.py
@@ -18,7 +18,7 @@ from core.model_runtime.errors.invoke import (
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError
-from core.plugin.manager.exc import (
+from core.plugin.impl.exc import (
PluginDaemonBadRequestError,
PluginDaemonInternalServerError,
PluginDaemonNotFoundError,
@@ -37,7 +37,7 @@ T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
logger = logging.getLogger(__name__)
-class BasePluginManager:
+class BasePluginClient:
def _request(
self,
method: str,
diff --git a/api/core/plugin/manager/debugging.py b/api/core/plugin/impl/debugging.py
similarity index 78%
rename from api/core/plugin/manager/debugging.py
rename to api/core/plugin/impl/debugging.py
index fb6bad7fa3..523377895c 100644
--- a/api/core/plugin/manager/debugging.py
+++ b/api/core/plugin/impl/debugging.py
@@ -1,9 +1,9 @@
from pydantic import BaseModel
-from core.plugin.manager.base import BasePluginManager
+from core.plugin.impl.base import BasePluginClient
-class PluginDebuggingManager(BasePluginManager):
+class PluginDebuggingClient(BasePluginClient):
def get_debugging_key(self, tenant_id: str) -> str:
"""
Get the debugging key for the given tenant.
diff --git a/api/core/plugin/manager/endpoint.py b/api/core/plugin/impl/endpoint.py
similarity index 97%
rename from api/core/plugin/manager/endpoint.py
rename to api/core/plugin/impl/endpoint.py
index 415b981ffb..5b88742be5 100644
--- a/api/core/plugin/manager/endpoint.py
+++ b/api/core/plugin/impl/endpoint.py
@@ -1,8 +1,8 @@
from core.plugin.entities.endpoint import EndpointEntityWithInstance
-from core.plugin.manager.base import BasePluginManager
+from core.plugin.impl.base import BasePluginClient
-class PluginEndpointManager(BasePluginManager):
+class PluginEndpointClient(BasePluginClient):
def create_endpoint(
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict
) -> bool:
diff --git a/api/core/plugin/manager/exc.py b/api/core/plugin/impl/exc.py
similarity index 100%
rename from api/core/plugin/manager/exc.py
rename to api/core/plugin/impl/exc.py
diff --git a/api/core/plugin/manager/model.py b/api/core/plugin/impl/model.py
similarity index 99%
rename from api/core/plugin/manager/model.py
rename to api/core/plugin/impl/model.py
index 5ebc0c2320..f7607eef8d 100644
--- a/api/core/plugin/manager/model.py
+++ b/api/core/plugin/impl/model.py
@@ -18,10 +18,10 @@ from core.plugin.entities.plugin_daemon import (
PluginTextEmbeddingNumTokensResponse,
PluginVoicesResponse,
)
-from core.plugin.manager.base import BasePluginManager
+from core.plugin.impl.base import BasePluginClient
-class PluginModelManager(BasePluginManager):
+class PluginModelClient(BasePluginClient):
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
"""
Fetch model providers for the given tenant.
diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py
new file mode 100644
index 0000000000..91774984c8
--- /dev/null
+++ b/api/core/plugin/impl/oauth.py
@@ -0,0 +1,98 @@
+from collections.abc import Mapping
+from typing import Any
+
+from werkzeug import Request
+
+from core.plugin.entities.plugin_daemon import PluginOAuthAuthorizationUrlResponse, PluginOAuthCredentialsResponse
+from core.plugin.impl.base import BasePluginClient
+
+
+class OAuthHandler(BasePluginClient):
+ def get_authorization_url(
+ self,
+ tenant_id: str,
+ user_id: str,
+ plugin_id: str,
+ provider: str,
+ system_credentials: Mapping[str, Any],
+ ) -> PluginOAuthAuthorizationUrlResponse:
+ return self._request_with_plugin_daemon_response(
+ "POST",
+ f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url",
+ PluginOAuthAuthorizationUrlResponse,
+ data={
+ "user_id": user_id,
+ "data": {
+ "provider": provider,
+ "system_credentials": system_credentials,
+ },
+ },
+ headers={
+ "X-Plugin-ID": plugin_id,
+ "Content-Type": "application/json",
+ },
+ )
+
+ def get_credentials(
+ self,
+ tenant_id: str,
+ user_id: str,
+ plugin_id: str,
+ provider: str,
+ system_credentials: Mapping[str, Any],
+ request: Request,
+ ) -> PluginOAuthCredentialsResponse:
+ """
+ Get credentials from the given request.
+ """
+
+ # encode request to raw http request
+ raw_request_bytes = self._convert_request_to_raw_data(request)
+
+ return self._request_with_plugin_daemon_response(
+ "POST",
+ f"plugin/{tenant_id}/dispatch/oauth/get_credentials",
+ PluginOAuthCredentialsResponse,
+ data={
+ "user_id": user_id,
+ "data": {
+ "provider": provider,
+ "system_credentials": system_credentials,
+ "raw_request_bytes": raw_request_bytes,
+ },
+ },
+ headers={
+ "X-Plugin-ID": plugin_id,
+ "Content-Type": "application/json",
+ },
+ )
+
+ def _convert_request_to_raw_data(self, request: Request) -> bytes:
+ """
+ Convert a Request object to raw HTTP data.
+
+ Args:
+ request: The Request object to convert.
+
+ Returns:
+ The raw HTTP data as bytes.
+ """
+ # Start with the request line
+ method = request.method
+ path = request.path
+ protocol = request.headers.get("HTTP_VERSION", "HTTP/1.1")
+ raw_data = f"{method} {path} {protocol}\r\n".encode()
+
+ # Add headers
+ for header_name, header_value in request.headers.items():
+ raw_data += f"{header_name}: {header_value}\r\n".encode()
+
+ # Add empty line to separate headers from body
+ raw_data += b"\r\n"
+
+ # Add body if exists
+ body = request.get_data(as_text=False)
+ if body:
+ raw_data += body
+
+ return raw_data
diff --git a/api/core/plugin/manager/plugin.py b/api/core/plugin/impl/plugin.py
similarity index 98%
rename from api/core/plugin/manager/plugin.py
rename to api/core/plugin/impl/plugin.py
index 15dcd6cb34..3349463ce5 100644
--- a/api/core/plugin/manager/plugin.py
+++ b/api/core/plugin/impl/plugin.py
@@ -10,10 +10,10 @@ from core.plugin.entities.plugin import (
PluginInstallationSource,
)
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginInstallTaskStartResponse, PluginUploadResponse
-from core.plugin.manager.base import BasePluginManager
+from core.plugin.impl.base import BasePluginClient
-class PluginInstallationManager(BasePluginManager):
+class PluginInstaller(BasePluginClient):
def fetch_plugin_by_identifier(
self,
tenant_id: str,
diff --git a/api/core/plugin/manager/tool.py b/api/core/plugin/impl/tool.py
similarity index 98%
rename from api/core/plugin/manager/tool.py
rename to api/core/plugin/impl/tool.py
index 7592f867e1..19b26c8fe3 100644
--- a/api/core/plugin/manager/tool.py
+++ b/api/core/plugin/impl/tool.py
@@ -5,11 +5,11 @@ from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
-from core.plugin.manager.base import BasePluginManager
+from core.plugin.impl.base import BasePluginClient
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
-class PluginToolManager(BasePluginManager):
+class PluginToolManager(BasePluginClient):
def fetch_tool_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]:
"""
Fetch tool providers for the given tenant.
diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py
index c7427f797e..25964ae063 100644
--- a/api/core/prompt/advanced_prompt_transform.py
+++ b/api/core/prompt/advanced_prompt_transform.py
@@ -9,13 +9,12 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
- PromptMessageContent,
PromptMessageRole,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
-from core.model_runtime.entities.message_entities import ImagePromptMessageContent
+from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from core.prompt.prompt_transform import PromptTransform
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
@@ -125,7 +124,7 @@ class AdvancedPromptTransform(PromptTransform):
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
if files:
- prompt_message_contents: list[PromptMessageContent] = []
+ prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
for file in files:
prompt_message_contents.append(
@@ -201,7 +200,7 @@ class AdvancedPromptTransform(PromptTransform):
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
if files and query is not None:
- prompt_message_contents: list[PromptMessageContent] = []
+ prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=query))
for file in files:
prompt_message_contents.append(
diff --git a/api/core/prompt/prompt_templates/baichuan_chat.json b/api/core/prompt/prompt_templates/baichuan_chat.json
index 03b6a53cff..b3f7cdaa18 100644
--- a/api/core/prompt/prompt_templates/baichuan_chat.json
+++ b/api/core/prompt/prompt_templates/baichuan_chat.json
@@ -10,4 +10,4 @@
],
"query_prompt": "\n\n用户:{{#query#}}",
"stops": ["用户:"]
-}
\ No newline at end of file
+}
diff --git a/api/core/prompt/prompt_templates/baichuan_completion.json b/api/core/prompt/prompt_templates/baichuan_completion.json
index ae8c0dac53..cee9ea47cd 100644
--- a/api/core/prompt/prompt_templates/baichuan_completion.json
+++ b/api/core/prompt/prompt_templates/baichuan_completion.json
@@ -6,4 +6,4 @@
],
"query_prompt": "{{#query#}}",
"stops": null
-}
\ No newline at end of file
+}
diff --git a/api/core/prompt/prompt_templates/common_completion.json b/api/core/prompt/prompt_templates/common_completion.json
index c148772010..706a8140d1 100644
--- a/api/core/prompt/prompt_templates/common_completion.json
+++ b/api/core/prompt/prompt_templates/common_completion.json
@@ -6,4 +6,4 @@
],
"query_prompt": "{{#query#}}",
"stops": null
-}
\ No newline at end of file
+}
diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py
index ad56d84cb6..47808928f7 100644
--- a/api/core/prompt/simple_prompt_transform.py
+++ b/api/core/prompt/simple_prompt_transform.py
@@ -11,7 +11,7 @@ from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
- PromptMessageContent,
+ PromptMessageContentUnionTypes,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
@@ -277,7 +277,7 @@ class SimplePromptTransform(PromptTransform):
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
) -> UserPromptMessage:
if files:
- prompt_message_contents: list[PromptMessageContent] = []
+ prompt_message_contents: list[PromptMessageContentUnionTypes] = []
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
for file in files:
prompt_message_contents.append(
diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py
index 46a5330bdb..01f74b4a22 100644
--- a/api/core/rag/datasource/retrieval_service.py
+++ b/api/core/rag/datasource/retrieval_service.py
@@ -10,6 +10,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments
+from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode
@@ -119,12 +120,25 @@ class RetrievalService:
return all_documents
@classmethod
- def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None):
+ def external_retrieve(
+ cls,
+ dataset_id: str,
+ query: str,
+ external_retrieval_model: Optional[dict] = None,
+ metadata_filtering_conditions: Optional[dict] = None,
+ ):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
return []
+ metadata_condition = (
+ MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
+ )
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
- dataset.tenant_id, dataset_id, query, external_retrieval_model or {}
+ dataset.tenant_id,
+ dataset_id,
+ query,
+ external_retrieval_model or {},
+ metadata_condition=metadata_condition,
)
return all_documents
diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
index c1792943bb..14481b1f10 100644
--- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
+++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_sql.py
@@ -156,8 +156,8 @@ class AnalyticdbVectorBySql:
values = []
id_prefix = str(uuid.uuid4()) + "_"
sql = f"""
- INSERT INTO {self.table_name}
- (id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
+ INSERT INTO {self.table_name}
+ (id, ref_doc_id, vector, page_content, metadata_, to_tsvector)
VALUES (%s, %s, %s, %s, %s, to_tsvector('zh_cn', %s));
"""
for i, doc in enumerate(documents):
@@ -242,7 +242,7 @@ class AnalyticdbVectorBySql:
where_clause += f"AND metadata_->>'document_id' IN ({document_ids})"
with self._get_cursor() as cur:
cur.execute(
- f"""SELECT id, vector, page_content, metadata_,
+ f"""SELECT id, vector, page_content, metadata_,
ts_rank(to_tsvector, to_tsquery_from_text(%s, 'zh_cn'), 32) AS score
FROM {self.table_name}
WHERE to_tsvector@@to_tsquery_from_text(%s, 'zh_cn') {where_clause}
diff --git a/api/core/rag/datasource/vdb/huawei/__init__.py b/api/core/rag/datasource/vdb/huawei/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
new file mode 100644
index 0000000000..89423eb160
--- /dev/null
+++ b/api/core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
@@ -0,0 +1,215 @@
+import json
+import logging
+import ssl
+from typing import Any, Optional
+
+from elasticsearch import Elasticsearch
+from pydantic import BaseModel, model_validator
+
+from configs import dify_config
+from core.rag.datasource.vdb.field import Field
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+from models.dataset import Dataset
+
+logger = logging.getLogger(__name__)
+
+
+def create_ssl_context() -> ssl.SSLContext:
+ ssl_context = ssl.create_default_context()
+ ssl_context.check_hostname = False
+ ssl_context.verify_mode = ssl.CERT_NONE
+ return ssl_context
+
+
+class HuaweiCloudVectorConfig(BaseModel):
+ hosts: str
+ username: str | None
+ password: str | None
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_config(cls, values: dict) -> dict:
+ if not values["hosts"]:
+ raise ValueError("config HOSTS is required")
+ return values
+
+ def to_elasticsearch_params(self) -> dict[str, Any]:
+ params = {
+ "hosts": self.hosts.split(","),
+ "verify_certs": False,
+ "ssl_show_warn": False,
+ "request_timeout": 30000,
+ "retry_on_timeout": True,
+ "max_retries": 10,
+ }
+ if self.username and self.password:
+ params["basic_auth"] = (self.username, self.password)
+ return params
+
+
+class HuaweiCloudVector(BaseVector):
+ def __init__(self, index_name: str, config: HuaweiCloudVectorConfig):
+ super().__init__(index_name.lower())
+ self._client = Elasticsearch(**config.to_elasticsearch_params())
+
+ def get_type(self) -> str:
+ return VectorType.HUAWEI_CLOUD
+
+ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+ uuids = self._get_uuids(documents)
+ for i in range(len(documents)):
+ self._client.index(
+ index=self._collection_name,
+ id=uuids[i],
+ document={
+ Field.CONTENT_KEY.value: documents[i].page_content,
+ Field.VECTOR.value: embeddings[i] or None,
+ Field.METADATA_KEY.value: documents[i].metadata or {},
+ },
+ )
+ self._client.indices.refresh(index=self._collection_name)
+ return uuids
+
+ def text_exists(self, id: str) -> bool:
+ return bool(self._client.exists(index=self._collection_name, id=id))
+
+ def delete_by_ids(self, ids: list[str]) -> None:
+ if not ids:
+ return
+ for id in ids:
+ self._client.delete(index=self._collection_name, id=id)
+
+ def delete_by_metadata_field(self, key: str, value: str) -> None:
+ query_str = {"query": {"match": {f"metadata.{key}": f"{value}"}}}
+ results = self._client.search(index=self._collection_name, body=query_str)
+ ids = [hit["_id"] for hit in results["hits"]["hits"]]
+ if ids:
+ self.delete_by_ids(ids)
+
+ def delete(self) -> None:
+ self._client.indices.delete(index=self._collection_name)
+
+ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+ top_k = kwargs.get("top_k", 4)
+
+ query = {
+ "size": top_k,
+ "query": {
+ "vector": {
+ Field.VECTOR.value: {
+ "vector": query_vector,
+ "topk": top_k,
+ }
+ }
+ },
+ }
+
+ results = self._client.search(index=self._collection_name, body=query)
+
+ docs_and_scores = []
+ for hit in results["hits"]["hits"]:
+ docs_and_scores.append(
+ (
+ Document(
+ page_content=hit["_source"][Field.CONTENT_KEY.value],
+ vector=hit["_source"][Field.VECTOR.value],
+ metadata=hit["_source"][Field.METADATA_KEY.value],
+ ),
+ hit["_score"],
+ )
+ )
+
+ docs = []
+ for doc, score in docs_and_scores:
+ score_threshold = float(kwargs.get("score_threshold") or 0.0)
+ if score > score_threshold:
+ if doc.metadata is not None:
+ doc.metadata["score"] = score
+ docs.append(doc)
+
+ return docs
+
+ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+ query_str = {"match": {Field.CONTENT_KEY.value: query}}
+ results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4))
+ docs = []
+ for hit in results["hits"]["hits"]:
+ docs.append(
+ Document(
+ page_content=hit["_source"][Field.CONTENT_KEY.value],
+ vector=hit["_source"][Field.VECTOR.value],
+ metadata=hit["_source"][Field.METADATA_KEY.value],
+ )
+ )
+
+ return docs
+
+ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+ metadatas = [d.metadata if d.metadata is not None else {} for d in texts]
+ self.create_collection(embeddings, metadatas)
+ self.add_texts(texts, embeddings, **kwargs)
+
+ def create_collection(
+ self,
+ embeddings: list[list[float]],
+ metadatas: Optional[list[dict[Any, Any]]] = None,
+ index_params: Optional[dict] = None,
+ ):
+ lock_name = f"vector_indexing_lock_{self._collection_name}"
+ with redis_client.lock(lock_name, timeout=20):
+ collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
+ if redis_client.get(collection_exist_cache_key):
+ logger.info(f"Collection {self._collection_name} already exists.")
+ return
+
+ if not self._client.indices.exists(index=self._collection_name):
+ dim = len(embeddings[0])
+ mappings = {
+ "properties": {
+ Field.CONTENT_KEY.value: {"type": "text"},
+ Field.VECTOR.value: { # Make sure the dimension is correct here
+ "type": "vector",
+ "dimension": dim,
+ "indexing": True,
+ "algorithm": "GRAPH",
+ "metric": "cosine",
+ "neighbors": 32,
+ "efc": 128,
+ },
+ Field.METADATA_KEY.value: {
+ "type": "object",
+ "properties": {
+ "doc_id": {"type": "keyword"} # Map doc_id to keyword type
+ },
+ },
+ }
+ }
+ settings = {"index.vector": True}
+ self._client.indices.create(index=self._collection_name, mappings=mappings, settings=settings)
+
+ redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
+
+class HuaweiCloudVectorFactory(AbstractVectorFactory):
+ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> HuaweiCloudVector:
+ if dataset.index_struct_dict:
+ class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
+ collection_name = class_prefix.lower()
+ else:
+ dataset_id = dataset.id
+ collection_name = Dataset.gen_collection_name_by_id(dataset_id).lower()
+ dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.HUAWEI_CLOUD, collection_name))
+
+ return HuaweiCloudVector(
+ index_name=collection_name,
+ config=HuaweiCloudVectorConfig(
+ hosts=dify_config.HUAWEI_CLOUD_HOSTS or "http://localhost:9200",
+ username=dify_config.HUAWEI_CLOUD_USER,
+ password=dify_config.HUAWEI_CLOUD_PASSWORD,
+ ),
+ )
diff --git a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
index 643ac2df4e..e9ff1ce43d 100644
--- a/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
+++ b/api/core/rag/datasource/vdb/lindorm/lindorm_vector.py
@@ -32,6 +32,7 @@ class LindormVectorStoreConfig(BaseModel):
username: Optional[str] = None
password: Optional[str] = None
using_ugc: Optional[bool] = False
+ request_timeout: Optional[float] = 1.0 # timeout units: s
@model_validator(mode="before")
@classmethod
@@ -251,9 +252,9 @@ class LindormVectorStore(BaseVector):
query = default_vector_search_query(query_vector=query_vector, k=top_k, filters=filters, **kwargs)
try:
- params = {}
+ params = {"timeout": self._client_config.request_timeout}
if self._using_ugc:
- params["routing"] = self._routing
+ params["routing"] = self._routing # type: ignore
response = self._client.search(index=self._collection_name, body=query, params=params)
except Exception:
logger.exception(f"Error executing vector search, query: {query}")
@@ -304,8 +305,8 @@ class LindormVectorStore(BaseVector):
routing=routing,
routing_field=self._routing_field,
)
-
- response = self._client.search(index=self._collection_name, body=full_text_query)
+ params = {"timeout": self._client_config.request_timeout}
+ response = self._client.search(index=self._collection_name, body=full_text_query, params=params)
docs = []
for hit in response["hits"]["hits"]:
docs.append(
@@ -554,6 +555,7 @@ class LindormVectorStoreFactory(AbstractVectorFactory):
username=dify_config.LINDORM_USERNAME,
password=dify_config.LINDORM_PASSWORD,
using_ugc=dify_config.USING_UGC_INDEX,
+ request_timeout=dify_config.LINDORM_QUERY_TIMEOUT,
)
using_ugc = dify_config.USING_UGC_INDEX
if using_ugc is None:
diff --git a/api/core/rag/datasource/vdb/milvus/milvus_vector.py b/api/core/rag/datasource/vdb/milvus/milvus_vector.py
index 100bcb198c..7b3f826082 100644
--- a/api/core/rag/datasource/vdb/milvus/milvus_vector.py
+++ b/api/core/rag/datasource/vdb/milvus/milvus_vector.py
@@ -27,8 +27,8 @@ class MilvusConfig(BaseModel):
uri: str # Milvus server URI
token: Optional[str] = None # Optional token for authentication
- user: str # Username for authentication
- password: str # Password for authentication
+ user: Optional[str] = None # Username for authentication
+ password: Optional[str] = None # Password for authentication
batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search
@@ -43,10 +43,11 @@ class MilvusConfig(BaseModel):
"""
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")
- if not values.get("user"):
- raise ValueError("config MILVUS_USER is required")
- if not values.get("password"):
- raise ValueError("config MILVUS_PASSWORD is required")
+ if not values.get("token"):
+ if not values.get("user"):
+ raise ValueError("config MILVUS_USER is required")
+ if not values.get("password"):
+ raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
@@ -356,11 +357,14 @@ class MilvusVector(BaseVector):
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
- def _init_client(self, config) -> MilvusClient:
+ def _init_client(self, config: MilvusConfig) -> MilvusClient:
"""
Initialize and return a Milvus client.
"""
- client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
+ if config.token:
+ client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database)
+ else:
+ client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
return client
diff --git a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
index ae6b0c51ab..2b47d179d2 100644
--- a/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
+++ b/api/core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
@@ -203,7 +203,7 @@ class OceanBaseVector(BaseVector):
full_sql = f"""SELECT metadata, text, MATCH (text) AGAINST (:query) AS score
FROM {self._collection_name}
- WHERE MATCH (text) AGAINST (:query) > 0
+ WHERE MATCH (text) AGAINST (:query) > 0
{where_clause}
ORDER BY score DESC
LIMIT {top_k}"""
diff --git a/api/core/rag/datasource/vdb/opengauss/opengauss.py b/api/core/rag/datasource/vdb/opengauss/opengauss.py
index dae908f67d..2548881b9c 100644
--- a/api/core/rag/datasource/vdb/opengauss/opengauss.py
+++ b/api/core/rag/datasource/vdb/opengauss/opengauss.py
@@ -59,12 +59,12 @@ CREATE TABLE IF NOT EXISTS {table_name} (
"""
SQL_CREATE_INDEX_PQ = """
-CREATE INDEX IF NOT EXISTS embedding_{table_name}_pq_idx ON {table_name}
+CREATE INDEX IF NOT EXISTS embedding_{table_name}_pq_idx ON {table_name}
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64, enable_pq=on, pq_m={pq_m});
"""
SQL_CREATE_INDEX = """
-CREATE INDEX IF NOT EXISTS embedding_cosine_{table_name}_idx ON {table_name}
+CREATE INDEX IF NOT EXISTS embedding_cosine_{table_name}_idx ON {table_name}
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
"""
diff --git a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py
index 6636646cff..e23b8d197f 100644
--- a/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py
+++ b/api/core/rag/datasource/vdb/opensearch/opensearch_vector.py
@@ -1,10 +1,9 @@
import json
import logging
-import ssl
-from typing import Any, Optional
+from typing import Any, Literal, Optional
from uuid import uuid4
-from opensearchpy import OpenSearch, helpers
+from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
@@ -24,9 +23,12 @@ logger = logging.getLogger(__name__)
class OpenSearchConfig(BaseModel):
host: str
port: int
+ secure: bool = False
+ auth_method: Literal["basic", "aws_managed_iam"] = "basic"
user: Optional[str] = None
password: Optional[str] = None
- secure: bool = False
+ aws_region: Optional[str] = None
+ aws_service: Optional[str] = None
@model_validator(mode="before")
@classmethod
@@ -35,24 +37,40 @@ class OpenSearchConfig(BaseModel):
raise ValueError("config OPENSEARCH_HOST is required")
if not values.get("port"):
raise ValueError("config OPENSEARCH_PORT is required")
+ if values.get("auth_method") == "aws_managed_iam":
+ if not values.get("aws_region"):
+ raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method")
+ if not values.get("aws_service"):
+ raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method")
return values
- def create_ssl_context(self) -> ssl.SSLContext:
- ssl_context = ssl.create_default_context()
- ssl_context.check_hostname = False
- ssl_context.verify_mode = ssl.CERT_NONE # Disable Certificate Validation
- return ssl_context
+ def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
+ import boto3 # type: ignore
+
+ return Urllib3AWSV4SignerAuth(
+ credentials=boto3.Session().get_credentials(),
+ region=self.aws_region,
+ service=self.aws_service, # type: ignore[arg-type]
+ )
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure,
"verify_certs": self.secure,
+ "connection_class": Urllib3HttpConnection,
+ "pool_maxsize": 20,
}
- if self.user and self.password:
+
+ if self.auth_method == "basic":
+ logger.info("Using basic authentication for OpenSearch Vector DB")
+
params["http_auth"] = (self.user, self.password)
- if self.secure:
- params["ssl_context"] = self.create_ssl_context()
+ elif self.auth_method == "aws_managed_iam":
+ logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
+
+ params["http_auth"] = self.create_aws_managed_iam_auth()
+
return params
@@ -76,16 +94,23 @@ class OpenSearchVector(BaseVector):
action = {
"_op_type": "index",
"_index": self._collection_name.lower(),
- "_id": uuid4().hex,
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
},
}
+ # See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
+ if self._client_config.aws_service not in ["aoss"]:
+ action["_id"] = uuid4().hex
actions.append(action)
- helpers.bulk(self._client, actions)
+ helpers.bulk(
+ client=self._client,
+ actions=actions,
+ timeout=30,
+ max_retries=3,
+ )
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
@@ -234,6 +259,7 @@ class OpenSearchVector(BaseVector):
},
}
+ logger.info(f"Creating OpenSearch index {self._collection_name.lower()}")
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
@@ -252,9 +278,12 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
open_search_config = OpenSearchConfig(
host=dify_config.OPENSEARCH_HOST or "localhost",
port=dify_config.OPENSEARCH_PORT,
+ secure=dify_config.OPENSEARCH_SECURE,
+ auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,
- secure=dify_config.OPENSEARCH_SECURE,
+ aws_region=dify_config.OPENSEARCH_AWS_REGION,
+ aws_service=dify_config.OPENSEARCH_AWS_SERVICE,
)
return OpenSearchVector(collection_name=collection_name, config=open_search_config)
diff --git a/api/core/rag/datasource/vdb/oracle/oraclevector.py b/api/core/rag/datasource/vdb/oracle/oraclevector.py
index 4af2578197..0a3738ac93 100644
--- a/api/core/rag/datasource/vdb/oracle/oraclevector.py
+++ b/api/core/rag/datasource/vdb/oracle/oraclevector.py
@@ -2,12 +2,12 @@ import array
import json
import re
import uuid
-from contextlib import contextmanager
from typing import Any
import jieba.posseg as pseg # type: ignore
import numpy
import oracledb
+from oracledb.connection import Connection
from pydantic import BaseModel, model_validator
from configs import dify_config
@@ -59,8 +59,8 @@ CREATE TABLE IF NOT EXISTS {table_name} (
)
"""
SQL_CREATE_INDEX = """
-CREATE INDEX IF NOT EXISTS idx_docs_{table_name} ON {table_name}(text)
-INDEXTYPE IS CTXSYS.CONTEXT PARAMETERS
+CREATE INDEX IF NOT EXISTS idx_docs_{table_name} ON {table_name}(text)
+INDEXTYPE IS CTXSYS.CONTEXT PARAMETERS
('FILTER CTXSYS.NULL_FILTER SECTION GROUP CTXSYS.HTML_SECTION_GROUP LEXER world_lexer')
"""
@@ -70,6 +70,7 @@ class OracleVector(BaseVector):
super().__init__(collection_name)
self.pool = self._create_connection_pool(config)
self.table_name = f"embedding_{collection_name}"
+ self.config = config
def get_type(self) -> str:
return VectorType.ORACLE
@@ -107,16 +108,19 @@ class OracleVector(BaseVector):
outconverter=self.numpy_converter_out,
)
+ def _get_connection(self) -> Connection:
+ connection = oracledb.connect(user=self.config.user, password=self.config.password, dsn=self.config.dsn)
+ return connection
+
def _create_connection_pool(self, config: OracleVectorConfig):
pool_params = {
"user": config.user,
"password": config.password,
"dsn": config.dsn,
"min": 1,
- "max": 50,
+ "max": 5,
"increment": 1,
}
-
if config.is_autonomous:
pool_params.update(
{
@@ -125,22 +129,8 @@ class OracleVector(BaseVector):
"wallet_password": config.wallet_password,
}
)
-
return oracledb.create_pool(**pool_params)
- @contextmanager
- def _get_cursor(self):
- conn = self.pool.acquire()
- conn.inputtypehandler = self.input_type_handler
- conn.outputtypehandler = self.output_type_handler
- cur = conn.cursor()
- try:
- yield cur
- finally:
- cur.close()
- conn.commit()
- conn.close()
-
def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
dimension = len(embeddings[0])
self._create_collection(dimension)
@@ -162,41 +152,68 @@ class OracleVector(BaseVector):
numpy.array(embeddings[i]),
)
)
- # print(f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)")
- with self._get_cursor() as cur:
- cur.executemany(
- f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
- )
+ with self._get_connection() as conn:
+ conn.inputtypehandler = self.input_type_handler
+ conn.outputtypehandler = self.output_type_handler
+ # with conn.cursor() as cur:
+ # cur.executemany(
+ # f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES (:1, :2, :3, :4)", values
+ # )
+ # conn.commit()
+ for value in values:
+ with conn.cursor() as cur:
+ try:
+ cur.execute(
+ f"""INSERT INTO {self.table_name} (id, text, meta, embedding)
+ VALUES (:1, :2, :3, :4)""",
+ value,
+ )
+ conn.commit()
+ except Exception as e:
+ print(e)
+ conn.close()
return pks
def text_exists(self, id: str) -> bool:
- with self._get_cursor() as cur:
- cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
- return cur.fetchone() is not None
+ with self._get_connection() as conn:
+ with conn.cursor() as cur:
+ cur.execute(f"SELECT id FROM {self.table_name} WHERE id = '%s'" % (id,))
+ return cur.fetchone() is not None
+ conn.close()
def get_by_ids(self, ids: list[str]) -> list[Document]:
- with self._get_cursor() as cur:
- cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
- docs = []
- for record in cur:
- docs.append(Document(page_content=record[1], metadata=record[0]))
+ with self._get_connection() as conn:
+ with conn.cursor() as cur:
+ cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
+ docs = []
+ for record in cur:
+ docs.append(Document(page_content=record[1], metadata=record[0]))
+ self.pool.release(connection=conn)
+ conn.close()
return docs
def delete_by_ids(self, ids: list[str]) -> None:
if not ids:
return
- with self._get_cursor() as cur:
- cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
+ with self._get_connection() as conn:
+ with conn.cursor() as cur:
+ cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s" % (tuple(ids),))
+ conn.commit()
+ conn.close()
def delete_by_metadata_field(self, key: str, value: str) -> None:
- with self._get_cursor() as cur:
- cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
+ with self._get_connection() as conn:
+ with conn.cursor() as cur:
+ cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
+ conn.commit()
+ conn.close()
def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
"""
Search the nearest neighbors to a vector.
:param query_vector: The input vector to search for similar items.
+ :param top_k: The number of nearest neighbors to return, default is 5.
:return: List of Documents that are nearest to the query vector.
"""
top_k = kwargs.get("top_k", 4)
@@ -205,20 +222,25 @@ class OracleVector(BaseVector):
if document_ids_filter:
document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
where_clause = f"WHERE metadata->>'document_id' in ({document_ids})"
- with self._get_cursor() as cur:
- cur.execute(
- f"SELECT meta, text, vector_distance(embedding,:1) AS distance FROM {self.table_name}"
- f" {where_clause} ORDER BY distance fetch first {top_k} rows only",
- [numpy.array(query_vector)],
- )
- docs = []
- score_threshold = float(kwargs.get("score_threshold") or 0.0)
- for record in cur:
- metadata, text, distance = record
- score = 1 - distance
- metadata["score"] = score
- if score > score_threshold:
- docs.append(Document(page_content=text, metadata=metadata))
+ with self._get_connection() as conn:
+ conn.inputtypehandler = self.input_type_handler
+ conn.outputtypehandler = self.output_type_handler
+ with conn.cursor() as cur:
+ cur.execute(
+ f"""SELECT meta, text, vector_distance(embedding,(select to_vector(:1) from dual),cosine)
+ AS distance FROM {self.table_name}
+ {where_clause} ORDER BY distance fetch first {top_k} rows only""",
+ [numpy.array(query_vector)],
+ )
+ docs = []
+ score_threshold = float(kwargs.get("score_threshold") or 0.0)
+ for record in cur:
+ metadata, text, distance = record
+ score = 1 - distance
+ metadata["score"] = score
+ if score > score_threshold:
+ docs.append(Document(page_content=text, metadata=metadata))
+ conn.close()
return docs
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
@@ -228,7 +250,7 @@ class OracleVector(BaseVector):
top_k = kwargs.get("top_k", 5)
# just not implement fetch by score_threshold now, may be later
- # score_threshold = float(kwargs.get("score_threshold") or 0.0)
+ score_threshold = float(kwargs.get("score_threshold") or 0.0)
if len(query) > 0:
# Check which language the query is in
zh_pattern = re.compile("[\u4e00-\u9fa5]+")
@@ -239,7 +261,7 @@ class OracleVector(BaseVector):
words = pseg.cut(query)
current_entity = ""
for word, pos in words:
- if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名,ns: 地名,nt: 机构名
+ if pos in {"nr", "Ng", "eng", "nz", "n", "ORG", "v"}: # nr: 人名, ns: 地名, nt: 机构名
current_entity += word
else:
if current_entity:
@@ -260,30 +282,35 @@ class OracleVector(BaseVector):
for token in all_tokens:
if token not in stop_words:
entities.append(token)
- with self._get_cursor() as cur:
- document_ids_filter = kwargs.get("document_ids_filter")
- where_clause = ""
- if document_ids_filter:
- document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
- where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
- cur.execute(
- f"select meta, text, embedding FROM {self.table_name}"
- f"WHERE CONTAINS(text, :1, 1) > 0 {where_clause} "
- f"order by score(1) desc fetch first {top_k} rows only",
- [" ACCUM ".join(entities)],
- )
- docs = []
- for record in cur:
- metadata, text, embedding = record
- docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
+ with self._get_connection() as conn:
+ with conn.cursor() as cur:
+ document_ids_filter = kwargs.get("document_ids_filter")
+ where_clause = ""
+ if document_ids_filter:
+ document_ids = ", ".join(f"'{id}'" for id in document_ids_filter)
+ where_clause = f" AND metadata->>'document_id' in ({document_ids}) "
+ cur.execute(
+ f"""select meta, text, embedding FROM {self.table_name}
+ WHERE CONTAINS(text, :kk, 1) > 0 {where_clause}
+ order by score(1) desc fetch first {top_k} rows only""",
+ kk=" ACCUM ".join(entities),
+ )
+ docs = []
+ for record in cur:
+ metadata, text, embedding = record
+ docs.append(Document(page_content=text, vector=embedding, metadata=metadata))
+ conn.close()
return docs
else:
return [Document(page_content="", metadata={})]
return []
def delete(self) -> None:
- with self._get_cursor() as cur:
- cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")
+ with self._get_connection() as conn:
+ with conn.cursor() as cur:
+ cur.execute(f"DROP TABLE IF EXISTS {self.table_name} cascade constraints")
+ conn.commit()
+ conn.close()
def _create_collection(self, dimension: int):
cache_key = f"vector_indexing_{self._collection_name}"
@@ -293,11 +320,14 @@ class OracleVector(BaseVector):
if redis_client.get(collection_exist_cache_key):
return
- with self._get_cursor() as cur:
- cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
- redis_client.set(collection_exist_cache_key, 1, ex=3600)
- with self._get_cursor() as cur:
- cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
+ with self._get_connection() as conn:
+ with conn.cursor() as cur:
+ cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name))
+ redis_client.set(collection_exist_cache_key, 1, ex=3600)
+ with conn.cursor() as cur:
+ cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
+ conn.commit()
+ conn.close()
class OracleVectorFactory(AbstractVectorFactory):
diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py
index eab51ab01d..366a21c381 100644
--- a/api/core/rag/datasource/vdb/pgvector/pgvector.py
+++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py
@@ -61,7 +61,7 @@ CREATE TABLE IF NOT EXISTS {table_name} (
"""
SQL_CREATE_INDEX = """
-CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
+CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64);
"""
diff --git a/api/core/rag/datasource/vdb/pyvastbase/__init__.py b/api/core/rag/datasource/vdb/pyvastbase/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py b/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py
new file mode 100644
index 0000000000..156730ff37
--- /dev/null
+++ b/api/core/rag/datasource/vdb/pyvastbase/vastbase_vector.py
@@ -0,0 +1,243 @@
+import json
+import uuid
+from contextlib import contextmanager
+from typing import Any
+
+import psycopg2.extras # type: ignore
+import psycopg2.pool # type: ignore
+from pydantic import BaseModel, model_validator
+
+from configs import dify_config
+from core.rag.datasource.vdb.vector_base import BaseVector
+from core.rag.datasource.vdb.vector_factory import AbstractVectorFactory
+from core.rag.datasource.vdb.vector_type import VectorType
+from core.rag.embedding.embedding_base import Embeddings
+from core.rag.models.document import Document
+from extensions.ext_redis import redis_client
+from models.dataset import Dataset
+
+
+class VastbaseVectorConfig(BaseModel):
+ host: str
+ port: int
+ user: str
+ password: str
+ database: str
+ min_connection: int
+ max_connection: int
+
+ @model_validator(mode="before")
+ @classmethod
+ def validate_config(cls, values: dict) -> dict:
+ if not values["host"]:
+ raise ValueError("config VASTBASE_HOST is required")
+ if not values["port"]:
+ raise ValueError("config VASTBASE_PORT is required")
+ if not values["user"]:
+ raise ValueError("config VASTBASE_USER is required")
+ if not values["password"]:
+ raise ValueError("config VASTBASE_PASSWORD is required")
+ if not values["database"]:
+ raise ValueError("config VASTBASE_DATABASE is required")
+ if not values["min_connection"]:
+ raise ValueError("config VASTBASE_MIN_CONNECTION is required")
+ if not values["max_connection"]:
+ raise ValueError("config VASTBASE_MAX_CONNECTION is required")
+ if values["min_connection"] > values["max_connection"]:
+ raise ValueError("config VASTBASE_MIN_CONNECTION should less than VASTBASE_MAX_CONNECTION")
+ return values
+
+
+SQL_CREATE_TABLE = """
+CREATE TABLE IF NOT EXISTS {table_name} (
+ id UUID PRIMARY KEY,
+ text TEXT NOT NULL,
+ meta JSONB NOT NULL,
+ embedding floatvector({dimension}) NOT NULL
+);
+"""
+
+SQL_CREATE_INDEX = """
+CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name}
+USING hnsw (embedding floatvector_cosine_ops) WITH (m = 16, ef_construction = 64);
+"""
+
+
+class VastbaseVector(BaseVector):
+ def __init__(self, collection_name: str, config: VastbaseVectorConfig):
+ super().__init__(collection_name)
+ self.pool = self._create_connection_pool(config)
+ self.table_name = f"embedding_{collection_name}"
+
+ def get_type(self) -> str:
+ return VectorType.VASTBASE
+
+ def _create_connection_pool(self, config: VastbaseVectorConfig):
+ return psycopg2.pool.SimpleConnectionPool(
+ config.min_connection,
+ config.max_connection,
+ host=config.host,
+ port=config.port,
+ user=config.user,
+ password=config.password,
+ database=config.database,
+ )
+
+ @contextmanager
+ def _get_cursor(self):
+ conn = self.pool.getconn()
+ cur = conn.cursor()
+ try:
+ yield cur
+ finally:
+ cur.close()
+ conn.commit()
+ self.pool.putconn(conn)
+
+ def create(self, texts: list[Document], embeddings: list[list[float]], **kwargs):
+ dimension = len(embeddings[0])
+ self._create_collection(dimension)
+ return self.add_texts(texts, embeddings)
+
+ def add_texts(self, documents: list[Document], embeddings: list[list[float]], **kwargs):
+ values = []
+ pks = []
+ for i, doc in enumerate(documents):
+ if doc.metadata is not None:
+ doc_id = doc.metadata.get("doc_id", str(uuid.uuid4()))
+ pks.append(doc_id)
+ values.append(
+ (
+ doc_id,
+ doc.page_content,
+ json.dumps(doc.metadata),
+ embeddings[i],
+ )
+ )
+ with self._get_cursor() as cur:
+ psycopg2.extras.execute_values(
+ cur, f"INSERT INTO {self.table_name} (id, text, meta, embedding) VALUES %s", values
+ )
+ return pks
+
+ def text_exists(self, id: str) -> bool:
+ with self._get_cursor() as cur:
+ cur.execute(f"SELECT id FROM {self.table_name} WHERE id = %s", (id,))
+ return cur.fetchone() is not None
+
+ def get_by_ids(self, ids: list[str]) -> list[Document]:
+ with self._get_cursor() as cur:
+ cur.execute(f"SELECT meta, text FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
+ docs = []
+ for record in cur:
+ docs.append(Document(page_content=record[1], metadata=record[0]))
+ return docs
+
+ def delete_by_ids(self, ids: list[str]) -> None:
+ # Avoiding crashes caused by performing delete operations on empty lists in certain scenarios
+ # Scenario 1: extract a document fails, resulting in a table not being created.
+ # Then clicking the retry button triggers a delete operation on an empty list.
+ if not ids:
+ return
+ with self._get_cursor() as cur:
+ cur.execute(f"DELETE FROM {self.table_name} WHERE id IN %s", (tuple(ids),))
+
+ def delete_by_metadata_field(self, key: str, value: str) -> None:
+ with self._get_cursor() as cur:
+ cur.execute(f"DELETE FROM {self.table_name} WHERE meta->>%s = %s", (key, value))
+
+ def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
+ """
+ Search the nearest neighbors to a vector.
+
+ :param query_vector: The input vector to search for similar items.
+ :param top_k: The number of nearest neighbors to return, default is 5.
+ :return: List of Documents that are nearest to the query vector.
+ """
+ top_k = kwargs.get("top_k", 4)
+
+ if not isinstance(top_k, int) or top_k <= 0:
+ raise ValueError("top_k must be a positive integer")
+ with self._get_cursor() as cur:
+ cur.execute(
+ f"SELECT meta, text, embedding <=> %s AS distance FROM {self.table_name}"
+ f" ORDER BY distance LIMIT {top_k}",
+ (json.dumps(query_vector),),
+ )
+ docs = []
+ score_threshold = float(kwargs.get("score_threshold") or 0.0)
+ for record in cur:
+ metadata, text, distance = record
+ score = 1 - distance
+ metadata["score"] = score
+ if score > score_threshold:
+ docs.append(Document(page_content=text, metadata=metadata))
+ return docs
+
+ def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
+ top_k = kwargs.get("top_k", 5)
+
+ if not isinstance(top_k, int) or top_k <= 0:
+ raise ValueError("top_k must be a positive integer")
+ with self._get_cursor() as cur:
+ cur.execute(
+ f"""SELECT meta, text, ts_rank(to_tsvector(coalesce(text, '')), plainto_tsquery(%s)) AS score
+ FROM {self.table_name}
+ WHERE to_tsvector(text) @@ plainto_tsquery(%s)
+ ORDER BY score DESC
+ LIMIT {top_k}""",
+ # f"'{query}'" is required in order to account for whitespace in query
+ (f"'{query}'", f"'{query}'"),
+ )
+
+ docs = []
+
+ for record in cur:
+ metadata, text, score = record
+ metadata["score"] = score
+ docs.append(Document(page_content=text, metadata=metadata))
+
+ return docs
+
+ def delete(self) -> None:
+ with self._get_cursor() as cur:
+ cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
+
+ def _create_collection(self, dimension: int):
+ cache_key = f"vector_indexing_{self._collection_name}"
+ lock_name = f"{cache_key}_lock"
+ with redis_client.lock(lock_name, timeout=20):
+ collection_exist_cache_key = f"vector_indexing_{self._collection_name}"
+ if redis_client.get(collection_exist_cache_key):
+ return
+
+ with self._get_cursor() as cur:
+ cur.execute(SQL_CREATE_TABLE.format(table_name=self.table_name, dimension=dimension))
+ # Vastbase 支持的向量维度取值范围为 [1,16000]
+ if dimension <= 16000:
+ cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name))
+ redis_client.set(collection_exist_cache_key, 1, ex=3600)
+
+
+class VastbaseVectorFactory(AbstractVectorFactory):
+ def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> VastbaseVector:
+ if dataset.index_struct_dict:
+ class_prefix: str = dataset.index_struct_dict["vector_store"]["class_prefix"]
+ collection_name = class_prefix
+ else:
+ dataset_id = dataset.id
+ collection_name = Dataset.gen_collection_name_by_id(dataset_id)
+ dataset.index_struct = json.dumps(self.gen_index_struct_dict(VectorType.VASTBASE, collection_name))
+
+ return VastbaseVector(
+ collection_name=collection_name,
+ config=VastbaseVectorConfig(
+ host=dify_config.VASTBASE_HOST or "localhost",
+ port=dify_config.VASTBASE_PORT,
+ user=dify_config.VASTBASE_USER or "dify",
+ password=dify_config.VASTBASE_PASSWORD or "",
+ database=dify_config.VASTBASE_DATABASE or "dify",
+ min_connection=dify_config.VASTBASE_MIN_CONNECTION,
+ max_connection=dify_config.VASTBASE_MAX_CONNECTION,
+ ),
+ )
diff --git a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
index 00229ce700..61c68b939e 100644
--- a/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
+++ b/api/core/rag/datasource/vdb/tidb_vector/tidb_vector.py
@@ -205,9 +205,9 @@ class TiDBVector(BaseVector):
with Session(self._engine) as session:
select_statement = sql_text(f"""
- SELECT meta, text, distance
+ SELECT meta, text, distance
FROM (
- SELECT
+ SELECT
meta,
text,
{tidb_dist_func}(vector, :query_vector_str) AS distance
diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py
index 00601c38a1..66e002312a 100644
--- a/api/core/rag/datasource/vdb/vector_factory.py
+++ b/api/core/rag/datasource/vdb/vector_factory.py
@@ -74,6 +74,10 @@ class Vector:
from core.rag.datasource.vdb.pgvector.pgvector import PGVectorFactory
return PGVectorFactory
+ case VectorType.VASTBASE:
+ from core.rag.datasource.vdb.pyvastbase.vastbase_vector import VastbaseVectorFactory
+
+ return VastbaseVectorFactory
case VectorType.PGVECTO_RS:
from core.rag.datasource.vdb.pgvecto_rs.pgvecto_rs import PGVectoRSFactory
@@ -156,6 +160,10 @@ class Vector:
from core.rag.datasource.vdb.tablestore.tablestore_vector import TableStoreVectorFactory
return TableStoreVectorFactory
+ case VectorType.HUAWEI_CLOUD:
+ from core.rag.datasource.vdb.huawei.huawei_cloud_vector import HuaweiCloudVectorFactory
+
+ return HuaweiCloudVectorFactory
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")
diff --git a/api/core/rag/datasource/vdb/vector_type.py b/api/core/rag/datasource/vdb/vector_type.py
index 940f12caef..7a81565e37 100644
--- a/api/core/rag/datasource/vdb/vector_type.py
+++ b/api/core/rag/datasource/vdb/vector_type.py
@@ -7,7 +7,9 @@ class VectorType(StrEnum):
MILVUS = "milvus"
MYSCALE = "myscale"
PGVECTOR = "pgvector"
+ VASTBASE = "vastbase"
PGVECTO_RS = "pgvecto-rs"
+
QDRANT = "qdrant"
RELYT = "relyt"
TIDB_VECTOR = "tidb_vector"
@@ -26,3 +28,4 @@ class VectorType(StrEnum):
OCEANBASE = "oceanbase"
OPENGAUSS = "opengauss"
TABLESTORE = "tablestore"
+ HUAWEI_CLOUD = "huawei_cloud"
diff --git a/api/core/rag/extractor/watercrawl/provider.py b/api/core/rag/extractor/watercrawl/provider.py
index b8003b386b..21fbb2100f 100644
--- a/api/core/rag/extractor/watercrawl/provider.py
+++ b/api/core/rag/extractor/watercrawl/provider.py
@@ -20,7 +20,7 @@ class WaterCrawlProvider:
}
if options.get("crawl_sub_pages", True):
spider_options["page_limit"] = options.get("limit", 1)
- spider_options["max_depth"] = options.get("depth", 1)
+ spider_options["max_depth"] = options.get("max_depth", 1)
spider_options["include_paths"] = options.get("includes", "").split(",") if options.get("includes") else []
spider_options["exclude_paths"] = options.get("excludes", "").split(",") if options.get("excludes") else []
diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py
index ac7a3f8bb8..693535413a 100644
--- a/api/core/rag/rerank/rerank_model.py
+++ b/api/core/rag/rerank/rerank_model.py
@@ -52,14 +52,16 @@ class RerankModelRunner(BaseRerankRunner):
rerank_documents = []
for result in rerank_result.docs:
- # format document
- rerank_document = Document(
- page_content=result.text,
- metadata=documents[result.index].metadata,
- provider=documents[result.index].provider,
- )
- if rerank_document.metadata is not None:
- rerank_document.metadata["score"] = result.score
- rerank_documents.append(rerank_document)
+ if score_threshold is None or result.score >= score_threshold:
+ # format document
+ rerank_document = Document(
+ page_content=result.text,
+ metadata=documents[result.index].metadata,
+ provider=documents[result.index].provider,
+ )
+ if rerank_document.metadata is not None:
+ rerank_document.metadata["score"] = result.score
+ rerank_documents.append(rerank_document)
- return rerank_documents
+ rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
+ return rerank_documents[:top_n] if top_n else rerank_documents
diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py
index 23ea775dec..9216b31b8e 100644
--- a/api/core/rag/retrieval/dataset_retrieval.py
+++ b/api/core/rag/retrieval/dataset_retrieval.py
@@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Optional, Union, cast
from flask import Flask, current_app
-from sqlalchemy import Integer, and_, or_, text
+from sqlalchemy import Float, and_, or_, text
from sqlalchemy import cast as sqlalchemy_cast
from core.app.app_config.entities import (
@@ -149,7 +149,7 @@ class DatasetRetrieval:
else:
inputs = {}
available_datasets_ids = [dataset.id for dataset in available_datasets]
- metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition(
+ metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
available_datasets_ids,
query,
tenant_id,
@@ -649,6 +649,8 @@ class DatasetRetrieval:
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
+ user_id: str,
+ inputs: dict,
) -> Optional[list[DatasetRetrieverBaseTool]]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
@@ -706,6 +708,9 @@ class DatasetRetrieval:
hit_callbacks=[hit_callback],
return_resource=return_resource,
retriever_from=invoke_from.to_source(),
+ retrieve_config=retrieve_config,
+ user_id=user_id,
+ inputs=inputs,
)
tools.append(tool)
@@ -826,7 +831,7 @@ class DatasetRetrieval:
)
return filter_documents[:top_k] if top_k else filter_documents
- def _get_metadata_filter_condition(
+ def get_metadata_filter_condition(
self,
dataset_ids: list,
query: str,
@@ -869,32 +874,45 @@ class DatasetRetrieval:
)
)
metadata_condition = MetadataCondition(
- logical_operator=metadata_filtering_conditions.logical_operator, # type: ignore
+ logical_operator=metadata_filtering_conditions.logical_operator
+ if metadata_filtering_conditions
+ else "or", # type: ignore
conditions=conditions,
)
elif metadata_filtering_mode == "manual":
if metadata_filtering_conditions:
- metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump())
+ conditions = []
for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
- if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
+ if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self._replace_metadata_filter_value(expected_value, inputs)
- filters = self._process_metadata_filter_func(
- sequence,
- condition.comparison_operator,
- metadata_name,
- expected_value,
- filters,
+ conditions.append(
+ Condition(
+ name=metadata_name,
+ comparison_operator=condition.comparison_operator,
+ value=expected_value,
)
+ )
+ filters = self._process_metadata_filter_func(
+ sequence,
+ condition.comparison_operator,
+ metadata_name,
+ expected_value,
+ filters,
+ )
+ metadata_condition = MetadataCondition(
+ logical_operator=metadata_filtering_conditions.logical_operator,
+ conditions=conditions,
+ )
else:
raise ValueError("Invalid metadata filtering mode")
if filters:
- if metadata_filtering_conditions.logical_operator == "or": # type: ignore
- document_query = document_query.filter(or_(*filters))
- else:
+ if metadata_filtering_conditions and metadata_filtering_conditions.logical_operator == "and": # type: ignore
document_query = document_query.filter(and_(*filters))
+ else:
+ document_query = document_query.filter(or_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
@@ -1003,28 +1021,24 @@ class DatasetRetrieval:
if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
else:
- filters.append(
- sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) == value
- )
+ filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) == value)
case "is not" | "≠":
if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
else:
- filters.append(
- sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) != value
- )
+ filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) != value)
case "empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
- filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) < value)
+ filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) < value)
case "after" | ">":
- filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) > value)
+ filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) > value)
case "≤" | "<=":
- filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) <= value)
+ filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) <= value)
case "≥" | ">=":
- filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) >= value)
+ filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) >= value)
case _:
pass
return filters
diff --git a/api/core/rag/retrieval/template_prompts.py b/api/core/rag/retrieval/template_prompts.py
index 7abd55d798..9c945e2f52 100644
--- a/api/core/rag/retrieval/template_prompts.py
+++ b/api/core/rag/retrieval/template_prompts.py
@@ -2,7 +2,7 @@ METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
- Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
+ Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "=", "≠", ">", "<", "≥", "≤", "before", "after"] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
@@ -50,7 +50,7 @@ You are a text metadata extract engine that extract text's metadata based on use
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
-### Constraint
+### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside XML tags.
@@ -59,7 +59,7 @@ User:{{"input_text": ["I want to know which company’s email address test@examp
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
-
+
### User Input
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
### Assistant Output
diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py
index 34b4056cf5..b711e8434a 100644
--- a/api/core/rag/splitter/text_splitter.py
+++ b/api/core/rag/splitter/text_splitter.py
@@ -159,50 +159,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
)
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
- @classmethod
- def from_tiktoken_encoder(
- cls: type[TS],
- encoding_name: str = "gpt2",
- model_name: Optional[str] = None,
- allowed_special: Union[Literal["all"], Set[str]] = set(),
- disallowed_special: Union[Literal["all"], Collection[str]] = "all",
- **kwargs: Any,
- ) -> TS:
- """Text splitter that uses tiktoken encoder to count length."""
- try:
- import tiktoken
- except ImportError:
- raise ImportError(
- "Could not import tiktoken python package. "
- "This is needed in order to calculate max_tokens_for_prompt. "
- "Please install it with `pip install tiktoken`."
- )
-
- if model_name is not None:
- enc = tiktoken.encoding_for_model(model_name)
- else:
- enc = tiktoken.get_encoding(encoding_name)
-
- def _tiktoken_encoder(text: str) -> int:
- return len(
- enc.encode(
- text,
- allowed_special=allowed_special,
- disallowed_special=disallowed_special,
- )
- )
-
- if issubclass(cls, TokenTextSplitter):
- extra_kwargs = {
- "encoding_name": encoding_name,
- "model_name": model_name,
- "allowed_special": allowed_special,
- "disallowed_special": disallowed_special,
- }
- kwargs = {**kwargs, **extra_kwargs}
-
- return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs)
-
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Transform sequence of documents by splitting them."""
return self.split_documents(list(documents))
diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py
new file mode 100644
index 0000000000..6452317120
--- /dev/null
+++ b/api/core/repositories/__init__.py
@@ -0,0 +1,12 @@
+"""
+Repository implementations for data access.
+
+This package contains concrete implementations of the repository interfaces
+defined in the core.workflow.repository package.
+"""
+
+from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
+
+__all__ = [
+ "SQLAlchemyWorkflowNodeExecutionRepository",
+]
diff --git a/api/repositories/workflow_node_execution/sqlalchemy_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
similarity index 94%
rename from api/repositories/workflow_node_execution/sqlalchemy_repository.py
rename to api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
index 0594d816a2..8bf2ab8761 100644
--- a/api/repositories/workflow_node_execution/sqlalchemy_repository.py
+++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py
@@ -10,13 +10,13 @@ from sqlalchemy import UnaryExpression, asc, delete, desc, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker
-from core.repository.workflow_node_execution_repository import OrderConfig
+from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
-class SQLAlchemyWorkflowNodeExecutionRepository:
+class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository):
"""
SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface.
@@ -37,8 +37,12 @@ class SQLAlchemyWorkflowNodeExecutionRepository:
# If an engine is provided, create a sessionmaker from it
if isinstance(session_factory, Engine):
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
- else:
+ elif isinstance(session_factory, sessionmaker):
self._session_factory = session_factory
+ else:
+ raise ValueError(
+ f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine"
+ )
self._tenant_id = tenant_id
self._app_id = app_id
diff --git a/api/core/repository/repository_factory.py b/api/core/repository/repository_factory.py
deleted file mode 100644
index 7da7e49055..0000000000
--- a/api/core/repository/repository_factory.py
+++ /dev/null
@@ -1,97 +0,0 @@
-"""
-Repository factory for creating repository instances.
-
-This module provides a simple factory interface for creating repository instances.
-It does not contain any implementation details or dependencies on specific repositories.
-"""
-
-from collections.abc import Callable, Mapping
-from typing import Any, Literal, Optional, cast
-
-from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
-
-# Type for factory functions - takes a dict of parameters and returns any repository type
-RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any]
-
-# Type for workflow node execution factory function
-WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository]
-
-# Repository type literals
-_RepositoryType = Literal["workflow_node_execution"]
-
-
-class RepositoryFactory:
- """
- Factory class for creating repository instances.
-
- This factory delegates the actual repository creation to implementation-specific
- factory functions that are registered with the factory at runtime.
- """
-
- # Dictionary to store factory functions
- _factory_functions: dict[str, RepositoryFactoryFunc] = {}
-
- @classmethod
- def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None:
- """
- Register a factory function for a specific repository type.
- This is a private method and should not be called directly.
-
- Args:
- repository_type: The type of repository (e.g., 'workflow_node_execution')
- factory_func: A function that takes parameters and returns a repository instance
- """
- cls._factory_functions[repository_type] = factory_func
-
- @classmethod
- def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any:
- """
- Create a new repository instance with the provided parameters.
- This is a private method and should not be called directly.
-
- Args:
- repository_type: The type of repository to create
- params: A dictionary of parameters to pass to the factory function
-
- Returns:
- A new instance of the requested repository
-
- Raises:
- ValueError: If no factory function is registered for the repository type
- """
- if repository_type not in cls._factory_functions:
- raise ValueError(f"No factory function registered for repository type '{repository_type}'")
-
- # Use empty dict if params is None
- params = params or {}
-
- return cls._factory_functions[repository_type](params)
-
- @classmethod
- def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None:
- """
- Register a factory function for the workflow node execution repository.
-
- Args:
- factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance
- """
- cls._register_factory("workflow_node_execution", factory_func)
-
- @classmethod
- def create_workflow_node_execution_repository(
- cls, params: Optional[Mapping[str, Any]] = None
- ) -> WorkflowNodeExecutionRepository:
- """
- Create a new WorkflowNodeExecutionRepository instance with the provided parameters.
-
- Args:
- params: A dictionary of parameters to pass to the factory function
-
- Returns:
- A new instance of the WorkflowNodeExecutionRepository
-
- Raises:
- ValueError: If no factory function is registered for the workflow_node_execution repository type
- """
- # We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc
- return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params))
diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py
index 4f733f0ea1..cf75bd3d7e 100644
--- a/api/core/tools/builtin_tool/provider.py
+++ b/api/core/tools/builtin_tool/provider.py
@@ -35,8 +35,9 @@ class BuiltinToolProviderController(ToolProviderController):
provider_yaml["credentials_for_provider"][credential_name]["name"] = credential_name
credentials_schema = []
- for credential in provider_yaml.get("credentials_for_provider", {}).values():
- credentials_schema.append(credential)
+ for credential in provider_yaml.get("credentials_for_provider", {}):
+ credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {})
+ credentials_schema.append(credential_dict)
super().__init__(
entity=ToolProviderEntity(
diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py
index 7f37f98d0c..724a2291c6 100644
--- a/api/core/tools/builtin_tool/tool.py
+++ b/api/core/tools/builtin_tool/tool.py
@@ -6,8 +6,8 @@ from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
_SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language
-and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
-retain the original meaning and keep the key points.
+and you can quickly aimed at the main point of an webpage and reproduce it in your own words but
+retain the original meaning and keep the key points.
however, the text you got is too long, what you got is possible a part of the text.
Please summarize the text you got.
"""
diff --git a/api/core/tools/plugin_tool/provider.py b/api/core/tools/plugin_tool/provider.py
index 3616e426b9..494b8e209c 100644
--- a/api/core/tools/plugin_tool/provider.py
+++ b/api/core/tools/plugin_tool/provider.py
@@ -1,6 +1,6 @@
from typing import Any
-from core.plugin.manager.tool import PluginToolManager
+from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.builtin_tool.provider import BuiltinToolProviderController
from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType
diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py
index f31a9a0d3e..d21e3d7d1c 100644
--- a/api/core/tools/plugin_tool/tool.py
+++ b/api/core/tools/plugin_tool/tool.py
@@ -1,7 +1,7 @@
from collections.abc import Generator
from typing import Any, Optional
-from core.plugin.manager.tool import PluginToolManager
+from core.plugin.impl.tool import PluginToolManager
from core.plugin.utils.converter import convert_parameters_to_plugin_format
from core.tools.__base.tool import Tool
from core.tools.__base.tool_runtime import ToolRuntime
diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py
new file mode 100644
index 0000000000..e80005d7bf
--- /dev/null
+++ b/api/core/tools/signature.py
@@ -0,0 +1,41 @@
+import base64
+import hashlib
+import hmac
+import os
+import time
+
+from configs import dify_config
+
+
+def sign_tool_file(tool_file_id: str, extension: str) -> str:
+ """
+ sign file to get a temporary url
+ """
+ base_url = dify_config.FILES_URL
+ file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
+
+ timestamp = str(int(time.time()))
+ nonce = os.urandom(16).hex()
+ data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
+ secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
+ sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
+ encoded_sign = base64.urlsafe_b64encode(sign).decode()
+
+ return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
+
+
+def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
+ """
+ verify signature
+ """
+ data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
+ secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
+ recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
+ recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
+
+ # verify signature
+ if sign != recalculated_encoded_sign:
+ return False
+
+ current_time = int(time.time())
+ return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py
index 997917f31c..3dce1ca293 100644
--- a/api/core/tools/tool_engine.py
+++ b/api/core/tools/tool_engine.py
@@ -246,7 +246,7 @@ class ToolEngine:
+ "you do not need to create it, just tell the user to check it now."
)
elif response.type == ToolInvokeMessage.MessageType.JSON:
- result = json.dumps(
+ result += json.dumps(
cast(ToolInvokeMessage.JsonMessage, response.message).json_object, ensure_ascii=False
)
else:
diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py
index 7e8d4280d4..b849f51064 100644
--- a/api/core/tools/tool_file_manager.py
+++ b/api/core/tools/tool_file_manager.py
@@ -4,23 +4,34 @@ import hmac
import logging
import os
import time
+from collections.abc import Generator
from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4
import httpx
+from sqlalchemy.orm import Session
from configs import dify_config
from core.helper import ssrf_proxy
-from extensions.ext_database import db
+from extensions.ext_database import db as global_db
from extensions.ext_storage import storage
from models.model import MessageFile
from models.tools import ToolFile
logger = logging.getLogger(__name__)
+from sqlalchemy.engine import Engine
+
class ToolFileManager:
+ _engine: Engine
+
+ def __init__(self, engine: Engine | None = None):
+ if engine is None:
+ engine = global_db.engine
+ self._engine = engine
+
@staticmethod
def sign_file(tool_file_id: str, extension: str) -> str:
"""
@@ -55,8 +66,8 @@ class ToolFileManager:
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
- @staticmethod
def create_file_by_raw(
+ self,
*,
user_id: str,
tenant_id: str,
@@ -77,24 +88,25 @@ class ToolFileManager:
filepath = f"tools/{tenant_id}/{unique_filename}"
storage.save(filepath, file_binary)
- tool_file = ToolFile(
- user_id=user_id,
- tenant_id=tenant_id,
- conversation_id=conversation_id,
- file_key=filepath,
- mimetype=mimetype,
- name=present_filename,
- size=len(file_binary),
- )
+ with Session(self._engine, expire_on_commit=False) as session:
+ tool_file = ToolFile(
+ user_id=user_id,
+ tenant_id=tenant_id,
+ conversation_id=conversation_id,
+ file_key=filepath,
+ mimetype=mimetype,
+ name=present_filename,
+ size=len(file_binary),
+ )
- db.session.add(tool_file)
- db.session.commit()
- db.session.refresh(tool_file)
+ session.add(tool_file)
+ session.commit()
+ session.refresh(tool_file)
return tool_file
- @staticmethod
def create_file_by_url(
+ self,
user_id: str,
tenant_id: str,
file_url: str,
@@ -119,24 +131,24 @@ class ToolFileManager:
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob)
- tool_file = ToolFile(
- user_id=user_id,
- tenant_id=tenant_id,
- conversation_id=conversation_id,
- file_key=filepath,
- mimetype=mimetype,
- original_url=file_url,
- name=filename,
- size=len(blob),
- )
+ with Session(self._engine, expire_on_commit=False) as session:
+ tool_file = ToolFile(
+ user_id=user_id,
+ tenant_id=tenant_id,
+ conversation_id=conversation_id,
+ file_key=filepath,
+ mimetype=mimetype,
+ original_url=file_url,
+ name=filename,
+ size=len(blob),
+ )
- db.session.add(tool_file)
- db.session.commit()
+ session.add(tool_file)
+ session.commit()
return tool_file
- @staticmethod
- def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
+ def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
@@ -144,13 +156,14 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
- tool_file: ToolFile | None = (
- db.session.query(ToolFile)
- .filter(
- ToolFile.id == id,
+ with Session(self._engine, expire_on_commit=False) as session:
+ tool_file: ToolFile | None = (
+ session.query(ToolFile)
+ .filter(
+ ToolFile.id == id,
+ )
+ .first()
)
- .first()
- )
if not tool_file:
return None
@@ -159,8 +172,7 @@ class ToolFileManager:
return blob, tool_file.mimetype
- @staticmethod
- def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
+ def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
@@ -168,33 +180,34 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
- message_file: MessageFile | None = (
- db.session.query(MessageFile)
- .filter(
- MessageFile.id == id,
+ with Session(self._engine, expire_on_commit=False) as session:
+ message_file: MessageFile | None = (
+ session.query(MessageFile)
+ .filter(
+ MessageFile.id == id,
+ )
+ .first()
)
- .first()
- )
- # Check if message_file is not None
- if message_file is not None:
- # get tool file id
- if message_file.url is not None:
- tool_file_id = message_file.url.split("/")[-1]
- # trim extension
- tool_file_id = tool_file_id.split(".")[0]
+ # Check if message_file is not None
+ if message_file is not None:
+ # get tool file id
+ if message_file.url is not None:
+ tool_file_id = message_file.url.split("/")[-1]
+ # trim extension
+ tool_file_id = tool_file_id.split(".")[0]
+ else:
+ tool_file_id = None
else:
tool_file_id = None
- else:
- tool_file_id = None
- tool_file: ToolFile | None = (
- db.session.query(ToolFile)
- .filter(
- ToolFile.id == tool_file_id,
+ tool_file: ToolFile | None = (
+ session.query(ToolFile)
+ .filter(
+ ToolFile.id == tool_file_id,
+ )
+ .first()
)
- .first()
- )
if not tool_file:
return None
@@ -203,8 +216,7 @@ class ToolFileManager:
return blob, tool_file.mimetype
- @staticmethod
- def get_file_generator_by_tool_file_id(tool_file_id: str):
+ def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Optional[Generator], Optional[ToolFile]]:
"""
get file binary
@@ -212,13 +224,14 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
- tool_file: ToolFile | None = (
- db.session.query(ToolFile)
- .filter(
- ToolFile.id == tool_file_id,
+ with Session(self._engine, expire_on_commit=False) as session:
+ tool_file: ToolFile | None = (
+ session.query(ToolFile)
+ .filter(
+ ToolFile.id == tool_file_id,
+ )
+ .first()
)
- .first()
- )
if not tool_file:
return None, None
@@ -229,6 +242,11 @@ class ToolFileManager:
# init tool_file_parser
-from core.file.tool_file_parser import tool_file_manager
+from core.file.tool_file_parser import set_tool_file_manager_factory
+
+
+def _factory() -> ToolFileManager:
+ return ToolFileManager()
+
-tool_file_manager["manager"] = ToolFileManager
+set_tool_file_manager_factory(_factory)
diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py
index f2d0b74f7c..aa2661fe63 100644
--- a/api/core/tools/tool_manager.py
+++ b/api/core/tools/tool_manager.py
@@ -10,7 +10,7 @@ from yarl import URL
import contexts
from core.plugin.entities.plugin import ToolProviderID
-from core.plugin.manager.tool import PluginToolManager
+from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.plugin_tool.provider import PluginToolProviderController
diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
index f5838c3b76..ed97b44f95 100644
--- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
+++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py
@@ -1,10 +1,12 @@
-from typing import Any
+from typing import Any, Optional, cast
from pydantic import BaseModel, Field
+from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext
from core.rag.models.document import Document as RetrievalDocument
+from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
@@ -33,6 +35,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
args_schema: type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "
dataset_id: str
+ user_id: Optional[str] = None
+ retrieve_config: DatasetRetrieveConfigEntity
+ inputs: dict
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
@@ -58,6 +63,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return ""
for hit_callback in self.hit_callbacks:
hit_callback.on_query(query, dataset.id)
+ dataset_retrieval = DatasetRetrieval()
+ metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
+ [dataset.id],
+ query,
+ self.tenant_id,
+ self.user_id or "unknown",
+ cast(str, self.retrieve_config.metadata_filtering_mode),
+ cast(ModelConfig, self.retrieve_config.metadata_model_config),
+ self.retrieve_config.metadata_filtering_conditions,
+ self.inputs,
+ )
+ if metadata_filter_document_ids:
+ document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
+ else:
+ document_ids_filter = None
if dataset.provider == "external":
results = []
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
@@ -65,6 +85,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
dataset_id=dataset.id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
+ metadata_condition=metadata_condition,
)
for external_document in external_documents:
document = RetrievalDocument(
@@ -100,12 +121,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return str("\n".join([item.page_content for item in results]))
else:
+ if metadata_condition and not document_ids_filter:
+ return ""
# get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy":
# use keyword table query
documents = RetrievalService.retrieve(
- retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k
+ retrieval_method="keyword_search",
+ dataset_id=dataset.id,
+ query=query,
+ top_k=self.top_k,
+ document_ids_filter=document_ids_filter,
)
return str("\n".join([document.page_content for document in documents]))
else:
@@ -124,6 +151,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights"),
+ document_ids_filter=document_ids_filter,
)
else:
documents = []
diff --git a/api/core/tools/utils/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever_tool.py
index b73dec4ebc..ec0575f6c3 100644
--- a/api/core/tools/utils/dataset_retriever_tool.py
+++ b/api/core/tools/utils/dataset_retriever_tool.py
@@ -34,6 +34,8 @@ class DatasetRetrieverTool(Tool):
return_resource: bool,
invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler,
+ user_id: str,
+ inputs: dict,
) -> list["DatasetRetrieverTool"]:
"""
get dataset tool
@@ -57,6 +59,8 @@ class DatasetRetrieverTool(Tool):
return_resource=return_resource,
invoke_from=invoke_from,
hit_callback=hit_callback,
+ user_id=user_id,
+ inputs=inputs,
)
if retrieval_tools is None or len(retrieval_tools) == 0:
return []
diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py
index 6fd0c201e3..257d96133e 100644
--- a/api/core/tools/utils/message_transformer.py
+++ b/api/core/tools/utils/message_transformer.py
@@ -31,8 +31,8 @@ class ToolFileMessageTransformer:
# try to download image
try:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
-
- file = ToolFileManager.create_file_by_url(
+ tool_file_manager = ToolFileManager()
+ file = tool_file_manager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
file_url=message.message.text,
@@ -60,7 +60,7 @@ class ToolFileMessageTransformer:
mimetype = meta.get("mime_type", "application/octet-stream")
# get filename from meta
- filename = meta.get("file_name", None)
+ filename = meta.get("filename", None)
# if message is str, encode it to bytes
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
@@ -68,7 +68,8 @@ class ToolFileMessageTransformer:
# FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes)
- file = ToolFileManager.create_file_by_raw(
+ tool_file_manager = ToolFileManager()
+ file = tool_file_manager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,
diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py
index da40cbcdea..771e0ca7a5 100644
--- a/api/core/workflow/nodes/agent/agent_node.py
+++ b/api/core/workflow/nodes/agent/agent_node.py
@@ -7,8 +7,8 @@ from core.agent.plugin_entities import AgentStrategyParameter
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
-from core.plugin.manager.exc import PluginDaemonClientSideError
-from core.plugin.manager.plugin import PluginInstallationManager
+from core.plugin.impl.exc import PluginDaemonClientSideError
+from core.plugin.impl.plugin import PluginInstaller
from core.provider_manager import ProviderManager
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
from core.tools.tool_manager import ToolManager
@@ -297,7 +297,7 @@ class AgentNode(ToolNode):
Get agent strategy icon
:return:
"""
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
plugins = manager.list_plugins(self.tenant_id)
try:
current_plugin = next(
diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py
index 960d0c3961..8fb1baec89 100644
--- a/api/core/workflow/nodes/document_extractor/node.py
+++ b/api/core/workflow/nodes/document_extractor/node.py
@@ -11,6 +11,7 @@ import docx
import pandas as pd
import pypandoc # type: ignore
import pypdfium2 # type: ignore
+import webvtt # type: ignore
import yaml # type: ignore
from docx.document import Document
from docx.oxml.table import CT_Tbl
@@ -132,6 +133,10 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
return _extract_text_from_json(file_content)
case "application/x-yaml" | "text/yaml":
return _extract_text_from_yaml(file_content)
+ case "text/vtt":
+ return _extract_text_from_vtt(file_content)
+ case "text/properties":
+ return _extract_text_from_properties(file_content)
case _:
raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}")
@@ -139,7 +144,7 @@ def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str:
def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str) -> str:
"""Extract text from a file based on its file extension."""
match file_extension:
- case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml" | ".vtt":
+ case ".txt" | ".markdown" | ".md" | ".html" | ".htm" | ".xml":
return _extract_text_from_plain_text(file_content)
case ".json":
return _extract_text_from_json(file_content)
@@ -165,6 +170,10 @@ def _extract_text_by_file_extension(*, file_content: bytes, file_extension: str)
return _extract_text_from_eml(file_content)
case ".msg":
return _extract_text_from_msg(file_content)
+ case ".vtt":
+ return _extract_text_from_vtt(file_content)
+ case ".properties":
+ return _extract_text_from_properties(file_content)
case _:
raise UnsupportedFileTypeError(f"Unsupported Extension Type: {file_extension}")
@@ -214,8 +223,8 @@ def _extract_text_from_doc(file_content: bytes) -> str:
"""
from unstructured.partition.api import partition_via_api
- if not (dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY):
- raise TextExtractionError("UNSTRUCTURED_API_URL and UNSTRUCTURED_API_KEY must be set")
+ if not dify_config.UNSTRUCTURED_API_URL:
+ raise TextExtractionError("UNSTRUCTURED_API_URL must be set")
try:
with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file:
@@ -226,7 +235,7 @@ def _extract_text_from_doc(file_content: bytes) -> str:
file=file,
metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL,
- api_key=dify_config.UNSTRUCTURED_API_KEY,
+ api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
)
os.unlink(temp_file.name)
return "\n".join([getattr(element, "text", "") for element in elements])
@@ -462,3 +471,68 @@ def _extract_text_from_msg(file_content: bytes) -> str:
return "\n".join([str(element) for element in elements])
except Exception as e:
raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e
+
+
+def _extract_text_from_vtt(vtt_bytes: bytes) -> str:
+ text = _extract_text_from_plain_text(vtt_bytes)
+
+ # remove bom
+ text = text.lstrip("\ufeff")
+
+ raw_results = []
+ for caption in webvtt.from_string(text):
+ raw_results.append((caption.voice, caption.text))
+
+ # Merge consecutive utterances by the same speaker
+ merged_results = []
+ if raw_results:
+ current_speaker, current_text = raw_results[0]
+
+ for i in range(1, len(raw_results)):
+ spk, txt = raw_results[i]
+ if spk == None:
+ merged_results.append((None, current_text))
+ continue
+
+ if spk == current_speaker:
+ # If it is the same speaker, merge the utterances (joined by space)
+ current_text += " " + txt
+ else:
+ # If the speaker changes, register the utterance so far and move on
+ merged_results.append((current_speaker, current_text))
+ current_speaker, current_text = spk, txt
+
+ # Add the last element
+ merged_results.append((current_speaker, current_text))
+ else:
+ merged_results = raw_results
+
+ # Return the result in the specified format: Speaker "text" style
+ formatted = [f'{spk or ""} "{txt}"' for spk, txt in merged_results]
+ return "\n".join(formatted)
+
+
+def _extract_text_from_properties(file_content: bytes) -> str:
+ try:
+ text = _extract_text_from_plain_text(file_content)
+ lines = text.splitlines()
+ result = []
+ for line in lines:
+ line = line.strip()
+ # Preserve comments and empty lines
+ if not line or line.startswith("#") or line.startswith("!"):
+ result.append(line)
+ continue
+
+ if "=" in line:
+ key, value = line.split("=", 1)
+ elif ":" in line:
+ key, value = line.split(":", 1)
+ else:
+ key, value = line, ""
+
+ result.append(f"{key.strip()}: {value.strip()}")
+
+ return "\n".join(result)
+ except Exception as e:
+ raise TextExtractionError(f"Failed to extract text from properties file: {str(e)}") from e
diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py
index 5d466e645f..2c42f5a1be 100644
--- a/api/core/workflow/nodes/http_request/executor.py
+++ b/api/core/workflow/nodes/http_request/executor.py
@@ -262,7 +262,10 @@ class Executor:
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
elif self.auth.config.type == "basic":
credentials = authorization.config.api_key
- encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
+ if ":" in credentials:
+ encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
+ else:
+ encoded_credentials = credentials
headers[authorization.config.header] = f"Basic {encoded_credentials}"
elif self.auth.config.type == "custom":
headers[authorization.config.header] = authorization.config.api_key or ""
diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py
index fd2b0f9ae8..1c82637974 100644
--- a/api/core/workflow/nodes/http_request/node.py
+++ b/api/core/workflow/nodes/http_request/node.py
@@ -191,8 +191,9 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
mime_type = (
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
)
+ tool_file_manager = ToolFileManager()
- tool_file = ToolFileManager.create_file_by_raw(
+ tool_file = tool_file_manager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,
diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
index 07a711cc4e..5c4cac9719 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
@@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
-from sqlalchemy import Integer, and_, func, or_, text
+from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast
from core.app.app_config.entities import DatasetRetrieveConfigEntity
@@ -32,11 +32,11 @@ from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_COMPLETION_PROMPT,
METADATA_FILTER_SYSTEM_PROMPT,
METADATA_FILTER_USER_PROMPT_1,
+ METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
from core.workflow.nodes.llm.node import LLMNode
-from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown
@@ -349,17 +349,19 @@ class KnowledgeRetrievalNode(LLMNode):
)
)
metadata_condition = MetadataCondition(
- logical_operator=node_data.metadata_filtering_conditions.logical_operator, # type: ignore
+ logical_operator=node_data.metadata_filtering_conditions.logical_operator
+ if node_data.metadata_filtering_conditions
+ else "or", # type: ignore
conditions=conditions,
)
elif node_data.metadata_filtering_mode == "manual":
if node_data.metadata_filtering_conditions:
- metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump())
+ conditions = []
if node_data.metadata_filtering_conditions:
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
- if expected_value is not None or condition.comparison_operator in ("empty", "not empty"):
+ if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
@@ -370,17 +372,31 @@ class KnowledgeRetrievalNode(LLMNode):
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
else:
raise ValueError("Invalid expected metadata value type")
- filters = self._process_metadata_filter_func(
- sequence,
- condition.comparison_operator,
- metadata_name,
- expected_value,
- filters,
+ conditions.append(
+ Condition(
+ name=metadata_name,
+ comparison_operator=condition.comparison_operator,
+ value=expected_value,
)
+ )
+ filters = self._process_metadata_filter_func(
+ sequence,
+ condition.comparison_operator,
+ metadata_name,
+ expected_value,
+ filters,
+ )
+ metadata_condition = MetadataCondition(
+ logical_operator=node_data.metadata_filtering_conditions.logical_operator,
+ conditions=conditions,
+ )
else:
raise ValueError("Invalid metadata filtering mode")
if filters:
- if node_data.metadata_filtering_conditions.logical_operator == "and": # type: ignore
+ if (
+ node_data.metadata_filtering_conditions
+ and node_data.metadata_filtering_conditions.logical_operator == "and"
+ ): # type: ignore
document_query = document_query.filter(and_(*filters))
else:
document_query = document_query.filter(or_(*filters))
@@ -488,24 +504,24 @@ class KnowledgeRetrievalNode(LLMNode):
if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
else:
- filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) == value)
+ filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) == value)
case "is not" | "≠":
if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] != f'"{value}"')
else:
- filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) != value)
+ filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) != value)
case "empty":
filters.append(Document.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(Document.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
- filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) < value)
+ filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) < value)
case "after" | ">":
- filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) > value)
- case "≤" | ">=":
- filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) <= value)
+ filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) > value)
+ case "≤" | "<=":
+ filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) <= value)
case "≥" | ">=":
- filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) >= value)
+ filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) >= value)
case _:
pass
return filters
@@ -613,7 +629,7 @@ class KnowledgeRetrievalNode(LLMNode):
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = LLMNodeChatModelMessage(
- role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
+ role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = LLMNodeChatModelMessage(
diff --git a/api/core/workflow/nodes/knowledge_retrieval/template_prompts.py b/api/core/workflow/nodes/knowledge_retrieval/template_prompts.py
index 7abd55d798..9c945e2f52 100644
--- a/api/core/workflow/nodes/knowledge_retrieval/template_prompts.py
+++ b/api/core/workflow/nodes/knowledge_retrieval/template_prompts.py
@@ -2,7 +2,7 @@ METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
- Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
+ Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "=", "≠", ">", "<", "≥", "≤", "before", "after"] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint
@@ -50,7 +50,7 @@ You are a text metadata extract engine that extract text's metadata based on use
# Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
-### Constraint
+### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside XML tags.
@@ -59,7 +59,7 @@ User:{{"input_text": ["I want to know which company’s email address test@examp
Assistant:{{"metadata_map": [{{"metadata_field_name": "email", "metadata_field_value": "test@example.com", "comparison_operator": "="}}]}}
User:{{"input_text": "What are the movies with a score of more than 9 in 2024?", "metadata_fields": ["name", "year", "rating", "country"]}}
Assistant:{{"metadata_map": [{{"metadata_field_name": "year", "metadata_field_value": "2024", "comparison_operator": "="}, {{"metadata_field_name": "rating", "metadata_field_value": "9", "comparison_operator": ">"}}]}}
-
+
### User Input
{{"input_text" : "{input_text}", "metadata_fields" : {metadata_fields}}}
### Assistant Output
diff --git a/api/core/workflow/nodes/llm/exc.py b/api/core/workflow/nodes/llm/exc.py
index 6599221691..42b8f4e6ce 100644
--- a/api/core/workflow/nodes/llm/exc.py
+++ b/api/core/workflow/nodes/llm/exc.py
@@ -38,3 +38,8 @@ class MemoryRolePrefixRequiredError(LLMNodeError):
class FileTypeNotSupportError(LLMNodeError):
def __init__(self, *, type_name: str):
super().__init__(f"{type_name} type is not supported by this model")
+
+
+class UnsupportedPromptContentTypeError(LLMNodeError):
+ def __init__(self, *, type_name: str) -> None:
+ super().__init__(f"Prompt content type {type_name} is not supported.")
diff --git a/api/core/workflow/nodes/llm/file_saver.py b/api/core/workflow/nodes/llm/file_saver.py
new file mode 100644
index 0000000000..c85baade03
--- /dev/null
+++ b/api/core/workflow/nodes/llm/file_saver.py
@@ -0,0 +1,160 @@
+import mimetypes
+import typing as tp
+
+from sqlalchemy import Engine
+
+from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
+from core.file import File, FileTransferMethod, FileType
+from core.helper import ssrf_proxy
+from core.tools.signature import sign_tool_file
+from core.tools.tool_file_manager import ToolFileManager
+from models import db as global_db
+
+
+class LLMFileSaver(tp.Protocol):
+ """LLMFileSaver is responsible for save multimodal output returned by
+ LLM.
+ """
+
+ def save_binary_string(
+ self,
+ data: bytes,
+ mime_type: str,
+ file_type: FileType,
+ extension_override: str | None = None,
+ ) -> File:
+ """save_binary_string saves the inline file data returned by LLM.
+
+ Currently (2025-04-30), only some of Google Gemini models will return
+ multimodal output as inline data.
+
+ :param data: the contents of the file
+ :param mime_type: the media type of the file, specified by rfc6838
+ (https://datatracker.ietf.org/doc/html/rfc6838)
+ :param file_type: The file type of the inline file.
+ :param extension_override: Override the auto-detected file extension while saving this file.
+
+ The default value is `None`, which means do not override the file extension and guessing it
+ from the `mime_type` attribute while saving the file.
+
+ Setting it to values other than `None` means override the file's extension, and
+ will bypass the extension guessing saving the file.
+
+ Specially, setting it to empty string (`""`) will leave the file extension empty.
+
+ When it is not `None` or empty string (`""`), it should be a string beginning with a
+ dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
+ and `tar.gz` are not.
+ """
+ pass
+
+ def save_remote_url(self, url: str, file_type: FileType) -> File:
+ """save_remote_url saves the file from a remote url returned by LLM.
+
+ Currently (2025-04-30), no model returns multimodel output as a url.
+
+ :param url: the url of the file.
+ :param file_type: the file type of the file, check `FileType` enum for reference.
+ """
+ pass
+
+
+EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
+
+
+class FileSaverImpl(LLMFileSaver):
+ _engine_factory: EngineFactory
+ _tenant_id: str
+ _user_id: str
+
+ def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
+ if engine_factory is None:
+
+ def _factory():
+ return global_db.engine
+
+ engine_factory = _factory
+ self._engine_factory = engine_factory
+ self._user_id = user_id
+ self._tenant_id = tenant_id
+
+ def _get_tool_file_manager(self):
+ return ToolFileManager(engine=self._engine_factory())
+
+ def save_remote_url(self, url: str, file_type: FileType) -> File:
+ http_response = ssrf_proxy.get(url)
+ http_response.raise_for_status()
+ data = http_response.content
+ mime_type_from_header = http_response.headers.get("Content-Type")
+ mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header)
+ return self.save_binary_string(data, mime_type, file_type, extension_override=extension)
+
+ def save_binary_string(
+ self,
+ data: bytes,
+ mime_type: str,
+ file_type: FileType,
+ extension_override: str | None = None,
+ ) -> File:
+ tool_file_manager = self._get_tool_file_manager()
+ tool_file = tool_file_manager.create_file_by_raw(
+ user_id=self._user_id,
+ tenant_id=self._tenant_id,
+ # TODO(QuantumGhost): what is conversation id?
+ conversation_id=None,
+ file_binary=data,
+ mimetype=mime_type,
+ )
+ extension_override = _validate_extension_override(extension_override)
+ extension = _get_extension(mime_type, extension_override)
+ url = sign_tool_file(tool_file.id, extension)
+
+ return File(
+ tenant_id=self._tenant_id,
+ type=file_type,
+ transfer_method=FileTransferMethod.TOOL_FILE,
+ filename=tool_file.name,
+ extension=extension,
+ mime_type=mime_type,
+ size=len(data),
+ related_id=tool_file.id,
+ url=url,
+ # TODO(QuantumGhost): how should I set the following key?
+ # What's the difference between `remote_url` and `url`?
+ # What's the purpose of `storage_key` and `dify_model_identity`?
+ storage_key=tool_file.file_key,
+ )
+
+
+def _get_extension(mime_type: str, extension_override: str | None = None) -> str:
+ """get_extension return the extension of file.
+
+ If the `extension_override` parameter is set, this function should honor it and
+ return its value.
+ """
+ if extension_override is not None:
+ return extension_override
+ return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION
+
+
+def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
+ """_extract_content_type_and_extension tries to
+ guess content type of file from url and `Content-Type` header in response.
+ """
+ if content_type_header:
+ extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
+ return content_type_header, extension
+ content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
+ extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
+ return content_type, extension
+
+
+def _validate_extension_override(extension_override: str | None) -> str | None:
+ # `extension_override` is allow to be `None or `""`.
+ if extension_override is None:
+ return None
+ if extension_override == "":
+ return ""
+ if not extension_override.startswith("."):
+ raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
+ return extension_override
diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py
index 8db7394e54..f42bc6784d 100644
--- a/api/core/workflow/nodes/llm/node.py
+++ b/api/core/workflow/nodes/llm/node.py
@@ -1,3 +1,5 @@
+import base64
+import io
import json
import logging
from collections.abc import Generator, Mapping, Sequence
@@ -21,10 +23,10 @@ from core.model_runtime.entities import (
PromptMessageContentType,
TextPromptMessageContent,
)
-from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
+from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
- PromptMessageContent,
+ PromptMessageContentUnionTypes,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
@@ -94,9 +96,13 @@ from .exc import (
TemplateTypeNotSupportError,
VariableNotFoundError,
)
+from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
+ from core.workflow.graph_engine.entities.graph import Graph
+ from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
+ from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
logger = logging.getLogger(__name__)
@@ -105,8 +111,45 @@ class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
+ # Instance attributes specific to LLMNode.
+ # Output variable for file
+ _file_outputs: list["File"]
+
+ _llm_file_saver: LLMFileSaver
+
+ def __init__(
+ self,
+ id: str,
+ config: Mapping[str, Any],
+ graph_init_params: "GraphInitParams",
+ graph: "Graph",
+ graph_runtime_state: "GraphRuntimeState",
+ previous_node_id: Optional[str] = None,
+ thread_pool_id: Optional[str] = None,
+ *,
+ llm_file_saver: LLMFileSaver | None = None,
+ ) -> None:
+ super().__init__(
+ id=id,
+ config=config,
+ graph_init_params=graph_init_params,
+ graph=graph,
+ graph_runtime_state=graph_runtime_state,
+ previous_node_id=previous_node_id,
+ thread_pool_id=thread_pool_id,
+ )
+ # LLM file outputs, used for MultiModal outputs.
+ self._file_outputs: list[File] = []
+
+ if llm_file_saver is None:
+ llm_file_saver = FileSaverImpl(
+ user_id=graph_init_params.user_id,
+ tenant_id=graph_init_params.tenant_id,
+ )
+ self._llm_file_saver = llm_file_saver
+
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
- def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
+ def process_structured_output(text: str) -> Optional[dict[str, Any]]:
"""Process structured output if enabled"""
if not self.node_data.structured_output_enabled or not self.node_data.structured_output:
return None
@@ -214,6 +257,9 @@ class LLMNode(BaseNode[LLMNodeData]):
structured_output = process_structured_output(result_text)
if structured_output:
outputs["structured_output"] = structured_output
+ if self._file_outputs is not None:
+ outputs["files"] = self._file_outputs
+
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -239,6 +285,7 @@ class LLMNode(BaseNode[LLMNodeData]):
)
)
except Exception as e:
+ logger.exception("error while executing llm node")
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@@ -267,55 +314,45 @@ class LLMNode(BaseNode[LLMNodeData]):
return self._handle_invoke_result(invoke_result=invoke_result)
- def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
+ def _handle_invoke_result(
+ self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None]
+ ) -> Generator[NodeEvent, None, None]:
+ # For blocking mode
if isinstance(invoke_result, LLMResult):
- content = invoke_result.message.content
- if content is None:
- message_text = ""
- elif isinstance(content, str):
- message_text = content
- elif isinstance(content, list):
- # Assuming the list contains PromptMessageContent objects with a "data" attribute
- message_text = "".join(
- item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
- )
- else:
- message_text = str(content)
-
- yield ModelInvokeCompletedEvent(
- text=message_text,
- usage=invoke_result.usage,
- finish_reason=None,
- )
+ event = self._handle_blocking_result(invoke_result=invoke_result)
+ yield event
return
- model = None
+ # For streaming mode
+ model = ""
prompt_messages: list[PromptMessage] = []
- full_text = ""
- usage = None
+
+ usage = LLMUsage.empty_usage()
finish_reason = None
+ full_text_buffer = io.StringIO()
for result in invoke_result:
- text = result.delta.message.content
- full_text += text
-
- yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
+ contents = result.delta.message.content
+ for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
+ full_text_buffer.write(text_part)
+ yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[self.node_id, "text"])
- if not model:
+ # Update the whole metadata
+ if not model and result.model:
model = result.model
-
- if not prompt_messages:
- prompt_messages = result.prompt_messages
-
- if not usage and result.delta.usage:
+ if len(prompt_messages) == 0:
+ # TODO(QuantumGhost): it seems that this update has no visable effect.
+ # What's the purpose of the line below?
+ prompt_messages = list(result.prompt_messages)
+ if usage.prompt_tokens == 0 and result.delta.usage:
usage = result.delta.usage
-
- if not finish_reason and result.delta.finish_reason:
+ if finish_reason is None and result.delta.finish_reason:
finish_reason = result.delta.finish_reason
- if not usage:
- usage = LLMUsage.empty_usage()
+ yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
- yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
+ def _image_file_to_markdown(self, file: "File", /):
+ text_chunk = f"})"
+ return text_chunk
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
@@ -594,8 +631,7 @@ class LLMNode(BaseNode[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
- # FIXME: fix the type error cause prompt_messages is type quick a few times
- prompt_messages: list[Any] = []
+ prompt_messages: list[PromptMessage] = []
if isinstance(prompt_template, list):
# For chat model
@@ -657,12 +693,14 @@ class LLMNode(BaseNode[LLMNodeData]):
# For issue #11247 - Check if prompt content is a string or a list
prompt_content_type = type(prompt_content)
if prompt_content_type == str:
+ prompt_content = str(prompt_content)
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
+ prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
if "#histories#" in content_item.data:
@@ -675,9 +713,10 @@ class LLMNode(BaseNode[LLMNodeData]):
# Add current query to the prompt message
if sys_query:
if prompt_content_type == str:
- prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query)
+ prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif prompt_content_type == list:
+ prompt_content = prompt_content if isinstance(prompt_content, list) else []
for content_item in prompt_content:
if content_item.type == PromptMessageContentType.TEXT:
content_item.data = sys_query + "\n" + content_item.data
@@ -707,7 +746,7 @@ class LLMNode(BaseNode[LLMNodeData]):
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
- prompt_message_content = []
+ prompt_message_content: list[PromptMessageContentUnionTypes] = []
for content_item in prompt_message.content:
# Skip content if features are not defined
if not model_config.model_schema.features:
@@ -758,18 +797,22 @@ class LLMNode(BaseNode[LLMNodeData]):
stop = model_config.stop
return filtered_prompt_messages, stop
- def _parse_structured_output(self, result_text: str) -> dict[str, Any] | list[Any]:
- structured_output: dict[str, Any] | list[Any] = {}
+ def _parse_structured_output(self, result_text: str) -> dict[str, Any]:
+ structured_output: dict[str, Any] = {}
try:
parsed = json.loads(result_text)
- if not isinstance(parsed, (dict | list)):
+ if not isinstance(parsed, dict):
raise LLMNodeError(f"Failed to parse structured output: {result_text}")
structured_output = parsed
except json.JSONDecodeError as e:
# if the result_text is not a valid json, try to repair it
parsed = json_repair.loads(result_text)
- if not isinstance(parsed, (dict | list)):
- raise LLMNodeError(f"Failed to parse structured output: {result_text}")
+ if not isinstance(parsed, dict):
+ # handle reasoning model like deepseek-r1 got '\n\n \n' prefix
+ if isinstance(parsed, list):
+ parsed = next((item for item in parsed if isinstance(item, dict)), {})
+ else:
+ raise LLMNodeError(f"Failed to parse structured output: {result_text}")
structured_output = parsed
return structured_output
@@ -971,6 +1014,42 @@ class LLMNode(BaseNode[LLMNodeData]):
return prompt_messages
+ def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
+ buffer = io.StringIO()
+ for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content):
+ buffer.write(text_part)
+
+ return ModelInvokeCompletedEvent(
+ text=buffer.getvalue(),
+ usage=invoke_result.usage,
+ finish_reason=None,
+ )
+
+ def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File":
+ """_save_multimodal_output saves multi-modal contents generated by LLM plugins.
+
+ There are two kinds of multimodal outputs:
+
+ - Inlined data encoded in base64, which would be saved to storage directly.
+ - Remote files referenced by an url, which would be downloaded and then saved to storage.
+
+ Currently, only image files are supported.
+ """
+ # Inject the saver somehow...
+ _saver = self._llm_file_saver
+
+ # If this
+ if content.url != "":
+ saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
+ else:
+ saved_file = _saver.save_binary_string(
+ data=base64.b64decode(content.base64_data),
+ mime_type=content.mime_type,
+ file_type=FileType.IMAGE,
+ )
+ self._file_outputs.append(saved_file)
+ return saved_file
+
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
"""
Handle structured output for models with native JSON schema support.
@@ -1131,8 +1210,45 @@ class LLMNode(BaseNode[LLMNodeData]):
else SupportStructuredOutputStatus.UNSUPPORTED
)
+ def _save_multimodal_output_and_convert_result_to_markdown(
+ self,
+ contents: str | list[PromptMessageContentUnionTypes] | None,
+ ) -> Generator[str, None, None]:
+ """Convert intermediate prompt messages into strings and yield them to the caller.
+
+ If the messages contain non-textual content (e.g., multimedia like images or videos),
+ it will be saved separately, and the corresponding Markdown representation will
+ be yielded to the caller.
+ """
+
+ # NOTE(QuantumGhost): This function should yield results to the caller immediately
+ # whenever new content or partial content is available. Avoid any intermediate buffering
+ # of results. Additionally, do not yield empty strings; instead, yield from an empty list
+ # if necessary.
+ if contents is None:
+ yield from []
+ return
+ if isinstance(contents, str):
+ yield contents
+ elif isinstance(contents, list):
+ for item in contents:
+ if isinstance(item, TextPromptMessageContent):
+ yield item.data
+ elif isinstance(item, ImagePromptMessageContent):
+ file = self._save_multimodal_image_output(item)
+ self._file_outputs.append(file)
+ yield self._image_file_to_markdown(file)
+ else:
+ logger.warning("unknown item type encountered, type=%s", type(item))
+ yield str(item)
+ else:
+ logger.warning("unknown contents type encountered, type=%s", type(contents))
+ yield str(contents)
-def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
+
+def _combine_message_content_with_role(
+ *, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
+):
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=contents)
diff --git a/api/core/workflow/nodes/parameter_extractor/prompts.py b/api/core/workflow/nodes/parameter_extractor/prompts.py
index 6c3155ac9a..ab7ddcc32a 100644
--- a/api/core/workflow/nodes/parameter_extractor/prompts.py
+++ b/api/core/workflow/nodes/parameter_extractor/prompts.py
@@ -17,7 +17,7 @@ Some additional information is provided below. Always adhere to these instructio
Steps:
1. Review the chat history provided within the tags.
-2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text.
+2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text.
3. Generate a well-formatted output using the defined functions and arguments.
4. Use the `extract_parameter` function to create structured outputs with appropriate parameters.
5. Do not include any XML tags in your output.
@@ -89,13 +89,13 @@ Some extra information are provided below, I should always follow the instructio
### Extract parameter Workflow
-I need to extract the following information from the input text. The tag specifies the 'type', 'description' and 'required' of the information to be extracted.
+I need to extract the following information from the input text. The tag specifies the 'type', 'description' and 'required' of the information to be extracted.
{{ structure }}
Step 1: Carefully read the input and understand the structure of the expected output.
-Step 2: Extract relevant parameters from the provided text based on the name and description of object.
+Step 2: Extract relevant parameters from the provided text based on the name and description of object.
Step 3: Structure the extracted parameters to JSON object as specified in .
Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted.
@@ -106,10 +106,10 @@ Here are the chat histories between human and assistant, inside
### Structure
-Here is the structure of the expected output, I should always follow the output structure.
+Here is the structure of the expected output, I should always follow the output structure.
{{γγγ
- 'properties1': 'relevant text extracted from input',
- 'properties2': 'relevant text extracted from input',
+ 'properties1': 'relevant text extracted from input',
+ 'properties2': 'relevant text extracted from input',
}}γγγ
### Input Text
@@ -119,7 +119,7 @@ Inside XML tags, there is a text that I should extract parameters
### Answer
-I should always output a valid JSON object. Output nothing other than the JSON object.
+I should always output a valid JSON object. Output nothing other than the JSON object.
```JSON
""" # noqa: E501
diff --git a/api/core/workflow/nodes/question_classifier/template_prompts.py b/api/core/workflow/nodes/question_classifier/template_prompts.py
index 70178ed934..a615c32383 100644
--- a/api/core/workflow/nodes/question_classifier/template_prompts.py
+++ b/api/core/workflow/nodes/question_classifier/template_prompts.py
@@ -55,7 +55,7 @@ You are a text classification engine that analyzes text data and assigns categor
Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification.
### Format
The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy.
-### Constraint
+### Constraint
DO NOT include anything other than the JSON array in your response.
### Example
Here is the chat example between human and assistant, inside XML tags.
@@ -64,7 +64,7 @@ User:{{"input_text": ["I recently had a great experience with your company. The
Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"category_id": "f5660049-284f-41a7-b301-fd24176a711c","category_name": "Customer Service"}}
User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"}},{{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"}},{{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}}], "classification_instructions": []}}
Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}}
-
+
### Memory
Here are the chat histories between human and assistant, inside XML tags.
diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py
index 6f0cc3f6d2..c72ae5b69b 100644
--- a/api/core/workflow/nodes/tool/tool_node.py
+++ b/api/core/workflow/nodes/tool/tool_node.py
@@ -6,8 +6,8 @@ from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod
-from core.plugin.manager.exc import PluginDaemonClientSideError
-from core.plugin.manager.plugin import PluginInstallationManager
+from core.plugin.impl.exc import PluginDaemonClientSideError
+from core.plugin.impl.plugin import PluginInstaller
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.errors import ToolInvokeError
from core.tools.tool_engine import ToolEngine
@@ -307,7 +307,7 @@ class ToolNode(BaseNode[ToolNodeData]):
icon = tool_info.get("icon", "")
dict_metadata = dict(message.message.metadata)
if dict_metadata.get("provider"):
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
plugins = manager.list_plugins(self.tenant_id)
try:
current_plugin = next(
diff --git a/api/core/workflow/nodes/variable_assigner/v2/enums.py b/api/core/workflow/nodes/variable_assigner/v2/enums.py
index 36cf68aa19..291b1208d4 100644
--- a/api/core/workflow/nodes/variable_assigner/v2/enums.py
+++ b/api/core/workflow/nodes/variable_assigner/v2/enums.py
@@ -11,6 +11,8 @@ class Operation(StrEnum):
SUBTRACT = "-="
MULTIPLY = "*="
DIVIDE = "/="
+ REMOVE_FIRST = "remove-first"
+ REMOVE_LAST = "remove-last"
class InputType(StrEnum):
diff --git a/api/core/workflow/nodes/variable_assigner/v2/helpers.py b/api/core/workflow/nodes/variable_assigner/v2/helpers.py
index a86c7eb94a..8fb2a27388 100644
--- a/api/core/workflow/nodes/variable_assigner/v2/helpers.py
+++ b/api/core/workflow/nodes/variable_assigner/v2/helpers.py
@@ -23,6 +23,15 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
+ case Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
+ # Only array variable can have elements removed
+ return variable_type in {
+ SegmentType.ARRAY_ANY,
+ SegmentType.ARRAY_OBJECT,
+ SegmentType.ARRAY_STRING,
+ SegmentType.ARRAY_NUMBER,
+ SegmentType.ARRAY_FILE,
+ }
case _:
return False
@@ -51,7 +60,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat
def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any):
- if operation == Operation.CLEAR:
+ if operation in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST}:
return True
match variable_type:
case SegmentType.STRING:
diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py
index 0305eb7f41..6a7ad86b51 100644
--- a/api/core/workflow/nodes/variable_assigner/v2/node.py
+++ b/api/core/workflow/nodes/variable_assigner/v2/node.py
@@ -64,7 +64,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
# Get value from variable pool
if (
item.input_type == InputType.VARIABLE
- and item.operation != Operation.CLEAR
+ and item.operation not in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST}
and item.value is not None
):
value = self.graph_runtime_state.variable_pool.get(item.value)
@@ -165,5 +165,15 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
return variable.value * value
case Operation.DIVIDE:
return variable.value / value
+ case Operation.REMOVE_FIRST:
+ # If array is empty, do nothing
+ if not variable.value:
+ return variable.value
+ return variable.value[1:]
+ case Operation.REMOVE_LAST:
+ # If array is empty, do nothing
+ if not variable.value:
+ return variable.value
+ return variable.value[:-1]
case _:
raise OperationNotSupportedError(operation=operation, variable_type=variable.value_type)
diff --git a/api/core/repository/__init__.py b/api/core/workflow/repository/__init__.py
similarity index 58%
rename from api/core/repository/__init__.py
rename to api/core/workflow/repository/__init__.py
index 253df1251d..672abb6583 100644
--- a/api/core/repository/__init__.py
+++ b/api/core/workflow/repository/__init__.py
@@ -6,10 +6,9 @@ for accessing and manipulating data, regardless of the underlying
storage mechanism.
"""
-from core.repository.repository_factory import RepositoryFactory
-from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository
__all__ = [
- "RepositoryFactory",
+ "OrderConfig",
"WorkflowNodeExecutionRepository",
]
diff --git a/api/core/repository/workflow_node_execution_repository.py b/api/core/workflow/repository/workflow_node_execution_repository.py
similarity index 100%
rename from api/core/repository/workflow_node_execution_repository.py
rename to api/core/workflow/repository/workflow_node_execution_repository.py
diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/workflow/workflow_app_generate_task_pipeline.py
similarity index 98%
rename from api/core/app/apps/workflow/generate_task_pipeline.py
rename to api/core/workflow/workflow_app_generate_task_pipeline.py
index 1f998edb6a..10a2d8b38b 100644
--- a/api/core/app/apps/workflow/generate_task_pipeline.py
+++ b/api/core/workflow/workflow_app_generate_task_pipeline.py
@@ -6,7 +6,6 @@ from typing import Optional, Union
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
-from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import (
InvokeFrom,
@@ -52,9 +51,11 @@ from core.app.entities.task_entities import (
WorkflowTaskState,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
-from core.app.task_pipeline.workflow_cycle_manage import WorkflowCycleManage
+from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariableKey
+from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
+from core.workflow.workflow_cycle_manager import WorkflowCycleManager
from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole
@@ -82,6 +83,7 @@ class WorkflowAppGenerateTaskPipeline:
queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool,
+ workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._base_task_pipeline = BasedGenerateTaskPipeline(
application_generate_entity=application_generate_entity,
@@ -100,7 +102,7 @@ class WorkflowAppGenerateTaskPipeline:
else:
raise ValueError(f"Invalid user type: {type(user)}")
- self._workflow_cycle_manager = WorkflowCycleManage(
+ self._workflow_cycle_manager = WorkflowCycleManager(
application_generate_entity=application_generate_entity,
workflow_system_variables={
SystemVariableKey.FILES: application_generate_entity.files,
@@ -109,6 +111,7 @@ class WorkflowAppGenerateTaskPipeline:
SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
},
+ workflow_node_execution_repository=workflow_node_execution_repository,
)
self._application_generate_entity = application_generate_entity
diff --git a/api/core/app/task_pipeline/workflow_cycle_manage.py b/api/core/workflow/workflow_cycle_manager.py
similarity index 97%
rename from api/core/app/task_pipeline/workflow_cycle_manage.py
rename to api/core/workflow/workflow_cycle_manager.py
index 5ce9f737d1..01d5db4303 100644
--- a/api/core/app/task_pipeline/workflow_cycle_manage.py
+++ b/api/core/workflow/workflow_cycle_manager.py
@@ -6,7 +6,7 @@ from typing import Any, Optional, Union, cast
from uuid import uuid4
from sqlalchemy import func, select
-from sqlalchemy.orm import Session, sessionmaker
+from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
@@ -49,14 +49,13 @@ from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
-from core.repository import RepositoryFactory
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
+from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.workflow_entry import WorkflowEntry
-from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.model import EndUser
@@ -70,32 +69,19 @@ from models.workflow import (
)
-class WorkflowCycleManage:
+class WorkflowCycleManager:
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
+ workflow_node_execution_repository: WorkflowNodeExecutionRepository,
) -> None:
self._workflow_run: WorkflowRun | None = None
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
-
- # Initialize the session factory and repository
- # We use the global db engine instead of the session passed to methods
- # Disable expire_on_commit to avoid the need for merging objects
- self._session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
- self._workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
- params={
- "tenant_id": self._application_generate_entity.app_config.tenant_id,
- "app_id": self._application_generate_entity.app_config.app_id,
- "session_factory": self._session_factory,
- }
- )
-
- # We'll still keep the cache for backward compatibility and performance
- # but use the repository for database operations
+ self._workflow_node_execution_repository = workflow_node_execution_repository
def _handle_workflow_run_start(
self,
@@ -395,6 +381,8 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
+ self._workflow_node_execution_repository.update(workflow_node_execution)
+
return workflow_node_execution
def _handle_workflow_node_execution_retried(
diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py
index 50118a401c..7648947fca 100644
--- a/api/core/workflow/workflow_entry.py
+++ b/api/core/workflow/workflow_entry.py
@@ -9,6 +9,7 @@ from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.workflow.callbacks import WorkflowCallback
+from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent
@@ -364,4 +365,5 @@ class WorkflowEntry:
input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id)
# append variable and value to variable pool
- variable_pool.add([variable_node_id] + variable_key_list, input_value)
+ if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID:
+ variable_pool.add([variable_node_id] + variable_key_list, input_value)
diff --git a/api/docker/entrypoint.sh b/api/docker/entrypoint.sh
index 68f3c65a4b..18d4f4885d 100755
--- a/api/docker/entrypoint.sh
+++ b/api/docker/entrypoint.sh
@@ -20,7 +20,8 @@ if [[ "${MODE}" == "worker" ]]; then
CONCURRENCY_OPTION="-c ${CELERY_WORKER_AMOUNT:-1}"
fi
- exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION --loglevel ${LOG_LEVEL:-INFO} \
+ exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
+ --max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion}
elif [[ "${MODE}" == "beat" ]]; then
diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py
index be43f55ea7..ddc2158a02 100644
--- a/api/extensions/ext_commands.py
+++ b/api/extensions/ext_commands.py
@@ -5,6 +5,7 @@ def init_app(app: DifyApp):
from commands import (
add_qdrant_index,
clear_free_plan_tenant_expired_logs,
+ clear_orphaned_file_records,
convert_to_agent_apps,
create_tenant,
extract_plugins,
@@ -13,6 +14,7 @@ def init_app(app: DifyApp):
install_plugins,
migrate_data_for_plugin,
old_metadata_migration,
+ remove_orphaned_files_on_storage,
reset_email,
reset_encrypt_key_pair,
reset_password,
@@ -36,6 +38,8 @@ def init_app(app: DifyApp):
install_plugins,
old_metadata_migration,
clear_free_plan_tenant_expired_logs,
+ clear_orphaned_file_records,
+ remove_orphaned_files_on_storage,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)
diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py
index 59ec0d0686..3cbdc8560b 100644
--- a/api/extensions/ext_otel.py
+++ b/api/extensions/ext_otel.py
@@ -6,31 +6,9 @@ import socket
import sys
from typing import Union
+import flask
from celery.signals import worker_init # type: ignore
from flask_login import user_loaded_from_request, user_logged_in # type: ignore
-from opentelemetry import trace
-from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
-from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
-from opentelemetry.instrumentation.celery import CeleryInstrumentor
-from opentelemetry.instrumentation.flask import FlaskInstrumentor
-from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
-from opentelemetry.metrics import get_meter_provider, set_meter_provider
-from opentelemetry.propagate import set_global_textmap
-from opentelemetry.propagators.b3 import B3Format
-from opentelemetry.propagators.composite import CompositePropagator
-from opentelemetry.sdk.metrics import MeterProvider
-from opentelemetry.sdk.metrics.export import ConsoleMetricExporter, PeriodicExportingMetricReader
-from opentelemetry.sdk.resources import Resource
-from opentelemetry.sdk.trace import TracerProvider
-from opentelemetry.sdk.trace.export import (
- BatchSpanProcessor,
- ConsoleSpanExporter,
-)
-from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio
-from opentelemetry.semconv.resource import ResourceAttributes
-from opentelemetry.trace import Span, get_current_span, get_tracer_provider, set_tracer_provider
-from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
-from opentelemetry.trace.status import StatusCode
from configs import dify_config
from dify_app import DifyApp
@@ -39,120 +17,199 @@ from dify_app import DifyApp
@user_logged_in.connect
@user_loaded_from_request.connect
def on_user_loaded(_sender, user):
- if user:
- current_span = get_current_span()
- if current_span:
- current_span.set_attribute("service.tenant.id", user.current_tenant_id)
- current_span.set_attribute("service.user.id", user.id)
+ if dify_config.ENABLE_OTEL:
+ from opentelemetry.trace import get_current_span
+
+ if user:
+ current_span = get_current_span()
+ if current_span:
+ current_span.set_attribute("service.tenant.id", user.current_tenant_id)
+ current_span.set_attribute("service.user.id", user.id)
def init_app(app: DifyApp):
- if dify_config.ENABLE_OTEL:
- setup_context_propagation()
- # Initialize OpenTelemetry
- # Follow Semantic Convertions 1.32.0 to define resource attributes
- resource = Resource(
- attributes={
- ResourceAttributes.SERVICE_NAME: dify_config.APPLICATION_NAME,
- ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}",
- ResourceAttributes.PROCESS_PID: os.getpid(),
- ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
- ResourceAttributes.HOST_NAME: socket.gethostname(),
- ResourceAttributes.HOST_ARCH: platform.machine(),
- "custom.deployment.git_commit": dify_config.COMMIT_SHA,
- ResourceAttributes.HOST_ID: platform.node(),
- ResourceAttributes.OS_TYPE: platform.system().lower(),
- ResourceAttributes.OS_DESCRIPTION: platform.platform(),
- ResourceAttributes.OS_VERSION: platform.version(),
- }
+ from opentelemetry.semconv.trace import SpanAttributes
+
+ def is_celery_worker():
+ return "celery" in sys.argv[0].lower()
+
+ def instrument_exception_logging():
+ exception_handler = ExceptionLoggingHandler()
+ logging.getLogger().addHandler(exception_handler)
+
+ def init_flask_instrumentor(app: DifyApp):
+ meter = get_meter("http_metrics", version=dify_config.CURRENT_VERSION)
+ _http_response_counter = meter.create_counter(
+ "http.server.response.count",
+ description="Total number of HTTP responses by status code, method and target",
+ unit="{response}",
)
- sampler = ParentBasedTraceIdRatio(dify_config.OTEL_SAMPLING_RATE)
- provider = TracerProvider(resource=resource, sampler=sampler)
- set_tracer_provider(provider)
- exporter: Union[OTLPSpanExporter, ConsoleSpanExporter]
- metric_exporter: Union[OTLPMetricExporter, ConsoleMetricExporter]
- if dify_config.OTEL_EXPORTER_TYPE == "otlp":
- exporter = OTLPSpanExporter(
- endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/traces",
- headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"},
- )
- metric_exporter = OTLPMetricExporter(
- endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics",
- headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"},
- )
- else:
- # Fallback to console exporter
- exporter = ConsoleSpanExporter()
- metric_exporter = ConsoleMetricExporter()
-
- provider.add_span_processor(
- BatchSpanProcessor(
- exporter,
- max_queue_size=dify_config.OTEL_MAX_QUEUE_SIZE,
- schedule_delay_millis=dify_config.OTEL_BATCH_EXPORT_SCHEDULE_DELAY,
- max_export_batch_size=dify_config.OTEL_MAX_EXPORT_BATCH_SIZE,
- export_timeout_millis=dify_config.OTEL_BATCH_EXPORT_TIMEOUT,
+
+ def response_hook(span: Span, status: str, response_headers: list):
+ if span and span.is_recording():
+ if status.startswith("2"):
+ span.set_status(StatusCode.OK)
+ else:
+ span.set_status(StatusCode.ERROR, status)
+
+ status = status.split(" ")[0]
+ status_code = int(status)
+ status_class = f"{status_code // 100}xx"
+ attributes: dict[str, str | int] = {"status_code": status_code, "status_class": status_class}
+ request = flask.request
+ if request and request.url_rule:
+ attributes[SpanAttributes.HTTP_TARGET] = str(request.url_rule.rule)
+ if request and request.method:
+ attributes[SpanAttributes.HTTP_METHOD] = str(request.method)
+ _http_response_counter.add(1, attributes)
+
+ instrumentor = FlaskInstrumentor()
+ if dify_config.DEBUG:
+ logging.info("Initializing Flask instrumentor")
+ instrumentor.instrument_app(app, response_hook=response_hook)
+
+ def init_sqlalchemy_instrumentor(app: DifyApp):
+ with app.app_context():
+ engines = list(app.extensions["sqlalchemy"].engines.values())
+ SQLAlchemyInstrumentor().instrument(enable_commenter=True, engines=engines)
+
+ def setup_context_propagation():
+ # Configure propagators
+ set_global_textmap(
+ CompositePropagator(
+ [
+ TraceContextTextMapPropagator(), # W3C trace context
+ B3Format(), # B3 propagation (used by many systems)
+ ]
)
)
- reader = PeriodicExportingMetricReader(
- metric_exporter,
- export_interval_millis=dify_config.OTEL_METRIC_EXPORT_INTERVAL,
- export_timeout_millis=dify_config.OTEL_METRIC_EXPORT_TIMEOUT,
+
+ def shutdown_tracer():
+ provider = trace.get_tracer_provider()
+ if hasattr(provider, "force_flush"):
+ provider.force_flush()
+
+ class ExceptionLoggingHandler(logging.Handler):
+ """Custom logging handler that creates spans for logging.exception() calls"""
+
+ def emit(self, record):
+ try:
+ if record.exc_info:
+ tracer = get_tracer_provider().get_tracer("dify.exception.logging")
+ with tracer.start_as_current_span(
+ "log.exception",
+ attributes={
+ "log.level": record.levelname,
+ "log.message": record.getMessage(),
+ "log.logger": record.name,
+ "log.file.path": record.pathname,
+ "log.file.line": record.lineno,
+ },
+ ) as span:
+ span.set_status(StatusCode.ERROR)
+ span.record_exception(record.exc_info[1])
+ span.set_attribute("exception.type", record.exc_info[0].__name__)
+ span.set_attribute("exception.message", str(record.exc_info[1]))
+ except Exception:
+ pass
+
+ from opentelemetry import trace
+ from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter
+ from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
+ from opentelemetry.instrumentation.celery import CeleryInstrumentor
+ from opentelemetry.instrumentation.flask import FlaskInstrumentor
+ from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor
+ from opentelemetry.metrics import get_meter, get_meter_provider, set_meter_provider
+ from opentelemetry.propagate import set_global_textmap
+ from opentelemetry.propagators.b3 import B3Format
+ from opentelemetry.propagators.composite import CompositePropagator
+ from opentelemetry.sdk.metrics import MeterProvider
+ from opentelemetry.sdk.metrics.export import ConsoleMetricExporter, PeriodicExportingMetricReader
+ from opentelemetry.sdk.resources import Resource
+ from opentelemetry.sdk.trace import TracerProvider
+ from opentelemetry.sdk.trace.export import (
+ BatchSpanProcessor,
+ ConsoleSpanExporter,
+ )
+ from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio
+ from opentelemetry.semconv.resource import ResourceAttributes
+ from opentelemetry.trace import Span, get_tracer_provider, set_tracer_provider
+ from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
+ from opentelemetry.trace.status import StatusCode
+
+ setup_context_propagation()
+ # Initialize OpenTelemetry
+ # Follow Semantic Convertions 1.32.0 to define resource attributes
+ resource = Resource(
+ attributes={
+ ResourceAttributes.SERVICE_NAME: dify_config.APPLICATION_NAME,
+ ResourceAttributes.SERVICE_VERSION: f"dify-{dify_config.CURRENT_VERSION}-{dify_config.COMMIT_SHA}",
+ ResourceAttributes.PROCESS_PID: os.getpid(),
+ ResourceAttributes.DEPLOYMENT_ENVIRONMENT: f"{dify_config.DEPLOY_ENV}-{dify_config.EDITION}",
+ ResourceAttributes.HOST_NAME: socket.gethostname(),
+ ResourceAttributes.HOST_ARCH: platform.machine(),
+ "custom.deployment.git_commit": dify_config.COMMIT_SHA,
+ ResourceAttributes.HOST_ID: platform.node(),
+ ResourceAttributes.OS_TYPE: platform.system().lower(),
+ ResourceAttributes.OS_DESCRIPTION: platform.platform(),
+ ResourceAttributes.OS_VERSION: platform.version(),
+ }
+ )
+ sampler = ParentBasedTraceIdRatio(dify_config.OTEL_SAMPLING_RATE)
+ provider = TracerProvider(resource=resource, sampler=sampler)
+ set_tracer_provider(provider)
+ exporter: Union[OTLPSpanExporter, ConsoleSpanExporter]
+ metric_exporter: Union[OTLPMetricExporter, ConsoleMetricExporter]
+ if dify_config.OTEL_EXPORTER_TYPE == "otlp":
+ exporter = OTLPSpanExporter(
+ endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/traces",
+ headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"},
+ )
+ metric_exporter = OTLPMetricExporter(
+ endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics",
+ headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"},
)
- set_meter_provider(MeterProvider(resource=resource, metric_readers=[reader]))
- if not is_celery_worker():
- init_flask_instrumentor(app)
- CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument()
- init_sqlalchemy_instrumentor(app)
- atexit.register(shutdown_tracer)
-
-
-def is_celery_worker():
- return "celery" in sys.argv[0].lower()
-
-
-def init_flask_instrumentor(app: DifyApp):
- def response_hook(span: Span, status: str, response_headers: list):
- if span and span.is_recording():
- if status.startswith("2"):
- span.set_status(StatusCode.OK)
- else:
- span.set_status(StatusCode.ERROR, status)
-
- instrumentor = FlaskInstrumentor()
- if dify_config.DEBUG:
- logging.info("Initializing Flask instrumentor")
- instrumentor.instrument_app(app, response_hook=response_hook)
-
-
-def init_sqlalchemy_instrumentor(app: DifyApp):
- with app.app_context():
- engines = list(app.extensions["sqlalchemy"].engines.values())
- SQLAlchemyInstrumentor().instrument(enable_commenter=True, engines=engines)
-
-
-def setup_context_propagation():
- # Configure propagators
- set_global_textmap(
- CompositePropagator(
- [
- TraceContextTextMapPropagator(), # W3C trace context
- B3Format(), # B3 propagation (used by many systems)
- ]
+ else:
+ # Fallback to console exporter
+ exporter = ConsoleSpanExporter()
+ metric_exporter = ConsoleMetricExporter()
+
+ provider.add_span_processor(
+ BatchSpanProcessor(
+ exporter,
+ max_queue_size=dify_config.OTEL_MAX_QUEUE_SIZE,
+ schedule_delay_millis=dify_config.OTEL_BATCH_EXPORT_SCHEDULE_DELAY,
+ max_export_batch_size=dify_config.OTEL_MAX_EXPORT_BATCH_SIZE,
+ export_timeout_millis=dify_config.OTEL_BATCH_EXPORT_TIMEOUT,
)
)
+ reader = PeriodicExportingMetricReader(
+ metric_exporter,
+ export_interval_millis=dify_config.OTEL_METRIC_EXPORT_INTERVAL,
+ export_timeout_millis=dify_config.OTEL_METRIC_EXPORT_TIMEOUT,
+ )
+ set_meter_provider(MeterProvider(resource=resource, metric_readers=[reader]))
+ if not is_celery_worker():
+ init_flask_instrumentor(app)
+ CeleryInstrumentor(tracer_provider=get_tracer_provider(), meter_provider=get_meter_provider()).instrument()
+ instrument_exception_logging()
+ init_sqlalchemy_instrumentor(app)
+ atexit.register(shutdown_tracer)
-@worker_init.connect(weak=False)
-def init_celery_worker(*args, **kwargs):
- tracer_provider = get_tracer_provider()
- metric_provider = get_meter_provider()
- if dify_config.DEBUG:
- logging.info("Initializing OpenTelemetry for Celery worker")
- CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument()
+def is_enabled():
+ return dify_config.ENABLE_OTEL
-def shutdown_tracer():
- provider = trace.get_tracer_provider()
- if hasattr(provider, "force_flush"):
- provider.force_flush()
+@worker_init.connect(weak=False)
+def init_celery_worker(*args, **kwargs):
+ if dify_config.ENABLE_OTEL:
+ from opentelemetry.instrumentation.celery import CeleryInstrumentor
+ from opentelemetry.metrics import get_meter_provider
+ from opentelemetry.trace import get_tracer_provider
+
+ tracer_provider = get_tracer_provider()
+ metric_provider = get_meter_provider()
+ if dify_config.DEBUG:
+ logging.info("Initializing OpenTelemetry for Celery worker")
+ CeleryInstrumentor(tracer_provider=tracer_provider, meter_provider=metric_provider).instrument()
diff --git a/api/extensions/ext_repositories.py b/api/extensions/ext_repositories.py
deleted file mode 100644
index 27d8408ec1..0000000000
--- a/api/extensions/ext_repositories.py
+++ /dev/null
@@ -1,18 +0,0 @@
-"""
-Extension for initializing repositories.
-
-This extension registers repository implementations with the RepositoryFactory.
-"""
-
-from dify_app import DifyApp
-from repositories.repository_registry import register_repositories
-
-
-def init_app(_app: DifyApp) -> None:
- """
- Initialize repository implementations.
-
- Args:
- _app: The Flask application instance (unused)
- """
- register_repositories()
diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py
index 4c811c66ba..bd35278544 100644
--- a/api/extensions/ext_storage.py
+++ b/api/extensions/ext_storage.py
@@ -102,6 +102,9 @@ class Storage:
def delete(self, filename):
return self.storage_runner.delete(filename)
+ def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
+ return self.storage_runner.scan(path, files=files, directories=directories)
+
storage = Storage()
diff --git a/api/extensions/storage/base_storage.py b/api/extensions/storage/base_storage.py
index 0dedd7ff8c..0393206e54 100644
--- a/api/extensions/storage/base_storage.py
+++ b/api/extensions/storage/base_storage.py
@@ -30,3 +30,11 @@ class BaseStorage(ABC):
@abstractmethod
def delete(self, filename):
raise NotImplementedError
+
+ def scan(self, path, files=True, directories=False) -> list[str]:
+ """
+ Scan files and directories in the given path.
+ This method is implemented only in some storage backends.
+ If a storage backend doesn't support scanning, it will raise NotImplementedError.
+ """
+ raise NotImplementedError("This storage backend doesn't support scanning")
diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py
index ee8cfa9179..12e2738e9d 100644
--- a/api/extensions/storage/opendal_storage.py
+++ b/api/extensions/storage/opendal_storage.py
@@ -80,3 +80,20 @@ class OpenDALStorage(BaseStorage):
logger.debug(f"file {filename} deleted")
return
logger.debug(f"file {filename} not found, skip delete")
+
+ def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
+ if not self.exists(path):
+ raise FileNotFoundError("Path not found")
+
+ all_files = self.op.scan(path=path)
+ if files and directories:
+ logger.debug(f"files and directories on {path} scanned")
+ return [f.path for f in all_files]
+ if files:
+ logger.debug(f"files on {path} scanned")
+ return [f.path for f in all_files if not f.path.endswith("/")]
+ elif directories:
+ logger.debug(f"directories on {path} scanned")
+ return [f.path for f in all_files if f.path.endswith("/")]
+ else:
+ raise ValueError("At least one of files or directories must be True")
diff --git a/api/factories/agent_factory.py b/api/factories/agent_factory.py
index 4b2d2cc769..4b12afb528 100644
--- a/api/factories/agent_factory.py
+++ b/api/factories/agent_factory.py
@@ -1,12 +1,12 @@
from core.agent.strategy.plugin import PluginAgentStrategy
-from core.plugin.manager.agent import PluginAgentManager
+from core.plugin.impl.agent import PluginAgentClient
def get_plugin_agent_strategy(
tenant_id: str, agent_strategy_provider_name: str, agent_strategy_name: str
) -> PluginAgentStrategy:
# TODO: use contexts to cache the agent provider
- manager = PluginAgentManager()
+ manager = PluginAgentClient()
agent_provider = manager.fetch_agent_strategy_provider(tenant_id, agent_strategy_provider_name)
for agent_strategy in agent_provider.declaration.strategies:
if agent_strategy.identity.name == agent_strategy_name:
diff --git a/api/fields/annotation_fields.py b/api/fields/annotation_fields.py
index 1c58b3a257..379dcc6d16 100644
--- a/api/fields/annotation_fields.py
+++ b/api/fields/annotation_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from libs.helper import TimestampField
diff --git a/api/fields/api_based_extension_fields.py b/api/fields/api_based_extension_fields.py
index d40407bfcc..a85d4a34db 100644
--- a/api/fields/api_based_extension_fields.py
+++ b/api/fields/api_based_extension_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from libs.helper import TimestampField
diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py
index f42364f110..0b0e2a2f54 100644
--- a/api/fields/app_fields.py
+++ b/api/fields/app_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from fields.workflow_fields import workflow_partial_fields
from libs.helper import AppIconUrlField, TimestampField
diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py
index 78e0794833..370e8a5a58 100644
--- a/api/fields/conversation_fields.py
+++ b/api/fields/conversation_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py
index c6385efb5a..71785e7d67 100644
--- a/api/fields/conversation_variable_fields.py
+++ b/api/fields/conversation_variable_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from libs.helper import TimestampField
@@ -19,3 +19,9 @@ paginated_conversation_variable_fields = {
"has_more": fields.Boolean,
"data": fields.List(fields.Nested(conversation_variable_fields), attribute="data"),
}
+
+conversation_variable_infinite_scroll_pagination_fields = {
+ "limit": fields.Integer,
+ "has_more": fields.Boolean,
+ "data": fields.List(fields.Nested(conversation_variable_fields)),
+}
diff --git a/api/fields/data_source_fields.py b/api/fields/data_source_fields.py
index 608672121e..071071376f 100644
--- a/api/fields/data_source_fields.py
+++ b/api/fields/data_source_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from libs.helper import TimestampField
diff --git a/api/fields/dataset_fields.py b/api/fields/dataset_fields.py
index 67d183c70d..32a88cc5db 100644
--- a/api/fields/dataset_fields.py
+++ b/api/fields/dataset_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from libs.helper import TimestampField
diff --git a/api/fields/document_fields.py b/api/fields/document_fields.py
index 6d59ee9baa..7fd43e8dbe 100644
--- a/api/fields/document_fields.py
+++ b/api/fields/document_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from fields.dataset_fields import dataset_fields
from libs.helper import TimestampField
diff --git a/api/fields/end_user_fields.py b/api/fields/end_user_fields.py
index aefa0b2758..99e529f9d1 100644
--- a/api/fields/end_user_fields.py
+++ b/api/fields/end_user_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
simple_end_user_fields = {
"id": fields.String,
diff --git a/api/fields/file_fields.py b/api/fields/file_fields.py
index f896c15f0f..8b4839ef97 100644
--- a/api/fields/file_fields.py
+++ b/api/fields/file_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from libs.helper import TimestampField
@@ -19,6 +19,7 @@ file_fields = {
"mime_type": fields.String,
"created_by": fields.String,
"created_at": TimestampField,
+ "preview_url": fields.String,
}
remote_file_info_fields = {
diff --git a/api/fields/hit_testing_fields.py b/api/fields/hit_testing_fields.py
index 4514c1b8ca..9d67999ea4 100644
--- a/api/fields/hit_testing_fields.py
+++ b/api/fields/hit_testing_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from libs.helper import TimestampField
diff --git a/api/fields/installed_app_fields.py b/api/fields/installed_app_fields.py
index 16f265b9bb..e0b3e340f6 100644
--- a/api/fields/installed_app_fields.py
+++ b/api/fields/installed_app_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from libs.helper import AppIconUrlField, TimestampField
diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py
index 0900bffb8a..8007b7e052 100644
--- a/api/fields/member_fields.py
+++ b/api/fields/member_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from libs.helper import AvatarUrlField, TimestampField
diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py
index 76e61f0707..e6aebd810f 100644
--- a/api/fields/message_fields.py
+++ b/api/fields/message_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from fields.conversation_fields import message_file_fields
from libs.helper import TimestampField
diff --git a/api/fields/raws.py b/api/fields/raws.py
index 493d4b6cce..15ec16ab13 100644
--- a/api/fields/raws.py
+++ b/api/fields/raws.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from core.file import File
diff --git a/api/fields/segment_fields.py b/api/fields/segment_fields.py
index 82311e5bb9..4126c24598 100644
--- a/api/fields/segment_fields.py
+++ b/api/fields/segment_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from libs.helper import TimestampField
diff --git a/api/fields/tag_fields.py b/api/fields/tag_fields.py
index 986cd725f7..9af4fc57dd 100644
--- a/api/fields/tag_fields.py
+++ b/api/fields/tag_fields.py
@@ -1,3 +1,3 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String, "binding_count": fields.String}
diff --git a/api/fields/workflow_app_log_fields.py b/api/fields/workflow_app_log_fields.py
index e8f8684ae0..823c99ec6b 100644
--- a/api/fields/workflow_app_log_fields.py
+++ b/api/fields/workflow_app_log_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py
index 971e99c259..9f1bef3b36 100644
--- a/api/fields/workflow_fields.py
+++ b/api/fields/workflow_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from core.helper import encrypter
from core.variables import SecretVariable, SegmentType, Variable
diff --git a/api/fields/workflow_run_fields.py b/api/fields/workflow_run_fields.py
index ef59c57ec3..74fdf8bd97 100644
--- a/api/fields/workflow_run_fields.py
+++ b/api/fields/workflow_run_fields.py
@@ -1,4 +1,4 @@
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
diff --git a/api/libs/external_api.py b/api/libs/external_api.py
index 922d2d9cd3..2070df3e55 100644
--- a/api/libs/external_api.py
+++ b/api/libs/external_api.py
@@ -3,7 +3,7 @@ import sys
from typing import Any
from flask import current_app, got_request_exception
-from flask_restful import Api, http_status_message # type: ignore
+from flask_restful import Api, http_status_message
from werkzeug.datastructures import Headers
from werkzeug.exceptions import HTTPException
diff --git a/api/libs/helper.py b/api/libs/helper.py
index f0325734d8..afc8f31681 100644
--- a/api/libs/helper.py
+++ b/api/libs/helper.py
@@ -13,7 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional, Union, cast
from zoneinfo import available_timezones
from flask import Response, stream_with_context
-from flask_restful import fields # type: ignore
+from flask_restful import fields
from configs import dify_config
from core.app.features.rate_limiting.rate_limit import RateLimitGenerator
diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py
index a5ba08d351..1c151633f0 100644
--- a/api/libs/oauth_data_source.py
+++ b/api/libs/oauth_data_source.py
@@ -3,7 +3,7 @@ import urllib.parse
from typing import Any
import requests
-from flask_login import current_user # type: ignore
+from flask_login import current_user
from extensions.ext_database import db
from models.source import DataSourceOauthBinding
diff --git a/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py b/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py
index c17d1db77a..00f2b15802 100644
--- a/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py
+++ b/api/migrations/versions/2024_10_10_0516-bbadea11becb_add_name_and_size_to_tool_files.py
@@ -5,45 +5,61 @@ Revises: 33f5fac87f29
Create Date: 2024-10-10 05:16:14.764268
"""
-from alembic import op
-import models as models
+
import sqlalchemy as sa
-from sqlalchemy.dialects import postgresql
+from alembic import op, context
# revision identifiers, used by Alembic.
-revision = 'bbadea11becb'
-down_revision = 'd8e744d88ed6'
+revision = "bbadea11becb"
+down_revision = "d8e744d88ed6"
branch_labels = None
depends_on = None
def upgrade():
+ def _has_name_or_size_column() -> bool:
+ # We cannot access the database in offline mode, so assume
+ # the "name" and "size" columns do not exist.
+ if context.is_offline_mode():
+ # Log a warning message to inform the user that the database schema cannot be inspected
+ # in offline mode, and the generated SQL may not accurately reflect the actual execution.
+ op.execute(
+ "-- Executing in offline mode, assuming the name and size columns do not exist.\n"
+ "-- The generated SQL may differ from what will actually be executed.\n"
+ "-- Please review the migration script carefully!"
+ )
+
+ return False
+ # Use SQLAlchemy inspector to get the columns of the 'tool_files' table
+ inspector = sa.inspect(conn)
+ columns = [col["name"] for col in inspector.get_columns("tool_files")]
+
+ # If 'name' or 'size' columns already exist, exit the upgrade function
+ if "name" in columns or "size" in columns:
+ return True
+ return False
+
# ### commands auto generated by Alembic - please adjust! ###
# Get the database connection
conn = op.get_bind()
-
- # Use SQLAlchemy inspector to get the columns of the 'tool_files' table
- inspector = sa.inspect(conn)
- columns = [col['name'] for col in inspector.get_columns('tool_files')]
-
- # If 'name' or 'size' columns already exist, exit the upgrade function
- if 'name' in columns or 'size' in columns:
- return
-
- with op.batch_alter_table('tool_files', schema=None) as batch_op:
- batch_op.add_column(sa.Column('name', sa.String(), nullable=True))
- batch_op.add_column(sa.Column('size', sa.Integer(), nullable=True))
+
+ if _has_name_or_size_column():
+ return
+
+ with op.batch_alter_table("tool_files", schema=None) as batch_op:
+ batch_op.add_column(sa.Column("name", sa.String(), nullable=True))
+ batch_op.add_column(sa.Column("size", sa.Integer(), nullable=True))
op.execute("UPDATE tool_files SET name = '' WHERE name IS NULL")
op.execute("UPDATE tool_files SET size = -1 WHERE size IS NULL")
- with op.batch_alter_table('tool_files', schema=None) as batch_op:
- batch_op.alter_column('name', existing_type=sa.String(), nullable=False)
- batch_op.alter_column('size', existing_type=sa.Integer(), nullable=False)
+ with op.batch_alter_table("tool_files", schema=None) as batch_op:
+ batch_op.alter_column("name", existing_type=sa.String(), nullable=False)
+ batch_op.alter_column("size", existing_type=sa.Integer(), nullable=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
- with op.batch_alter_table('tool_files', schema=None) as batch_op:
- batch_op.drop_column('size')
- batch_op.drop_column('name')
+ with op.batch_alter_table("tool_files", schema=None) as batch_op:
+ batch_op.drop_column("size")
+ batch_op.drop_column("name")
# ### end Alembic commands ###
diff --git a/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py b/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py
index 0facd0ecc0..ae9f2de9b1 100644
--- a/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py
+++ b/api/migrations/versions/2024_12_20_0628-e1944c35e15e_add_retry_index_field_to_node_execution_.py
@@ -35,4 +35,4 @@ def downgrade():
# batch_op.drop_column('retry_index')
pass
- # ### end Alembic commands ###
\ No newline at end of file
+ # ### end Alembic commands ###
diff --git a/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py b/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py
index ea129d15f7..adf6421e57 100644
--- a/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py
+++ b/api/migrations/versions/2024_12_23_1154-d7999dfa4aae_remove_workflow_node_executions_retry_.py
@@ -5,28 +5,38 @@ Revises: e1944c35e15e
Create Date: 2024-12-23 11:54:15.344543
"""
-from alembic import op
-import models as models
-import sqlalchemy as sa
+
+from alembic import op, context
from sqlalchemy import inspect
# revision identifiers, used by Alembic.
-revision = 'd7999dfa4aae'
-down_revision = 'e1944c35e15e'
+revision = "d7999dfa4aae"
+down_revision = "e1944c35e15e"
branch_labels = None
depends_on = None
def upgrade():
- # Check if column exists before attempting to remove it
- conn = op.get_bind()
- inspector = inspect(conn)
- has_column = 'retry_index' in [col['name'] for col in inspector.get_columns('workflow_node_executions')]
-
+ def _has_retry_index_column() -> bool:
+ if context.is_offline_mode():
+ # Log a warning message to inform the user that the database schema cannot be inspected
+ # in offline mode, and the generated SQL may not accurately reflect the actual execution.
+ op.execute(
+ '-- Executing in offline mode: assuming the "retry_index" column does not exist.\n'
+ "-- The generated SQL may differ from what will actually be executed.\n"
+ "-- Please review the migration script carefully!"
+ )
+ return False
+ conn = op.get_bind()
+ inspector = inspect(conn)
+ return "retry_index" in [col["name"] for col in inspector.get_columns("workflow_node_executions")]
+
+ has_column = _has_retry_index_column()
+
if has_column:
- with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
- batch_op.drop_column('retry_index')
+ with op.batch_alter_table("workflow_node_executions", schema=None) as batch_op:
+ batch_op.drop_column("retry_index")
def downgrade():
diff --git a/api/migrations/versions/64b051264f32_init.py b/api/migrations/versions/64b051264f32_init.py
index 8c45ae898d..b0fb3deac6 100644
--- a/api/migrations/versions/64b051264f32_init.py
+++ b/api/migrations/versions/64b051264f32_init.py
@@ -1,7 +1,7 @@
"""init
Revision ID: 64b051264f32
-Revises:
+Revises:
Create Date: 2023-05-13 14:26:59.085018
"""
diff --git a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py
index fcca705d21..c18126286c 100644
--- a/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py
+++ b/api/migrations/versions/de95f5c77138_migration_serpapi_api_key.py
@@ -99,12 +99,12 @@ def upgrade():
id=id,
tenant_id=tenant_id,
user_id=user_id,
- provider='google',
+ provider='google',
encrypted_credentials=encrypted_credentials,
created_at=created_at,
updated_at=updated_at
)
-
+
# ### end Alembic commands ###
diff --git a/api/models/engine.py b/api/models/engine.py
index dda93bc941..05c1cacdcb 100644
--- a/api/models/engine.py
+++ b/api/models/engine.py
@@ -10,4 +10,16 @@ POSTGRES_INDEXES_NAMING_CONVENTION = {
}
metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)
+
+# ****** IMPORTANT NOTICE ******
+#
+# NOTE(QuantumGhost): Avoid directly importing and using `db` in modules outside of the
+# `controllers` package.
+#
+# Instead, import `db` within the `controllers` package and pass it as an argument to
+# functions or class constructors.
+#
+# Directly importing `db` in other modules can make the code more difficult to read, test, and maintain.
+#
+# Whenever possible, avoid this pattern in new code.
db = SQLAlchemy(metadata=metadata)
diff --git a/api/models/model.py b/api/models/model.py
index 6577492d1b..fd05d67e9a 100644
--- a/api/models/model.py
+++ b/api/models/model.py
@@ -3,19 +3,17 @@ import re
import uuid
from collections.abc import Mapping
from datetime import datetime
-from enum import Enum
-from typing import TYPE_CHECKING, Optional
+from enum import Enum, StrEnum
+from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from core.plugin.entities.plugin import GenericProviderID
from core.tools.entities.tool_entities import ToolProviderType
+from core.tools.signature import sign_tool_file
from services.plugin.plugin_service import PluginService
if TYPE_CHECKING:
from models.workflow import Workflow
-from enum import StrEnum
-from typing import TYPE_CHECKING, Any, Literal, cast
-
import sqlalchemy as sa
from flask import request
from flask_login import UserMixin # type: ignore
@@ -26,7 +24,6 @@ from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import helpers as file_helpers
-from core.file.tool_file_parser import ToolFileParser
from libs.helper import generate_string
from models.base import Base
from models.enums import CreatedByRole
@@ -989,9 +986,7 @@ class Message(db.Model): # type: ignore[name-defined]
if not tool_file_id:
continue
- sign_url = ToolFileParser.get_tool_file_manager().sign_file(
- tool_file_id=tool_file_id, extension=extension
- )
+ sign_url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
elif "file-preview" in url:
# get upload file id
upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp="
@@ -1015,7 +1010,9 @@ class Message(db.Model): # type: ignore[name-defined]
sign_url = file_helpers.get_signed_file_url(upload_file_id)
else:
continue
-
+ # if as_attachment is in the url, add it to the sign_url.
+ if "as_attachment" in url:
+ sign_url += "&as_attachment=true"
re_sign_file_url_answer = re_sign_file_url_answer.replace(url, sign_url)
return re_sign_file_url_answer
diff --git a/api/models/tools.py b/api/models/tools.py
index aef1490729..e027475e38 100644
--- a/api/models/tools.py
+++ b/api/models/tools.py
@@ -1,6 +1,6 @@
import json
from datetime import datetime
-from typing import Any, Optional, cast
+from typing import Any, cast
import sqlalchemy as sa
from deprecated import deprecated
@@ -263,8 +263,8 @@ class ToolConversationVariables(Base):
class ToolFile(Base):
- """
- store the file created by agent
+ """This table stores file metadata generated in workflows,
+ not only files created by agent.
"""
__tablename__ = "tool_files"
@@ -304,8 +304,11 @@ class DeprecatedPublishedAppTool(Base):
db.UniqueConstraint("app_id", "user_id", name="unique_published_app_tool"),
)
+ id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
# id of the app
app_id = db.Column(StringUUID, ForeignKey("apps.id"), nullable=False)
+
+ user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
# who published this tool
description = db.Column(db.Text, nullable=False)
# llm_description of the tool, for LLM
@@ -325,34 +328,3 @@ class DeprecatedPublishedAppTool(Base):
@property
def description_i18n(self) -> I18nObject:
return I18nObject(**json.loads(self.description))
-
- id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
- user_id: Mapped[str] = db.Column(StringUUID, nullable=False)
- tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False)
- conversation_id: Mapped[Optional[str]] = db.Column(StringUUID, nullable=True)
- file_key: Mapped[str] = db.Column(db.String(255), nullable=False)
- mimetype: Mapped[str] = db.Column(db.String(255), nullable=False)
- original_url: Mapped[Optional[str]] = db.Column(db.String(2048), nullable=True)
- name: Mapped[str] = mapped_column(default="")
- size: Mapped[int] = mapped_column(default=-1)
-
- def __init__(
- self,
- *,
- user_id: str,
- tenant_id: str,
- conversation_id: Optional[str] = None,
- file_key: str,
- mimetype: str,
- original_url: Optional[str] = None,
- name: str,
- size: int,
- ):
- self.user_id = user_id
- self.tenant_id = tenant_id
- self.conversation_id = conversation_id
- self.file_key = file_key
- self.mimetype = mimetype
- self.original_url = original_url
- self.name = name
- self.size = size
diff --git a/api/models/workflow.py b/api/models/workflow.py
index 51f2f4cc9f..da60617de5 100644
--- a/api/models/workflow.py
+++ b/api/models/workflow.py
@@ -1,14 +1,12 @@
import json
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
-from enum import Enum
+from enum import Enum, StrEnum
from typing import TYPE_CHECKING, Any, Optional, Self, Union
from uuid import uuid4
if TYPE_CHECKING:
from models.model import AppMode
-from enum import StrEnum
-from typing import TYPE_CHECKING
import sqlalchemy as sa
from sqlalchemy import Index, PrimaryKeyConstraint, func
@@ -245,6 +243,13 @@ class Workflow(Base):
@property
def tool_published(self) -> bool:
+ """
+ DEPRECATED: This property is not accurate for determining if a workflow is published as a tool.
+ It only checks if there's a WorkflowToolProvider for the app, not if this specific workflow version
+ is the one being used by the tool.
+
+ For accurate checking, use a direct query with tenant_id, app_id, and version.
+ """
from models.tools import WorkflowToolProvider
return (
diff --git a/api/mypy.ini b/api/mypy.ini
index 2898b9b52d..865be3c17d 100644
--- a/api/mypy.ini
+++ b/api/mypy.ini
@@ -7,3 +7,13 @@ exclude = (?x)(
| tests/
| migrations/
)
+
+[mypy-flask_login]
+ignore_missing_imports=True
+
+[mypy-flask_restful]
+ignore_missing_imports=True
+
+[mypy-flask_restful.inputs]
+ignore_missing_imports=True
+
diff --git a/api/pyproject.toml b/api/pyproject.toml
index 08f9c1e229..65315e9be7 100644
--- a/api/pyproject.toml
+++ b/api/pyproject.toml
@@ -1,6 +1,6 @@
[project]
name = "dify-api"
-version = "1.2.0"
+dynamic = ["version"]
requires-python = ">=3.11,<3.13"
dependencies = [
@@ -10,7 +10,7 @@ dependencies = [
"boto3==1.35.99",
"bs4~=0.0.1",
"cachetools~=5.3.0",
- "celery~=5.4.0",
+ "celery~=5.5.2",
"chardet~=5.1.0",
"flask~=3.1.0",
"flask-compress~=1.17",
@@ -77,18 +77,23 @@ dependencies = [
"sentry-sdk[flask]~=1.44.1",
"sqlalchemy~=2.0.29",
"starlette==0.41.0",
- "tiktoken~=0.8.0",
+ "tiktoken~=0.9.0",
"tokenizers~=0.15.0",
"transformers~=4.35.0",
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
- "validators==0.21.0",
+ "weave~=0.51.34",
"yarl~=1.18.3",
+ "webvtt-py~=0.5.1",
]
# Before adding new dependency, consider place it in
# alphabet order (a-z) and suitable group.
+[tool.setuptools]
+packages = []
+
[tool.uv]
default-groups = ["storage", "tools", "vdb"]
+package = false
[dependency-groups]
@@ -115,6 +120,7 @@ dev = [
"types-defusedxml~=0.7.0",
"types-deprecated~=1.2.15",
"types-docutils~=0.21.0",
+ "types-jsonschema~=4.23.0",
"types-flask-cors~=5.0.0",
"types-flask-migrate~=4.1.0",
"types-gevent~=24.11.0",
@@ -178,7 +184,7 @@ vdb = [
"couchbase~=4.3.0",
"elasticsearch==8.14.0",
"opensearch-py==2.4.0",
- "oracledb~=2.2.1",
+ "oracledb==3.0.0",
"pgvecto-rs[sqlalchemy]~=0.2.1",
"pgvector==0.2.5",
"pymilvus~=2.5.0",
@@ -190,6 +196,6 @@ vdb = [
"tidb-vector==0.0.9",
"upstash-vector==0.6.0",
"volcengine-compat~=1.0.156",
- "weaviate-client~=3.21.0",
+ "weaviate-client~=3.24.0",
"xinference-client~=1.2.2",
]
diff --git a/api/repositories/__init__.py b/api/repositories/__init__.py
deleted file mode 100644
index 4cc339688b..0000000000
--- a/api/repositories/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-"""
-Repository implementations for data access.
-
-This package contains concrete implementations of the repository interfaces
-defined in the core.repository package.
-"""
diff --git a/api/repositories/repository_registry.py b/api/repositories/repository_registry.py
deleted file mode 100644
index aa0a208d8e..0000000000
--- a/api/repositories/repository_registry.py
+++ /dev/null
@@ -1,87 +0,0 @@
-"""
-Registry for repository implementations.
-
-This module is responsible for registering factory functions with the repository factory.
-"""
-
-import logging
-from collections.abc import Mapping
-from typing import Any
-
-from sqlalchemy.orm import sessionmaker
-
-from configs import dify_config
-from core.repository.repository_factory import RepositoryFactory
-from extensions.ext_database import db
-from repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository
-
-logger = logging.getLogger(__name__)
-
-# Storage type constants
-STORAGE_TYPE_RDBMS = "rdbms"
-STORAGE_TYPE_HYBRID = "hybrid"
-
-
-def register_repositories() -> None:
- """
- Register repository factory functions with the RepositoryFactory.
-
- This function reads configuration settings to determine which repository
- implementations to register.
- """
- # Configure WorkflowNodeExecutionRepository factory based on configuration
- workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE
-
- # Check storage type and register appropriate implementation
- if workflow_node_execution_storage == STORAGE_TYPE_RDBMS:
- # Register SQLAlchemy implementation for RDBMS storage
- logger.info("Registering WorkflowNodeExecution repository with RDBMS storage")
- RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository)
- elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID:
- # Hybrid storage is not yet implemented
- raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented")
- else:
- # Unknown storage type
- raise ValueError(
- f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. "
- f"Supported types: {STORAGE_TYPE_RDBMS}"
- )
-
-
-def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository:
- """
- Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation.
-
- This factory function creates a repository for the RDBMS storage type.
-
- Args:
- params: Parameters for creating the repository, including:
- - tenant_id: Required. The tenant ID for multi-tenancy.
- - app_id: Optional. The application ID for filtering.
- - session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided,
- a new sessionmaker will be created using the global database engine.
-
- Returns:
- A WorkflowNodeExecutionRepository instance
-
- Raises:
- ValueError: If required parameters are missing
- """
- # Extract required parameters
- tenant_id = params.get("tenant_id")
- if tenant_id is None:
- raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage")
-
- # Extract optional parameters
- app_id = params.get("app_id")
-
- # Use the session_factory from params if provided, otherwise create one using the global db engine
- session_factory = params.get("session_factory")
- if session_factory is None:
- # Create a sessionmaker using the same engine as the global db session
- session_factory = sessionmaker(bind=db.engine)
-
- # Create and return the repository
- return SQLAlchemyWorkflowNodeExecutionRepository(
- session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
- )
diff --git a/api/repositories/workflow_node_execution/__init__.py b/api/repositories/workflow_node_execution/__init__.py
deleted file mode 100644
index eed827bd05..0000000000
--- a/api/repositories/workflow_node_execution/__init__.py
+++ /dev/null
@@ -1,9 +0,0 @@
-"""
-WorkflowNodeExecution repository implementations.
-"""
-
-from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
-
-__all__ = [
- "SQLAlchemyWorkflowNodeExecutionRepository",
-]
diff --git a/api/services/agent_service.py b/api/services/agent_service.py
index 0ff144052f..503b31ede2 100644
--- a/api/services/agent_service.py
+++ b/api/services/agent_service.py
@@ -2,12 +2,12 @@ import threading
from typing import Optional
import pytz
-from flask_login import current_user # type: ignore
+from flask_login import current_user
import contexts
from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager
-from core.plugin.manager.agent import PluginAgentManager
-from core.plugin.manager.exc import PluginDaemonClientSideError
+from core.plugin.impl.agent import PluginAgentClient
+from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from models.account import Account
@@ -161,7 +161,7 @@ class AgentService:
"""
List agent providers
"""
- manager = PluginAgentManager()
+ manager = PluginAgentClient()
return manager.fetch_agent_strategy_providers(tenant_id)
@classmethod
@@ -169,7 +169,7 @@ class AgentService:
"""
Get agent provider
"""
- manager = PluginAgentManager()
+ manager = PluginAgentClient()
try:
return manager.fetch_agent_strategy_provider(tenant_id, provider_name)
except PluginDaemonClientSideError as e:
diff --git a/api/services/annotation_service.py b/api/services/annotation_service.py
index 45ec1e9b5a..ae7b372b82 100644
--- a/api/services/annotation_service.py
+++ b/api/services/annotation_service.py
@@ -3,7 +3,7 @@ import uuid
from typing import cast
import pandas as pd
-from flask_login import current_user # type: ignore
+from flask_login import current_user
from sqlalchemy import or_
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py
index 2e2b729021..a2775fe6ad 100644
--- a/api/services/app_dsl_service.py
+++ b/api/services/app_dsl_service.py
@@ -40,7 +40,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
-CURRENT_DSL_VERSION = "0.1.5"
+CURRENT_DSL_VERSION = "0.2.0"
class ImportMode(StrEnum):
@@ -77,13 +77,19 @@ def _check_version_compatibility(imported_version: str) -> ImportStatus:
except version.InvalidVersion:
return ImportStatus.FAILED
- # Compare major version and minor version
- if current_ver.major != imported_ver.major or current_ver.minor != imported_ver.minor:
+ # If imported version is newer than current, always return PENDING
+ if imported_ver > current_ver:
return ImportStatus.PENDING
- if current_ver.micro != imported_ver.micro:
+ # If imported version is older than current's major, return PENDING
+ if imported_ver.major < current_ver.major:
+ return ImportStatus.PENDING
+
+ # If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS
+ if imported_ver.minor < current_ver.minor:
return ImportStatus.COMPLETED_WITH_WARNINGS
+ # If imported version equals or is older than current's micro, return COMPLETED
return ImportStatus.COMPLETED
diff --git a/api/services/app_service.py b/api/services/app_service.py
index e87a1c7931..2fae479e05 100644
--- a/api/services/app_service.py
+++ b/api/services/app_service.py
@@ -3,7 +3,7 @@ import logging
from datetime import UTC, datetime
from typing import Optional, cast
-from flask_login import current_user # type: ignore
+from flask_login import current_user
from flask_sqlalchemy.pagination import Pagination
from configs import dify_config
diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py
index 6485cbf37d..afdaa49465 100644
--- a/api/services/conversation_service.py
+++ b/api/services/conversation_service.py
@@ -9,9 +9,14 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.llm_generator.llm_generator import LLMGenerator
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
+from models import ConversationVariable
from models.account import Account
from models.model import App, Conversation, EndUser, Message
-from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
+from services.errors.conversation import (
+ ConversationNotExistsError,
+ ConversationVariableNotExistsError,
+ LastConversationNotExistsError,
+)
from services.errors.message import MessageNotExistsError
@@ -166,3 +171,50 @@ class ConversationService:
conversation.is_deleted = True
conversation.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
+
+ @classmethod
+ def get_conversational_variable(
+ cls,
+ app_model: App,
+ conversation_id: str,
+ user: Optional[Union[Account, EndUser]],
+ limit: int,
+ last_id: Optional[str],
+ ) -> InfiniteScrollPagination:
+ conversation = cls.get_conversation(app_model, conversation_id, user)
+
+ stmt = (
+ select(ConversationVariable)
+ .where(ConversationVariable.app_id == app_model.id)
+ .where(ConversationVariable.conversation_id == conversation.id)
+ .order_by(ConversationVariable.created_at)
+ )
+
+ with Session(db.engine) as session:
+ if last_id:
+ last_variable = session.scalar(stmt.where(ConversationVariable.id == last_id))
+ if not last_variable:
+ raise ConversationVariableNotExistsError()
+
+ # Filter for variables created after the last_id
+ stmt = stmt.where(ConversationVariable.created_at > last_variable.created_at)
+
+ # Apply limit to query
+ query_stmt = stmt.limit(limit) # Get one extra to check if there are more
+ rows = session.scalars(query_stmt).all()
+
+ has_more = False
+ if len(rows) > limit:
+ has_more = True
+ rows = rows[:limit] # Remove the extra item
+
+ variables = [
+ {
+ "created_at": row.created_at,
+ "updated_at": row.updated_at,
+ **row.to_variable().model_dump(),
+ }
+ for row in rows
+ ]
+
+ return InfiniteScrollPagination(variables, limit, has_more)
diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py
index 44d2594ee8..de90355ebf 100644
--- a/api/services/dataset_service.py
+++ b/api/services/dataset_service.py
@@ -8,7 +8,7 @@ import uuid
from collections import Counter
from typing import Any, Optional
-from flask_login import current_user # type: ignore
+from flask_login import current_user
from sqlalchemy import func
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
diff --git a/api/services/errors/conversation.py b/api/services/errors/conversation.py
index 139dd9a70a..f8051e3417 100644
--- a/api/services/errors/conversation.py
+++ b/api/services/errors/conversation.py
@@ -11,3 +11,7 @@ class ConversationNotExistsError(BaseServiceError):
class ConversationCompletedError(Exception):
pass
+
+
+class ConversationVariableNotExistsError(BaseServiceError):
+ pass
diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py
index d9ee221a3c..6b75c29d95 100644
--- a/api/services/external_knowledge_service.py
+++ b/api/services/external_knowledge_service.py
@@ -2,9 +2,9 @@ import json
from copy import deepcopy
from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
+from urllib.parse import urlparse
import httpx
-import validators
from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
@@ -72,7 +72,9 @@ class ExternalDatasetService:
endpoint = f"{settings['endpoint']}/retrieval"
api_key = settings["api_key"]
- if not validators.url(endpoint, simple_host=True):
+
+ parsed_url = urlparse(endpoint)
+ if not all([parsed_url.scheme, parsed_url.netloc]):
if not endpoint.startswith("http://") and not endpoint.startswith("https://"):
raise ValueError(f"invalid endpoint: {endpoint} must start with http:// or https://")
else:
diff --git a/api/services/file_service.py b/api/services/file_service.py
index b4442c36c3..2ca6b4f9aa 100644
--- a/api/services/file_service.py
+++ b/api/services/file_service.py
@@ -4,7 +4,7 @@ import os
import uuid
from typing import Any, Literal, Union
-from flask_login import current_user # type: ignore
+from flask_login import current_user
from werkzeug.exceptions import NotFound
from configs import dify_config
diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py
index 0b98065f5d..56e06cc33e 100644
--- a/api/services/hit_testing_service.py
+++ b/api/services/hit_testing_service.py
@@ -69,6 +69,7 @@ class HitTestingService:
query: str,
account: Account,
external_retrieval_model: dict,
+ metadata_filtering_conditions: dict,
) -> dict:
if dataset.provider != "external":
return {
@@ -82,6 +83,7 @@ class HitTestingService:
dataset_id=dataset.id,
query=cls.escape_query_for_search(query),
external_retrieval_model=external_retrieval_model,
+ metadata_filtering_conditions=metadata_filtering_conditions,
)
end = time.perf_counter()
diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py
index 4cd2f9e8cb..c47c16f2f7 100644
--- a/api/services/metadata_service.py
+++ b/api/services/metadata_service.py
@@ -3,7 +3,7 @@ import datetime
import logging
from typing import Optional
-from flask_login import current_user # type: ignore
+from flask_login import current_user
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from extensions.ext_database import db
diff --git a/api/services/ops_service.py b/api/services/ops_service.py
index 06b4732304..6b317212d1 100644
--- a/api/services/ops_service.py
+++ b/api/services/ops_service.py
@@ -67,7 +67,14 @@ class OpsService:
new_decrypt_tracing_config.update({"project_url": project_url})
except Exception:
new_decrypt_tracing_config.update({"project_url": "https://www.comet.com/opik/"})
-
+ if tracing_provider == "weave" and (
+ "project_url" not in decrypt_tracing_config or not decrypt_tracing_config.get("project_url")
+ ):
+ try:
+ project_url = OpsTraceManager.get_trace_config_project_url(decrypt_tracing_config, tracing_provider)
+ new_decrypt_tracing_config.update({"project_url": project_url})
+ except Exception:
+ new_decrypt_tracing_config.update({"project_url": "https://wandb.ai/"})
trace_config_data.tracing_config = new_decrypt_tracing_config
return trace_config_data.to_dict()
diff --git a/api/services/plugin/data_migration.py b/api/services/plugin/data_migration.py
index 597585588b..1c5abfecba 100644
--- a/api/services/plugin/data_migration.py
+++ b/api/services/plugin/data_migration.py
@@ -86,9 +86,9 @@ limit 1000"""
update_retrieval_model_sql = ", retrieval_model = :retrieval_model"
params["retrieval_model"] = json.dumps(retrieval_model)
- sql = f"""update {table_name}
- set {provider_column_name} =
- concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
+ sql = f"""update {table_name}
+ set {provider_column_name} =
+ concat('{DEFAULT_PLUGIN_ID}/', {provider_column_name}, '/', {provider_column_name})
{update_retrieval_model_sql}
where id = :record_id"""
conn.execute(db.text(sql), params)
@@ -131,10 +131,10 @@ limit 1000"""
while True:
sql = f"""
- SELECT id, {provider_column_name} AS provider_name
+ SELECT id, {provider_column_name} AS provider_name
FROM {table_name}
- WHERE {provider_column_name} NOT LIKE '%/%'
- AND {provider_column_name} IS NOT NULL
+ WHERE {provider_column_name} NOT LIKE '%/%'
+ AND {provider_column_name} IS NOT NULL
AND {provider_column_name} != ''
AND id > :last_id
ORDER BY id ASC
@@ -183,8 +183,8 @@ limit 1000"""
if batch_updates:
update_sql = f"""
- UPDATE {table_name}
- SET {provider_column_name} = :updated_value
+ UPDATE {table_name}
+ SET {provider_column_name} = :updated_value
WHERE id = :record_id
"""
conn.execute(db.text(update_sql), [{"updated_value": u, "record_id": r} for u, r in batch_updates])
diff --git a/api/services/plugin/dependencies_analysis.py b/api/services/plugin/dependencies_analysis.py
index 07e624b4e8..830d3a4769 100644
--- a/api/services/plugin/dependencies_analysis.py
+++ b/api/services/plugin/dependencies_analysis.py
@@ -1,7 +1,7 @@
from configs import dify_config
from core.helper import marketplace
from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID
-from core.plugin.manager.plugin import PluginInstallationManager
+from core.plugin.impl.plugin import PluginInstaller
class DependenciesAnalysisService:
@@ -38,7 +38,7 @@ class DependenciesAnalysisService:
for dependency in dependencies:
required_plugin_unique_identifiers.append(dependency.value.plugin_unique_identifier)
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
# get leaked dependencies
missing_plugins = manager.fetch_missing_dependencies(tenant_id, required_plugin_unique_identifiers)
@@ -64,7 +64,7 @@ class DependenciesAnalysisService:
Generate dependencies through the list of plugin ids
"""
dependencies = list(set(dependencies))
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
plugins = manager.fetch_plugin_installation_by_ids(tenant_id, dependencies)
result = []
for plugin in plugins:
diff --git a/api/services/plugin/endpoint_service.py b/api/services/plugin/endpoint_service.py
index 35961345a8..11b8e0a3d9 100644
--- a/api/services/plugin/endpoint_service.py
+++ b/api/services/plugin/endpoint_service.py
@@ -1,10 +1,10 @@
-from core.plugin.manager.endpoint import PluginEndpointManager
+from core.plugin.impl.endpoint import PluginEndpointClient
class EndpointService:
@classmethod
def create_endpoint(cls, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict):
- return PluginEndpointManager().create_endpoint(
+ return PluginEndpointClient().create_endpoint(
tenant_id=tenant_id,
user_id=user_id,
plugin_unique_identifier=plugin_unique_identifier,
@@ -14,7 +14,7 @@ class EndpointService:
@classmethod
def list_endpoints(cls, tenant_id: str, user_id: str, page: int, page_size: int):
- return PluginEndpointManager().list_endpoints(
+ return PluginEndpointClient().list_endpoints(
tenant_id=tenant_id,
user_id=user_id,
page=page,
@@ -23,7 +23,7 @@ class EndpointService:
@classmethod
def list_endpoints_for_single_plugin(cls, tenant_id: str, user_id: str, plugin_id: str, page: int, page_size: int):
- return PluginEndpointManager().list_endpoints_for_single_plugin(
+ return PluginEndpointClient().list_endpoints_for_single_plugin(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
@@ -33,7 +33,7 @@ class EndpointService:
@classmethod
def update_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
- return PluginEndpointManager().update_endpoint(
+ return PluginEndpointClient().update_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
@@ -43,7 +43,7 @@ class EndpointService:
@classmethod
def delete_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
- return PluginEndpointManager().delete_endpoint(
+ return PluginEndpointClient().delete_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
@@ -51,7 +51,7 @@ class EndpointService:
@classmethod
def enable_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
- return PluginEndpointManager().enable_endpoint(
+ return PluginEndpointClient().enable_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
@@ -59,7 +59,7 @@ class EndpointService:
@classmethod
def disable_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str):
- return PluginEndpointManager().disable_endpoint(
+ return PluginEndpointClient().disable_endpoint(
tenant_id=tenant_id,
user_id=user_id,
endpoint_id=endpoint_id,
diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py
new file mode 100644
index 0000000000..461247419b
--- /dev/null
+++ b/api/services/plugin/oauth_service.py
@@ -0,0 +1,7 @@
+from core.plugin.impl.base import BasePluginClient
+
+
+class OAuthService(BasePluginClient):
+ @classmethod
+ def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str:
+ return "1234567890"
diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py
index ec9e0aa8dc..dbaaa7160e 100644
--- a/api/services/plugin/plugin_migration.py
+++ b/api/services/plugin/plugin_migration.py
@@ -17,7 +17,7 @@ from core.agent.entities import AgentToolEntity
from core.helper import marketplace
from core.plugin.entities.plugin import ModelProviderID, PluginInstallationSource, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus
-from core.plugin.manager.plugin import PluginInstallationManager
+from core.plugin.impl.plugin import PluginInstaller
from core.tools.entities.tool_entities import ToolProviderType
from models.account import Tenant
from models.engine import db
@@ -331,7 +331,7 @@ class PluginMigration:
"""
Install plugins.
"""
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
plugins = cls.extract_unique_plugins(extracted_plugins)
not_installed = []
@@ -426,7 +426,7 @@ class PluginMigration:
"""
Install plugins for a tenant.
"""
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
# download all the plugins and upload
thread_pool = ThreadPoolExecutor(max_workers=10)
diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py
index 96a07d36b9..be722a59ad 100644
--- a/api/services/plugin/plugin_service.py
+++ b/api/services/plugin/plugin_service.py
@@ -18,9 +18,9 @@ from core.plugin.entities.plugin import (
PluginInstallationSource,
)
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginUploadResponse
-from core.plugin.manager.asset import PluginAssetManager
-from core.plugin.manager.debugging import PluginDebuggingManager
-from core.plugin.manager.plugin import PluginInstallationManager
+from core.plugin.impl.asset import PluginAssetManager
+from core.plugin.impl.debugging import PluginDebuggingClient
+from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
@@ -91,7 +91,7 @@ class PluginService:
"""
get the debugging key of the tenant
"""
- manager = PluginDebuggingManager()
+ manager = PluginDebuggingClient()
return manager.get_debugging_key(tenant_id)
@staticmethod
@@ -106,7 +106,7 @@ class PluginService:
"""
list all plugins of the tenant
"""
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
plugins = manager.list_plugins(tenant_id)
return plugins
@@ -115,7 +115,7 @@ class PluginService:
"""
List plugin installations from ids
"""
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
return manager.fetch_plugin_installation_by_ids(tenant_id, ids)
@staticmethod
@@ -133,7 +133,7 @@ class PluginService:
"""
check if the plugin unique identifier is already installed by other tenant
"""
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
return manager.fetch_plugin_by_identifier(tenant_id, plugin_unique_identifier)
@staticmethod
@@ -141,7 +141,7 @@ class PluginService:
"""
Fetch plugin manifest
"""
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
@staticmethod
@@ -149,12 +149,12 @@ class PluginService:
"""
Fetch plugin installation tasks
"""
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
return manager.fetch_plugin_installation_tasks(tenant_id, page, page_size)
@staticmethod
def fetch_install_task(tenant_id: str, task_id: str) -> PluginInstallTask:
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
return manager.fetch_plugin_installation_task(tenant_id, task_id)
@staticmethod
@@ -162,7 +162,7 @@ class PluginService:
"""
Delete a plugin installation task
"""
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
return manager.delete_plugin_installation_task(tenant_id, task_id)
@staticmethod
@@ -172,7 +172,7 @@ class PluginService:
"""
Delete all plugin installation task items
"""
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
return manager.delete_all_plugin_installation_task_items(tenant_id)
@staticmethod
@@ -180,7 +180,7 @@ class PluginService:
"""
Delete a plugin installation task item
"""
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
return manager.delete_plugin_installation_task_item(tenant_id, task_id, identifier)
@staticmethod
@@ -190,11 +190,14 @@ class PluginService:
"""
Upgrade plugin with marketplace
"""
+ if not dify_config.MARKETPLACE_ENABLED:
+ raise ValueError("marketplace is not enabled")
+
if original_plugin_unique_identifier == new_plugin_unique_identifier:
raise ValueError("you should not upgrade plugin with the same plugin")
# check if plugin pkg is already downloaded
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
try:
manager.fetch_plugin_manifest(tenant_id, new_plugin_unique_identifier)
@@ -227,7 +230,7 @@ class PluginService:
"""
Upgrade plugin with github
"""
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
return manager.upgrade_plugin(
tenant_id,
original_plugin_unique_identifier,
@@ -247,7 +250,7 @@ class PluginService:
returns: plugin_unique_identifier
"""
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
return manager.upload_pkg(tenant_id, pkg, verify_signature)
@staticmethod
@@ -262,7 +265,7 @@ class PluginService:
f"https://github.com/{repo}/releases/download/{version}/{package}", dify_config.PLUGIN_MAX_PACKAGE_SIZE
)
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
return manager.upload_pkg(
tenant_id,
pkg,
@@ -276,12 +279,12 @@ class PluginService:
"""
Upload a plugin bundle and return the dependencies.
"""
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
return manager.upload_bundle(tenant_id, bundle, verify_signature)
@staticmethod
def install_from_local_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]):
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
return manager.install_from_identifiers(
tenant_id,
plugin_unique_identifiers,
@@ -295,7 +298,7 @@ class PluginService:
Install plugin from github release package files,
returns plugin_unique_identifier
"""
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
return manager.install_from_identifiers(
tenant_id,
[plugin_unique_identifier],
@@ -316,7 +319,10 @@ class PluginService:
"""
Fetch marketplace package
"""
- manager = PluginInstallationManager()
+ if not dify_config.MARKETPLACE_ENABLED:
+ raise ValueError("marketplace is not enabled")
+
+ manager = PluginInstaller()
try:
declaration = manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier)
except Exception:
@@ -333,7 +339,10 @@ class PluginService:
Install plugin from marketplace package files,
returns installation task id
"""
- manager = PluginInstallationManager()
+ if not dify_config.MARKETPLACE_ENABLED:
+ raise ValueError("marketplace is not enabled")
+
+ manager = PluginInstaller()
# check if already downloaded
for plugin_unique_identifier in plugin_unique_identifiers:
@@ -359,7 +368,7 @@ class PluginService:
@staticmethod
def uninstall(tenant_id: str, plugin_installation_id: str) -> bool:
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
return manager.uninstall(tenant_id, plugin_installation_id)
@staticmethod
@@ -367,5 +376,5 @@ class PluginService:
"""
Check if the tools exist
"""
- manager = PluginInstallationManager()
+ manager = PluginInstaller()
return manager.check_tools_existence(tenant_id, provider_ids)
diff --git a/api/services/tag_service.py b/api/services/tag_service.py
index 1fbaee96e8..21cb861f87 100644
--- a/api/services/tag_service.py
+++ b/api/services/tag_service.py
@@ -1,7 +1,7 @@
import uuid
from typing import Optional
-from flask_login import current_user # type: ignore
+from flask_login import current_user
from sqlalchemy import func
from werkzeug.exceptions import NotFound
diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py
index 075c60842b..3ccd14415d 100644
--- a/api/services/tools/builtin_tools_manage_service.py
+++ b/api/services/tools/builtin_tools_manage_service.py
@@ -8,7 +8,7 @@ from configs import dify_config
from core.helper.position_helper import is_filtered
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
-from core.plugin.manager.exc import PluginDaemonClientSideError
+from core.plugin.impl.exc import PluginDaemonClientSideError
from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError
diff --git a/api/services/website_service.py b/api/services/website_service.py
index 460a637a43..3913dc2efe 100644
--- a/api/services/website_service.py
+++ b/api/services/website_service.py
@@ -3,7 +3,7 @@ import json
from typing import Any
import requests
-from flask_login import current_user # type: ignore
+from flask_login import current_user
from core.helper import encrypter
from core.rag.extractor.firecrawl.firecrawl_app import FirecrawlApp
diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py
index ff3b33eecd..6d5b737962 100644
--- a/api/services/workflow_run_service.py
+++ b/api/services/workflow_run_service.py
@@ -2,8 +2,8 @@ import threading
from typing import Optional
import contexts
-from core.repository import RepositoryFactory
-from core.repository.workflow_node_execution_repository import OrderConfig
+from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
+from core.workflow.repository.workflow_node_execution_repository import OrderConfig
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import WorkflowRunTriggeredFrom
@@ -129,12 +129,8 @@ class WorkflowRunService:
return []
# Use the repository to get the node executions
- repository = RepositoryFactory.create_workflow_node_execution_repository(
- params={
- "tenant_id": app_model.tenant_id,
- "app_id": app_model.id,
- "session_factory": db.session.get_bind,
- }
+ repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id
)
# Use the repository to get the node executions with ordering
diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py
index b88c7b296d..331dba8bf1 100644
--- a/api/services/workflow_service.py
+++ b/api/services/workflow_service.py
@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.model_runtime.utils.encoders import jsonable_encoder
-from core.repository import RepositoryFactory
+from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables import Variable
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.errors import WorkflowNodeRunFailedError
@@ -28,6 +28,7 @@ from extensions.ext_database import db
from models.account import Account
from models.enums import CreatedByRole
from models.model import App, AppMode
+from models.tools import WorkflowToolProvider
from models.workflow import (
Workflow,
WorkflowNodeExecution,
@@ -284,12 +285,8 @@ class WorkflowService:
workflow_node_execution.workflow_id = draft_workflow.id
# Use the repository to save the workflow node execution
- repository = RepositoryFactory.create_workflow_node_execution_repository(
- params={
- "tenant_id": app_model.tenant_id,
- "app_id": app_model.id,
- "session_factory": db.session.get_bind,
- }
+ repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id
)
repository.save(workflow_node_execution)
@@ -523,8 +520,19 @@ class WorkflowService:
# Cannot delete a workflow that's currently in use by an app
raise WorkflowInUseError(f"Cannot delete workflow that is currently in use by app '{app.name}'")
- # Check if this workflow is published as a tool
- if workflow.tool_published:
+ # Don't use workflow.tool_published as it's not accurate for specific workflow versions
+ # Check if there's a tool provider using this specific workflow version
+ tool_provider = (
+ session.query(WorkflowToolProvider)
+ .filter(
+ WorkflowToolProvider.tenant_id == workflow.tenant_id,
+ WorkflowToolProvider.app_id == workflow.app_id,
+ WorkflowToolProvider.version == workflow.version,
+ )
+ .first()
+ )
+
+ if tool_provider:
# Cannot delete a workflow that's published as a tool
raise WorkflowInUseError("Cannot delete workflow that is published as a tool")
diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py
index e012fd4296..125e0c1b1e 100644
--- a/api/services/workspace_service.py
+++ b/api/services/workspace_service.py
@@ -1,4 +1,4 @@
-from flask_login import current_user # type: ignore
+from flask_login import current_user
from configs import dify_config
from extensions.ext_database import db
diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py
index 2b49e4bb23..2e77332ffe 100644
--- a/api/tasks/ops_trace_task.py
+++ b/api/tasks/ops_trace_task.py
@@ -44,7 +44,10 @@ def process_trace_tasks(file_info):
trace_info = trace_type(**trace_info)
trace_instance.trace(trace_info)
logging.info(f"Processing trace tasks success, app_id: {app_id}")
- except Exception:
+ except Exception as e:
+ logging.info(
+ f"error:\n\n\n{e}\n\n\n\n",
+ )
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
redis_client.incr(failed_key)
logging.info(f"Processing trace tasks failed, app_id: {app_id}")
diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py
index 4542b1b923..d5a783396a 100644
--- a/api/tasks/remove_app_and_related_data_task.py
+++ b/api/tasks/remove_app_and_related_data_task.py
@@ -7,7 +7,7 @@ from celery import shared_task # type: ignore
from sqlalchemy import delete
from sqlalchemy.exc import SQLAlchemyError
-from core.repository import RepositoryFactory
+from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from extensions.ext_database import db
from models.dataset import AppDatasetJoin
from models.model import (
@@ -189,12 +189,8 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str):
def _delete_app_workflow_node_executions(tenant_id: str, app_id: str):
# Create a repository instance for WorkflowNodeExecution
- repository = RepositoryFactory.create_workflow_node_execution_repository(
- params={
- "tenant_id": tenant_id,
- "app_id": app_id,
- "session_factory": db.session.get_bind,
- }
+ repository = SQLAlchemyWorkflowNodeExecutionRepository(
+ session_factory=db.engine, tenant_id=tenant_id, app_id=app_id
)
# Use the clear method to delete all records for this tenant_id and app_id
diff --git a/api/templates/clean_document_job_mail_template-US.html b/api/templates/clean_document_job_mail_template-US.html
index 88e78f41c7..0f7ddc62a9 100644
--- a/api/templates/clean_document_job_mail_template-US.html
+++ b/api/templates/clean_document_job_mail_template-US.html
@@ -77,7 +77,7 @@
Some Documents in Your Knowledge Base Have Been Disabled
Dear {{userName}},
- We're sorry for the inconvenience. To ensure optimal performance, documents
+ We're sorry for the inconvenience. To ensure optimal performance, documents
that haven’t been updated or accessed in the past 30 days have been disabled in
your knowledge bases:
@@ -97,4 +97,4 @@