diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index e7a8c98d26..44c1ddf739 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,5 +1,4 @@ FROM mcr.microsoft.com/devcontainers/python:3.12 -# [Optional] Uncomment this section to install additional OS packages. -# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ -# && apt-get -y install --no-install-recommends +RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \ + && apt-get -y install libgmp-dev libmpfr-dev libmpc-dev diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 30c0ff000d..b06ab9653e 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -139,6 +139,7 @@ jobs: - name: Checkout code uses: actions/checkout@v4 with: + fetch-depth: 0 persist-credentials: false - name: Check changed files diff --git a/README.md b/README.md index efb37d6083..ef654ced1e 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

📌 Introducing Dify Workflow File Upload: Recreate Google NotebookLM Podcast @@ -87,8 +87,6 @@ Please refer to our [FAQ](https://docs.dify.ai/getting-started/install-self-host **1. Workflow**: Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond. -https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - **2. Comprehensive model support**: Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers). diff --git a/README_AR.md b/README_AR.md index 4f93802fda..1310a9d802 100644 --- a/README_AR.md +++ b/README_AR.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify Cloud · @@ -54,8 +54,6 @@ **1. سير العمل**: قم ببناء واختبار سير عمل الذكاء الاصطناعي القوي على قماش بصري، مستفيدًا من جميع الميزات التالية وأكثر. - - **2. الدعم الشامل للنماذج**: تكامل سلس مع مئات من LLMs الخاصة / مفتوحة المصدر من عشرات من موفري التحليل والحلول المستضافة ذاتيًا، مما يغطي GPT و Mistral و Llama3 وأي نماذج متوافقة مع واجهة OpenAI API. يمكن العثور على قائمة كاملة بمزودي النموذج المدعومين [هنا](https://docs.dify.ai/getting-started/readme/model-providers). ![providers-v5](https://github.com/langgenius/dify/assets/13230914/5a17bdbe-097a-4100-8363-40255b70f6e3) diff --git a/README_BN.md b/README_BN.md index 7599fae9ff..be1f8cc72e 100644 --- a/README_BN.md +++ b/README_BN.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

📌 ডিফাই ওয়ার্কফ্লো ফাইল আপলোড পরিচিতি: গুগল নোটবুক-এলএম পডকাস্ট পুনর্নির্মাণ @@ -84,8 +84,6 @@ docker compose up -d **১. ওয়ার্কফ্লো**: ভিজ্যুয়াল ক্যানভাসে AI ওয়ার্কফ্লো তৈরি এবং পরীক্ষা করুন, নিম্নলিখিত সব ফিচার এবং তার বাইরেও আরও অনেক কিছু ব্যবহার করে। - - **২. মডেল সাপোর্ট**: GPT, Mistral, Llama3, এবং যেকোনো OpenAI API-সামঞ্জস্যপূর্ণ মডেলসহ, কয়েক ডজন ইনফারেন্স প্রদানকারী এবং সেল্ফ-হোস্টেড সমাধান থেকে শুরু করে প্রোপ্রাইটরি/ওপেন-সোর্স LLM-এর সাথে সহজে ইন্টিগ্রেশন। সমর্থিত মডেল প্রদানকারীদের একটি সম্পূর্ণ তালিকা পাওয়া যাবে [এখানে](https://docs.dify.ai/getting-started/readme/model-providers)। diff --git a/README_CN.md b/README_CN.md index 973629f459..1d4d51e856 100644 --- a/README_CN.md +++ b/README_CN.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify 云服务 · @@ -61,11 +61,6 @@ Dify 是一个开源的 LLM 应用开发平台。其直观的界面结合了 AI **1. 工作流**: 在画布上构建和测试功能强大的 AI 工作流程,利用以下所有功能以及更多功能。 - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. 全面的模型支持**: 与数百种专有/开源 LLMs 以及数十种推理提供商和自托管解决方案无缝集成,涵盖 GPT、Mistral、Llama3 以及任何与 OpenAI API 兼容的模型。完整的支持模型提供商列表可在[此处](https://docs.dify.ai/getting-started/readme/model-providers)找到。 diff --git a/README_DE.md b/README_DE.md index 738c0e3b67..7cc2710fde 100644 --- a/README_DE.md +++ b/README_DE.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

📌 Einführung in Dify Workflow File Upload: Google NotebookLM Podcast nachbilden @@ -83,11 +83,6 @@ Bitte beachten Sie unsere [FAQ](https://docs.dify.ai/getting-started/install-sel **1. Workflow**: Erstellen und testen Sie leistungsstarke KI-Workflows auf einer visuellen Oberfläche, wobei Sie alle der folgenden Funktionen und darüber hinaus nutzen können. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. Umfassende Modellunterstützung**: Nahtlose Integration mit Hunderten von proprietären und Open-Source-LLMs von Dutzenden Inferenzanbietern und selbstgehosteten Lösungen, die GPT, Mistral, Llama3 und alle mit der OpenAI API kompatiblen Modelle abdecken. Eine vollständige Liste der unterstützten Modellanbieter finden Sie [hier](https://docs.dify.ai/getting-started/readme/model-providers). diff --git a/README_ES.md b/README_ES.md index 212268b73d..12f2ce8c11 100644 --- a/README_ES.md +++ b/README_ES.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify Cloud · @@ -59,11 +59,6 @@ Dify es una plataforma de desarrollo de aplicaciones de LLM de código abierto. **1. Flujo de trabajo**: Construye y prueba potentes flujos de trabajo de IA en un lienzo visual, aprovechando todas las siguientes características y más. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. Soporte de modelos completo**: Integración perfecta con cientos de LLMs propietarios / de código abierto de docenas de proveedores de inferencia y soluciones auto-alojadas, que cubren GPT, Mistral, Llama3 y cualquier modelo compatible con la API de OpenAI. Se puede encontrar una lista completa de proveedores de modelos admitidos [aquí](https://docs.dify.ai/getting-started/readme/model-providers). diff --git a/README_FR.md b/README_FR.md index 89eea7d058..b106615b31 100644 --- a/README_FR.md +++ b/README_FR.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify Cloud · @@ -59,11 +59,6 @@ Dify est une plateforme de développement d'applications LLM open source. Son in **1. Flux de travail** : Construisez et testez des flux de travail d'IA puissants sur un canevas visuel, en utilisant toutes les fonctionnalités suivantes et plus encore. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. Prise en charge complète des modèles** : Intégration transparente avec des centaines de LLM propriétaires / open source provenant de dizaines de fournisseurs d'inférence et de solutions auto-hébergées, couvrant GPT, Mistral, Llama3, et tous les modèles compatibles avec l'API OpenAI. Une liste complète des fournisseurs de modèles pris en charge se trouve [ici](https://docs.dify.ai/getting-started/readme/model-providers). diff --git a/README_JA.md b/README_JA.md index adca219753..be5d8644e0 100644 --- a/README_JA.md +++ b/README_JA.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify Cloud · @@ -60,11 +60,6 @@ DifyはオープンソースのLLMアプリケーション開発プラットフ **1. ワークフロー**: 強力なAIワークフローをビジュアルキャンバス上で構築し、テストできます。すべての機能、および以下の機能を使用できます。 - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. 総合的なモデルサポート**: 数百ものプロプライエタリ/オープンソースのLLMと、数十もの推論プロバイダーおよびセルフホスティングソリューションとのシームレスな統合を提供します。GPT、Mistral、Llama3、OpenAI APIと互換性のあるすべてのモデルを統合されています。サポートされているモデルプロバイダーの完全なリストは[こちら](https://docs.dify.ai/getting-started/readme/model-providers)をご覧ください。 diff --git a/README_KL.md b/README_KL.md index 17e6c9d509..3ebd756929 100644 --- a/README_KL.md +++ b/README_KL.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify Cloud · @@ -59,11 +59,6 @@ Dify is an open-source LLM app development platform. Its intuitive interface com **1. Workflow**: Build and test powerful AI workflows on a visual canvas, leveraging all the following features and beyond. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. Comprehensive model support**: Seamless integration with hundreds of proprietary / open-source LLMs from dozens of inference providers and self-hosted solutions, covering GPT, Mistral, Llama3, and any OpenAI API-compatible models. A full list of supported model providers can be found [here](https://docs.dify.ai/getting-started/readme/model-providers). diff --git a/README_KR.md b/README_KR.md index d44723f9b6..ecbe2f6b74 100644 --- a/README_KR.md +++ b/README_KR.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify 클라우드 · @@ -54,11 +54,6 @@ **1. 워크플로우**: 다음 기능들을 비롯한 다양한 기능을 활용하여 시각적 캔버스에서 강력한 AI 워크플로우를 구축하고 테스트하세요. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. 포괄적인 모델 지원:**: 수십 개의 추론 제공업체와 자체 호스팅 솔루션에서 제공하는 수백 개의 독점 및 오픈 소스 LLM과 원활하게 통합되며, GPT, Mistral, Llama3 및 모든 OpenAI API 호환 모델을 포함합니다. 지원되는 모델 제공업체의 전체 목록은 [여기](https://docs.dify.ai/getting-started/readme/model-providers)에서 확인할 수 있습니다. diff --git a/README_PT.md b/README_PT.md index 9dc2207279..157772d528 100644 --- a/README_PT.md +++ b/README_PT.md @@ -1,5 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) - +![cover-v5-optimized](./images/GitHub_README_if.png)

📌 Introduzindo o Dify Workflow com Upload de Arquivo: Recrie o Podcast Google NotebookLM

@@ -59,11 +58,6 @@ Dify é uma plataforma de desenvolvimento de aplicativos LLM de código aberto. **1. Workflow**: Construa e teste workflows poderosos de IA em uma interface visual, aproveitando todos os recursos a seguir e muito mais. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. Suporte abrangente a modelos**: Integração perfeita com centenas de LLMs proprietários e de código aberto de diversas provedoras e soluções auto-hospedadas, abrangendo GPT, Mistral, Llama3 e qualquer modelo compatível com a API da OpenAI. A lista completa de provedores suportados pode ser encontrada [aqui](https://docs.dify.ai/getting-started/readme/model-providers). diff --git a/README_SI.md b/README_SI.md index 9a38b558b4..45aec75dcd 100644 --- a/README_SI.md +++ b/README_SI.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

📌 Predstavljamo nalaganje datotek Dify Workflow: znova ustvarite Google NotebookLM Podcast @@ -81,11 +81,6 @@ Prosimo, glejte naša pogosta vprašanja [FAQ](https://docs.dify.ai/getting-star **1. Potek dela**: Zgradite in preizkusite zmogljive poteke dela AI na vizualnem platnu, pri čemer izkoristite vse naslednje funkcije in več. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. Celovita podpora za modele**: Brezhibna integracija s stotinami lastniških/odprtokodnih LLM-jev ducatov ponudnikov sklepanja in samostojnih rešitev, ki pokrivajo GPT, Mistral, Llama3 in vse modele, združljive z API-jem OpenAI. Celoten seznam podprtih ponudnikov modelov najdete [tukaj](https://docs.dify.ai/getting-started/readme/model-providers). diff --git a/README_TR.md b/README_TR.md index ab2853a019..94ab6ee338 100644 --- a/README_TR.md +++ b/README_TR.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify Bulut · @@ -55,11 +55,6 @@ Dify, açık kaynaklı bir LLM uygulama geliştirme platformudur. Sezgisel aray **1. Workflow**: Görsel bir arayüz üzerinde güçlü AI iş akışları oluşturun ve test edin, aşağıdaki tüm özellikleri ve daha fazlasını kullanarak. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. Kapsamlı model desteği**: Çok sayıda çıkarım sağlayıcısı ve kendi kendine barındırılan çözümlerden yüzlerce özel / açık kaynaklı LLM ile sorunsuz entegrasyon sağlar. GPT, Mistral, Llama3 ve OpenAI API uyumlu tüm modelleri kapsar. Desteklenen model sağlayıcılarının tam listesine [buradan](https://docs.dify.ai/getting-started/readme/model-providers) ulaşabilirsiniz. diff --git a/README_TW.md b/README_TW.md index 8263a22b64..c252292e44 100644 --- a/README_TW.md +++ b/README_TW.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

📌 介紹 Dify 工作流程檔案上傳功能:重現 Google NotebookLM Podcast @@ -86,8 +86,6 @@ docker compose up -d **1. 工作流程**: 在視覺化畫布上建立和測試強大的 AI 工作流程,利用以下所有功能及更多。 -https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - **2. 全面的模型支援**: 無縫整合來自數十個推理提供商和自託管解決方案的數百個專有/開源 LLM,涵蓋 GPT、Mistral、Llama3 和任何與 OpenAI API 兼容的模型。您可以在[此處](https://docs.dify.ai/getting-started/readme/model-providers)找到支援的模型提供商完整列表。 diff --git a/README_VI.md b/README_VI.md index 852ed7aaa0..4e1e05cbf3 100644 --- a/README_VI.md +++ b/README_VI.md @@ -1,4 +1,4 @@ -![cover-v5-optimized](https://github.com/langgenius/dify/assets/13230914/f9e19af5-61ba-4119-b926-d10c4c06ebab) +![cover-v5-optimized](./images/GitHub_README_if.png)

Dify Cloud · @@ -55,11 +55,6 @@ Dify là một nền tảng phát triển ứng dụng LLM mã nguồn mở. Gia **1. Quy trình làm việc**: Xây dựng và kiểm tra các quy trình làm việc AI mạnh mẽ trên một canvas trực quan, tận dụng tất cả các tính năng sau đây và hơn thế nữa. - - https://github.com/langgenius/dify/assets/13230914/356df23e-1604-483d-80a6-9517ece318aa - - - **2. Hỗ trợ mô hình toàn diện**: Tích hợp liền mạch với hàng trăm mô hình LLM độc quyền / mã nguồn mở từ hàng chục nhà cung cấp suy luận và giải pháp tự lưu trữ, bao gồm GPT, Mistral, Llama3, và bất kỳ mô hình tương thích API OpenAI nào. Danh sách đầy đủ các nhà cung cấp mô hình được hỗ trợ có thể được tìm thấy [tại đây](https://docs.dify.ai/getting-started/readme/model-providers). diff --git a/api/.env.example b/api/.env.example index 2cc6410cdd..9dec564c28 100644 --- a/api/.env.example +++ b/api/.env.example @@ -348,6 +348,7 @@ SENTRY_DSN= # DEBUG DEBUG=false +ENABLE_REQUEST_LOGGING=False SQLALCHEMY_ECHO=false # Notion import configuration, support public and internal @@ -476,6 +477,7 @@ LOGIN_LOCKOUT_DURATION=86400 ENABLE_OTEL=false OTLP_BASE_ENDPOINT=http://localhost:4318 OTLP_API_KEY= +OTEL_EXPORTER_OTLP_PROTOCOL= OTEL_EXPORTER_TYPE=otlp OTEL_SAMPLING_RATE=0.1 OTEL_BATCH_EXPORT_SCHEDULE_DELAY=5000 diff --git a/api/app_factory.py b/api/app_factory.py index 1c886ac5c7..3a258be28f 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -54,6 +54,7 @@ def initialize_extensions(app: DifyApp): ext_otel, ext_proxy_fix, ext_redis, + ext_request_logging, ext_sentry, ext_set_secretkey, ext_storage, @@ -83,6 +84,7 @@ def initialize_extensions(app: DifyApp): ext_blueprints, ext_commands, ext_otel, + ext_request_logging, ] for ext in extensions: short_name = ext.__name__.split(".")[-1] diff --git a/api/configs/deploy/__init__.py b/api/configs/deploy/__init__.py index 950936d3c6..63f4dfba63 100644 --- a/api/configs/deploy/__init__.py +++ b/api/configs/deploy/__init__.py @@ -17,6 +17,12 @@ class DeploymentConfig(BaseSettings): default=False, ) + # Request logging configuration + ENABLE_REQUEST_LOGGING: bool = Field( + description="Enable request and response body logging", + default=False, + ) + EDITION: str = Field( description="Deployment edition of the application (e.g., 'SELF_HOSTED', 'CLOUD')", default="SELF_HOSTED", diff --git a/api/configs/observability/otel/otel_config.py b/api/configs/observability/otel/otel_config.py index 568a800d10..1b88ddcfe6 100644 --- a/api/configs/observability/otel/otel_config.py +++ b/api/configs/observability/otel/otel_config.py @@ -27,6 +27,11 @@ class OTelConfig(BaseSettings): default="otlp", ) + OTEL_EXPORTER_OTLP_PROTOCOL: str = Field( + description="OTLP exporter protocol ('grpc' or 'http')", + default="http", + ) + OTEL_SAMPLING_RATE: float = Field(default=0.1, description="Sampling rate for traces (0.0 to 1.0)") OTEL_BATCH_EXPORT_SCHEDULE_DELAY: int = Field( diff --git a/api/configs/packaging/__init__.py b/api/configs/packaging/__init__.py index c7960e1356..de13018f37 100644 --- a/api/configs/packaging/__init__.py +++ b/api/configs/packaging/__init__.py @@ -9,7 +9,7 @@ class PackagingInfo(BaseSettings): CURRENT_VERSION: str = Field( description="Dify version", - default="1.3.1", + default="1.4.0", ) COMMIT_SHA: str = Field( diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index f97209c369..860166a61a 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -17,15 +17,13 @@ from controllers.console.wraps import ( ) from core.ops.ops_trace_manager import OpsTraceManager from extensions.ext_database import db -from fields.app_fields import ( - app_detail_fields, - app_detail_fields_with_site, - app_pagination_fields, -) +from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields from libs.login import login_required from models import Account, App from services.app_dsl_service import AppDslService, ImportMode from services.app_service import AppService +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"] @@ -75,7 +73,17 @@ class AppListApi(Resource): if not app_pagination: return {"data": [], "total": 0, "page": 1, "limit": 20, "has_more": False} - return marshal(app_pagination, app_pagination_fields) + if FeatureService.get_system_features().webapp_auth.enabled: + app_ids = [str(app.id) for app in app_pagination.items] + res = EnterpriseService.WebAppAuth.batch_get_app_access_mode_by_id(app_ids=app_ids) + if len(res) != len(app_ids): + raise BadRequest("Invalid app id in webapp auth") + + for app in app_pagination.items: + if str(app.id) in res: + app.access_mode = res[str(app.id)].access_mode + + return marshal(app_pagination, app_pagination_fields), 200 @setup_required @login_required @@ -119,6 +127,10 @@ class AppApi(Resource): app_model = app_service.get_app(app_model) + if FeatureService.get_system_features().webapp_auth.enabled: + app_setting = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id=str(app_model.id)) + app_model.access_mode = app_setting.access_mode + return app_model @setup_required diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index 0c13adce9b..cbbdd324ba 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -81,8 +81,7 @@ class DraftWorkflowApi(Resource): parser.add_argument("graph", type=dict, required=True, nullable=False, location="json") parser.add_argument("features", type=dict, required=True, nullable=False, location="json") parser.add_argument("hash", type=str, required=False, location="json") - # TODO: set this to required=True after frontend is updated - parser.add_argument("environment_variables", type=list, required=False, location="json") + parser.add_argument("environment_variables", type=list, required=True, location="json") parser.add_argument("conversation_variables", type=list, required=False, location="json") args = parser.parse_args() elif "text/plain" in content_type: diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 08ab61bbb9..9099700213 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,3 +1,6 @@ +from typing import cast + +from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse from flask_restful.inputs import int_range @@ -12,8 +15,7 @@ from fields.workflow_run_fields import ( ) from libs.helper import uuid_value from libs.login import login_required -from models import App -from models.model import AppMode +from models import Account, App, AppMode, EndUser from services.workflow_run_service import WorkflowRunService @@ -90,7 +92,12 @@ class WorkflowRunNodeExecutionListApi(Resource): run_id = str(run_id) workflow_run_service = WorkflowRunService() - node_executions = workflow_run_service.get_workflow_run_node_executions(app_model=app_model, run_id=run_id) + user = cast("Account | EndUser", current_user) + node_executions = workflow_run_service.get_workflow_run_node_executions( + app_model=app_model, + run_id=run_id, + user=user, + ) return {"data": node_executions} diff --git a/api/controllers/console/auth/forgot_password.py b/api/controllers/console/auth/forgot_password.py index d73d8ce701..8db19095d2 100644 --- a/api/controllers/console/auth/forgot_password.py +++ b/api/controllers/console/auth/forgot_password.py @@ -24,7 +24,7 @@ from libs.password import hash_password, valid_password from models.account import Account from services.account_service import AccountService, TenantService from services.errors.account import AccountRegisterError -from services.errors.workspace import WorkSpaceNotAllowedCreateError +from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService @@ -119,6 +119,9 @@ class ForgotPasswordResetApi(Resource): if not reset_data: raise InvalidTokenError() # Must use token in reset phase + if reset_data.get("phase", "") != "reset": + raise InvalidTokenError() + # Must use token in reset phase if reset_data.get("phase", "") != "reset": raise InvalidTokenError() @@ -168,6 +171,8 @@ class ForgotPasswordResetApi(Resource): ) except WorkSpaceNotAllowedCreateError: pass + except WorkspacesLimitExceededError: + pass except AccountRegisterError: raise AccountInFreezeError() diff --git a/api/controllers/console/auth/login.py b/api/controllers/console/auth/login.py index 27864bab3d..86231bf616 100644 --- a/api/controllers/console/auth/login.py +++ b/api/controllers/console/auth/login.py @@ -21,6 +21,7 @@ from controllers.console.error import ( AccountNotFound, EmailSendIpLimitError, NotAllowedCreateWorkspace, + WorkspacesLimitExceeded, ) from controllers.console.wraps import email_password_login_enabled, setup_required from events.tenant_event import tenant_was_created @@ -30,7 +31,7 @@ from models.account import Account from services.account_service import AccountService, RegisterService, TenantService from services.billing_service import BillingService from services.errors.account import AccountRegisterError -from services.errors.workspace import WorkSpaceNotAllowedCreateError +from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService @@ -88,10 +89,15 @@ class LoginApi(Resource): # SELF_HOSTED only have one workspace tenants = TenantService.get_join_tenants(account) if len(tenants) == 0: - return { - "result": "fail", - "data": "workspace not found, please contact system admin to invite you to join in a workspace", - } + system_features = FeatureService.get_system_features() + + if system_features.is_allow_create_workspace and not system_features.license.workspaces.is_available(): + raise WorkspacesLimitExceeded() + else: + return { + "result": "fail", + "data": "workspace not found, please contact system admin to invite you to join in a workspace", + } token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(args["email"]) @@ -198,6 +204,9 @@ class EmailCodeLoginApi(Resource): if account: tenant = TenantService.get_join_tenants(account) if not tenant: + workspaces = FeatureService.get_system_features().license.workspaces + if not workspaces.is_available(): + raise WorkspacesLimitExceeded() if not FeatureService.get_system_features().is_allow_create_workspace: raise NotAllowedCreateWorkspace() else: @@ -215,6 +224,8 @@ class EmailCodeLoginApi(Resource): return NotAllowedCreateWorkspace() except AccountRegisterError as are: raise AccountInFreezeError() + except WorkspacesLimitExceededError: + raise WorkspacesLimitExceeded() token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) AccountService.reset_login_error_rate_limit(args["email"]) return {"result": "success", "data": token_pair.model_dump()} diff --git a/api/controllers/console/error.py b/api/controllers/console/error.py index b8fd1f0358..6944c56bf8 100644 --- a/api/controllers/console/error.py +++ b/api/controllers/console/error.py @@ -46,6 +46,18 @@ class NotAllowedCreateWorkspace(BaseHTTPException): code = 400 +class WorkspaceMembersLimitExceeded(BaseHTTPException): + error_code = "limit_exceeded" + description = "Unable to add member because the maximum workspace's member limit was exceeded" + code = 400 + + +class WorkspacesLimitExceeded(BaseHTTPException): + error_code = "limit_exceeded" + description = "Unable to create workspace because the maximum workspace limit was exceeded" + code = 400 + + class AccountBannedError(BaseHTTPException): error_code = "account_banned" description = "Account is banned." diff --git a/api/controllers/console/explore/error.py b/api/controllers/console/explore/error.py index 18221b7797..1e05ff4206 100644 --- a/api/controllers/console/explore/error.py +++ b/api/controllers/console/explore/error.py @@ -23,3 +23,9 @@ class AppSuggestedQuestionsAfterAnswerDisabledError(BaseHTTPException): error_code = "app_suggested_questions_after_answer_disabled" description = "Function Suggested questions after answer disabled." code = 403 + + +class AppAccessDeniedError(BaseHTTPException): + error_code = "access_denied" + description = "App access denied." + code = 403 diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 4062972d08..f73a226c89 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,3 +1,4 @@ +import logging from datetime import UTC, datetime from typing import Any @@ -15,6 +16,11 @@ from fields.installed_app_fields import installed_app_list_fields from libs.login import login_required from models import App, InstalledApp, RecommendedApp from services.account_service import TenantService +from services.app_service import AppService +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService + +logger = logging.getLogger(__name__) class InstalledAppsListApi(Resource): @@ -48,6 +54,21 @@ class InstalledAppsListApi(Resource): for installed_app in installed_apps if installed_app.app is not None ] + + # filter out apps that user doesn't have access to + if FeatureService.get_system_features().webapp_auth.enabled: + user_id = current_user.id + res = [] + for installed_app in installed_app_list: + app_code = AppService.get_app_code_by_id(str(installed_app["app"].id)) + if EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( + user_id=user_id, + app_code=app_code, + ): + res.append(installed_app) + installed_app_list = res + logger.debug(f"installed_app_list: {installed_app_list}, user_id: {user_id}") + installed_app_list.sort( key=lambda app: ( -app["is_pinned"], diff --git a/api/controllers/console/explore/wraps.py b/api/controllers/console/explore/wraps.py index 49ea81a8a0..afbd78bd5b 100644 --- a/api/controllers/console/explore/wraps.py +++ b/api/controllers/console/explore/wraps.py @@ -4,10 +4,14 @@ from flask_login import current_user from flask_restful import Resource from werkzeug.exceptions import NotFound +from controllers.console.explore.error import AppAccessDeniedError from controllers.console.wraps import account_initialization_required from extensions.ext_database import db from libs.login import login_required from models import InstalledApp +from services.app_service import AppService +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService def installed_app_required(view=None): @@ -48,6 +52,36 @@ def installed_app_required(view=None): return decorator +def user_allowed_to_access_app(view=None): + def decorator(view): + @wraps(view) + def decorated(installed_app: InstalledApp, *args, **kwargs): + feature = FeatureService.get_system_features() + if feature.webapp_auth.enabled: + app_id = installed_app.app_id + app_code = AppService.get_app_code_by_id(app_id) + res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp( + user_id=str(current_user.id), + app_code=app_code, + ) + if not res: + raise AppAccessDeniedError() + + return view(installed_app, *args, **kwargs) + + return decorated + + if view: + return decorator(view) + return decorator + + class InstalledAppResource(Resource): # must be reversed if there are multiple decorators - method_decorators = [installed_app_required, account_initialization_required, login_required] + + method_decorators = [ + user_allowed_to_access_app, + installed_app_required, + account_initialization_required, + login_required, + ] diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index a0031307fd..db49da7840 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -6,6 +6,7 @@ from flask_restful import Resource, abort, marshal_with, reqparse import services from configs import dify_config from controllers.console import api +from controllers.console.error import WorkspaceMembersLimitExceeded from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, @@ -17,6 +18,7 @@ from libs.login import login_required from models.account import Account, TenantAccountRole from services.account_service import RegisterService, TenantService from services.errors.account import AccountAlreadyInTenantError +from services.feature_service import FeatureService class MemberListApi(Resource): @@ -54,6 +56,12 @@ class MemberInviteEmailApi(Resource): inviter = current_user invitation_results = [] console_web_url = dify_config.CONSOLE_WEB_URL + + workspace_members = FeatureService.get_features(tenant_id=inviter.current_tenant.id).workspace_members + + if not workspace_members.is_available(len(invitee_emails)): + raise WorkspaceMembersLimitExceeded() + for invitee_email in invitee_emails: try: token = RegisterService.invite_new_member( diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index f147a3453f..d51db4322a 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -5,5 +5,6 @@ from libs.external_api import ExternalApi bp = Blueprint("inner_api", __name__, url_prefix="/inner/api") api = ExternalApi(bp) +from . import mail from .plugin import plugin from .workspace import workspace diff --git a/api/controllers/inner_api/mail.py b/api/controllers/inner_api/mail.py new file mode 100644 index 0000000000..ce3373d65c --- /dev/null +++ b/api/controllers/inner_api/mail.py @@ -0,0 +1,27 @@ +from flask_restful import ( + Resource, # type: ignore + reqparse, +) + +from controllers.console.wraps import setup_required +from controllers.inner_api import api +from controllers.inner_api.wraps import enterprise_inner_api_only +from services.enterprise.mail_service import DifyMail, EnterpriseMailService + + +class EnterpriseMail(Resource): + @setup_required + @enterprise_inner_api_only + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("to", type=str, action="append", required=True) + parser.add_argument("subject", type=str, required=True) + parser.add_argument("body", type=str, required=True) + parser.add_argument("substitutions", type=dict, required=False) + args = parser.parse_args() + + EnterpriseMailService.send_mail(DifyMail(**args)) + return {"message": "success"}, 200 + + +api.add_resource(EnterpriseMail, "/enterprise/mail") diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 92daa8cf7b..56a5a76c6f 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -313,7 +313,7 @@ class DatasetApi(DatasetApiResource): try: if DatasetService.delete_dataset(dataset_id_str, current_user): DatasetPermissionService.clear_partial_member_list(dataset_id_str) - return {"result": "success"}, 204 + return 204 else: raise NotFound("Dataset not found.") except services.errors.dataset.DatasetInUseError: diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index f0f39fc2e5..44c75f40ef 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -323,7 +323,7 @@ class DocumentDeleteApi(DatasetApiResource): except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return {"result": "success"}, 204 + return 204 class DocumentListApi(DatasetApiResource): diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index fb3ca1e15f..ea4be4e511 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -159,7 +159,7 @@ class DatasetSegmentApi(DatasetApiResource): if not segment: raise NotFound("Segment not found.") SegmentService.delete_segment(segment, document, dataset) - return {"result": "success"}, 204 + return 204 @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id, segment_id): @@ -344,7 +344,7 @@ class DatasetChildChunkApi(DatasetApiResource): except ChildChunkDeleteIndexServiceError as e: raise ChildChunkDeleteIndexError(str(e)) - return {"result": "success"}, 204 + return 204 @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") diff --git a/api/controllers/web/app.py b/api/controllers/web/app.py index c9a37af5ed..bb4486bd91 100644 --- a/api/controllers/web/app.py +++ b/api/controllers/web/app.py @@ -1,12 +1,15 @@ -from flask_restful import marshal_with +from flask import request +from flask_restful import Resource, marshal_with, reqparse from controllers.common import fields from controllers.web import api from controllers.web.error import AppUnavailableError from controllers.web.wraps import WebApiResource from core.app.app_config.common.parameters_mapping import get_parameters_from_feature_dict +from libs.passport import PassportService from models.model import App, AppMode from services.app_service import AppService +from services.enterprise.enterprise_service import EnterpriseService class AppParameterApi(WebApiResource): @@ -40,5 +43,51 @@ class AppMeta(WebApiResource): return AppService().get_app_meta(app_model) +class AppAccessMode(Resource): + def get(self): + parser = reqparse.RequestParser() + parser.add_argument("appId", type=str, required=True, location="args") + args = parser.parse_args() + + app_id = args["appId"] + res = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_id) + + return {"accessMode": res.access_mode} + + +class AppWebAuthPermission(Resource): + def get(self): + user_id = "visitor" + try: + auth_header = request.headers.get("Authorization") + if auth_header is None: + raise + if " " not in auth_header: + raise + + auth_scheme, tk = auth_header.split(None, 1) + auth_scheme = auth_scheme.lower() + if auth_scheme != "bearer": + raise + + decoded = PassportService().verify(tk) + user_id = decoded.get("user_id", "visitor") + except Exception as e: + pass + + parser = reqparse.RequestParser() + parser.add_argument("appId", type=str, required=True, location="args") + args = parser.parse_args() + + app_id = args["appId"] + app_code = AppService.get_app_code_by_id(app_id) + + res = EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(str(user_id), app_code) + return {"result": res} + + api.add_resource(AppParameterApi, "/parameters") api.add_resource(AppMeta, "/meta") +# webapp auth apis +api.add_resource(AppAccessMode, "/webapp/access-mode") +api.add_resource(AppWebAuthPermission, "/webapp/permission") diff --git a/api/controllers/web/error.py b/api/controllers/web/error.py index 9fe5d08d54..4371e679db 100644 --- a/api/controllers/web/error.py +++ b/api/controllers/web/error.py @@ -121,9 +121,15 @@ class UnsupportedFileTypeError(BaseHTTPException): code = 415 -class WebSSOAuthRequiredError(BaseHTTPException): +class WebAppAuthRequiredError(BaseHTTPException): error_code = "web_sso_auth_required" - description = "Web SSO authentication required." + description = "Web app authentication required." + code = 401 + + +class WebAppAuthAccessDeniedError(BaseHTTPException): + error_code = "web_app_access_denied" + description = "You do not have permission to access this web app." code = 401 diff --git a/api/controllers/web/login.py b/api/controllers/web/login.py new file mode 100644 index 0000000000..06c2274440 --- /dev/null +++ b/api/controllers/web/login.py @@ -0,0 +1,120 @@ +from flask import request +from flask_restful import Resource, reqparse +from jwt import InvalidTokenError # type: ignore +from werkzeug.exceptions import BadRequest + +import services +from controllers.console.auth.error import EmailCodeError, EmailOrPasswordMismatchError, InvalidEmailError +from controllers.console.error import AccountBannedError, AccountNotFound +from controllers.console.wraps import setup_required +from libs.helper import email +from libs.password import valid_password +from services.account_service import AccountService +from services.webapp_auth_service import WebAppAuthService + + +class LoginApi(Resource): + """Resource for web app email/password login.""" + + def post(self): + """Authenticate user and login.""" + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("password", type=valid_password, required=True, location="json") + args = parser.parse_args() + + app_code = request.headers.get("X-App-Code") + if app_code is None: + raise BadRequest("X-App-Code header is missing.") + + try: + account = WebAppAuthService.authenticate(args["email"], args["password"]) + except services.errors.account.AccountLoginError: + raise AccountBannedError() + except services.errors.account.AccountPasswordError: + raise EmailOrPasswordMismatchError() + except services.errors.account.AccountNotFoundError: + raise AccountNotFound() + + WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code) + + end_user = WebAppAuthService.create_end_user(email=args["email"], app_code=app_code) + + token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id) + return {"result": "success", "token": token} + + +# class LogoutApi(Resource): +# @setup_required +# def get(self): +# account = cast(Account, flask_login.current_user) +# if isinstance(account, flask_login.AnonymousUserMixin): +# return {"result": "success"} +# flask_login.logout_user() +# return {"result": "success"} + + +class EmailCodeLoginSendEmailApi(Resource): + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + + account = WebAppAuthService.get_user_through_email(args["email"]) + if account is None: + raise AccountNotFound() + else: + token = WebAppAuthService.send_email_code_login_email(account=account, language=language) + + return {"result": "success", "data": token} + + +class EmailCodeLoginApi(Resource): + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=str, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, location="json") + args = parser.parse_args() + + user_email = args["email"] + app_code = request.headers.get("X-App-Code") + if app_code is None: + raise BadRequest("X-App-Code header is missing.") + + token_data = WebAppAuthService.get_email_code_login_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if token_data["email"] != args["email"]: + raise InvalidEmailError() + + if token_data["code"] != args["code"]: + raise EmailCodeError() + + WebAppAuthService.revoke_email_code_login_token(args["token"]) + account = WebAppAuthService.get_user_through_email(user_email) + if not account: + raise AccountNotFound() + + WebAppAuthService._validate_user_accessibility(account=account, app_code=app_code) + + end_user = WebAppAuthService.create_end_user(email=user_email, app_code=app_code) + + token = WebAppAuthService.login(account=account, app_code=app_code, end_user_id=end_user.id) + AccountService.reset_login_error_rate_limit(args["email"]) + return {"result": "success", "token": token} + + +# api.add_resource(LoginApi, "/login") +# api.add_resource(LogoutApi, "/logout") +# api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") +# api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") diff --git a/api/controllers/web/passport.py b/api/controllers/web/passport.py index 267dac223d..154c6772b7 100644 --- a/api/controllers/web/passport.py +++ b/api/controllers/web/passport.py @@ -5,7 +5,7 @@ from flask_restful import Resource from werkzeug.exceptions import NotFound, Unauthorized from controllers.web import api -from controllers.web.error import WebSSOAuthRequiredError +from controllers.web.error import WebAppAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site @@ -24,10 +24,10 @@ class PassportResource(Resource): if app_code is None: raise Unauthorized("X-App-Code header is missing.") - if system_features.sso_enforced_for_web: - app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) - if app_web_sso_enabled: - raise WebSSOAuthRequiredError() + if system_features.webapp_auth.enabled: + app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) + if not app_settings or not app_settings.access_mode == "public": + raise WebAppAuthRequiredError() # get site from db and check if it is normal site = db.session.query(Site).filter(Site.code == app_code, Site.status == "normal").first() diff --git a/api/controllers/web/wraps.py b/api/controllers/web/wraps.py index c327c3df18..3bb029d6eb 100644 --- a/api/controllers/web/wraps.py +++ b/api/controllers/web/wraps.py @@ -4,7 +4,7 @@ from flask import request from flask_restful import Resource from werkzeug.exceptions import BadRequest, NotFound, Unauthorized -from controllers.web.error import WebSSOAuthRequiredError +from controllers.web.error import WebAppAuthAccessDeniedError, WebAppAuthRequiredError from extensions.ext_database import db from libs.passport import PassportService from models.model import App, EndUser, Site @@ -29,7 +29,7 @@ def validate_jwt_token(view=None): def decode_jwt_token(): system_features = FeatureService.get_system_features() - app_code = request.headers.get("X-App-Code") + app_code = str(request.headers.get("X-App-Code")) try: auth_header = request.headers.get("Authorization") if auth_header is None: @@ -57,35 +57,53 @@ def decode_jwt_token(): if not end_user: raise NotFound() - _validate_web_sso_token(decoded, system_features, app_code) + # for enterprise webapp auth + app_web_auth_enabled = False + if system_features.webapp_auth.enabled: + app_web_auth_enabled = ( + EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code).access_mode != "public" + ) + + _validate_webapp_token(decoded, app_web_auth_enabled, system_features.webapp_auth.enabled) + _validate_user_accessibility(decoded, app_code, app_web_auth_enabled, system_features.webapp_auth.enabled) return app_model, end_user except Unauthorized as e: - if system_features.sso_enforced_for_web: - app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) - if app_web_sso_enabled: - raise WebSSOAuthRequiredError() + if system_features.webapp_auth.enabled: + app_web_auth_enabled = ( + EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=str(app_code)).access_mode != "public" + ) + if app_web_auth_enabled: + raise WebAppAuthRequiredError() raise Unauthorized(e.description) -def _validate_web_sso_token(decoded, system_features, app_code): - app_web_sso_enabled = False - - # Check if SSO is enforced for web, and if the token source is not SSO, raise an error and redirect to SSO login - if system_features.sso_enforced_for_web: - app_web_sso_enabled = EnterpriseService.get_app_web_sso_enabled(app_code).get("enabled", False) - if app_web_sso_enabled: - source = decoded.get("token_source") - if not source or source != "sso": - raise WebSSOAuthRequiredError() +def _validate_webapp_token(decoded, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool): + # Check if authentication is enforced for web app, and if the token source is not webapp, + # raise an error and redirect to login + if system_webapp_auth_enabled and app_web_auth_enabled: + source = decoded.get("token_source") + if not source or source != "webapp": + raise WebAppAuthRequiredError() - # Check if SSO is not enforced for web, and if the token source is SSO, + # Check if authentication is not enforced for web, and if the token source is webapp, # raise an error and redirect to normal passport login - if not system_features.sso_enforced_for_web or not app_web_sso_enabled: + if not system_webapp_auth_enabled or not app_web_auth_enabled: source = decoded.get("token_source") - if source and source == "sso": - raise Unauthorized("sso token expired.") + if source and source == "webapp": + raise Unauthorized("webapp token expired.") + + +def _validate_user_accessibility(decoded, app_code, app_web_auth_enabled: bool, system_webapp_auth_enabled: bool): + if system_webapp_auth_enabled and app_web_auth_enabled: + # Check if the user is allowed to access the web app + user_id = decoded.get("user_id") + if not user_id: + raise WebAppAuthRequiredError() + + if not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(user_id, app_code=app_code): + raise WebAppAuthAccessDeniedError() class WebApiResource(Resource): diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 4b0e64130b..b74100bb19 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -29,9 +29,7 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory -from models.account import Account -from models.model import App, Conversation, EndUser, Message -from models.workflow import Workflow +from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from services.conversation_service import ConversationService from services.errors.message import MessageNotExistsError @@ -165,8 +163,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) return self._generate( @@ -231,8 +230,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) return self._generate( @@ -295,8 +295,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) return self._generate( diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index f71c49d112..735b2a9709 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -70,7 +70,7 @@ from events.message_event import message_was_created from extensions.ext_database import db from models import Conversation, EndUser, Message, MessageFile from models.account import Account -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.workflow import ( Workflow, WorkflowRunStatus, @@ -105,11 +105,11 @@ class AdvancedChatAppGenerateTaskPipeline: if isinstance(user, EndUser): self._user_id = user.id user_session_id = user.session_id - self._created_by_role = CreatedByRole.END_USER + self._created_by_role = CreatorUserRole.END_USER elif isinstance(user, Account): self._user_id = user.id user_session_id = user.id - self._created_by_role = CreatedByRole.ACCOUNT + self._created_by_role = CreatorUserRole.ACCOUNT else: raise NotImplementedError(f"User type not supported: {type(user)}") @@ -739,9 +739,9 @@ class AdvancedChatAppGenerateTaskPipeline: url=file["remote_url"], belongs_to="assistant", upload_file_id=file["related_id"], - created_by_role=CreatedByRole.ACCOUNT + created_by_role=CreatorUserRole.ACCOUNT if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else CreatedByRole.END_USER, + else CreatorUserRole.END_USER, created_by=message.from_account_id or message.from_end_user_id or "", ) for file in self._recorded_files diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 995082b79d..58b94f4d43 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -25,7 +25,7 @@ from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBa from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db from models import Account -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.conversation import ConversationNotExistsError @@ -223,7 +223,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): belongs_to="user", url=file.remote_url, upload_file_id=file.related_id, - created_by_role=(CreatedByRole.ACCOUNT if account_id else CreatedByRole.END_USER), + created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER), created_by=account_id or end_user_id or "", ) db.session.add(message_file) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 1d67671974..d49ff682b9 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -27,7 +27,7 @@ from core.workflow.repository.workflow_node_execution_repository import Workflow from core.workflow.workflow_app_generate_task_pipeline import WorkflowAppGenerateTaskPipeline from extensions.ext_database import db from factories import file_factory -from models import Account, App, EndUser, Workflow +from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -138,10 +138,12 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create workflow node execution repository session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) return self._generate( @@ -262,10 +264,12 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create workflow node execution repository session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) return self._generate( @@ -325,10 +329,12 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create workflow node execution repository session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( session_factory=session_factory, - tenant_id=application_generate_entity.app_config.tenant_id, + user=user, app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) return self._generate( diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 817699bd20..0c2d617f80 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, ConfigDict from core.model_runtime.entities.llm_entities import LLMResult from core.model_runtime.utils.encoders import jsonable_encoder -from core.workflow.entities.node_entities import AgentNodeStrategyInit +from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunMetadataKey from models.workflow import WorkflowNodeExecutionStatus @@ -244,7 +244,7 @@ class NodeStartStreamResponse(StreamResponse): title: str index: int predecessor_node_id: Optional[str] = None - inputs: Optional[dict] = None + inputs: Optional[Mapping[str, Any]] = None created_at: int extras: dict = {} parallel_id: Optional[str] = None @@ -301,13 +301,13 @@ class NodeFinishStreamResponse(StreamResponse): title: str index: int predecessor_node_id: Optional[str] = None - inputs: Optional[dict] = None - process_data: Optional[dict] = None - outputs: Optional[dict] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None status: str error: Optional[str] = None elapsed_time: float - execution_metadata: Optional[dict] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None created_at: int finished_at: int files: Optional[Sequence[Mapping[str, Any]]] = [] @@ -370,13 +370,13 @@ class NodeRetryStreamResponse(StreamResponse): title: str index: int predecessor_node_id: Optional[str] = None - inputs: Optional[dict] = None - process_data: Optional[dict] = None - outputs: Optional[dict] = None + inputs: Optional[Mapping[str, Any]] = None + process_data: Optional[Mapping[str, Any]] = None + outputs: Optional[Mapping[str, Any]] = None status: str error: Optional[str] = None elapsed_time: float - execution_metadata: Optional[dict] = None + execution_metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None created_at: int finished_at: int files: Optional[Sequence[Mapping[str, Any]]] = [] diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index 86887c9b4a..66d8d0f414 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -754,7 +754,7 @@ class ProviderConfiguration(BaseModel): :param only_active: return active model only :return: """ - provider_models = self.get_provider_models(model_type, only_active) + provider_models = self.get_provider_models(model_type, only_active, model) for provider_model in provider_models: if provider_model.model == model: @@ -763,12 +763,13 @@ class ProviderConfiguration(BaseModel): return None def get_provider_models( - self, model_type: Optional[ModelType] = None, only_active: bool = False + self, model_type: Optional[ModelType] = None, only_active: bool = False, model: Optional[str] = None ) -> list[ModelWithProviderEntity]: """ Get provider models. :param model_type: model type :param only_active: only active models + :param model: model name :return: """ model_provider_factory = ModelProviderFactory(self.tenant_id) @@ -791,7 +792,10 @@ class ProviderConfiguration(BaseModel): ) else: provider_models = self._get_custom_provider_models( - model_types=model_types, provider_schema=provider_schema, model_setting_map=model_setting_map + model_types=model_types, + provider_schema=provider_schema, + model_setting_map=model_setting_map, + model=model, ) if only_active: @@ -897,37 +901,36 @@ class ProviderConfiguration(BaseModel): ) except Exception as ex: logger.warning(f"get custom model schema failed, {ex}") - - if not custom_model_schema: - continue - - if custom_model_schema.model_type not in model_types: - continue - - status = ModelStatus.ACTIVE - if ( - custom_model_schema.model_type in model_setting_map - and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] - ): - model_setting = model_setting_map[custom_model_schema.model_type][ - custom_model_schema.model - ] - if model_setting.enabled is False: - status = ModelStatus.DISABLED - - provider_models.append( - ModelWithProviderEntity( - model=custom_model_schema.model, - label=custom_model_schema.label, - model_type=custom_model_schema.model_type, - features=custom_model_schema.features, - fetch_from=FetchFrom.PREDEFINED_MODEL, - model_properties=custom_model_schema.model_properties, - deprecated=custom_model_schema.deprecated, - provider=SimpleModelProviderEntity(self.provider), - status=status, - ) + continue + + if not custom_model_schema: + continue + + if custom_model_schema.model_type not in model_types: + continue + + status = ModelStatus.ACTIVE + if ( + custom_model_schema.model_type in model_setting_map + and custom_model_schema.model in model_setting_map[custom_model_schema.model_type] + ): + model_setting = model_setting_map[custom_model_schema.model_type][custom_model_schema.model] + if model_setting.enabled is False: + status = ModelStatus.DISABLED + + provider_models.append( + ModelWithProviderEntity( + model=custom_model_schema.model, + label=custom_model_schema.label, + model_type=custom_model_schema.model_type, + features=custom_model_schema.features, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties=custom_model_schema.model_properties, + deprecated=custom_model_schema.deprecated, + provider=SimpleModelProviderEntity(self.provider), + status=status, ) + ) # if llm name not in restricted llm list, remove it restrict_model_names = [rm.model for rm in restrict_models] @@ -944,6 +947,7 @@ class ProviderConfiguration(BaseModel): model_types: Sequence[ModelType], provider_schema: ProviderEntity, model_setting_map: dict[ModelType, dict[str, ModelSettings]], + model: Optional[str] = None, ) -> list[ModelWithProviderEntity]: """ Get custom provider models. @@ -996,7 +1000,8 @@ class ProviderConfiguration(BaseModel): for model_configuration in self.custom_configuration.models: if model_configuration.model_type not in model_types: continue - + if model and model != model_configuration.model: + continue try: custom_model_schema = self.get_model_schema( model_type=model_configuration.model_type, diff --git a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py index f486da3a6d..46ba1c45b9 100644 --- a/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py +++ b/api/core/ops/langfuse_trace/entities/langfuse_trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from datetime import datetime from enum import StrEnum from typing import Any, Optional, Union @@ -155,10 +156,10 @@ class LangfuseSpan(BaseModel): description="The status message of the span. Additional field for context of the event. E.g. the error " "message of an error event.", ) - input: Optional[Union[str, dict[str, Any], list, None]] = Field( + input: Optional[Union[str, Mapping[str, Any], list, None]] = Field( default=None, description="The input of the span. Can be any JSON object." ) - output: Optional[Union[str, dict[str, Any], list, None]] = Field( + output: Optional[Union[str, Mapping[str, Any], list, None]] = Field( default=None, description="The output of the span. Can be any JSON object." ) version: Optional[str] = Field( diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index c74617e558..120c36f53d 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -1,11 +1,10 @@ -import json import logging import os from datetime import datetime, timedelta from typing import Optional from langfuse import Langfuse # type: ignore -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import LangfuseConfig @@ -30,8 +29,9 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( ) from core.ops.utils import filter_none_values from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.model import EndUser +from models import Account, App, EndUser, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -113,8 +113,29 @@ class LangFuseDataTrace(BaseTraceInstance): # through workflow_run_id get all_nodes_execution using repository session_factory = sessionmaker(bind=db.engine) + # Find the app's creator account + with Session(db.engine, expire_on_commit=False) as session: + # Get the app to find its creator + app_id = trace_info.metadata.get("app_id") + if not app_id: + raise ValueError("No app_id found in trace_info metadata") + + app = session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError(f"App with id {app_id} not found") + + if not app.created_by: + raise ValueError(f"App with id {app_id} has no creator (created_by is None)") + + service_account = session.query(Account).filter(Account.id == app.created_by).first() + if not service_account: + raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=session_factory, tenant_id=trace_info.tenant_id + session_factory=session_factory, + user=service_account, + app_id=trace_info.metadata.get("app_id"), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) # Get all executions for this workflow run @@ -124,23 +145,22 @@ class LangFuseDataTrace(BaseTraceInstance): for node_execution in workflow_node_executions: node_execution_id = node_execution.id - tenant_id = node_execution.tenant_id - app_id = node_execution.app_id + tenant_id = trace_info.tenant_id # Use from trace_info instead + app_id = trace_info.metadata.get("app_id") # Use from trace_info instead node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == "llm": - inputs = ( - json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} - ) + if node_type == NodeType.LLM: + inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} + inputs = node_execution.inputs if node_execution.inputs else {} + outputs = node_execution.outputs if node_execution.outputs else {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - metadata = json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} + execution_metadata = node_execution.metadata if node_execution.metadata else {} + metadata = {str(k): v for k, v in execution_metadata.items()} metadata.update( { "workflow_run_id": trace_info.workflow_run_id, @@ -152,7 +172,7 @@ class LangFuseDataTrace(BaseTraceInstance): "status": status, } ) - process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + process_data = node_execution.process_data if node_execution.process_data else {} model_provider = process_data.get("model_provider", None) model_name = process_data.get("model_name", None) if model_provider is not None and model_name is not None: diff --git a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py index 348b7ba501..4fd01136ba 100644 --- a/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py +++ b/api/core/ops/langsmith_trace/entities/langsmith_trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from datetime import datetime from enum import StrEnum from typing import Any, Optional, Union @@ -30,8 +31,8 @@ class LangSmithMultiModel(BaseModel): class LangSmithRunModel(LangSmithTokenUsage, LangSmithMultiModel): name: Optional[str] = Field(..., description="Name of the run") - inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the run") - outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the run") + inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the run") + outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the run") run_type: LangSmithRunType = Field(..., description="Type of the run") start_time: Optional[datetime | str] = Field(None, description="Start time of the run") end_time: Optional[datetime | str] = Field(None, description="End time of the run") diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index d1e16d3152..6631727c79 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -1,4 +1,3 @@ -import json import logging import os import uuid @@ -7,7 +6,7 @@ from typing import Optional, cast from langsmith import Client from langsmith.schemas import RunBase -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import LangSmithConfig @@ -29,8 +28,10 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( ) from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.model import EndUser, MessageFile +from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -137,8 +138,29 @@ class LangSmithDataTrace(BaseTraceInstance): # through workflow_run_id get all_nodes_execution using repository session_factory = sessionmaker(bind=db.engine) + # Find the app's creator account + with Session(db.engine, expire_on_commit=False) as session: + # Get the app to find its creator + app_id = trace_info.metadata.get("app_id") + if not app_id: + raise ValueError("No app_id found in trace_info metadata") + + app = session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError(f"App with id {app_id} not found") + + if not app.created_by: + raise ValueError(f"App with id {app_id} has no creator (created_by is None)") + + service_account = session.query(Account).filter(Account.id == app.created_by).first() + if not service_account: + raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id") + session_factory=session_factory, + user=service_account, + app_id=trace_info.metadata.get("app_id"), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) # Get all executions for this workflow run @@ -148,27 +170,23 @@ class LangSmithDataTrace(BaseTraceInstance): for node_execution in workflow_node_executions: node_execution_id = node_execution.id - tenant_id = node_execution.tenant_id - app_id = node_execution.app_id + tenant_id = trace_info.tenant_id # Use from trace_info instead + app_id = trace_info.metadata.get("app_id") # Use from trace_info instead node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == "llm": - inputs = ( - json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} - ) + if node_type == NodeType.LLM: + inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} + inputs = node_execution.inputs if node_execution.inputs else {} + outputs = node_execution.outputs if node_execution.outputs else {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = ( - json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} - ) - node_total_tokens = execution_metadata.get("total_tokens", 0) - metadata = execution_metadata.copy() + execution_metadata = node_execution.metadata if node_execution.metadata else {} + node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 + metadata = {str(key): value for key, value in execution_metadata.items()} metadata.update( { "workflow_run_id": trace_info.workflow_run_id, @@ -181,7 +199,7 @@ class LangSmithDataTrace(BaseTraceInstance): } ) - process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + process_data = node_execution.process_data if node_execution.process_data else {} if process_data and process_data.get("model_mode") == "chat": run_type = LangSmithRunType.llm @@ -191,7 +209,7 @@ class LangSmithDataTrace(BaseTraceInstance): "ls_model_name": process_data.get("model_name", ""), } ) - elif node_type == "knowledge-retrieval": + elif node_type == NodeType.KNOWLEDGE_RETRIEVAL: run_type = LangSmithRunType.retriever else: run_type = LangSmithRunType.tool diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 1484041447..c22df55357 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -1,4 +1,3 @@ -import json import logging import os import uuid @@ -7,7 +6,7 @@ from typing import Optional, cast from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session, sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import OpikConfig @@ -23,8 +22,10 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.model import EndUser, MessageFile +from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -150,8 +151,29 @@ class OpikDataTrace(BaseTraceInstance): # through workflow_run_id get all_nodes_execution using repository session_factory = sessionmaker(bind=db.engine) + # Find the app's creator account + with Session(db.engine, expire_on_commit=False) as session: + # Get the app to find its creator + app_id = trace_info.metadata.get("app_id") + if not app_id: + raise ValueError("No app_id found in trace_info metadata") + + app = session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError(f"App with id {app_id} not found") + + if not app.created_by: + raise ValueError(f"App with id {app_id} has no creator (created_by is None)") + + service_account = session.query(Account).filter(Account.id == app.created_by).first() + if not service_account: + raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=session_factory, tenant_id=trace_info.tenant_id, app_id=trace_info.metadata.get("app_id") + session_factory=session_factory, + user=service_account, + app_id=trace_info.metadata.get("app_id"), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) # Get all executions for this workflow run @@ -161,26 +183,22 @@ class OpikDataTrace(BaseTraceInstance): for node_execution in workflow_node_executions: node_execution_id = node_execution.id - tenant_id = node_execution.tenant_id - app_id = node_execution.app_id + tenant_id = trace_info.tenant_id # Use from trace_info instead + app_id = trace_info.metadata.get("app_id") # Use from trace_info instead node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == "llm": - inputs = ( - json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} - ) + if node_type == NodeType.LLM: + inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} + inputs = node_execution.inputs if node_execution.inputs else {} + outputs = node_execution.outputs if node_execution.outputs else {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = ( - json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} - ) - metadata = execution_metadata.copy() + execution_metadata = node_execution.metadata if node_execution.metadata else {} + metadata = {str(k): v for k, v in execution_metadata.items()} metadata.update( { "workflow_run_id": trace_info.workflow_run_id, @@ -193,7 +211,7 @@ class OpikDataTrace(BaseTraceInstance): } ) - process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + process_data = node_execution.process_data if node_execution.process_data else {} provider = None model = None @@ -226,7 +244,7 @@ class OpikDataTrace(BaseTraceInstance): parent_span_id = trace_info.workflow_app_log_id or trace_info.workflow_run_id if not total_tokens: - total_tokens = execution_metadata.get("total_tokens", 0) + total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 span_data = { "trace_id": opik_trace_id, diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 2bcca6ccea..8ddeb1b846 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -287,7 +287,9 @@ class OpsTraceManager: :return: """ # auth check - if tracing_provider not in provider_config_map and tracing_provider is not None: + try: + provider_config_map[tracing_provider] + except KeyError: raise ValueError(f"Invalid tracing provider: {tracing_provider}") app_config: Optional[App] = db.session.query(App).filter(App.id == app_id).first() diff --git a/api/core/ops/weave_trace/entities/weave_trace_entity.py b/api/core/ops/weave_trace/entities/weave_trace_entity.py index e423f5ccbb..7f489f37ac 100644 --- a/api/core/ops/weave_trace/entities/weave_trace_entity.py +++ b/api/core/ops/weave_trace/entities/weave_trace_entity.py @@ -1,3 +1,4 @@ +from collections.abc import Mapping from typing import Any, Optional, Union from pydantic import BaseModel, Field, field_validator @@ -19,8 +20,8 @@ class WeaveMultiModel(BaseModel): class WeaveTraceModel(WeaveTokenUsage, WeaveMultiModel): id: str = Field(..., description="ID of the trace") op: str = Field(..., description="Name of the operation") - inputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Inputs of the trace") - outputs: Optional[Union[str, dict[str, Any], list, None]] = Field(None, description="Outputs of the trace") + inputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Inputs of the trace") + outputs: Optional[Union[str, Mapping[str, Any], list, None]] = Field(None, description="Outputs of the trace") attributes: Optional[Union[str, dict[str, Any], list, None]] = Field( None, description="Metadata and attributes associated with trace" ) diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 49594cb0f1..a4f38dfbba 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -1,4 +1,3 @@ -import json import logging import os import uuid @@ -7,6 +6,7 @@ from typing import Any, Optional, cast import wandb import weave +from sqlalchemy.orm import Session, sessionmaker from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import WeaveConfig @@ -22,9 +22,11 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel +from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.nodes.enums import NodeType from extensions.ext_database import db -from models.model import EndUser, MessageFile -from models.workflow import WorkflowNodeExecution +from models import Account, App, EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom logger = logging.getLogger(__name__) @@ -128,58 +130,57 @@ class WeaveDataTrace(BaseTraceInstance): self.start_call(workflow_run, parent_run_id=trace_info.message_id) - # through workflow_run_id get all_nodes_execution - workflow_nodes_execution_id_records = ( - db.session.query(WorkflowNodeExecution.id) - .filter(WorkflowNodeExecution.workflow_run_id == trace_info.workflow_run_id) - .all() + # through workflow_run_id get all_nodes_execution using repository + session_factory = sessionmaker(bind=db.engine) + # Find the app's creator account + with Session(db.engine, expire_on_commit=False) as session: + # Get the app to find its creator + app_id = trace_info.metadata.get("app_id") + if not app_id: + raise ValueError("No app_id found in trace_info metadata") + + app = session.query(App).filter(App.id == app_id).first() + if not app: + raise ValueError(f"App with id {app_id} not found") + + if not app.created_by: + raise ValueError(f"App with id {app_id} has no creator (created_by is None)") + + service_account = session.query(Account).filter(Account.id == app.created_by).first() + if not service_account: + raise ValueError(f"Creator account with id {app.created_by} not found for app {app_id}") + + workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + session_factory=session_factory, + user=service_account, + app_id=trace_info.metadata.get("app_id"), + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - for node_execution_id_record in workflow_nodes_execution_id_records: - node_execution = ( - db.session.query( - WorkflowNodeExecution.id, - WorkflowNodeExecution.tenant_id, - WorkflowNodeExecution.app_id, - WorkflowNodeExecution.title, - WorkflowNodeExecution.node_type, - WorkflowNodeExecution.status, - WorkflowNodeExecution.inputs, - WorkflowNodeExecution.outputs, - WorkflowNodeExecution.created_at, - WorkflowNodeExecution.elapsed_time, - WorkflowNodeExecution.process_data, - WorkflowNodeExecution.execution_metadata, - ) - .filter(WorkflowNodeExecution.id == node_execution_id_record.id) - .first() - ) - - if not node_execution: - continue + # Get all executions for this workflow run + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( + workflow_run_id=trace_info.workflow_run_id + ) + for node_execution in workflow_node_executions: node_execution_id = node_execution.id - tenant_id = node_execution.tenant_id - app_id = node_execution.app_id + tenant_id = trace_info.tenant_id # Use from trace_info instead + app_id = trace_info.metadata.get("app_id") # Use from trace_info instead node_name = node_execution.title node_type = node_execution.node_type status = node_execution.status - if node_type == "llm": - inputs = ( - json.loads(node_execution.process_data).get("prompts", {}) if node_execution.process_data else {} - ) + if node_type == NodeType.LLM: + inputs = node_execution.process_data.get("prompts", {}) if node_execution.process_data else {} else: - inputs = json.loads(node_execution.inputs) if node_execution.inputs else {} - outputs = json.loads(node_execution.outputs) if node_execution.outputs else {} + inputs = node_execution.inputs if node_execution.inputs else {} + outputs = node_execution.outputs if node_execution.outputs else {} created_at = node_execution.created_at or datetime.now() elapsed_time = node_execution.elapsed_time finished_at = created_at + timedelta(seconds=elapsed_time) - execution_metadata = ( - json.loads(node_execution.execution_metadata) if node_execution.execution_metadata else {} - ) - node_total_tokens = execution_metadata.get("total_tokens", 0) - attributes = execution_metadata.copy() + execution_metadata = node_execution.metadata if node_execution.metadata else {} + node_total_tokens = execution_metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) or 0 + attributes = {str(k): v for k, v in execution_metadata.items()} attributes.update( { "workflow_run_id": trace_info.workflow_run_id, @@ -192,7 +193,7 @@ class WeaveDataTrace(BaseTraceInstance): } ) - process_data = json.loads(node_execution.process_data) if node_execution.process_data else {} + process_data = node_execution.process_data if node_execution.process_data else {} if process_data and process_data.get("model_mode") == "chat": attributes.update( { diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index db07e52f3f..7898795ce2 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -64,9 +64,9 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): ) return { - "inputs": execution.inputs_dict, - "outputs": execution.outputs_dict, - "process_data": execution.process_data_dict, + "inputs": execution.inputs, + "outputs": execution.outputs, + "process_data": execution.process_data, } @classmethod @@ -113,7 +113,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): ) return { - "inputs": execution.inputs_dict, - "outputs": execution.outputs_dict, - "process_data": execution.process_data_dict, + "inputs": execution.inputs, + "outputs": execution.outputs, + "process_data": execution.process_data, } diff --git a/api/core/plugin/endpoint/exc.py b/api/core/plugin/endpoint/exc.py new file mode 100644 index 0000000000..aa29f1e9a1 --- /dev/null +++ b/api/core/plugin/endpoint/exc.py @@ -0,0 +1,6 @@ +class EndpointSetupFailedError(ValueError): + """ + Endpoint setup failed error + """ + + pass diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 4f1d808a3e..591e7b0525 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -17,6 +17,7 @@ from core.model_runtime.errors.invoke import ( InvokeServerUnavailableError, ) from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.plugin.endpoint.exc import EndpointSetupFailedError from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError from core.plugin.impl.exc import ( PluginDaemonBadRequestError, @@ -219,6 +220,8 @@ class BasePluginClient: raise InvokeServerUnavailableError(description=args.get("description")) case CredentialsValidateFailedError.__name__: raise CredentialsValidateFailedError(error_object.get("message")) + case EndpointSetupFailedError.__name__: + raise EndpointSetupFailedError(error_object.get("message")) case _: raise PluginInvokeError(description=message) case PluginDaemonInternalServerError.__name__: diff --git a/api/core/rag/datasource/vdb/pgvector/pgvector.py b/api/core/rag/datasource/vdb/pgvector/pgvector.py index 366a21c381..04e9cf801e 100644 --- a/api/core/rag/datasource/vdb/pgvector/pgvector.py +++ b/api/core/rag/datasource/vdb/pgvector/pgvector.py @@ -1,3 +1,4 @@ +import hashlib import json import logging import uuid @@ -61,12 +62,12 @@ CREATE TABLE IF NOT EXISTS {table_name} ( """ SQL_CREATE_INDEX = """ -CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx ON {table_name} +CREATE INDEX IF NOT EXISTS embedding_cosine_v1_idx_{index_hash} ON {table_name} USING hnsw (embedding vector_cosine_ops) WITH (m = 16, ef_construction = 64); """ SQL_CREATE_INDEX_PG_BIGM = """ -CREATE INDEX IF NOT EXISTS bigm_idx ON {table_name} +CREATE INDEX IF NOT EXISTS bigm_idx_{index_hash} ON {table_name} USING gin (text gin_bigm_ops); """ @@ -76,6 +77,7 @@ class PGVector(BaseVector): super().__init__(collection_name) self.pool = self._create_connection_pool(config) self.table_name = f"embedding_{collection_name}" + self.index_hash = hashlib.md5(self.table_name.encode()).hexdigest()[:8] self.pg_bigm = config.pg_bigm def get_type(self) -> str: @@ -256,10 +258,9 @@ class PGVector(BaseVector): # PG hnsw index only support 2000 dimension or less # ref: https://github.com/pgvector/pgvector?tab=readme-ov-file#indexing if dimension <= 2000: - cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name)) + cur.execute(SQL_CREATE_INDEX.format(table_name=self.table_name, index_hash=self.index_hash)) if self.pg_bigm: - cur.execute("CREATE EXTENSION IF NOT EXISTS pg_bigm") - cur.execute(SQL_CREATE_INDEX_PG_BIGM.format(table_name=self.table_name)) + cur.execute(SQL_CREATE_INDEX_PG_BIGM.format(table_name=self.table_name, index_hash=self.index_hash)) redis_client.set(collection_exist_cache_key, 1, ex=3600) diff --git a/api/core/rag/extractor/watercrawl/client.py b/api/core/rag/extractor/watercrawl/client.py index 6eaede7dbc..6d596e07d8 100644 --- a/api/core/rag/extractor/watercrawl/client.py +++ b/api/core/rag/extractor/watercrawl/client.py @@ -6,6 +6,12 @@ from urllib.parse import urljoin import requests from requests import Response +from core.rag.extractor.watercrawl.exceptions import ( + WaterCrawlAuthenticationError, + WaterCrawlBadRequestError, + WaterCrawlPermissionError, +) + class BaseAPIClient: def __init__(self, api_key, base_url): @@ -53,6 +59,15 @@ class WaterCrawlAPIClient(BaseAPIClient): yield data def process_response(self, response: Response) -> dict | bytes | list | None | Generator: + if response.status_code == 401: + raise WaterCrawlAuthenticationError(response) + + if response.status_code == 403: + raise WaterCrawlPermissionError(response) + + if 400 <= response.status_code < 500: + raise WaterCrawlBadRequestError(response) + response.raise_for_status() if response.status_code == 204: return None diff --git a/api/core/rag/extractor/watercrawl/exceptions.py b/api/core/rag/extractor/watercrawl/exceptions.py new file mode 100644 index 0000000000..e407a594e0 --- /dev/null +++ b/api/core/rag/extractor/watercrawl/exceptions.py @@ -0,0 +1,32 @@ +import json + + +class WaterCrawlError(Exception): + pass + + +class WaterCrawlBadRequestError(WaterCrawlError): + def __init__(self, response): + self.status_code = response.status_code + self.response = response + data = response.json() + self.message = data.get("message", "Unknown error occurred") + self.errors = data.get("errors", {}) + super().__init__(self.message) + + @property + def flat_errors(self): + return json.dumps(self.errors) + + def __str__(self): + return f"WaterCrawlBadRequestError: {self.message} \n {self.flat_errors}" + + +class WaterCrawlPermissionError(WaterCrawlBadRequestError): + def __str__(self): + return f"You are exceeding your WaterCrawl API limits. {self.message}" + + +class WaterCrawlAuthenticationError(WaterCrawlBadRequestError): + def __str__(self): + return "WaterCrawl API key is invalid or expired. Please check your API key and try again." diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index a4ccdcafd3..bff0acc48f 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -19,7 +19,7 @@ from core.rag.extractor.extractor_base import BaseExtractor from core.rag.models.document import Document from extensions.ext_database import db from extensions.ext_storage import storage -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import UploadFile logger = logging.getLogger(__name__) @@ -116,7 +116,7 @@ class WordExtractor(BaseExtractor): extension=str(image_ext), mime_type=mime_type or "", created_by=self.user_id, - created_by_role=CreatedByRole.ACCOUNT, + created_by_role=CreatorUserRole.ACCOUNT, created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=True, used_by=self.user_id, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index d3605da146..c4adf6de4d 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -190,7 +190,7 @@ class DatasetRetrieval: retrieve_config.rerank_mode or "reranking_model", retrieve_config.reranking_model, retrieve_config.weights, - retrieve_config.reranking_enabled or True, + True if retrieve_config.reranking_enabled is None else retrieve_config.reranking_enabled, message_id, metadata_filter_document_ids, metadata_condition, diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 8bf2ab8761..7e0115e4e7 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -2,16 +2,31 @@ SQLAlchemy implementation of the WorkflowNodeExecutionRepository. """ +import json import logging -from collections.abc import Sequence -from typing import Optional +from collections.abc import Mapping, Sequence +from typing import Any, Optional, Union, cast from sqlalchemy import UnaryExpression, asc, delete, desc, select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker +from core.model_runtime.utils.encoders import jsonable_encoder +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.entities.node_execution_entities import ( + NodeExecution, + NodeExecutionStatus, +) +from core.workflow.nodes.enums import NodeType from core.workflow.repository.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom +from models import ( + Account, + CreatorUserRole, + EndUser, + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, + WorkflowNodeExecutionTriggeredFrom, +) logger = logging.getLogger(__name__) @@ -23,16 +38,26 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) This implementation supports multi-tenancy by filtering operations based on tenant_id. Each method creates its own session, handles the transaction, and commits changes to the database. This prevents long-running connections in the workflow core. + + This implementation also includes an in-memory cache for node executions to improve + performance by reducing database queries. """ - def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None): + def __init__( + self, + session_factory: sessionmaker | Engine, + user: Union[Account, EndUser], + app_id: Optional[str], + triggered_from: Optional[WorkflowNodeExecutionTriggeredFrom], + ): """ - Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context. + Initialize the repository with a SQLAlchemy sessionmaker or engine and context information. Args: session_factory: SQLAlchemy sessionmaker or engine for creating sessions - tenant_id: Tenant ID for multi-tenancy - app_id: Optional app ID for filtering by application + user: Account or EndUser object containing tenant_id, user ID, and role information + app_id: App ID for filtering by application (can be None) + triggered_from: Source of the execution trigger (SINGLE_STEP or WORKFLOW_RUN) """ # If an engine is provided, create a sessionmaker from it if isinstance(session_factory, Engine): @@ -44,38 +69,170 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) f"Invalid session_factory type {type(session_factory).__name__}; expected sessionmaker or Engine" ) + # Extract tenant_id from user + tenant_id: str | None = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id + if not tenant_id: + raise ValueError("User must have a tenant_id or current_tenant_id") self._tenant_id = tenant_id + + # Store app context self._app_id = app_id - def save(self, execution: WorkflowNodeExecution) -> None: + # Extract user context + self._triggered_from = triggered_from + self._creator_user_id = user.id + + # Determine user role based on user type + self._creator_user_role = CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER + + # Initialize in-memory cache for node executions + # Key: node_execution_id, Value: NodeExecution + self._node_execution_cache: dict[str, NodeExecution] = {} + + def _to_domain_model(self, db_model: WorkflowNodeExecution) -> NodeExecution: """ - Save a WorkflowNodeExecution instance and commit changes to the database. + Convert a database model to a domain model. Args: - execution: The WorkflowNodeExecution instance to save + db_model: The database model to convert + + Returns: + The domain model """ - with self._session_factory() as session: - # Ensure tenant_id is set - if not execution.tenant_id: - execution.tenant_id = self._tenant_id + # Parse JSON fields + inputs = db_model.inputs_dict + process_data = db_model.process_data_dict + outputs = db_model.outputs_dict + metadata = db_model.execution_metadata_dict + + # Convert status to domain enum + status = NodeExecutionStatus(db_model.status) + + return NodeExecution( + id=db_model.id, + node_execution_id=db_model.node_execution_id, + workflow_id=db_model.workflow_id, + workflow_run_id=db_model.workflow_run_id, + index=db_model.index, + predecessor_node_id=db_model.predecessor_node_id, + node_id=db_model.node_id, + node_type=NodeType(db_model.node_type), + title=db_model.title, + inputs=inputs, + process_data=process_data, + outputs=outputs, + status=status, + error=db_model.error, + elapsed_time=db_model.elapsed_time, + # FIXME(QuantumGhost): a temporary workaround for the following type check failure in Python 3.11. + # However, this problem is not occurred in Python 3.12. + # + # A case of this error is: + # https://github.com/langgenius/dify/actions/runs/15112698604/job/42475659482?pr=19737#step:9:24 + metadata=cast(Mapping[NodeRunMetadataKey, Any] | None, metadata), + created_at=db_model.created_at, + finished_at=db_model.finished_at, + ) + + def to_db_model(self, domain_model: NodeExecution) -> WorkflowNodeExecution: + """ + Convert a domain model to a database model. - # Set app_id if provided and not already set - if self._app_id and not execution.app_id: - execution.app_id = self._app_id + Args: + domain_model: The domain model to convert + + Returns: + The database model + """ + # Use values from constructor if provided + if not self._triggered_from: + raise ValueError("triggered_from is required in repository constructor") + if not self._creator_user_id: + raise ValueError("created_by is required in repository constructor") + if not self._creator_user_role: + raise ValueError("created_by_role is required in repository constructor") + + db_model = WorkflowNodeExecution() + db_model.id = domain_model.id + db_model.tenant_id = self._tenant_id + if self._app_id is not None: + db_model.app_id = self._app_id + db_model.workflow_id = domain_model.workflow_id + db_model.triggered_from = self._triggered_from + db_model.workflow_run_id = domain_model.workflow_run_id + db_model.index = domain_model.index + db_model.predecessor_node_id = domain_model.predecessor_node_id + db_model.node_execution_id = domain_model.node_execution_id + db_model.node_id = domain_model.node_id + db_model.node_type = domain_model.node_type + db_model.title = domain_model.title + db_model.inputs = json.dumps(domain_model.inputs) if domain_model.inputs else None + db_model.process_data = json.dumps(domain_model.process_data) if domain_model.process_data else None + db_model.outputs = json.dumps(domain_model.outputs) if domain_model.outputs else None + db_model.status = domain_model.status + db_model.error = domain_model.error + db_model.elapsed_time = domain_model.elapsed_time + db_model.execution_metadata = ( + json.dumps(jsonable_encoder(domain_model.metadata)) if domain_model.metadata else None + ) + db_model.created_at = domain_model.created_at + db_model.created_by_role = self._creator_user_role + db_model.created_by = self._creator_user_id + db_model.finished_at = domain_model.finished_at + return db_model + + def save(self, execution: NodeExecution) -> None: + """ + Save or update a NodeExecution domain entity to the database. + + This method serves as a domain-to-database adapter that: + 1. Converts the domain entity to its database representation + 2. Persists the database model using SQLAlchemy's merge operation + 3. Maintains proper multi-tenancy by including tenant context during conversion + 4. Updates the in-memory cache for faster subsequent lookups + + The method handles both creating new records and updating existing ones through + SQLAlchemy's merge operation. + + Args: + execution: The NodeExecution domain entity to persist + """ + # Convert domain model to database model using tenant context and other attributes + db_model = self.to_db_model(execution) - session.add(execution) + # Create a new database session + with self._session_factory() as session: + # SQLAlchemy merge intelligently handles both insert and update operations + # based on the presence of the primary key + session.merge(db_model) session.commit() - def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: + # Update the in-memory cache for faster subsequent lookups + # Only cache if we have a node_execution_id to use as the cache key + if db_model.node_execution_id: + logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}") + self._node_execution_cache[db_model.node_execution_id] = execution + + def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]: """ - Retrieve a WorkflowNodeExecution by its node_execution_id. + Retrieve a NodeExecution by its node_execution_id. + + First checks the in-memory cache, and if not found, queries the database. + If found in the database, adds it to the cache for future lookups. Args: node_execution_id: The node execution ID Returns: - The WorkflowNodeExecution instance if found, None otherwise + The NodeExecution instance if found, None otherwise """ + # First check the cache + if node_execution_id in self._node_execution_cache: + logger.debug(f"Cache hit for node_execution_id: {node_execution_id}") + return self._node_execution_cache[node_execution_id] + + # If not in cache, query the database + logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database") with self._session_factory() as session: stmt = select(WorkflowNodeExecution).where( WorkflowNodeExecution.node_execution_id == node_execution_id, @@ -85,15 +242,28 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) if self._app_id: stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) - return session.scalar(stmt) + db_model = session.scalar(stmt) + if db_model: + # Convert to domain model + domain_model = self._to_domain_model(db_model) + + # Add to cache + self._node_execution_cache[node_execution_id] = domain_model + + return domain_model + + return None def get_by_workflow_run( self, workflow_run_id: str, order_config: Optional[OrderConfig] = None, - ) -> Sequence[WorkflowNodeExecution]: + ) -> Sequence[NodeExecution]: """ - Retrieve all WorkflowNodeExecution instances for a specific workflow run. + Retrieve all NodeExecution instances for a specific workflow run. + + This method always queries the database to ensure complete and ordered results, + but updates the cache with any retrieved executions. Args: workflow_run_id: The workflow run ID @@ -102,7 +272,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of WorkflowNodeExecution instances + A list of NodeExecution instances """ with self._session_factory() as session: stmt = select(WorkflowNodeExecution).where( @@ -129,17 +299,31 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) if order_columns: stmt = stmt.order_by(*order_columns) - return session.scalars(stmt).all() + db_models = session.scalars(stmt).all() + + # Convert database models to domain models and update cache + domain_models = [] + for model in db_models: + domain_model = self._to_domain_model(model) + # Update cache if node_execution_id is present + if domain_model.node_execution_id: + self._node_execution_cache[domain_model.node_execution_id] = domain_model + domain_models.append(domain_model) + + return domain_models - def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: + def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]: """ - Retrieve all running WorkflowNodeExecution instances for a specific workflow run. + Retrieve all running NodeExecution instances for a specific workflow run. + + This method queries the database directly and updates the cache with any + retrieved executions that have a node_execution_id. Args: workflow_run_id: The workflow run ID Returns: - A list of running WorkflowNodeExecution instances + A list of running NodeExecution instances """ with self._session_factory() as session: stmt = select(WorkflowNodeExecution).where( @@ -152,26 +336,17 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) if self._app_id: stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id) - return session.scalars(stmt).all() + db_models = session.scalars(stmt).all() + domain_models = [] - def update(self, execution: WorkflowNodeExecution) -> None: - """ - Update an existing WorkflowNodeExecution instance and commit changes to the database. + for model in db_models: + domain_model = self._to_domain_model(model) + # Update cache if node_execution_id is present + if domain_model.node_execution_id: + self._node_execution_cache[domain_model.node_execution_id] = domain_model + domain_models.append(domain_model) - Args: - execution: The WorkflowNodeExecution instance to update - """ - with self._session_factory() as session: - # Ensure tenant_id is set - if not execution.tenant_id: - execution.tenant_id = self._tenant_id - - # Set app_id if provided and not already set - if self._app_id and not execution.app_id: - execution.app_id = self._app_id - - session.merge(execution) - session.commit() + return domain_models def clear(self) -> None: """ @@ -179,6 +354,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) This method deletes all WorkflowNodeExecution records that match the tenant_id and app_id (if provided) associated with this repository instance. + It also clears the in-memory cache. """ with self._session_factory() as session: stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id) @@ -194,3 +370,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}" + (f" and app {self._app_id}" if self._app_id else "") ) + + # Clear the in-memory cache + self._node_execution_cache.clear() + logger.info("Cleared in-memory node execution cache") diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 3dce1ca293..178f2b9689 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -32,7 +32,7 @@ from core.tools.errors import ( from core.tools.utils.message_transformer import ToolFileMessageTransformer from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import Message, MessageFile @@ -339,9 +339,9 @@ class ToolEngine: url=message.url, upload_file_id=tool_file_id, created_by_role=( - CreatedByRole.ACCOUNT + CreatorUserRole.ACCOUNT if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else CreatedByRole.END_USER + else CreatorUserRole.END_USER ), created_by=user_id, ) diff --git a/api/core/variables/consts.py b/api/core/variables/consts.py new file mode 100644 index 0000000000..03b277d619 --- /dev/null +++ b/api/core/variables/consts.py @@ -0,0 +1,7 @@ +# The minimal selector length for valid variables. +# +# The first element of the selector is the node id, and the second element is the variable name. +# +# If the selector length is more than 2, the remaining parts are the keys / indexes paths used +# to extract part of the variable value. +MIN_SELECTORS_LENGTH = 2 diff --git a/api/core/variables/utils.py b/api/core/variables/utils.py new file mode 100644 index 0000000000..e5d222af7d --- /dev/null +++ b/api/core/variables/utils.py @@ -0,0 +1,8 @@ +from collections.abc import Iterable, Sequence + + +def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]: + selectors = [node_id, name] + if paths: + selectors.extend(paths) + return selectors diff --git a/api/core/workflow/entities/node_execution_entities.py b/api/core/workflow/entities/node_execution_entities.py new file mode 100644 index 0000000000..5e5ead062f --- /dev/null +++ b/api/core/workflow/entities/node_execution_entities.py @@ -0,0 +1,98 @@ +""" +Domain entities for workflow node execution. + +This module contains the domain model for workflow node execution, which is used +by the core workflow module. These models are independent of the storage mechanism +and don't contain implementation details like tenant_id, app_id, etc. +""" + +from collections.abc import Mapping +from datetime import datetime +from enum import StrEnum +from typing import Any, Optional + +from pydantic import BaseModel, Field + +from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.nodes.enums import NodeType + + +class NodeExecutionStatus(StrEnum): + """ + Node Execution Status Enum. + """ + + RUNNING = "running" + SUCCEEDED = "succeeded" + FAILED = "failed" + EXCEPTION = "exception" + RETRY = "retry" + + +class NodeExecution(BaseModel): + """ + Domain model for workflow node execution. + + This model represents the core business entity of a node execution, + without implementation details like tenant_id, app_id, etc. + + Note: User/context-specific fields (triggered_from, created_by, created_by_role) + have been moved to the repository implementation to keep the domain model clean. + These fields are still accepted in the constructor for backward compatibility, + but they are not stored in the model. + """ + + # Core identification fields + id: str # Unique identifier for this execution record + node_execution_id: Optional[str] = None # Optional secondary ID for cross-referencing + workflow_id: str # ID of the workflow this node belongs to + workflow_run_id: Optional[str] = None # ID of the specific workflow run (null for single-step debugging) + + # Execution positioning and flow + index: int # Sequence number for ordering in trace visualization + predecessor_node_id: Optional[str] = None # ID of the node that executed before this one + node_id: str # ID of the node being executed + node_type: NodeType # Type of node (e.g., start, llm, knowledge) + title: str # Display title of the node + + # Execution data + inputs: Optional[Mapping[str, Any]] = None # Input variables used by this node + process_data: Optional[Mapping[str, Any]] = None # Intermediate processing data + outputs: Optional[Mapping[str, Any]] = None # Output variables produced by this node + + # Execution state + status: NodeExecutionStatus = NodeExecutionStatus.RUNNING # Current execution status + error: Optional[str] = None # Error message if execution failed + elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds + + # Additional metadata + metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None # Execution metadata (tokens, cost, etc.) + + # Timing information + created_at: datetime # When execution started + finished_at: Optional[datetime] = None # When execution completed + + def update_from_mapping( + self, + inputs: Optional[Mapping[str, Any]] = None, + process_data: Optional[Mapping[str, Any]] = None, + outputs: Optional[Mapping[str, Any]] = None, + metadata: Optional[Mapping[NodeRunMetadataKey, Any]] = None, + ) -> None: + """ + Update the model from mappings. + + Args: + inputs: The inputs to update + process_data: The process data to update + outputs: The outputs to update + metadata: The metadata to update + """ + if inputs is not None: + self.inputs = dict(inputs) + if process_data is not None: + self.process_data = dict(process_data) + if outputs is not None: + self.outputs = dict(outputs) + if metadata is not None: + self.metadata = dict(metadata) diff --git a/api/core/workflow/repository/workflow_node_execution_repository.py b/api/core/workflow/repository/workflow_node_execution_repository.py index 9bb790cb0f..3ca9e2ecab 100644 --- a/api/core/workflow/repository/workflow_node_execution_repository.py +++ b/api/core/workflow/repository/workflow_node_execution_repository.py @@ -2,12 +2,12 @@ from collections.abc import Sequence from dataclasses import dataclass from typing import Literal, Optional, Protocol -from models.workflow import WorkflowNodeExecution +from core.workflow.entities.node_execution_entities import NodeExecution @dataclass class OrderConfig: - """Configuration for ordering WorkflowNodeExecution instances.""" + """Configuration for ordering NodeExecution instances.""" order_by: list[str] order_direction: Optional[Literal["asc", "desc"]] = None @@ -15,10 +15,10 @@ class OrderConfig: class WorkflowNodeExecutionRepository(Protocol): """ - Repository interface for WorkflowNodeExecution. + Repository interface for NodeExecution. This interface defines the contract for accessing and manipulating - WorkflowNodeExecution data, regardless of the underlying storage mechanism. + NodeExecution data, regardless of the underlying storage mechanism. Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), and trigger sources (triggered_from) should be handled at the implementation level, not in @@ -26,24 +26,28 @@ class WorkflowNodeExecutionRepository(Protocol): application domains or deployment scenarios. """ - def save(self, execution: WorkflowNodeExecution) -> None: + def save(self, execution: NodeExecution) -> None: """ - Save a WorkflowNodeExecution instance. + Save or update a NodeExecution instance. + + This method handles both creating new records and updating existing ones. + The implementation should determine whether to create or update based on + the execution's ID or other identifying fields. Args: - execution: The WorkflowNodeExecution instance to save + execution: The NodeExecution instance to save or update """ ... - def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: + def get_by_node_execution_id(self, node_execution_id: str) -> Optional[NodeExecution]: """ - Retrieve a WorkflowNodeExecution by its node_execution_id. + Retrieve a NodeExecution by its node_execution_id. Args: node_execution_id: The node execution ID Returns: - The WorkflowNodeExecution instance if found, None otherwise + The NodeExecution instance if found, None otherwise """ ... @@ -51,9 +55,9 @@ class WorkflowNodeExecutionRepository(Protocol): self, workflow_run_id: str, order_config: Optional[OrderConfig] = None, - ) -> Sequence[WorkflowNodeExecution]: + ) -> Sequence[NodeExecution]: """ - Retrieve all WorkflowNodeExecution instances for a specific workflow run. + Retrieve all NodeExecution instances for a specific workflow run. Args: workflow_run_id: The workflow run ID @@ -62,34 +66,25 @@ class WorkflowNodeExecutionRepository(Protocol): order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of WorkflowNodeExecution instances + A list of NodeExecution instances """ ... - def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: + def get_running_executions(self, workflow_run_id: str) -> Sequence[NodeExecution]: """ - Retrieve all running WorkflowNodeExecution instances for a specific workflow run. + Retrieve all running NodeExecution instances for a specific workflow run. Args: workflow_run_id: The workflow run ID Returns: - A list of running WorkflowNodeExecution instances - """ - ... - - def update(self, execution: WorkflowNodeExecution) -> None: - """ - Update an existing WorkflowNodeExecution instance. - - Args: - execution: The WorkflowNodeExecution instance to update + A list of running NodeExecution instances """ ... def clear(self) -> None: """ - Clear all WorkflowNodeExecution records based on implementation-specific criteria. + Clear all NodeExecution records based on implementation-specific criteria. This method is intended to be used for bulk deletion operations, such as removing all records associated with a specific app_id and tenant_id in multi-tenant implementations. diff --git a/api/core/workflow/workflow_app_generate_task_pipeline.py b/api/core/workflow/workflow_app_generate_task_pipeline.py index 10a2d8b38b..0396fa8157 100644 --- a/api/core/workflow/workflow_app_generate_task_pipeline.py +++ b/api/core/workflow/workflow_app_generate_task_pipeline.py @@ -58,7 +58,7 @@ from core.workflow.repository.workflow_node_execution_repository import Workflow from core.workflow.workflow_cycle_manager import WorkflowCycleManager from extensions.ext_database import db from models.account import Account -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import EndUser from models.workflow import ( Workflow, @@ -94,11 +94,11 @@ class WorkflowAppGenerateTaskPipeline: if isinstance(user, EndUser): self._user_id = user.id user_session_id = user.session_id - self._created_by_role = CreatedByRole.END_USER + self._created_by_role = CreatorUserRole.END_USER elif isinstance(user, Account): self._user_id = user.id user_session_id = user.id - self._created_by_role = CreatedByRole.ACCOUNT + self._created_by_role = CreatorUserRole.ACCOUNT else: raise ValueError(f"Invalid user type: {type(user)}") diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 01d5db4303..6d33d7372c 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -46,26 +46,28 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.exc import WorkflowRunNotFoundError from core.file import FILE_MODEL_IDENTITY, File -from core.model_runtime.utils.encoders import jsonable_encoder from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.tools.tool_manager import ToolManager from core.workflow.entities.node_entities import NodeRunMetadataKey +from core.workflow.entities.node_execution_entities import ( + NodeExecution, + NodeExecutionStatus, +) from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.workflow_entry import WorkflowEntry -from models.account import Account -from models.enums import CreatedByRole, WorkflowRunTriggeredFrom -from models.model import EndUser -from models.workflow import ( +from models import ( + Account, + CreatorUserRole, + EndUser, Workflow, - WorkflowNodeExecution, WorkflowNodeExecutionStatus, - WorkflowNodeExecutionTriggeredFrom, WorkflowRun, WorkflowRunStatus, + WorkflowRunTriggeredFrom, ) @@ -78,7 +80,6 @@ class WorkflowCycleManager: workflow_node_execution_repository: WorkflowNodeExecutionRepository, ) -> None: self._workflow_run: WorkflowRun | None = None - self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {} self._application_generate_entity = application_generate_entity self._workflow_system_variables = workflow_system_variables self._workflow_node_execution_repository = workflow_node_execution_repository @@ -89,7 +90,7 @@ class WorkflowCycleManager: session: Session, workflow_id: str, user_id: str, - created_by_role: CreatedByRole, + created_by_role: CreatorUserRole, ) -> WorkflowRun: workflow_stmt = select(Workflow).where(Workflow.id == workflow_id) workflow = session.scalar(workflow_stmt) @@ -258,21 +259,22 @@ class WorkflowCycleManager: workflow_run.exceptions_count = exceptions_count # Use the instance repository to find running executions for a workflow run - running_workflow_node_executions = self._workflow_node_execution_repository.get_running_executions( + running_domain_executions = self._workflow_node_execution_repository.get_running_executions( workflow_run_id=workflow_run.id ) - # Update the cache with the retrieved executions - for execution in running_workflow_node_executions: - if execution.node_execution_id: - self._workflow_node_executions[execution.node_execution_id] = execution + # Update the domain models + now = datetime.now(UTC).replace(tzinfo=None) + for domain_execution in running_domain_executions: + if domain_execution.node_execution_id: + # Update the domain model + domain_execution.status = NodeExecutionStatus.FAILED + domain_execution.error = error + domain_execution.finished_at = now + domain_execution.elapsed_time = (now - domain_execution.created_at).total_seconds() - for workflow_node_execution in running_workflow_node_executions: - now = datetime.now(UTC).replace(tzinfo=None) - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error - workflow_node_execution.finished_at = now - workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds() + # Update the repository with the domain model + self._workflow_node_execution_repository.save(domain_execution) if trace_manager: trace_manager.add_trace_task( @@ -286,63 +288,67 @@ class WorkflowCycleManager: return workflow_run - def _handle_node_execution_start( - self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent - ) -> WorkflowNodeExecution: - workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.id = str(uuid4()) - workflow_node_execution.tenant_id = workflow_run.tenant_id - workflow_node_execution.app_id = workflow_run.app_id - workflow_node_execution.workflow_id = workflow_run.workflow_id - workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value - workflow_node_execution.workflow_run_id = workflow_run.id - workflow_node_execution.predecessor_node_id = event.predecessor_node_id - workflow_node_execution.index = event.node_run_index - workflow_node_execution.node_execution_id = event.node_execution_id - workflow_node_execution.node_id = event.node_id - workflow_node_execution.node_type = event.node_type.value - workflow_node_execution.title = event.node_data.title - workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value - workflow_node_execution.created_by_role = workflow_run.created_by_role - workflow_node_execution.created_by = workflow_run.created_by - workflow_node_execution.execution_metadata = json.dumps( - { - NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, - NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, - NodeRunMetadataKey.LOOP_ID: event.in_loop_id, - } + def _handle_node_execution_start(self, *, workflow_run: WorkflowRun, event: QueueNodeStartedEvent) -> NodeExecution: + # Create a domain model + created_at = datetime.now(UTC).replace(tzinfo=None) + metadata = { + NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, + NodeRunMetadataKey.LOOP_ID: event.in_loop_id, + } + + domain_execution = NodeExecution( + id=str(uuid4()), + workflow_id=workflow_run.workflow_id, + workflow_run_id=workflow_run.id, + predecessor_node_id=event.predecessor_node_id, + index=event.node_run_index, + node_execution_id=event.node_execution_id, + node_id=event.node_id, + node_type=event.node_type, + title=event.node_data.title, + status=NodeExecutionStatus.RUNNING, + metadata=metadata, + created_at=created_at, ) - workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - # Use the instance repository to save the workflow node execution - self._workflow_node_execution_repository.save(workflow_node_execution) + # Use the instance repository to save the domain model + self._workflow_node_execution_repository.save(domain_execution) + + return domain_execution - self._workflow_node_executions[event.node_execution_id] = workflow_node_execution - return workflow_node_execution + def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> NodeExecution: + # Get the domain model from repository + domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) + if not domain_execution: + raise ValueError(f"Domain node execution not found: {event.node_execution_id}") - def _handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: - workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) + # Process data inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) - execution_metadata_dict = dict(event.execution_metadata or {}) - execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None + + # Convert metadata keys to strings + execution_metadata_dict = {} + if event.execution_metadata: + for key, value in event.execution_metadata.items(): + execution_metadata_dict[key] = value + finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - process_data = WorkflowEntry.handle_special_values(event.process_data) + # Update domain model + domain_execution.status = NodeExecutionStatus.SUCCEEDED + domain_execution.update_from_mapping( + inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict + ) + domain_execution.finished_at = finished_at + domain_execution.elapsed_time = elapsed_time - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value - workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None - workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = execution_metadata - workflow_node_execution.finished_at = finished_at - workflow_node_execution.elapsed_time = elapsed_time + # Update the repository with the domain model + self._workflow_node_execution_repository.save(domain_execution) - # Use the instance repository to update the workflow node execution - self._workflow_node_execution_repository.update(workflow_node_execution) - return workflow_node_execution + return domain_execution def _handle_workflow_node_execution_failed( self, @@ -351,43 +357,52 @@ class WorkflowCycleManager: | QueueNodeInIterationFailedEvent | QueueNodeInLoopFailedEvent | QueueNodeExceptionEvent, - ) -> WorkflowNodeExecution: + ) -> NodeExecution: """ Workflow node execution failed :param event: queue node failed event :return: """ - workflow_node_execution = self._get_workflow_node_execution(node_execution_id=event.node_execution_id) + # Get the domain model from repository + domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) + if not domain_execution: + raise ValueError(f"Domain node execution not found: {event.node_execution_id}") + # Process data inputs = WorkflowEntry.handle_special_values(event.inputs) process_data = WorkflowEntry.handle_special_values(event.process_data) outputs = WorkflowEntry.handle_special_values(event.outputs) + + # Convert metadata keys to strings + execution_metadata_dict = {} + if event.execution_metadata: + for key, value in event.execution_metadata.items(): + execution_metadata_dict[key] = value + finished_at = datetime.now(UTC).replace(tzinfo=None) elapsed_time = (finished_at - event.start_at).total_seconds() - execution_metadata = ( - json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None - ) - process_data = WorkflowEntry.handle_special_values(event.process_data) - workflow_node_execution.status = ( - WorkflowNodeExecutionStatus.FAILED.value + + # Update domain model + domain_execution.status = ( + NodeExecutionStatus.FAILED if not isinstance(event, QueueNodeExceptionEvent) - else WorkflowNodeExecutionStatus.EXCEPTION.value + else NodeExecutionStatus.EXCEPTION ) - workflow_node_execution.error = event.error - workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.process_data = json.dumps(process_data) if process_data else None - workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.finished_at = finished_at - workflow_node_execution.elapsed_time = elapsed_time - workflow_node_execution.execution_metadata = execution_metadata + domain_execution.error = event.error + domain_execution.update_from_mapping( + inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict + ) + domain_execution.finished_at = finished_at + domain_execution.elapsed_time = elapsed_time - self._workflow_node_execution_repository.update(workflow_node_execution) + # Update the repository with the domain model + self._workflow_node_execution_repository.save(domain_execution) - return workflow_node_execution + return domain_execution def _handle_workflow_node_execution_retried( self, *, workflow_run: WorkflowRun, event: QueueNodeRetryEvent - ) -> WorkflowNodeExecution: + ) -> NodeExecution: """ Workflow node execution failed :param workflow_run: workflow run @@ -399,47 +414,47 @@ class WorkflowCycleManager: elapsed_time = (finished_at - created_at).total_seconds() inputs = WorkflowEntry.handle_special_values(event.inputs) outputs = WorkflowEntry.handle_special_values(event.outputs) + + # Convert metadata keys to strings origin_metadata = { NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, NodeRunMetadataKey.LOOP_ID: event.in_loop_id, } - merged_metadata = ( - {**jsonable_encoder(event.execution_metadata), **origin_metadata} - if event.execution_metadata is not None - else origin_metadata + + # Convert execution metadata keys to strings + execution_metadata_dict: dict[NodeRunMetadataKey, str | None] = {} + if event.execution_metadata: + for key, value in event.execution_metadata.items(): + execution_metadata_dict[key] = value + + merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata + + # Create a domain model + domain_execution = NodeExecution( + id=str(uuid4()), + workflow_id=workflow_run.workflow_id, + workflow_run_id=workflow_run.id, + predecessor_node_id=event.predecessor_node_id, + node_execution_id=event.node_execution_id, + node_id=event.node_id, + node_type=event.node_type, + title=event.node_data.title, + status=NodeExecutionStatus.RETRY, + created_at=created_at, + finished_at=finished_at, + elapsed_time=elapsed_time, + error=event.error, + index=event.node_run_index, ) - execution_metadata = json.dumps(merged_metadata) - - workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.id = str(uuid4()) - workflow_node_execution.tenant_id = workflow_run.tenant_id - workflow_node_execution.app_id = workflow_run.app_id - workflow_node_execution.workflow_id = workflow_run.workflow_id - workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value - workflow_node_execution.workflow_run_id = workflow_run.id - workflow_node_execution.predecessor_node_id = event.predecessor_node_id - workflow_node_execution.node_execution_id = event.node_execution_id - workflow_node_execution.node_id = event.node_id - workflow_node_execution.node_type = event.node_type.value - workflow_node_execution.title = event.node_data.title - workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value - workflow_node_execution.created_by_role = workflow_run.created_by_role - workflow_node_execution.created_by = workflow_run.created_by - workflow_node_execution.created_at = created_at - workflow_node_execution.finished_at = finished_at - workflow_node_execution.elapsed_time = elapsed_time - workflow_node_execution.error = event.error - workflow_node_execution.inputs = json.dumps(inputs) if inputs else None - workflow_node_execution.outputs = json.dumps(outputs) if outputs else None - workflow_node_execution.execution_metadata = execution_metadata - workflow_node_execution.index = event.node_run_index - - # Use the instance repository to save the workflow node execution - self._workflow_node_execution_repository.save(workflow_node_execution) - - self._workflow_node_executions[event.node_execution_id] = workflow_node_execution - return workflow_node_execution + + # Update with mappings + domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata) + + # Use the instance repository to save the domain model + self._workflow_node_execution_repository.save(domain_execution) + + return domain_execution def _workflow_start_to_stream_response( self, @@ -469,7 +484,7 @@ class WorkflowCycleManager: workflow_run: WorkflowRun, ) -> WorkflowFinishStreamResponse: created_by = None - if workflow_run.created_by_role == CreatedByRole.ACCOUNT: + if workflow_run.created_by_role == CreatorUserRole.ACCOUNT: stmt = select(Account).where(Account.id == workflow_run.created_by) account = session.scalar(stmt) if account: @@ -478,7 +493,7 @@ class WorkflowCycleManager: "name": account.name, "email": account.email, } - elif workflow_run.created_by_role == CreatedByRole.END_USER: + elif workflow_run.created_by_role == CreatorUserRole.END_USER: stmt = select(EndUser).where(EndUser.id == workflow_run.created_by) end_user = session.scalar(stmt) if end_user: @@ -515,9 +530,9 @@ class WorkflowCycleManager: *, event: QueueNodeStartedEvent, task_id: str, - workflow_node_execution: WorkflowNodeExecution, + workflow_node_execution: NodeExecution, ) -> Optional[NodeStartStreamResponse]: - if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_run_id: return None @@ -532,7 +547,7 @@ class WorkflowCycleManager: title=workflow_node_execution.title, index=workflow_node_execution.index, predecessor_node_id=workflow_node_execution.predecessor_node_id, - inputs=workflow_node_execution.inputs_dict, + inputs=workflow_node_execution.inputs, created_at=int(workflow_node_execution.created_at.timestamp()), parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, @@ -565,9 +580,9 @@ class WorkflowCycleManager: | QueueNodeInLoopFailedEvent | QueueNodeExceptionEvent, task_id: str, - workflow_node_execution: WorkflowNodeExecution, + workflow_node_execution: NodeExecution, ) -> Optional[NodeFinishStreamResponse]: - if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_run_id: return None @@ -584,16 +599,16 @@ class WorkflowCycleManager: index=workflow_node_execution.index, title=workflow_node_execution.title, predecessor_node_id=workflow_node_execution.predecessor_node_id, - inputs=workflow_node_execution.inputs_dict, - process_data=workflow_node_execution.process_data_dict, - outputs=workflow_node_execution.outputs_dict, + inputs=workflow_node_execution.inputs, + process_data=workflow_node_execution.process_data, + outputs=workflow_node_execution.outputs, status=workflow_node_execution.status, error=workflow_node_execution.error, elapsed_time=workflow_node_execution.elapsed_time, - execution_metadata=workflow_node_execution.execution_metadata_dict, + execution_metadata=workflow_node_execution.metadata, created_at=int(workflow_node_execution.created_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, parent_parallel_id=event.parent_parallel_id, @@ -608,9 +623,9 @@ class WorkflowCycleManager: *, event: QueueNodeRetryEvent, task_id: str, - workflow_node_execution: WorkflowNodeExecution, + workflow_node_execution: NodeExecution, ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: - if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: + if workflow_node_execution.node_type in {NodeType.ITERATION, NodeType.LOOP}: return None if not workflow_node_execution.workflow_run_id: return None @@ -627,16 +642,16 @@ class WorkflowCycleManager: index=workflow_node_execution.index, title=workflow_node_execution.title, predecessor_node_id=workflow_node_execution.predecessor_node_id, - inputs=workflow_node_execution.inputs_dict, - process_data=workflow_node_execution.process_data_dict, - outputs=workflow_node_execution.outputs_dict, + inputs=workflow_node_execution.inputs, + process_data=workflow_node_execution.process_data, + outputs=workflow_node_execution.outputs, status=workflow_node_execution.status, error=workflow_node_execution.error, elapsed_time=workflow_node_execution.elapsed_time, - execution_metadata=workflow_node_execution.execution_metadata_dict, + execution_metadata=workflow_node_execution.metadata, created_at=int(workflow_node_execution.created_at.timestamp()), finished_at=int(workflow_node_execution.finished_at.timestamp()), - files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}), + files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs or {}), parallel_id=event.parallel_id, parallel_start_node_id=event.parallel_start_node_id, parent_parallel_id=event.parent_parallel_id, @@ -908,23 +923,6 @@ class WorkflowCycleManager: return workflow_run - def _get_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: - # First check the cache for performance - if node_execution_id in self._workflow_node_executions: - cached_execution = self._workflow_node_executions[node_execution_id] - # No need to merge with session since expire_on_commit=False - return cached_execution - - # If not in cache, use the instance repository to get by node_execution_id - execution = self._workflow_node_execution_repository.get_by_node_execution_id(node_execution_id) - - if not execution: - raise ValueError(f"Workflow node execution not found: {node_execution_id}") - - # Update cache - self._workflow_node_executions[node_execution_id] = execution - return execution - def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse: """ Handle agent log diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index 3cbdc8560b..59982d96ad 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -114,8 +114,10 @@ def init_app(app: DifyApp): pass from opentelemetry import trace - from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter - from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter + from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter from opentelemetry.instrumentation.celery import CeleryInstrumentor from opentelemetry.instrumentation.flask import FlaskInstrumentor from opentelemetry.instrumentation.sqlalchemy import SQLAlchemyInstrumentor @@ -158,19 +160,32 @@ def init_app(app: DifyApp): sampler = ParentBasedTraceIdRatio(dify_config.OTEL_SAMPLING_RATE) provider = TracerProvider(resource=resource, sampler=sampler) set_tracer_provider(provider) - exporter: Union[OTLPSpanExporter, ConsoleSpanExporter] - metric_exporter: Union[OTLPMetricExporter, ConsoleMetricExporter] + exporter: Union[GRPCSpanExporter, HTTPSpanExporter, ConsoleSpanExporter] + metric_exporter: Union[GRPCMetricExporter, HTTPMetricExporter, ConsoleMetricExporter] + protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower() if dify_config.OTEL_EXPORTER_TYPE == "otlp": - exporter = OTLPSpanExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/traces", - headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"}, - ) - metric_exporter = OTLPMetricExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics", - headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"}, - ) + if protocol == "grpc": + exporter = GRPCSpanExporter( + endpoint=dify_config.OTLP_BASE_ENDPOINT, + # Header field names must consist of lowercase letters, check RFC7540 + headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),), + insecure=True, + ) + metric_exporter = GRPCMetricExporter( + endpoint=dify_config.OTLP_BASE_ENDPOINT, + headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),), + insecure=True, + ) + else: + exporter = HTTPSpanExporter( + endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/traces", + headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"}, + ) + metric_exporter = HTTPMetricExporter( + endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics", + headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"}, + ) else: - # Fallback to console exporter exporter = ConsoleSpanExporter() metric_exporter = ConsoleMetricExporter() diff --git a/api/extensions/ext_request_logging.py b/api/extensions/ext_request_logging.py new file mode 100644 index 0000000000..7c69483e0f --- /dev/null +++ b/api/extensions/ext_request_logging.py @@ -0,0 +1,73 @@ +import json +import logging + +import flask +import werkzeug.http +from flask import Flask +from flask.signals import request_finished, request_started + +from configs import dify_config + +_logger = logging.getLogger(__name__) + + +def _is_content_type_json(content_type: str) -> bool: + if not content_type: + return False + content_type_no_option, _ = werkzeug.http.parse_options_header(content_type) + return content_type_no_option.lower() == "application/json" + + +def _log_request_started(_sender, **_extra): + """Log the start of a request.""" + if not _logger.isEnabledFor(logging.DEBUG): + return + + request = flask.request + if not (_is_content_type_json(request.content_type) and request.data): + _logger.debug("Received Request %s -> %s", request.method, request.path) + return + try: + json_data = json.loads(request.data) + except (TypeError, ValueError): + _logger.exception("Failed to parse JSON request") + return + formatted_json = json.dumps(json_data, ensure_ascii=False, indent=2) + _logger.debug( + "Received Request %s -> %s, Request Body:\n%s", + request.method, + request.path, + formatted_json, + ) + + +def _log_request_finished(_sender, response, **_extra): + """Log the end of a request.""" + if not _logger.isEnabledFor(logging.DEBUG) or response is None: + return + + if not _is_content_type_json(response.content_type): + _logger.debug("Response %s %s", response.status, response.content_type) + return + + response_data = response.get_data(as_text=True) + try: + json_data = json.loads(response_data) + except (TypeError, ValueError): + _logger.exception("Failed to parse JSON response") + return + formatted_json = json.dumps(json_data, ensure_ascii=False, indent=2) + _logger.debug( + "Response %s %s, Response Body:\n%s", + response.status, + response.content_type, + formatted_json, + ) + + +def init_app(app: Flask): + """Initialize the request logging extension.""" + if not dify_config.ENABLE_REQUEST_LOGGING: + return + request_started.connect(_log_request_started, app) + request_finished.connect(_log_request_finished, app) diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 0b0e2a2f54..416870cafa 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -63,6 +63,7 @@ app_detail_fields = { "created_at": TimestampField, "updated_by": fields.String, "updated_at": TimestampField, + "access_mode": fields.String, } prompt_config_fields = { @@ -98,6 +99,7 @@ app_partial_fields = { "updated_by": fields.String, "updated_at": TimestampField, "tags": fields.List(fields.Nested(tag_fields)), + "access_mode": fields.String, } @@ -176,6 +178,7 @@ app_detail_fields_with_site = { "updated_by": fields.String, "updated_at": TimestampField, "deleted_tools": fields.List(fields.Nested(deleted_tool_fields)), + "access_mode": fields.String, } diff --git a/api/migrations/versions/2025_05_15_1531-2adcbe1f5dfb_add_workflowdraftvariable_model.py b/api/migrations/versions/2025_05_15_1531-2adcbe1f5dfb_add_workflowdraftvariable_model.py new file mode 100644 index 0000000000..5bf394b21c --- /dev/null +++ b/api/migrations/versions/2025_05_15_1531-2adcbe1f5dfb_add_workflowdraftvariable_model.py @@ -0,0 +1,51 @@ +"""add WorkflowDraftVariable model + +Revision ID: 2adcbe1f5dfb +Revises: d28f2004b072 +Create Date: 2025-05-15 15:31:03.128680 + +""" + +import sqlalchemy as sa +from alembic import op + +import models as models + +# revision identifiers, used by Alembic. +revision = "2adcbe1f5dfb" +down_revision = "d28f2004b072" +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "workflow_draft_variables", + sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuid_generate_v4()"), nullable=False), + sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False), + sa.Column("app_id", models.types.StringUUID(), nullable=False), + sa.Column("last_edited_at", sa.DateTime(), nullable=True), + sa.Column("node_id", sa.String(length=255), nullable=False), + sa.Column("name", sa.String(length=255), nullable=False), + sa.Column("description", sa.String(length=255), nullable=False), + sa.Column("selector", sa.String(length=255), nullable=False), + sa.Column("value_type", sa.String(length=20), nullable=False), + sa.Column("value", sa.Text(), nullable=False), + sa.Column("visible", sa.Boolean(), nullable=False), + sa.Column("editable", sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("workflow_draft_variables_pkey")), + sa.UniqueConstraint("app_id", "node_id", "name", name=op.f("workflow_draft_variables_app_id_key")), + ) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + + # Dropping `workflow_draft_variables` also drops any index associated with it. + op.drop_table("workflow_draft_variables") + + # ### end Alembic commands ### diff --git a/api/models/__init__.py b/api/models/__init__.py index 2066481a61..f652449e98 100644 --- a/api/models/__init__.py +++ b/api/models/__init__.py @@ -27,7 +27,7 @@ from .dataset import ( Whitelist, ) from .engine import db -from .enums import CreatedByRole, UserFrom, WorkflowRunTriggeredFrom +from .enums import CreatorUserRole, UserFrom, WorkflowRunTriggeredFrom from .model import ( ApiRequest, ApiToken, @@ -112,7 +112,7 @@ __all__ = [ "CeleryTaskSet", "Conversation", "ConversationVariable", - "CreatedByRole", + "CreatorUserRole", "DataSourceApiKeyAuthBinding", "DataSourceOauthBinding", "Dataset", diff --git a/api/models/enums.py b/api/models/enums.py index 7b9500ebe4..4434c3fec8 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -1,7 +1,7 @@ from enum import StrEnum -class CreatedByRole(StrEnum): +class CreatorUserRole(StrEnum): ACCOUNT = "account" END_USER = "end_user" @@ -14,3 +14,10 @@ class UserFrom(StrEnum): class WorkflowRunTriggeredFrom(StrEnum): DEBUGGING = "debugging" APP_RUN = "app-run" + + +class DraftVariableType(StrEnum): + # node means that the correspond variable + NODE = "node" + SYS = "sys" + CONVERSATION = "conversation" diff --git a/api/models/model.py b/api/models/model.py index ab426649c5..ee79fbd6b5 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -29,7 +29,7 @@ from libs.helper import generate_string from .account import Account, Tenant from .base import Base from .engine import db -from .enums import CreatedByRole +from .enums import CreatorUserRole from .types import StringUUID from .workflow import WorkflowRunStatus @@ -1270,7 +1270,7 @@ class MessageFile(Base): url: str | None = None, belongs_to: Literal["user", "assistant"] | None = None, upload_file_id: str | None = None, - created_by_role: CreatedByRole, + created_by_role: CreatorUserRole, created_by: str, ): self.message_id = message_id @@ -1417,7 +1417,7 @@ class EndUser(Base, UserMixin): ) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) - tenant_id = db.Column(StringUUID, nullable=False) + tenant_id: Mapped[str] = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=True) type = db.Column(db.String(255), nullable=False) external_user_id = db.Column(db.String(255), nullable=True) @@ -1547,7 +1547,7 @@ class UploadFile(Base): size: int, extension: str, mime_type: str, - created_by_role: CreatedByRole, + created_by_role: CreatorUserRole, created_by: str, created_at: datetime, used: bool, diff --git a/api/models/types.py b/api/models/types.py index cb6773e70c..e5581c3ab0 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -1,4 +1,7 @@ -from sqlalchemy import CHAR, TypeDecorator +import enum +from typing import Generic, TypeVar + +from sqlalchemy import CHAR, VARCHAR, TypeDecorator from sqlalchemy.dialects.postgresql import UUID @@ -24,3 +27,51 @@ class StringUUID(TypeDecorator): if value is None: return value return str(value) + + +_E = TypeVar("_E", bound=enum.StrEnum) + + +class EnumText(TypeDecorator, Generic[_E]): + impl = VARCHAR + cache_ok = True + + _length: int + _enum_class: type[_E] + + def __init__(self, enum_class: type[_E], length: int | None = None): + self._enum_class = enum_class + max_enum_value_len = max(len(e.value) for e in enum_class) + if length is not None: + if length < max_enum_value_len: + raise ValueError("length should be greater than enum value length.") + self._length = length + else: + # leave some rooms for future longer enum values. + self._length = max(max_enum_value_len, 20) + + def process_bind_param(self, value: _E | str | None, dialect): + if value is None: + return value + if isinstance(value, self._enum_class): + return value.value + elif isinstance(value, str): + self._enum_class(value) + return value + else: + raise TypeError(f"expected str or {self._enum_class}, got {type(value)}") + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(VARCHAR(self._length)) + + def process_result_value(self, value, dialect) -> _E | None: + if value is None: + return value + if not isinstance(value, str): + raise TypeError(f"expected str, got {type(value)}") + return self._enum_class(value) + + def compare_values(self, x, y): + if x is None or y is None: + return x is y + return x == y diff --git a/api/models/workflow.py b/api/models/workflow.py index 730ddcfcb5..91be1cda6e 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,29 +1,36 @@ import json +import logging from collections.abc import Mapping, Sequence from datetime import UTC, datetime from enum import Enum, StrEnum from typing import TYPE_CHECKING, Any, Optional, Self, Union from uuid import uuid4 +from core.variables import utils as variable_utils +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID +from factories.variable_factory import build_segment + if TYPE_CHECKING: from models.model import AppMode import sqlalchemy as sa -from sqlalchemy import func +from sqlalchemy import UniqueConstraint, func from sqlalchemy.orm import Mapped, mapped_column import contexts from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter -from core.variables import SecretVariable, Variable +from core.variables import SecretVariable, Segment, SegmentType, Variable from factories import variable_factory from libs import helper from .account import Account from .base import Base from .engine import db -from .enums import CreatedByRole -from .types import StringUUID +from .enums import CreatorUserRole, DraftVariableType +from .types import EnumText, StringUUID + +_logger = logging.getLogger(__name__) if TYPE_CHECKING: from models.model import AppMode @@ -429,15 +436,15 @@ class WorkflowRun(Base): @property def created_by_account(self): - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None + created_by_role = CreatorUserRole(self.created_by_role) + return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None + created_by_role = CreatorUserRole(self.created_by_role) + return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None @property def graph_dict(self): @@ -634,24 +641,24 @@ class WorkflowNodeExecution(Base): @property def created_by_account(self): - created_by_role = CreatedByRole(self.created_by_role) + created_by_role = CreatorUserRole(self.created_by_role) # TODO(-LAN-): Avoid using db.session.get() here. - return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None + return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole(self.created_by_role) + created_by_role = CreatorUserRole(self.created_by_role) # TODO(-LAN-): Avoid using db.session.get() here. - return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None + return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None @property def inputs_dict(self): return json.loads(self.inputs) if self.inputs else None @property - def outputs_dict(self): + def outputs_dict(self) -> dict[str, Any] | None: return json.loads(self.outputs) if self.outputs else None @property @@ -659,7 +666,7 @@ class WorkflowNodeExecution(Base): return json.loads(self.process_data) if self.process_data else None @property - def execution_metadata_dict(self): + def execution_metadata_dict(self) -> dict[str, Any] | None: return json.loads(self.execution_metadata) if self.execution_metadata else None @property @@ -755,15 +762,15 @@ class WorkflowAppLog(Base): @property def created_by_account(self): - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(Account, self.created_by) if created_by_role == CreatedByRole.ACCOUNT else None + created_by_role = CreatorUserRole(self.created_by_role) + return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None @property def created_by_end_user(self): from models.model import EndUser - created_by_role = CreatedByRole(self.created_by_role) - return db.session.get(EndUser, self.created_by) if created_by_role == CreatedByRole.END_USER else None + created_by_role = CreatorUserRole(self.created_by_role) + return db.session.get(EndUser, self.created_by) if created_by_role == CreatorUserRole.END_USER else None class ConversationVariable(Base): @@ -797,3 +804,201 @@ class ConversationVariable(Base): def to_variable(self) -> Variable: mapping = json.loads(self.data) return variable_factory.build_conversation_variable_from_mapping(mapping) + + +# Only `sys.query` and `sys.files` could be modified. +_EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"]) + + +def _naive_utc_datetime(): + return datetime.now(UTC).replace(tzinfo=None) + + +class WorkflowDraftVariable(Base): + @staticmethod + def unique_columns() -> list[str]: + return [ + "app_id", + "node_id", + "name", + ] + + __tablename__ = "workflow_draft_variables" + __table_args__ = (UniqueConstraint(*unique_columns()),) + + # id is the unique identifier of a draft variable. + id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) + + created_at = mapped_column( + db.DateTime, + nullable=False, + default=_naive_utc_datetime, + server_default=func.current_timestamp(), + ) + + updated_at = mapped_column( + db.DateTime, + nullable=False, + default=_naive_utc_datetime, + server_default=func.current_timestamp(), + onupdate=func.current_timestamp(), + ) + + # "`app_id` maps to the `id` field in the `model.App` model." + app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + + # `last_edited_at` records when the value of a given draft variable + # is edited. + # + # If it's not edited after creation, its value is `None`. + last_edited_at: Mapped[datetime | None] = mapped_column( + db.DateTime, + nullable=True, + default=None, + ) + + # The `node_id` field is special. + # + # If the variable is a conversation variable or a system variable, then the value of `node_id` + # is `conversation` or `sys`, respective. + # + # Otherwise, if the variable is a variable belonging to a specific node, the value of `_node_id` is + # the identity of correspond node in graph definition. An example of node id is `"1745769620734"`. + # + # However, there's one caveat. The id of the first "Answer" node in chatflow is "answer". (Other + # "Answer" node conform the rules above.) + node_id: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="node_id") + + # From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than + # 80 chars. + # + # ref: api/core/workflow/entities/variable_pool.py:18 + name: Mapped[str] = mapped_column(sa.String(255), nullable=False) + description: Mapped[str] = mapped_column( + sa.String(255), + default="", + nullable=False, + ) + + selector: Mapped[str] = mapped_column(sa.String(255), nullable=False, name="selector") + + value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20)) + # JSON string + value: Mapped[str] = mapped_column(sa.Text, nullable=False, name="value") + + # visible + visible: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=True) + editable: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, default=False) + + def get_selector(self) -> list[str]: + selector = json.loads(self.selector) + if not isinstance(selector, list): + _logger.error( + "invalid selector loaded from database, type=%s, value=%s", + type(selector), + self.selector, + ) + raise ValueError("invalid selector.") + return selector + + def _set_selector(self, value: list[str]): + self.selector = json.dumps(value) + + def get_value(self) -> Segment | None: + return build_segment(json.loads(self.value)) + + def set_name(self, name: str): + self.name = name + self._set_selector([self.node_id, name]) + + def set_value(self, value: Segment): + self.value = json.dumps(value.value) + self.value_type = value.value_type + + def get_node_id(self) -> str | None: + if self.get_variable_type() == DraftVariableType.NODE: + return self.node_id + else: + return None + + def get_variable_type(self) -> DraftVariableType: + match self.node_id: + case DraftVariableType.CONVERSATION: + return DraftVariableType.CONVERSATION + case DraftVariableType.SYS: + return DraftVariableType.SYS + case _: + return DraftVariableType.NODE + + @classmethod + def _new( + cls, + *, + app_id: str, + node_id: str, + name: str, + value: Segment, + description: str = "", + ) -> "WorkflowDraftVariable": + variable = WorkflowDraftVariable() + variable.created_at = _naive_utc_datetime() + variable.updated_at = _naive_utc_datetime() + variable.description = description + variable.app_id = app_id + variable.node_id = node_id + variable.name = name + variable.set_value(value) + variable._set_selector(list(variable_utils.to_selector(node_id, name))) + return variable + + @classmethod + def new_conversation_variable( + cls, + *, + app_id: str, + name: str, + value: Segment, + ) -> "WorkflowDraftVariable": + variable = cls._new( + app_id=app_id, + node_id=CONVERSATION_VARIABLE_NODE_ID, + name=name, + value=value, + ) + return variable + + @classmethod + def new_sys_variable( + cls, + *, + app_id: str, + name: str, + value: Segment, + editable: bool = False, + ) -> "WorkflowDraftVariable": + variable = cls._new(app_id=app_id, node_id=SYSTEM_VARIABLE_NODE_ID, name=name, value=value) + variable.editable = editable + return variable + + @classmethod + def new_node_variable( + cls, + *, + app_id: str, + node_id: str, + name: str, + value: Segment, + visible: bool = True, + ) -> "WorkflowDraftVariable": + variable = cls._new(app_id=app_id, node_id=node_id, name=name, value=value) + variable.visible = visible + variable.editable = True + return variable + + @property + def edited(self): + return self.last_edited_at is not None + + +def is_system_variable_editable(name: str) -> bool: + return name in _EDITABLE_SYSTEM_VARIABLE diff --git a/api/pyproject.toml b/api/pyproject.toml index 18fe5db1a6..26a0bdb11e 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -72,7 +72,7 @@ dependencies = [ "python-dotenv==1.0.1", "pyyaml~=6.0.1", "readabilipy~=0.3.0", - "redis[hiredis]~=6.0.0", + "redis[hiredis]~=6.1.0", "resend~=2.9.0", "sentry-sdk[flask]~=2.28.0", "sqlalchemy~=2.0.29", diff --git a/api/services/account_service.py b/api/services/account_service.py index d21926d746..ac84a46299 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -49,7 +49,7 @@ from services.errors.account import ( RoleAlreadyAssignedError, TenantNotFoundError, ) -from services.errors.workspace import WorkSpaceNotAllowedCreateError +from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError from services.feature_service import FeatureService from tasks.delete_account_task import delete_account_task from tasks.mail_account_deletion_task import send_account_deletion_verification_code @@ -628,6 +628,10 @@ class TenantService: if not FeatureService.get_system_features().is_allow_create_workspace and not is_setup: raise WorkSpaceNotAllowedCreateError() + workspaces = FeatureService.get_system_features().license.workspaces + if not workspaces.is_available(): + raise WorkspacesLimitExceededError() + if name: tenant = TenantService.create_tenant(name=name, is_setup=is_setup) else: @@ -928,7 +932,11 @@ class RegisterService: if open_id is not None and provider is not None: AccountService.link_account_integrate(provider, open_id, account) - if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required: + if ( + FeatureService.get_system_features().is_allow_create_workspace + and create_workspace_required + and FeatureService.get_system_features().license.workspaces.is_available() + ): tenant = TenantService.create_tenant(f"{account.name}'s Workspace") TenantService.create_tenant_member(tenant, account, role="owner") account.current_tenant = tenant diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index a2775fe6ad..d2875180d8 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -40,7 +40,7 @@ IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:" CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:" IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB -CURRENT_DSL_VERSION = "0.2.0" +CURRENT_DSL_VERSION = "0.3.0" class ImportMode(StrEnum): diff --git a/api/services/app_service.py b/api/services/app_service.py index 2fae479e05..ebebf8fa58 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -18,8 +18,10 @@ from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created from extensions.ext_database import db from models.account import Account -from models.model import App, AppMode, AppModelConfig +from models.model import App, AppMode, AppModelConfig, Site from models.tools import ApiToolProvider +from services.enterprise.enterprise_service import EnterpriseService +from services.feature_service import FeatureService from services.tag_service import TagService from tasks.remove_app_and_related_data_task import remove_app_and_related_data_task @@ -155,6 +157,10 @@ class AppService: app_was_created.send(app, account=account) + if FeatureService.get_system_features().webapp_auth.enabled: + # update web app setting as private + EnterpriseService.WebAppAuth.update_app_access_mode(app.id, "private") + return app def get_app(self, app: App) -> App: @@ -307,6 +313,10 @@ class AppService: db.session.delete(app) db.session.commit() + # clean up web app settings + if FeatureService.get_system_features().webapp_auth.enabled: + EnterpriseService.WebAppAuth.cleanup_webapp(app.id) + # Trigger asynchronous deletion of app and related data remove_app_and_related_data_task.delay(tenant_id=app.tenant_id, app_id=app.id) @@ -373,3 +383,15 @@ class AppService: meta["tool_icons"][tool_name] = {"background": "#252525", "content": "\ud83d\ude01"} return meta + + @staticmethod + def get_app_code_by_id(app_id: str) -> str: + """ + Get app code by app id + :param app_id: app id + :return: app code + """ + site = db.session.query(Site).filter(Site.app_id == app_id).first() + if not site: + raise ValueError(f"App with id {app_id} not found") + return str(site.code) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 15b2ae4f15..6411acda7f 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -994,7 +994,7 @@ class DocumentService: created_by=account.id, ) else: - logging.warn( + logging.warning( f"Invalid process rule mode: {process_rule.mode}, can not find dataset process rule" ) return diff --git a/api/services/enterprise/enterprise_service.py b/api/services/enterprise/enterprise_service.py index abc01ddf8f..1be78d2e62 100644 --- a/api/services/enterprise/enterprise_service.py +++ b/api/services/enterprise/enterprise_service.py @@ -1,11 +1,90 @@ +from pydantic import BaseModel, Field + from services.enterprise.base import EnterpriseRequest +class WebAppSettings(BaseModel): + access_mode: str = Field( + description="Access mode for the web app. Can be 'public' or 'private'", + default="private", + alias="accessMode", + ) + + class EnterpriseService: @classmethod def get_info(cls): return EnterpriseRequest.send_request("GET", "/info") @classmethod - def get_app_web_sso_enabled(cls, app_code): - return EnterpriseRequest.send_request("GET", f"/app-sso-setting?appCode={app_code}") + def get_workspace_info(cls, tenant_id: str): + return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info") + + class WebAppAuth: + @classmethod + def is_user_allowed_to_access_webapp(cls, user_id: str, app_code: str): + params = {"userId": user_id, "appCode": app_code} + data = EnterpriseRequest.send_request("GET", "/webapp/permission", params=params) + + return data.get("result", False) + + @classmethod + def get_app_access_mode_by_id(cls, app_id: str) -> WebAppSettings: + if not app_id: + raise ValueError("app_id must be provided.") + params = {"appId": app_id} + data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/id", params=params) + if not data: + raise ValueError("No data found.") + return WebAppSettings(**data) + + @classmethod + def batch_get_app_access_mode_by_id(cls, app_ids: list[str]) -> dict[str, WebAppSettings]: + if not app_ids: + return {} + body = {"appIds": app_ids} + data: dict[str, str] = EnterpriseRequest.send_request("POST", "/webapp/access-mode/batch/id", json=body) + if not data: + raise ValueError("No data found.") + + if not isinstance(data["accessModes"], dict): + raise ValueError("Invalid data format.") + + ret = {} + for key, value in data["accessModes"].items(): + curr = WebAppSettings() + curr.access_mode = value + ret[key] = curr + + return ret + + @classmethod + def get_app_access_mode_by_code(cls, app_code: str) -> WebAppSettings: + if not app_code: + raise ValueError("app_code must be provided.") + params = {"appCode": app_code} + data = EnterpriseRequest.send_request("GET", "/webapp/access-mode/code", params=params) + if not data: + raise ValueError("No data found.") + return WebAppSettings(**data) + + @classmethod + def update_app_access_mode(cls, app_id: str, access_mode: str): + if not app_id: + raise ValueError("app_id must be provided.") + if access_mode not in ["public", "private", "private_all"]: + raise ValueError("access_mode must be either 'public', 'private', or 'private_all'") + + data = {"appId": app_id, "accessMode": access_mode} + + response = EnterpriseRequest.send_request("POST", "/webapp/access-mode", json=data) + + return response.get("result", False) + + @classmethod + def cleanup_webapp(cls, app_id: str): + if not app_id: + raise ValueError("app_id must be provided.") + + body = {"appId": app_id} + EnterpriseRequest.send_request("DELETE", "/webapp/clean", json=body) diff --git a/api/services/enterprise/mail_service.py b/api/services/enterprise/mail_service.py new file mode 100644 index 0000000000..630e7679ac --- /dev/null +++ b/api/services/enterprise/mail_service.py @@ -0,0 +1,18 @@ +from pydantic import BaseModel + +from tasks.mail_enterprise_task import send_enterprise_email_task + + +class DifyMail(BaseModel): + to: list[str] + subject: str + body: str + substitutions: dict[str, str] = {} + + +class EnterpriseMailService: + @classmethod + def send_mail(cls, mail: DifyMail): + send_enterprise_email_task.delay( + to=mail.to, subject=mail.subject, body=mail.body, substitutions=mail.substitutions + ) diff --git a/api/services/errors/workspace.py b/api/services/errors/workspace.py index 714064ffdf..577238507f 100644 --- a/api/services/errors/workspace.py +++ b/api/services/errors/workspace.py @@ -7,3 +7,7 @@ class WorkSpaceNotAllowedCreateError(BaseServiceError): class WorkSpaceNotFoundError(BaseServiceError): pass + + +class WorkspacesLimitExceededError(BaseServiceError): + pass diff --git a/api/services/feature_service.py b/api/services/feature_service.py index c2226c319f..be85a03e80 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -1,6 +1,6 @@ from enum import StrEnum -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from configs import dify_config from services.billing_service import BillingService @@ -27,6 +27,32 @@ class LimitationModel(BaseModel): limit: int = 0 +class LicenseLimitationModel(BaseModel): + """ + - enabled: whether this limit is enforced + - size: current usage count + - limit: maximum allowed count; 0 means unlimited + """ + + enabled: bool = Field(False, description="Whether this limit is currently active") + size: int = Field(0, description="Number of resources already consumed") + limit: int = Field(0, description="Maximum number of resources allowed; 0 means no limit") + + def is_available(self, required: int = 1) -> bool: + """ + Determine whether the requested amount can be allocated. + + Returns True if: + - this limit is not active, or + - the limit is zero (unlimited), or + - there is enough remaining quota. + """ + if not self.enabled or self.limit == 0: + return True + + return (self.limit - self.size) >= required + + class LicenseStatus(StrEnum): NONE = "none" INACTIVE = "inactive" @@ -39,6 +65,27 @@ class LicenseStatus(StrEnum): class LicenseModel(BaseModel): status: LicenseStatus = LicenseStatus.NONE expired_at: str = "" + workspaces: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0) + + +class BrandingModel(BaseModel): + enabled: bool = False + application_title: str = "" + login_page_logo: str = "" + workspace_logo: str = "" + favicon: str = "" + + +class WebAppAuthSSOModel(BaseModel): + protocol: str = "" + + +class WebAppAuthModel(BaseModel): + enabled: bool = False + allow_sso: bool = False + sso_config: WebAppAuthSSOModel = WebAppAuthSSOModel() + allow_email_code_login: bool = False + allow_email_password_login: bool = False class FeatureModel(BaseModel): @@ -54,6 +101,8 @@ class FeatureModel(BaseModel): can_replace_logo: bool = False model_load_balancing_enabled: bool = False dataset_operator_enabled: bool = False + webapp_copyright_enabled: bool = False + workspace_members: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0) # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -68,9 +117,6 @@ class KnowledgeRateLimitModel(BaseModel): class SystemFeatureModel(BaseModel): sso_enforced_for_signin: bool = False sso_enforced_for_signin_protocol: str = "" - sso_enforced_for_web: bool = False - sso_enforced_for_web_protocol: str = "" - enable_web_sso_switch_component: bool = False enable_marketplace: bool = False max_plugin_package_size: int = dify_config.PLUGIN_MAX_PACKAGE_SIZE enable_email_code_login: bool = False @@ -80,6 +126,8 @@ class SystemFeatureModel(BaseModel): is_allow_create_workspace: bool = False is_email_setup: bool = False license: LicenseModel = LicenseModel() + branding: BrandingModel = BrandingModel() + webapp_auth: WebAppAuthModel = WebAppAuthModel() class FeatureService: @@ -92,6 +140,10 @@ class FeatureService: if dify_config.BILLING_ENABLED and tenant_id: cls._fulfill_params_from_billing_api(features, tenant_id) + if dify_config.ENTERPRISE_ENABLED: + features.webapp_copyright_enabled = True + cls._fulfill_params_from_workspace_info(features, tenant_id) + return features @classmethod @@ -111,8 +163,8 @@ class FeatureService: cls._fulfill_system_params_from_env(system_features) if dify_config.ENTERPRISE_ENABLED: - system_features.enable_web_sso_switch_component = True - + system_features.branding.enabled = True + system_features.webapp_auth.enabled = True cls._fulfill_params_from_enterprise(system_features) if dify_config.MARKETPLACE_ENABLED: @@ -136,6 +188,14 @@ class FeatureService: features.dataset_operator_enabled = dify_config.DATASET_OPERATOR_ENABLED features.education.enabled = dify_config.EDUCATION_ENABLED + @classmethod + def _fulfill_params_from_workspace_info(cls, features: FeatureModel, tenant_id: str): + workspace_info = EnterpriseService.get_workspace_info(tenant_id) + if "WorkspaceMembers" in workspace_info: + features.workspace_members.size = workspace_info["WorkspaceMembers"]["used"] + features.workspace_members.limit = workspace_info["WorkspaceMembers"]["limit"] + features.workspace_members.enabled = workspace_info["WorkspaceMembers"]["enabled"] + @classmethod def _fulfill_params_from_billing_api(cls, features: FeatureModel, tenant_id: str): billing_info = BillingService.get_info(tenant_id) @@ -145,6 +205,9 @@ class FeatureService: features.billing.subscription.interval = billing_info["subscription"]["interval"] features.education.activated = billing_info["subscription"].get("education", False) + if features.billing.subscription.plan != "sandbox": + features.webapp_copyright_enabled = True + if "members" in billing_info: features.members.size = billing_info["members"]["size"] features.members.limit = billing_info["members"]["limit"] @@ -178,38 +241,53 @@ class FeatureService: features.knowledge_rate_limit = billing_info["knowledge_rate_limit"]["limit"] @classmethod - def _fulfill_params_from_enterprise(cls, features): + def _fulfill_params_from_enterprise(cls, features: SystemFeatureModel): enterprise_info = EnterpriseService.get_info() - if "sso_enforced_for_signin" in enterprise_info: - features.sso_enforced_for_signin = enterprise_info["sso_enforced_for_signin"] + if "SSOEnforcedForSignin" in enterprise_info: + features.sso_enforced_for_signin = enterprise_info["SSOEnforcedForSignin"] - if "sso_enforced_for_signin_protocol" in enterprise_info: - features.sso_enforced_for_signin_protocol = enterprise_info["sso_enforced_for_signin_protocol"] + if "SSOEnforcedForSigninProtocol" in enterprise_info: + features.sso_enforced_for_signin_protocol = enterprise_info["SSOEnforcedForSigninProtocol"] - if "sso_enforced_for_web" in enterprise_info: - features.sso_enforced_for_web = enterprise_info["sso_enforced_for_web"] + if "EnableEmailCodeLogin" in enterprise_info: + features.enable_email_code_login = enterprise_info["EnableEmailCodeLogin"] - if "sso_enforced_for_web_protocol" in enterprise_info: - features.sso_enforced_for_web_protocol = enterprise_info["sso_enforced_for_web_protocol"] + if "EnableEmailPasswordLogin" in enterprise_info: + features.enable_email_password_login = enterprise_info["EnableEmailPasswordLogin"] - if "enable_email_code_login" in enterprise_info: - features.enable_email_code_login = enterprise_info["enable_email_code_login"] + if "IsAllowRegister" in enterprise_info: + features.is_allow_register = enterprise_info["IsAllowRegister"] - if "enable_email_password_login" in enterprise_info: - features.enable_email_password_login = enterprise_info["enable_email_password_login"] + if "IsAllowCreateWorkspace" in enterprise_info: + features.is_allow_create_workspace = enterprise_info["IsAllowCreateWorkspace"] - if "is_allow_register" in enterprise_info: - features.is_allow_register = enterprise_info["is_allow_register"] + if "Branding" in enterprise_info: + features.branding.application_title = enterprise_info["Branding"].get("applicationTitle", "") + features.branding.login_page_logo = enterprise_info["Branding"].get("loginPageLogo", "") + features.branding.workspace_logo = enterprise_info["Branding"].get("workspaceLogo", "") + features.branding.favicon = enterprise_info["Branding"].get("favicon", "") - if "is_allow_create_workspace" in enterprise_info: - features.is_allow_create_workspace = enterprise_info["is_allow_create_workspace"] + if "WebAppAuth" in enterprise_info: + features.webapp_auth.allow_sso = enterprise_info["WebAppAuth"].get("allowSso", False) + features.webapp_auth.allow_email_code_login = enterprise_info["WebAppAuth"].get( + "allowEmailCodeLogin", False + ) + features.webapp_auth.allow_email_password_login = enterprise_info["WebAppAuth"].get( + "allowEmailPasswordLogin", False + ) + features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "") - if "license" in enterprise_info: - license_info = enterprise_info["license"] + if "License" in enterprise_info: + license_info = enterprise_info["License"] if "status" in license_info: features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE)) - if "expired_at" in license_info: - features.license.expired_at = license_info["expired_at"] + if "expiredAt" in license_info: + features.license.expired_at = license_info["expiredAt"] + + if "workspaces" in license_info: + features.license.workspaces.enabled = license_info["workspaces"]["enabled"] + features.license.workspaces.limit = license_info["workspaces"]["limit"] + features.license.workspaces.size = license_info["workspaces"]["used"] diff --git a/api/services/file_service.py b/api/services/file_service.py index 2ca6b4f9aa..2d68f30c5a 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -19,7 +19,7 @@ from core.rag.extractor.extract_processor import ExtractProcessor from extensions.ext_database import db from extensions.ext_storage import storage from models.account import Account -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.model import EndUser, UploadFile from .errors.file import FileTooLargeError, UnsupportedFileTypeError @@ -81,7 +81,7 @@ class FileService: size=file_size, extension=extension, mime_type=mimetype, - created_by_role=(CreatedByRole.ACCOUNT if isinstance(user, Account) else CreatedByRole.END_USER), + created_by_role=(CreatorUserRole.ACCOUNT if isinstance(user, Account) else CreatorUserRole.END_USER), created_by=user.id, created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=False, @@ -133,7 +133,7 @@ class FileService: extension="txt", mime_type="text/plain", created_by=current_user.id, - created_by_role=CreatedByRole.ACCOUNT, + created_by_role=CreatorUserRole.ACCOUNT, created_at=datetime.datetime.now(datetime.UTC).replace(tzinfo=None), used=True, used_by=current_user.id, diff --git a/api/services/ops_service.py b/api/services/ops_service.py index 6b317212d1..a9c2b28476 100644 --- a/api/services/ops_service.py +++ b/api/services/ops_service.py @@ -87,7 +87,9 @@ class OpsService: :param tracing_config: tracing config :return: """ - if tracing_provider not in provider_config_map and tracing_provider: + try: + provider_config_map[tracing_provider] + except KeyError: return {"error": f"Invalid tracing provider: {tracing_provider}"} config_class, other_keys = ( @@ -150,7 +152,9 @@ class OpsService: :param tracing_config: tracing config :return: """ - if tracing_provider not in provider_config_map: + try: + provider_config_map[tracing_provider] + except KeyError: raise ValueError(f"Invalid tracing provider: {tracing_provider}") # check if trace config already exists diff --git a/api/services/webapp_auth_service.py b/api/services/webapp_auth_service.py new file mode 100644 index 0000000000..79d5217de7 --- /dev/null +++ b/api/services/webapp_auth_service.py @@ -0,0 +1,141 @@ +import random +from datetime import UTC, datetime, timedelta +from typing import Any, Optional, cast + +from werkzeug.exceptions import NotFound, Unauthorized + +from configs import dify_config +from controllers.web.error import WebAppAuthAccessDeniedError +from extensions.ext_database import db +from libs.helper import TokenManager +from libs.passport import PassportService +from libs.password import compare_password +from models.account import Account, AccountStatus +from models.model import App, EndUser, Site +from services.enterprise.enterprise_service import EnterpriseService +from services.errors.account import AccountLoginError, AccountNotFoundError, AccountPasswordError +from services.feature_service import FeatureService +from tasks.mail_email_code_login import send_email_code_login_mail_task + + +class WebAppAuthService: + """Service for web app authentication.""" + + @staticmethod + def authenticate(email: str, password: str) -> Account: + """authenticate account with email and password""" + + account = Account.query.filter_by(email=email).first() + if not account: + raise AccountNotFoundError() + + if account.status == AccountStatus.BANNED.value: + raise AccountLoginError("Account is banned.") + + if account.password is None or not compare_password(password, account.password, account.password_salt): + raise AccountPasswordError("Invalid email or password.") + + return cast(Account, account) + + @classmethod + def login(cls, account: Account, app_code: str, end_user_id: str) -> str: + site = db.session.query(Site).filter(Site.code == app_code).first() + if not site: + raise NotFound("Site not found.") + + access_token = cls._get_account_jwt_token(account=account, site=site, end_user_id=end_user_id) + + return access_token + + @classmethod + def get_user_through_email(cls, email: str): + account = db.session.query(Account).filter(Account.email == email).first() + if not account: + return None + + if account.status == AccountStatus.BANNED.value: + raise Unauthorized("Account is banned.") + + return account + + @classmethod + def send_email_code_login_email( + cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" + ): + email = account.email if account else email + if email is None: + raise ValueError("Email must be provided.") + + code = "".join([str(random.randint(0, 9)) for _ in range(6)]) + token = TokenManager.generate_token( + account=account, email=email, token_type="webapp_email_code_login", additional_data={"code": code} + ) + send_email_code_login_mail_task.delay( + language=language, + to=account.email if account else email, + code=code, + ) + + return token + + @classmethod + def get_email_code_login_data(cls, token: str) -> Optional[dict[str, Any]]: + return TokenManager.get_token_data(token, "webapp_email_code_login") + + @classmethod + def revoke_email_code_login_token(cls, token: str): + TokenManager.revoke_token(token, "webapp_email_code_login") + + @classmethod + def create_end_user(cls, app_code, email) -> EndUser: + site = db.session.query(Site).filter(Site.code == app_code).first() + if not site: + raise NotFound("Site not found.") + app_model = db.session.query(App).filter(App.id == site.app_id).first() + if not app_model: + raise NotFound("App not found.") + end_user = EndUser( + tenant_id=app_model.tenant_id, + app_id=app_model.id, + type="browser", + is_anonymous=False, + session_id=email, + name="enterpriseuser", + external_user_id="enterpriseuser", + ) + db.session.add(end_user) + db.session.commit() + + return end_user + + @classmethod + def _validate_user_accessibility(cls, account: Account, app_code: str): + """Check if the user is allowed to access the app.""" + system_features = FeatureService.get_system_features() + if system_features.webapp_auth.enabled: + app_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_code(app_code=app_code) + + if ( + app_settings.access_mode != "public" + and not EnterpriseService.WebAppAuth.is_user_allowed_to_access_webapp(account.id, app_code=app_code) + ): + raise WebAppAuthAccessDeniedError() + + @classmethod + def _get_account_jwt_token(cls, account: Account, site: Site, end_user_id: str) -> str: + exp_dt = datetime.now(UTC) + timedelta(hours=dify_config.ACCESS_TOKEN_EXPIRE_MINUTES * 24) + exp = int(exp_dt.timestamp()) + + payload = { + "iss": site.id, + "sub": "Web API Passport", + "app_id": site.app_id, + "app_code": site.code, + "user_id": account.id, + "end_user_id": end_user_id, + "token_source": "webapp", + "exp": exp, + } + + token: str = PassportService().issue(payload) + return token diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index e526517b51..a899ebe278 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -5,7 +5,7 @@ from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session from models import App, EndUser, WorkflowAppLog, WorkflowRun -from models.enums import CreatedByRole +from models.enums import CreatorUserRole from models.workflow import WorkflowRunStatus @@ -58,7 +58,7 @@ class WorkflowAppService: stmt = stmt.outerjoin( EndUser, - and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatedByRole.END_USER), + and_(WorkflowRun.created_by == EndUser.id, WorkflowRun.created_by_role == CreatorUserRole.END_USER), ).where(or_(*keyword_conditions)) if status: diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 6d5b737962..3efdb3a79d 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -1,4 +1,5 @@ import threading +from collections.abc import Sequence from typing import Optional import contexts @@ -6,11 +7,13 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.repository.workflow_node_execution_repository import OrderConfig from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination -from models.enums import WorkflowRunTriggeredFrom -from models.model import App -from models.workflow import ( +from models import ( + Account, + App, + EndUser, WorkflowNodeExecution, WorkflowRun, + WorkflowRunTriggeredFrom, ) @@ -116,7 +119,12 @@ class WorkflowRunService: return workflow_run - def get_workflow_run_node_executions(self, app_model: App, run_id: str) -> list[WorkflowNodeExecution]: + def get_workflow_run_node_executions( + self, + app_model: App, + run_id: str, + user: Account | EndUser, + ) -> Sequence[WorkflowNodeExecution]: """ Get workflow run node execution list """ @@ -128,13 +136,18 @@ class WorkflowRunService: if not workflow_run: return [] - # Use the repository to get the node executions repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id + session_factory=db.engine, + user=user, + app_id=app_model.id, + triggered_from=None, ) # Use the repository to get the node executions with ordering order_config = OrderConfig(order_by=["index"], order_direction="desc") node_executions = repository.get_by_workflow_run(workflow_run_id=run_id, order_config=order_config) - return list(node_executions) + # Convert domain models to database models + workflow_node_executions = [repository.to_db_model(node_execution) for node_execution in node_executions] + + return workflow_node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 331dba8bf1..06d92b9e29 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -10,10 +10,10 @@ from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager -from core.model_runtime.utils.encoders import jsonable_encoder from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.variables import Variable from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.node_execution_entities import NodeExecution, NodeExecutionStatus from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes import NodeType @@ -26,7 +26,6 @@ from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from models.account import Account -from models.enums import CreatedByRole from models.model import App, AppMode from models.tools import WorkflowToolProvider from models.workflow import ( @@ -268,33 +267,37 @@ class WorkflowService: # run draft workflow node start_at = time.perf_counter() - workflow_node_execution = self._handle_node_run_result( - getter=lambda: WorkflowEntry.single_step_run( + node_execution = self._handle_node_run_result( + invoke_node_fn=lambda: WorkflowEntry.single_step_run( workflow=draft_workflow, node_id=node_id, user_inputs=user_inputs, user_id=account.id, ), start_at=start_at, - tenant_id=app_model.tenant_id, node_id=node_id, ) - workflow_node_execution.app_id = app_model.id - workflow_node_execution.created_by = account.id - workflow_node_execution.workflow_id = draft_workflow.id + # Set workflow_id on the NodeExecution + node_execution.workflow_id = draft_workflow.id - # Use the repository to save the workflow node execution + # Create repository and save the node execution repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, tenant_id=app_model.tenant_id, app_id=app_model.id + session_factory=db.engine, + user=account, + app_id=app_model.id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) - repository.save(workflow_node_execution) + repository.save(node_execution) + + # Convert node_execution to WorkflowNodeExecution after save + workflow_node_execution = repository.to_db_model(node_execution) return workflow_node_execution def run_free_workflow_node( self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any] - ) -> WorkflowNodeExecution: + ) -> NodeExecution: """ Run draft workflow node """ @@ -302,7 +305,7 @@ class WorkflowService: start_at = time.perf_counter() workflow_node_execution = self._handle_node_run_result( - getter=lambda: WorkflowEntry.run_free_node( + invoke_node_fn=lambda: WorkflowEntry.run_free_node( node_id=node_id, node_data=node_data, tenant_id=tenant_id, @@ -310,7 +313,6 @@ class WorkflowService: user_inputs=user_inputs, ), start_at=start_at, - tenant_id=tenant_id, node_id=node_id, ) @@ -318,21 +320,12 @@ class WorkflowService: def _handle_node_run_result( self, - getter: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]], + invoke_node_fn: Callable[[], tuple[BaseNode, Generator[NodeEvent | InNodeEvent, None, None]]], start_at: float, - tenant_id: str, node_id: str, - ) -> WorkflowNodeExecution: - """ - Handle node run result - - :param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]] - :param start_at: float - :param tenant_id: str - :param node_id: str - """ + ) -> NodeExecution: try: - node_instance, generator = getter() + node_instance, generator = invoke_node_fn() node_run_result: NodeRunResult | None = None for event in generator: @@ -381,20 +374,21 @@ class WorkflowService: node_run_result = None error = e.error - workflow_node_execution = WorkflowNodeExecution() - workflow_node_execution.id = str(uuid4()) - workflow_node_execution.tenant_id = tenant_id - workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value - workflow_node_execution.index = 1 - workflow_node_execution.node_id = node_id - workflow_node_execution.node_type = node_instance.node_type - workflow_node_execution.title = node_instance.node_data.title - workflow_node_execution.elapsed_time = time.perf_counter() - start_at - workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value - workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None) - workflow_node_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) + # Create a NodeExecution domain model + node_execution = NodeExecution( + id=str(uuid4()), + workflow_id="", # This is a single-step execution, so no workflow ID + index=1, + node_id=node_id, + node_type=node_instance.node_type, + title=node_instance.node_data.title, + elapsed_time=time.perf_counter() - start_at, + created_at=datetime.now(UTC).replace(tzinfo=None), + finished_at=datetime.now(UTC).replace(tzinfo=None), + ) + if run_succeeded and node_run_result: - # create workflow node execution + # Set inputs, process_data, and outputs as dictionaries (not JSON strings) inputs = WorkflowEntry.handle_special_values(node_run_result.inputs) if node_run_result.inputs else None process_data = ( WorkflowEntry.handle_special_values(node_run_result.process_data) @@ -403,23 +397,23 @@ class WorkflowService: ) outputs = WorkflowEntry.handle_special_values(node_run_result.outputs) if node_run_result.outputs else None - workflow_node_execution.inputs = json.dumps(inputs) - workflow_node_execution.process_data = json.dumps(process_data) - workflow_node_execution.outputs = json.dumps(outputs) - workflow_node_execution.execution_metadata = ( - json.dumps(jsonable_encoder(node_run_result.metadata)) if node_run_result.metadata else None - ) + node_execution.inputs = inputs + node_execution.process_data = process_data + node_execution.outputs = outputs + node_execution.metadata = node_run_result.metadata + + # Map status from WorkflowNodeExecutionStatus to NodeExecutionStatus if node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: - workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value + node_execution.status = NodeExecutionStatus.SUCCEEDED elif node_run_result.status == WorkflowNodeExecutionStatus.EXCEPTION: - workflow_node_execution.status = WorkflowNodeExecutionStatus.EXCEPTION.value - workflow_node_execution.error = node_run_result.error + node_execution.status = NodeExecutionStatus.EXCEPTION + node_execution.error = node_run_result.error else: - # create workflow node execution - workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value - workflow_node_execution.error = error + # Set failed status and error + node_execution.status = NodeExecutionStatus.FAILED + node_execution.error = error - return workflow_node_execution + return node_execution def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App: """ diff --git a/api/tasks/mail_email_code_login.py b/api/tasks/mail_email_code_login.py index 5dc935548f..ddad331725 100644 --- a/api/tasks/mail_email_code_login.py +++ b/api/tasks/mail_email_code_login.py @@ -6,6 +6,7 @@ from celery import shared_task # type: ignore from flask import render_template from extensions.ext_mail import mail +from services.feature_service import FeatureService @shared_task(queue="mail") @@ -25,10 +26,24 @@ def send_email_code_login_mail_task(language: str, to: str, code: str): # send email code login mail using different languages try: if language == "zh-Hans": - html_content = render_template("email_code_login_mail_template_zh-CN.html", to=to, code=code) + template = "email_code_login_mail_template_zh-CN.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + application_title = system_features.branding.application_title + template = "without-brand/email_code_login_mail_template_zh-CN.html" + html_content = render_template(template, to=to, code=code, application_title=application_title) + else: + html_content = render_template(template, to=to, code=code) mail.send(to=to, subject="邮箱验证码", html=html_content) else: - html_content = render_template("email_code_login_mail_template_en-US.html", to=to, code=code) + template = "email_code_login_mail_template_en-US.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + application_title = system_features.branding.application_title + template = "without-brand/email_code_login_mail_template_en-US.html" + html_content = render_template(template, to=to, code=code, application_title=application_title) + else: + html_content = render_template(template, to=to, code=code) mail.send(to=to, subject="Email Code", html=html_content) end_at = time.perf_counter() diff --git a/api/tasks/mail_enterprise_task.py b/api/tasks/mail_enterprise_task.py new file mode 100644 index 0000000000..b9d8fd55df --- /dev/null +++ b/api/tasks/mail_enterprise_task.py @@ -0,0 +1,33 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore +from flask import render_template_string + +from extensions.ext_mail import mail + + +@shared_task(queue="mail") +def send_enterprise_email_task(to, subject, body, substitutions): + if not mail.is_inited(): + return + + logging.info(click.style("Start enterprise mail to {} with subject {}".format(to, subject), fg="green")) + start_at = time.perf_counter() + + try: + html_content = render_template_string(body, **substitutions) + + if isinstance(to, list): + for t in to: + mail.send(to=t, subject=subject, html=html_content) + else: + mail.send(to=to, subject=subject, html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style("Send enterprise mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green") + ) + except Exception: + logging.exception("Send enterprise mail to {} failed".format(to)) diff --git a/api/tasks/mail_invite_member_task.py b/api/tasks/mail_invite_member_task.py index 3094527fd4..7ca85c7f2d 100644 --- a/api/tasks/mail_invite_member_task.py +++ b/api/tasks/mail_invite_member_task.py @@ -7,6 +7,7 @@ from flask import render_template from configs import dify_config from extensions.ext_mail import mail +from services.feature_service import FeatureService @shared_task(queue="mail") @@ -33,23 +34,45 @@ def send_invite_member_mail_task(language: str, to: str, token: str, inviter_nam try: url = f"{dify_config.CONSOLE_WEB_URL}/activate?token={token}" if language == "zh-Hans": - html_content = render_template( - "invite_member_mail_template_zh-CN.html", - to=to, - inviter_name=inviter_name, - workspace_name=workspace_name, - url=url, - ) - mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) + template = "invite_member_mail_template_zh-CN.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + application_title = system_features.branding.application_title + template = "without-brand/invite_member_mail_template_zh-CN.html" + html_content = render_template( + template, + to=to, + inviter_name=inviter_name, + workspace_name=workspace_name, + url=url, + application_title=application_title, + ) + mail.send(to=to, subject=f"立即加入 {application_title} 工作空间", html=html_content) + else: + html_content = render_template( + template, to=to, inviter_name=inviter_name, workspace_name=workspace_name, url=url + ) + mail.send(to=to, subject="立即加入 Dify 工作空间", html=html_content) else: - html_content = render_template( - "invite_member_mail_template_en-US.html", - to=to, - inviter_name=inviter_name, - workspace_name=workspace_name, - url=url, - ) - mail.send(to=to, subject="Join Dify Workspace Now", html=html_content) + template = "invite_member_mail_template_en-US.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + application_title = system_features.branding.application_title + template = "without-brand/invite_member_mail_template_en-US.html" + html_content = render_template( + template, + to=to, + inviter_name=inviter_name, + workspace_name=workspace_name, + url=url, + application_title=application_title, + ) + mail.send(to=to, subject=f"Join {application_title} Workspace Now", html=html_content) + else: + html_content = render_template( + template, to=to, inviter_name=inviter_name, workspace_name=workspace_name, url=url + ) + mail.send(to=to, subject="Join Dify Workspace Now", html=html_content) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/mail_reset_password_task.py b/api/tasks/mail_reset_password_task.py index d5be94431b..d4f4482a48 100644 --- a/api/tasks/mail_reset_password_task.py +++ b/api/tasks/mail_reset_password_task.py @@ -6,6 +6,7 @@ from celery import shared_task # type: ignore from flask import render_template from extensions.ext_mail import mail +from services.feature_service import FeatureService @shared_task(queue="mail") @@ -25,11 +26,27 @@ def send_reset_password_mail_task(language: str, to: str, code: str): # send reset password mail using different languages try: if language == "zh-Hans": - html_content = render_template("reset_password_mail_template_zh-CN.html", to=to, code=code) - mail.send(to=to, subject="设置您的 Dify 密码", html=html_content) + template = "reset_password_mail_template_zh-CN.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + application_title = system_features.branding.application_title + template = "without-brand/reset_password_mail_template_zh-CN.html" + html_content = render_template(template, to=to, code=code, application_title=application_title) + mail.send(to=to, subject=f"设置您的 {application_title} 密码", html=html_content) + else: + html_content = render_template(template, to=to, code=code) + mail.send(to=to, subject="设置您的 Dify 密码", html=html_content) else: - html_content = render_template("reset_password_mail_template_en-US.html", to=to, code=code) - mail.send(to=to, subject="Set Your Dify Password", html=html_content) + template = "reset_password_mail_template_en-US.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + application_title = system_features.branding.application_title + template = "without-brand/reset_password_mail_template_en-US.html" + html_content = render_template(template, to=to, code=code, application_title=application_title) + mail.send(to=to, subject=f"Set Your {application_title} Password", html=html_content) + else: + html_content = render_template(template, to=to, code=code) + mail.send(to=to, subject="Set Your Dify Password", html=html_content) end_at = time.perf_counter() logging.info( diff --git a/api/tasks/remove_app_and_related_data_task.py b/api/tasks/remove_app_and_related_data_task.py index d5a783396a..d9c1980d3f 100644 --- a/api/tasks/remove_app_and_related_data_task.py +++ b/api/tasks/remove_app_and_related_data_task.py @@ -4,16 +4,19 @@ from collections.abc import Callable import click from celery import shared_task # type: ignore -from sqlalchemy import delete +from sqlalchemy import delete, select from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import Session from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from extensions.ext_database import db -from models.dataset import AppDatasetJoin -from models.model import ( +from models import ( + Account, ApiToken, + App, AppAnnotationHitHistory, AppAnnotationSetting, + AppDatasetJoin, AppModelConfig, Conversation, EndUser, @@ -188,9 +191,24 @@ def _delete_app_workflow_runs(tenant_id: str, app_id: str): def _delete_app_workflow_node_executions(tenant_id: str, app_id: str): + # Get app's owner + with Session(db.engine, expire_on_commit=False) as session: + stmt = select(Account).where(Account.id == App.owner_id).where(App.id == app_id) + user = session.scalar(stmt) + + if user is None: + errmsg = ( + f"Failed to delete workflow node executions for tenant {tenant_id} and app {app_id}, app's owner not found" + ) + logging.error(errmsg) + raise ValueError(errmsg) + # Create a repository instance for WorkflowNodeExecution repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, tenant_id=tenant_id, app_id=app_id + session_factory=db.engine, + user=user, + app_id=app_id, + triggered_from=None, ) # Use the clear method to delete all records for this tenant_id and app_id diff --git a/api/templates/clean_document_job_mail_template-US.html b/api/templates/clean_document_job_mail_template-US.html index 0f7ddc62a9..2d8f78b46a 100644 --- a/api/templates/clean_document_job_mail_template-US.html +++ b/api/templates/clean_document_job_mail_template-US.html @@ -69,7 +69,7 @@