diff --git a/.devcontainer/post_create_command.sh b/.devcontainer/post_create_command.sh index 93ecac48f2..022f71bfb4 100755 --- a/.devcontainer/post_create_command.sh +++ b/.devcontainer/post_create_command.sh @@ -1,6 +1,6 @@ #!/bin/bash -npm add -g pnpm@10.11.1 +npm add -g pnpm@10.13.1 cd web && pnpm install pipx install uv @@ -12,3 +12,4 @@ echo 'alias start-containers="cd /workspaces/dify/docker && docker-compose -f do echo 'alias stop-containers="cd /workspaces/dify/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env down"' >> ~/.bashrc source /home/vscode/.bashrc + diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml index a9580a3ba3..d684fe9144 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -8,13 +8,15 @@ body: label: Self Checks description: "To make sure we get to you in time, please check the following :)" options: + - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). + required: true - label: This is only for bug report, if you would like to ask a question, please head to [Discussions](https://github.com/langgenius/dify/discussions/categories/general). required: true - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true - - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). + - label: I confirm that I am using English to submit this report, otherwise it will be closed. required: true - - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + - label: 【中文用户 & Non English User】请使用英语提交,否则会被关闭 :) required: true - label: "Please do not modify this template :) and fill in all the required fields." required: true @@ -42,20 +44,22 @@ body: attributes: label: Steps to reproduce description: We highly suggest including screenshots and a bug report log. Please use the right markdown syntax for code blocks. - placeholder: Having detailed steps helps us reproduce the bug. + placeholder: Having detailed steps helps us reproduce the bug. If you have logs, please use fenced code blocks (triple backticks ```) to format them. validations: required: true - type: textarea attributes: label: ✔️ Expected Behavior - placeholder: What were you expecting? + description: Describe what you expected to happen. + placeholder: What were you expecting? Please do not copy and paste the steps to reproduce here. validations: - required: false + required: true - type: textarea attributes: label: ❌ Actual Behavior - placeholder: What happened instead? + description: Describe what actually happened. + placeholder: What happened instead? Please do not copy and paste the steps to reproduce here. validations: required: false diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 6877c382c4..c1666d24cf 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -1,5 +1,11 @@ blank_issues_enabled: false contact_links: + - name: "\U0001F4A1 Model Providers & Plugins" + url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose" + about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details. + - name: "\U0001F4AC Documentation Issues" + url: "https://github.com/langgenius/dify-docs/issues/new" + about: Report issues with the documentation, such as typos, outdated information, or missing content. Please provide the specific section and details of the issue. - name: "\U0001F4E7 Discussions" url: https://github.com/langgenius/dify/discussions/categories/general - about: General discussions and request help from the community + about: General discussions and seek help from the community diff --git a/.github/ISSUE_TEMPLATE/document_issue.yml b/.github/ISSUE_TEMPLATE/document_issue.yml deleted file mode 100644 index 8fdbc0fb9a..0000000000 --- a/.github/ISSUE_TEMPLATE/document_issue.yml +++ /dev/null @@ -1,24 +0,0 @@ -name: "📚 Documentation Issue" -description: Report issues in our documentation -labels: - - documentation -body: - - type: checkboxes - attributes: - label: Self Checks - description: "To make sure we get to you in time, please check the following :)" - options: - - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. - required: true - - label: I confirm that I am using English to submit report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). - required: true - - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" - required: true - - label: "Please do not modify this template :) and fill in all the required fields." - required: true - - type: textarea - attributes: - label: Provide a description of requested docs changes - placeholder: Briefly describe which document needs to be corrected and why. - validations: - required: true diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml index b1952c63a9..bd293e2442 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -8,11 +8,11 @@ body: label: Self Checks description: "To make sure we get to you in time, please check the following :)" options: - - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. + - label: I have read the [Contributing Guide](https://github.com/langgenius/dify/blob/main/CONTRIBUTING.md) and [Language Policy](https://github.com/langgenius/dify/issues/1542). required: true - - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). + - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. required: true - - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" + - label: I confirm that I am using English to submit this report, otherwise it will be closed. required: true - label: "Please do not modify this template :) and fill in all the required fields." required: true diff --git a/.github/ISSUE_TEMPLATE/translation_issue.yml b/.github/ISSUE_TEMPLATE/translation_issue.yml deleted file mode 100644 index f9c2dfb7d2..0000000000 --- a/.github/ISSUE_TEMPLATE/translation_issue.yml +++ /dev/null @@ -1,55 +0,0 @@ -name: "🌐 Localization/Translation issue" -description: Report incorrect translations. [please use English :)] -labels: - - translation -body: - - type: checkboxes - attributes: - label: Self Checks - description: "To make sure we get to you in time, please check the following :)" - options: - - label: I have searched for existing issues [search for existing issues](https://github.com/langgenius/dify/issues), including closed ones. - required: true - - label: I confirm that I am using English to submit this report (我已阅读并同意 [Language Policy](https://github.com/langgenius/dify/issues/1542)). - required: true - - label: "[FOR CHINESE USERS] 请务必使用英文提交 Issue,否则会被关闭。谢谢!:)" - required: true - - label: "Please do not modify this template :) and fill in all the required fields." - required: true - - type: input - attributes: - label: Dify version - description: Hover over system tray icon or look at Settings - validations: - required: true - - type: input - attributes: - label: Utility with translation issue - placeholder: Some area - description: Please input here the utility with the translation issue - validations: - required: true - - type: input - attributes: - label: 🌐 Language affected - placeholder: "German" - validations: - required: true - - type: textarea - attributes: - label: ❌ Actual phrase(s) - placeholder: What is there? Please include a screenshot as that is extremely helpful. - validations: - required: true - - type: textarea - attributes: - label: ✔️ Expected phrase(s) - placeholder: What was expected? - validations: - required: true - - type: textarea - attributes: - label: ℹ Why is the current translation wrong - placeholder: Why do you feel this is incorrect? - validations: - required: true diff --git a/.github/workflows/build-push.yml b/.github/workflows/build-push.yml index cc735ae67c..b933560a5e 100644 --- a/.github/workflows/build-push.yml +++ b/.github/workflows/build-push.yml @@ -6,6 +6,7 @@ on: - "main" - "deploy/dev" - "deploy/enterprise" + - "build/**" tags: - "*" diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index b06ab9653e..a283f8d5ca 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -28,7 +28,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v45 + uses: tj-actions/changed-files@v46 with: files: | api/** @@ -75,7 +75,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v45 + uses: tj-actions/changed-files@v46 with: files: web/** @@ -113,7 +113,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v45 + uses: tj-actions/changed-files@v46 with: files: | docker/generate_docker_compose @@ -144,7 +144,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v45 + uses: tj-actions/changed-files@v46 with: files: | **.sh @@ -152,13 +152,15 @@ jobs: **.yml **Dockerfile dev/** + .editorconfig - name: Super-linter - uses: super-linter/super-linter/slim@v7 + uses: super-linter/super-linter/slim@v8 if: steps.changed-files.outputs.any_changed == 'true' env: BASH_SEVERITY: warning - DEFAULT_BRANCH: main + DEFAULT_BRANCH: origin/main + EDITORCONFIG_FILE_NAME: editorconfig-checker.json FILTER_REGEX_INCLUDE: pnpm-lock.yaml GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} IGNORE_GENERATED_FILES: true @@ -168,16 +170,6 @@ jobs: # FIXME: temporarily disabled until api-docker.yaml's run script is fixed for shellcheck # VALIDATE_GITHUB_ACTIONS: true VALIDATE_DOCKERFILE_HADOLINT: true + VALIDATE_EDITORCONFIG: true VALIDATE_XML: true VALIDATE_YAML: true - - - name: EditorConfig checks - uses: super-linter/super-linter/slim@v7 - env: - DEFAULT_BRANCH: main - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - IGNORE_GENERATED_FILES: true - IGNORE_GITIGNORED_FILES: true - # EditorConfig validation - VALIDATE_EDITORCONFIG: true - EDITORCONFIG_FILE_NAME: editorconfig-checker.json diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index 37cfdc5c1e..c3f8fdbaf6 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -27,7 +27,7 @@ jobs: - name: Check changed files id: changed-files - uses: tj-actions/changed-files@v45 + uses: tj-actions/changed-files@v46 with: files: web/** diff --git a/api/.env.example b/api/.env.example index eab017a624..daa0df535b 100644 --- a/api/.env.example +++ b/api/.env.example @@ -5,17 +5,17 @@ SECRET_KEY= # Console API base URL -CONSOLE_API_URL=http://127.0.0.1:5001 -CONSOLE_WEB_URL=http://127.0.0.1:3000 +CONSOLE_API_URL=http://localhost:5001 +CONSOLE_WEB_URL=http://localhost:3000 # Service API base URL -SERVICE_API_URL=http://127.0.0.1:5001 +SERVICE_API_URL=http://localhost:5001 # Web APP base URL -APP_WEB_URL=http://127.0.0.1:3000 +APP_WEB_URL=http://localhost:3000 # Files URL -FILES_URL=http://127.0.0.1:5001 +FILES_URL=http://localhost:5001 # INTERNAL_FILES_URL is used for plugin daemon communication within Docker network. # Set this to the internal Docker service URL for proper plugin file access. @@ -54,7 +54,7 @@ REDIS_CLUSTERS_PASSWORD= # celery configuration CELERY_BROKER_URL=redis://:difyai123456@localhost:${REDIS_PORT}/1 - +CELERY_BACKEND=redis # PostgreSQL database configuration DB_USERNAME=postgres DB_PASSWORD=difyai123456 @@ -138,12 +138,14 @@ SUPABASE_API_KEY=your-access-key SUPABASE_URL=your-server-url # CORS configuration -WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* -CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,* +WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,* +CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,* # Vector database configuration -# support: weaviate, qdrant, milvus, myscale, relyt, pgvecto_rs, pgvector, pgvector, chroma, opensearch, tidb_vector, couchbase, vikingdb, upstash, lindorm, oceanbase, opengauss, tablestore, matrixone +# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`. VECTOR_STORE=weaviate +# Prefix used to create collection name in vector database +VECTOR_INDEX_NAME_PREFIX=Vector_index # Weaviate configuration WEAVIATE_ENDPOINT=http://localhost:8080 @@ -495,6 +497,8 @@ ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id} # Reset password token expiry minutes RESET_PASSWORD_TOKEN_EXPIRY_MINUTES=5 +CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES=5 +OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES=5 CREATE_TIDB_SERVICE_JOB_ENABLED=false @@ -505,6 +509,8 @@ LOGIN_LOCKOUT_DURATION=86400 # Enable OpenTelemetry ENABLE_OTEL=false +OTLP_TRACE_ENDPOINT= +OTLP_METRIC_ENDPOINT= OTLP_BASE_ENDPOINT=http://localhost:4318 OTLP_API_KEY= OTEL_EXPORTER_OTLP_PROTOCOL= diff --git a/api/commands.py b/api/commands.py index 86769847c1..9f933a378c 100644 --- a/api/commands.py +++ b/api/commands.py @@ -2,19 +2,22 @@ import base64 import json import logging import secrets -from typing import Optional +from typing import Any, Optional import click from flask import current_app +from pydantic import TypeAdapter from sqlalchemy import select from werkzeug.exceptions import NotFound from configs import dify_config from constants.languages import languages +from core.plugin.entities.plugin import ToolProviderID from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_type import VectorType from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.models.document import Document +from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params from events.app_event import app_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -27,6 +30,7 @@ from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, D from models.dataset import Document as DatasetDocument from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation from models.provider import Provider, ProviderModel +from models.tools import ToolOAuthSystemClient from services.account_service import AccountService, RegisterService, TenantService from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpiredLogs from services.plugin.data_migration import PluginDataMigration @@ -1155,3 +1159,49 @@ def remove_orphaned_files_on_storage(force: bool): click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green")) else: click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow")) + + +@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.") +@click.option("--provider", prompt=True, help="Provider name") +@click.option("--client-params", prompt=True, help="Client Params") +def setup_system_tool_oauth_client(provider, client_params): + """ + Setup system tool oauth client + """ + provider_id = ToolProviderID(provider) + provider_name = provider_id.provider_name + plugin_id = provider_id.plugin_id + + try: + # json validate + click.echo(click.style(f"Validating client params: {client_params}", fg="yellow")) + client_params_dict = TypeAdapter(dict[str, Any]).validate_json(client_params) + click.echo(click.style("Client params validated successfully.", fg="green")) + + click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow")) + click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow")) + oauth_client_params = encrypt_system_oauth_params(client_params_dict) + click.echo(click.style("Client params encrypted successfully.", fg="green")) + except Exception as e: + click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red")) + return + + deleted_count = ( + db.session.query(ToolOAuthSystemClient) + .filter_by( + provider=provider_name, + plugin_id=plugin_id, + ) + .delete() + ) + if deleted_count > 0: + click.echo(click.style(f"Deleted {deleted_count} existing oauth client params.", fg="yellow")) + + oauth_client = ToolOAuthSystemClient( + provider=provider_name, + plugin_id=plugin_id, + encrypted_oauth_params=oauth_client_params, + ) + db.session.add(oauth_client) + db.session.commit() + click.echo(click.style(f"OAuth client params setup successfully. id: {oauth_client.id}", fg="green")) diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index f6a8b037ca..f1d529355d 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -31,6 +31,15 @@ class SecurityConfig(BaseSettings): description="Duration in minutes for which a password reset token remains valid", default=5, ) + CHANGE_EMAIL_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( + description="Duration in minutes for which a change email token remains valid", + default=5, + ) + + OWNER_TRANSFER_TOKEN_EXPIRY_MINUTES: PositiveInt = Field( + description="Duration in minutes for which a owner transfer token remains valid", + default=5, + ) LOGIN_DISABLED: bool = Field( description="Whether to disable login checks", @@ -614,6 +623,16 @@ class AuthConfig(BaseSettings): default=86400, ) + CHANGE_EMAIL_LOCKOUT_DURATION: PositiveInt = Field( + description="Time (in seconds) a user must wait before retrying change email after exceeding the rate limit.", + default=86400, + ) + + OWNER_TRANSFER_LOCKOUT_DURATION: PositiveInt = Field( + description="Time (in seconds) a user must wait before retrying owner transfer after exceeding the rate limit.", + default=86400, + ) + class ModerationConfig(BaseSettings): """ diff --git a/api/configs/middleware/__init__.py b/api/configs/middleware/__init__.py index 0c0c06dd46..587ea55ca7 100644 --- a/api/configs/middleware/__init__.py +++ b/api/configs/middleware/__init__.py @@ -85,6 +85,11 @@ class VectorStoreConfig(BaseSettings): default=False, ) + VECTOR_INDEX_NAME_PREFIX: Optional[str] = Field( + description="Prefix used to create collection name in vector database", + default="Vector_index", + ) + class KeywordStoreConfig(BaseSettings): KEYWORD_STORE: str = Field( @@ -211,7 +216,7 @@ class DatabaseConfig(BaseSettings): class CeleryConfig(DatabaseConfig): CELERY_BACKEND: str = Field( description="Backend for Celery task results. Options: 'database', 'redis'.", - default="database", + default="redis", ) CELERY_BROKER_URL: Optional[str] = Field( diff --git a/api/configs/observability/otel/otel_config.py b/api/configs/observability/otel/otel_config.py index 1b88ddcfe6..7572a696ce 100644 --- a/api/configs/observability/otel/otel_config.py +++ b/api/configs/observability/otel/otel_config.py @@ -12,6 +12,16 @@ class OTelConfig(BaseSettings): default=False, ) + OTLP_TRACE_ENDPOINT: str = Field( + description="OTLP trace endpoint", + default="", + ) + + OTLP_METRIC_ENDPOINT: str = Field( + description="OTLP metric endpoint", + default="", + ) + OTLP_BASE_ENDPOINT: str = Field( description="OTLP base endpoint", default="http://localhost:4318", diff --git a/api/constants/__init__.py b/api/constants/__init__.py index a84de0a451..9e052320ac 100644 --- a/api/constants/__init__.py +++ b/api/constants/__init__.py @@ -1,6 +1,7 @@ from configs import dify_config HIDDEN_VALUE = "[__HIDDEN__]" +UNKNOWN_VALUE = "[__UNKNOWN__]" UUID_NIL = "00000000-0000-0000-0000-000000000000" DEFAULT_FILE_NUMBER_LIMITS = 3 diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 860166a61a..9fe32dde6d 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -151,6 +151,7 @@ class AppApi(Resource): parser.add_argument("icon", type=str, location="json") parser.add_argument("icon_background", type=str, location="json") parser.add_argument("use_icon_as_answer_icon", type=bool, location="json") + parser.add_argument("max_active_requests", type=int, location="json") args = parser.parse_args() app_service = AppService() diff --git a/api/controllers/console/app/conversation.py b/api/controllers/console/app/conversation.py index 70d6216497..4eef9fed43 100644 --- a/api/controllers/console/app/conversation.py +++ b/api/controllers/console/app/conversation.py @@ -1,4 +1,4 @@ -from datetime import UTC, datetime +from datetime import datetime import pytz # pip install pytz from flask_login import current_user @@ -19,6 +19,7 @@ from fields.conversation_fields import ( conversation_pagination_fields, conversation_with_summary_pagination_fields, ) +from libs.datetime_utils import naive_utc_now from libs.helper import DatetimeString from libs.login import login_required from models import Conversation, EndUser, Message, MessageAnnotation @@ -315,7 +316,7 @@ def _get_conversation(app_model, conversation_id): raise NotFound("Conversation Not Exists.") if not conversation.read_at: - conversation.read_at = datetime.now(UTC).replace(tzinfo=None) + conversation.read_at = naive_utc_now() conversation.read_account_id = current_user.id db.session.commit() diff --git a/api/controllers/console/app/mcp_server.py b/api/controllers/console/app/mcp_server.py index 0f53860f56..503393f264 100644 --- a/api/controllers/console/app/mcp_server.py +++ b/api/controllers/console/app/mcp_server.py @@ -35,16 +35,20 @@ class AppMCPServerController(Resource): @get_app_model @marshal_with(app_server_fields) def post(self, app_model): - # The role of the current user in the ta table must be editor, admin, or owner if not current_user.is_editor: raise NotFound() parser = reqparse.RequestParser() - parser.add_argument("description", type=str, required=True, location="json") + parser.add_argument("description", type=str, required=False, location="json") parser.add_argument("parameters", type=dict, required=True, location="json") args = parser.parse_args() + + description = args.get("description") + if not description: + description = app_model.description or "" + server = AppMCPServer( name=app_model.name, - description=args["description"], + description=description, parameters=json.dumps(args["parameters"], ensure_ascii=False), status=AppMCPServerStatus.ACTIVE, app_id=app_model.id, @@ -65,14 +69,22 @@ class AppMCPServerController(Resource): raise NotFound() parser = reqparse.RequestParser() parser.add_argument("id", type=str, required=True, location="json") - parser.add_argument("description", type=str, required=True, location="json") + parser.add_argument("description", type=str, required=False, location="json") parser.add_argument("parameters", type=dict, required=True, location="json") parser.add_argument("status", type=str, required=False, location="json") args = parser.parse_args() server = db.session.query(AppMCPServer).filter(AppMCPServer.id == args["id"]).first() if not server: raise NotFound() - server.description = args["description"] + + description = args.get("description") + if description is None: + pass + elif not description: + server.description = app_model.description or "" + else: + server.description = description + server.parameters = json.dumps(args["parameters"], ensure_ascii=False) if args["status"]: if args["status"] not in [status.value for status in AppMCPServerStatus]: diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index b7a4c31a15..ea659f9f5b 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -5,6 +5,7 @@ from flask_restful import Resource, fields, marshal_with, reqparse from flask_restful.inputs import int_range from werkzeug.exceptions import Forbidden, InternalServerError, NotFound +import services from controllers.console import api from controllers.console.app.error import ( CompletionRequestError, @@ -27,7 +28,7 @@ from fields.conversation_fields import annotation_fields, message_detail_fields from libs.helper import uuid_value from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.login import login_required -from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback +from models.model import AppMode, Conversation, Message, MessageAnnotation from services.annotation_service import AppAnnotationService from services.errors.conversation import ConversationNotExistsError from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError @@ -124,33 +125,16 @@ class MessageFeedbackApi(Resource): parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") args = parser.parse_args() - message_id = str(args["message_id"]) - - message = db.session.query(Message).filter(Message.id == message_id, Message.app_id == app_model.id).first() - - if not message: - raise NotFound("Message Not Exists.") - - feedback = message.admin_feedback - - if not args["rating"] and feedback: - db.session.delete(feedback) - elif args["rating"] and feedback: - feedback.rating = args["rating"] - elif not args["rating"] and not feedback: - raise ValueError("rating cannot be None when feedback not exists") - else: - feedback = MessageFeedback( - app_id=app_model.id, - conversation_id=message.conversation_id, - message_id=message.id, - rating=args["rating"], - from_source="admin", - from_account_id=current_user.id, + try: + MessageService.create_feedback( + app_model=app_model, + message_id=str(args["message_id"]), + user=current_user, + rating=args.get("rating"), + content=None, ) - db.session.add(feedback) - - db.session.commit() + except services.errors.message.MessageNotExistsError: + raise NotFound("Message Not Exists.") return {"result": "success"} diff --git a/api/controllers/console/app/site.py b/api/controllers/console/app/site.py index 3c3a359eeb..358a5e8cdb 100644 --- a/api/controllers/console/app/site.py +++ b/api/controllers/console/app/site.py @@ -1,5 +1,3 @@ -from datetime import UTC, datetime - from flask_login import current_user from flask_restful import Resource, marshal_with, reqparse from werkzeug.exceptions import Forbidden, NotFound @@ -10,6 +8,7 @@ from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from extensions.ext_database import db from fields.app_fields import app_site_fields +from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import Site @@ -77,7 +76,7 @@ class AppSite(Resource): setattr(site, attr_name, value) site.updated_by = current_user.id - site.updated_at = datetime.now(UTC).replace(tzinfo=None) + site.updated_at = naive_utc_now() db.session.commit() return site @@ -101,7 +100,7 @@ class AppSiteAccessTokenReset(Resource): site.code = Site.generate_code(16) site.updated_by = current_user.id - site.updated_at = datetime.now(UTC).replace(tzinfo=None) + site.updated_at = naive_utc_now() db.session.commit() return site diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 00d6fa3cbf..ba93f82756 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -68,13 +68,18 @@ def _create_pagination_parser(): return parser +def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str: + value_type = workflow_draft_var.value_type + return value_type.exposed_type().value + + _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { "id": fields.String, "type": fields.String(attribute=lambda model: model.get_variable_type()), "name": fields.String, "description": fields.String, "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), - "value_type": fields.String, + "value_type": fields.String(attribute=_serialize_variable_type), "edited": fields.Boolean(attribute=lambda model: model.edited), "visible": fields.Boolean, } @@ -90,7 +95,7 @@ _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = { "name": fields.String, "description": fields.String, "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), - "value_type": fields.String, + "value_type": fields.String(attribute=_serialize_variable_type), "edited": fields.Boolean(attribute=lambda model: model.edited), "visible": fields.Boolean, } @@ -396,7 +401,7 @@ class EnvironmentVariableCollectionApi(Resource): "name": v.name, "description": v.description, "selector": v.selector, - "value_type": v.value_type.value, + "value_type": v.value_type.exposed_type().value, "value": v.value, # Do not track edited for env vars. "edited": False, diff --git a/api/controllers/console/auth/activate.py b/api/controllers/console/auth/activate.py index 1795563ff7..2562fb5eb8 100644 --- a/api/controllers/console/auth/activate.py +++ b/api/controllers/console/auth/activate.py @@ -1,5 +1,3 @@ -import datetime - from flask import request from flask_restful import Resource, reqparse @@ -7,6 +5,7 @@ from constants.languages import supported_language from controllers.console import api from controllers.console.error import AlreadyActivateError from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from libs.helper import StrLen, email, extract_remote_ip, timezone from models.account import AccountStatus from services.account_service import AccountService, RegisterService @@ -65,7 +64,7 @@ class ActivateApi(Resource): account.timezone = args["timezone"] account.interface_theme = "light" account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() db.session.commit() token_pair = AccountService.login(account, ip_address=extract_remote_ip(request)) diff --git a/api/controllers/console/auth/error.py b/api/controllers/console/auth/error.py index b40934dbf5..8c5e23de58 100644 --- a/api/controllers/console/auth/error.py +++ b/api/controllers/console/auth/error.py @@ -27,7 +27,19 @@ class InvalidTokenError(BaseHTTPException): class PasswordResetRateLimitExceededError(BaseHTTPException): error_code = "password_reset_rate_limit_exceeded" - description = "Too many password reset emails have been sent. Please try again in 1 minutes." + description = "Too many password reset emails have been sent. Please try again in 1 minute." + code = 429 + + +class EmailChangeRateLimitExceededError(BaseHTTPException): + error_code = "email_change_rate_limit_exceeded" + description = "Too many email change emails have been sent. Please try again in 1 minute." + code = 429 + + +class OwnerTransferRateLimitExceededError(BaseHTTPException): + error_code = "owner_transfer_rate_limit_exceeded" + description = "Too many owner transfer emails have been sent. Please try again in 1 minute." code = 429 @@ -65,3 +77,39 @@ class EmailPasswordResetLimitError(BaseHTTPException): error_code = "email_password_reset_limit" description = "Too many failed password reset attempts. Please try again in 24 hours." code = 429 + + +class EmailChangeLimitError(BaseHTTPException): + error_code = "email_change_limit" + description = "Too many failed email change attempts. Please try again in 24 hours." + code = 429 + + +class EmailAlreadyInUseError(BaseHTTPException): + error_code = "email_already_in_use" + description = "A user with this email already exists." + code = 400 + + +class OwnerTransferLimitError(BaseHTTPException): + error_code = "owner_transfer_limit" + description = "Too many failed owner transfer attempts. Please try again in 24 hours." + code = 429 + + +class NotOwnerError(BaseHTTPException): + error_code = "not_owner" + description = "You are not the owner of the workspace." + code = 400 + + +class CannotTransferOwnerToSelfError(BaseHTTPException): + error_code = "cannot_transfer_owner_to_self" + description = "You cannot transfer ownership to yourself." + code = 400 + + +class MemberNotInTenantError(BaseHTTPException): + error_code = "member_not_in_tenant" + description = "The member is not in the workspace." + code = 400 diff --git a/api/controllers/console/auth/oauth.py b/api/controllers/console/auth/oauth.py index 395367c9e2..d0a4f3ff6d 100644 --- a/api/controllers/console/auth/oauth.py +++ b/api/controllers/console/auth/oauth.py @@ -1,5 +1,4 @@ import logging -from datetime import UTC, datetime from typing import Optional import requests @@ -13,6 +12,7 @@ from configs import dify_config from constants.languages import languages from events.tenant_event import tenant_was_created from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from libs.helper import extract_remote_ip from libs.oauth import GitHubOAuth, GoogleOAuth, OAuthUserInfo from models import Account @@ -110,7 +110,7 @@ class OAuthCallback(Resource): if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.now(UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() db.session.commit() try: diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 7b0d9373cf..b49f8affc8 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,4 +1,3 @@ -import datetime import json from flask import request @@ -15,6 +14,7 @@ from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.notion_extractor import NotionExtractor from extensions.ext_database import db from fields.data_source_fields import integrate_list_fields, integrate_notion_info_list_fields +from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import DataSourceOauthBinding, Document from services.dataset_service import DatasetService, DocumentService @@ -88,7 +88,7 @@ class DataSourceApi(Resource): if action == "enable": if data_source_binding.disabled: data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + data_source_binding.updated_at = naive_utc_now() db.session.add(data_source_binding) db.session.commit() else: @@ -97,7 +97,7 @@ class DataSourceApi(Resource): if action == "disable": if not data_source_binding.disabled: data_source_binding.disabled = True - data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + data_source_binding.updated_at = naive_utc_now() db.session.add(data_source_binding) db.session.commit() else: diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 1611214cb3..4f62ac78b4 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -211,10 +211,6 @@ class DatasetApi(Resource): else: data["embedding_available"] = True - if data.get("permission") == "partial_members": - part_users_list = DatasetPermissionService.get_dataset_partial_member_list(dataset_id_str) - data.update({"partial_member_list": part_users_list}) - return data, 200 @setup_required diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index b2fcf3ce7b..28a2e93049 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -1,6 +1,5 @@ import logging from argparse import ArgumentTypeError -from datetime import UTC, datetime from typing import cast from flask import request @@ -49,6 +48,7 @@ from fields.document_fields import ( document_status_fields, document_with_segments_fields, ) +from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from services.dataset_service import DatasetService, DocumentService @@ -750,7 +750,7 @@ class DocumentProcessingApi(DocumentResource): raise InvalidActionError("Document not in indexing state.") document.paused_by = current_user.id - document.paused_at = datetime.now(UTC).replace(tzinfo=None) + document.paused_at = naive_utc_now() document.is_paused = True db.session.commit() @@ -830,7 +830,7 @@ class DocumentMetadataApi(DocumentResource): document.doc_metadata[key] = value document.doc_type = doc_type - document.updated_at = datetime.now(UTC).replace(tzinfo=None) + document.updated_at = naive_utc_now() db.session.commit() return {"result": "success", "message": "Document metadata updated."}, 200 diff --git a/api/controllers/console/datasets/error.py b/api/controllers/console/datasets/error.py index 2f00a84de6..cb68bb5e81 100644 --- a/api/controllers/console/datasets/error.py +++ b/api/controllers/console/datasets/error.py @@ -25,12 +25,6 @@ class UnsupportedFileTypeError(BaseHTTPException): code = 415 -class HighQualityDatasetOnlyError(BaseHTTPException): - error_code = "high_quality_dataset_only" - description = "Current operation only supports 'high-quality' datasets." - code = 400 - - class DatasetNotInitializedError(BaseHTTPException): error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." diff --git a/api/controllers/console/datasets/website.py b/api/controllers/console/datasets/website.py index 4200a51709..fcdc91ec67 100644 --- a/api/controllers/console/datasets/website.py +++ b/api/controllers/console/datasets/website.py @@ -4,7 +4,7 @@ from controllers.console import api from controllers.console.datasets.error import WebsiteCrawlError from controllers.console.wraps import account_initialization_required, setup_required from libs.login import login_required -from services.website_service import WebsiteService +from services.website_service import WebsiteCrawlApiRequest, WebsiteCrawlStatusApiRequest, WebsiteService class WebsiteCrawlApi(Resource): @@ -24,10 +24,16 @@ class WebsiteCrawlApi(Resource): parser.add_argument("url", type=str, required=True, nullable=True, location="json") parser.add_argument("options", type=dict, required=True, nullable=True, location="json") args = parser.parse_args() - WebsiteService.document_create_args_validate(args) - # crawl url + + # Create typed request and validate + try: + api_request = WebsiteCrawlApiRequest.from_args(args) + except ValueError as e: + raise WebsiteCrawlError(str(e)) + + # Crawl URL using typed request try: - result = WebsiteService.crawl_url(args) + result = WebsiteService.crawl_url(api_request) except Exception as e: raise WebsiteCrawlError(str(e)) return result, 200 @@ -43,9 +49,16 @@ class WebsiteCrawlStatusApi(Resource): "provider", type=str, choices=["firecrawl", "watercrawl", "jinareader"], required=True, location="args" ) args = parser.parse_args() - # get crawl status + + # Create typed request and validate + try: + api_request = WebsiteCrawlStatusApiRequest.from_args(args, job_id) + except ValueError as e: + raise WebsiteCrawlError(str(e)) + + # Get crawl status using typed request try: - result = WebsiteService.get_crawl_status(job_id, args["provider"]) + result = WebsiteService.get_crawl_status_typed(api_request) except Exception as e: raise WebsiteCrawlError(str(e)) return result, 200 diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index 4367da1162..4842fefc57 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -1,5 +1,4 @@ import logging -from datetime import UTC, datetime from flask_login import current_user from flask_restful import reqparse @@ -27,6 +26,7 @@ from core.errors.error import ( from core.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from libs import helper +from libs.datetime_utils import naive_utc_now from libs.helper import uuid_value from models.model import AppMode from services.app_generate_service import AppGenerateService @@ -51,7 +51,7 @@ class CompletionApi(InstalledAppResource): streaming = args["response_mode"] == "streaming" args["auto_generate_name"] = False - installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None) + installed_app.last_used_at = naive_utc_now() db.session.commit() try: @@ -111,7 +111,7 @@ class ChatApi(InstalledAppResource): args["auto_generate_name"] = False - installed_app.last_used_at = datetime.now(UTC).replace(tzinfo=None) + installed_app.last_used_at = naive_utc_now() db.session.commit() try: diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 9d0c08564e..29111fb865 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -1,5 +1,4 @@ import logging -from datetime import UTC, datetime from typing import Any from flask import request @@ -13,6 +12,7 @@ from controllers.console.explore.wraps import InstalledAppResource from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check from extensions.ext_database import db from fields.installed_app_fields import installed_app_list_fields +from libs.datetime_utils import naive_utc_now from libs.login import login_required from models import App, InstalledApp, RecommendedApp from services.account_service import TenantService @@ -122,7 +122,7 @@ class InstalledAppsListApi(Resource): tenant_id=current_tenant_id, app_owner_tenant_id=app.tenant_id, is_pinned=False, - last_used_at=datetime.now(UTC).replace(tzinfo=None), + last_used_at=naive_utc_now(), ) db.session.add(new_installed_app) db.session.commit() diff --git a/api/controllers/console/workspace/account.py b/api/controllers/console/workspace/account.py index a9dbf44456..7f7e64a59c 100644 --- a/api/controllers/console/workspace/account.py +++ b/api/controllers/console/workspace/account.py @@ -1,13 +1,21 @@ -import datetime - import pytz from flask import request from flask_login import current_user from flask_restful import Resource, fields, marshal_with, reqparse +from sqlalchemy import select +from sqlalchemy.orm import Session from configs import dify_config from constants.languages import supported_language from controllers.console import api +from controllers.console.auth.error import ( + EmailAlreadyInUseError, + EmailChangeLimitError, + EmailCodeError, + InvalidEmailError, + InvalidTokenError, +) +from controllers.console.error import AccountNotFound, EmailSendIpLimitError from controllers.console.workspace.error import ( AccountAlreadyInitedError, CurrentPasswordIncorrectError, @@ -18,15 +26,18 @@ from controllers.console.workspace.error import ( from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_enabled, + enable_change_email, enterprise_license_required, only_edition_cloud, setup_required, ) from extensions.ext_database import db from fields.member_fields import account_fields -from libs.helper import TimestampField, timezone +from libs.datetime_utils import naive_utc_now +from libs.helper import TimestampField, email, extract_remote_ip, timezone from libs.login import login_required from models import AccountIntegrate, InvitationCode +from models.account import Account from services.account_service import AccountService from services.billing_service import BillingService from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError @@ -68,7 +79,7 @@ class AccountInitApi(Resource): raise InvalidInvitationCodeError() invitation_code.status = "used" - invitation_code.used_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + invitation_code.used_at = naive_utc_now() invitation_code.used_by_tenant_id = account.current_tenant_id invitation_code.used_by_account_id = account.id @@ -76,7 +87,7 @@ class AccountInitApi(Resource): account.timezone = args["timezone"] account.interface_theme = "light" account.status = "active" - account.initialized_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() db.session.commit() return {"result": "success"} @@ -369,6 +380,134 @@ class EducationAutoCompleteApi(Resource): return BillingService.EducationIdentity.autocomplete(args["keywords"], args["page"], args["limit"]) +class ChangeEmailSendEmailApi(Resource): + @enable_change_email + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("language", type=str, required=False, location="json") + parser.add_argument("phase", type=str, required=False, location="json") + parser.add_argument("token", type=str, required=False, location="json") + args = parser.parse_args() + + ip_address = extract_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + account = None + user_email = args["email"] + if args["phase"] is not None and args["phase"] == "new_email": + if args["token"] is None: + raise InvalidTokenError() + + reset_data = AccountService.get_change_email_data(args["token"]) + if reset_data is None: + raise InvalidTokenError() + user_email = reset_data.get("email", "") + + if user_email != current_user.email: + raise InvalidEmailError() + else: + with Session(db.engine) as session: + account = session.execute(select(Account).filter_by(email=args["email"])).scalar_one_or_none() + if account is None: + raise AccountNotFound() + + token = AccountService.send_change_email_email( + account=account, email=args["email"], old_email=user_email, language=language, phase=args["phase"] + ) + return {"result": "success", "data": token} + + +class ChangeEmailCheckApi(Resource): + @enable_change_email + @setup_required + @login_required + @account_initialization_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + user_email = args["email"] + + is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args["email"]) + if is_change_email_error_rate_limit: + raise EmailChangeLimitError() + + token_data = AccountService.get_change_email_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if user_email != token_data.get("email"): + raise InvalidEmailError() + + if args["code"] != token_data.get("code"): + AccountService.add_change_email_error_rate_limit(args["email"]) + raise EmailCodeError() + + # Verified, revoke the first token + AccountService.revoke_change_email_token(args["token"]) + + # Refresh token data by generating a new token + _, new_token = AccountService.generate_change_email_token( + user_email, code=args["code"], old_email=token_data.get("old_email"), additional_data={} + ) + + AccountService.reset_change_email_error_rate_limit(args["email"]) + return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + + +class ChangeEmailResetApi(Resource): + @enable_change_email + @setup_required + @login_required + @account_initialization_required + @marshal_with(account_fields) + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("new_email", type=email, required=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + reset_data = AccountService.get_change_email_data(args["token"]) + if not reset_data: + raise InvalidTokenError() + + AccountService.revoke_change_email_token(args["token"]) + + if not AccountService.check_email_unique(args["new_email"]): + raise EmailAlreadyInUseError() + + old_email = reset_data.get("old_email", "") + if current_user.email != old_email: + raise AccountNotFound() + + updated_account = AccountService.update_account(current_user, email=args["new_email"]) + + return updated_account + + +class CheckEmailUnique(Resource): + @setup_required + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("email", type=email, required=True, location="json") + args = parser.parse_args() + if not AccountService.check_email_unique(args["email"]): + raise EmailAlreadyInUseError() + return {"result": "success"} + + # Register API resources api.add_resource(AccountInitApi, "/account/init") api.add_resource(AccountProfileApi, "/account/profile") @@ -385,5 +524,10 @@ api.add_resource(AccountDeleteUpdateFeedbackApi, "/account/delete/feedback") api.add_resource(EducationVerifyApi, "/account/education/verify") api.add_resource(EducationApi, "/account/education") api.add_resource(EducationAutoCompleteApi, "/account/education/autocomplete") +# Change email +api.add_resource(ChangeEmailSendEmailApi, "/account/change-email") +api.add_resource(ChangeEmailCheckApi, "/account/change-email/validity") +api.add_resource(ChangeEmailResetApi, "/account/change-email/reset") +api.add_resource(CheckEmailUnique, "/account/change-email/check-email-unique") # api.add_resource(AccountEmailApi, '/account/email') # api.add_resource(AccountEmailVerifyApi, '/account/email-verify') diff --git a/api/controllers/console/workspace/error.py b/api/controllers/console/workspace/error.py index 8b70ca62b9..4427d1ff72 100644 --- a/api/controllers/console/workspace/error.py +++ b/api/controllers/console/workspace/error.py @@ -13,12 +13,6 @@ class CurrentPasswordIncorrectError(BaseHTTPException): code = 400 -class ProviderRequestFailedError(BaseHTTPException): - error_code = "provider_request_failed" - description = None - code = 400 - - class InvalidInvitationCodeError(BaseHTTPException): error_code = "invalid_invitation_code" description = "Invalid invitation code." diff --git a/api/controllers/console/workspace/members.py b/api/controllers/console/workspace/members.py index 48225ac90d..b1f79ffdec 100644 --- a/api/controllers/console/workspace/members.py +++ b/api/controllers/console/workspace/members.py @@ -1,22 +1,34 @@ from urllib import parse +from flask import request from flask_login import current_user from flask_restful import Resource, abort, marshal_with, reqparse import services from configs import dify_config from controllers.console import api -from controllers.console.error import WorkspaceMembersLimitExceeded +from controllers.console.auth.error import ( + CannotTransferOwnerToSelfError, + EmailCodeError, + InvalidEmailError, + InvalidTokenError, + MemberNotInTenantError, + NotOwnerError, + OwnerTransferLimitError, +) +from controllers.console.error import EmailSendIpLimitError, WorkspaceMembersLimitExceeded from controllers.console.wraps import ( account_initialization_required, cloud_edition_billing_resource_check, + is_allow_transfer_owner, setup_required, ) from extensions.ext_database import db from fields.member_fields import account_with_role_list_fields +from libs.helper import extract_remote_ip from libs.login import login_required from models.account import Account, TenantAccountRole -from services.account_service import RegisterService, TenantService +from services.account_service import AccountService, RegisterService, TenantService from services.errors.account import AccountAlreadyInTenantError from services.feature_service import FeatureService @@ -156,8 +168,146 @@ class DatasetOperatorMemberListApi(Resource): return {"result": "success", "accounts": members}, 200 +class SendOwnerTransferEmailApi(Resource): + """Send owner transfer email.""" + + @setup_required + @login_required + @account_initialization_required + @is_allow_transfer_owner + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("language", type=str, required=False, location="json") + args = parser.parse_args() + ip_address = extract_remote_ip(request) + if AccountService.is_email_send_ip_limit(ip_address): + raise EmailSendIpLimitError() + + # check if the current user is the owner of the workspace + if not TenantService.is_owner(current_user, current_user.current_tenant): + raise NotOwnerError() + + if args["language"] is not None and args["language"] == "zh-Hans": + language = "zh-Hans" + else: + language = "en-US" + + email = current_user.email + + token = AccountService.send_owner_transfer_email( + account=current_user, + email=email, + language=language, + workspace_name=current_user.current_tenant.name, + ) + + return {"result": "success", "data": token} + + +class OwnerTransferCheckApi(Resource): + @setup_required + @login_required + @account_initialization_required + @is_allow_transfer_owner + def post(self): + parser = reqparse.RequestParser() + parser.add_argument("code", type=str, required=True, location="json") + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + # check if the current user is the owner of the workspace + if not TenantService.is_owner(current_user, current_user.current_tenant): + raise NotOwnerError() + + user_email = current_user.email + + is_owner_transfer_error_rate_limit = AccountService.is_owner_transfer_error_rate_limit(user_email) + if is_owner_transfer_error_rate_limit: + raise OwnerTransferLimitError() + + token_data = AccountService.get_owner_transfer_data(args["token"]) + if token_data is None: + raise InvalidTokenError() + + if user_email != token_data.get("email"): + raise InvalidEmailError() + + if args["code"] != token_data.get("code"): + AccountService.add_owner_transfer_error_rate_limit(user_email) + raise EmailCodeError() + + # Verified, revoke the first token + AccountService.revoke_owner_transfer_token(args["token"]) + + # Refresh token data by generating a new token + _, new_token = AccountService.generate_owner_transfer_token(user_email, code=args["code"], additional_data={}) + + AccountService.reset_owner_transfer_error_rate_limit(user_email) + return {"is_valid": True, "email": token_data.get("email"), "token": new_token} + + +class OwnerTransfer(Resource): + @setup_required + @login_required + @account_initialization_required + @is_allow_transfer_owner + def post(self, member_id): + parser = reqparse.RequestParser() + parser.add_argument("token", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + # check if the current user is the owner of the workspace + if not TenantService.is_owner(current_user, current_user.current_tenant): + raise NotOwnerError() + + if current_user.id == str(member_id): + raise CannotTransferOwnerToSelfError() + + transfer_token_data = AccountService.get_owner_transfer_data(args["token"]) + if not transfer_token_data: + raise InvalidTokenError() + + if transfer_token_data.get("email") != current_user.email: + raise InvalidEmailError() + + AccountService.revoke_owner_transfer_token(args["token"]) + + member = db.session.get(Account, str(member_id)) + if not member: + abort(404) + else: + member_account = member + if not TenantService.is_member(member_account, current_user.current_tenant): + raise MemberNotInTenantError() + + try: + assert member is not None, "Member not found" + TenantService.update_member_role(current_user.current_tenant, member, "owner", current_user) + + AccountService.send_new_owner_transfer_notify_email( + account=member, + email=member.email, + workspace_name=current_user.current_tenant.name, + ) + + AccountService.send_old_owner_transfer_notify_email( + account=current_user, + email=current_user.email, + workspace_name=current_user.current_tenant.name, + new_owner_email=member.email, + ) + + except Exception as e: + raise ValueError(str(e)) + + return {"result": "success"} + + api.add_resource(MemberListApi, "/workspaces/current/members") api.add_resource(MemberInviteEmailApi, "/workspaces/current/members/invite-email") api.add_resource(MemberCancelInviteApi, "/workspaces/current/members/") api.add_resource(MemberUpdateRoleApi, "/workspaces/current/members//update-role") api.add_resource(DatasetOperatorMemberListApi, "/workspaces/current/dataset-operators") +# owner transfer +api.add_resource(SendOwnerTransferEmailApi, "/workspaces/current/members/send-owner-transfer-confirm-email") +api.add_resource(OwnerTransferCheckApi, "/workspaces/current/members/owner-transfer-check") +api.add_resource(OwnerTransfer, "/workspaces/current/members//owner-transfer") diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index df50871a38..c70bf84d2a 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -1,26 +1,35 @@ import io from urllib.parse import urlparse -from flask import redirect, send_file +from flask import make_response, redirect, request, send_file from flask_login import current_user -from flask_restful import Resource, reqparse -from sqlalchemy.orm import Session +from flask_restful import ( + Resource, + reqparse, +) from werkzeug.exceptions import Forbidden from configs import dify_config from controllers.console import api -from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required +from controllers.console.wraps import ( + account_initialization_required, + enterprise_license_required, + setup_required, +) from core.mcp.auth.auth_flow import auth, handle_callback from core.mcp.auth.auth_provider import OAuthClientProvider from core.mcp.error import MCPAuthError, MCPError from core.mcp.mcp_client import MCPClient from core.model_runtime.utils.encoders import jsonable_encoder -from extensions.ext_database import db -from libs.helper import alphanumeric, uuid_value +from core.plugin.entities.plugin import ToolProviderID +from core.plugin.impl.oauth import OAuthHandler +from core.tools.entities.tool_entities import CredentialType +from libs.helper import StrLen, alphanumeric, uuid_value from libs.login import login_required +from services.plugin.oauth_service import OAuthProxyService from services.tools.api_tools_manage_service import ApiToolManageService from services.tools.builtin_tools_manage_service import BuiltinToolManageService -from services.tools.mcp_tools_mange_service import MCPToolManageService +from services.tools.mcp_tools_manage_service import MCPToolManageService from services.tools.tool_labels_service import ToolLabelsService from services.tools.tools_manage_service import ToolCommonService from services.tools.tools_transform_service import ToolTransformService @@ -89,7 +98,7 @@ class ToolBuiltinProviderInfoApi(Resource): user_id = user.id tenant_id = user.current_tenant_id - return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(user_id, tenant_id, provider)) + return jsonable_encoder(BuiltinToolManageService.get_builtin_tool_provider_info(tenant_id, provider)) class ToolBuiltinProviderDeleteApi(Resource): @@ -98,17 +107,47 @@ class ToolBuiltinProviderDeleteApi(Resource): @account_initialization_required def post(self, provider): user = current_user - if not user.is_admin_or_owner: raise Forbidden() - user_id = user.id tenant_id = user.current_tenant_id + req = reqparse.RequestParser() + req.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + args = req.parse_args() return BuiltinToolManageService.delete_builtin_tool_provider( - user_id, tenant_id, provider, + args["credential_id"], + ) + + +class ToolBuiltinProviderAddApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + user = current_user + + user_id = user.id + tenant_id = user.current_tenant_id + + parser = reqparse.RequestParser() + parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("name", type=StrLen(30), required=False, nullable=False, location="json") + parser.add_argument("type", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + + if args["type"] not in CredentialType.values(): + raise ValueError(f"Invalid credential type: {args['type']}") + + return BuiltinToolManageService.add_builtin_tool_provider( + user_id=user_id, + tenant_id=tenant_id, + provider=provider, + credentials=args["credentials"], + name=args["name"], + api_type=CredentialType.of(args["type"]), ) @@ -126,19 +165,20 @@ class ToolBuiltinProviderUpdateApi(Resource): tenant_id = user.current_tenant_id parser = reqparse.RequestParser() - parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") + parser.add_argument("credential_id", type=str, required=True, nullable=False, location="json") + parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") + parser.add_argument("name", type=StrLen(30), required=False, nullable=True, location="json") args = parser.parse_args() - with Session(db.engine) as session: - result = BuiltinToolManageService.update_builtin_tool_provider( - session=session, - user_id=user_id, - tenant_id=tenant_id, - provider_name=provider, - credentials=args["credentials"], - ) - session.commit() + result = BuiltinToolManageService.update_builtin_tool_provider( + user_id=user_id, + tenant_id=tenant_id, + provider=provider, + credential_id=args["credential_id"], + credentials=args.get("credentials", None), + name=args.get("name", ""), + ) return result @@ -149,9 +189,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): def get(self, provider): tenant_id = current_user.current_tenant_id - return BuiltinToolManageService.get_builtin_tool_provider_credentials( - tenant_id=tenant_id, - provider_name=provider, + return jsonable_encoder( + BuiltinToolManageService.get_builtin_tool_provider_credentials( + tenant_id=tenant_id, + provider_name=provider, + ) ) @@ -344,12 +386,15 @@ class ToolBuiltinProviderCredentialsSchemaApi(Resource): @setup_required @login_required @account_initialization_required - def get(self, provider): + def get(self, provider, credential_type): user = current_user - tenant_id = user.current_tenant_id - return BuiltinToolManageService.list_builtin_provider_credentials_schema(provider, tenant_id) + return jsonable_encoder( + BuiltinToolManageService.list_builtin_provider_credentials_schema( + provider, CredentialType.of(credential_type), tenant_id + ) + ) class ToolApiProviderSchemaApi(Resource): @@ -586,15 +631,12 @@ class ToolApiListApi(Resource): @account_initialization_required def get(self): user = current_user - - user_id = user.id tenant_id = user.current_tenant_id return jsonable_encoder( [ provider.to_dict() for provider in ApiToolManageService.list_api_tools( - user_id, tenant_id, ) ] @@ -631,6 +673,179 @@ class ToolLabelsApi(Resource): return jsonable_encoder(ToolLabelsService.list_tool_labels()) +class ToolPluginOAuthApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + tool_provider = ToolProviderID(provider) + plugin_id = tool_provider.plugin_id + provider_name = tool_provider.provider_name + + # todo check permission + user = current_user + + if not user.is_admin_or_owner: + raise Forbidden() + + tenant_id = user.current_tenant_id + oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id=tenant_id, provider=provider) + if oauth_client_params is None: + raise Forbidden("no oauth available client config found for this tool provider") + + oauth_handler = OAuthHandler() + context_id = OAuthProxyService.create_proxy_context( + user_id=current_user.id, tenant_id=tenant_id, plugin_id=plugin_id, provider=provider_name + ) + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" + authorization_url_response = oauth_handler.get_authorization_url( + tenant_id=tenant_id, + user_id=user.id, + plugin_id=plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, + system_credentials=oauth_client_params, + ) + response = make_response(jsonable_encoder(authorization_url_response)) + response.set_cookie( + "context_id", + context_id, + httponly=True, + samesite="Lax", + max_age=OAuthProxyService.__MAX_AGE__, + ) + return response + + +class ToolOAuthCallback(Resource): + @setup_required + def get(self, provider): + context_id = request.cookies.get("context_id") + if not context_id: + raise Forbidden("context_id not found") + + context = OAuthProxyService.use_proxy_context(context_id) + if context is None: + raise Forbidden("Invalid context_id") + + tool_provider = ToolProviderID(provider) + plugin_id = tool_provider.plugin_id + provider_name = tool_provider.provider_name + user_id, tenant_id = context.get("user_id"), context.get("tenant_id") + + oauth_handler = OAuthHandler() + oauth_client_params = BuiltinToolManageService.get_oauth_client(tenant_id, provider) + if oauth_client_params is None: + raise Forbidden("no oauth available client config found for this tool provider") + + redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback" + credentials = oauth_handler.get_credentials( + tenant_id=tenant_id, + user_id=user_id, + plugin_id=plugin_id, + provider=provider_name, + redirect_uri=redirect_uri, + system_credentials=oauth_client_params, + request=request, + ).credentials + + if not credentials: + raise Exception("the plugin credentials failed") + + # add credentials to database + BuiltinToolManageService.add_builtin_tool_provider( + user_id=user_id, + tenant_id=tenant_id, + provider=provider, + credentials=dict(credentials), + api_type=CredentialType.OAUTH2, + ) + return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback") + + +class ToolBuiltinProviderSetDefaultApi(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + parser = reqparse.RequestParser() + parser.add_argument("id", type=str, required=True, nullable=False, location="json") + args = parser.parse_args() + return BuiltinToolManageService.set_default_provider( + tenant_id=current_user.current_tenant_id, user_id=current_user.id, provider=provider, id=args["id"] + ) + + +class ToolOAuthCustomClient(Resource): + @setup_required + @login_required + @account_initialization_required + def post(self, provider): + parser = reqparse.RequestParser() + parser.add_argument("client_params", type=dict, required=False, nullable=True, location="json") + parser.add_argument("enable_oauth_custom_client", type=bool, required=False, nullable=True, location="json") + args = parser.parse_args() + + user = current_user + + if not user.is_admin_or_owner: + raise Forbidden() + + return BuiltinToolManageService.save_custom_oauth_client_params( + tenant_id=user.current_tenant_id, + provider=provider, + client_params=args.get("client_params", {}), + enable_oauth_custom_client=args.get("enable_oauth_custom_client", True), + ) + + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + return jsonable_encoder( + BuiltinToolManageService.get_custom_oauth_client_params( + tenant_id=current_user.current_tenant_id, provider=provider + ) + ) + + @setup_required + @login_required + @account_initialization_required + def delete(self, provider): + return jsonable_encoder( + BuiltinToolManageService.delete_custom_oauth_client_params( + tenant_id=current_user.current_tenant_id, provider=provider + ) + ) + + +class ToolBuiltinProviderGetOauthClientSchemaApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + return jsonable_encoder( + BuiltinToolManageService.get_builtin_tool_provider_oauth_client_schema( + tenant_id=current_user.current_tenant_id, provider_name=provider + ) + ) + + +class ToolBuiltinProviderGetCredentialInfoApi(Resource): + @setup_required + @login_required + @account_initialization_required + def get(self, provider): + tenant_id = current_user.current_tenant_id + + return jsonable_encoder( + BuiltinToolManageService.get_builtin_tool_provider_credential_info( + tenant_id=tenant_id, + provider=provider, + ) + ) + + class ToolProviderMCPApi(Resource): @setup_required @login_required @@ -794,17 +1009,33 @@ class ToolMCPCallbackApi(Resource): # tool provider api.add_resource(ToolProviderListApi, "/workspaces/current/tool-providers") +# tool oauth +api.add_resource(ToolPluginOAuthApi, "/oauth/plugin//tool/authorization-url") +api.add_resource(ToolOAuthCallback, "/oauth/plugin//tool/callback") +api.add_resource(ToolOAuthCustomClient, "/workspaces/current/tool-provider/builtin//oauth/custom-client") + # builtin tool provider api.add_resource(ToolBuiltinProviderListToolsApi, "/workspaces/current/tool-provider/builtin//tools") api.add_resource(ToolBuiltinProviderInfoApi, "/workspaces/current/tool-provider/builtin//info") +api.add_resource(ToolBuiltinProviderAddApi, "/workspaces/current/tool-provider/builtin//add") api.add_resource(ToolBuiltinProviderDeleteApi, "/workspaces/current/tool-provider/builtin//delete") api.add_resource(ToolBuiltinProviderUpdateApi, "/workspaces/current/tool-provider/builtin//update") +api.add_resource( + ToolBuiltinProviderSetDefaultApi, "/workspaces/current/tool-provider/builtin//default-credential" +) +api.add_resource( + ToolBuiltinProviderGetCredentialInfoApi, "/workspaces/current/tool-provider/builtin//credential/info" +) api.add_resource( ToolBuiltinProviderGetCredentialsApi, "/workspaces/current/tool-provider/builtin//credentials" ) api.add_resource( ToolBuiltinProviderCredentialsSchemaApi, - "/workspaces/current/tool-provider/builtin//credentials_schema", + "/workspaces/current/tool-provider/builtin//credential/schema/", +) +api.add_resource( + ToolBuiltinProviderGetOauthClientSchemaApi, + "/workspaces/current/tool-provider/builtin//oauth/client-schema", ) api.add_resource(ToolBuiltinProviderIconApi, "/workspaces/current/tool-provider/builtin//icon") diff --git a/api/controllers/console/wraps.py b/api/controllers/console/wraps.py index ca122772de..d862dac373 100644 --- a/api/controllers/console/wraps.py +++ b/api/controllers/console/wraps.py @@ -235,3 +235,29 @@ def email_password_login_enabled(view): abort(403) return decorated + + +def enable_change_email(view): + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_system_features() + if features.enable_change_email: + return view(*args, **kwargs) + + # otherwise, return 403 + abort(403) + + return decorated + + +def is_allow_transfer_owner(view): + @wraps(view) + def decorated(*args, **kwargs): + features = FeatureService.get_features(current_user.current_tenant_id) + if features.is_allow_transfer_workspace: + return view(*args, **kwargs) + + # otherwise, return 403 + abort(403) + + return decorated diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 327e9ce834..5dfe41eb6b 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -175,6 +175,7 @@ class PluginInvokeToolApi(Resource): provider=payload.provider, tool_name=payload.tool, tool_parameters=payload.tool_parameters, + credential_id=payload.credential_id, ), ) diff --git a/api/controllers/service_api/dataset/error.py b/api/controllers/service_api/dataset/error.py index 5ff5e08c72..ecc47b40a1 100644 --- a/api/controllers/service_api/dataset/error.py +++ b/api/controllers/service_api/dataset/error.py @@ -25,12 +25,6 @@ class UnsupportedFileTypeError(BaseHTTPException): code = 415 -class HighQualityDatasetOnlyError(BaseHTTPException): - error_code = "high_quality_dataset_only" - description = "Current operation only supports 'high-quality' datasets." - code = 400 - - class DatasetNotInitializedError(BaseHTTPException): error_code = "dataset_not_initialized" description = "The dataset is still being initialized or indexing. Please wait a moment." diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 5b919a68d4..eeed321430 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -1,6 +1,6 @@ import time from collections.abc import Callable -from datetime import UTC, datetime, timedelta +from datetime import timedelta from enum import Enum from functools import wraps from typing import Optional @@ -15,6 +15,7 @@ from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.datetime_utils import naive_utc_now from libs.login import _get_user from models.account import Account, Tenant, TenantAccountJoin, TenantStatus from models.dataset import Dataset, RateLimitLog @@ -256,7 +257,7 @@ def validate_and_get_api_token(scope: str | None = None): if auth_scheme != "bearer": raise Unauthorized("Authorization scheme must be 'Bearer'") - current_time = datetime.now(UTC).replace(tzinfo=None) + current_time = naive_utc_now() cutoff_time = current_time - timedelta(minutes=1) with Session(db.engine, expire_on_commit=False) as session: update_stmt = ( diff --git a/api/core/agent/entities.py b/api/core/agent/entities.py index 143a3a51aa..a31c1050bd 100644 --- a/api/core/agent/entities.py +++ b/api/core/agent/entities.py @@ -16,6 +16,7 @@ class AgentToolEntity(BaseModel): tool_name: str tool_parameters: dict[str, Any] = Field(default_factory=dict) plugin_unique_identifier: str | None = None + credential_id: str | None = None class AgentPromptEntity(BaseModel): diff --git a/api/core/agent/plugin_entities.py b/api/core/agent/plugin_entities.py index 3b48288710..a3438fc2c7 100644 --- a/api/core/agent/plugin_entities.py +++ b/api/core/agent/plugin_entities.py @@ -41,6 +41,7 @@ class AgentStrategyParameter(PluginParameter): APP_SELECTOR = CommonParameterType.APP_SELECTOR.value MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value + ANY = CommonParameterType.ANY.value # deprecated, should not use. SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value diff --git a/api/core/agent/strategy/base.py b/api/core/agent/strategy/base.py index ead81a7a0e..a52a1dfd7a 100644 --- a/api/core/agent/strategy/base.py +++ b/api/core/agent/strategy/base.py @@ -4,6 +4,7 @@ from typing import Any, Optional from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyParameter +from core.plugin.entities.request import InvokeCredentials class BaseAgentStrategy(ABC): @@ -18,11 +19,12 @@ class BaseAgentStrategy(ABC): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, + credentials: Optional[InvokeCredentials] = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent strategy. """ - yield from self._invoke(params, user_id, conversation_id, app_id, message_id) + yield from self._invoke(params, user_id, conversation_id, app_id, message_id, credentials) def get_parameters(self) -> Sequence[AgentStrategyParameter]: """ @@ -38,5 +40,6 @@ class BaseAgentStrategy(ABC): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, + credentials: Optional[InvokeCredentials] = None, ) -> Generator[AgentInvokeMessage, None, None]: pass diff --git a/api/core/agent/strategy/plugin.py b/api/core/agent/strategy/plugin.py index 4cfcfbf86a..04661581a7 100644 --- a/api/core/agent/strategy/plugin.py +++ b/api/core/agent/strategy/plugin.py @@ -4,6 +4,7 @@ from typing import Any, Optional from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter from core.agent.strategy.base import BaseAgentStrategy +from core.plugin.entities.request import InvokeCredentials, PluginInvokeContext from core.plugin.impl.agent import PluginAgentClient from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -40,6 +41,7 @@ class PluginAgentStrategy(BaseAgentStrategy): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, + credentials: Optional[InvokeCredentials] = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent strategy. @@ -58,4 +60,5 @@ class PluginAgentStrategy(BaseAgentStrategy): conversation_id=conversation_id, app_id=app_id, message_id=message_id, + context=PluginInvokeContext(credentials=credentials or InvokeCredentials()), ) diff --git a/api/core/app/app_config/easy_ui_based_app/agent/manager.py b/api/core/app/app_config/easy_ui_based_app/agent/manager.py index 590b944c0d..8887d2500c 100644 --- a/api/core/app/app_config/easy_ui_based_app/agent/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/agent/manager.py @@ -39,6 +39,7 @@ class AgentConfigManager: "provider_id": tool["provider_id"], "tool_name": tool["tool_name"], "tool_parameters": tool.get("tool_parameters", {}), + "credential_id": tool.get("credential_id", None), } agent_tools.append(AgentToolEntity(**agent_tool_properties)) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 4b8f5ebe27..bd5ad9c51b 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -17,7 +17,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 840a3c9d3b..af15324f46 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -16,9 +16,10 @@ from core.app.entities.queue_entities import ( QueueTextChunkEvent, ) from core.moderation.base import ModerationError +from core.variables.variables import VariableUnion from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db @@ -64,7 +65,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): if not workflow: raise ValueError("Workflow not initialized") - user_id = None + user_id: str | None = None if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}: end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first() if end_user: @@ -136,23 +137,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): session.commit() # Create a variable pool. - system_inputs = { - SystemVariableKey.QUERY: query, - SystemVariableKey.FILES: files, - SystemVariableKey.CONVERSATION_ID: self.conversation.id, - SystemVariableKey.USER_ID: user_id, - SystemVariableKey.DIALOGUE_COUNT: self._dialogue_count, - SystemVariableKey.APP_ID: app_config.app_id, - SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_run_id, - } + system_inputs = SystemVariable( + query=query, + files=files, + conversation_id=self.conversation.id, + user_id=user_id, + dialogue_count=self._dialogue_count, + app_id=app_config.app_id, + workflow_id=app_config.workflow_id, + workflow_execution_id=self.application_generate_entity.workflow_run_id, + ) # init variable pool variable_pool = VariablePool( system_variables=system_inputs, user_inputs=inputs, environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + conversation_variables=cast(list[VariableUnion], conversation_variables), ) # init graph diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index 140ccbc443..a42679f600 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -1,6 +1,7 @@ import logging import time -from collections.abc import Generator, Mapping +from collections.abc import Callable, Generator, Mapping +from contextlib import contextmanager from threading import Thread from typing import Any, Optional, Union @@ -15,6 +16,7 @@ from core.app.entities.app_invoke_entities import ( InvokeFrom, ) from core.app.entities.queue_entities import ( + MessageQueueMessage, QueueAdvancedChatMessageEndEvent, QueueAgentLogEvent, QueueAnnotationReplyEvent, @@ -44,6 +46,7 @@ from core.app.entities.queue_entities import ( QueueWorkflowPartialSuccessEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, + WorkflowQueueMessage, ) from core.app.entities.task_entities import ( ChatbotAppBlockingResponse, @@ -52,6 +55,7 @@ from core.app.entities.task_entities import ( MessageAudioEndStreamResponse, MessageAudioStreamResponse, MessageEndStreamResponse, + PingStreamResponse, StreamResponse, WorkflowTaskState, ) @@ -61,12 +65,12 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.model_runtime.entities.llm_entities import LLMUsage from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.workflow_execution import WorkflowExecutionStatus, WorkflowType -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.nodes import NodeType from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from events.message_event import message_was_created from extensions.ext_database import db @@ -116,16 +120,16 @@ class AdvancedChatAppGenerateTaskPipeline: self._workflow_cycle_manager = WorkflowCycleManager( application_generate_entity=application_generate_entity, - workflow_system_variables={ - SystemVariableKey.QUERY: message.query, - SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.CONVERSATION_ID: conversation.id, - SystemVariableKey.USER_ID: user_session_id, - SystemVariableKey.DIALOGUE_COUNT: dialogue_count, - SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_run_id, - }, + workflow_system_variables=SystemVariable( + query=message.query, + files=application_generate_entity.files, + conversation_id=conversation.id, + user_id=user_session_id, + dialogue_count=dialogue_count, + app_id=application_generate_entity.app_config.app_id, + workflow_id=workflow.id, + workflow_execution_id=application_generate_entity.workflow_run_id, + ), workflow_info=CycleManagerWorkflowInfo( workflow_id=workflow.id, workflow_type=WorkflowType(workflow.type), @@ -162,12 +166,12 @@ class AdvancedChatAppGenerateTaskPipeline: Process generate task pipeline. :return: """ - # start generate conversation name thread self._conversation_name_generate_thread = self._message_cycle_manager.generate_conversation_name( conversation_id=self._conversation_id, query=self._application_generate_entity.query ) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) + if self._base_task_pipeline._stream: return self._to_stream_response(generator) else: @@ -253,15 +257,12 @@ class AdvancedChatAppGenerateTaskPipeline: yield response start_listener_time = time.time() - # timeout while (time.time() - start_listener_time) < TTS_AUTO_PLAY_TIMEOUT: try: if not tts_publisher: break audio_trunk = tts_publisher.check_and_get_audio() if audio_trunk is None: - # release cpu - # sleep 20 ms ( 40ms => 1280 byte audio file,20ms => 640 byte audio file) time.sleep(TTS_AUTO_PLAY_YIELD_CPU_TIME) continue if audio_trunk.status == "finish": @@ -275,403 +276,618 @@ class AdvancedChatAppGenerateTaskPipeline: if tts_publisher: yield MessageAudioEndStreamResponse(audio="", task_id=task_id) - def _process_stream_response( + @contextmanager + def _database_session(self): + """Context manager for database sessions.""" + with Session(db.engine, expire_on_commit=False) as session: + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + + def _ensure_workflow_initialized(self) -> None: + """Fluent validation for workflow state.""" + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState: + """Fluent validation for graph runtime state.""" + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + return graph_runtime_state + + def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: + """Handle ping events.""" + yield self._base_task_pipeline._ping_stream_response() + + def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]: + """Handle error events.""" + with self._database_session() as session: + err = self._base_task_pipeline._handle_error(event=event, session=session, message_id=self._message_id) + yield self._base_task_pipeline._error_to_stream_response(err) + + def _handle_workflow_started_event( + self, event: QueueWorkflowStartedEvent, *, graph_runtime_state: Optional[GraphRuntimeState] = None, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle workflow started events.""" + # Override graph runtime state - this is a side effect but necessary + graph_runtime_state = event.graph_runtime_state + + with self._database_session() as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start() + self._workflow_run_id = workflow_execution.id_ + + message = self._get_message(session=session) + if not message: + raise ValueError(f"Message not found: {self._message_id}") + + message.workflow_run_id = workflow_execution.id_ + workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + + yield workflow_start_resp + + def _handle_node_retry_event(self, event: QueueNodeRetryEvent, **kwargs) -> Generator[StreamResponse, None, None]: + """Handle node retry events.""" + self._ensure_workflow_initialized() + + with self._database_session() as session: + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( + workflow_execution_id=self._workflow_run_id, event=event + ) + node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if node_retry_resp: + yield node_retry_resp + + def _handle_node_started_event( + self, event: QueueNodeStartedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle node started events.""" + self._ensure_workflow_initialized() + + workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( + workflow_execution_id=self._workflow_run_id, event=event + ) + + node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if node_start_resp: + yield node_start_resp + + def _handle_node_succeeded_event( + self, event: QueueNodeSucceededEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle node succeeded events.""" + # Record files if it's an answer node or end node + if event.node_type in [NodeType.ANSWER, NodeType.END]: + self._recorded_files.extend( + self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {}) + ) + + with self._database_session() as session: + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event) + node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + self._save_output_for_event(event, workflow_node_execution.id) + + if node_finish_resp: + yield node_finish_resp + + def _handle_node_failed_events( self, + event: Union[ + QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent + ], + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle various node failure events.""" + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed(event=event) + + node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if isinstance(event, QueueNodeExceptionEvent): + self._save_output_for_event(event, workflow_node_execution.id) + + if node_finish_resp: + yield node_finish_resp + + def _handle_text_chunk_event( + self, + event: QueueTextChunkEvent, + *, tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None, + queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + **kwargs, ) -> Generator[StreamResponse, None, None]: - """ - Process stream response. - :return: - """ - # init fake graph runtime state - graph_runtime_state: Optional[GraphRuntimeState] = None + """Handle text chunk events.""" + delta_text = event.text + if delta_text is None: + return + + # Handle output moderation chunk + should_direct_answer = self._handle_output_moderation_chunk(delta_text) + if should_direct_answer: + return + + # Only publish tts message at text chunk streaming + if tts_publisher and queue_message: + tts_publisher.publish(queue_message) + + self._task_state.answer += delta_text + yield self._message_cycle_manager.message_to_stream_response( + answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector + ) - for queue_message in self._base_task_pipeline._queue_manager.listen(): - event = queue_message.event + def _handle_parallel_branch_started_event( + self, event: QueueParallelBranchRunStartedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle parallel branch started events.""" + self._ensure_workflow_initialized() - if isinstance(event, QueuePingEvent): - yield self._base_task_pipeline._ping_stream_response() - elif isinstance(event, QueueErrorEvent): - with Session(db.engine, expire_on_commit=False) as session: - err = self._base_task_pipeline._handle_error( - event=event, session=session, message_id=self._message_id - ) - session.commit() - yield self._base_task_pipeline._error_to_stream_response(err) - break - elif isinstance(event, QueueWorkflowStartedEvent): - # override graph runtime state - graph_runtime_state = event.graph_runtime_state - - with Session(db.engine, expire_on_commit=False) as session: - # init workflow run - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start() - self._workflow_run_id = workflow_execution.id_ - message = self._get_message(session=session) - if not message: - raise ValueError(f"Message not found: {self._message_id}") - message.workflow_run_id = workflow_execution.id_ - workflow_start_resp = self._workflow_response_converter.workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - session.commit() + parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield parallel_start_resp - yield workflow_start_resp - elif isinstance( - event, - QueueNodeRetryEvent, - ): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( - workflow_execution_id=self._workflow_run_id, event=event - ) - node_retry_resp = self._workflow_response_converter.workflow_node_retry_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() + def _handle_parallel_branch_finished_events( + self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle parallel branch finished events.""" + self._ensure_workflow_initialized() - if node_retry_resp: - yield node_retry_resp - elif isinstance(event, QueueNodeStartedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield parallel_finish_resp - workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( - workflow_execution_id=self._workflow_run_id, event=event - ) + def _handle_iteration_start_event( + self, event: QueueIterationStartEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle iteration start events.""" + self._ensure_workflow_initialized() - node_start_resp = self._workflow_response_converter.workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield iter_start_resp - if node_start_resp: - yield node_start_resp - elif isinstance(event, QueueNodeSucceededEvent): - # Record files if it's an answer node or end node - if event.node_type in [NodeType.ANSWER, NodeType.END]: - self._recorded_files.extend( - self._workflow_response_converter.fetch_files_from_node_outputs(event.outputs or {}) - ) - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success( - event=event - ) + def _handle_iteration_next_event( + self, event: QueueIterationNextEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle iteration next events.""" + self._ensure_workflow_initialized() - node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() - self._save_output_for_event(event, workflow_node_execution.id) + iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield iter_next_resp - if node_finish_resp: - yield node_finish_resp - elif isinstance( - event, - QueueNodeFailedEvent - | QueueNodeInIterationFailedEvent - | QueueNodeInLoopFailedEvent - | QueueNodeExceptionEvent, - ): - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed( - event=event - ) + def _handle_iteration_completed_event( + self, event: QueueIterationCompletedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle iteration completed events.""" + self._ensure_workflow_initialized() - node_finish_resp = self._workflow_response_converter.workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - if isinstance(event, QueueNodeExceptionEvent): - self._save_output_for_event(event, workflow_node_execution.id) - - if node_finish_resp: - yield node_finish_resp - elif isinstance(event, QueueParallelBranchRunStartedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - - parallel_start_resp = ( - self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - ) + iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield iter_finish_resp - yield parallel_start_resp - elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + def _handle_loop_start_event(self, event: QueueLoopStartEvent, **kwargs) -> Generator[StreamResponse, None, None]: + """Handle loop start events.""" + self._ensure_workflow_initialized() - parallel_finish_resp = ( - self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - ) + loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield loop_start_resp - yield parallel_finish_resp - elif isinstance(event, QueueIterationStartEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + def _handle_loop_next_event(self, event: QueueLoopNextEvent, **kwargs) -> Generator[StreamResponse, None, None]: + """Handle loop next events.""" + self._ensure_workflow_initialized() - iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) + loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield loop_next_resp - yield iter_start_resp - elif isinstance(event, QueueIterationNextEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + def _handle_loop_completed_event( + self, event: QueueLoopCompletedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle loop completed events.""" + self._ensure_workflow_initialized() - iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) + loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield loop_finish_resp - yield iter_next_resp - elif isinstance(event, QueueIterationCompletedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + def _handle_workflow_succeeded_event( + self, + event: QueueWorkflowSucceededEvent, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + trace_manager: Optional[TraceQueueManager] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle workflow succeeded events.""" + self._ensure_workflow_initialized() + validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) + + with self._database_session() as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( + workflow_run_id=self._workflow_run_id, + total_tokens=validated_state.total_tokens, + total_steps=validated_state.node_run_steps, + outputs=event.outputs, + conversation_id=self._conversation_id, + trace_manager=trace_manager, + ) + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + # Store outputs into metadata for external usage (exclude the answer key) + workflow_outputs = workflow_finish_resp.data.outputs + if workflow_outputs: + filtered_outputs = {k: v for k, v in workflow_outputs.items() if k != "answer"} + self._task_state.metadata.outputs = filtered_outputs - iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) + yield workflow_finish_resp + self._base_task_pipeline._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) - yield iter_finish_resp - elif isinstance(event, QueueLoopStartEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + def _handle_workflow_partial_success_event( + self, + event: QueueWorkflowPartialSuccessEvent, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + trace_manager: Optional[TraceQueueManager] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle workflow partial success events.""" + self._ensure_workflow_initialized() + validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) + + with self._database_session() as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( + workflow_run_id=self._workflow_run_id, + total_tokens=validated_state.total_tokens, + total_steps=validated_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) - loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) + yield workflow_finish_resp + self._base_task_pipeline._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) + + def _handle_workflow_failed_event( + self, + event: QueueWorkflowFailedEvent, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + trace_manager: Optional[TraceQueueManager] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle workflow failed events.""" + self._ensure_workflow_initialized() + validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) + + with self._database_session() as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( + workflow_run_id=self._workflow_run_id, + total_tokens=validated_state.total_tokens, + total_steps=validated_state.node_run_steps, + status=WorkflowExecutionStatus.FAILED, + error_message=event.error, + conversation_id=self._conversation_id, + trace_manager=trace_manager, + exceptions_count=event.exceptions_count, + ) + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}")) + err = self._base_task_pipeline._handle_error(event=err_event, session=session, message_id=self._message_id) - yield loop_start_resp - elif isinstance(event, QueueLoopNextEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + yield workflow_finish_resp + yield self._base_task_pipeline._error_to_stream_response(err) - loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response( + def _handle_stop_event( + self, + event: QueueStopEvent, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + trace_manager: Optional[TraceQueueManager] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle stop events.""" + if self._workflow_run_id and graph_runtime_state: + with self._database_session() as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( + workflow_run_id=self._workflow_run_id, + total_tokens=graph_runtime_state.total_tokens, + total_steps=graph_runtime_state.node_run_steps, + status=WorkflowExecutionStatus.STOPPED, + error_message=event.get_stop_reason(), + conversation_id=self._conversation_id, + trace_manager=trace_manager, + ) + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, + workflow_execution=workflow_execution, ) + # Save message + self._save_message(session=session, graph_runtime_state=graph_runtime_state) - yield loop_next_resp - elif isinstance(event, QueueLoopCompletedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + yield workflow_finish_resp + elif event.stopped_by in ( + QueueStopEvent.StopBy.INPUT_MODERATION, + QueueStopEvent.StopBy.ANNOTATION_REPLY, + ): + # When hitting input-moderation or annotation-reply, the workflow will not start + with self._database_session() as session: + # Save message + self._save_message(session=session) - loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) + yield self._message_end_to_stream_response() - yield loop_finish_resp - elif isinstance(event, QueueWorkflowSucceededEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + def _handle_advanced_chat_message_end_event( + self, + event: QueueAdvancedChatMessageEndEvent, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle advanced chat message end events.""" + self._ensure_graph_runtime_initialized(graph_runtime_state) - if not graph_runtime_state: - raise ValueError("workflow run not initialized.") + output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished( + self._task_state.answer + ) + if output_moderation_answer: + self._task_state.answer = output_moderation_answer + yield self._message_cycle_manager.message_replace_to_stream_response( + answer=output_moderation_answer, + reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, + ) - with Session(db.engine, expire_on_commit=False) as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( - workflow_run_id=self._workflow_run_id, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - conversation_id=self._conversation_id, - trace_manager=trace_manager, - ) - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - workflow_outputs = workflow_finish_resp.data.outputs - if workflow_outputs: - filtered_outputs = {k: v for k, v in workflow_outputs.items() if k != "answer"} - self._task_state.metadata.outputs = filtered_outputs - yield workflow_finish_resp - self._base_task_pipeline._queue_manager.publish( - QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE - ) - elif isinstance(event, QueueWorkflowPartialSuccessEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - - with Session(db.engine, expire_on_commit=False) as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( - workflow_run_id=self._workflow_run_id, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - exceptions_count=event.exceptions_count, - conversation_id=None, - trace_manager=trace_manager, - ) - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) + # Save message + with self._database_session() as session: + self._save_message(session=session, graph_runtime_state=graph_runtime_state) - yield workflow_finish_resp - self._base_task_pipeline._queue_manager.publish( - QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE - ) - elif isinstance(event, QueueWorkflowFailedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - - with Session(db.engine, expire_on_commit=False) as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( - workflow_run_id=self._workflow_run_id, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowExecutionStatus.FAILED, - error_message=event.error, - conversation_id=self._conversation_id, - trace_manager=trace_manager, - exceptions_count=event.exceptions_count, - ) - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_execution.error_message}")) - err = self._base_task_pipeline._handle_error( - event=err_event, session=session, message_id=self._message_id - ) + yield self._message_end_to_stream_response() - yield workflow_finish_resp - yield self._base_task_pipeline._error_to_stream_response(err) - break - elif isinstance(event, QueueStopEvent): - if self._workflow_run_id and graph_runtime_state: - with Session(db.engine, expire_on_commit=False) as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( - workflow_run_id=self._workflow_run_id, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowExecutionStatus.STOPPED, - error_message=event.get_stop_reason(), - conversation_id=self._conversation_id, - trace_manager=trace_manager, - ) - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - # Save message - self._save_message(session=session, graph_runtime_state=graph_runtime_state) - session.commit() - - yield workflow_finish_resp - elif event.stopped_by in ( - QueueStopEvent.StopBy.INPUT_MODERATION, - QueueStopEvent.StopBy.ANNOTATION_REPLY, - ): - # When hitting input-moderation or annotation-reply, the workflow will not start - with Session(db.engine, expire_on_commit=False) as session: - # Save message - self._save_message(session=session) - session.commit() - - yield self._message_end_to_stream_response() - break - elif isinstance(event, QueueRetrieverResourcesEvent): - self._message_cycle_manager.handle_retriever_resources(event) - - with Session(db.engine, expire_on_commit=False) as session: - message = self._get_message(session=session) - message.message_metadata = self._task_state.metadata.model_dump_json() - session.commit() - elif isinstance(event, QueueAnnotationReplyEvent): - self._message_cycle_manager.handle_annotation_reply(event) - - with Session(db.engine, expire_on_commit=False) as session: - message = self._get_message(session=session) - message.message_metadata = self._task_state.metadata.model_dump_json() - session.commit() - elif isinstance(event, QueueTextChunkEvent): - delta_text = event.text - if delta_text is None: - continue + def _handle_retriever_resources_event( + self, event: QueueRetrieverResourcesEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle retriever resources events.""" + self._message_cycle_manager.handle_retriever_resources(event) - # handle output moderation chunk - should_direct_answer = self._handle_output_moderation_chunk(delta_text) - if should_direct_answer: - continue + with self._database_session() as session: + message = self._get_message(session=session) + message.message_metadata = self._task_state.metadata.model_dump_json() + return + yield # Make this a generator - # only publish tts message at text chunk streaming - if tts_publisher: - tts_publisher.publish(queue_message) + def _handle_annotation_reply_event( + self, event: QueueAnnotationReplyEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle annotation reply events.""" + self._message_cycle_manager.handle_annotation_reply(event) - self._task_state.answer += delta_text - yield self._message_cycle_manager.message_to_stream_response( - answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector - ) - elif isinstance(event, QueueMessageReplaceEvent): - # published by moderation - yield self._message_cycle_manager.message_replace_to_stream_response( - answer=event.text, reason=event.reason - ) - elif isinstance(event, QueueAdvancedChatMessageEndEvent): - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") + with self._database_session() as session: + message = self._get_message(session=session) + message.message_metadata = self._task_state.metadata.model_dump_json() + return + yield # Make this a generator - output_moderation_answer = self._base_task_pipeline._handle_output_moderation_when_task_finished( - self._task_state.answer - ) - if output_moderation_answer: - self._task_state.answer = output_moderation_answer - yield self._message_cycle_manager.message_replace_to_stream_response( - answer=output_moderation_answer, - reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION, + def _handle_message_replace_event( + self, event: QueueMessageReplaceEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle message replace events.""" + yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason) + + def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]: + """Handle agent log events.""" + yield self._workflow_response_converter.handle_agent_log( + task_id=self._application_generate_entity.task_id, event=event + ) + + def _get_event_handlers(self) -> dict[type, Callable]: + """Get mapping of event types to their handlers using fluent pattern.""" + return { + # Basic events + QueuePingEvent: self._handle_ping_event, + QueueErrorEvent: self._handle_error_event, + QueueTextChunkEvent: self._handle_text_chunk_event, + # Workflow events + QueueWorkflowStartedEvent: self._handle_workflow_started_event, + QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event, + QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event, + QueueWorkflowFailedEvent: self._handle_workflow_failed_event, + # Node events + QueueNodeRetryEvent: self._handle_node_retry_event, + QueueNodeStartedEvent: self._handle_node_started_event, + QueueNodeSucceededEvent: self._handle_node_succeeded_event, + # Parallel branch events + QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event, + # Iteration events + QueueIterationStartEvent: self._handle_iteration_start_event, + QueueIterationNextEvent: self._handle_iteration_next_event, + QueueIterationCompletedEvent: self._handle_iteration_completed_event, + # Loop events + QueueLoopStartEvent: self._handle_loop_start_event, + QueueLoopNextEvent: self._handle_loop_next_event, + QueueLoopCompletedEvent: self._handle_loop_completed_event, + # Control events + QueueStopEvent: self._handle_stop_event, + # Message events + QueueRetrieverResourcesEvent: self._handle_retriever_resources_event, + QueueAnnotationReplyEvent: self._handle_annotation_reply_event, + QueueMessageReplaceEvent: self._handle_message_replace_event, + QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event, + QueueAgentLogEvent: self._handle_agent_log_event, + } + + def _dispatch_event( + self, + event: Any, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + trace_manager: Optional[TraceQueueManager] = None, + queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + ) -> Generator[StreamResponse, None, None]: + """Dispatch events using elegant pattern matching.""" + handlers = self._get_event_handlers() + event_type = type(event) + + # Direct handler lookup + if handler := handlers.get(event_type): + yield from handler( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + return + + # Handle node failure events with isinstance check + if isinstance( + event, + ( + QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, + QueueNodeInLoopFailedEvent, + QueueNodeExceptionEvent, + ), + ): + yield from self._handle_node_failed_events( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + return + + # Handle parallel branch finished events with isinstance check + if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)): + yield from self._handle_parallel_branch_finished_events( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + return + + # For unhandled events, we continue (original behavior) + return + + def _process_stream_response( + self, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + trace_manager: Optional[TraceQueueManager] = None, + ) -> Generator[StreamResponse, None, None]: + """ + Process stream response using elegant Fluent Python patterns. + Maintains exact same functionality as original 57-if-statement version. + """ + # Initialize graph runtime state + graph_runtime_state: Optional[GraphRuntimeState] = None + + for queue_message in self._base_task_pipeline._queue_manager.listen(): + event = queue_message.event + + match event: + case QueueWorkflowStartedEvent(): + graph_runtime_state = event.graph_runtime_state + yield from self._handle_workflow_started_event(event) + + case QueueTextChunkEvent(): + yield from self._handle_text_chunk_event( + event, tts_publisher=tts_publisher, queue_message=queue_message ) - # Save message - with Session(db.engine, expire_on_commit=False) as session: - self._save_message(session=session, graph_runtime_state=graph_runtime_state) - session.commit() - - yield self._message_end_to_stream_response() - elif isinstance(event, QueueAgentLogEvent): - yield self._workflow_response_converter.handle_agent_log( - task_id=self._application_generate_entity.task_id, event=event - ) - else: - continue - # publish None when task finished + case QueueErrorEvent(): + yield from self._handle_error_event(event) + break + + case QueueWorkflowFailedEvent(): + yield from self._handle_workflow_failed_event( + event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager + ) + break + + case QueueStopEvent(): + yield from self._handle_stop_event( + event, graph_runtime_state=graph_runtime_state, trace_manager=trace_manager + ) + break + + # Handle all other events through elegant dispatch + case _: + if responses := list( + self._dispatch_event( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + ): + yield from responses + if tts_publisher: tts_publisher.publish(None) @@ -743,7 +959,6 @@ class AdvancedChatAppGenerateTaskPipeline: """ if self._base_task_pipeline._output_moderation_handler: if self._base_task_pipeline._output_moderation_handler.should_direct_output(): - # stop subscribe new token when output moderation should direct output self._task_state.answer = self._base_task_pipeline._output_moderation_handler.get_final_output() self._base_task_pipeline._queue_manager.publish( QueueTextChunkEvent(text=self._task_state.answer), PublishFrom.TASK_PIPELINE diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index edea6199d3..8665bc9d11 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -15,7 +15,8 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfigManager from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.app.apps.agent_chat.generate_response_converter import AgentChatAppGenerateResponseConverter -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 0ba33fbe0d..9da0bae56a 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -169,7 +169,3 @@ class AppQueueManager: raise TypeError( "Critical Error: Passing SQLAlchemy Model instances that cause thread safety issues is not allowed." ) - - -class GenerateTaskStoppedError(Exception): - pass diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index a3f0cf7f9f..6e8c261a6a 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -38,69 +38,6 @@ _logger = logging.getLogger(__name__) class AppRunner: - def get_pre_calculate_rest_tokens( - self, - app_record: App, - model_config: ModelConfigWithCredentialsEntity, - prompt_template_entity: PromptTemplateEntity, - inputs: Mapping[str, str], - files: Sequence["File"], - query: Optional[str] = None, - ) -> int: - """ - Get pre calculate rest tokens - :param app_record: app record - :param model_config: model config entity - :param prompt_template_entity: prompt template entity - :param inputs: inputs - :param files: files - :param query: query - :return: - """ - # Invoke model - model_instance = ModelInstance( - provider_model_bundle=model_config.provider_model_bundle, model=model_config.model - ) - - model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - - max_tokens = 0 - for parameter_rule in model_config.model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_config.parameters.get(parameter_rule.name) - or model_config.parameters.get(parameter_rule.use_template or "") - ) or 0 - - if model_context_tokens is None: - return -1 - - if max_tokens is None: - max_tokens = 0 - - # get prompt messages without memory and context - prompt_messages, stop = self.organize_prompt_messages( - app_record=app_record, - model_config=model_config, - prompt_template_entity=prompt_template_entity, - inputs=inputs, - files=files, - query=query, - ) - - prompt_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - rest_tokens: int = model_context_tokens - max_tokens - prompt_tokens - if rest_tokens < 0: - raise InvokeBadRequestError( - "Query or prefix prompt is too long, you can reduce the prefix prompt, " - "or shrink the max token, or switch to a llm with a larger token limit size." - ) - - return rest_tokens - def recalc_llm_max_tokens( self, model_config: ModelConfigWithCredentialsEntity, prompt_messages: list[PromptMessage] ): @@ -181,7 +118,7 @@ class AppRunner: else: memory_config = MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)) - model_mode = ModelMode.value_of(model_config.mode) + model_mode = ModelMode(model_config.mode) prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] if model_mode == ModelMode.COMPLETION: advanced_completion_prompt_template = prompt_template_entity.advanced_completion_prompt_template diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index a28c106ce9..0c76cc39ae 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -11,10 +11,11 @@ from configs import dify_config from constants import UUID_NIL from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.chat.app_config_manager import ChatAppConfigManager from core.app.apps.chat.app_runner import ChatAppRunner from core.app.apps.chat.generate_response_converter import ChatAppGenerateResponseConverter +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 966a6f1d66..195e7e2e3d 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -10,10 +10,11 @@ from pydantic import ValidationError from configs import dify_config from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.completion.app_config_manager import CompletionAppConfigManager from core.app.apps.completion.app_runner import CompletionAppRunner from core.app.apps.completion.generate_response_converter import CompletionAppGenerateResponseConverter +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom diff --git a/api/core/app/apps/exc.py b/api/core/app/apps/exc.py new file mode 100644 index 0000000000..4187118b9b --- /dev/null +++ b/api/core/app/apps/exc.py @@ -0,0 +1,2 @@ +class GenerateTaskStoppedError(Exception): + pass diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index e84d59209d..d50cf1c941 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -1,12 +1,12 @@ import json import logging from collections.abc import Generator -from datetime import UTC, datetime from typing import Optional, Union, cast from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppModelConfigFrom from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError +from core.app.apps.base_app_queue_manager import AppQueueManager +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, AgentChatAppGenerateEntity, @@ -24,6 +24,7 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models import Account from models.enums import CreatorUserRole from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile @@ -183,7 +184,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): db.session.commit() db.session.refresh(conversation) else: - conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) + conversation.updated_at = naive_utc_now() db.session.commit() message = Message( diff --git a/api/core/app/apps/message_based_app_queue_manager.py b/api/core/app/apps/message_based_app_queue_manager.py index 363c3c82bb..8507f23f17 100644 --- a/api/core/app/apps/message_based_app_queue_manager.py +++ b/api/core/app/apps/message_based_app_queue_manager.py @@ -1,4 +1,5 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 2f9632e97d..6f560b3253 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -13,7 +13,8 @@ import contexts from configs import dify_config from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner diff --git a/api/core/app/apps/workflow/app_queue_manager.py b/api/core/app/apps/workflow/app_queue_manager.py index 349b8eb51b..40fc03afb7 100644 --- a/api/core/app/apps/workflow/app_queue_manager.py +++ b/api/core/app/apps/workflow/app_queue_manager.py @@ -1,4 +1,5 @@ -from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom +from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import ( AppQueueEvent, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 07aeb57fa3..3a66ffa578 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -11,7 +11,7 @@ from core.app.entities.app_invoke_entities import ( ) from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db @@ -95,13 +95,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): files = self.application_generate_entity.files # Create a variable pool. - system_inputs = { - SystemVariableKey.FILES: files, - SystemVariableKey.USER_ID: user_id, - SystemVariableKey.APP_ID: app_config.app_id, - SystemVariableKey.WORKFLOW_ID: app_config.workflow_id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id, - } + + system_inputs = SystemVariable( + files=files, + user_id=user_id, + app_id=app_config.app_id, + workflow_id=app_config.workflow_id, + workflow_execution_id=self.application_generate_entity.workflow_execution_id, + ) variable_pool = VariablePool( system_variables=system_inputs, diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index c6b326d8a4..9a39b2e01e 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -1,7 +1,8 @@ import logging import time -from collections.abc import Generator -from typing import Optional, Union +from collections.abc import Callable, Generator +from contextlib import contextmanager +from typing import Any, Optional, Union from sqlalchemy.orm import Session @@ -13,6 +14,7 @@ from core.app.entities.app_invoke_entities import ( WorkflowAppGenerateEntity, ) from core.app.entities.queue_entities import ( + MessageQueueMessage, QueueAgentLogEvent, QueueErrorEvent, QueueIterationCompletedEvent, @@ -38,11 +40,13 @@ from core.app.entities.queue_entities import ( QueueWorkflowPartialSuccessEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, + WorkflowQueueMessage, ) from core.app.entities.task_entities import ( ErrorStreamResponse, MessageAudioEndStreamResponse, MessageAudioStreamResponse, + PingStreamResponse, StreamResponse, TextChunkStreamResponse, WorkflowAppBlockingResponse, @@ -54,10 +58,11 @@ from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTas from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.workflow.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType -from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_cycle_manager import CycleManagerWorkflowInfo, WorkflowCycleManager from extensions.ext_database import db from models.account import Account @@ -107,13 +112,13 @@ class WorkflowAppGenerateTaskPipeline: self._workflow_cycle_manager = WorkflowCycleManager( application_generate_entity=application_generate_entity, - workflow_system_variables={ - SystemVariableKey.FILES: application_generate_entity.files, - SystemVariableKey.USER_ID: user_session_id, - SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - SystemVariableKey.WORKFLOW_EXECUTION_ID: application_generate_entity.workflow_execution_id, - }, + workflow_system_variables=SystemVariable( + files=application_generate_entity.files, + user_id=user_session_id, + app_id=application_generate_entity.app_config.app_id, + workflow_id=workflow.id, + workflow_execution_id=application_generate_entity.workflow_execution_id, + ), workflow_info=CycleManagerWorkflowInfo( workflow_id=workflow.id, workflow_type=WorkflowType(workflow.type), @@ -246,315 +251,492 @@ class WorkflowAppGenerateTaskPipeline: if tts_publisher: yield MessageAudioEndStreamResponse(audio="", task_id=task_id) - def _process_stream_response( + @contextmanager + def _database_session(self): + """Context manager for database sessions.""" + with Session(db.engine, expire_on_commit=False) as session: + try: + yield session + session.commit() + except Exception: + session.rollback() + raise + + def _ensure_workflow_initialized(self) -> None: + """Fluent validation for workflow state.""" + if not self._workflow_run_id: + raise ValueError("workflow run not initialized.") + + def _ensure_graph_runtime_initialized(self, graph_runtime_state: Optional[GraphRuntimeState]) -> GraphRuntimeState: + """Fluent validation for graph runtime state.""" + if not graph_runtime_state: + raise ValueError("graph runtime state not initialized.") + return graph_runtime_state + + def _handle_ping_event(self, event: QueuePingEvent, **kwargs) -> Generator[PingStreamResponse, None, None]: + """Handle ping events.""" + yield self._base_task_pipeline._ping_stream_response() + + def _handle_error_event(self, event: QueueErrorEvent, **kwargs) -> Generator[ErrorStreamResponse, None, None]: + """Handle error events.""" + err = self._base_task_pipeline._handle_error(event=event) + yield self._base_task_pipeline._error_to_stream_response(err) + + def _handle_workflow_started_event( + self, event: QueueWorkflowStartedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle workflow started events.""" + # init workflow run + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start() + self._workflow_run_id = workflow_execution.id_ + start_resp = self._workflow_response_converter.workflow_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) + yield start_resp + + def _handle_node_retry_event(self, event: QueueNodeRetryEvent, **kwargs) -> Generator[StreamResponse, None, None]: + """Handle node retry events.""" + self._ensure_workflow_initialized() + + with self._database_session() as session: + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( + workflow_execution_id=self._workflow_run_id, + event=event, + ) + response = self._workflow_response_converter.workflow_node_retry_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if response: + yield response + + def _handle_node_started_event( + self, event: QueueNodeStartedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle node started events.""" + self._ensure_workflow_initialized() + + workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( + workflow_execution_id=self._workflow_run_id, event=event + ) + node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + if node_start_response: + yield node_start_response + + def _handle_node_succeeded_event( + self, event: QueueNodeSucceededEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle node succeeded events.""" + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success(event=event) + node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) + + self._save_output_for_event(event, workflow_node_execution.id) + + if node_success_response: + yield node_success_response + + def _handle_node_failed_events( self, - tts_publisher: Optional[AppGeneratorTTSPublisher] = None, - trace_manager: Optional[TraceQueueManager] = None, + event: Union[ + QueueNodeFailedEvent, QueueNodeInIterationFailedEvent, QueueNodeInLoopFailedEvent, QueueNodeExceptionEvent + ], + **kwargs, ) -> Generator[StreamResponse, None, None]: - """ - Process stream response. - :return: - """ - graph_runtime_state = None + """Handle various node failure events.""" + workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed( + event=event, + ) + node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response( + event=event, + task_id=self._application_generate_entity.task_id, + workflow_node_execution=workflow_node_execution, + ) - for queue_message in self._base_task_pipeline._queue_manager.listen(): - event = queue_message.event + if isinstance(event, QueueNodeExceptionEvent): + self._save_output_for_event(event, workflow_node_execution.id) - if isinstance(event, QueuePingEvent): - yield self._base_task_pipeline._ping_stream_response() - elif isinstance(event, QueueErrorEvent): - err = self._base_task_pipeline._handle_error(event=event) - yield self._base_task_pipeline._error_to_stream_response(err) - break - elif isinstance(event, QueueWorkflowStartedEvent): - # override graph runtime state - graph_runtime_state = event.graph_runtime_state - - # init workflow run - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_start() - self._workflow_run_id = workflow_execution.id_ - start_resp = self._workflow_response_converter.workflow_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) + if node_failed_response: + yield node_failed_response - yield start_resp - elif isinstance( - event, - QueueNodeRetryEvent, - ): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - with Session(db.engine, expire_on_commit=False) as session: - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_retried( - workflow_execution_id=self._workflow_run_id, - event=event, - ) - response = self._workflow_response_converter.workflow_node_retry_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - session.commit() + def _handle_parallel_branch_started_event( + self, event: QueueParallelBranchRunStartedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle parallel branch started events.""" + self._ensure_workflow_initialized() - if response: - yield response - elif isinstance(event, QueueNodeStartedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + parallel_start_resp = self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield parallel_start_resp - workflow_node_execution = self._workflow_cycle_manager.handle_node_execution_start( - workflow_execution_id=self._workflow_run_id, event=event - ) - node_start_response = self._workflow_response_converter.workflow_node_start_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + def _handle_parallel_branch_finished_events( + self, event: Union[QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent], **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle parallel branch finished events.""" + self._ensure_workflow_initialized() - if node_start_response: - yield node_start_response - elif isinstance(event, QueueNodeSucceededEvent): - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_success( - event=event - ) - node_success_response = self._workflow_response_converter.workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) + parallel_finish_resp = self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield parallel_finish_resp - self._save_output_for_event(event, workflow_node_execution.id) + def _handle_iteration_start_event( + self, event: QueueIterationStartEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle iteration start events.""" + self._ensure_workflow_initialized() - if node_success_response: - yield node_success_response - elif isinstance( - event, - QueueNodeFailedEvent - | QueueNodeInIterationFailedEvent - | QueueNodeInLoopFailedEvent - | QueueNodeExceptionEvent, - ): - workflow_node_execution = self._workflow_cycle_manager.handle_workflow_node_execution_failed( - event=event, - ) - node_failed_response = self._workflow_response_converter.workflow_node_finish_to_stream_response( - event=event, - task_id=self._application_generate_entity.task_id, - workflow_node_execution=workflow_node_execution, - ) - if isinstance(event, QueueNodeExceptionEvent): - self._save_output_for_event(event, workflow_node_execution.id) + iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield iter_start_resp - if node_failed_response: - yield node_failed_response + def _handle_iteration_next_event( + self, event: QueueIterationNextEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle iteration next events.""" + self._ensure_workflow_initialized() - elif isinstance(event, QueueParallelBranchRunStartedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield iter_next_resp - parallel_start_resp = ( - self._workflow_response_converter.workflow_parallel_branch_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - ) + def _handle_iteration_completed_event( + self, event: QueueIterationCompletedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle iteration completed events.""" + self._ensure_workflow_initialized() - yield parallel_start_resp + iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield iter_finish_resp - elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + def _handle_loop_start_event(self, event: QueueLoopStartEvent, **kwargs) -> Generator[StreamResponse, None, None]: + """Handle loop start events.""" + self._ensure_workflow_initialized() - parallel_finish_resp = ( - self._workflow_response_converter.workflow_parallel_branch_finished_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) - ) + loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield loop_start_resp - yield parallel_finish_resp + def _handle_loop_next_event(self, event: QueueLoopNextEvent, **kwargs) -> Generator[StreamResponse, None, None]: + """Handle loop next events.""" + self._ensure_workflow_initialized() - elif isinstance(event, QueueIterationStartEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield loop_next_resp - iter_start_resp = self._workflow_response_converter.workflow_iteration_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) + def _handle_loop_completed_event( + self, event: QueueLoopCompletedEvent, **kwargs + ) -> Generator[StreamResponse, None, None]: + """Handle loop completed events.""" + self._ensure_workflow_initialized() - yield iter_start_resp + loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response( + task_id=self._application_generate_entity.task_id, + workflow_execution_id=self._workflow_run_id, + event=event, + ) + yield loop_finish_resp - elif isinstance(event, QueueIterationNextEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + def _handle_workflow_succeeded_event( + self, + event: QueueWorkflowSucceededEvent, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + trace_manager: Optional[TraceQueueManager] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle workflow succeeded events.""" + self._ensure_workflow_initialized() + validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) + + with self._database_session() as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( + workflow_run_id=self._workflow_run_id, + total_tokens=validated_state.total_tokens, + total_steps=validated_state.node_run_steps, + outputs=event.outputs, + conversation_id=None, + trace_manager=trace_manager, + ) - iter_next_resp = self._workflow_response_converter.workflow_iteration_next_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) + # save workflow app log + self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) - yield iter_next_resp + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) - elif isinstance(event, QueueIterationCompletedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + yield workflow_finish_resp - iter_finish_resp = self._workflow_response_converter.workflow_iteration_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) + def _handle_workflow_partial_success_event( + self, + event: QueueWorkflowPartialSuccessEvent, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + trace_manager: Optional[TraceQueueManager] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle workflow partial success events.""" + self._ensure_workflow_initialized() + validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) + + with self._database_session() as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( + workflow_run_id=self._workflow_run_id, + total_tokens=validated_state.total_tokens, + total_steps=validated_state.node_run_steps, + outputs=event.outputs, + exceptions_count=event.exceptions_count, + conversation_id=None, + trace_manager=trace_manager, + ) - yield iter_finish_resp + # save workflow app log + self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) - elif isinstance(event, QueueLoopStartEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) - loop_start_resp = self._workflow_response_converter.workflow_loop_start_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) + yield workflow_finish_resp - yield loop_start_resp + def _handle_workflow_failed_and_stop_events( + self, + event: Union[QueueWorkflowFailedEvent, QueueStopEvent], + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + trace_manager: Optional[TraceQueueManager] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle workflow failed and stop events.""" + self._ensure_workflow_initialized() + validated_state = self._ensure_graph_runtime_initialized(graph_runtime_state) + + with self._database_session() as session: + workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( + workflow_run_id=self._workflow_run_id, + total_tokens=validated_state.total_tokens, + total_steps=validated_state.node_run_steps, + status=WorkflowExecutionStatus.FAILED + if isinstance(event, QueueWorkflowFailedEvent) + else WorkflowExecutionStatus.STOPPED, + error_message=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(), + conversation_id=None, + trace_manager=trace_manager, + exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, + ) - elif isinstance(event, QueueLoopNextEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + # save workflow app log + self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) - loop_next_resp = self._workflow_response_converter.workflow_loop_next_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) + workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( + session=session, + task_id=self._application_generate_entity.task_id, + workflow_execution=workflow_execution, + ) - yield loop_next_resp + yield workflow_finish_resp - elif isinstance(event, QueueLoopCompletedEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") + def _handle_text_chunk_event( + self, + event: QueueTextChunkEvent, + *, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + **kwargs, + ) -> Generator[StreamResponse, None, None]: + """Handle text chunk events.""" + delta_text = event.text + if delta_text is None: + return - loop_finish_resp = self._workflow_response_converter.workflow_loop_completed_to_stream_response( - task_id=self._application_generate_entity.task_id, - workflow_execution_id=self._workflow_run_id, - event=event, - ) + # only publish tts message at text chunk streaming + if tts_publisher and queue_message: + tts_publisher.publish(queue_message) - yield loop_finish_resp - - elif isinstance(event, QueueWorkflowSucceededEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - - with Session(db.engine, expire_on_commit=False) as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_success( - workflow_run_id=self._workflow_run_id, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - conversation_id=None, - trace_manager=trace_manager, - ) + yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector) - # save workflow app log - self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) + def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]: + """Handle agent log events.""" + yield self._workflow_response_converter.handle_agent_log( + task_id=self._application_generate_entity.task_id, event=event + ) - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - session.commit() - - yield workflow_finish_resp - elif isinstance(event, QueueWorkflowPartialSuccessEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - - with Session(db.engine, expire_on_commit=False) as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_partial_success( - workflow_run_id=self._workflow_run_id, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - outputs=event.outputs, - exceptions_count=event.exceptions_count, - conversation_id=None, - trace_manager=trace_manager, - ) + def _get_event_handlers(self) -> dict[type, Callable]: + """Get mapping of event types to their handlers using fluent pattern.""" + return { + # Basic events + QueuePingEvent: self._handle_ping_event, + QueueErrorEvent: self._handle_error_event, + QueueTextChunkEvent: self._handle_text_chunk_event, + # Workflow events + QueueWorkflowStartedEvent: self._handle_workflow_started_event, + QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event, + QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event, + # Node events + QueueNodeRetryEvent: self._handle_node_retry_event, + QueueNodeStartedEvent: self._handle_node_started_event, + QueueNodeSucceededEvent: self._handle_node_succeeded_event, + # Parallel branch events + QueueParallelBranchRunStartedEvent: self._handle_parallel_branch_started_event, + # Iteration events + QueueIterationStartEvent: self._handle_iteration_start_event, + QueueIterationNextEvent: self._handle_iteration_next_event, + QueueIterationCompletedEvent: self._handle_iteration_completed_event, + # Loop events + QueueLoopStartEvent: self._handle_loop_start_event, + QueueLoopNextEvent: self._handle_loop_next_event, + QueueLoopCompletedEvent: self._handle_loop_completed_event, + # Agent events + QueueAgentLogEvent: self._handle_agent_log_event, + } + + def _dispatch_event( + self, + event: Any, + *, + graph_runtime_state: Optional[GraphRuntimeState] = None, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + trace_manager: Optional[TraceQueueManager] = None, + queue_message: Optional[Union[WorkflowQueueMessage, MessageQueueMessage]] = None, + ) -> Generator[StreamResponse, None, None]: + """Dispatch events using elegant pattern matching.""" + handlers = self._get_event_handlers() + event_type = type(event) - # save workflow app log - self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) + # Direct handler lookup + if handler := handlers.get(event_type): + yield from handler( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + return - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - session.commit() - - yield workflow_finish_resp - elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): - if not self._workflow_run_id: - raise ValueError("workflow run not initialized.") - if not graph_runtime_state: - raise ValueError("graph runtime state not initialized.") - - with Session(db.engine, expire_on_commit=False) as session: - workflow_execution = self._workflow_cycle_manager.handle_workflow_run_failed( - workflow_run_id=self._workflow_run_id, - total_tokens=graph_runtime_state.total_tokens, - total_steps=graph_runtime_state.node_run_steps, - status=WorkflowExecutionStatus.FAILED - if isinstance(event, QueueWorkflowFailedEvent) - else WorkflowExecutionStatus.STOPPED, - error_message=event.error - if isinstance(event, QueueWorkflowFailedEvent) - else event.get_stop_reason(), - conversation_id=None, - trace_manager=trace_manager, - exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0, - ) + # Handle node failure events with isinstance check + if isinstance( + event, + ( + QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, + QueueNodeInLoopFailedEvent, + QueueNodeExceptionEvent, + ), + ): + yield from self._handle_node_failed_events( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + return - # save workflow app log - self._save_workflow_app_log(session=session, workflow_execution=workflow_execution) + # Handle parallel branch finished events with isinstance check + if isinstance(event, (QueueParallelBranchRunSucceededEvent, QueueParallelBranchRunFailedEvent)): + yield from self._handle_parallel_branch_finished_events( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + return - workflow_finish_resp = self._workflow_response_converter.workflow_finish_to_stream_response( - session=session, - task_id=self._application_generate_entity.task_id, - workflow_execution=workflow_execution, - ) - session.commit() + # Handle workflow failed and stop events with isinstance check + if isinstance(event, (QueueWorkflowFailedEvent, QueueStopEvent)): + yield from self._handle_workflow_failed_and_stop_events( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + return - yield workflow_finish_resp - elif isinstance(event, QueueTextChunkEvent): - delta_text = event.text - if delta_text is None: - continue + # For unhandled events, we continue (original behavior) + return - # only publish tts message at text chunk streaming - if tts_publisher: - tts_publisher.publish(queue_message) + def _process_stream_response( + self, + tts_publisher: Optional[AppGeneratorTTSPublisher] = None, + trace_manager: Optional[TraceQueueManager] = None, + ) -> Generator[StreamResponse, None, None]: + """ + Process stream response using elegant Fluent Python patterns. + Maintains exact same functionality as original 44-if-statement version. + """ + # Initialize graph runtime state + graph_runtime_state = None - yield self._text_chunk_to_stream_response( - delta_text, from_variable_selector=event.from_variable_selector - ) - elif isinstance(event, QueueAgentLogEvent): - yield self._workflow_response_converter.handle_agent_log( - task_id=self._application_generate_entity.task_id, event=event - ) - else: - continue + for queue_message in self._base_task_pipeline._queue_manager.listen(): + event = queue_message.event + + match event: + case QueueWorkflowStartedEvent(): + graph_runtime_state = event.graph_runtime_state + yield from self._handle_workflow_started_event(event) + + case QueueTextChunkEvent(): + yield from self._handle_text_chunk_event( + event, tts_publisher=tts_publisher, queue_message=queue_message + ) + + case QueueErrorEvent(): + yield from self._handle_error_event(event) + break + + # Handle all other events through elegant dispatch + case _: + if responses := list( + self._dispatch_event( + event, + graph_runtime_state=graph_runtime_state, + tts_publisher=tts_publisher, + trace_manager=trace_manager, + queue_message=queue_message, + ) + ): + yield from responses if tts_publisher: tts_publisher.publish(None) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 17b9ac5827..2f4d234ecd 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -62,6 +62,7 @@ from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes import NodeType from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db @@ -166,7 +167,7 @@ class WorkflowBasedAppRunner(AppRunner): # init variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, environment_variables=workflow.environment_variables, ) @@ -263,7 +264,7 @@ class WorkflowBasedAppRunner(AppRunner): # init variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, environment_variables=workflow.environment_variables, ) diff --git a/api/core/app/task_pipeline/exc.py b/api/core/app/task_pipeline/exc.py index e4b4168d08..df62776977 100644 --- a/api/core/app/task_pipeline/exc.py +++ b/api/core/app/task_pipeline/exc.py @@ -10,8 +10,3 @@ class RecordNotFoundError(TaskPipilineError): class WorkflowRunNotFoundError(RecordNotFoundError): def __init__(self, workflow_run_id: str): super().__init__("WorkflowRun", workflow_run_id) - - -class WorkflowNodeExecutionNotFoundError(RecordNotFoundError): - def __init__(self, workflow_node_execution_id: str): - super().__init__("WorkflowNodeExecution", workflow_node_execution_id) diff --git a/api/core/entities/parameter_entities.py b/api/core/entities/parameter_entities.py index 2fa347c204..fbd62437e6 100644 --- a/api/core/entities/parameter_entities.py +++ b/api/core/entities/parameter_entities.py @@ -14,6 +14,7 @@ class CommonParameterType(StrEnum): APP_SELECTOR = "app-selector" MODEL_SELECTOR = "model-selector" TOOLS_SELECTOR = "array[tools]" + ANY = "any" # Dynamic select parameter # Once you are not sure about the available options until authorization is done diff --git a/api/core/file/file_manager.py b/api/core/file/file_manager.py index ada19ef8ce..f8c050c2ac 100644 --- a/api/core/file/file_manager.py +++ b/api/core/file/file_manager.py @@ -7,6 +7,7 @@ from core.model_runtime.entities import ( AudioPromptMessageContent, DocumentPromptMessageContent, ImagePromptMessageContent, + TextPromptMessageContent, VideoPromptMessageContent, ) from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes @@ -44,11 +45,44 @@ def to_prompt_message_content( *, image_detail_config: ImagePromptMessageContent.DETAIL | None = None, ) -> PromptMessageContentUnionTypes: + """ + Convert a file to prompt message content. + + This function converts files to their appropriate prompt message content types. + For supported file types (IMAGE, AUDIO, VIDEO, DOCUMENT), it creates the + corresponding message content with proper encoding/URL. + + For unsupported file types, instead of raising an error, it returns a + TextPromptMessageContent with a descriptive message about the file. + + Args: + f: The file to convert + image_detail_config: Optional detail configuration for image files + + Returns: + PromptMessageContentUnionTypes: The appropriate message content type + + Raises: + ValueError: If file extension or mime_type is missing + """ if f.extension is None: raise ValueError("Missing file extension") if f.mime_type is None: raise ValueError("Missing file mime_type") + prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = { + FileType.IMAGE: ImagePromptMessageContent, + FileType.AUDIO: AudioPromptMessageContent, + FileType.VIDEO: VideoPromptMessageContent, + FileType.DOCUMENT: DocumentPromptMessageContent, + } + + # Check if file type is supported + if f.type not in prompt_class_map: + # For unsupported file types, return a text description + return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]") + + # Process supported file types params = { "base64_data": _get_encoded_string(f) if dify_config.MULTIMODAL_SEND_FORMAT == "base64" else "", "url": _to_url(f) if dify_config.MULTIMODAL_SEND_FORMAT == "url" else "", @@ -58,17 +92,7 @@ def to_prompt_message_content( if f.type == FileType.IMAGE: params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = { - FileType.IMAGE: ImagePromptMessageContent, - FileType.AUDIO: AudioPromptMessageContent, - FileType.VIDEO: VideoPromptMessageContent, - FileType.DOCUMENT: DocumentPromptMessageContent, - } - - try: - return prompt_class_map[f.type].model_validate(params) - except KeyError: - raise ValueError(f"file type {f.type} is not supported") + return prompt_class_map[f.type].model_validate(params) def download(f: File, /): diff --git a/api/core/file/tool_file_parser.py b/api/core/file/tool_file_parser.py index 656c9d48ed..fac68beb0f 100644 --- a/api/core/file/tool_file_parser.py +++ b/api/core/file/tool_file_parser.py @@ -7,13 +7,6 @@ if TYPE_CHECKING: _tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None -class ToolFileParser: - @staticmethod - def get_tool_file_manager() -> "ToolFileManager": - assert _tool_file_manager_factory is not None - return _tool_file_manager_factory() - - def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None: global _tool_file_manager_factory _tool_file_manager_factory = factory diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 744fce1cf9..1e40997a8b 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -21,7 +21,7 @@ def encrypt_token(tenant_id: str, token: str): return base64.b64encode(encrypted_token).decode() -def decrypt_token(tenant_id: str, token: str): +def decrypt_token(tenant_id: str, token: str) -> str: return rsa.decrypt(base64.b64decode(token), tenant_id) diff --git a/api/core/helper/provider_cache.py b/api/core/helper/provider_cache.py new file mode 100644 index 0000000000..48ec3be5c8 --- /dev/null +++ b/api/core/helper/provider_cache.py @@ -0,0 +1,84 @@ +import json +from abc import ABC, abstractmethod +from json import JSONDecodeError +from typing import Any, Optional + +from extensions.ext_redis import redis_client + + +class ProviderCredentialsCache(ABC): + """Base class for provider credentials cache""" + + def __init__(self, **kwargs): + self.cache_key = self._generate_cache_key(**kwargs) + + @abstractmethod + def _generate_cache_key(self, **kwargs) -> str: + """Generate cache key based on subclass implementation""" + pass + + def get(self) -> Optional[dict]: + """Get cached provider credentials""" + cached_credentials = redis_client.get(self.cache_key) + if cached_credentials: + try: + cached_credentials = cached_credentials.decode("utf-8") + return dict(json.loads(cached_credentials)) + except JSONDecodeError: + return None + return None + + def set(self, config: dict[str, Any]) -> None: + """Cache provider credentials""" + redis_client.setex(self.cache_key, 86400, json.dumps(config)) + + def delete(self) -> None: + """Delete cached provider credentials""" + redis_client.delete(self.cache_key) + + +class SingletonProviderCredentialsCache(ProviderCredentialsCache): + """Cache for tool single provider credentials""" + + def __init__(self, tenant_id: str, provider_type: str, provider_identity: str): + super().__init__( + tenant_id=tenant_id, + provider_type=provider_type, + provider_identity=provider_identity, + ) + + def _generate_cache_key(self, **kwargs) -> str: + tenant_id = kwargs["tenant_id"] + provider_type = kwargs["provider_type"] + identity_name = kwargs["provider_identity"] + identity_id = f"{provider_type}.{identity_name}" + return f"{provider_type}_credentials:tenant_id:{tenant_id}:id:{identity_id}" + + +class ToolProviderCredentialsCache(ProviderCredentialsCache): + """Cache for tool provider credentials""" + + def __init__(self, tenant_id: str, provider: str, credential_id: str): + super().__init__(tenant_id=tenant_id, provider=provider, credential_id=credential_id) + + def _generate_cache_key(self, **kwargs) -> str: + tenant_id = kwargs["tenant_id"] + provider = kwargs["provider"] + credential_id = kwargs["credential_id"] + return f"tool_credentials:tenant_id:{tenant_id}:provider:{provider}:credential_id:{credential_id}" + + +class NoOpProviderCredentialCache: + """No-op provider credential cache""" + + def get(self) -> Optional[dict]: + """Get cached provider credentials""" + return None + + def set(self, config: dict[str, Any]) -> None: + """Cache provider credentials""" + pass + + def delete(self) -> None: + """Delete cached provider credentials""" + pass diff --git a/api/core/helper/tool_provider_cache.py b/api/core/helper/tool_provider_cache.py deleted file mode 100644 index 2e4a04c579..0000000000 --- a/api/core/helper/tool_provider_cache.py +++ /dev/null @@ -1,51 +0,0 @@ -import json -from enum import Enum -from json import JSONDecodeError -from typing import Optional - -from extensions.ext_redis import redis_client - - -class ToolProviderCredentialsCacheType(Enum): - PROVIDER = "tool_provider" - ENDPOINT = "endpoint" - - -class ToolProviderCredentialsCache: - def __init__(self, tenant_id: str, identity_id: str, cache_type: ToolProviderCredentialsCacheType): - self.cache_key = f"{cache_type.value}_credentials:tenant_id:{tenant_id}:id:{identity_id}" - - def get(self) -> Optional[dict]: - """ - Get cached model provider credentials. - - :return: - """ - cached_provider_credentials = redis_client.get(self.cache_key) - if cached_provider_credentials: - try: - cached_provider_credentials = cached_provider_credentials.decode("utf-8") - cached_provider_credentials = json.loads(cached_provider_credentials) - except JSONDecodeError: - return None - - return dict(cached_provider_credentials) - else: - return None - - def set(self, credentials: dict) -> None: - """ - Cache model provider credentials. - - :param credentials: provider credentials - :return: - """ - redis_client.setex(self.cache_key, 86400, json.dumps(credentials)) - - def delete(self) -> None: - """ - Delete cached model provider credentials. - - :return: - """ - redis_client.delete(self.cache_key) diff --git a/api/core/helper/url_signer.py b/api/core/helper/url_signer.py deleted file mode 100644 index dfb143f4c4..0000000000 --- a/api/core/helper/url_signer.py +++ /dev/null @@ -1,52 +0,0 @@ -import base64 -import hashlib -import hmac -import os -import time - -from pydantic import BaseModel, Field - -from configs import dify_config - - -class SignedUrlParams(BaseModel): - sign_key: str = Field(..., description="The sign key") - timestamp: str = Field(..., description="Timestamp") - nonce: str = Field(..., description="Nonce") - sign: str = Field(..., description="Signature") - - -class UrlSigner: - @classmethod - def get_signed_url(cls, url: str, sign_key: str, prefix: str) -> str: - signed_url_params = cls.get_signed_url_params(sign_key, prefix) - return ( - f"{url}?timestamp={signed_url_params.timestamp}" - f"&nonce={signed_url_params.nonce}&sign={signed_url_params.sign}" - ) - - @classmethod - def get_signed_url_params(cls, sign_key: str, prefix: str) -> SignedUrlParams: - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - sign = cls._sign(sign_key, timestamp, nonce, prefix) - - return SignedUrlParams(sign_key=sign_key, timestamp=timestamp, nonce=nonce, sign=sign) - - @classmethod - def verify(cls, sign_key: str, timestamp: str, nonce: str, sign: str, prefix: str) -> bool: - recalculated_sign = cls._sign(sign_key, timestamp, nonce, prefix) - - return sign == recalculated_sign - - @classmethod - def _sign(cls, sign_key: str, timestamp: str, nonce: str, prefix: str) -> str: - if not dify_config.SECRET_KEY: - raise Exception("SECRET_KEY is not set") - - data_to_sign = f"{prefix}|{sign_key}|{timestamp}|{nonce}" - secret_key = dify_config.SECRET_KEY.encode() - sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - - return encoded_sign diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index e01896a491..f7fd93be4a 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -148,9 +148,11 @@ class LLMGenerator: model_manager = ModelManager() - model_instance = model_manager.get_default_model_instance( + model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, + provider=model_config.get("provider", ""), + model=model_config.get("name", ""), ) try: diff --git a/api/core/mcp/auth/auth_flow.py b/api/core/mcp/auth/auth_flow.py index b63478e822..bcb31a816f 100644 --- a/api/core/mcp/auth/auth_flow.py +++ b/api/core/mcp/auth/auth_flow.py @@ -240,7 +240,7 @@ def refresh_authorization( response = requests.post(token_url, data=params) if not response.ok: raise ValueError(f"Token refresh failed: HTTP {response.status_code}") - return OAuthTokens.parse_obj(response.json()) + return OAuthTokens.model_validate(response.json()) def register_client( diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py index cd55dbf64f..00d5a25956 100644 --- a/api/core/mcp/auth/auth_provider.py +++ b/api/core/mcp/auth/auth_provider.py @@ -8,7 +8,7 @@ from core.mcp.types import ( OAuthTokens, ) from models.tools import MCPToolProvider -from services.tools.mcp_tools_mange_service import MCPToolManageService +from services.tools.mcp_tools_manage_service import MCPToolManageService LATEST_PROTOCOL_VERSION = "1.0" diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index e9036de8c6..5fe52c008a 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -68,15 +68,17 @@ class MCPClient: } parsed_url = urlparse(self.server_url) - path = parsed_url.path + path = parsed_url.path or "" method_name = path.rstrip("/").split("/")[-1] if path else "" - try: + if method_name in connection_methods: client_factory = connection_methods[method_name] self.connect_server(client_factory, method_name) - except KeyError: + else: try: + logger.debug(f"Not supported method {method_name} found in URL path, trying default 'mcp' method.") self.connect_server(sse_client, "sse") except MCPConnectionError: + logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.") self.connect_server(streamablehttp_client, "mcp") def connect_server( @@ -91,7 +93,7 @@ class MCPClient: else {} ) self._streams_context = client_factory(url=self.server_url, headers=headers) - if self._streams_context is None: + if not self._streams_context: raise MCPConnectionError("Failed to create connection context") # Use exit_stack to manage context managers properly @@ -141,10 +143,11 @@ class MCPClient: try: # ExitStack will handle proper cleanup of all managed context managers self.exit_stack.close() + except Exception as e: + logging.exception("Error during cleanup") + raise ValueError(f"Error during cleanup: {e}") + finally: self._session = None self._session_context = None self._streams_context = None self._initialized = False - except Exception as e: - logging.exception("Error during cleanup") - raise ValueError(f"Error during cleanup: {e}") diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index 1c2cf570e2..20ff7e7524 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -148,9 +148,7 @@ class MCPServerStreamableHTTPRequestHandler: if not self.end_user: raise ValueError("User not found") request = cast(types.CallToolRequest, self.request.root) - args = request.params.arguments - if not args: - raise ValueError("No arguments provided") + args = request.params.arguments or {} if self.app.mode in {AppMode.WORKFLOW.value}: args = {"inputs": args} elif self.app.mode in {AppMode.COMPLETION.value}: diff --git a/api/core/mcp/session/base_session.py b/api/core/mcp/session/base_session.py index 1c0f582501..7734b8fdd9 100644 --- a/api/core/mcp/session/base_session.py +++ b/api/core/mcp/session/base_session.py @@ -1,7 +1,7 @@ import logging import queue from collections.abc import Callable -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor, TimeoutError from contextlib import ExitStack from datetime import timedelta from types import TracebackType @@ -171,23 +171,41 @@ class BaseSession( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._exit_stack = ExitStack() + # Initialize executor and future to None for proper cleanup checks + self._executor: ThreadPoolExecutor | None = None + self._receiver_future: Future | None = None def __enter__(self) -> Self: - self._executor = ThreadPoolExecutor() + # The thread pool is dedicated to running `_receive_loop`. Setting `max_workers` to 1 + # ensures no unnecessary threads are created. + self._executor = ThreadPoolExecutor(max_workers=1) self._receiver_future = self._executor.submit(self._receive_loop) return self def check_receiver_status(self) -> None: - if self._receiver_future.done(): + """`check_receiver_status` ensures that any exceptions raised during the + execution of `_receive_loop` are retrieved and propagated.""" + if self._receiver_future and self._receiver_future.done(): self._receiver_future.result() def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None ) -> None: - self._exit_stack.close() self._read_stream.put(None) self._write_stream.put(None) + # Wait for the receiver loop to finish + if self._receiver_future: + try: + self._receiver_future.result(timeout=5.0) # Wait up to 5 seconds + except TimeoutError: + # If the receiver loop is still running after timeout, we'll force shutdown + pass + + # Shutdown the executor + if self._executor: + self._executor.shutdown(wait=True) + def send_request( self, request: SendRequestT, diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index b18a6905fe..db8fec4ee9 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -284,7 +284,8 @@ class AliyunDataTrace(BaseTraceInstance): else: node_span = self.build_workflow_task_span(trace_id, workflow_span_id, trace_info, node_execution) return node_span - except Exception: + except Exception as e: + logging.debug(f"Error occurred in build_workflow_node_span: {e}", exc_info=True) return None def get_workflow_node_status(self, node_execution: WorkflowNodeExecution) -> Status: @@ -306,7 +307,7 @@ class AliyunDataTrace(BaseTraceInstance): start_time=convert_datetime_to_nanoseconds(node_execution.created_at), end_time=convert_datetime_to_nanoseconds(node_execution.finished_at), attributes={ - GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""), + GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", GEN_AI_SPAN_KIND: GenAISpanKind.TASK.value, GEN_AI_FRAMEWORK: "dify", INPUT_VALUE: json.dumps(node_execution.inputs, ensure_ascii=False), @@ -381,7 +382,7 @@ class AliyunDataTrace(BaseTraceInstance): start_time=convert_datetime_to_nanoseconds(node_execution.created_at), end_time=convert_datetime_to_nanoseconds(node_execution.finished_at), attributes={ - GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""), + GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", GEN_AI_SPAN_KIND: GenAISpanKind.LLM.value, GEN_AI_FRAMEWORK: "dify", GEN_AI_MODEL_NAME: process_data.get("model_name", ""), @@ -415,7 +416,7 @@ class AliyunDataTrace(BaseTraceInstance): start_time=convert_datetime_to_nanoseconds(trace_info.start_time), end_time=convert_datetime_to_nanoseconds(trace_info.end_time), attributes={ - GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id", ""), + GEN_AI_SESSION_ID: trace_info.metadata.get("conversation_id") or "", GEN_AI_USER_ID: str(user_id), GEN_AI_SPAN_KIND: GenAISpanKind.CHAIN.value, GEN_AI_FRAMEWORK: "dify", diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index ffda0885d4..8b3ce0c448 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -3,7 +3,7 @@ import json import logging import os from datetime import datetime, timedelta -from typing import Optional, Union, cast +from typing import Any, Optional, Union, cast from openinference.semconv.trace import OpenInferenceSpanKindValues, SpanAttributes from opentelemetry import trace @@ -142,11 +142,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): raise def workflow_trace(self, trace_info: WorkflowTraceInfo): - if trace_info.message_data is None: - return - workflow_metadata = { - "workflow_id": trace_info.workflow_run_id or "", + "workflow_run_id": trace_info.workflow_run_id or "", "message_id": trace_info.message_id or "", "workflow_app_log_id": trace_info.workflow_app_log_id or "", "status": trace_info.workflow_run_status or "", @@ -156,7 +153,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): } workflow_metadata.update(trace_info.metadata) - trace_id = uuid_to_trace_id(trace_info.message_id) + trace_id = uuid_to_trace_id(trace_info.workflow_run_id) span_id = RandomIdGenerator().generate_span_id() context = SpanContext( trace_id=trace_id, @@ -213,7 +210,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): if model: node_metadata["ls_model_name"] = model - outputs = json.loads(node_execution.outputs).get("usage", {}) + outputs = json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {} usage_data = process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) if usage_data: node_metadata["total_tokens"] = usage_data.get("total_tokens", 0) @@ -236,31 +233,34 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.SESSION_ID: trace_info.conversation_id or "", }, start_time=datetime_to_nanos(created_at), + context=trace.set_span_in_context(trace.NonRecordingSpan(context)), ) try: if node_execution.node_type == "llm": + llm_attributes: dict[str, Any] = { + SpanAttributes.INPUT_VALUE: json.dumps(process_data.get("prompts", []), ensure_ascii=False), + } provider = process_data.get("model_provider") model = process_data.get("model_name") if provider: - node_span.set_attribute(SpanAttributes.LLM_PROVIDER, provider) + llm_attributes[SpanAttributes.LLM_PROVIDER] = provider if model: - node_span.set_attribute(SpanAttributes.LLM_MODEL_NAME, model) - - outputs = json.loads(node_execution.outputs).get("usage", {}) + llm_attributes[SpanAttributes.LLM_MODEL_NAME] = model + outputs = ( + json.loads(node_execution.outputs).get("usage", {}) if "outputs" in node_execution else {} + ) usage_data = ( process_data.get("usage", {}) if "usage" in process_data else outputs.get("usage", {}) ) if usage_data: - node_span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_TOTAL, usage_data.get("total_tokens", 0) - ) - node_span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_PROMPT, usage_data.get("prompt_tokens", 0) - ) - node_span.set_attribute( - SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, usage_data.get("completion_tokens", 0) + llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = usage_data.get("total_tokens", 0) + llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_PROMPT] = usage_data.get("prompt_tokens", 0) + llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_COMPLETION] = usage_data.get( + "completion_tokens", 0 ) + llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", []))) + node_span.set_attributes(llm_attributes) finally: node_span.end(end_time=datetime_to_nanos(finished_at)) finally: @@ -352,25 +352,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): SpanAttributes.METADATA: json.dumps(message_metadata, ensure_ascii=False), SpanAttributes.SESSION_ID: trace_info.message_data.conversation_id, } - - if isinstance(trace_info.inputs, list): - for i, msg in enumerate(trace_info.inputs): - if isinstance(msg, dict): - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "") - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get( - "role", "user" - ) - # todo: handle assistant and tool role messages, as they don't always - # have a text field, but may have a tool_calls field instead - # e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58', - # 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]} - elif isinstance(trace_info.inputs, dict): - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(trace_info.inputs) - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" - elif isinstance(trace_info.inputs, str): - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = trace_info.inputs - llm_attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" - + llm_attributes.update(self._construct_llm_attributes(trace_info.inputs)) if trace_info.total_tokens is not None and trace_info.total_tokens > 0: llm_attributes[SpanAttributes.LLM_TOKEN_COUNT_TOTAL] = trace_info.total_tokens if trace_info.message_tokens is not None and trace_info.message_tokens > 0: @@ -724,3 +706,24 @@ class ArizePhoenixDataTrace(BaseTraceInstance): .all() ) return workflow_nodes + + def _construct_llm_attributes(self, prompts: dict | list | str | None) -> dict[str, str]: + """Helper method to construct LLM attributes with passed prompts.""" + attributes = {} + if isinstance(prompts, list): + for i, msg in enumerate(prompts): + if isinstance(msg, dict): + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.content"] = msg.get("text", "") + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.{i}.message.role"] = msg.get("role", "user") + # todo: handle assistant and tool role messages, as they don't always + # have a text field, but may have a tool_calls field instead + # e.g. 'tool_calls': [{'id': '98af3a29-b066-45a5-b4b1-46c74ddafc58', + # 'type': 'function', 'function': {'name': 'current_time', 'arguments': '{}'}}]} + elif isinstance(prompts, dict): + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = json.dumps(prompts) + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" + elif isinstance(prompts, str): + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.content"] = prompts + attributes[f"{SpanAttributes.LLM_INPUT_MESSAGES}.0.message.role"] = "user" + + return attributes diff --git a/api/core/plugin/backwards_invocation/encrypt.py b/api/core/plugin/backwards_invocation/encrypt.py index 81a5d033a0..213f5c726a 100644 --- a/api/core/plugin/backwards_invocation/encrypt.py +++ b/api/core/plugin/backwards_invocation/encrypt.py @@ -1,16 +1,20 @@ +from core.helper.provider_cache import SingletonProviderCredentialsCache from core.plugin.entities.request import RequestInvokeEncrypt -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import create_provider_encrypter from models.account import Tenant class PluginEncrypter: @classmethod def invoke_encrypt(cls, tenant: Tenant, payload: RequestInvokeEncrypt) -> dict: - encrypter = ProviderConfigEncrypter( + encrypter, cache = create_provider_encrypter( tenant_id=tenant.id, config=payload.config, - provider_type=payload.namespace, - provider_identity=payload.identity, + cache=SingletonProviderCredentialsCache( + tenant_id=tenant.id, + provider_type=payload.namespace, + provider_identity=payload.identity, + ), ) if payload.opt == "encrypt": @@ -22,7 +26,7 @@ class PluginEncrypter: "data": encrypter.decrypt(payload.data), } elif payload.opt == "clear": - encrypter.delete_tool_credentials_cache() + cache.delete() return { "data": {}, } diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py index 1d62743f13..06773504d9 100644 --- a/api/core/plugin/backwards_invocation/tool.py +++ b/api/core/plugin/backwards_invocation/tool.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from typing import Any +from typing import Any, Optional from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.plugin.backwards_invocation.base import BaseBackwardsInvocation @@ -23,6 +23,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): provider: str, tool_name: str, tool_parameters: dict[str, Any], + credential_id: Optional[str] = None, ) -> Generator[ToolInvokeMessage, None, None]: """ invoke tool @@ -30,7 +31,7 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): # get tool runtime try: tool_runtime = ToolManager.get_tool_runtime_from_plugin( - tool_type, tenant_id, provider, tool_name, tool_parameters + tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id ) response = ToolEngine.generic_invoke( tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1 diff --git a/api/core/plugin/entities/parameters.py b/api/core/plugin/entities/parameters.py index 2be65d67a0..47290ee613 100644 --- a/api/core/plugin/entities/parameters.py +++ b/api/core/plugin/entities/parameters.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, Field, field_validator from core.entities.parameter_entities import CommonParameterType from core.tools.entities.common_entities import I18nObject +from core.workflow.nodes.base.entities import NumberType class PluginParameterOption(BaseModel): @@ -38,6 +39,7 @@ class PluginParameterType(enum.StrEnum): APP_SELECTOR = CommonParameterType.APP_SELECTOR.value MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value + ANY = CommonParameterType.ANY.value DYNAMIC_SELECT = CommonParameterType.DYNAMIC_SELECT.value # deprecated, should not use. @@ -151,6 +153,10 @@ def cast_parameter_value(typ: enum.StrEnum, value: Any, /): if value and not isinstance(value, list): raise ValueError("The tools selector must be a list.") return value + case PluginParameterType.ANY: + if value and not isinstance(value, str | dict | list | NumberType): + raise ValueError("The var selector must be a string, dictionary, list or number.") + return value case PluginParameterType.ARRAY: if not isinstance(value, list): # Try to parse JSON string for arrays diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index e5cf7ee03a..a07b58d9ea 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -135,17 +135,6 @@ class PluginEntity(PluginInstallation): return self -class GithubPackage(BaseModel): - repo: str - version: str - package: str - - -class GithubVersion(BaseModel): - repo: str - version: str - - class GenericProviderID: organization: str plugin_name: str diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index 89f595ec46..3a783dad3e 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -27,6 +27,20 @@ from core.workflow.nodes.question_classifier.entities import ( ) +class InvokeCredentials(BaseModel): + tool_credentials: dict[str, str] = Field( + default_factory=dict, + description="Map of tool provider to credential id, used to store the credential id for the tool provider.", + ) + + +class PluginInvokeContext(BaseModel): + credentials: Optional[InvokeCredentials] = Field( + default_factory=InvokeCredentials, + description="Credentials context for the plugin invocation or backward invocation.", + ) + + class RequestInvokeTool(BaseModel): """ Request to invoke a tool @@ -36,6 +50,7 @@ class RequestInvokeTool(BaseModel): provider: str tool: str tool_parameters: dict + credential_id: Optional[str] = None class BaseRequestInvokeModel(BaseModel): diff --git a/api/core/plugin/impl/agent.py b/api/core/plugin/impl/agent.py index 66b77c7489..9575c57ac8 100644 --- a/api/core/plugin/impl/agent.py +++ b/api/core/plugin/impl/agent.py @@ -6,6 +6,7 @@ from core.plugin.entities.plugin import GenericProviderID from core.plugin.entities.plugin_daemon import ( PluginAgentProviderEntity, ) +from core.plugin.entities.request import PluginInvokeContext from core.plugin.impl.base import BasePluginClient @@ -83,6 +84,7 @@ class PluginAgentClient(BasePluginClient): conversation_id: Optional[str] = None, app_id: Optional[str] = None, message_id: Optional[str] = None, + context: Optional[PluginInvokeContext] = None, ) -> Generator[AgentInvokeMessage, None, None]: """ Invoke the agent with the given tenant, user, plugin, provider, name and parameters. @@ -99,6 +101,7 @@ class PluginAgentClient(BasePluginClient): "conversation_id": conversation_id, "app_id": app_id, "message_id": message_id, + "context": context.model_dump() if context else {}, "data": { "agent_strategy_provider": agent_provider_id.provider_name, "agent_strategy": agent_strategy, diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py index b006bf1d4b..d73e5d9f9e 100644 --- a/api/core/plugin/impl/oauth.py +++ b/api/core/plugin/impl/oauth.py @@ -15,27 +15,32 @@ class OAuthHandler(BasePluginClient): user_id: str, plugin_id: str, provider: str, + redirect_uri: str, system_credentials: Mapping[str, Any], ) -> PluginOAuthAuthorizationUrlResponse: - response = self._request_with_plugin_daemon_response_stream( - "POST", - f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", - PluginOAuthAuthorizationUrlResponse, - data={ - "user_id": user_id, - "data": { - "provider": provider, - "system_credentials": system_credentials, + try: + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/oauth/get_authorization_url", + PluginOAuthAuthorizationUrlResponse, + data={ + "user_id": user_id, + "data": { + "provider": provider, + "redirect_uri": redirect_uri, + "system_credentials": system_credentials, + }, }, - }, - headers={ - "X-Plugin-ID": plugin_id, - "Content-Type": "application/json", - }, - ) - for resp in response: - return resp - raise ValueError("No response received from plugin daemon for authorization URL request.") + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", + }, + ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") + except Exception as e: + raise ValueError(f"Error getting authorization URL: {e}") def get_credentials( self, @@ -43,6 +48,7 @@ class OAuthHandler(BasePluginClient): user_id: str, plugin_id: str, provider: str, + redirect_uri: str, system_credentials: Mapping[str, Any], request: Request, ) -> PluginOAuthCredentialsResponse: @@ -50,30 +56,33 @@ class OAuthHandler(BasePluginClient): Get credentials from the given request. """ - # encode request to raw http request - raw_request_bytes = self._convert_request_to_raw_data(request) - - response = self._request_with_plugin_daemon_response_stream( - "POST", - f"plugin/{tenant_id}/dispatch/oauth/get_credentials", - PluginOAuthCredentialsResponse, - data={ - "user_id": user_id, - "data": { - "provider": provider, - "system_credentials": system_credentials, - # for json serialization - "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), + try: + # encode request to raw http request + raw_request_bytes = self._convert_request_to_raw_data(request) + response = self._request_with_plugin_daemon_response_stream( + "POST", + f"plugin/{tenant_id}/dispatch/oauth/get_credentials", + PluginOAuthCredentialsResponse, + data={ + "user_id": user_id, + "data": { + "provider": provider, + "redirect_uri": redirect_uri, + "system_credentials": system_credentials, + # for json serialization + "raw_http_request": binascii.hexlify(raw_request_bytes).decode(), + }, + }, + headers={ + "X-Plugin-ID": plugin_id, + "Content-Type": "application/json", }, - }, - headers={ - "X-Plugin-ID": plugin_id, - "Content-Type": "application/json", - }, - ) - for resp in response: - return resp - raise ValueError("No response received from plugin daemon for authorization URL request.") + ) + for resp in response: + return resp + raise ValueError("No response received from plugin daemon for authorization URL request.") + except Exception as e: + raise ValueError(f"Error getting credentials: {e}") def _convert_request_to_raw_data(self, request: Request) -> bytes: """ diff --git a/api/core/plugin/impl/tool.py b/api/core/plugin/impl/tool.py index 19b26c8fe3..04225f95ee 100644 --- a/api/core/plugin/impl/tool.py +++ b/api/core/plugin/impl/tool.py @@ -6,7 +6,7 @@ from pydantic import BaseModel from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity from core.plugin.impl.base import BasePluginClient -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter +from core.tools.entities.tool_entities import CredentialType, ToolInvokeMessage, ToolParameter class PluginToolManager(BasePluginClient): @@ -78,6 +78,7 @@ class PluginToolManager(BasePluginClient): tool_provider: str, tool_name: str, credentials: dict[str, Any], + credential_type: CredentialType, tool_parameters: dict[str, Any], conversation_id: Optional[str] = None, app_id: Optional[str] = None, @@ -102,6 +103,7 @@ class PluginToolManager(BasePluginClient): "provider": tool_provider_id.provider_name, "tool": tool_name, "credentials": credentials, + "credential_type": credential_type, "tool_parameters": tool_parameters, }, }, diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index 25964ae063..0f0fe65f27 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -158,7 +158,7 @@ class AdvancedPromptTransform(PromptTransform): if prompt_item.edition_type == "basic" or not prompt_item.edition_type: if self.with_variable_tmpl: - vp = VariablePool() + vp = VariablePool.empty() for k, v in inputs.items(): if k.startswith("#"): vp.add(k[1:-1].split("."), v) diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 47808928f7..e19c6419ca 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -29,19 +29,6 @@ class ModelMode(enum.StrEnum): COMPLETION = "completion" CHAT = "chat" - @classmethod - def value_of(cls, value: str) -> "ModelMode": - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid mode value {value}") - prompt_file_contents: dict[str, Any] = {} @@ -65,7 +52,7 @@ class SimplePromptTransform(PromptTransform): ) -> tuple[list[PromptMessage], Optional[list[str]]]: inputs = {key: str(value) for key, value in inputs.items()} - model_mode = ModelMode.value_of(model_config.mode) + model_mode = ModelMode(model_config.mode) if model_mode == ModelMode.CHAT: prompt_messages, stops = self._get_chat_model_prompt_messages( app_mode=app_mode, diff --git a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py deleted file mode 100644 index 167a919e69..0000000000 --- a/api/core/rag/cleaner/unstructured/unstructured_extra_whitespace_cleaner.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Abstract interface for document clean implementations.""" - -from core.rag.cleaner.cleaner_base import BaseCleaner - - -class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: - """clean document content.""" - from unstructured.cleaners.core import clean_extra_whitespace - - # Returns "ITEM 1A: RISK FACTORS" - return clean_extra_whitespace(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py deleted file mode 100644 index 9c682d29db..0000000000 --- a/api/core/rag/cleaner/unstructured/unstructured_group_broken_paragraphs_cleaner.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Abstract interface for document clean implementations.""" - -from core.rag.cleaner.cleaner_base import BaseCleaner - - -class UnstructuredGroupBrokenParagraphsCleaner(BaseCleaner): - def clean(self, content) -> str: - """clean document content.""" - import re - - from unstructured.cleaners.core import group_broken_paragraphs - - para_split_re = re.compile(r"(\s*\n\s*){3}") - - return group_broken_paragraphs(content, paragraph_split=para_split_re) diff --git a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py deleted file mode 100644 index 0cdbb171e1..0000000000 --- a/api/core/rag/cleaner/unstructured/unstructured_non_ascii_chars_cleaner.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Abstract interface for document clean implementations.""" - -from core.rag.cleaner.cleaner_base import BaseCleaner - - -class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: - """clean document content.""" - from unstructured.cleaners.core import clean_non_ascii_chars - - # Returns "This text contains non-ascii characters!" - return clean_non_ascii_chars(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py deleted file mode 100644 index 9f42044a2d..0000000000 --- a/api/core/rag/cleaner/unstructured/unstructured_replace_unicode_quotes_cleaner.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Abstract interface for document clean implementations.""" - -from core.rag.cleaner.cleaner_base import BaseCleaner - - -class UnstructuredNonAsciiCharsCleaner(BaseCleaner): - def clean(self, content) -> str: - """Replaces unicode quote characters, such as the \x91 character in a string.""" - - from unstructured.cleaners.core import replace_unicode_quotes - - return replace_unicode_quotes(content) diff --git a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py b/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py deleted file mode 100644 index 32ae7217e8..0000000000 --- a/api/core/rag/cleaner/unstructured/unstructured_translate_text_cleaner.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Abstract interface for document clean implementations.""" - -from core.rag.cleaner.cleaner_base import BaseCleaner - - -class UnstructuredTranslateTextCleaner(BaseCleaner): - def clean(self, content) -> str: - """clean document content.""" - from unstructured.cleaners.translate import translate_text - - return translate_text(content) diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index 095752ea8e..6f3e15d166 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -233,6 +233,12 @@ class AnalyticdbVectorOpenAPI: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause += f"metadata_->>'document_id' IN ({document_ids})" + score_threshold = kwargs.get("score_threshold") or 0.0 request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, @@ -245,7 +251,7 @@ class AnalyticdbVectorOpenAPI: vector=query_vector, content=None, top_k=kwargs.get("top_k", 4), - filter=None, + filter=where_clause, ) response = self._client.query_collection_data(request) documents = [] @@ -265,6 +271,11 @@ class AnalyticdbVectorOpenAPI: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause += f"metadata_->>'document_id' IN ({document_ids})" score_threshold = float(kwargs.get("score_threshold") or 0.0) request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, @@ -277,7 +288,7 @@ class AnalyticdbVectorOpenAPI: vector=None, content=query, top_k=kwargs.get("top_k", 4), - filter=None, + filter=where_clause, ) response = self._client.query_collection_data(request) documents = [] diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 44cc5d3e98..ad39717183 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -147,10 +147,17 @@ class ElasticSearchVector(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - query_str = {"match": {Field.CONTENT_KEY.value: query}} + query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY.value: query}} document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: - query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} # type: ignore + query_str = { + "bool": { + "must": {"match": {Field.CONTENT_KEY.value: query}}, + "filter": {"terms": {"metadata.document_id": document_ids_filter}}, + } + } + results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4)) docs = [] for hit in results["hits"]["hits"]: diff --git a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py b/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py deleted file mode 100644 index 1e62b3c589..0000000000 --- a/api/core/rag/datasource/vdb/tidb_on_qdrant/tidb_entities.py +++ /dev/null @@ -1,17 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - - -class ClusterEntity(BaseModel): - """ - Model Config Entity. - """ - - name: str - cluster_id: str - displayName: str - region: str - spendingLimit: Optional[int] = 1000 - version: str - createdBy: str diff --git a/api/core/rag/extractor/blob/blob.py b/api/core/rag/extractor/blob/blob.py index e46ab8b7fd..01003a13b6 100644 --- a/api/core/rag/extractor/blob/blob.py +++ b/api/core/rag/extractor/blob/blob.py @@ -9,8 +9,7 @@ from __future__ import annotations import contextlib import mimetypes -from abc import ABC, abstractmethod -from collections.abc import Generator, Iterable, Mapping +from collections.abc import Generator, Mapping from io import BufferedReader, BytesIO from pathlib import Path, PurePath from typing import Any, Optional, Union @@ -143,21 +142,3 @@ class Blob(BaseModel): if self.source: str_repr += f" {self.source}" return str_repr - - -class BlobLoader(ABC): - """Abstract interface for blob loaders implementation. - - Implementer should be able to load raw content from a datasource system according - to some criteria and return the raw content lazily as a stream of blobs. - """ - - @abstractmethod - def yield_blobs( - self, - ) -> Iterable[Blob]: - """A lazy loader for raw data represented by Blob object. - - Returns: - A generator over blobs - """ diff --git a/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py b/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py deleted file mode 100644 index dd8a979e70..0000000000 --- a/api/core/rag/extractor/unstructured/unstructured_pdf_extractor.py +++ /dev/null @@ -1,47 +0,0 @@ -import logging - -from core.rag.extractor.extractor_base import BaseExtractor -from core.rag.models.document import Document - -logger = logging.getLogger(__name__) - - -class UnstructuredPDFExtractor(BaseExtractor): - """Load pdf files. - - - Args: - file_path: Path to the file to load. - - api_url: Unstructured API URL - - api_key: Unstructured API Key - """ - - def __init__(self, file_path: str, api_url: str, api_key: str): - """Initialize with file path.""" - self._file_path = file_path - self._api_url = api_url - self._api_key = api_key - - def extract(self) -> list[Document]: - if self._api_url: - from unstructured.partition.api import partition_via_api - - elements = partition_via_api( - filename=self._file_path, api_url=self._api_url, api_key=self._api_key, strategy="auto" - ) - else: - from unstructured.partition.pdf import partition_pdf - - elements = partition_pdf(filename=self._file_path, strategy="auto") - - from unstructured.chunking.title import chunk_by_title - - chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) - documents = [] - for chunk in chunks: - text = chunk.text.strip() - documents.append(Document(page_content=text)) - - return documents diff --git a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py b/api/core/rag/extractor/unstructured/unstructured_text_extractor.py deleted file mode 100644 index 22dfdd2075..0000000000 --- a/api/core/rag/extractor/unstructured/unstructured_text_extractor.py +++ /dev/null @@ -1,34 +0,0 @@ -import logging - -from core.rag.extractor.extractor_base import BaseExtractor -from core.rag.models.document import Document - -logger = logging.getLogger(__name__) - - -class UnstructuredTextExtractor(BaseExtractor): - """Load msg files. - - - Args: - file_path: Path to the file to load. - """ - - def __init__(self, file_path: str, api_url: str): - """Initialize with file path.""" - self._file_path = file_path - self._api_url = api_url - - def extract(self) -> list[Document]: - from unstructured.partition.text import partition_text - - elements = partition_text(filename=self._file_path) - from unstructured.chunking.title import chunk_by_title - - chunks = chunk_by_title(elements, max_characters=2000, combine_text_under_n_chars=2000) - documents = [] - for chunk in chunks: - text = chunk.text.strip() - documents.append(Document(page_content=text)) - - return documents diff --git a/api/core/rag/extractor/word_extractor.py b/api/core/rag/extractor/word_extractor.py index bff0acc48f..14363de7d4 100644 --- a/api/core/rag/extractor/word_extractor.py +++ b/api/core/rag/extractor/word_extractor.py @@ -238,9 +238,11 @@ class WordExtractor(BaseExtractor): paragraph_content = [] for run in paragraph.runs: if hasattr(run.element, "tag") and isinstance(run.element.tag, str) and run.element.tag.endswith("r"): + # Process drawing type images drawing_elements = run.element.findall( ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}drawing" ) + has_drawing = False for drawing in drawing_elements: blip_elements = drawing.findall( ".//{http://schemas.openxmlformats.org/drawingml/2006/main}blip" @@ -252,6 +254,34 @@ class WordExtractor(BaseExtractor): if embed_id: image_part = doc.part.related_parts.get(embed_id) if image_part in image_map: + has_drawing = True + paragraph_content.append(image_map[image_part]) + # Process pict type images + shape_elements = run.element.findall( + ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}pict" + ) + for shape in shape_elements: + # Find image data in VML + shape_image = shape.find( + ".//{http://schemas.openxmlformats.org/wordprocessingml/2006/main}binData" + ) + if shape_image is not None and shape_image.text: + image_id = shape_image.get( + "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" + ) + if image_id and image_id in doc.part.rels: + image_part = doc.part.rels[image_id].target_part + if image_part in image_map and not has_drawing: + paragraph_content.append(image_map[image_part]) + # Find imagedata element in VML + image_data = shape.find(".//{urn:schemas-microsoft-com:vml}imagedata") + if image_data is not None: + image_id = image_data.get("id") or image_data.get( + "{http://schemas.openxmlformats.org/officeDocument/2006/relationships}id" + ) + if image_id and image_id in doc.part.rels: + image_part = doc.part.rels[image_id].target_part + if image_part in image_map and not has_drawing: paragraph_content.append(image_map[image_part]) if run.text.strip(): paragraph_content.append(run.text.strip()) diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 5c0360b064..3d0f0f97bc 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1137,7 +1137,7 @@ class DatasetRetrieval: def _get_prompt_template( self, model_config: ModelConfigWithCredentialsEntity, mode: str, metadata_fields: list, query: str ): - model_mode = ModelMode.value_of(mode) + model_mode = ModelMode(mode) input_text = query prompt_template: Union[CompletionModelPromptTemplate, list[ChatModelMessage]] diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 0fb1bcb2e0..bcaf299892 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -102,6 +102,7 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter) splits = text.split() else: splits = text.split(separator) + splits = [item + separator if i < len(splits) else item for i, item in enumerate(splits)] else: splits = list(text) splits = [s for s in splits if (s not in {"", "\n"})] diff --git a/api/core/rag/splitter/text_splitter.py b/api/core/rag/splitter/text_splitter.py index b711e8434a..529d8ccd27 100644 --- a/api/core/rag/splitter/text_splitter.py +++ b/api/core/rag/splitter/text_splitter.py @@ -10,7 +10,6 @@ from typing import ( Any, Literal, Optional, - TypedDict, TypeVar, Union, ) @@ -168,167 +167,6 @@ class TextSplitter(BaseDocumentTransformer, ABC): raise NotImplementedError -class CharacterTextSplitter(TextSplitter): - """Splitting text that looks at characters.""" - - def __init__(self, separator: str = "\n\n", **kwargs: Any) -> None: - """Create a new TextSplitter.""" - super().__init__(**kwargs) - self._separator = separator - - def split_text(self, text: str) -> list[str]: - """Split incoming text and return chunks.""" - # First we naively split the large input into a bunch of smaller ones. - splits = _split_text_with_regex(text, self._separator, self._keep_separator) - _separator = "" if self._keep_separator else self._separator - _good_splits_lengths = [] # cache the lengths of the splits - if splits: - _good_splits_lengths.extend(self._length_function(splits)) - return self._merge_splits(splits, _separator, _good_splits_lengths) - - -class LineType(TypedDict): - """Line type as typed dict.""" - - metadata: dict[str, str] - content: str - - -class HeaderType(TypedDict): - """Header type as typed dict.""" - - level: int - name: str - data: str - - -class MarkdownHeaderTextSplitter: - """Splitting markdown files based on specified headers.""" - - def __init__(self, headers_to_split_on: list[tuple[str, str]], return_each_line: bool = False): - """Create a new MarkdownHeaderTextSplitter. - - Args: - headers_to_split_on: Headers we want to track - return_each_line: Return each line w/ associated headers - """ - # Output line-by-line or aggregated into chunks w/ common headers - self.return_each_line = return_each_line - # Given the headers we want to split on, - # (e.g., "#, ##, etc") order by length - self.headers_to_split_on = sorted(headers_to_split_on, key=lambda split: len(split[0]), reverse=True) - - def aggregate_lines_to_chunks(self, lines: list[LineType]) -> list[Document]: - """Combine lines with common metadata into chunks - Args: - lines: Line of text / associated header metadata - """ - aggregated_chunks: list[LineType] = [] - - for line in lines: - if aggregated_chunks and aggregated_chunks[-1]["metadata"] == line["metadata"]: - # If the last line in the aggregated list - # has the same metadata as the current line, - # append the current content to the last lines's content - aggregated_chunks[-1]["content"] += " \n" + line["content"] - else: - # Otherwise, append the current line to the aggregated list - aggregated_chunks.append(line) - - return [Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in aggregated_chunks] - - def split_text(self, text: str) -> list[Document]: - """Split markdown file - Args: - text: Markdown file""" - - # Split the input text by newline character ("\n"). - lines = text.split("\n") - # Final output - lines_with_metadata: list[LineType] = [] - # Content and metadata of the chunk currently being processed - current_content: list[str] = [] - current_metadata: dict[str, str] = {} - # Keep track of the nested header structure - # header_stack: List[Dict[str, Union[int, str]]] = [] - header_stack: list[HeaderType] = [] - initial_metadata: dict[str, str] = {} - - for line in lines: - stripped_line = line.strip() - # Check each line against each of the header types (e.g., #, ##) - for sep, name in self.headers_to_split_on: - # Check if line starts with a header that we intend to split on - if stripped_line.startswith(sep) and ( - # Header with no text OR header is followed by space - # Both are valid conditions that sep is being used a header - len(stripped_line) == len(sep) or stripped_line[len(sep)] == " " - ): - # Ensure we are tracking the header as metadata - if name is not None: - # Get the current header level - current_header_level = sep.count("#") - - # Pop out headers of lower or same level from the stack - while header_stack and header_stack[-1]["level"] >= current_header_level: - # We have encountered a new header - # at the same or higher level - popped_header = header_stack.pop() - # Clear the metadata for the - # popped header in initial_metadata - if popped_header["name"] in initial_metadata: - initial_metadata.pop(popped_header["name"]) - - # Push the current header to the stack - header: HeaderType = { - "level": current_header_level, - "name": name, - "data": stripped_line[len(sep) :].strip(), - } - header_stack.append(header) - # Update initial_metadata with the current header - initial_metadata[name] = header["data"] - - # Add the previous line to the lines_with_metadata - # only if current_content is not empty - if current_content: - lines_with_metadata.append( - { - "content": "\n".join(current_content), - "metadata": current_metadata.copy(), - } - ) - current_content.clear() - - break - else: - if stripped_line: - current_content.append(stripped_line) - elif current_content: - lines_with_metadata.append( - { - "content": "\n".join(current_content), - "metadata": current_metadata.copy(), - } - ) - current_content.clear() - - current_metadata = initial_metadata.copy() - - if current_content: - lines_with_metadata.append({"content": "\n".join(current_content), "metadata": current_metadata}) - - # lines_with_metadata has each line with associated header metadata - # aggregate these into chunks based on common metadata - if not self.return_each_line: - return self.aggregate_lines_to_chunks(lines_with_metadata) - else: - return [ - Document(page_content=chunk["content"], metadata=chunk["metadata"]) for chunk in lines_with_metadata - ] - - -# should be in newer Python versions (3.10+) # @dataclass(frozen=True, kw_only=True, slots=True) @dataclass(frozen=True) class Tokenizer: diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 0b3e5eb424..c579ff4028 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -6,7 +6,6 @@ import json import logging from typing import Optional, Union -from sqlalchemy import select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -206,44 +205,3 @@ class SQLAlchemyWorkflowExecutionRepository(WorkflowExecutionRepository): # Update the in-memory cache for faster subsequent lookups logger.debug(f"Updating cache for execution_id: {db_model.id}") self._execution_cache[db_model.id] = db_model - - def get(self, execution_id: str) -> Optional[WorkflowExecution]: - """ - Retrieve a WorkflowExecution by its ID. - - First checks the in-memory cache, and if not found, queries the database. - If found in the database, adds it to the cache for future lookups. - - Args: - execution_id: The workflow execution ID - - Returns: - The WorkflowExecution instance if found, None otherwise - """ - # First check the cache - if execution_id in self._execution_cache: - logger.debug(f"Cache hit for execution_id: {execution_id}") - # Convert cached DB model to domain model - cached_db_model = self._execution_cache[execution_id] - return self._to_domain_model(cached_db_model) - - # If not in cache, query the database - logger.debug(f"Cache miss for execution_id: {execution_id}, querying database") - with self._session_factory() as session: - stmt = select(WorkflowRun).where( - WorkflowRun.id == execution_id, - WorkflowRun.tenant_id == self._tenant_id, - ) - - if self._app_id: - stmt = stmt.where(WorkflowRun.app_id == self._app_id) - - db_model = session.scalar(stmt) - if db_model: - # Add DB model to cache - self._execution_cache[execution_id] = db_model - - # Convert to domain model and return - return self._to_domain_model(db_model) - - return None diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index a5feeb0d7c..d4a31390f8 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -7,7 +7,7 @@ import logging from collections.abc import Sequence from typing import Optional, Union -from sqlalchemy import UnaryExpression, asc, delete, desc, select +from sqlalchemy import UnaryExpression, asc, desc, select from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker @@ -218,47 +218,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) logger.debug(f"Updating cache for node_execution_id: {db_model.node_execution_id}") self._node_execution_cache[db_model.node_execution_id] = db_model - def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: - """ - Retrieve a NodeExecution by its node_execution_id. - - First checks the in-memory cache, and if not found, queries the database. - If found in the database, adds it to the cache for future lookups. - - Args: - node_execution_id: The node execution ID - - Returns: - The NodeExecution instance if found, None otherwise - """ - # First check the cache - if node_execution_id in self._node_execution_cache: - logger.debug(f"Cache hit for node_execution_id: {node_execution_id}") - # Convert cached DB model to domain model - cached_db_model = self._node_execution_cache[node_execution_id] - return self._to_domain_model(cached_db_model) - - # If not in cache, query the database - logger.debug(f"Cache miss for node_execution_id: {node_execution_id}, querying database") - with self._session_factory() as session: - stmt = select(WorkflowNodeExecutionModel).where( - WorkflowNodeExecutionModel.node_execution_id == node_execution_id, - WorkflowNodeExecutionModel.tenant_id == self._tenant_id, - ) - - if self._app_id: - stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) - - db_model = session.scalar(stmt) - if db_model: - # Add DB model to cache - self._node_execution_cache[node_execution_id] = db_model - - # Convert to domain model and return - return self._to_domain_model(db_model) - - return None - def get_db_models_by_workflow_run( self, workflow_run_id: str, @@ -344,68 +303,3 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) domain_models.append(domain_model) return domain_models - - def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: - """ - Retrieve all running NodeExecution instances for a specific workflow run. - - This method queries the database directly and updates the cache with any - retrieved executions that have a node_execution_id. - - Args: - workflow_run_id: The workflow run ID - - Returns: - A list of running NodeExecution instances - """ - with self._session_factory() as session: - stmt = select(WorkflowNodeExecutionModel).where( - WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, - WorkflowNodeExecutionModel.tenant_id == self._tenant_id, - WorkflowNodeExecutionModel.status == WorkflowNodeExecutionStatus.RUNNING, - WorkflowNodeExecutionModel.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - if self._app_id: - stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) - - db_models = session.scalars(stmt).all() - domain_models = [] - - for model in db_models: - # Update cache if node_execution_id is present - if model.node_execution_id: - self._node_execution_cache[model.node_execution_id] = model - - # Convert to domain model - domain_model = self._to_domain_model(model) - domain_models.append(domain_model) - - return domain_models - - def clear(self) -> None: - """ - Clear all WorkflowNodeExecution records for the current tenant_id and app_id. - - This method deletes all WorkflowNodeExecution records that match the tenant_id - and app_id (if provided) associated with this repository instance. - It also clears the in-memory cache. - """ - with self._session_factory() as session: - stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.tenant_id == self._tenant_id) - - if self._app_id: - stmt = stmt.where(WorkflowNodeExecutionModel.app_id == self._app_id) - - result = session.execute(stmt) - session.commit() - - deleted_count = result.rowcount - logger.info( - f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}" - + (f" and app {self._app_id}" if self._app_id else "") - ) - - # Clear the in-memory cache - self._node_execution_cache.clear() - logger.info("Cleared in-memory node execution cache") diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index c9e157cb77..ddec7b1329 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -4,7 +4,7 @@ from openai import BaseModel from pydantic import Field from core.app.entities.app_invoke_entities import InvokeFrom -from core.tools.entities.tool_entities import ToolInvokeFrom +from core.tools.entities.tool_entities import CredentialType, ToolInvokeFrom class ToolRuntime(BaseModel): @@ -17,6 +17,7 @@ class ToolRuntime(BaseModel): invoke_from: Optional[InvokeFrom] = None tool_invoke_from: Optional[ToolInvokeFrom] = None credentials: dict[str, Any] = Field(default_factory=dict) + credential_type: CredentialType = Field(default=CredentialType.API_KEY) runtime_parameters: dict[str, Any] = Field(default_factory=dict) diff --git a/api/core/tools/builtin_tool/provider.py b/api/core/tools/builtin_tool/provider.py index cf75bd3d7e..a70ded9efd 100644 --- a/api/core/tools/builtin_tool/provider.py +++ b/api/core/tools/builtin_tool/provider.py @@ -7,7 +7,13 @@ from core.helper.module_import_helper import load_single_subclass_from_source from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool -from core.tools.entities.tool_entities import ToolEntity, ToolProviderEntity, ToolProviderType +from core.tools.entities.tool_entities import ( + CredentialType, + OAuthSchema, + ToolEntity, + ToolProviderEntity, + ToolProviderType, +) from core.tools.entities.values import ToolLabelEnum, default_tool_label_dict from core.tools.errors import ( ToolProviderNotFoundError, @@ -39,10 +45,18 @@ class BuiltinToolProviderController(ToolProviderController): credential_dict = provider_yaml.get("credentials_for_provider", {}).get(credential, {}) credentials_schema.append(credential_dict) + oauth_schema = None + if provider_yaml.get("oauth_schema", None) is not None: + oauth_schema = OAuthSchema( + client_schema=provider_yaml.get("oauth_schema", {}).get("client_schema", []), + credentials_schema=provider_yaml.get("oauth_schema", {}).get("credentials_schema", []), + ) + super().__init__( entity=ToolProviderEntity( identity=provider_yaml["identity"], credentials_schema=credentials_schema, + oauth_schema=oauth_schema, ), ) @@ -97,10 +111,39 @@ class BuiltinToolProviderController(ToolProviderController): :return: the credentials schema """ - if not self.entity.credentials_schema: - return [] + return self.get_credentials_schema_by_type(CredentialType.API_KEY.value) + + def get_credentials_schema_by_type(self, credential_type: str) -> list[ProviderConfig]: + """ + returns the credentials schema of the provider - return self.entity.credentials_schema.copy() + :param credential_type: the type of the credential + :return: the credentials schema of the provider + """ + if credential_type == CredentialType.OAUTH2.value: + return self.entity.oauth_schema.credentials_schema.copy() if self.entity.oauth_schema else [] + if credential_type == CredentialType.API_KEY.value: + return self.entity.credentials_schema.copy() if self.entity.credentials_schema else [] + raise ValueError(f"Invalid credential type: {credential_type}") + + def get_oauth_client_schema(self) -> list[ProviderConfig]: + """ + returns the oauth client schema of the provider + + :return: the oauth client schema + """ + return self.entity.oauth_schema.client_schema.copy() if self.entity.oauth_schema else [] + + def get_supported_credential_types(self) -> list[str]: + """ + returns the credential support type of the provider + """ + types = [] + if self.entity.credentials_schema is not None and len(self.entity.credentials_schema) > 0: + types.append(CredentialType.API_KEY.value) + if self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) > 0: + types.append(CredentialType.OAUTH2.value) + return types def get_tools(self) -> list[BuiltinTool]: """ @@ -123,7 +166,11 @@ class BuiltinToolProviderController(ToolProviderController): :return: whether the provider needs credentials """ - return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0 + return ( + self.entity.credentials_schema is not None + and len(self.entity.credentials_schema) != 0 + or (self.entity.oauth_schema is not None and len(self.entity.oauth_schema.credentials_schema) != 0) + ) @property def provider_type(self) -> ToolProviderType: diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 90134ba71d..27ce96b90e 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, field_validator from core.model_runtime.utils.encoders import jsonable_encoder from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject -from core.tools.entities.tool_entities import ToolProviderType +from core.tools.entities.tool_entities import CredentialType, ToolProviderType class ToolApiEntity(BaseModel): @@ -87,3 +87,22 @@ class ToolProviderApiEntity(BaseModel): def optional_field(self, key: str, value: Any) -> dict: """Return dict with key-value if value is truthy, empty dict otherwise.""" return {key: value} if value else {} + + +class ToolProviderCredentialApiEntity(BaseModel): + id: str = Field(description="The unique id of the credential") + name: str = Field(description="The name of the credential") + provider: str = Field(description="The provider of the credential") + credential_type: CredentialType = Field(description="The type of the credential") + is_default: bool = Field( + default=False, description="Whether the credential is the default credential for the provider in the workspace" + ) + credentials: dict = Field(description="The credentials of the provider") + + +class ToolProviderCredentialInfoApiEntity(BaseModel): + supported_credential_types: list[str] = Field(description="The supported credential types of the provider") + is_oauth_custom_client_enabled: bool = Field( + default=False, description="Whether the OAuth custom client is enabled for the provider" + ) + credentials: list[ToolProviderCredentialApiEntity] = Field(description="The credentials of the provider") diff --git a/api/core/tools/entities/tool_entities.py b/api/core/tools/entities/tool_entities.py index b5148e245f..5377cbbb69 100644 --- a/api/core/tools/entities/tool_entities.py +++ b/api/core/tools/entities/tool_entities.py @@ -16,6 +16,7 @@ from core.plugin.entities.parameters import ( cast_parameter_value, init_frontend_parameter, ) +from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.tools.entities.common_entities import I18nObject from core.tools.entities.constants import TOOL_SELECTOR_MODEL_IDENTITY @@ -179,6 +180,10 @@ class ToolInvokeMessage(BaseModel): data: Mapping[str, Any] = Field(..., description="Detailed log data") metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log") + class RetrieverResourceMessage(BaseModel): + retriever_resources: list[RetrievalSourceMetadata] = Field(..., description="retriever resources") + context: str = Field(..., description="context") + class MessageType(Enum): TEXT = "text" IMAGE = "image" @@ -191,13 +196,22 @@ class ToolInvokeMessage(BaseModel): FILE = "file" LOG = "log" BLOB_CHUNK = "blob_chunk" + RETRIEVER_RESOURCES = "retriever_resources" type: MessageType = MessageType.TEXT """ plain text, image url or link url """ message: ( - JsonMessage | TextMessage | BlobChunkMessage | BlobMessage | LogMessage | FileMessage | None | VariableMessage + JsonMessage + | TextMessage + | BlobChunkMessage + | BlobMessage + | LogMessage + | FileMessage + | None + | VariableMessage + | RetrieverResourceMessage ) meta: dict[str, Any] | None = None @@ -243,6 +257,7 @@ class ToolParameter(PluginParameter): FILES = PluginParameterType.FILES.value APP_SELECTOR = PluginParameterType.APP_SELECTOR.value MODEL_SELECTOR = PluginParameterType.MODEL_SELECTOR.value + ANY = PluginParameterType.ANY.value DYNAMIC_SELECT = PluginParameterType.DYNAMIC_SELECT.value # MCP object and array type parameters @@ -355,10 +370,18 @@ class ToolEntity(BaseModel): return v or [] +class OAuthSchema(BaseModel): + client_schema: list[ProviderConfig] = Field(default_factory=list, description="The schema of the OAuth client") + credentials_schema: list[ProviderConfig] = Field( + default_factory=list, description="The schema of the OAuth credentials" + ) + + class ToolProviderEntity(BaseModel): identity: ToolProviderIdentity plugin_id: Optional[str] = None credentials_schema: list[ProviderConfig] = Field(default_factory=list) + oauth_schema: Optional[OAuthSchema] = None class ToolProviderEntityWithPlugin(ToolProviderEntity): @@ -438,6 +461,7 @@ class ToolSelector(BaseModel): options: Optional[list[PluginParameterOption]] = None provider_id: str = Field(..., description="The id of the provider") + credential_id: Optional[str] = Field(default=None, description="The id of the credential") tool_name: str = Field(..., description="The name of the tool") tool_description: str = Field(..., description="The description of the tool") tool_configuration: Mapping[str, Any] = Field(..., description="Configuration, type form") @@ -445,3 +469,36 @@ class ToolSelector(BaseModel): def to_plugin_parameter(self) -> dict[str, Any]: return self.model_dump() + + +class CredentialType(enum.StrEnum): + API_KEY = "api-key" + OAUTH2 = "oauth2" + + def get_name(self): + if self == CredentialType.API_KEY: + return "API KEY" + elif self == CredentialType.OAUTH2: + return "AUTH" + else: + return self.value.replace("-", " ").upper() + + def is_editable(self): + return self == CredentialType.API_KEY + + def is_validate_allowed(self): + return self == CredentialType.API_KEY + + @classmethod + def values(cls): + return [item.value for item in cls] + + @classmethod + def of(cls, credential_type: str) -> "CredentialType": + type_name = credential_type.lower() + if type_name == "api-key": + return cls.API_KEY + elif type_name == "oauth2": + return cls.OAUTH2 + else: + raise ValueError(f"Invalid credential type: {credential_type}") diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index d21e3d7d1c..aef2677c36 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -44,6 +44,7 @@ class PluginTool(Tool): tool_provider=self.entity.identity.provider, tool_name=self.entity.identity.name, credentials=self.runtime.credentials, + credential_type=self.runtime.credential_type, tool_parameters=tool_parameters, conversation_id=conversation_id, app_id=app_id, diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 22a9853b41..7822bc389c 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast from yarl import URL import contexts +from core.helper.provider_cache import ToolProviderCredentialsCache from core.plugin.entities.plugin import ToolProviderID from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool_provider import ToolProviderController @@ -17,14 +18,14 @@ from core.tools.mcp_tool.provider import MCPToolProviderController from core.tools.mcp_tool.tool import MCPTool from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.tool import PluginTool +from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.workflow.entities.variable_pool import VariablePool -from services.tools.mcp_tools_mange_service import MCPToolManageService +from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: from core.workflow.nodes.tool.entities import ToolEntity - from configs import dify_config from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom @@ -41,16 +42,17 @@ from core.tools.entities.api_entities import ToolProviderApiEntity, ToolProvider from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( ApiProviderAuthType, + CredentialType, ToolInvokeFrom, ToolParameter, ToolProviderType, ) -from core.tools.errors import ToolNotFoundError, ToolProviderNotFoundError +from core.tools.errors import ToolProviderNotFoundError from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ( - ProviderConfigEncrypter, ToolParameterConfigurationManager, ) +from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool from extensions.ext_database import db from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider @@ -68,8 +70,11 @@ class ToolManager: @classmethod def get_hardcoded_provider(cls, provider: str) -> BuiltinToolProviderController: """ + get the hardcoded provider + """ + if len(cls._hardcoded_providers) == 0: # init the builtin providers cls.load_hardcoded_providers_cache() @@ -113,7 +118,12 @@ class ToolManager: contexts.plugin_tool_providers.set({}) contexts.plugin_tool_providers_lock.set(Lock()) + plugin_tool_providers = contexts.plugin_tool_providers.get() + if provider in plugin_tool_providers: + return plugin_tool_providers[provider] + with contexts.plugin_tool_providers_lock.get(): + # double check plugin_tool_providers = contexts.plugin_tool_providers.get() if provider in plugin_tool_providers: return plugin_tool_providers[provider] @@ -131,25 +141,7 @@ class ToolManager: ) plugin_tool_providers[provider] = controller - - return controller - - @classmethod - def get_builtin_tool(cls, provider: str, tool_name: str, tenant_id: str) -> BuiltinTool | PluginTool | None: - """ - get the builtin tool - - :param provider: the name of the provider - :param tool_name: the name of the tool - :param tenant_id: the id of the tenant - :return: the provider, the tool - """ - provider_controller = cls.get_builtin_provider(provider, tenant_id) - tool = provider_controller.get_tool(tool_name) - if tool is None: - raise ToolNotFoundError(f"tool {tool_name} not found") - - return tool + return controller @classmethod def get_tool_runtime( @@ -160,6 +152,7 @@ class ToolManager: tenant_id: str, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, + credential_id: Optional[str] = None, ) -> Union[BuiltinTool, PluginTool, ApiTool, WorkflowTool, MCPTool]: """ get the tool runtime @@ -170,6 +163,7 @@ class ToolManager: :param tenant_id: the tenant id :param invoke_from: invoke from :param tool_invoke_from: the tool invoke from + :param credential_id: the credential id :return: the tool """ @@ -193,49 +187,70 @@ class ToolManager: ) ), ) - + builtin_provider = None if isinstance(provider_controller, PluginToolProviderController): provider_id_entity = ToolProviderID(provider_id) - # get credentials - builtin_provider: BuiltinToolProvider | None = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == str(provider_id_entity)) - | (BuiltinToolProvider.provider == provider_id_entity.provider_name), - ) - .first() - ) - + # get specific credentials + if is_valid_uuid(credential_id): + try: + builtin_provider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first() + ) + except Exception as e: + builtin_provider = None + logger.info(f"Error getting builtin provider {credential_id}:{e}", exc_info=True) + # if the provider has been deleted, raise an error + if builtin_provider is None: + raise ToolProviderNotFoundError(f"provider has been deleted: {credential_id}") + + # fallback to the default provider if builtin_provider is None: - raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") + # use the default provider + builtin_provider = ( + db.session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == str(provider_id_entity)) + | (BuiltinToolProvider.provider == provider_id_entity.provider_name), + ) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .first() + ) + if builtin_provider is None: + raise ToolProviderNotFoundError(f"no default provider for {provider_id}") else: builtin_provider = ( db.session.query(BuiltinToolProvider) .filter(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) .first() ) if builtin_provider is None: raise ToolProviderNotFoundError(f"builtin provider {provider_id} not found") - # decrypt the credentials - credentials = builtin_provider.credentials - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(builtin_provider.credential_type) + ], + cache=ToolProviderCredentialsCache( + tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id + ), ) - - decrypted_credentials = tool_configuration.decrypt(credentials) - return cast( BuiltinTool, builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, - credentials=decrypted_credentials, + credentials=encrypter.decrypt(builtin_provider.credentials), + credential_type=CredentialType.of(builtin_provider.credential_type), runtime_parameters={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -245,22 +260,16 @@ class ToolManager: elif provider_type == ToolProviderType.API: api_provider, credentials = cls.get_api_provider_controller(tenant_id, provider_id) - - # decrypt the credentials - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in api_provider.get_credentials_schema()], - provider_type=api_provider.provider_type.value, - provider_identity=api_provider.entity.identity.name, + controller=api_provider, ) - decrypted_credentials = tool_configuration.decrypt(credentials) - return cast( ApiTool, api_provider.get_tool(tool_name).fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, - credentials=decrypted_credentials, + credentials=encrypter.decrypt(credentials), invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, ) @@ -320,6 +329,7 @@ class ToolManager: tenant_id=tenant_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.AGENT, + credential_id=agent_tool.credential_id, ) runtime_parameters = {} parameters = tool_entity.get_merged_runtime_parameters() @@ -362,6 +372,7 @@ class ToolManager: tenant_id=tenant_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, + credential_id=workflow_tool.credential_id, ) parameters = tool_runtime.get_merged_runtime_parameters() @@ -391,6 +402,7 @@ class ToolManager: provider: str, tool_name: str, tool_parameters: dict[str, Any], + credential_id: Optional[str] = None, ) -> Tool: """ get tool runtime from plugin @@ -402,6 +414,7 @@ class ToolManager: tenant_id=tenant_id, invoke_from=InvokeFrom.SERVICE_API, tool_invoke_from=ToolInvokeFrom.PLUGIN, + credential_id=credential_id, ) runtime_parameters = {} parameters = tool_entity.get_merged_runtime_parameters() @@ -551,6 +564,22 @@ class ToolManager: return cls._builtin_tools_labels[tool_name] + @classmethod + def list_default_builtin_providers(cls, tenant_id: str) -> list[BuiltinToolProvider]: + """ + list all the builtin providers + """ + # according to multi credentials, select the one with is_default=True first, then created_at oldest + # for compatibility with old version + sql = """ + SELECT DISTINCT ON (tenant_id, provider) id + FROM tool_builtin_providers + WHERE tenant_id = :tenant_id + ORDER BY tenant_id, provider, is_default DESC, created_at DESC + """ + ids = [row.id for row in db.session.execute(db.text(sql), {"tenant_id": tenant_id}).all()] + return db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.id.in_(ids)).all() + @classmethod def list_providers_from_api( cls, user_id: str, tenant_id: str, typ: ToolProviderTypeApiLiteral @@ -565,21 +594,13 @@ class ToolManager: with db.session.no_autoflush: if "builtin" in filters: - # get builtin providers builtin_providers = cls.list_builtin_providers(tenant_id) - # get db builtin providers - db_builtin_providers: list[BuiltinToolProvider] = ( - db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() - ) - - # rewrite db_builtin_providers - for db_provider in db_builtin_providers: - tool_provider_id = str(ToolProviderID(db_provider.provider)) - db_provider.provider = tool_provider_id - - def find_db_builtin_provider(provider): - return next((x for x in db_builtin_providers if x.provider == provider), None) + # key: provider name, value: provider + db_builtin_providers = { + str(ToolProviderID(provider.provider)): provider + for provider in cls.list_default_builtin_providers(tenant_id) + } # append builtin providers for provider in builtin_providers: @@ -591,10 +612,9 @@ class ToolManager: name_func=lambda x: x.identity.name, ): continue - user_provider = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider, - db_provider=find_db_builtin_provider(provider.entity.identity.name), + db_provider=db_builtin_providers.get(provider.entity.identity.name), decrypt_credentials=False, ) @@ -604,7 +624,6 @@ class ToolManager: result_providers[f"builtin_provider.{user_provider.name}"] = user_provider # get db api providers - if "api" in filters: db_api_providers: list[ApiToolProvider] = ( db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() @@ -764,15 +783,12 @@ class ToolManager: auth_type, ) # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], - provider_type=controller.provider_type.value, - provider_identity=controller.entity.identity.name, + controller=controller, ) - decrypted_credentials = tool_configuration.decrypt(credentials) - masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) + masked_credentials = encrypter.mask_tool_credentials(encrypter.decrypt(credentials)) try: icon = json.loads(provider_obj.icon) diff --git a/api/core/tools/utils/configuration.py b/api/core/tools/utils/configuration.py index 251fedf56e..aceba6e69f 100644 --- a/api/core/tools/utils/configuration.py +++ b/api/core/tools/utils/configuration.py @@ -1,12 +1,8 @@ from copy import deepcopy from typing import Any -from pydantic import BaseModel - -from core.entities.provider_entities import BasicProviderConfig from core.helper import encrypter from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType -from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ( ToolParameter, @@ -14,110 +10,6 @@ from core.tools.entities.tool_entities import ( ) -class ProviderConfigEncrypter(BaseModel): - tenant_id: str - config: list[BasicProviderConfig] - provider_type: str - provider_identity: str - - def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: - """ - deep copy data - """ - return deepcopy(data) - - def encrypt(self, data: dict[str, str]) -> dict[str, str]: - """ - encrypt tool credentials with tenant id - - return a deep copy of credentials with encrypted values - """ - data = self._deep_copy(data) - - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") - data[field_name] = encrypted - - return data - - def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: - """ - mask tool credentials - - return a deep copy of credentials with masked values - """ - data = self._deep_copy(data) - - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - if len(data[field_name]) > 6: - data[field_name] = ( - data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] - ) - else: - data[field_name] = "*" * len(data[field_name]) - - return data - - def decrypt(self, data: dict[str, str], use_cache: bool = True) -> dict[str, str]: - """ - decrypt tool credentials with tenant id - - return a deep copy of credentials with decrypted values - """ - if use_cache: - cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f"{self.provider_type}.{self.provider_identity}", - cache_type=ToolProviderCredentialsCacheType.PROVIDER, - ) - cached_credentials = cache.get() - if cached_credentials: - return cached_credentials - data = self._deep_copy(data) - # get fields need to be decrypted - fields = dict[str, BasicProviderConfig]() - for credential in self.config: - fields[credential.name] = credential - - for field_name, field in fields.items(): - if field.type == BasicProviderConfig.Type.SECRET_INPUT: - if field_name in data: - try: - # if the value is None or empty string, skip decrypt - if not data[field_name]: - continue - - data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) - except Exception: - pass - - if use_cache: - cache.set(data) - return data - - def delete_tool_credentials_cache(self): - cache = ToolProviderCredentialsCache( - tenant_id=self.tenant_id, - identity_id=f"{self.provider_type}.{self.provider_identity}", - cache_type=ToolProviderCredentialsCacheType.PROVIDER, - ) - cache.delete() - - class ToolParameterConfigurationManager: """ Tool parameter configuration manager diff --git a/api/core/tools/utils/encryption.py b/api/core/tools/utils/encryption.py new file mode 100644 index 0000000000..5fdfd3b9d1 --- /dev/null +++ b/api/core/tools/utils/encryption.py @@ -0,0 +1,142 @@ +from copy import deepcopy +from typing import Any, Optional, Protocol + +from core.entities.provider_entities import BasicProviderConfig +from core.helper import encrypter +from core.helper.provider_cache import SingletonProviderCredentialsCache +from core.tools.__base.tool_provider import ToolProviderController + + +class ProviderConfigCache(Protocol): + """ + Interface for provider configuration cache operations + """ + + def get(self) -> Optional[dict]: + """Get cached provider configuration""" + ... + + def set(self, config: dict[str, Any]) -> None: + """Cache provider configuration""" + ... + + def delete(self) -> None: + """Delete cached provider configuration""" + ... + + +class ProviderConfigEncrypter: + tenant_id: str + config: list[BasicProviderConfig] + provider_config_cache: ProviderConfigCache + + def __init__( + self, + tenant_id: str, + config: list[BasicProviderConfig], + provider_config_cache: ProviderConfigCache, + ): + self.tenant_id = tenant_id + self.config = config + self.provider_config_cache = provider_config_cache + + def _deep_copy(self, data: dict[str, str]) -> dict[str, str]: + """ + deep copy data + """ + return deepcopy(data) + + def encrypt(self, data: dict[str, str]) -> dict[str, str]: + """ + encrypt tool credentials with tenant id + + return a deep copy of credentials with encrypted values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "") + data[field_name] = encrypted + + return data + + def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]: + """ + mask tool credentials + + return a deep copy of credentials with masked values + """ + data = self._deep_copy(data) + + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + if len(data[field_name]) > 6: + data[field_name] = ( + data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:] + ) + else: + data[field_name] = "*" * len(data[field_name]) + + return data + + def decrypt(self, data: dict[str, str]) -> dict[str, Any]: + """ + decrypt tool credentials with tenant id + + return a deep copy of credentials with decrypted values + """ + cached_credentials = self.provider_config_cache.get() + if cached_credentials: + return cached_credentials + + data = self._deep_copy(data) + # get fields need to be decrypted + fields = dict[str, BasicProviderConfig]() + for credential in self.config: + fields[credential.name] = credential + + for field_name, field in fields.items(): + if field.type == BasicProviderConfig.Type.SECRET_INPUT: + if field_name in data: + try: + # if the value is None or empty string, skip decrypt + if not data[field_name]: + continue + + data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name]) + except Exception: + pass + + self.provider_config_cache.set(data) + return data + + +def create_provider_encrypter(tenant_id: str, config: list[BasicProviderConfig], cache: ProviderConfigCache): + return ProviderConfigEncrypter(tenant_id=tenant_id, config=config, provider_config_cache=cache), cache + + +def create_tool_provider_encrypter(tenant_id: str, controller: ToolProviderController): + cache = SingletonProviderCredentialsCache( + tenant_id=tenant_id, + provider_type=controller.provider_type.value, + provider_identity=controller.entity.identity.name, + ) + encrypt = ProviderConfigEncrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in controller.get_credentials_schema()], + provider_config_cache=cache, + ) + return encrypt, cache diff --git a/api/core/tools/utils/system_oauth_encryption.py b/api/core/tools/utils/system_oauth_encryption.py new file mode 100644 index 0000000000..f3c946b95f --- /dev/null +++ b/api/core/tools/utils/system_oauth_encryption.py @@ -0,0 +1,187 @@ +import base64 +import hashlib +import logging +from collections.abc import Mapping +from typing import Any, Optional + +from Crypto.Cipher import AES +from Crypto.Random import get_random_bytes +from Crypto.Util.Padding import pad, unpad +from pydantic import TypeAdapter + +from configs import dify_config + +logger = logging.getLogger(__name__) + + +class OAuthEncryptionError(Exception): + """OAuth encryption/decryption specific error""" + + pass + + +class SystemOAuthEncrypter: + """ + A simple OAuth parameters encrypter using AES-CBC encryption. + + This class provides methods to encrypt and decrypt OAuth parameters + using AES-CBC mode with a key derived from the application's SECRET_KEY. + """ + + def __init__(self, secret_key: Optional[str] = None): + """ + Initialize the OAuth encrypter. + + Args: + secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY + + Raises: + ValueError: If SECRET_KEY is not configured or empty + """ + secret_key = secret_key or dify_config.SECRET_KEY or "" + + # Generate a fixed 256-bit key using SHA-256 + self.key = hashlib.sha256(secret_key.encode()).digest() + + def encrypt_oauth_params(self, oauth_params: Mapping[str, Any]) -> str: + """ + Encrypt OAuth parameters. + + Args: + oauth_params: OAuth parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"} + + Returns: + Base64-encoded encrypted string + + Raises: + OAuthEncryptionError: If encryption fails + ValueError: If oauth_params is invalid + """ + + try: + # Generate random IV (16 bytes) + iv = get_random_bytes(16) + + # Create AES cipher (CBC mode) + cipher = AES.new(self.key, AES.MODE_CBC, iv) + + # Encrypt data + padded_data = pad(TypeAdapter(dict).dump_json(dict(oauth_params)), AES.block_size) + encrypted_data = cipher.encrypt(padded_data) + + # Combine IV and encrypted data + combined = iv + encrypted_data + + # Return base64 encoded string + return base64.b64encode(combined).decode() + + except Exception as e: + raise OAuthEncryptionError(f"Encryption failed: {str(e)}") from e + + def decrypt_oauth_params(self, encrypted_data: str) -> Mapping[str, Any]: + """ + Decrypt OAuth parameters. + + Args: + encrypted_data: Base64-encoded encrypted string + + Returns: + Decrypted OAuth parameters dictionary + + Raises: + OAuthEncryptionError: If decryption fails + ValueError: If encrypted_data is invalid + """ + if not isinstance(encrypted_data, str): + raise ValueError("encrypted_data must be a string") + + if not encrypted_data: + raise ValueError("encrypted_data cannot be empty") + + try: + # Base64 decode + combined = base64.b64decode(encrypted_data) + + # Check minimum length (IV + at least one AES block) + if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data + raise ValueError("Invalid encrypted data format") + + # Separate IV and encrypted data + iv = combined[:16] + encrypted_data_bytes = combined[16:] + + # Create AES cipher + cipher = AES.new(self.key, AES.MODE_CBC, iv) + + # Decrypt data + decrypted_data = cipher.decrypt(encrypted_data_bytes) + unpadded_data = unpad(decrypted_data, AES.block_size) + + # Parse JSON + oauth_params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data) + + if not isinstance(oauth_params, dict): + raise ValueError("Decrypted data is not a valid dictionary") + + return oauth_params + + except Exception as e: + raise OAuthEncryptionError(f"Decryption failed: {str(e)}") from e + + +# Factory function for creating encrypter instances +def create_system_oauth_encrypter(secret_key: Optional[str] = None) -> SystemOAuthEncrypter: + """ + Create an OAuth encrypter instance. + + Args: + secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY + + Returns: + SystemOAuthEncrypter instance + """ + return SystemOAuthEncrypter(secret_key=secret_key) + + +# Global encrypter instance (for backward compatibility) +_oauth_encrypter: Optional[SystemOAuthEncrypter] = None + + +def get_system_oauth_encrypter() -> SystemOAuthEncrypter: + """ + Get the global OAuth encrypter instance. + + Returns: + SystemOAuthEncrypter instance + """ + global _oauth_encrypter + if _oauth_encrypter is None: + _oauth_encrypter = SystemOAuthEncrypter() + return _oauth_encrypter + + +# Convenience functions for backward compatibility +def encrypt_system_oauth_params(oauth_params: Mapping[str, Any]) -> str: + """ + Encrypt OAuth parameters using the global encrypter. + + Args: + oauth_params: OAuth parameters dictionary + + Returns: + Base64-encoded encrypted string + """ + return get_system_oauth_encrypter().encrypt_oauth_params(oauth_params) + + +def decrypt_system_oauth_params(encrypted_data: str) -> Mapping[str, Any]: + """ + Decrypt OAuth parameters using the global encrypter. + + Args: + encrypted_data: Base64-encoded encrypted string + + Returns: + Decrypted OAuth parameters dictionary + """ + return get_system_oauth_encrypter().decrypt_oauth_params(encrypted_data) diff --git a/api/core/tools/utils/uuid_utils.py b/api/core/tools/utils/uuid_utils.py index 3046c08c89..bdcc33259d 100644 --- a/api/core/tools/utils/uuid_utils.py +++ b/api/core/tools/utils/uuid_utils.py @@ -1,7 +1,9 @@ import uuid -def is_valid_uuid(uuid_str: str) -> bool: +def is_valid_uuid(uuid_str: str | None) -> bool: + if uuid_str is None or len(uuid_str) == 0: + return False try: uuid.UUID(uuid_str) return True diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index 6cf09e0372..13274f4e0e 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -1,9 +1,9 @@ import json import sys from collections.abc import Mapping, Sequence -from typing import Any +from typing import Annotated, Any, TypeAlias -from pydantic import BaseModel, ConfigDict, field_validator +from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator from core.file import File @@ -11,6 +11,11 @@ from .types import SegmentType class Segment(BaseModel): + """Segment is runtime type used during the execution of workflow. + + Note: this class is abstract, you should use subclasses of this class instead. + """ + model_config = ConfigDict(frozen=True) value_type: SegmentType @@ -73,7 +78,7 @@ class StringSegment(Segment): class FloatSegment(Segment): - value_type: SegmentType = SegmentType.NUMBER + value_type: SegmentType = SegmentType.FLOAT value: float # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. # The following tests cannot pass. @@ -92,7 +97,7 @@ class FloatSegment(Segment): class IntegerSegment(Segment): - value_type: SegmentType = SegmentType.NUMBER + value_type: SegmentType = SegmentType.INTEGER value: int @@ -181,3 +186,46 @@ class ArrayFileSegment(ArraySegment): @property def text(self) -> str: return "" + + +def get_segment_discriminator(v: Any) -> SegmentType | None: + if isinstance(v, Segment): + return v.value_type + elif isinstance(v, dict): + value_type = v.get("value_type") + if value_type is None: + return None + try: + seg_type = SegmentType(value_type) + except ValueError: + return None + return seg_type + else: + # return None if the discriminator value isn't found + return None + + +# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic. +# Use `Segment` for type hinting when serialization is not required. +# +# Note: +# - All variants in `SegmentUnion` must inherit from the `Segment` class. +# - The union must include all non-abstract subclasses of `Segment`, except: +# - `SegmentGroup`, which is not added to the variable pool. +# - `Variable` and its subclasses, which are handled by `VariableUnion`. +SegmentUnion: TypeAlias = Annotated[ + ( + Annotated[NoneSegment, Tag(SegmentType.NONE)] + | Annotated[StringSegment, Tag(SegmentType.STRING)] + | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] + | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] + | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] + | Annotated[FileSegment, Tag(SegmentType.FILE)] + | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] + | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] + | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] + | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] + | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] + ), + Discriminator(get_segment_discriminator), +] diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 68d3d82883..e79b2410bf 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -1,8 +1,27 @@ +from collections.abc import Mapping from enum import StrEnum +from typing import Any, Optional + +from core.file.models import File + + +class ArrayValidation(StrEnum): + """Strategy for validating array elements""" + + # Skip element validation (only check array container) + NONE = "none" + + # Validate the first element (if array is non-empty) + FIRST = "first" + + # Validate all elements in the array. + ALL = "all" class SegmentType(StrEnum): NUMBER = "number" + INTEGER = "integer" + FLOAT = "float" STRING = "string" OBJECT = "object" SECRET = "secret" @@ -19,16 +38,139 @@ class SegmentType(StrEnum): GROUP = "group" - def is_array_type(self): + def is_array_type(self) -> bool: return self in _ARRAY_TYPES + @classmethod + def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]: + """ + Attempt to infer the `SegmentType` based on the Python type of the `value` parameter. + + Returns `None` if no appropriate `SegmentType` can be determined for the given `value`. + For example, this may occur if the input is a generic Python object of type `object`. + """ + + if isinstance(value, list): + elem_types: set[SegmentType] = set() + for i in value: + segment_type = cls.infer_segment_type(i) + if segment_type is None: + return None + + elem_types.add(segment_type) + + if len(elem_types) != 1: + if elem_types.issubset(_NUMERICAL_TYPES): + return SegmentType.ARRAY_NUMBER + return SegmentType.ARRAY_ANY + elif all(i.is_array_type() for i in elem_types): + return SegmentType.ARRAY_ANY + match elem_types.pop(): + case SegmentType.STRING: + return SegmentType.ARRAY_STRING + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: + return SegmentType.ARRAY_NUMBER + case SegmentType.OBJECT: + return SegmentType.ARRAY_OBJECT + case SegmentType.FILE: + return SegmentType.ARRAY_FILE + case SegmentType.NONE: + return SegmentType.ARRAY_ANY + case _: + # This should be unreachable. + raise ValueError(f"not supported value {value}") + if value is None: + return SegmentType.NONE + elif isinstance(value, int) and not isinstance(value, bool): + return SegmentType.INTEGER + elif isinstance(value, float): + return SegmentType.FLOAT + elif isinstance(value, str): + return SegmentType.STRING + elif isinstance(value, dict): + return SegmentType.OBJECT + elif isinstance(value, File): + return SegmentType.FILE + else: + return None + + def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool: + if not isinstance(value, list): + return False + # Skip element validation if array is empty + if len(value) == 0: + return True + if self == SegmentType.ARRAY_ANY: + return True + element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self] + + if array_validation == ArrayValidation.NONE: + return True + elif array_validation == ArrayValidation.FIRST: + return element_type.is_valid(value[0]) + else: + return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value) + + def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool: + """ + Check if a value matches the segment type. + Users of `SegmentType` should call this method, instead of using + `isinstance` manually. + + Args: + value: The value to validate + array_validation: Validation strategy for array types (ignored for non-array types) + + Returns: + True if the value matches the type under the given validation strategy + """ + if self.is_array_type(): + return self._validate_array(value, array_validation) + elif self == SegmentType.NUMBER: + return isinstance(value, (int, float)) + elif self == SegmentType.STRING: + return isinstance(value, str) + elif self == SegmentType.OBJECT: + return isinstance(value, dict) + elif self == SegmentType.SECRET: + return isinstance(value, str) + elif self == SegmentType.FILE: + return isinstance(value, File) + elif self == SegmentType.NONE: + return value is None + else: + raise AssertionError("this statement should be unreachable.") + + def exposed_type(self) -> "SegmentType": + """Returns the type exposed to the frontend. + + The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here. + """ + if self in (SegmentType.INTEGER, SegmentType.FLOAT): + return SegmentType.NUMBER + return self + + +_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = { + # ARRAY_ANY does not have correpond element type. + SegmentType.ARRAY_STRING: SegmentType.STRING, + SegmentType.ARRAY_NUMBER: SegmentType.NUMBER, + SegmentType.ARRAY_OBJECT: SegmentType.OBJECT, + SegmentType.ARRAY_FILE: SegmentType.FILE, +} _ARRAY_TYPES = frozenset( - [ + list(_ARRAY_ELEMENT_TYPES_MAPPING.keys()) + + [ SegmentType.ARRAY_ANY, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_FILE, + ] +) + + +_NUMERICAL_TYPES = frozenset( + [ + SegmentType.NUMBER, + SegmentType.INTEGER, + SegmentType.FLOAT, ] ) diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index b650b1682e..a31ebc848e 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -1,8 +1,8 @@ from collections.abc import Sequence -from typing import cast +from typing import Annotated, TypeAlias, cast from uuid import uuid4 -from pydantic import Field +from pydantic import Discriminator, Field, Tag from core.helper import encrypter @@ -20,6 +20,7 @@ from .segments import ( ObjectSegment, Segment, StringSegment, + get_segment_discriminator, ) from .types import SegmentType @@ -27,6 +28,10 @@ from .types import SegmentType class Variable(Segment): """ A variable is a segment that has a name. + + It is mainly used to store segments and their selector in VariablePool. + + Note: this class is abstract, you should use subclasses of this class instead. """ id: str = Field( @@ -93,3 +98,28 @@ class FileVariable(FileSegment, Variable): class ArrayFileVariable(ArrayFileSegment, ArrayVariable): pass + + +# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic. +# Use `Variable` for type hinting when serialization is not required. +# +# Note: +# - All variants in `VariableUnion` must inherit from the `Variable` class. +# - The union must include all non-abstract subclasses of `Segment`, except: +VariableUnion: TypeAlias = Annotated[ + ( + Annotated[NoneVariable, Tag(SegmentType.NONE)] + | Annotated[StringVariable, Tag(SegmentType.STRING)] + | Annotated[FloatVariable, Tag(SegmentType.FLOAT)] + | Annotated[IntegerVariable, Tag(SegmentType.INTEGER)] + | Annotated[ObjectVariable, Tag(SegmentType.OBJECT)] + | Annotated[FileVariable, Tag(SegmentType.FILE)] + | Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)] + | Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)] + | Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)] + | Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)] + | Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)] + | Annotated[SecretVariable, Tag(SegmentType.SECRET)] + ), + Discriminator(get_segment_discriminator), +] diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 80dda2632d..fbb8df6b01 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -1,7 +1,7 @@ import re from collections import defaultdict from collections.abc import Mapping, Sequence -from typing import Any, Union +from typing import Annotated, Any, Union, cast from pydantic import BaseModel, Field @@ -9,8 +9,9 @@ from core.file import File, FileAttribute, file_manager from core.variables import Segment, SegmentGroup, Variable from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.segments import FileSegment, NoneSegment +from core.variables.variables import VariableUnion from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from factories import variable_factory VariableValue = Union[str, int, float, dict, list, File] @@ -23,31 +24,31 @@ class VariablePool(BaseModel): # The first element of the selector is the node id, it's the first-level key in the dictionary. # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # elements of the selector except the first one. - variable_dictionary: dict[str, dict[int, Segment]] = Field( + variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field( description="Variables mapping", default=defaultdict(dict), ) - # TODO: This user inputs is not used for pool. + + # The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere. user_inputs: Mapping[str, Any] = Field( description="User inputs", default_factory=dict, ) - system_variables: Mapping[SystemVariableKey, Any] = Field( + system_variables: SystemVariable = Field( description="System variables", - default_factory=dict, ) - environment_variables: Sequence[Variable] = Field( + environment_variables: Sequence[VariableUnion] = Field( description="Environment variables.", default_factory=list, ) - conversation_variables: Sequence[Variable] = Field( + conversation_variables: Sequence[VariableUnion] = Field( description="Conversation variables.", default_factory=list, ) def model_post_init(self, context: Any, /) -> None: - for key, value in self.system_variables.items(): - self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) + # Create a mapping from field names to SystemVariableKey enum values + self._add_system_variables(self.system_variables) # Add environment variables to the variable pool for var in self.environment_variables: self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) @@ -83,8 +84,22 @@ class VariablePool(BaseModel): segment = variable_factory.build_segment(value) variable = variable_factory.segment_to_variable(segment=segment, selector=selector) - hash_key = hash(tuple(selector[1:])) - self.variable_dictionary[selector[0]][hash_key] = variable + key, hash_key = self._selector_to_keys(selector) + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable) + + @classmethod + def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]: + return selector[0], hash(tuple(selector[1:])) + + def _has(self, selector: Sequence[str]) -> bool: + key, hash_key = self._selector_to_keys(selector) + if key not in self.variable_dictionary: + return False + if hash_key not in self.variable_dictionary[key]: + return False + return True def get(self, selector: Sequence[str], /) -> Segment | None: """ @@ -102,8 +117,8 @@ class VariablePool(BaseModel): if len(selector) < MIN_SELECTORS_LENGTH: return None - hash_key = hash(tuple(selector[1:])) - value = self.variable_dictionary[selector[0]].get(hash_key) + key, hash_key = self._selector_to_keys(selector) + value: Segment | None = self.variable_dictionary[key].get(hash_key) if value is None: selector, attr = selector[:-1], selector[-1] @@ -136,8 +151,8 @@ class VariablePool(BaseModel): if len(selector) == 1: self.variable_dictionary[selector[0]] = {} return - hash_key = hash(tuple(selector[1:])) - self.variable_dictionary[selector[0]].pop(hash_key, None) + key, hash_key = self._selector_to_keys(selector) + self.variable_dictionary[key].pop(hash_key, None) def convert_template(self, template: str, /): parts = VARIABLE_PATTERN.split(template) @@ -154,3 +169,20 @@ class VariablePool(BaseModel): if isinstance(segment, FileSegment): return segment return None + + def _add_system_variables(self, system_variable: SystemVariable): + sys_var_mapping = system_variable.to_dict() + for key, value in sys_var_mapping.items(): + if value is None: + continue + selector = (SYSTEM_VARIABLE_NODE_ID, key) + # If the system variable already exists, do not add it again. + # This ensures that we can keep the id of the system variables intact. + if self._has(selector): + continue + self.add(selector, value) # type: ignore + + @classmethod + def empty(cls) -> "VariablePool": + """Create an empty variable pool.""" + return cls(system_variables=SystemVariable.empty()) diff --git a/api/core/workflow/entities/workflow_entities.py b/api/core/workflow/entities/workflow_entities.py deleted file mode 100644 index 8896416f12..0000000000 --- a/api/core/workflow/entities/workflow_entities.py +++ /dev/null @@ -1,79 +0,0 @@ -from typing import Optional - -from pydantic import BaseModel - -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.nodes.base import BaseIterationState, BaseLoopState, BaseNode -from models.enums import UserFrom -from models.workflow import Workflow, WorkflowType - -from .node_entities import NodeRunResult -from .variable_pool import VariablePool - - -class WorkflowNodeAndResult: - node: BaseNode - result: Optional[NodeRunResult] = None - - def __init__(self, node: BaseNode, result: Optional[NodeRunResult] = None): - self.node = node - self.result = result - - -class WorkflowRunState: - tenant_id: str - app_id: str - workflow_id: str - workflow_type: WorkflowType - user_id: str - user_from: UserFrom - invoke_from: InvokeFrom - - workflow_call_depth: int - - start_at: float - variable_pool: VariablePool - - total_tokens: int = 0 - - workflow_nodes_and_results: list[WorkflowNodeAndResult] - - class NodeRun(BaseModel): - node_id: str - iteration_node_id: str - loop_node_id: str - - workflow_node_runs: list[NodeRun] - workflow_node_steps: int - - current_iteration_state: Optional[BaseIterationState] - current_loop_state: Optional[BaseLoopState] - - def __init__( - self, - workflow: Workflow, - start_at: float, - variable_pool: VariablePool, - user_id: str, - user_from: UserFrom, - invoke_from: InvokeFrom, - workflow_call_depth: int, - ): - self.workflow_id = workflow.id - self.tenant_id = workflow.tenant_id - self.app_id = workflow.app_id - self.workflow_type = WorkflowType.value_of(workflow.type) - self.user_id = user_id - self.user_from = user_from - self.invoke_from = invoke_from - self.workflow_call_depth = workflow_call_depth - - self.start_at = start_at - self.variable_pool = variable_pool - - self.total_tokens = 0 - - self.workflow_node_steps = 1 - self.workflow_node_runs = [] - self.current_iteration_state = None - self.current_loop_state = None diff --git a/api/core/workflow/errors.py b/api/core/workflow/errors.py index bd4ccc1072..594bb2b32e 100644 --- a/api/core/workflow/errors.py +++ b/api/core/workflow/errors.py @@ -2,7 +2,7 @@ from core.workflow.nodes.base import BaseNode class WorkflowNodeRunFailedError(Exception): - def __init__(self, node_instance: BaseNode, error: str): - self.node_instance = node_instance - self.error = error - super().__init__(f"Node {node_instance.node_data.title} run failed: {error}") + def __init__(self, node: BaseNode, err_msg: str): + self._node = node + self._error = err_msg + super().__init__(f"Node {node.title} run failed: {err_msg}") diff --git a/api/core/workflow/graph_engine/__init__.py b/api/core/workflow/graph_engine/__init__.py index 2fee3d7fad..12e1de464b 100644 --- a/api/core/workflow/graph_engine/__init__.py +++ b/api/core/workflow/graph_engine/__init__.py @@ -1,3 +1,4 @@ from .entities import Graph, GraphInitParams, GraphRuntimeState, RuntimeRouteState +from .graph_engine import GraphEngine -__all__ = ["Graph", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"] +__all__ = ["Graph", "GraphEngine", "GraphInitParams", "GraphRuntimeState", "RuntimeRouteState"] diff --git a/api/core/workflow/graph_engine/entities/graph_runtime_state.py b/api/core/workflow/graph_engine/entities/graph_runtime_state.py index afc09bfac5..a62ffe46c9 100644 --- a/api/core/workflow/graph_engine/entities/graph_runtime_state.py +++ b/api/core/workflow/graph_engine/entities/graph_runtime_state.py @@ -17,8 +17,12 @@ class GraphRuntimeState(BaseModel): """total tokens""" llm_usage: LLMUsage = LLMUsage.empty_usage() """llm usage info""" + + # The `outputs` field stores the final output values generated by executing workflows or chatflows. + # + # Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent + # after a serialization and deserialization round trip. outputs: dict[str, Any] = {} - """outputs""" node_run_steps: int = 0 """node run steps""" diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index f37672aaad..6c2d828ad0 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -12,7 +12,7 @@ from typing import Any, Optional, cast from flask import Flask, current_app from configs import dify_config -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import AgentNodeStrategyInit, NodeRunResult from core.workflow.entities.variable_pool import VariablePool, VariableValue @@ -48,11 +48,9 @@ from core.workflow.nodes.agent.entities import AgentNodeData from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor from core.workflow.nodes.answer.base_stream_processor import StreamProcessor from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.base.entities import BaseNodeData from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent -from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.utils import variable_utils from libs.flask_utils import preserve_flask_contexts from models.enums import UserFrom @@ -269,12 +267,16 @@ class GraphEngine: # convert to specific node node_type = NodeType(node_config.get("data", {}).get("type")) node_version = node_config.get("data", {}).get("version", "1") + + # Import here to avoid circular import + from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING + node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version] previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None # init workflow run state - node_instance = node_cls( # type: ignore + node = node_cls( id=route_node_state.id, config=node_config, graph_init_params=self.init_params, @@ -283,11 +285,11 @@ class GraphEngine: previous_node_id=previous_node_id, thread_pool_id=self.thread_pool_id, ) - node_instance = cast(BaseNode[BaseNodeData], node_instance) + node.init_node_data(node_config.get("data", {})) try: # run node generator = self._run_node( - node_instance=node_instance, + node=node, route_node_state=route_node_state, parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, @@ -315,16 +317,16 @@ class GraphEngine: route_node_state.failed_reason = str(e) yield NodeRunFailedEvent( error=str(e), - id=node_instance.id, + id=node.id, node_id=next_node_id, node_type=node_type, - node_data=node_instance.node_data, + node_data=node.get_base_node_data(), route_node_state=route_node_state, parallel_id=in_parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) raise e @@ -346,7 +348,7 @@ class GraphEngine: edge = edge_mappings[0] if ( previous_route_node_state.status == RouteNodeState.Status.EXCEPTION - and node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH + and node.error_strategy == ErrorStrategy.FAIL_BRANCH and edge.run_condition is None ): break @@ -422,8 +424,8 @@ class GraphEngine: next_node_id = final_node_id elif ( - node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH - and node_instance.should_continue_on_error + node.continue_on_error + and node.error_strategy == ErrorStrategy.FAIL_BRANCH and previous_route_node_state.status == RouteNodeState.Status.EXCEPTION ): break @@ -606,7 +608,7 @@ class GraphEngine: def _run_node( self, - node_instance: BaseNode[BaseNodeData], + node: BaseNode, route_node_state: RouteNodeState, parallel_id: Optional[str] = None, parallel_start_node_id: Optional[str] = None, @@ -620,29 +622,29 @@ class GraphEngine: # trigger node run start event agent_strategy = ( AgentNodeStrategyInit( - name=cast(AgentNodeData, node_instance.node_data).agent_strategy_name, - icon=cast(AgentNode, node_instance).agent_strategy_icon, + name=cast(AgentNodeData, node.get_base_node_data()).agent_strategy_name, + icon=cast(AgentNode, node).agent_strategy_icon, ) - if node_instance.node_type == NodeType.AGENT + if node.type_ == NodeType.AGENT else None ) yield NodeRunStartedEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, - predecessor_node_id=node_instance.previous_node_id, + predecessor_node_id=node.previous_node_id, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, agent_strategy=agent_strategy, - node_version=node_instance.version(), + node_version=node.version(), ) - max_retries = node_instance.node_data.retry_config.max_retries - retry_interval = node_instance.node_data.retry_config.retry_interval_seconds + max_retries = node.retry_config.max_retries + retry_interval = node.retry_config.retry_interval_seconds retries = 0 should_continue_retry = True while should_continue_retry and retries <= max_retries: @@ -651,7 +653,7 @@ class GraphEngine: retry_start_at = datetime.now(UTC).replace(tzinfo=None) # yield control to other threads time.sleep(0.001) - event_stream = node_instance.run() + event_stream = node.run() for event in event_stream: if isinstance(event, GraphEngineEvent): # add parallel info to iteration event @@ -667,21 +669,21 @@ class GraphEngine: if run_result.status == WorkflowNodeExecutionStatus.FAILED: if ( retries == max_retries - and node_instance.node_type == NodeType.HTTP_REQUEST + and node.type_ == NodeType.HTTP_REQUEST and run_result.outputs - and not node_instance.should_continue_on_error + and not node.continue_on_error ): run_result.status = WorkflowNodeExecutionStatus.SUCCEEDED - if node_instance.should_retry and retries < max_retries: + if node.retry and retries < max_retries: retries += 1 route_node_state.node_run_result = run_result yield NodeRunRetryEvent( id=str(uuid.uuid4()), - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, - predecessor_node_id=node_instance.previous_node_id, + predecessor_node_id=node.previous_node_id, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, @@ -689,17 +691,17 @@ class GraphEngine: error=run_result.error or "Unknown error", retry_index=retries, start_at=retry_start_at, - node_version=node_instance.version(), + node_version=node.version(), ) time.sleep(retry_interval) break route_node_state.set_finished(run_result=run_result) if run_result.status == WorkflowNodeExecutionStatus.FAILED: - if node_instance.should_continue_on_error: + if node.continue_on_error: # if run failed, handle error run_result = self._handle_continue_on_error( - node_instance, + node, event.run_result, self.graph_runtime_state.variable_pool, handle_exceptions=handle_exceptions, @@ -710,44 +712,44 @@ class GraphEngine: for variable_key, variable_value in run_result.outputs.items(): # append variables to variable pool recursively self._append_variables_recursively( - node_id=node_instance.node_id, + node_id=node.node_id, variable_key_list=[variable_key], variable_value=variable_value, ) yield NodeRunExceptionEvent( error=run_result.error or "System Error", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) should_continue_retry = False else: yield NodeRunFailedEvent( error=route_node_state.failed_reason or "Unknown error.", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) should_continue_retry = False elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED: if ( - node_instance.should_continue_on_error - and self.graph.edge_mapping.get(node_instance.node_id) - and node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH + node.continue_on_error + and self.graph.edge_mapping.get(node.node_id) + and node.error_strategy is ErrorStrategy.FAIL_BRANCH ): run_result.edge_source_handle = FailBranchSourceHandle.SUCCESS if run_result.metadata and run_result.metadata.get( @@ -767,7 +769,7 @@ class GraphEngine: for variable_key, variable_value in run_result.outputs.items(): # append variables to variable pool recursively self._append_variables_recursively( - node_id=node_instance.node_id, + node_id=node.node_id, variable_key_list=[variable_key], variable_value=variable_value, ) @@ -792,26 +794,26 @@ class GraphEngine: run_result.metadata = metadata_dict yield NodeRunSucceededEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) should_continue_retry = False break elif isinstance(event, RunStreamChunkEvent): yield NodeRunStreamChunkEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), chunk_content=event.chunk_content, from_variable_selector=event.from_variable_selector, route_node_state=route_node_state, @@ -819,14 +821,14 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) elif isinstance(event, RunRetrieverResourceEvent): yield NodeRunRetrieverResourceEvent( - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), retriever_resources=event.retriever_resources, context=event.context, route_node_state=route_node_state, @@ -834,7 +836,7 @@ class GraphEngine: parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) except GenerateTaskStoppedError: # trigger node run failed event @@ -842,20 +844,20 @@ class GraphEngine: route_node_state.failed_reason = "Workflow stopped." yield NodeRunFailedEvent( error="Workflow stopped.", - id=node_instance.id, - node_id=node_instance.node_id, - node_type=node_instance.node_type, - node_data=node_instance.node_data, + id=node.id, + node_id=node.node_id, + node_type=node.type_, + node_data=node.get_base_node_data(), route_node_state=route_node_state, parallel_id=parallel_id, parallel_start_node_id=parallel_start_node_id, parent_parallel_id=parent_parallel_id, parent_parallel_start_node_id=parent_parallel_start_node_id, - node_version=node_instance.version(), + node_version=node.version(), ) return except Exception as e: - logger.exception(f"Node {node_instance.node_data.title} run failed") + logger.exception(f"Node {node.title} run failed") raise e def _append_variables_recursively(self, node_id: str, variable_key_list: list[str], variable_value: VariableValue): @@ -895,22 +897,14 @@ class GraphEngine: def _handle_continue_on_error( self, - node_instance: BaseNode[BaseNodeData], + node: BaseNode, error_result: NodeRunResult, variable_pool: VariablePool, handle_exceptions: list[str] = [], ) -> NodeRunResult: - """ - handle continue on error when self._should_continue_on_error is True - - - :param error_result (NodeRunResult): error run result - :param variable_pool (VariablePool): variable pool - :return: excption run result - """ # add error message and error type to variable pool - variable_pool.add([node_instance.node_id, "error_message"], error_result.error) - variable_pool.add([node_instance.node_id, "error_type"], error_result.error_type) + variable_pool.add([node.node_id, "error_message"], error_result.error) + variable_pool.add([node.node_id, "error_type"], error_result.error_type) # add error message to handle_exceptions handle_exceptions.append(error_result.error or "") node_error_args: dict[str, Any] = { @@ -918,21 +912,21 @@ class GraphEngine: "error": error_result.error, "inputs": error_result.inputs, "metadata": { - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node_instance.node_data.error_strategy, + WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: node.error_strategy, }, } - if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + if node.error_strategy is ErrorStrategy.DEFAULT_VALUE: return NodeRunResult( **node_error_args, outputs={ - **node_instance.node_data.default_value_dict, + **node.default_value_dict, "error_message": error_result.error, "error_type": error_result.error_type, }, ) - elif node_instance.node_data.error_strategy is ErrorStrategy.FAIL_BRANCH: - if self.graph.edge_mapping.get(node_instance.node_id): + elif node.error_strategy is ErrorStrategy.FAIL_BRANCH: + if self.graph.edge_mapping.get(node.node_id): node_error_args["edge_source_handle"] = FailBranchSourceHandle.FAILED return NodeRunResult( **node_error_args, diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 678b99d546..704eb6a3ac 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -1,62 +1,100 @@ import json -import uuid from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast from packaging.version import Version +from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter from core.agent.strategy.plugin import PluginAgentStrategy +from core.file import File, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager +from core.model_runtime.entities.llm_entities import LLMUsage from core.model_runtime.entities.model_entities import AIModelEntity, ModelType +from core.plugin.entities.request import InvokeCredentials from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.plugin import PluginInstaller from core.provider_manager import ProviderManager -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolProviderType +from core.tools.entities.tool_entities import ( + ToolIdentity, + ToolInvokeMessage, + ToolParameter, + ToolProviderType, +) from core.tools.tool_manager import ToolManager -from core.variables.segments import StringSegment +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.variables.segments import ArrayFileSegment, StringSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool -from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey +from core.workflow.graph_engine.entities.event import AgentLogEvent from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated -from core.workflow.nodes.base.entities import BaseNodeData -from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.event.event import RunCompletedEvent -from core.workflow.nodes.tool.tool_node import ToolNode +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType +from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db +from factories import file_factory from factories.agent_factory import get_plugin_agent_strategy +from models import ToolFile from models.model import Conversation +from services.tools.builtin_tools_manage_service import BuiltinToolManageService +from .exc import ( + AgentInputTypeError, + AgentInvocationError, + AgentMessageTransformError, + AgentVariableNotFoundError, + AgentVariableTypeError, + ToolFileNotFoundError, +) -class AgentNode(ToolNode): + +class AgentNode(BaseNode): """ Agent Node """ - _node_data_cls = AgentNodeData # type: ignore _node_type = NodeType.AGENT + _node_data: AgentNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = AgentNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data @classmethod def version(cls) -> str: return "1" def _run(self) -> Generator: - """ - Run the agent node - """ - node_data = cast(AgentNodeData, self.node_data) - try: strategy = get_plugin_agent_strategy( tenant_id=self.tenant_id, - agent_strategy_provider_name=node_data.agent_strategy_provider_name, - agent_strategy_name=node_data.agent_strategy_name, + agent_strategy_provider_name=self._node_data.agent_strategy_provider_name, + agent_strategy_name=self._node_data.agent_strategy_name, ) except Exception as e: yield RunCompletedEvent( @@ -74,16 +112,17 @@ class AgentNode(ToolNode): parameters = self._generate_agent_parameters( agent_parameters=agent_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=node_data, + node_data=self._node_data, strategy=strategy, ) parameters_for_log = self._generate_agent_parameters( agent_parameters=agent_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=node_data, + node_data=self._node_data, for_log=True, strategy=strategy, ) + credentials = self._generate_credentials(parameters=parameters) # get conversation id conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) @@ -94,61 +133,42 @@ class AgentNode(ToolNode): user_id=self.user_id, app_id=self.app_id, conversation_id=conversation_id.text if conversation_id else None, + credentials=credentials, ) except Exception as e: + error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e) yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - error=f"Failed to invoke agent: {str(e)}", + error=str(error), ) ) return try: - # convert tool messages - agent_thoughts: list = [] - - thought_log_message = ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LOG, - message=ToolInvokeMessage.LogMessage( - id=str(uuid.uuid4()), - label=f"Agent Strategy: {cast(AgentNodeData, self.node_data).agent_strategy_name}", - parent_id=None, - error=None, - status=ToolInvokeMessage.LogMessage.LogStatus.START, - data={ - "strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, - "parameters": parameters_for_log, - "thought_process": "Agent strategy execution started", - }, - metadata={ - "icon": self.agent_strategy_icon, - "agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, - }, - ), - ) - - def enhanced_message_stream(): - yield thought_log_message - - yield from message_stream - yield from self._transform_message( - message_stream, - { + messages=message_stream, + tool_info={ "icon": self.agent_strategy_icon, - "agent_strategy": cast(AgentNodeData, self.node_data).agent_strategy_name, + "agent_strategy": cast(AgentNodeData, self._node_data).agent_strategy_name, }, - parameters_for_log, - agent_thoughts, + parameters_for_log=parameters_for_log, + user_id=self.user_id, + tenant_id=self.tenant_id, + node_type=self.type_, + node_id=self.node_id, + node_execution_id=self.id, ) except PluginDaemonClientSideError as e: + transform_error = AgentMessageTransformError( + f"Failed to transform agent message: {str(e)}", original_error=e + ) yield RunCompletedEvent( run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=parameters_for_log, - error=f"Failed to transform agent message: {str(e)}", + error=str(transform_error), ) ) @@ -185,7 +205,7 @@ class AgentNode(ToolNode): if agent_input.type == "variable": variable = variable_pool.get(agent_input.value) # type: ignore if variable is None: - raise ValueError(f"Variable {agent_input.value} does not exist") + raise AgentVariableNotFoundError(str(agent_input.value)) parameter_value = variable.value elif agent_input.type in {"mixed", "constant"}: # variable_pool.convert_template expects a string template, @@ -207,7 +227,7 @@ class AgentNode(ToolNode): except json.JSONDecodeError: parameter_value = parameter_value else: - raise ValueError(f"Unknown agent input type '{agent_input.type}'") + raise AgentInputTypeError(agent_input.type) value = parameter_value if parameter.type == "array[tools]": value = cast(list[dict[str, Any]], value) @@ -246,10 +266,18 @@ class AgentNode(ToolNode): tool_name=tool.get("tool_name", ""), tool_parameters=parameters, plugin_unique_identifier=tool.get("plugin_unique_identifier", None), + credential_id=tool.get("credential_id", None), ) extra = tool.get("extra", {}) - runtime_variable_pool = variable_pool if self.node_data.version != "1" else None + + # This is an issue that caused problems before. + # Logically, we shouldn't use the node_data.version field for judgment + # But for backward compatibility with historical data + # this version field judgment is still preserved here. + runtime_variable_pool: VariablePool | None = None + if node_data.version != "1" or node_data.tool_node_version != "1": + runtime_variable_pool = variable_pool tool_runtime = ToolManager.get_agent_tool_runtime( self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool ) @@ -276,6 +304,7 @@ class AgentNode(ToolNode): { **tool_runtime.entity.model_dump(mode="json"), "runtime_parameters": runtime_parameters, + "credential_id": tool.get("credential_id", None), "provider_type": provider_type.value, } ) @@ -305,25 +334,41 @@ class AgentNode(ToolNode): return result + def _generate_credentials( + self, + parameters: dict[str, Any], + ) -> InvokeCredentials: + """ + Generate credentials based on the given agent parameters. + """ + + credentials = InvokeCredentials() + + # generate credentials for tools selector + credentials.tool_credentials = {} + for tool in parameters.get("tools", []): + if tool.get("credential_id"): + try: + identity = ToolIdentity.model_validate(tool.get("identity", {})) + credentials.tool_credentials[identity.provider] = tool.get("credential_id", None) + except ValidationError: + continue + return credentials + @classmethod def _extract_variable_selector_to_variable_mapping( cls, *, graph_config: Mapping[str, Any], node_id: str, - node_data: BaseNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - node_data = cast(AgentNodeData, node_data) + # Create typed NodeData from dict + typed_node_data = AgentNodeData.model_validate(node_data) + result: dict[str, Any] = {} - for parameter_name in node_data.agent_parameters: - input = node_data.agent_parameters[parameter_name] + for parameter_name in typed_node_data.agent_parameters: + input = typed_node_data.agent_parameters[parameter_name] if input.type in ["mixed", "constant"]: selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors() for selector in selectors: @@ -348,7 +393,7 @@ class AgentNode(ToolNode): plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" - == cast(AgentNodeData, self.node_data).agent_strategy_provider_name + == cast(AgentNodeData, self._node_data).agent_strategy_provider_name ) icon = current_plugin.declaration.icon except StopIteration: @@ -416,3 +461,236 @@ class AgentNode(ToolNode): return tools else: return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP.value] + + def _transform_message( + self, + messages: Generator[ToolInvokeMessage, None, None], + tool_info: Mapping[str, Any], + parameters_for_log: dict[str, Any], + user_id: str, + tenant_id: str, + node_type: NodeType, + node_id: str, + node_execution_id: str, + ) -> Generator: + """ + Convert ToolInvokeMessages into tuple[plain_text, files] + """ + # transform message and handle file storage + message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=user_id, + tenant_id=tenant_id, + conversation_id=None, + ) + + text = "" + files: list[File] = [] + json: list[dict] = [] + + agent_logs: list[AgentLogEvent] = [] + agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} + llm_usage: LLMUsage | None = None + variables: dict[str, Any] = {} + + for message in message_stream: + if message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + ToolInvokeMessage.MessageType.IMAGE, + }: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + + url = message.message.text + if message.meta: + transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + else: + transfer_method = FileTransferMethod.TOOL_FILE + + tool_file_id = str(url).split("/")[-1].split(".")[0] + + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileNotFoundError(tool_file_id) + + mapping = { + "tool_file_id": tool_file_id, + "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "transfer_method": transfer_method, + "url": url, + } + file = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + files.append(file) + elif message.type == ToolInvokeMessage.MessageType.BLOB: + # get tool file id + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + assert message.meta + + tool_file_id = message.message.text.split("/")[-1].split(".")[0] + with Session(db.engine) as session: + stmt = select(ToolFile).where(ToolFile.id == tool_file_id) + tool_file = session.scalar(stmt) + if tool_file is None: + raise ToolFileNotFoundError(tool_file_id) + + mapping = { + "tool_file_id": tool_file_id, + "transfer_method": FileTransferMethod.TOOL_FILE, + } + + files.append( + file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + ) + ) + elif message.type == ToolInvokeMessage.MessageType.TEXT: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + text += message.message.text + yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"]) + elif message.type == ToolInvokeMessage.MessageType.JSON: + assert isinstance(message.message, ToolInvokeMessage.JsonMessage) + if node_type == NodeType.AGENT: + msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) + llm_usage = LLMUsage.from_metadata(msg_metadata) + agent_execution_metadata = { + WorkflowNodeExecutionMetadataKey(key): value + for key, value in msg_metadata.items() + if key in WorkflowNodeExecutionMetadataKey.__members__.values() + } + if message.message.json_object is not None: + json.append(message.message.json_object) + elif message.type == ToolInvokeMessage.MessageType.LINK: + assert isinstance(message.message, ToolInvokeMessage.TextMessage) + stream_text = f"Link: {message.message.text}\n" + text += stream_text + yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"]) + elif message.type == ToolInvokeMessage.MessageType.VARIABLE: + assert isinstance(message.message, ToolInvokeMessage.VariableMessage) + variable_name = message.message.variable_name + variable_value = message.message.variable_value + if message.message.stream: + if not isinstance(variable_value, str): + raise AgentVariableTypeError( + "When 'stream' is True, 'variable_value' must be a string.", + variable_name=variable_name, + expected_type="str", + actual_type=type(variable_value).__name__, + ) + if variable_name not in variables: + variables[variable_name] = "" + variables[variable_name] += variable_value + + yield RunStreamChunkEvent( + chunk_content=variable_value, from_variable_selector=[node_id, variable_name] + ) + else: + variables[variable_name] = variable_value + elif message.type == ToolInvokeMessage.MessageType.FILE: + assert message.meta is not None + assert isinstance(message.meta, File) + files.append(message.meta["file"]) + elif message.type == ToolInvokeMessage.MessageType.LOG: + assert isinstance(message.message, ToolInvokeMessage.LogMessage) + if message.message.metadata: + icon = tool_info.get("icon", "") + dict_metadata = dict(message.message.metadata) + if dict_metadata.get("provider"): + manager = PluginInstaller() + plugins = manager.list_plugins(tenant_id) + try: + current_plugin = next( + plugin + for plugin in plugins + if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] + ) + icon = current_plugin.declaration.icon + except StopIteration: + pass + icon_dark = None + try: + builtin_tool = next( + provider + for provider in BuiltinToolManageService.list_builtin_tools( + user_id, + tenant_id, + ) + if provider.name == dict_metadata["provider"] + ) + icon = builtin_tool.icon + icon_dark = builtin_tool.icon_dark + except StopIteration: + pass + + dict_metadata["icon"] = icon + dict_metadata["icon_dark"] = icon_dark + message.message.metadata = dict_metadata + agent_log = AgentLogEvent( + id=message.message.id, + node_execution_id=node_execution_id, + parent_id=message.message.parent_id, + error=message.message.error, + status=message.message.status.value, + data=message.message.data, + label=message.message.label, + metadata=message.message.metadata, + node_id=node_id, + ) + + # check if the agent log is already in the list + for log in agent_logs: + if log.id == agent_log.id: + # update the log + log.data = agent_log.data + log.status = agent_log.status + log.error = agent_log.error + log.label = agent_log.label + log.metadata = agent_log.metadata + break + else: + agent_logs.append(agent_log) + + yield agent_log + + # Add agent_logs to outputs['json'] to ensure frontend can access thinking process + json_output: list[dict[str, Any]] = [] + + # Step 1: append each agent log as its own dict. + if agent_logs: + for log in agent_logs: + json_output.append( + { + "id": log.id, + "parent_id": log.parent_id, + "error": log.error, + "status": log.status, + "data": log.data, + "label": log.label, + "metadata": log.metadata, + "node_id": log.node_id, + } + ) + # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] + if json: + json_output.extend(json) + else: + json_output.append({"data": []}) + + yield RunCompletedEvent( + run_result=NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, + metadata={ + **agent_execution_metadata, + WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, + WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, + }, + inputs=parameters_for_log, + llm_usage=llm_usage, + ) + ) diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 075a41fb2f..11b11068e7 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -13,6 +13,10 @@ class AgentNodeData(BaseNodeData): agent_strategy_name: str agent_strategy_label: str # redundancy memory: MemoryConfig | None = None + # The version of the tool parameter. + # If this value is None, it indicates this is a previous version + # and requires using the legacy parameter parsing rules. + tool_node_version: str | None = None class AgentInput(BaseModel): value: Union[list[str], list[ToolSelector], Any] diff --git a/api/core/workflow/nodes/agent/exc.py b/api/core/workflow/nodes/agent/exc.py new file mode 100644 index 0000000000..d5955bdd7d --- /dev/null +++ b/api/core/workflow/nodes/agent/exc.py @@ -0,0 +1,124 @@ +from typing import Optional + + +class AgentNodeError(Exception): + """Base exception for all agent node errors.""" + + def __init__(self, message: str): + self.message = message + super().__init__(self.message) + + +class AgentStrategyError(AgentNodeError): + """Exception raised when there's an error with the agent strategy.""" + + def __init__(self, message: str, strategy_name: Optional[str] = None, provider_name: Optional[str] = None): + self.strategy_name = strategy_name + self.provider_name = provider_name + super().__init__(message) + + +class AgentStrategyNotFoundError(AgentStrategyError): + """Exception raised when the specified agent strategy is not found.""" + + def __init__(self, strategy_name: str, provider_name: Optional[str] = None): + super().__init__( + f"Agent strategy '{strategy_name}' not found" + + (f" for provider '{provider_name}'" if provider_name else ""), + strategy_name, + provider_name, + ) + + +class AgentInvocationError(AgentNodeError): + """Exception raised when there's an error invoking the agent.""" + + def __init__(self, message: str, original_error: Optional[Exception] = None): + self.original_error = original_error + super().__init__(message) + + +class AgentParameterError(AgentNodeError): + """Exception raised when there's an error with agent parameters.""" + + def __init__(self, message: str, parameter_name: Optional[str] = None): + self.parameter_name = parameter_name + super().__init__(message) + + +class AgentVariableError(AgentNodeError): + """Exception raised when there's an error with variables in the agent node.""" + + def __init__(self, message: str, variable_name: Optional[str] = None): + self.variable_name = variable_name + super().__init__(message) + + +class AgentVariableNotFoundError(AgentVariableError): + """Exception raised when a variable is not found in the variable pool.""" + + def __init__(self, variable_name: str): + super().__init__(f"Variable '{variable_name}' does not exist", variable_name) + + +class AgentInputTypeError(AgentNodeError): + """Exception raised when an unknown agent input type is encountered.""" + + def __init__(self, input_type: str): + super().__init__(f"Unknown agent input type '{input_type}'") + + +class ToolFileError(AgentNodeError): + """Exception raised when there's an error with a tool file.""" + + def __init__(self, message: str, file_id: Optional[str] = None): + self.file_id = file_id + super().__init__(message) + + +class ToolFileNotFoundError(ToolFileError): + """Exception raised when a tool file is not found.""" + + def __init__(self, file_id: str): + super().__init__(f"Tool file '{file_id}' does not exist", file_id) + + +class AgentMessageTransformError(AgentNodeError): + """Exception raised when there's an error transforming agent messages.""" + + def __init__(self, message: str, original_error: Optional[Exception] = None): + self.original_error = original_error + super().__init__(message) + + +class AgentModelError(AgentNodeError): + """Exception raised when there's an error with the model used by the agent.""" + + def __init__(self, message: str, model_name: Optional[str] = None, provider: Optional[str] = None): + self.model_name = model_name + self.provider = provider + super().__init__(message) + + +class AgentMemoryError(AgentNodeError): + """Exception raised when there's an error with the agent's memory.""" + + def __init__(self, message: str, conversation_id: Optional[str] = None): + self.conversation_id = conversation_id + super().__init__(message) + + +class AgentVariableTypeError(AgentNodeError): + """Exception raised when a variable has an unexpected type.""" + + def __init__( + self, + message: str, + variable_name: Optional[str] = None, + expected_type: Optional[str] = None, + actual_type: Optional[str] = None, + ): + self.variable_name = variable_name + self.expected_type = expected_type + self.actual_type = actual_type + super().__init__(message) diff --git a/api/core/workflow/nodes/answer/answer_node.py b/api/core/workflow/nodes/answer/answer_node.py index 0877257d61..70d7500d47 100644 --- a/api/core/workflow/nodes/answer/answer_node.py +++ b/api/core/workflow/nodes/answer/answer_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any, Optional, cast from core.variables import ArrayFileSegment, FileSegment from core.workflow.entities.node_entities import NodeRunResult @@ -12,14 +12,37 @@ from core.workflow.nodes.answer.entities import ( VarGenerateRouteChunk, ) from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.utils.variable_template_parser import VariableTemplateParser -class AnswerNode(BaseNode[AnswerNodeData]): - _node_data_cls = AnswerNodeData +class AnswerNode(BaseNode): _node_type = NodeType.ANSWER + _node_data: AnswerNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = AnswerNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -30,7 +53,7 @@ class AnswerNode(BaseNode[AnswerNodeData]): :return: """ # generate routes - generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self.node_data) + generate_routes = AnswerStreamGeneratorRouter.extract_generate_route_from_node_data(self._node_data) answer = "" files = [] @@ -68,16 +91,12 @@ class AnswerNode(BaseNode[AnswerNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: AnswerNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - variable_template_parser = VariableTemplateParser(template=node_data.answer) + # Create typed NodeData from dict + typed_node_data = AnswerNodeData.model_validate(node_data) + + variable_template_parser = VariableTemplateParser(template=typed_node_data.answer) variable_selectors = variable_template_parser.extract_variable_selectors() variable_mapping = {} diff --git a/api/core/workflow/nodes/base/entities.py b/api/core/workflow/nodes/base/entities.py index d853eb71be..dcfed5eed2 100644 --- a/api/core/workflow/nodes/base/entities.py +++ b/api/core/workflow/nodes/base/entities.py @@ -122,13 +122,13 @@ class RetryConfig(BaseModel): class BaseNodeData(ABC, BaseModel): title: str desc: Optional[str] = None + version: str = "1" error_strategy: Optional[ErrorStrategy] = None default_value: Optional[list[DefaultValue]] = None - version: str = "1" retry_config: RetryConfig = RetryConfig() @property - def default_value_dict(self): + def default_value_dict(self) -> dict[str, Any]: if self.default_value: return {item.key: item.value for item in self.default_value} return {} diff --git a/api/core/workflow/nodes/base/node.py b/api/core/workflow/nodes/base/node.py index 6973401429..fb5ec55453 100644 --- a/api/core/workflow/nodes/base/node.py +++ b/api/core/workflow/nodes/base/node.py @@ -1,28 +1,22 @@ import logging from abc import abstractmethod from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar, Union, cast +from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.nodes.enums import CONTINUE_ON_ERROR_NODE_TYPE, RETRY_ON_ERROR_NODE_TYPE, NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent -from .entities import BaseNodeData - if TYPE_CHECKING: + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState from core.workflow.graph_engine.entities.event import InNodeEvent - from core.workflow.graph_engine.entities.graph import Graph - from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams - from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState logger = logging.getLogger(__name__) -GenericNodeData = TypeVar("GenericNodeData", bound=BaseNodeData) - -class BaseNode(Generic[GenericNodeData]): - _node_data_cls: type[GenericNodeData] +class BaseNode: _node_type: ClassVar[NodeType] def __init__( @@ -56,8 +50,8 @@ class BaseNode(Generic[GenericNodeData]): self.node_id = node_id - node_data = self._node_data_cls.model_validate(config.get("data", {})) - self.node_data = node_data + @abstractmethod + def init_node_data(self, data: Mapping[str, Any]) -> None: ... @abstractmethod def _run(self) -> NodeRunResult | Generator[Union[NodeEvent, "InNodeEvent"], None, None]: @@ -130,9 +124,9 @@ class BaseNode(Generic[GenericNodeData]): if not node_id: raise ValueError("Node ID is required when extracting variable selector to variable mapping.") - node_data = cls._node_data_cls(**config.get("data", {})) + # Pass raw dict data instead of creating NodeData instance data = cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, node_id=node_id, node_data=cast(GenericNodeData, node_data) + graph_config=graph_config, node_id=node_id, node_data=config.get("data", {}) ) return data @@ -142,32 +136,16 @@ class BaseNode(Generic[GenericNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: GenericNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ return {} @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ return {} @property - def node_type(self) -> NodeType: - """ - Get node type - :return: - """ + def type_(self) -> NodeType: return self._node_type @classmethod @@ -181,19 +159,68 @@ class BaseNode(Generic[GenericNodeData]): raise NotImplementedError("subclasses of BaseNode must implement `version` method.") @property - def should_continue_on_error(self) -> bool: - """judge if should continue on error + def continue_on_error(self) -> bool: + return False - Returns: - bool: if should continue on error - """ - return self.node_data.error_strategy is not None and self.node_type in CONTINUE_ON_ERROR_NODE_TYPE + @property + def retry(self) -> bool: + return False + + # Abstract methods that subclasses must implement to provide access + # to BaseNodeData properties in a type-safe way + + @abstractmethod + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + """Get the error strategy for this node.""" + ... + @abstractmethod + def _get_retry_config(self) -> RetryConfig: + """Get the retry configuration for this node.""" + ... + + @abstractmethod + def _get_title(self) -> str: + """Get the node title.""" + ... + + @abstractmethod + def _get_description(self) -> Optional[str]: + """Get the node description.""" + ... + + @abstractmethod + def _get_default_value_dict(self) -> dict[str, Any]: + """Get the default values dictionary for this node.""" + ... + + @abstractmethod + def get_base_node_data(self) -> BaseNodeData: + """Get the BaseNodeData object for this node.""" + ... + + # Public interface properties that delegate to abstract methods @property - def should_retry(self) -> bool: - """judge if should retry + def error_strategy(self) -> Optional[ErrorStrategy]: + """Get the error strategy for this node.""" + return self._get_error_strategy() - Returns: - bool: if should retry - """ - return self.node_data.retry_config.retry_enabled and self.node_type in RETRY_ON_ERROR_NODE_TYPE + @property + def retry_config(self) -> RetryConfig: + """Get the retry configuration for this node.""" + return self._get_retry_config() + + @property + def title(self) -> str: + """Get the node title.""" + return self._get_title() + + @property + def description(self) -> Optional[str]: + """Get the node description.""" + return self._get_description() + + @property + def default_value_dict(self) -> dict[str, Any]: + """Get the default values dictionary for this node.""" + return self._get_default_value_dict() diff --git a/api/core/workflow/nodes/code/code_node.py b/api/core/workflow/nodes/code/code_node.py index 22ed9e2651..fdf3932827 100644 --- a/api/core/workflow/nodes/code/code_node.py +++ b/api/core/workflow/nodes/code/code_node.py @@ -1,4 +1,5 @@ from collections.abc import Mapping, Sequence +from decimal import Decimal from typing import Any, Optional from configs import dify_config @@ -10,8 +11,9 @@ from core.variables.segments import ArrayFileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.code.entities import CodeNodeData -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType from .exc import ( CodeNodeError, @@ -20,10 +22,32 @@ from .exc import ( ) -class CodeNode(BaseNode[CodeNodeData]): - _node_data_cls = CodeNodeData +class CodeNode(BaseNode): _node_type = NodeType.CODE + _node_data: CodeNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = CodeNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ @@ -46,12 +70,12 @@ class CodeNode(BaseNode[CodeNodeData]): def _run(self) -> NodeRunResult: # Get code language - code_language = self.node_data.code_language - code = self.node_data.code + code_language = self._node_data.code_language + code = self._node_data.code # Get variables variables = {} - for variable_selector in self.node_data.variables: + for variable_selector in self._node_data.variables: variable_name = variable_selector.variable variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) if isinstance(variable, ArrayFileSegment): @@ -67,7 +91,7 @@ class CodeNode(BaseNode[CodeNodeData]): ) # Transform result - result = self._transform_result(result=result, output_schema=self.node_data.outputs) + result = self._transform_result(result=result, output_schema=self._node_data.outputs) except (CodeExecutionError, CodeNodeError) as e: return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ @@ -114,8 +138,10 @@ class CodeNode(BaseNode[CodeNodeData]): ) if isinstance(value, float): + decimal_value = Decimal(str(value)).normalize() + precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator] # raise error if precision is too high - if len(str(value).split(".")[1]) > dify_config.CODE_MAX_PRECISION: + if precision > dify_config.CODE_MAX_PRECISION: raise OutputValidationError( f"Output variable `{variable}` has too high precision," f" it must be less than {dify_config.CODE_MAX_PRECISION} digits." @@ -331,16 +357,20 @@ class CodeNode(BaseNode[CodeNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: CodeNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = CodeNodeData.model_validate(node_data) + return { node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables + for variable_selector in typed_node_data.variables } + + @property + def continue_on_error(self) -> bool: + return self._node_data.error_strategy is not None + + @property + def retry(self) -> bool: + return self._node_data.retry_config.retry_enabled diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 8e6150f9cc..ab5964ebd4 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -5,7 +5,7 @@ import logging import os import tempfile from collections.abc import Mapping, Sequence -from typing import Any, cast +from typing import Any, Optional, cast import chardet import docx @@ -28,7 +28,8 @@ from core.variables.segments import ArrayStringSegment, FileSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from .entities import DocumentExtractorNodeData from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError @@ -36,21 +37,43 @@ from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, logger = logging.getLogger(__name__) -class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): +class DocumentExtractorNode(BaseNode): """ Extracts text content from various file types. Supports plain text, PDF, and DOC/DOCX files. """ - _node_data_cls = DocumentExtractorNodeData _node_type = NodeType.DOCUMENT_EXTRACTOR + _node_data: DocumentExtractorNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = DocumentExtractorNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" def _run(self): - variable_selector = self.node_data.variable_selector + variable_selector = self._node_data.variable_selector variable = self.graph_runtime_state.variable_pool.get(variable_selector) if variable is None: @@ -97,16 +120,12 @@ class DocumentExtractorNode(BaseNode[DocumentExtractorNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: DocumentExtractorNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - return {node_id + ".files": node_data.variable_selector} + # Create typed NodeData from dict + typed_node_data = DocumentExtractorNodeData.model_validate(node_data) + + return {node_id + ".files": typed_node_data.variable_selector} def _extract_text_by_mime_type(*, file_content: bytes, mime_type: str) -> str: diff --git a/api/core/workflow/nodes/end/end_node.py b/api/core/workflow/nodes/end/end_node.py index 17a0b3adeb..f86f2e8129 100644 --- a/api/core/workflow/nodes/end/end_node.py +++ b/api/core/workflow/nodes/end/end_node.py @@ -1,14 +1,40 @@ +from collections.abc import Mapping +from typing import Any, Optional + from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.end.entities import EndNodeData -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType -class EndNode(BaseNode[EndNodeData]): - _node_data_cls = EndNodeData +class EndNode(BaseNode): _node_type = NodeType.END + _node_data: EndNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = EndNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -18,7 +44,7 @@ class EndNode(BaseNode[EndNodeData]): Run node :return: """ - output_variables = self.node_data.outputs + output_variables = self._node_data.outputs outputs = {} for variable_selector in output_variables: diff --git a/api/core/workflow/nodes/enums.py b/api/core/workflow/nodes/enums.py index 73b43eeaf7..7cf9ab9107 100644 --- a/api/core/workflow/nodes/enums.py +++ b/api/core/workflow/nodes/enums.py @@ -35,7 +35,3 @@ class ErrorStrategy(StrEnum): class FailBranchSourceHandle(StrEnum): FAILED = "fail-branch" SUCCESS = "success-branch" - - -CONTINUE_ON_ERROR_NODE_TYPE = [NodeType.LLM, NodeType.CODE, NodeType.TOOL, NodeType.HTTP_REQUEST] -RETRY_ON_ERROR_NODE_TYPE = CONTINUE_ON_ERROR_NODE_TYPE diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 971e0f73e7..6799d5c63c 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -11,7 +11,8 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.http_request.executor import Executor from core.workflow.utils import variable_template_parser from factories import file_factory @@ -32,10 +33,32 @@ HTTP_REQUEST_DEFAULT_TIMEOUT = HttpRequestNodeTimeout( logger = logging.getLogger(__name__) -class HttpRequestNode(BaseNode[HttpRequestNodeData]): - _node_data_cls = HttpRequestNodeData +class HttpRequestNode(BaseNode): _node_type = NodeType.HTTP_REQUEST + _node_data: HttpRequestNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = HttpRequestNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def get_default_config(cls, filters: Optional[dict[str, Any]] = None) -> dict: return { @@ -69,8 +92,8 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): process_data = {} try: http_executor = Executor( - node_data=self.node_data, - timeout=self._get_request_timeout(self.node_data), + node_data=self._node_data, + timeout=self._get_request_timeout(self._node_data), variable_pool=self.graph_runtime_state.variable_pool, max_retries=0, ) @@ -78,7 +101,7 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): response = http_executor.invoke() files = self.extract_files(url=http_executor.url, response=response) - if not response.response.is_success and (self.should_continue_on_error or self.should_retry): + if not response.response.is_success and (self.continue_on_error or self.retry): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, outputs={ @@ -131,15 +154,18 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: HttpRequestNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: + # Create typed NodeData from dict + typed_node_data = HttpRequestNodeData.model_validate(node_data) + selectors: list[VariableSelector] = [] - selectors += variable_template_parser.extract_selectors_from_template(node_data.url) - selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) - selectors += variable_template_parser.extract_selectors_from_template(node_data.params) - if node_data.body: - body_type = node_data.body.type - data = node_data.body.data + selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.url) + selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.headers) + selectors += variable_template_parser.extract_selectors_from_template(typed_node_data.params) + if typed_node_data.body: + body_type = typed_node_data.body.type + data = typed_node_data.body.data match body_type: case "binary": if len(data) != 1: @@ -217,3 +243,11 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]): files.append(file) return ArrayFileSegment(value=files) + + @property + def continue_on_error(self) -> bool: + return self._node_data.error_strategy is not None + + @property + def retry(self) -> bool: + return self._node_data.retry_config.retry_enabled diff --git a/api/core/workflow/nodes/if_else/if_else_node.py b/api/core/workflow/nodes/if_else/if_else_node.py index 22b748030c..86e703dc68 100644 --- a/api/core/workflow/nodes/if_else/if_else_node.py +++ b/api/core/workflow/nodes/if_else/if_else_node.py @@ -1,5 +1,5 @@ from collections.abc import Mapping, Sequence -from typing import Any, Literal +from typing import Any, Literal, Optional from typing_extensions import deprecated @@ -7,16 +7,39 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.if_else.entities import IfElseNodeData from core.workflow.utils.condition.entities import Condition from core.workflow.utils.condition.processor import ConditionProcessor -class IfElseNode(BaseNode[IfElseNodeData]): - _node_data_cls = IfElseNodeData +class IfElseNode(BaseNode): _node_type = NodeType.IF_ELSE + _node_data: IfElseNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = IfElseNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -36,8 +59,8 @@ class IfElseNode(BaseNode[IfElseNodeData]): condition_processor = ConditionProcessor() try: # Check if the new cases structure is used - if self.node_data.cases: - for case in self.node_data.cases: + if self._node_data.cases: + for case in self._node_data.cases: input_conditions, group_result, final_result = condition_processor.process_conditions( variable_pool=self.graph_runtime_state.variable_pool, conditions=case.conditions, @@ -63,8 +86,8 @@ class IfElseNode(BaseNode[IfElseNodeData]): input_conditions, group_result, final_result = _should_not_use_old_function( condition_processor=condition_processor, variable_pool=self.graph_runtime_state.variable_pool, - conditions=self.node_data.conditions or [], - operator=self.node_data.logical_operator or "and", + conditions=self._node_data.conditions or [], + operator=self._node_data.logical_operator or "and", ) selected_case_id = "true" if final_result else "false" @@ -98,10 +121,13 @@ class IfElseNode(BaseNode[IfElseNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: IfElseNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: + # Create typed NodeData from dict + typed_node_data = IfElseNodeData.model_validate(node_data) + var_mapping: dict[str, list[str]] = {} - for case in node_data.cases or []: + for case in typed_node_data.cases or []: for condition in case.conditions: key = "{}.#{}#".format(node_id, ".".join(condition.variable_selector)) var_mapping[key] = condition.variable_selector diff --git a/api/core/workflow/nodes/iteration/iteration_node.py b/api/core/workflow/nodes/iteration/iteration_node.py index 8b566c83cd..5842c8d64b 100644 --- a/api/core/workflow/nodes/iteration/iteration_node.py +++ b/api/core/workflow/nodes/iteration/iteration_node.py @@ -36,7 +36,8 @@ from core.workflow.graph_engine.entities.event import ( ) from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.iteration.entities import ErrorHandleMode, IterationNodeData from factories.variable_factory import build_segment @@ -56,14 +57,36 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class IterationNode(BaseNode[IterationNodeData]): +class IterationNode(BaseNode): """ Iteration Node. """ - _node_data_cls = IterationNodeData _node_type = NodeType.ITERATION + _node_data: IterationNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = IterationNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: return { @@ -83,10 +106,10 @@ class IterationNode(BaseNode[IterationNodeData]): """ Run the node. """ - variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) + variable = self.graph_runtime_state.variable_pool.get(self._node_data.iterator_selector) if not variable: - raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found") + raise IteratorVariableNotFoundError(f"iterator variable {self._node_data.iterator_selector} not found") if not isinstance(variable, ArrayVariable) and not isinstance(variable, NoneVariable): raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") @@ -116,10 +139,10 @@ class IterationNode(BaseNode[IterationNodeData]): graph_config = self.graph_config - if not self.node_data.start_node_id: + if not self._node_data.start_node_id: raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self.node_id} not found") - root_node_id = self.node_data.start_node_id + root_node_id = self._node_data.start_node_id # init graph iteration_graph = Graph.init(graph_config=graph_config, root_node_id=root_node_id) @@ -161,8 +184,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunStartedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, metadata={"iterator_length": len(iterator_list_value)}, @@ -172,8 +195,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, index=0, pre_iteration_output=None, duration=None, @@ -181,11 +204,11 @@ class IterationNode(BaseNode[IterationNodeData]): iter_run_map: dict[str, float] = {} outputs: list[Any] = [None] * len(iterator_list_value) try: - if self.node_data.is_parallel: + if self._node_data.is_parallel: futures: list[Future] = [] q: Queue = Queue() thread_pool = GraphEngineThreadPool( - max_workers=self.node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT + max_workers=self._node_data.parallel_nums, max_submit_count=dify_config.MAX_SUBMIT_COUNT ) for index, item in enumerate(iterator_list_value): future: Future = thread_pool.submit( @@ -242,7 +265,7 @@ class IterationNode(BaseNode[IterationNodeData]): iteration_graph=iteration_graph, iter_run_map=iter_run_map, ) - if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + if self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: outputs = [output for output in outputs if output is not None] # Flatten the list of lists @@ -253,8 +276,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunSucceededEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": outputs}, @@ -278,8 +301,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": outputs}, @@ -305,21 +328,17 @@ class IterationNode(BaseNode[IterationNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: IterationNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = IterationNodeData.model_validate(node_data) + variable_mapping: dict[str, Sequence[str]] = { - f"{node_id}.input_selector": node_data.iterator_selector, + f"{node_id}.input_selector": typed_node_data.iterator_selector, } # init graph - iteration_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) + iteration_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id) if not iteration_graph: raise IterationGraphNotFoundError("iteration graph not found") @@ -375,7 +394,7 @@ class IterationNode(BaseNode[IterationNodeData]): """ if not isinstance(event, BaseNodeEvent): return event - if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent): + if self._node_data.is_parallel and isinstance(event, NodeRunStartedEvent): event.parallel_mode_run_id = parallel_mode_run_id iter_metadata = { @@ -438,12 +457,12 @@ class IterationNode(BaseNode[IterationNodeData]): elif isinstance(event, BaseGraphEvent): if isinstance(event, GraphRunFailedEvent): # iteration run failed - if self.node_data.is_parallel: + if self._node_data.is_parallel: yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, parallel_mode_run_id=parallel_mode_run_id, start_at=start_at, inputs=inputs, @@ -456,8 +475,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": outputs}, @@ -478,7 +497,7 @@ class IterationNode(BaseNode[IterationNodeData]): event=event, iter_run_index=current_index, parallel_mode_run_id=parallel_mode_run_id ) if isinstance(event, NodeRunFailedEvent): - if self.node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: + if self._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR: yield NodeInIterationFailedEvent( **metadata_event.model_dump(), ) @@ -491,15 +510,15 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, pre_iteration_output=None, duration=duration, ) return - elif self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: + elif self._node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: yield NodeInIterationFailedEvent( **metadata_event.model_dump(), ) @@ -512,15 +531,15 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, pre_iteration_output=None, duration=duration, ) return - elif self.node_data.error_handle_mode == ErrorHandleMode.TERMINATED: + elif self._node_data.error_handle_mode == ErrorHandleMode.TERMINATED: yield NodeInIterationFailedEvent( **metadata_event.model_dump(), ) @@ -531,12 +550,12 @@ class IterationNode(BaseNode[IterationNodeData]): variable_pool.remove([node_id]) # iteration run failed - if self.node_data.is_parallel: + if self._node_data.is_parallel: yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, parallel_mode_run_id=parallel_mode_run_id, start_at=start_at, inputs=inputs, @@ -549,8 +568,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": outputs}, @@ -569,7 +588,7 @@ class IterationNode(BaseNode[IterationNodeData]): return yield metadata_event - current_output_segment = variable_pool.get(self.node_data.output_selector) + current_output_segment = variable_pool.get(self._node_data.output_selector) if current_output_segment is None: raise IterationNodeError("iteration output selector not found") current_iteration_output = current_output_segment.value @@ -588,8 +607,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunNextEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, index=next_index, parallel_mode_run_id=parallel_mode_run_id, pre_iteration_output=current_iteration_output or None, @@ -601,8 +620,8 @@ class IterationNode(BaseNode[IterationNodeData]): yield IterationRunFailedEvent( iteration_id=self.id, iteration_node_id=self.node_id, - iteration_node_type=self.node_type, - iteration_node_data=self.node_data, + iteration_node_type=self.type_, + iteration_node_data=self._node_data, start_at=start_at, inputs=inputs, outputs={"output": None}, diff --git a/api/core/workflow/nodes/iteration/iteration_start_node.py b/api/core/workflow/nodes/iteration/iteration_start_node.py index 9900aa225d..b82c29291a 100644 --- a/api/core/workflow/nodes/iteration/iteration_start_node.py +++ b/api/core/workflow/nodes/iteration/iteration_start_node.py @@ -1,18 +1,44 @@ +from collections.abc import Mapping +from typing import Any, Optional + from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.iteration.entities import IterationStartNodeData -class IterationStartNode(BaseNode[IterationStartNodeData]): +class IterationStartNode(BaseNode): """ Iteration Start Node. """ - _node_data_cls = IterationStartNodeData _node_type = NodeType.ITERATION_START + _node_data: IterationStartNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = IterationStartNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index 19bdee4fe2..f1767bdf9e 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,10 +1,10 @@ from collections.abc import Sequence -from typing import Any, Literal, Optional +from typing import Literal, Optional from pydantic import BaseModel, Field from core.workflow.nodes.base import BaseNodeData -from core.workflow.nodes.llm.entities import VisionConfig +from core.workflow.nodes.llm.entities import ModelConfig, VisionConfig class RerankingModelConfig(BaseModel): @@ -56,17 +56,6 @@ class MultipleRetrievalConfig(BaseModel): weights: Optional[WeightedScoreConfig] = None -class ModelConfig(BaseModel): - """ - Model Config. - """ - - provider: str - name: str - mode: str - completion_params: dict[str, Any] = {} - - class SingleRetrievalConfig(BaseModel): """ Single Retrieval Config. diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index f05d93d83e..5f092dc2f1 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -4,7 +4,7 @@ import re import time from collections import defaultdict from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from sqlalchemy import Float, and_, func, or_, text from sqlalchemy import cast as sqlalchemy_cast @@ -15,20 +15,31 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.model_manager import ModelInstance, ModelManager -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.model_runtime.entities.message_entities import ( + PromptMessageRole, +) +from core.model_runtime.entities.model_entities import ( + ModelFeature, + ModelType, +) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.prompt.simple_prompt_transform import ModelMode from core.rag.datasource.retrieval_service import RetrievalService from core.rag.entities.metadata_entities import Condition, MetadataCondition from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod -from core.variables import StringSegment +from core.variables import ( + StringSegment, +) from core.variables.segments import ArrayObjectSegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.nodes.enums import NodeType -from core.workflow.nodes.event.event import ModelInvokeCompletedEvent +from core.workflow.nodes.base import BaseNode +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType +from core.workflow.nodes.event import ( + ModelInvokeCompletedEvent, +) from core.workflow.nodes.knowledge_retrieval.template_prompts import ( METADATA_FILTER_ASSISTANT_PROMPT_1, METADATA_FILTER_ASSISTANT_PROMPT_2, @@ -38,7 +49,8 @@ from core.workflow.nodes.knowledge_retrieval.template_prompts import ( METADATA_FILTER_USER_PROMPT_2, METADATA_FILTER_USER_PROMPT_3, ) -from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate +from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig +from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver from core.workflow.nodes.llm.node import LLMNode from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -46,7 +58,7 @@ from libs.json_in_md_parser import parse_and_check_json_markdown from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog from services.feature_service import FeatureService -from .entities import KnowledgeRetrievalNodeData, ModelConfig +from .entities import KnowledgeRetrievalNodeData from .exc import ( InvalidModelTypeError, KnowledgeRetrievalNodeError, @@ -56,6 +68,10 @@ from .exc import ( ModelQuotaExceededError, ) +if TYPE_CHECKING: + from core.file.models import File + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState + logger = logging.getLogger(__name__) default_retrieval_model = { @@ -67,18 +83,76 @@ default_retrieval_model = { } -class KnowledgeRetrievalNode(LLMNode): - _node_data_cls = KnowledgeRetrievalNodeData # type: ignore +class KnowledgeRetrievalNode(BaseNode): _node_type = NodeType.KNOWLEDGE_RETRIEVAL + _node_data: KnowledgeRetrievalNodeData + + # Instance attributes specific to LLMNode. + # Output variable for file + _file_outputs: list["File"] + + _llm_file_saver: LLMFileSaver + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + *, + llm_file_saver: LLMFileSaver | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=thread_pool_id, + ) + # LLM file outputs, used for MultiModal outputs. + self._file_outputs: list[File] = [] + + if llm_file_saver is None: + llm_file_saver = FileSaverImpl( + user_id=graph_init_params.user_id, + tenant_id=graph_init_params.tenant_id, + ) + self._llm_file_saver = llm_file_saver + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = KnowledgeRetrievalNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls): return "1" def _run(self) -> NodeRunResult: # type: ignore - node_data = cast(KnowledgeRetrievalNodeData, self.node_data) # extract variables - variable = self.graph_runtime_state.variable_pool.get(node_data.query_variable_selector) + variable = self.graph_runtime_state.variable_pool.get(self._node_data.query_variable_selector) if not isinstance(variable, StringSegment): return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, @@ -119,7 +193,7 @@ class KnowledgeRetrievalNode(LLMNode): # retrieve knowledge try: - results = self._fetch_dataset_retriever(node_data=node_data, query=query) + results = self._fetch_dataset_retriever(node_data=self._node_data, query=query) outputs = {"result": ArrayObjectSegment(value=results)} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -435,20 +509,17 @@ class KnowledgeRetrievalNode(LLMNode): # get all metadata field metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] - # get metadata model config - metadata_model_config = node_data.metadata_model_config - if metadata_model_config is None: + if node_data.metadata_model_config is None: raise ValueError("metadata_model_config is required") - # get metadata model instance - # fetch model config - model_instance, model_config = self.get_model_config(metadata_model_config) + # get metadata model instance and fetch model config + model_instance, model_config = self.get_model_config(node_data.metadata_model_config) # fetch prompt messages prompt_template = self._get_prompt_template( node_data=node_data, metadata_fields=all_metadata_fields, query=query or "", ) - prompt_messages, stop = self._fetch_prompt_messages( + prompt_messages, stop = LLMNode.fetch_prompt_messages( prompt_template=prompt_template, sys_query=query, memory=None, @@ -458,16 +529,23 @@ class KnowledgeRetrievalNode(LLMNode): vision_detail=node_data.vision.configs.detail, variable_pool=self.graph_runtime_state.variable_pool, jinja2_variables=[], + tenant_id=self.tenant_id, ) result_text = "" try: # handle invoke result - generator = self._invoke_llm( - node_data_model=node_data.metadata_model_config, # type: ignore + generator = LLMNode.invoke_llm( + node_data_model=node_data.metadata_model_config, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, + user_id=self.user_id, + structured_output_enabled=self._node_data.structured_output_enabled, + structured_output=None, + file_saver=self._llm_file_saver, + file_outputs=self._file_outputs, + node_id=self.node_id, ) for event in generator: @@ -557,17 +635,13 @@ class KnowledgeRetrievalNode(LLMNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: KnowledgeRetrievalNodeData, # type: ignore + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = KnowledgeRetrievalNodeData.model_validate(node_data) + variable_mapping = {} - variable_mapping[node_id + ".query"] = node_data.query_variable_selector + variable_mapping[node_id + ".query"] = typed_node_data.query_variable_selector return variable_mapping def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: @@ -629,7 +703,7 @@ class KnowledgeRetrievalNode(LLMNode): ) def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str): - model_mode = ModelMode.value_of(node_data.metadata_model_config.mode) # type: ignore + model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore input_text = query prompt_messages: list[LLMNodeChatModelMessage] = [] diff --git a/api/core/workflow/nodes/list_operator/node.py b/api/core/workflow/nodes/list_operator/node.py index 3c9ba44cf1..ae9401b056 100644 --- a/api/core/workflow/nodes/list_operator/node.py +++ b/api/core/workflow/nodes/list_operator/node.py @@ -1,5 +1,5 @@ -from collections.abc import Callable, Sequence -from typing import Any, Literal, Union +from collections.abc import Callable, Mapping, Sequence +from typing import Any, Literal, Optional, Union from core.file import File from core.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment @@ -7,16 +7,39 @@ from core.variables.segments import ArrayAnySegment, ArraySegment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from .entities import ListOperatorNodeData from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError -class ListOperatorNode(BaseNode[ListOperatorNodeData]): - _node_data_cls = ListOperatorNodeData +class ListOperatorNode(BaseNode): _node_type = NodeType.LIST_OPERATOR + _node_data: ListOperatorNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = ListOperatorNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -26,9 +49,9 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): process_data: dict[str, list] = {} outputs: dict[str, Any] = {} - variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) + variable = self.graph_runtime_state.variable_pool.get(self._node_data.variable) if variable is None: - error_message = f"Variable not found for selector: {self.node_data.variable}" + error_message = f"Variable not found for selector: {self._node_data.variable}" return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs ) @@ -48,7 +71,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): ) if not isinstance(variable, ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment): error_message = ( - f"Variable {self.node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " + f"Variable {self._node_data.variable} is not an ArrayFileSegment, ArrayNumberSegment " "or ArrayStringSegment" ) return NodeRunResult( @@ -64,19 +87,19 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): try: # Filter - if self.node_data.filter_by.enabled: + if self._node_data.filter_by.enabled: variable = self._apply_filter(variable) # Extract - if self.node_data.extract_by.enabled: + if self._node_data.extract_by.enabled: variable = self._extract_slice(variable) # Order - if self.node_data.order_by.enabled: + if self._node_data.order_by.enabled: variable = self._apply_order(variable) # Slice - if self.node_data.limit.enabled: + if self._node_data.limit.enabled: variable = self._apply_slice(variable) outputs = { @@ -104,7 +127,7 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: filter_func: Callable[[Any], bool] result: list[Any] = [] - for condition in self.node_data.filter_by.conditions: + for condition in self._node_data.filter_by.conditions: if isinstance(variable, ArrayStringSegment): if not isinstance(condition.value, str): raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") @@ -137,14 +160,14 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: if isinstance(variable, ArrayStringSegment): - result = _order_string(order=self.node_data.order_by.value, array=variable.value) + result = _order_string(order=self._node_data.order_by.value, array=variable.value) variable = variable.model_copy(update={"value": result}) elif isinstance(variable, ArrayNumberSegment): - result = _order_number(order=self.node_data.order_by.value, array=variable.value) + result = _order_number(order=self._node_data.order_by.value, array=variable.value) variable = variable.model_copy(update={"value": result}) elif isinstance(variable, ArrayFileSegment): result = _order_file( - order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value + order=self._node_data.order_by.value, order_by=self._node_data.order_by.key, array=variable.value ) variable = variable.model_copy(update={"value": result}) return variable @@ -152,13 +175,13 @@ class ListOperatorNode(BaseNode[ListOperatorNodeData]): def _apply_slice( self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: - result = variable.value[: self.node_data.limit.size] + result = variable.value[: self._node_data.limit.size] return variable.model_copy(update={"value": result}) def _extract_slice( self, variable: Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment] ) -> Union[ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment]: - value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text) + value = int(self.graph_runtime_state.variable_pool.convert_template(self._node_data.extract_by.serial).text) if value < 1: raise ValueError(f"Invalid serial index: must be >= 1, got {value}") value -= 1 diff --git a/api/core/workflow/nodes/llm/entities.py b/api/core/workflow/nodes/llm/entities.py index 36d0688807..4bb62d35a2 100644 --- a/api/core/workflow/nodes/llm/entities.py +++ b/api/core/workflow/nodes/llm/entities.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence +from collections.abc import Mapping, Sequence from typing import Any, Optional from pydantic import BaseModel, Field, field_validator @@ -65,7 +65,7 @@ class LLMNodeData(BaseNodeData): memory: Optional[MemoryConfig] = None context: ContextConfig vision: VisionConfig = Field(default_factory=VisionConfig) - structured_output: dict | None = None + structured_output: Mapping[str, Any] | None = None # We used 'structured_output_enabled' in the past, but it's not a good name. structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 9bfb402dc8..91e7312805 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -59,7 +59,8 @@ from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import ( ModelInvokeCompletedEvent, NodeEvent, @@ -90,17 +91,16 @@ from .file_saver import FileSaverImpl, LLMFileSaver if TYPE_CHECKING: from core.file.models import File - from core.workflow.graph_engine.entities.graph import Graph - from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams - from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState logger = logging.getLogger(__name__) -class LLMNode(BaseNode[LLMNodeData]): - _node_data_cls = LLMNodeData +class LLMNode(BaseNode): _node_type = NodeType.LLM + _node_data: LLMNodeData + # Instance attributes specific to LLMNode. # Output variable for file _file_outputs: list["File"] @@ -138,6 +138,27 @@ class LLMNode(BaseNode[LLMNodeData]): ) self._llm_file_saver = llm_file_saver + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = LLMNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -152,13 +173,13 @@ class LLMNode(BaseNode[LLMNodeData]): try: # init messages template - self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) + self._node_data.prompt_template = self._transform_chat_messages(self._node_data.prompt_template) # fetch variables and fetch values from variable pool - inputs = self._fetch_inputs(node_data=self.node_data) + inputs = self._fetch_inputs(node_data=self._node_data) # fetch jinja2 inputs - jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) + jinja_inputs = self._fetch_jinja_inputs(node_data=self._node_data) # merge inputs inputs.update(jinja_inputs) @@ -169,9 +190,9 @@ class LLMNode(BaseNode[LLMNodeData]): files = ( llm_utils.fetch_files( variable_pool=variable_pool, - selector=self.node_data.vision.configs.variable_selector, + selector=self._node_data.vision.configs.variable_selector, ) - if self.node_data.vision.enabled + if self._node_data.vision.enabled else [] ) @@ -179,7 +200,7 @@ class LLMNode(BaseNode[LLMNodeData]): node_inputs["#files#"] = [file.to_dict() for file in files] # fetch context value - generator = self._fetch_context(node_data=self.node_data) + generator = self._fetch_context(node_data=self._node_data) context = None for event in generator: if isinstance(event, RunRetrieverResourceEvent): @@ -189,44 +210,54 @@ class LLMNode(BaseNode[LLMNodeData]): node_inputs["#context#"] = context # fetch model config - model_instance, model_config = self._fetch_model_config(self.node_data.model) + model_instance, model_config = LLMNode._fetch_model_config( + node_data_model=self._node_data.model, + tenant_id=self.tenant_id, + ) # fetch memory memory = llm_utils.fetch_memory( variable_pool=variable_pool, app_id=self.app_id, - node_data_memory=self.node_data.memory, + node_data_memory=self._node_data.memory, model_instance=model_instance, ) query = None - if self.node_data.memory: - query = self.node_data.memory.query_prompt_template + if self._node_data.memory: + query = self._node_data.memory.query_prompt_template if not query and ( query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) ): query = query_variable.text - prompt_messages, stop = self._fetch_prompt_messages( + prompt_messages, stop = LLMNode.fetch_prompt_messages( sys_query=query, sys_files=files, context=context, memory=memory, model_config=model_config, - prompt_template=self.node_data.prompt_template, - memory_config=self.node_data.memory, - vision_enabled=self.node_data.vision.enabled, - vision_detail=self.node_data.vision.configs.detail, + prompt_template=self._node_data.prompt_template, + memory_config=self._node_data.memory, + vision_enabled=self._node_data.vision.enabled, + vision_detail=self._node_data.vision.configs.detail, variable_pool=variable_pool, - jinja2_variables=self.node_data.prompt_config.jinja2_variables, + jinja2_variables=self._node_data.prompt_config.jinja2_variables, + tenant_id=self.tenant_id, ) # handle invoke result - generator = self._invoke_llm( - node_data_model=self.node_data.model, + generator = LLMNode.invoke_llm( + node_data_model=self._node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, + user_id=self.user_id, + structured_output_enabled=self._node_data.structured_output_enabled, + structured_output=self._node_data.structured_output, + file_saver=self._llm_file_saver, + file_outputs=self._file_outputs, + node_id=self.node_id, ) structured_output: LLMStructuredOutput | None = None @@ -296,12 +327,19 @@ class LLMNode(BaseNode[LLMNodeData]): ) ) - def _invoke_llm( - self, + @staticmethod + def invoke_llm( + *, node_data_model: ModelConfig, model_instance: ModelInstance, prompt_messages: Sequence[PromptMessage], stop: Optional[Sequence[str]] = None, + user_id: str, + structured_output_enabled: bool, + structured_output: Optional[Mapping[str, Any]] = None, + file_saver: LLMFileSaver, + file_outputs: list["File"], + node_id: str, ) -> Generator[NodeEvent | LLMStructuredOutput, None, None]: model_schema = model_instance.model_type_instance.get_model_schema( node_data_model.name, model_instance.credentials @@ -309,8 +347,10 @@ class LLMNode(BaseNode[LLMNodeData]): if not model_schema: raise ValueError(f"Model schema not found for {node_data_model.name}") - if self.node_data.structured_output_enabled: - output_schema = self._fetch_structured_output_schema() + if structured_output_enabled: + output_schema = LLMNode.fetch_structured_output_schema( + structured_output=structured_output or {}, + ) invoke_result = invoke_llm_with_structured_output( provider=model_instance.provider, model_schema=model_schema, @@ -320,7 +360,7 @@ class LLMNode(BaseNode[LLMNodeData]): model_parameters=node_data_model.completion_params, stop=list(stop or []), stream=True, - user=self.user_id, + user=user_id, ) else: invoke_result = model_instance.invoke_llm( @@ -328,17 +368,31 @@ class LLMNode(BaseNode[LLMNodeData]): model_parameters=node_data_model.completion_params, stop=list(stop or []), stream=True, - user=self.user_id, + user=user_id, ) - return self._handle_invoke_result(invoke_result=invoke_result) + return LLMNode.handle_invoke_result( + invoke_result=invoke_result, + file_saver=file_saver, + file_outputs=file_outputs, + node_id=node_id, + ) - def _handle_invoke_result( - self, invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None] + @staticmethod + def handle_invoke_result( + *, + invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], + file_saver: LLMFileSaver, + file_outputs: list["File"], + node_id: str, ) -> Generator[NodeEvent | LLMStructuredOutput, None, None]: # For blocking mode if isinstance(invoke_result, LLMResult): - event = self._handle_blocking_result(invoke_result=invoke_result) + event = LLMNode.handle_blocking_result( + invoke_result=invoke_result, + saver=file_saver, + file_outputs=file_outputs, + ) yield event return @@ -356,11 +410,13 @@ class LLMNode(BaseNode[LLMNodeData]): yield result if isinstance(result, LLMResultChunk): contents = result.delta.message.content - for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents): + for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( + contents=contents, + file_saver=file_saver, + file_outputs=file_outputs, + ): full_text_buffer.write(text_part) - yield RunStreamChunkEvent( - chunk_content=text_part, from_variable_selector=[self.node_id, "text"] - ) + yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[node_id, "text"]) # Update the whole metadata if not model and result.model: @@ -378,7 +434,8 @@ class LLMNode(BaseNode[LLMNodeData]): yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason) - def _image_file_to_markdown(self, file: "File", /): + @staticmethod + def _image_file_to_markdown(file: "File", /): text_chunk = f"![]({file.generate_url()})" return text_chunk @@ -539,11 +596,14 @@ class LLMNode(BaseNode[LLMNodeData]): return None + @staticmethod def _fetch_model_config( - self, node_data_model: ModelConfig + *, + node_data_model: ModelConfig, + tenant_id: str, ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: model, model_config_with_cred = llm_utils.fetch_model_config( - tenant_id=self.tenant_id, node_data_model=node_data_model + tenant_id=tenant_id, node_data_model=node_data_model ) completion_params = model_config_with_cred.parameters @@ -556,8 +616,8 @@ class LLMNode(BaseNode[LLMNodeData]): node_data_model.completion_params = completion_params return model, model_config_with_cred - def _fetch_prompt_messages( - self, + @staticmethod + def fetch_prompt_messages( *, sys_query: str | None = None, sys_files: Sequence["File"], @@ -570,13 +630,14 @@ class LLMNode(BaseNode[LLMNodeData]): vision_detail: ImagePromptMessageContent.DETAIL, variable_pool: VariablePool, jinja2_variables: Sequence[VariableSelector], + tenant_id: str, ) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]: prompt_messages: list[PromptMessage] = [] if isinstance(prompt_template, list): # For chat model prompt_messages.extend( - self._handle_list_messages( + LLMNode.handle_list_messages( messages=prompt_template, context=context, jinja2_variables=jinja2_variables, @@ -602,7 +663,7 @@ class LLMNode(BaseNode[LLMNodeData]): edition_type="basic", ) prompt_messages.extend( - self._handle_list_messages( + LLMNode.handle_list_messages( messages=[message], context="", jinja2_variables=[], @@ -731,7 +792,7 @@ class LLMNode(BaseNode[LLMNodeData]): ) model = ModelManager().get_model_instance( - tenant_id=self.tenant_id, + tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model, @@ -750,10 +811,12 @@ class LLMNode(BaseNode[LLMNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: LLMNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - prompt_template = node_data.prompt_template + # Create typed NodeData from dict + typed_node_data = LLMNodeData.model_validate(node_data) + prompt_template = typed_node_data.prompt_template variable_selectors = [] if isinstance(prompt_template, list) and all( isinstance(prompt, LLMNodeChatModelMessage) for prompt in prompt_template @@ -773,7 +836,7 @@ class LLMNode(BaseNode[LLMNodeData]): for variable_selector in variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector - memory = node_data.memory + memory = typed_node_data.memory if memory and memory.query_prompt_template: query_variable_selectors = VariableTemplateParser( template=memory.query_prompt_template @@ -781,16 +844,16 @@ class LLMNode(BaseNode[LLMNodeData]): for variable_selector in query_variable_selectors: variable_mapping[variable_selector.variable] = variable_selector.value_selector - if node_data.context.enabled: - variable_mapping["#context#"] = node_data.context.variable_selector + if typed_node_data.context.enabled: + variable_mapping["#context#"] = typed_node_data.context.variable_selector - if node_data.vision.enabled: - variable_mapping["#files#"] = node_data.vision.configs.variable_selector + if typed_node_data.vision.enabled: + variable_mapping["#files#"] = typed_node_data.vision.configs.variable_selector - if node_data.memory: + if typed_node_data.memory: variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY.value] - if node_data.prompt_config: + if typed_node_data.prompt_config: enable_jinja = False if isinstance(prompt_template, list): @@ -803,7 +866,7 @@ class LLMNode(BaseNode[LLMNodeData]): enable_jinja = True if enable_jinja: - for variable_selector in node_data.prompt_config.jinja2_variables or []: + for variable_selector in typed_node_data.prompt_config.jinja2_variables or []: variable_mapping[variable_selector.variable] = variable_selector.value_selector variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} @@ -835,8 +898,8 @@ class LLMNode(BaseNode[LLMNodeData]): }, } - def _handle_list_messages( - self, + @staticmethod + def handle_list_messages( *, messages: Sequence[LLMNodeChatModelMessage], context: Optional[str], @@ -849,7 +912,7 @@ class LLMNode(BaseNode[LLMNodeData]): if message.edition_type == "jinja2": result_text = _render_jinja2_message( template=message.jinja2_text or "", - jinjia2_variables=jinja2_variables, + jinja2_variables=jinja2_variables, variable_pool=variable_pool, ) prompt_message = _combine_message_content_with_role( @@ -897,9 +960,19 @@ class LLMNode(BaseNode[LLMNodeData]): return prompt_messages - def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent: + @staticmethod + def handle_blocking_result( + *, + invoke_result: LLMResult, + saver: LLMFileSaver, + file_outputs: list["File"], + ) -> ModelInvokeCompletedEvent: buffer = io.StringIO() - for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content): + for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( + contents=invoke_result.message.content, + file_saver=saver, + file_outputs=file_outputs, + ): buffer.write(text_part) return ModelInvokeCompletedEvent( @@ -908,7 +981,12 @@ class LLMNode(BaseNode[LLMNodeData]): finish_reason=None, ) - def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File": + @staticmethod + def save_multimodal_image_output( + *, + content: ImagePromptMessageContent, + file_saver: LLMFileSaver, + ) -> "File": """_save_multimodal_output saves multi-modal contents generated by LLM plugins. There are two kinds of multimodal outputs: @@ -918,26 +996,21 @@ class LLMNode(BaseNode[LLMNodeData]): Currently, only image files are supported. """ - # Inject the saver somehow... - _saver = self._llm_file_saver - - # If this if content.url != "": - saved_file = _saver.save_remote_url(content.url, FileType.IMAGE) + saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE) else: - saved_file = _saver.save_binary_string( + saved_file = file_saver.save_binary_string( data=base64.b64decode(content.base64_data), mime_type=content.mime_type, file_type=FileType.IMAGE, ) - self._file_outputs.append(saved_file) return saved_file def _fetch_model_schema(self, provider: str) -> AIModelEntity | None: """ Fetch model schema """ - model_name = self.node_data.model.name + model_name = self._node_data.model.name model_manager = ModelManager() model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider, model=model_name @@ -948,16 +1021,20 @@ class LLMNode(BaseNode[LLMNodeData]): model_schema = model_type_instance.get_model_schema(model_name, model_credentials) return model_schema - def _fetch_structured_output_schema(self) -> dict[str, Any]: + @staticmethod + def fetch_structured_output_schema( + *, + structured_output: Mapping[str, Any], + ) -> dict[str, Any]: """ Fetch the structured output schema from the node data. Returns: dict[str, Any]: The structured output schema """ - if not self.node_data.structured_output: + if not structured_output: raise LLMNodeError("Please provide a valid structured output schema") - structured_output_schema = json.dumps(self.node_data.structured_output.get("schema", {}), ensure_ascii=False) + structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False) if not structured_output_schema: raise LLMNodeError("Please provide a valid structured output schema") @@ -969,9 +1046,12 @@ class LLMNode(BaseNode[LLMNodeData]): except json.JSONDecodeError: raise LLMNodeError("structured_output_schema is not valid JSON format") + @staticmethod def _save_multimodal_output_and_convert_result_to_markdown( - self, + *, contents: str | list[PromptMessageContentUnionTypes] | None, + file_saver: LLMFileSaver, + file_outputs: list["File"], ) -> Generator[str, None, None]: """Convert intermediate prompt messages into strings and yield them to the caller. @@ -994,9 +1074,12 @@ class LLMNode(BaseNode[LLMNodeData]): if isinstance(item, TextPromptMessageContent): yield item.data elif isinstance(item, ImagePromptMessageContent): - file = self._save_multimodal_image_output(item) - self._file_outputs.append(file) - yield self._image_file_to_markdown(file) + file = LLMNode.save_multimodal_image_output( + content=item, + file_saver=file_saver, + ) + file_outputs.append(file) + yield LLMNode._image_file_to_markdown(file) else: logger.warning("unknown item type encountered, type=%s", type(item)) yield str(item) @@ -1004,6 +1087,14 @@ class LLMNode(BaseNode[LLMNodeData]): logger.warning("unknown contents type encountered, type=%s", type(contents)) yield str(contents) + @property + def continue_on_error(self) -> bool: + return self._node_data.error_strategy is not None + + @property + def retry(self) -> bool: + return self._node_data.retry_config.retry_enabled + def _combine_message_content_with_role( *, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole @@ -1021,20 +1112,20 @@ def _combine_message_content_with_role( def _render_jinja2_message( *, template: str, - jinjia2_variables: Sequence[VariableSelector], + jinja2_variables: Sequence[VariableSelector], variable_pool: VariablePool, ): if not template: return "" - jinjia2_inputs = {} - for jinja2_variable in jinjia2_variables: + jinja2_inputs = {} + for jinja2_variable in jinja2_variables: variable = variable_pool.get(jinja2_variable.value_selector) - jinjia2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" + jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" code_execute_resp = CodeExecutor.execute_workflow_code_template( language=CodeLanguage.JINJA2, code=template, - inputs=jinjia2_inputs, + inputs=jinja2_inputs, ) result_text = code_execute_resp["result"] return result_text @@ -1130,7 +1221,7 @@ def _handle_completion_template( if template.edition_type == "jinja2": result_text = _render_jinja2_message( template=template.jinja2_text or "", - jinjia2_variables=jinja2_variables, + jinja2_variables=jinja2_variables, variable_pool=variable_pool, ) else: diff --git a/api/core/workflow/nodes/loop/entities.py b/api/core/workflow/nodes/loop/entities.py index 3f4a5edab9..d04e0bfae1 100644 --- a/api/core/workflow/nodes/loop/entities.py +++ b/api/core/workflow/nodes/loop/entities.py @@ -1,11 +1,29 @@ from collections.abc import Mapping -from typing import Any, Literal, Optional +from typing import Annotated, Any, Literal, Optional -from pydantic import BaseModel, Field +from pydantic import AfterValidator, BaseModel, Field +from core.variables.types import SegmentType from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData from core.workflow.utils.condition.entities import Condition +_VALID_VAR_TYPE = frozenset( + [ + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.OBJECT, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + ] +) + + +def _is_valid_var_type(seg_type: SegmentType) -> SegmentType: + if seg_type not in _VALID_VAR_TYPE: + raise ValueError(...) + return seg_type + class LoopVariableData(BaseModel): """ @@ -13,7 +31,7 @@ class LoopVariableData(BaseModel): """ label: str - var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] + var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)] value_type: Literal["variable", "constant"] value: Optional[Any | list[str]] = None diff --git a/api/core/workflow/nodes/loop/loop_end_node.py b/api/core/workflow/nodes/loop/loop_end_node.py index b144021bab..53cadc5251 100644 --- a/api/core/workflow/nodes/loop/loop_end_node.py +++ b/api/core/workflow/nodes/loop/loop_end_node.py @@ -1,18 +1,44 @@ +from collections.abc import Mapping +from typing import Any, Optional + from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.loop.entities import LoopEndNodeData -class LoopEndNode(BaseNode[LoopEndNodeData]): +class LoopEndNode(BaseNode): """ Loop End Node. """ - _node_data_cls = LoopEndNodeData _node_type = NodeType.LOOP_END + _node_data: LoopEndNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = LoopEndNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/loop/loop_node.py b/api/core/workflow/nodes/loop/loop_node.py index 11fd7b6c2d..655de9362f 100644 --- a/api/core/workflow/nodes/loop/loop_node.py +++ b/api/core/workflow/nodes/loop/loop_node.py @@ -3,18 +3,13 @@ import logging import time from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, Literal, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, cast from configs import dify_config from core.variables import ( - ArrayNumberSegment, - ArrayObjectSegment, - ArrayStringSegment, IntegerSegment, - ObjectSegment, Segment, SegmentType, - StringSegment, ) from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus @@ -35,10 +30,12 @@ from core.workflow.graph_engine.entities.event import ( ) from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.loop.entities import LoopNodeData from core.workflow.utils.condition.processor import ConditionProcessor +from factories.variable_factory import TypeMismatchError, build_segment_with_type if TYPE_CHECKING: from core.workflow.entities.variable_pool import VariablePool @@ -47,14 +44,36 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class LoopNode(BaseNode[LoopNodeData]): +class LoopNode(BaseNode): """ Loop Node. """ - _node_data_cls = LoopNodeData _node_type = NodeType.LOOP + _node_data: LoopNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = LoopNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -62,17 +81,17 @@ class LoopNode(BaseNode[LoopNodeData]): def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]: """Run the node.""" # Get inputs - loop_count = self.node_data.loop_count - break_conditions = self.node_data.break_conditions - logical_operator = self.node_data.logical_operator + loop_count = self._node_data.loop_count + break_conditions = self._node_data.break_conditions + logical_operator = self._node_data.logical_operator inputs = {"loop_count": loop_count} - if not self.node_data.start_node_id: + if not self._node_data.start_node_id: raise ValueError(f"field start_node_id in loop {self.node_id} not found") # Initialize graph - loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self.node_data.start_node_id) + loop_graph = Graph.init(graph_config=self.graph_config, root_node_id=self._node_data.start_node_id) if not loop_graph: raise ValueError("loop graph not found") @@ -82,8 +101,8 @@ class LoopNode(BaseNode[LoopNodeData]): # Initialize loop variables loop_variable_selectors = {} - if self.node_data.loop_variables: - for loop_variable in self.node_data.loop_variables: + if self._node_data.loop_variables: + for loop_variable in self._node_data.loop_variables: value_processor = { "constant": lambda var=loop_variable: self._get_segment_for_constant(var.var_type, var.value), "variable": lambda var=loop_variable: variable_pool.get(var.value), @@ -131,8 +150,8 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunStartedEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, metadata={"loop_length": loop_count}, @@ -188,11 +207,11 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunSucceededEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, - outputs=self.node_data.outputs, + outputs=self._node_data.outputs, steps=loop_count, metadata={ WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: graph_engine.graph_runtime_state.total_tokens, @@ -210,7 +229,7 @@ class LoopNode(BaseNode[LoopNodeData]): WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, }, - outputs=self.node_data.outputs, + outputs=self._node_data.outputs, inputs=inputs, ) ) @@ -221,8 +240,8 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunFailedEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, steps=loop_count, @@ -324,8 +343,8 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunFailedEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, steps=current_index, @@ -355,8 +374,8 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunFailedEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, start_at=start_at, inputs=inputs, steps=current_index, @@ -392,7 +411,7 @@ class LoopNode(BaseNode[LoopNodeData]): _outputs[loop_variable_key] = None _outputs["loop_round"] = current_index + 1 - self.node_data.outputs = _outputs + self._node_data.outputs = _outputs if check_break_result: return {"check_break_result": True} @@ -404,10 +423,10 @@ class LoopNode(BaseNode[LoopNodeData]): yield LoopRunNextEvent( loop_id=self.id, loop_node_id=self.node_id, - loop_node_type=self.node_type, - loop_node_data=self.node_data, + loop_node_type=self.type_, + loop_node_data=self._node_data, index=next_index, - pre_loop_output=self.node_data.outputs, + pre_loop_output=self._node_data.outputs, ) return {"check_break_result": False} @@ -442,19 +461,15 @@ class LoopNode(BaseNode[LoopNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: LoopNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = LoopNodeData.model_validate(node_data) + variable_mapping = {} # init graph - loop_graph = Graph.init(graph_config=graph_config, root_node_id=node_data.start_node_id) + loop_graph = Graph.init(graph_config=graph_config, root_node_id=typed_node_data.start_node_id) if not loop_graph: raise ValueError("loop graph not found") @@ -490,7 +505,7 @@ class LoopNode(BaseNode[LoopNodeData]): variable_mapping.update(sub_node_variable_mapping) - for loop_variable in node_data.loop_variables or []: + for loop_variable in typed_node_data.loop_variables or []: if loop_variable.value_type == "variable": assert loop_variable.value is not None, "Loop variable value must be provided for variable type" # add loop variable to variable mapping @@ -505,23 +520,21 @@ class LoopNode(BaseNode[LoopNodeData]): return variable_mapping @staticmethod - def _get_segment_for_constant(var_type: str, value: Any) -> Segment: + def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment: """Get the appropriate segment type for a constant value.""" - segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = { - "string": (StringSegment, SegmentType.STRING), - "number": (IntegerSegment, SegmentType.NUMBER), - "object": (ObjectSegment, SegmentType.OBJECT), - "array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING), - "array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER), - "array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT), - } if var_type in ["array[string]", "array[number]", "array[object]"]: - if value: + if value and isinstance(value, str): value = json.loads(value) else: value = [] - segment_info = segment_mapping.get(var_type) - if not segment_info: - raise ValueError(f"Invalid variable type: {var_type}") - segment_class, value_type = segment_info - return segment_class(value=value, value_type=value_type) + try: + return build_segment_with_type(var_type, value) + except TypeMismatchError as type_exc: + # Attempt to parse the value as a JSON-encoded string, if applicable. + if not isinstance(value, str): + raise + try: + value = json.loads(value) + except ValueError: + raise type_exc + return build_segment_with_type(var_type, value) diff --git a/api/core/workflow/nodes/loop/loop_start_node.py b/api/core/workflow/nodes/loop/loop_start_node.py index f5e38b7516..29b45ea0c3 100644 --- a/api/core/workflow/nodes/loop/loop_start_node.py +++ b/api/core/workflow/nodes/loop/loop_start_node.py @@ -1,18 +1,44 @@ +from collections.abc import Mapping +from typing import Any, Optional + from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.loop.entities import LoopStartNodeData -class LoopStartNode(BaseNode[LoopStartNodeData]): +class LoopStartNode(BaseNode): """ Loop Start Node. """ - _node_data_cls = LoopStartNodeData _node_type = NodeType.LOOP_START + _node_data: LoopStartNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = LoopStartNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index ccfaec4a8c..294b47670b 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -73,6 +73,9 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { }, NodeType.TOOL: { LATEST_VERSION: ToolNode, + # This is an issue that caused problems before. + # Logically, we shouldn't use two different versions to point to the same class here, + # but in order to maintain compatibility with historical data, this approach has been retained. "2": ToolNode, "1": ToolNode, }, @@ -123,6 +126,9 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { }, NodeType.AGENT: { LATEST_VERSION: AgentNode, + # This is an issue that caused problems before. + # Logically, we shouldn't use two different versions to point to the same class here, + # but in order to maintain compatibility with historical data, this approach has been retained. "2": AgentNode, "1": AgentNode, }, diff --git a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py index 25a534256b..a23d284626 100644 --- a/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py +++ b/api/core/workflow/nodes/parameter_extractor/parameter_extractor_node.py @@ -29,8 +29,9 @@ from core.variables.types import SegmentType from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig from core.workflow.nodes.base.node import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.llm import ModelConfig, llm_utils from core.workflow.utils import variable_template_parser from factories.variable_factory import build_segment_with_type @@ -91,10 +92,31 @@ class ParameterExtractorNode(BaseNode): Parameter Extractor Node. """ - # FIXME: figure out why here is different from super class - _node_data_cls = ParameterExtractorNodeData # type: ignore _node_type = NodeType.PARAMETER_EXTRACTOR + _node_data: ParameterExtractorNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = ParameterExtractorNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + _model_instance: Optional[ModelInstance] = None _model_config: Optional[ModelConfigWithCredentialsEntity] = None @@ -119,7 +141,7 @@ class ParameterExtractorNode(BaseNode): """ Run the node. """ - node_data = cast(ParameterExtractorNodeData, self.node_data) + node_data = cast(ParameterExtractorNodeData, self._node_data) variable = self.graph_runtime_state.variable_pool.get(node_data.query) query = variable.text if variable else "" @@ -398,7 +420,7 @@ class ParameterExtractorNode(BaseNode): """ Generate prompt engineering prompt. """ - model_mode = ModelMode.value_of(data.model.mode) + model_mode = ModelMode(data.model.mode) if model_mode == ModelMode.COMPLETION: return self._generate_prompt_engineering_completion_prompt( @@ -694,7 +716,7 @@ class ParameterExtractorNode(BaseNode): memory: Optional[TokenBufferMemory], max_token_limit: int = 2000, ) -> list[ChatModelMessage]: - model_mode = ModelMode.value_of(node_data.model.mode) + model_mode = ModelMode(node_data.model.mode) input_text = query memory_str = "" instruction = variable_pool.convert_template(node_data.instruction or "").text @@ -721,7 +743,7 @@ class ParameterExtractorNode(BaseNode): memory: Optional[TokenBufferMemory], max_token_limit: int = 2000, ): - model_mode = ModelMode.value_of(node_data.model.mode) + model_mode = ModelMode(node_data.model.mode) input_text = query memory_str = "" instruction = variable_pool.convert_template(node_data.instruction or "").text @@ -827,19 +849,15 @@ class ParameterExtractorNode(BaseNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: ParameterExtractorNodeData, # type: ignore + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} + # Create typed NodeData from dict + typed_node_data = ParameterExtractorNodeData.model_validate(node_data) + + variable_mapping: dict[str, Sequence[str]] = {"query": typed_node_data.query} - if node_data.instruction: - selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) + if typed_node_data.instruction: + selectors = variable_template_parser.extract_selectors_from_template(typed_node_data.instruction) for selector in selectors: variable_mapping[selector.variable] = selector.value_selector diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 74024ed90c..15012fa48d 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -1,6 +1,6 @@ import json from collections.abc import Mapping, Sequence -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -11,8 +11,11 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.simple_prompt_transform import ModelMode from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.variable_entities import VariableSelector from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.base.node import BaseNode +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import ModelInvokeCompletedEvent from core.workflow.nodes.llm import ( LLMNode, @@ -20,6 +23,7 @@ from core.workflow.nodes.llm import ( LLMNodeCompletionModelPromptTemplate, llm_utils, ) +from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver from core.workflow.utils.variable_template_parser import VariableTemplateParser from libs.json_in_md_parser import parse_and_check_json_markdown @@ -35,17 +39,77 @@ from .template_prompts import ( QUESTION_CLASSIFIER_USER_PROMPT_3, ) +if TYPE_CHECKING: + from core.file.models import File + from core.workflow.graph_engine import Graph, GraphInitParams, GraphRuntimeState -class QuestionClassifierNode(LLMNode): - _node_data_cls = QuestionClassifierNodeData # type: ignore + +class QuestionClassifierNode(BaseNode): _node_type = NodeType.QUESTION_CLASSIFIER + _node_data: QuestionClassifierNodeData + + _file_outputs: list["File"] + _llm_file_saver: LLMFileSaver + + def __init__( + self, + id: str, + config: Mapping[str, Any], + graph_init_params: "GraphInitParams", + graph: "Graph", + graph_runtime_state: "GraphRuntimeState", + previous_node_id: Optional[str] = None, + thread_pool_id: Optional[str] = None, + *, + llm_file_saver: LLMFileSaver | None = None, + ) -> None: + super().__init__( + id=id, + config=config, + graph_init_params=graph_init_params, + graph=graph, + graph_runtime_state=graph_runtime_state, + previous_node_id=previous_node_id, + thread_pool_id=thread_pool_id, + ) + # LLM file outputs, used for MultiModal outputs. + self._file_outputs: list[File] = [] + + if llm_file_saver is None: + llm_file_saver = FileSaverImpl( + user_id=graph_init_params.user_id, + tenant_id=graph_init_params.tenant_id, + ) + self._llm_file_saver = llm_file_saver + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = QuestionClassifierNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls): return "1" def _run(self): - node_data = cast(QuestionClassifierNodeData, self.node_data) + node_data = cast(QuestionClassifierNodeData, self._node_data) variable_pool = self.graph_runtime_state.variable_pool # extract variables @@ -53,7 +117,10 @@ class QuestionClassifierNode(LLMNode): query = variable.value if variable else None variables = {"query": query} # fetch model config - model_instance, model_config = self._fetch_model_config(node_data.model) + model_instance, model_config = LLMNode._fetch_model_config( + node_data_model=node_data.model, + tenant_id=self.tenant_id, + ) # fetch memory memory = llm_utils.fetch_memory( variable_pool=variable_pool, @@ -91,7 +158,7 @@ class QuestionClassifierNode(LLMNode): # If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt, # two consecutive user prompts will be generated, causing model's error. # To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end. - prompt_messages, stop = self._fetch_prompt_messages( + prompt_messages, stop = LLMNode.fetch_prompt_messages( prompt_template=prompt_template, sys_query="", memory=memory, @@ -101,6 +168,7 @@ class QuestionClassifierNode(LLMNode): vision_detail=node_data.vision.configs.detail, variable_pool=variable_pool, jinja2_variables=[], + tenant_id=self.tenant_id, ) result_text = "" @@ -109,11 +177,17 @@ class QuestionClassifierNode(LLMNode): try: # handle invoke result - generator = self._invoke_llm( + generator = LLMNode.invoke_llm( node_data_model=node_data.model, model_instance=model_instance, prompt_messages=prompt_messages, stop=stop, + user_id=self.user_id, + structured_output_enabled=False, + structured_output=None, + file_saver=self._llm_file_saver, + file_outputs=self._file_outputs, + node_id=self.node_id, ) for event in generator: @@ -183,23 +257,18 @@ class QuestionClassifierNode(LLMNode): *, graph_config: Mapping[str, Any], node_id: str, - node_data: Any, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - node_data = cast(QuestionClassifierNodeData, node_data) - variable_mapping = {"query": node_data.query_variable_selector} - variable_selectors = [] - if node_data.instruction: - variable_template_parser = VariableTemplateParser(template=node_data.instruction) + # Create typed NodeData from dict + typed_node_data = QuestionClassifierNodeData.model_validate(node_data) + + variable_mapping = {"query": typed_node_data.query_variable_selector} + variable_selectors: list[VariableSelector] = [] + if typed_node_data.instruction: + variable_template_parser = VariableTemplateParser(template=typed_node_data.instruction) variable_selectors.extend(variable_template_parser.extract_variable_selectors()) for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector + variable_mapping[variable_selector.variable] = list(variable_selector.value_selector) variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} @@ -265,7 +334,7 @@ class QuestionClassifierNode(LLMNode): memory: Optional[TokenBufferMemory], max_token_limit: int = 2000, ): - model_mode = ModelMode.value_of(node_data.model.mode) + model_mode = ModelMode(node_data.model.mode) classes = node_data.classes categories = [] for class_ in classes: diff --git a/api/core/workflow/nodes/start/start_node.py b/api/core/workflow/nodes/start/start_node.py index 5ee9bc331f..9e401e76bb 100644 --- a/api/core/workflow/nodes/start/start_node.py +++ b/api/core/workflow/nodes/start/start_node.py @@ -1,22 +1,48 @@ +from collections.abc import Mapping +from typing import Any, Optional + from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.start.entities import StartNodeData -class StartNode(BaseNode[StartNodeData]): - _node_data_cls = StartNodeData +class StartNode(BaseNode): _node_type = NodeType.START + _node_data: StartNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = StartNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" def _run(self) -> NodeRunResult: node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables + system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() # TODO: System variables should be directly accessible, no need for special handling # Set system variables as node outputs. diff --git a/api/core/workflow/nodes/template_transform/template_transform_node.py b/api/core/workflow/nodes/template_transform/template_transform_node.py index ba573074c3..1962c82db1 100644 --- a/api/core/workflow/nodes/template_transform/template_transform_node.py +++ b/api/core/workflow/nodes/template_transform/template_transform_node.py @@ -6,16 +6,39 @@ from core.helper.code_executor.code_executor import CodeExecutionError, CodeExec from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.template_transform.entities import TemplateTransformNodeData MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = int(os.environ.get("TEMPLATE_TRANSFORM_MAX_LENGTH", "80000")) -class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): - _node_data_cls = TemplateTransformNodeData +class TemplateTransformNode(BaseNode): _node_type = NodeType.TEMPLATE_TRANSFORM + _node_data: TemplateTransformNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = TemplateTransformNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def get_default_config(cls, filters: Optional[dict] = None) -> dict: """ @@ -35,14 +58,14 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): def _run(self) -> NodeRunResult: # Get variables variables = {} - for variable_selector in self.node_data.variables: + for variable_selector in self._node_data.variables: variable_name = variable_selector.variable value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) variables[variable_name] = value.to_object() if value else None # Run code try: result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, code=self.node_data.template, inputs=variables + language=CodeLanguage.JINJA2, code=self._node_data.template, inputs=variables ) except CodeExecutionError as e: return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) @@ -60,16 +83,12 @@ class TemplateTransformNode(BaseNode[TemplateTransformNodeData]): @classmethod def _extract_variable_selector_to_variable_mapping( - cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData + cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: Mapping[str, Any] ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ + # Create typed NodeData from dict + typed_node_data = TemplateTransformNodeData.model_validate(node_data) + return { node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables + for variable_selector in typed_node_data.variables } diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 691f6e0196..f0a44d919b 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -14,6 +14,7 @@ class ToolEntity(BaseModel): tool_name: str tool_label: str # redundancy tool_configurations: dict[str, Any] + credential_id: str | None = None plugin_unique_identifier: str | None = None # redundancy @field_validator("tool_configurations", mode="before") @@ -58,6 +59,10 @@ class ToolNodeData(BaseNodeData, ToolEntity): return typ tool_parameters: dict[str, ToolInput] + # The version of the tool parameter. + # If this value is None, it indicates this is a previous version + # and requires using the legacy parameter parsing rules. + tool_node_version: str | None = None @field_validator("tool_parameters", mode="before") @classmethod diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 48627a229d..140fe71f60 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -6,7 +6,6 @@ from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file import File, FileTransferMethod -from core.model_runtime.entities.llm_entities import LLMUsage from core.plugin.impl.exc import PluginDaemonClientSideError from core.plugin.impl.plugin import PluginInstaller from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter @@ -19,9 +18,9 @@ from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.enums import SystemVariableKey -from core.workflow.graph_engine.entities.event import AgentLogEvent from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent from core.workflow.utils.variable_template_parser import VariableTemplateParser from extensions.ext_database import db @@ -37,14 +36,18 @@ from .exc import ( ) -class ToolNode(BaseNode[ToolNodeData]): +class ToolNode(BaseNode): """ Tool Node """ - _node_data_cls = ToolNodeData _node_type = NodeType.TOOL + _node_data: ToolNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = ToolNodeData.model_validate(data) + @classmethod def version(cls) -> str: return "1" @@ -54,7 +57,7 @@ class ToolNode(BaseNode[ToolNodeData]): Run the tool node """ - node_data = cast(ToolNodeData, self.node_data) + node_data = cast(ToolNodeData, self._node_data) # fetch tool icon tool_info = { @@ -67,9 +70,15 @@ class ToolNode(BaseNode[ToolNodeData]): try: from core.tools.tool_manager import ToolManager - variable_pool = self.graph_runtime_state.variable_pool if self.node_data.version != "1" else None + # This is an issue that caused problems before. + # Logically, we shouldn't use the node_data.version field for judgment + # But for backward compatibility with historical data + # this version field judgment is still preserved here. + variable_pool: VariablePool | None = None + if node_data.version != "1" or node_data.tool_node_version != "1": + variable_pool = self.graph_runtime_state.variable_pool tool_runtime = ToolManager.get_workflow_tool_runtime( - self.tenant_id, self.app_id, self.node_id, self.node_data, self.invoke_from, variable_pool + self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool ) except ToolNodeError as e: yield RunCompletedEvent( @@ -88,12 +97,12 @@ class ToolNode(BaseNode[ToolNodeData]): parameters = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, + node_data=self._node_data, ) parameters_for_log = self._generate_parameters( tool_parameters=tool_parameters, variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, + node_data=self._node_data, for_log=True, ) # get conversation id @@ -124,7 +133,14 @@ class ToolNode(BaseNode[ToolNodeData]): try: # convert tool messages - yield from self._transform_message(message_stream, tool_info, parameters_for_log) + yield from self._transform_message( + messages=message_stream, + tool_info=tool_info, + parameters_for_log=parameters_for_log, + user_id=self.user_id, + tenant_id=self.tenant_id, + node_id=self.node_id, + ) except (PluginDaemonClientSideError, ToolInvokeError) as e: yield RunCompletedEvent( run_result=NodeRunResult( @@ -191,7 +207,9 @@ class ToolNode(BaseNode[ToolNodeData]): messages: Generator[ToolInvokeMessage, None, None], tool_info: Mapping[str, Any], parameters_for_log: dict[str, Any], - agent_thoughts: Optional[list] = None, + user_id: str, + tenant_id: str, + node_id: str, ) -> Generator: """ Convert ToolInvokeMessages into tuple[plain_text, files] @@ -199,8 +217,8 @@ class ToolNode(BaseNode[ToolNodeData]): # transform message and handle file storage message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( messages=messages, - user_id=self.user_id, - tenant_id=self.tenant_id, + user_id=user_id, + tenant_id=tenant_id, conversation_id=None, ) @@ -208,9 +226,6 @@ class ToolNode(BaseNode[ToolNodeData]): files: list[File] = [] json: list[dict] = [] - agent_logs: list[AgentLogEvent] = [] - agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} - llm_usage: LLMUsage | None = None variables: dict[str, Any] = {} for message in message_stream: @@ -243,7 +258,7 @@ class ToolNode(BaseNode[ToolNodeData]): } file = file_factory.build_from_mapping( mapping=mapping, - tenant_id=self.tenant_id, + tenant_id=tenant_id, ) files.append(file) elif message.type == ToolInvokeMessage.MessageType.BLOB: @@ -266,45 +281,36 @@ class ToolNode(BaseNode[ToolNodeData]): files.append( file_factory.build_from_mapping( mapping=mapping, - tenant_id=self.tenant_id, + tenant_id=tenant_id, ) ) elif message.type == ToolInvokeMessage.MessageType.TEXT: assert isinstance(message.message, ToolInvokeMessage.TextMessage) text += message.message.text - yield RunStreamChunkEvent( - chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"] - ) + yield RunStreamChunkEvent(chunk_content=message.message.text, from_variable_selector=[node_id, "text"]) elif message.type == ToolInvokeMessage.MessageType.JSON: assert isinstance(message.message, ToolInvokeMessage.JsonMessage) - if self.node_type == NodeType.AGENT: - msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {}) - llm_usage = LLMUsage.from_metadata(msg_metadata) - agent_execution_metadata = { - WorkflowNodeExecutionMetadataKey(key): value - for key, value in msg_metadata.items() - if key in WorkflowNodeExecutionMetadataKey.__members__.values() - } + # JSON message handling for tool node if message.message.json_object is not None: json.append(message.message.json_object) elif message.type == ToolInvokeMessage.MessageType.LINK: assert isinstance(message.message, ToolInvokeMessage.TextMessage) stream_text = f"Link: {message.message.text}\n" text += stream_text - yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"]) + yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[node_id, "text"]) elif message.type == ToolInvokeMessage.MessageType.VARIABLE: assert isinstance(message.message, ToolInvokeMessage.VariableMessage) variable_name = message.message.variable_name variable_value = message.message.variable_value if message.message.stream: if not isinstance(variable_value, str): - raise ValueError("When 'stream' is True, 'variable_value' must be a string.") + raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.") if variable_name not in variables: variables[variable_name] = "" variables[variable_name] += variable_value yield RunStreamChunkEvent( - chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name] + chunk_content=variable_value, from_variable_selector=[node_id, variable_name] ) else: variables[variable_name] = variable_value @@ -319,7 +325,7 @@ class ToolNode(BaseNode[ToolNodeData]): dict_metadata = dict(message.message.metadata) if dict_metadata.get("provider"): manager = PluginInstaller() - plugins = manager.list_plugins(self.tenant_id) + plugins = manager.list_plugins(tenant_id) try: current_plugin = next( plugin @@ -334,8 +340,8 @@ class ToolNode(BaseNode[ToolNodeData]): builtin_tool = next( provider for provider in BuiltinToolManageService.list_builtin_tools( - self.user_id, - self.tenant_id, + user_id, + tenant_id, ) if provider.name == dict_metadata["provider"] ) @@ -347,51 +353,10 @@ class ToolNode(BaseNode[ToolNodeData]): dict_metadata["icon"] = icon dict_metadata["icon_dark"] = icon_dark message.message.metadata = dict_metadata - agent_log = AgentLogEvent( - id=message.message.id, - node_execution_id=self.id, - parent_id=message.message.parent_id, - error=message.message.error, - status=message.message.status.value, - data=message.message.data, - label=message.message.label, - metadata=message.message.metadata, - node_id=self.node_id, - ) - - # check if the agent log is already in the list - for log in agent_logs: - if log.id == agent_log.id: - # update the log - log.data = agent_log.data - log.status = agent_log.status - log.error = agent_log.error - log.label = agent_log.label - log.metadata = agent_log.metadata - break - else: - agent_logs.append(agent_log) - - yield agent_log # Add agent_logs to outputs['json'] to ensure frontend can access thinking process json_output: list[dict[str, Any]] = [] - # Step 1: append each agent log as its own dict. - if agent_logs: - for log in agent_logs: - json_output.append( - { - "id": log.id, - "parent_id": log.parent_id, - "error": log.error, - "status": log.status, - "data": log.data, - "label": log.label, - "metadata": log.metadata, - "node_id": log.node_id, - } - ) # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] if json: json_output.extend(json) @@ -403,12 +368,9 @@ class ToolNode(BaseNode[ToolNodeData]): status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, metadata={ - **agent_execution_metadata, WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, - WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs, }, inputs=parameters_for_log, - llm_usage=llm_usage, ) ) @@ -418,7 +380,7 @@ class ToolNode(BaseNode[ToolNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: ToolNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: """ Extract variable selector to variable mapping @@ -427,9 +389,12 @@ class ToolNode(BaseNode[ToolNodeData]): :param node_data: node data :return: """ + # Create typed NodeData from dict + typed_node_data = ToolNodeData.model_validate(node_data) + result = {} - for parameter_name in node_data.tool_parameters: - input = node_data.tool_parameters[parameter_name] + for parameter_name in typed_node_data.tool_parameters: + input = typed_node_data.tool_parameters[parameter_name] if input.type == "mixed": assert isinstance(input.value, str) selectors = VariableTemplateParser(input.value).extract_variable_selectors() @@ -443,3 +408,29 @@ class ToolNode(BaseNode[ToolNodeData]): result = {node_id + "." + key: value for key, value in result.items()} return result + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + + @property + def continue_on_error(self) -> bool: + return self._node_data.error_strategy is not None + + @property + def retry(self) -> bool: + return self._node_data.retry_config.retry_enabled diff --git a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py index 96bb3e793a..98127bbeb6 100644 --- a/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py +++ b/api/core/workflow/nodes/variable_aggregator/variable_aggregator_node.py @@ -1,17 +1,41 @@ from collections.abc import Mapping +from typing import Any, Optional from core.variables.segments import Segment from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.variable_aggregator.entities import VariableAssignerNodeData -class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): - _node_data_cls = VariableAssignerNodeData +class VariableAggregatorNode(BaseNode): _node_type = NodeType.VARIABLE_AGGREGATOR + _node_data: VariableAssignerNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = VariableAssignerNodeData(**data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + @classmethod def version(cls) -> str: return "1" @@ -21,8 +45,8 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): outputs: dict[str, Segment | Mapping[str, Segment]] = {} inputs = {} - if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: - for selector in self.node_data.variables: + if not self._node_data.advanced_settings or not self._node_data.advanced_settings.group_enabled: + for selector in self._node_data.variables: variable = self.graph_runtime_state.variable_pool.get(selector) if variable is not None: outputs = {"output": variable} @@ -30,7 +54,7 @@ class VariableAggregatorNode(BaseNode[VariableAssignerNodeData]): inputs = {".".join(selector[1:]): variable.to_object()} break else: - for group in self.node_data.advanced_settings.groups: + for group in self._node_data.advanced_settings.groups: for selector in group.variables: variable = self.graph_runtime_state.variable_pool.get(selector) diff --git a/api/core/workflow/nodes/variable_assigner/v1/node.py b/api/core/workflow/nodes/variable_assigner/v1/node.py index be5083c9c1..51383fa588 100644 --- a/api/core/workflow/nodes/variable_assigner/v1/node.py +++ b/api/core/workflow/nodes/variable_assigner/v1/node.py @@ -7,7 +7,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from factories import variable_factory @@ -22,11 +23,33 @@ if TYPE_CHECKING: _CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] -class VariableAssignerNode(BaseNode[VariableAssignerData]): - _node_data_cls = VariableAssignerData +class VariableAssignerNode(BaseNode): _node_type = NodeType.VARIABLE_ASSIGNER _conv_var_updater_factory: _CONV_VAR_UPDATER_FACTORY + _node_data: VariableAssignerData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = VariableAssignerData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + def __init__( self, id: str, @@ -59,36 +82,39 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: VariableAssignerData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: + # Create typed NodeData from dict + typed_node_data = VariableAssignerData.model_validate(node_data) + mapping = {} - assigned_variable_node_id = node_data.assigned_variable_selector[0] + assigned_variable_node_id = typed_node_data.assigned_variable_selector[0] if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: - selector_key = ".".join(node_data.assigned_variable_selector) + selector_key = ".".join(typed_node_data.assigned_variable_selector) key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.assigned_variable_selector + mapping[key] = typed_node_data.assigned_variable_selector - selector_key = ".".join(node_data.input_variable_selector) + selector_key = ".".join(typed_node_data.input_variable_selector) key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.input_variable_selector + mapping[key] = typed_node_data.input_variable_selector return mapping def _run(self) -> NodeRunResult: - assigned_variable_selector = self.node_data.assigned_variable_selector + assigned_variable_selector = self._node_data.assigned_variable_selector # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) if not isinstance(original_variable, Variable): raise VariableOperatorNodeError("assigned variable not found") - match self.node_data.write_mode: + match self._node_data.write_mode: case WriteMode.OVER_WRITE: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector) if not income_value: raise VariableOperatorNodeError("input value not found") updated_variable = original_variable.model_copy(update={"value": income_value.value}) case WriteMode.APPEND: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) + income_value = self.graph_runtime_state.variable_pool.get(self._node_data.input_variable_selector) if not income_value: raise VariableOperatorNodeError("input value not found") updated_value = original_variable.value + [income_value.value] @@ -101,7 +127,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) case _: - raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}") + raise VariableOperatorNodeError(f"unsupported write mode: {self._node_data.write_mode}") # Over write the variable. self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) @@ -130,6 +156,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]): def get_zero_value(t: SegmentType): + # TODO(QuantumGhost): this should be a method of `SegmentType`. match t: case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: return variable_factory.build_segment([]) @@ -137,6 +164,10 @@ def get_zero_value(t: SegmentType): return variable_factory.build_segment({}) case SegmentType.STRING: return variable_factory.build_segment("") + case SegmentType.INTEGER: + return variable_factory.build_segment(0) + case SegmentType.FLOAT: + return variable_factory.build_segment(0.0) case SegmentType.NUMBER: return variable_factory.build_segment(0) case _: diff --git a/api/core/workflow/nodes/variable_assigner/v2/constants.py b/api/core/workflow/nodes/variable_assigner/v2/constants.py index 3797bfa77a..7f760e5baa 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/constants.py +++ b/api/core/workflow/nodes/variable_assigner/v2/constants.py @@ -1,5 +1,6 @@ from core.variables import SegmentType +# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy. EMPTY_VALUE_MAPPING = { SegmentType.STRING: "", SegmentType.NUMBER: 0, diff --git a/api/core/workflow/nodes/variable_assigner/v2/helpers.py b/api/core/workflow/nodes/variable_assigner/v2/helpers.py index 8fb2a27388..7a20975b15 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/helpers.py +++ b/api/core/workflow/nodes/variable_assigner/v2/helpers.py @@ -10,10 +10,16 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation): case Operation.OVER_WRITE | Operation.CLEAR: return True case Operation.SET: - return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER} + return variable_type in { + SegmentType.OBJECT, + SegmentType.STRING, + SegmentType.NUMBER, + SegmentType.INTEGER, + SegmentType.FLOAT, + } case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE: # Only number variable can be added, subtracted, multiplied or divided - return variable_type == SegmentType.NUMBER + return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT} case Operation.APPEND | Operation.EXTEND: # Only array variable can be appended or extended return variable_type in { @@ -46,7 +52,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat match variable_type: case SegmentType.STRING | SegmentType.OBJECT: return operation in {Operation.OVER_WRITE, Operation.SET} - case SegmentType.NUMBER: + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: return operation in { Operation.OVER_WRITE, Operation.SET, @@ -66,7 +72,7 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va case SegmentType.STRING: return isinstance(value, str) - case SegmentType.NUMBER: + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: if not isinstance(value, int | float): return False if operation == Operation.DIVIDE and value == 0: diff --git a/api/core/workflow/nodes/variable_assigner/v2/node.py b/api/core/workflow/nodes/variable_assigner/v2/node.py index 9292da6f1c..c0215cae71 100644 --- a/api/core/workflow/nodes/variable_assigner/v2/node.py +++ b/api/core/workflow/nodes/variable_assigner/v2/node.py @@ -1,6 +1,6 @@ import json -from collections.abc import Callable, Mapping, MutableMapping, Sequence -from typing import Any, TypeAlias, cast +from collections.abc import Mapping, MutableMapping, Sequence +from typing import Any, Optional, cast from core.app.entities.app_invoke_entities import InvokeFrom from core.variables import SegmentType, Variable @@ -10,7 +10,8 @@ from core.workflow.conversation_variable_updater import ConversationVariableUpda from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus from core.workflow.nodes.base import BaseNode -from core.workflow.nodes.enums import NodeType +from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig +from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.variable_assigner.common import helpers as common_helpers from core.workflow.nodes.variable_assigner.common.exc import VariableOperatorNodeError from core.workflow.nodes.variable_assigner.common.impl import conversation_variable_updater_factory @@ -28,8 +29,6 @@ from .exc import ( VariableNotFoundError, ) -_CONV_VAR_UPDATER_FACTORY: TypeAlias = Callable[[], ConversationVariableUpdater] - def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): selector_node_id = item.variable_selector[0] @@ -54,10 +53,32 @@ def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_ mapping[key] = selector -class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): - _node_data_cls = VariableAssignerNodeData +class VariableAssignerNode(BaseNode): _node_type = NodeType.VARIABLE_ASSIGNER + _node_data: VariableAssignerNodeData + + def init_node_data(self, data: Mapping[str, Any]) -> None: + self._node_data = VariableAssignerNodeData.model_validate(data) + + def _get_error_strategy(self) -> Optional[ErrorStrategy]: + return self._node_data.error_strategy + + def _get_retry_config(self) -> RetryConfig: + return self._node_data.retry_config + + def _get_title(self) -> str: + return self._node_data.title + + def _get_description(self) -> Optional[str]: + return self._node_data.desc + + def _get_default_value_dict(self) -> dict[str, Any]: + return self._node_data.default_value_dict + + def get_base_node_data(self) -> BaseNodeData: + return self._node_data + def _conv_var_updater_factory(self) -> ConversationVariableUpdater: return conversation_variable_updater_factory() @@ -71,22 +92,25 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]): *, graph_config: Mapping[str, Any], node_id: str, - node_data: VariableAssignerNodeData, + node_data: Mapping[str, Any], ) -> Mapping[str, Sequence[str]]: + # Create typed NodeData from dict + typed_node_data = VariableAssignerNodeData.model_validate(node_data) + var_mapping: dict[str, Sequence[str]] = {} - for item in node_data.items: + for item in typed_node_data.items: _target_mapping_from_item(var_mapping, node_id, item) _source_mapping_from_item(var_mapping, node_id, item) return var_mapping def _run(self) -> NodeRunResult: - inputs = self.node_data.model_dump() + inputs = self._node_data.model_dump() process_data: dict[str, Any] = {} # NOTE: This node has no outputs updated_variable_selectors: list[Sequence[str]] = [] try: - for item in self.node_data.items: + for item in self._node_data.items: variable = self.graph_runtime_state.variable_pool.get(item.variable_selector) # ==================== Validation Part diff --git a/api/core/workflow/repositories/workflow_execution_repository.py b/api/core/workflow/repositories/workflow_execution_repository.py index 5917310c8b..bcbd253392 100644 --- a/api/core/workflow/repositories/workflow_execution_repository.py +++ b/api/core/workflow/repositories/workflow_execution_repository.py @@ -1,4 +1,4 @@ -from typing import Optional, Protocol +from typing import Protocol from core.workflow.entities.workflow_execution import WorkflowExecution @@ -28,15 +28,3 @@ class WorkflowExecutionRepository(Protocol): execution: The WorkflowExecution instance to save or update """ ... - - def get(self, execution_id: str) -> Optional[WorkflowExecution]: - """ - Retrieve a WorkflowExecution by its ID. - - Args: - execution_id: The workflow execution ID - - Returns: - The WorkflowExecution instance if found, None otherwise - """ - ... diff --git a/api/core/workflow/repositories/workflow_node_execution_repository.py b/api/core/workflow/repositories/workflow_node_execution_repository.py index 1908a6b190..8bf81f5442 100644 --- a/api/core/workflow/repositories/workflow_node_execution_repository.py +++ b/api/core/workflow/repositories/workflow_node_execution_repository.py @@ -39,18 +39,6 @@ class WorkflowNodeExecutionRepository(Protocol): """ ... - def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]: - """ - Retrieve a NodeExecution by its node_execution_id. - - Args: - node_execution_id: The node execution ID - - Returns: - The NodeExecution instance if found, None otherwise - """ - ... - def get_by_workflow_run( self, workflow_run_id: str, @@ -69,24 +57,3 @@ class WorkflowNodeExecutionRepository(Protocol): A list of NodeExecution instances """ ... - - def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]: - """ - Retrieve all running NodeExecution instances for a specific workflow run. - - Args: - workflow_run_id: The workflow run ID - - Returns: - A list of running NodeExecution instances - """ - ... - - def clear(self) -> None: - """ - Clear all NodeExecution records based on implementation-specific criteria. - - This method is intended to be used for bulk deletion operations, such as removing - all records associated with a specific app_id and tenant_id in multi-tenant implementations. - """ - ... diff --git a/api/core/workflow/system_variable.py b/api/core/workflow/system_variable.py new file mode 100644 index 0000000000..df90c16596 --- /dev/null +++ b/api/core/workflow/system_variable.py @@ -0,0 +1,89 @@ +from collections.abc import Sequence +from typing import Any + +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator + +from core.file.models import File +from core.workflow.enums import SystemVariableKey + + +class SystemVariable(BaseModel): + """A model for managing system variables. + + Fields with a value of `None` are treated as absent and will not be included + in the variable pool. + """ + + model_config = ConfigDict( + extra="forbid", + serialize_by_alias=True, + validate_by_alias=True, + ) + + user_id: str | None = None + + # Ideally, `app_id` and `workflow_id` should be required and not `None`. + # However, there are scenarios in the codebase where these fields are not set. + # To maintain compatibility, they are marked as optional here. + app_id: str | None = None + workflow_id: str | None = None + + files: Sequence[File] = Field(default_factory=list) + + # NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`. + # To maintain compatibility with existing workflows, it must be serialized + # as `workflow_run_id` in dictionaries or JSON objects, and also referenced + # as `workflow_run_id` in the variable pool. + workflow_execution_id: str | None = Field( + validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"), + serialization_alias="workflow_run_id", + default=None, + ) + # Chatflow related fields. + query: str | None = None + conversation_id: str | None = None + dialogue_count: int | None = None + + @model_validator(mode="before") + @classmethod + def validate_json_fields(cls, data): + if isinstance(data, dict): + # For JSON validation, only allow workflow_run_id + if "workflow_execution_id" in data and "workflow_run_id" not in data: + # This is likely from direct instantiation, allow it + return data + elif "workflow_execution_id" in data and "workflow_run_id" in data: + # Both present, remove workflow_execution_id + data = data.copy() + data.pop("workflow_execution_id") + return data + return data + + @classmethod + def empty(cls) -> "SystemVariable": + return cls() + + def to_dict(self) -> dict[SystemVariableKey, Any]: + # NOTE: This method is provided for compatibility with legacy code. + # New code should use the `SystemVariable` object directly instead of converting + # it to a dictionary, as this conversion results in the loss of type information + # for each key, making static analysis more difficult. + + d: dict[SystemVariableKey, Any] = { + SystemVariableKey.FILES: self.files, + } + if self.user_id is not None: + d[SystemVariableKey.USER_ID] = self.user_id + if self.app_id is not None: + d[SystemVariableKey.APP_ID] = self.app_id + if self.workflow_id is not None: + d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id + if self.workflow_execution_id is not None: + d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id + if self.query is not None: + d[SystemVariableKey.QUERY] = self.query + if self.conversation_id is not None: + d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id + if self.dialogue_count is not None: + d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count + return d diff --git a/api/core/workflow/workflow_cycle_manager.py b/api/core/workflow/workflow_cycle_manager.py index 0aab2426af..f844aada95 100644 --- a/api/core/workflow/workflow_cycle_manager.py +++ b/api/core/workflow/workflow_cycle_manager.py @@ -1,6 +1,6 @@ from collections.abc import Mapping from dataclasses import dataclass -from datetime import UTC, datetime +from datetime import datetime from typing import Any, Optional, Union from uuid import uuid4 @@ -26,6 +26,7 @@ from core.workflow.entities.workflow_node_execution import ( from core.workflow.enums import SystemVariableKey from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from libs.datetime_utils import naive_utc_now @@ -43,7 +44,7 @@ class WorkflowCycleManager: self, *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], - workflow_system_variables: dict[SystemVariableKey, Any], + workflow_system_variables: SystemVariable, workflow_info: CycleManagerWorkflowInfo, workflow_execution_repository: WorkflowExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository, @@ -54,19 +55,15 @@ class WorkflowCycleManager: self._workflow_execution_repository = workflow_execution_repository self._workflow_node_execution_repository = workflow_node_execution_repository - def handle_workflow_run_start(self) -> WorkflowExecution: - inputs = {**self._application_generate_entity.inputs} - for key, value in (self._workflow_system_variables or {}).items(): - if key.value == "conversation": - continue - inputs[f"sys.{key.value}"] = value + # Initialize caches for workflow execution cycle + # These caches avoid redundant repository calls during a single workflow execution + self._workflow_execution_cache: dict[str, WorkflowExecution] = {} + self._node_execution_cache: dict[str, WorkflowNodeExecution] = {} - # handle special values - inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) + def handle_workflow_run_start(self) -> WorkflowExecution: + inputs = self._prepare_workflow_inputs() + execution_id = self._get_or_generate_execution_id() - # init workflow run - # TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this - execution_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_EXECUTION_ID) or uuid4()) execution = WorkflowExecution.new( id_=execution_id, workflow_id=self._workflow_info.workflow_id, @@ -74,12 +71,10 @@ class WorkflowCycleManager: workflow_version=self._workflow_info.version, graph=self._workflow_info.graph_data, inputs=inputs, - started_at=datetime.now(UTC).replace(tzinfo=None), + started_at=naive_utc_now(), ) - self._workflow_execution_repository.save(execution) - - return execution + return self._save_and_cache_workflow_execution(execution) def handle_workflow_run_success( self, @@ -93,23 +88,15 @@ class WorkflowCycleManager: ) -> WorkflowExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - # outputs = WorkflowEntry.handle_special_values(outputs) + self._update_workflow_execution_completion( + workflow_execution, + status=WorkflowExecutionStatus.SUCCEEDED, + outputs=outputs, + total_tokens=total_tokens, + total_steps=total_steps, + ) - workflow_execution.status = WorkflowExecutionStatus.SUCCEEDED - workflow_execution.outputs = outputs or {} - workflow_execution.total_tokens = total_tokens - workflow_execution.total_steps = total_steps - workflow_execution.finished_at = datetime.now(UTC).replace(tzinfo=None) - - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.WORKFLOW_TRACE, - workflow_execution=workflow_execution, - conversation_id=conversation_id, - user_id=trace_manager.user_id, - ) - ) + self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id) self._workflow_execution_repository.save(workflow_execution) return workflow_execution @@ -126,24 +113,17 @@ class WorkflowCycleManager: trace_manager: Optional[TraceQueueManager] = None, ) -> WorkflowExecution: execution = self._get_workflow_execution_or_raise_error(workflow_run_id) - # outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) - execution.status = WorkflowExecutionStatus.PARTIAL_SUCCEEDED - execution.outputs = outputs or {} - execution.total_tokens = total_tokens - execution.total_steps = total_steps - execution.finished_at = datetime.now(UTC).replace(tzinfo=None) - execution.exceptions_count = exceptions_count + self._update_workflow_execution_completion( + execution, + status=WorkflowExecutionStatus.PARTIAL_SUCCEEDED, + outputs=outputs, + total_tokens=total_tokens, + total_steps=total_steps, + exceptions_count=exceptions_count, + ) - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.WORKFLOW_TRACE, - workflow_execution=execution, - conversation_id=conversation_id, - user_id=trace_manager.user_id, - ) - ) + self._add_trace_task_if_needed(trace_manager, execution, conversation_id) self._workflow_execution_repository.save(execution) return execution @@ -163,39 +143,18 @@ class WorkflowCycleManager: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_run_id) now = naive_utc_now() - workflow_execution.status = WorkflowExecutionStatus(status.value) - workflow_execution.error_message = error_message - workflow_execution.total_tokens = total_tokens - workflow_execution.total_steps = total_steps - workflow_execution.finished_at = now - workflow_execution.exceptions_count = exceptions_count - - # Use the instance repository to find running executions for a workflow run - running_node_executions = self._workflow_node_execution_repository.get_running_executions( - workflow_run_id=workflow_execution.id_ + self._update_workflow_execution_completion( + workflow_execution, + status=status, + total_tokens=total_tokens, + total_steps=total_steps, + error_message=error_message, + exceptions_count=exceptions_count, + finished_at=now, ) - # Update the domain models - for node_execution in running_node_executions: - if node_execution.node_execution_id: - # Update the domain model - node_execution.status = WorkflowNodeExecutionStatus.FAILED - node_execution.error = error_message - node_execution.finished_at = now - node_execution.elapsed_time = (now - node_execution.created_at).total_seconds() - - # Update the repository with the domain model - self._workflow_node_execution_repository.save(node_execution) - - if trace_manager: - trace_manager.add_trace_task( - TraceTask( - TraceTaskName.WORKFLOW_TRACE, - workflow_execution=workflow_execution, - conversation_id=conversation_id, - user_id=trace_manager.user_id, - ) - ) + self._fail_running_node_executions(workflow_execution.id_, error_message, now) + self._add_trace_task_if_needed(trace_manager, workflow_execution, conversation_id) self._workflow_execution_repository.save(workflow_execution) return workflow_execution @@ -208,65 +167,24 @@ class WorkflowCycleManager: ) -> WorkflowNodeExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) - # Create a domain model - created_at = datetime.now(UTC).replace(tzinfo=None) - metadata = { - WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, - WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id, - WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id, - } - - domain_execution = WorkflowNodeExecution( - id=str(uuid4()), - workflow_id=workflow_execution.workflow_id, - workflow_execution_id=workflow_execution.id_, - predecessor_node_id=event.predecessor_node_id, - index=event.node_run_index, - node_execution_id=event.node_execution_id, - node_id=event.node_id, - node_type=event.node_type, - title=event.node_data.title, + domain_execution = self._create_node_execution_from_event( + workflow_execution=workflow_execution, + event=event, status=WorkflowNodeExecutionStatus.RUNNING, - metadata=metadata, - created_at=created_at, ) - # Use the instance repository to save the domain model - self._workflow_node_execution_repository.save(domain_execution) - - return domain_execution + return self._save_and_cache_node_execution(domain_execution) def handle_workflow_node_execution_success(self, *, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: - # Get the domain model from repository - domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) - if not domain_execution: - raise ValueError(f"Domain node execution not found: {event.node_execution_id}") + domain_execution = self._get_node_execution_from_cache(event.node_execution_id) - # Process data - inputs = event.inputs - process_data = event.process_data - outputs = event.outputs - - # Convert metadata keys to strings - execution_metadata_dict = {} - if event.execution_metadata: - for key, value in event.execution_metadata.items(): - execution_metadata_dict[key] = value - - finished_at = datetime.now(UTC).replace(tzinfo=None) - elapsed_time = (finished_at - event.start_at).total_seconds() - - # Update domain model - domain_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED - domain_execution.update_from_mapping( - inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict + self._update_node_execution_completion( + domain_execution, + event=event, + status=WorkflowNodeExecutionStatus.SUCCEEDED, ) - domain_execution.finished_at = finished_at - domain_execution.elapsed_time = elapsed_time - # Update the repository with the domain model self._workflow_node_execution_repository.save(domain_execution) - return domain_execution def handle_workflow_node_execution_failed( @@ -282,96 +200,251 @@ class WorkflowCycleManager: :param event: queue node failed event :return: """ - # Get the domain model from repository - domain_execution = self._workflow_node_execution_repository.get_by_node_execution_id(event.node_execution_id) - if not domain_execution: - raise ValueError(f"Domain node execution not found: {event.node_execution_id}") - - # Process data - inputs = WorkflowEntry.handle_special_values(event.inputs) - process_data = WorkflowEntry.handle_special_values(event.process_data) - outputs = event.outputs - - # Convert metadata keys to strings - execution_metadata_dict = {} - if event.execution_metadata: - for key, value in event.execution_metadata.items(): - execution_metadata_dict[key] = value - - finished_at = datetime.now(UTC).replace(tzinfo=None) - elapsed_time = (finished_at - event.start_at).total_seconds() + domain_execution = self._get_node_execution_from_cache(event.node_execution_id) - # Update domain model - domain_execution.status = ( - WorkflowNodeExecutionStatus.FAILED - if not isinstance(event, QueueNodeExceptionEvent) - else WorkflowNodeExecutionStatus.EXCEPTION + status = ( + WorkflowNodeExecutionStatus.EXCEPTION + if isinstance(event, QueueNodeExceptionEvent) + else WorkflowNodeExecutionStatus.FAILED ) - domain_execution.error = event.error - domain_execution.update_from_mapping( - inputs=inputs, process_data=process_data, outputs=outputs, metadata=execution_metadata_dict + + self._update_node_execution_completion( + domain_execution, + event=event, + status=status, + error=event.error, + handle_special_values=True, ) - domain_execution.finished_at = finished_at - domain_execution.elapsed_time = elapsed_time - # Update the repository with the domain model self._workflow_node_execution_repository.save(domain_execution) - return domain_execution def handle_workflow_node_execution_retried( self, *, workflow_execution_id: str, event: QueueNodeRetryEvent ) -> WorkflowNodeExecution: workflow_execution = self._get_workflow_execution_or_raise_error(workflow_execution_id) - created_at = event.start_at - finished_at = datetime.now(UTC).replace(tzinfo=None) - elapsed_time = (finished_at - created_at).total_seconds() + + domain_execution = self._create_node_execution_from_event( + workflow_execution=workflow_execution, + event=event, + status=WorkflowNodeExecutionStatus.RETRY, + error=event.error, + created_at=event.start_at, + ) + + # Handle inputs and outputs inputs = WorkflowEntry.handle_special_values(event.inputs) outputs = event.outputs + metadata = self._merge_event_metadata(event) - # Convert metadata keys to strings - origin_metadata = { - WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id, + domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=metadata) + + return self._save_and_cache_node_execution(domain_execution) + + def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution: + # Check cache first + if id in self._workflow_execution_cache: + return self._workflow_execution_cache[id] + + raise WorkflowRunNotFoundError(id) + + def _prepare_workflow_inputs(self) -> dict[str, Any]: + """Prepare workflow inputs by merging application inputs with system variables.""" + inputs = {**self._application_generate_entity.inputs} + + if self._workflow_system_variables: + for field_name, value in self._workflow_system_variables.to_dict().items(): + if field_name != SystemVariableKey.CONVERSATION_ID: + inputs[f"sys.{field_name}"] = value + + return dict(WorkflowEntry.handle_special_values(inputs) or {}) + + def _get_or_generate_execution_id(self) -> str: + """Get execution ID from system variables or generate a new one.""" + if self._workflow_system_variables and self._workflow_system_variables.workflow_execution_id: + return str(self._workflow_system_variables.workflow_execution_id) + return str(uuid4()) + + def _save_and_cache_workflow_execution(self, execution: WorkflowExecution) -> WorkflowExecution: + """Save workflow execution to repository and cache it.""" + self._workflow_execution_repository.save(execution) + self._workflow_execution_cache[execution.id_] = execution + return execution + + def _save_and_cache_node_execution(self, execution: WorkflowNodeExecution) -> WorkflowNodeExecution: + """Save node execution to repository and cache it if it has an ID.""" + self._workflow_node_execution_repository.save(execution) + if execution.node_execution_id: + self._node_execution_cache[execution.node_execution_id] = execution + return execution + + def _get_node_execution_from_cache(self, node_execution_id: str) -> WorkflowNodeExecution: + """Get node execution from cache or raise error if not found.""" + domain_execution = self._node_execution_cache.get(node_execution_id) + if not domain_execution: + raise ValueError(f"Domain node execution not found: {node_execution_id}") + return domain_execution + + def _update_workflow_execution_completion( + self, + execution: WorkflowExecution, + *, + status: WorkflowExecutionStatus, + total_tokens: int, + total_steps: int, + outputs: Mapping[str, Any] | None = None, + error_message: Optional[str] = None, + exceptions_count: int = 0, + finished_at: Optional[datetime] = None, + ) -> None: + """Update workflow execution with completion data.""" + execution.status = status + execution.outputs = outputs or {} + execution.total_tokens = total_tokens + execution.total_steps = total_steps + execution.finished_at = finished_at or naive_utc_now() + execution.exceptions_count = exceptions_count + if error_message: + execution.error_message = error_message + + def _add_trace_task_if_needed( + self, + trace_manager: Optional[TraceQueueManager], + workflow_execution: WorkflowExecution, + conversation_id: Optional[str], + ) -> None: + """Add trace task if trace manager is provided.""" + if trace_manager: + trace_manager.add_trace_task( + TraceTask( + TraceTaskName.WORKFLOW_TRACE, + workflow_execution=workflow_execution, + conversation_id=conversation_id, + user_id=trace_manager.user_id, + ) + ) + + def _fail_running_node_executions( + self, + workflow_execution_id: str, + error_message: str, + now: datetime, + ) -> None: + """Fail all running node executions for a workflow.""" + running_node_executions = [ + node_exec + for node_exec in self._node_execution_cache.values() + if node_exec.workflow_execution_id == workflow_execution_id + and node_exec.status == WorkflowNodeExecutionStatus.RUNNING + ] + + for node_execution in running_node_executions: + if node_execution.node_execution_id: + node_execution.status = WorkflowNodeExecutionStatus.FAILED + node_execution.error = error_message + node_execution.finished_at = now + node_execution.elapsed_time = (now - node_execution.created_at).total_seconds() + self._workflow_node_execution_repository.save(node_execution) + + def _create_node_execution_from_event( + self, + *, + workflow_execution: WorkflowExecution, + event: Union[QueueNodeStartedEvent, QueueNodeRetryEvent], + status: WorkflowNodeExecutionStatus, + error: Optional[str] = None, + created_at: Optional[datetime] = None, + ) -> WorkflowNodeExecution: + """Create a node execution from an event.""" + now = naive_utc_now() + created_at = created_at or now + + metadata = { WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id, WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id, } - # Convert execution metadata keys to strings - execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {} - if event.execution_metadata: - for key, value in event.execution_metadata.items(): - execution_metadata_dict[key] = value - - merged_metadata = {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata - - # Create a domain model domain_execution = WorkflowNodeExecution( id=str(uuid4()), workflow_id=workflow_execution.workflow_id, workflow_execution_id=workflow_execution.id_, predecessor_node_id=event.predecessor_node_id, + index=event.node_run_index, node_execution_id=event.node_execution_id, node_id=event.node_id, node_type=event.node_type, title=event.node_data.title, - status=WorkflowNodeExecutionStatus.RETRY, + status=status, + metadata=metadata, created_at=created_at, - finished_at=finished_at, - elapsed_time=elapsed_time, - error=event.error, - index=event.node_run_index, + error=error, ) - # Update with mappings - domain_execution.update_from_mapping(inputs=inputs, outputs=outputs, metadata=merged_metadata) - - # Use the instance repository to save the domain model - self._workflow_node_execution_repository.save(domain_execution) + if status == WorkflowNodeExecutionStatus.RETRY: + domain_execution.finished_at = now + domain_execution.elapsed_time = (now - created_at).total_seconds() return domain_execution - def _get_workflow_execution_or_raise_error(self, id: str, /) -> WorkflowExecution: - execution = self._workflow_execution_repository.get(id) - if not execution: - raise WorkflowRunNotFoundError(id) - return execution + def _update_node_execution_completion( + self, + domain_execution: WorkflowNodeExecution, + *, + event: Union[ + QueueNodeSucceededEvent, + QueueNodeFailedEvent, + QueueNodeInIterationFailedEvent, + QueueNodeInLoopFailedEvent, + QueueNodeExceptionEvent, + ], + status: WorkflowNodeExecutionStatus, + error: Optional[str] = None, + handle_special_values: bool = False, + ) -> None: + """Update node execution with completion data.""" + finished_at = naive_utc_now() + elapsed_time = (finished_at - event.start_at).total_seconds() + + # Process data + if handle_special_values: + inputs = WorkflowEntry.handle_special_values(event.inputs) + process_data = WorkflowEntry.handle_special_values(event.process_data) + else: + inputs = event.inputs + process_data = event.process_data + + outputs = event.outputs + + # Convert metadata + execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, Any] = {} + if event.execution_metadata: + execution_metadata_dict.update(event.execution_metadata) + + # Update domain model + domain_execution.status = status + domain_execution.update_from_mapping( + inputs=inputs, + process_data=process_data, + outputs=outputs, + metadata=execution_metadata_dict, + ) + domain_execution.finished_at = finished_at + domain_execution.elapsed_time = elapsed_time + + if error: + domain_execution.error = error + + def _merge_event_metadata(self, event: QueueNodeRetryEvent) -> dict[WorkflowNodeExecutionMetadataKey, str | None]: + """Merge event metadata with origin metadata.""" + origin_metadata = { + WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id, + WorkflowNodeExecutionMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, + WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id, + } + + execution_metadata_dict: dict[WorkflowNodeExecutionMetadataKey, str | None] = {} + if event.execution_metadata: + execution_metadata_dict.update(event.execution_metadata) + + return {**execution_metadata_dict, **origin_metadata} if execution_metadata_dict else origin_metadata diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 2868dcb7de..d2375da39c 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -5,7 +5,7 @@ from collections.abc import Generator, Mapping, Sequence from typing import Any, Optional, cast from configs import dify_config -from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError +from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File from core.workflow.callbacks import WorkflowCallback @@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType from core.workflow.nodes.base import BaseNode from core.workflow.nodes.event import NodeEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.system_variable import SystemVariable from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from factories import file_factory from models.enums import UserFrom @@ -145,7 +146,7 @@ class WorkflowEntry: graph = Graph.init(graph_config=workflow.graph_dict) # init workflow run state - node_instance = node_cls( + node = node_cls( id=str(uuid.uuid4()), config=node_config, graph_init_params=GraphInitParams( @@ -189,17 +190,11 @@ class WorkflowEntry: try: # run node - generator = node_instance.run() + generator = node.run() except Exception as e: - logger.exception( - "error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s", - workflow.id, - node_instance.id, - node_instance.node_type, - node_instance.version(), - ) - raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) - return node_instance, generator + logger.exception(f"error while running node, {workflow.id=}, {node.id=}, {node.type_=}, {node.version()=}") + raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) + return node, generator @classmethod def run_free_node( @@ -254,14 +249,14 @@ class WorkflowEntry: # init variable pool variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs={}, environment_variables=[], ) node_cls = cast(type[BaseNode], node_cls) # init workflow run state - node_instance: BaseNode = node_cls( + node: BaseNode = node_cls( id=str(uuid.uuid4()), config=node_config, graph_init_params=GraphInitParams( @@ -296,17 +291,12 @@ class WorkflowEntry: ) # run node - generator = node_instance.run() + generator = node.run() - return node_instance, generator + return node, generator except Exception as e: - logger.exception( - "error while running node_instance, node_id=%s, type=%s, version=%s", - node_instance.id, - node_instance.node_type, - node_instance.version(), - ) - raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) + logger.exception(f"error while running node, {node.id=}, {node.type_=}, {node.version()=}") + raise WorkflowNodeRunFailedError(node=node, err_msg=str(e)) @staticmethod def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None: diff --git a/api/core/workflow/workflow_type_encoder.py b/api/core/workflow/workflow_type_encoder.py index 0123fdac18..2c634d25ec 100644 --- a/api/core/workflow/workflow_type_encoder.py +++ b/api/core/workflow/workflow_type_encoder.py @@ -1,4 +1,3 @@ -import json from collections.abc import Mapping from typing import Any @@ -8,18 +7,6 @@ from core.file.models import File from core.variables import Segment -class WorkflowRuntimeTypeEncoder(json.JSONEncoder): - def default(self, o: Any): - if isinstance(o, Segment): - return o.value - elif isinstance(o, File): - return o.to_dict() - elif isinstance(o, BaseModel): - return o.model_dump(mode="json") - else: - return super().default(o) - - class WorkflowRuntimeTypeConverter: def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: result = self._to_json_encodable_recursive(value) diff --git a/api/events/event_handlers/create_document_index.py b/api/events/event_handlers/create_document_index.py index 8a677f6b6f..cb48bd92a0 100644 --- a/api/events/event_handlers/create_document_index.py +++ b/api/events/event_handlers/create_document_index.py @@ -1,4 +1,3 @@ -import datetime import logging import time @@ -8,6 +7,7 @@ from werkzeug.exceptions import NotFound from core.indexing_runner import DocumentIsPausedError, IndexingRunner from events.event_handlers.document_index_event import document_index_created from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.dataset import Document @@ -33,7 +33,7 @@ def handle(sender, **kwargs): raise NotFound("Document not found") document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.processing_started_at = naive_utc_now() documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index 249bd14429..6c9fc0bf1d 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -20,6 +20,7 @@ def handle(sender, **kwargs): provider_id=tool_entity.provider_id, tool_name=tool_entity.tool_name, tenant_id=app.tenant_id, + credential_id=tool_entity.credential_id, ) manager = ToolParameterConfigurationManager( tenant_id=app.tenant_id, diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index ddc2158a02..600e336c19 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -18,6 +18,7 @@ def init_app(app: DifyApp): reset_email, reset_encrypt_key_pair, reset_password, + setup_system_tool_oauth_client, upgrade_db, vdb_migrate, ) @@ -40,6 +41,7 @@ def init_app(app: DifyApp): clear_free_plan_tenant_expired_logs, clear_orphaned_file_records, remove_orphaned_files_on_storage, + setup_system_tool_oauth_client, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index b62b0b60d6..b027a165f9 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -193,13 +193,22 @@ def init_app(app: DifyApp): insecure=True, ) else: + headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None + + trace_endpoint = dify_config.OTLP_TRACE_ENDPOINT + if not trace_endpoint: + trace_endpoint = dify_config.OTLP_BASE_ENDPOINT + "/v1/traces" exporter = HTTPSpanExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/traces", - headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"}, + endpoint=trace_endpoint, + headers=headers, ) + + metric_endpoint = dify_config.OTLP_METRIC_ENDPOINT + if not metric_endpoint: + metric_endpoint = dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics" metric_exporter = HTTPMetricExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT + "/v1/metrics", - headers={"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"}, + endpoint=metric_endpoint, + headers=headers, ) else: exporter = ConsoleSpanExporter() diff --git a/api/extensions/storage/azure_blob_storage.py b/api/extensions/storage/azure_blob_storage.py index 7448fd4a6b..81eec94da4 100644 --- a/api/extensions/storage/azure_blob_storage.py +++ b/api/extensions/storage/azure_blob_storage.py @@ -1,5 +1,5 @@ from collections.abc import Generator -from datetime import UTC, datetime, timedelta +from datetime import timedelta from typing import Optional from azure.identity import ChainedTokenCredential, DefaultAzureCredential @@ -8,6 +8,7 @@ from azure.storage.blob import AccountSasPermissions, BlobServiceClient, Resourc from configs import dify_config from extensions.ext_redis import redis_client from extensions.storage.base_storage import BaseStorage +from libs.datetime_utils import naive_utc_now class AzureBlobStorage(BaseStorage): @@ -78,7 +79,7 @@ class AzureBlobStorage(BaseStorage): account_key=self.account_key or "", resource_types=ResourceTypes(service=True, container=True, object=True), permission=AccountSasPermissions(read=True, write=True, delete=True, list=True, add=True, create=True), - expiry=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1), + expiry=naive_utc_now() + timedelta(hours=1), ) redis_client.set(cache_key, sas_token, ex=3000) return BlobServiceClient(account_url=self.account_url or "", credential=sas_token) diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py index 25d1390492..c974dbb700 100644 --- a/api/factories/file_factory.py +++ b/api/factories/file_factory.py @@ -148,9 +148,7 @@ def _build_from_local_file( if strict_type_validation and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = ( - FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type - ) + file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type return File( id=mapping.get("id"), @@ -199,9 +197,7 @@ def _build_from_remote_url( raise ValueError("Detected file type does not match the specified type. Please verify the file.") file_type = ( - FileType(specified_type) - if specified_type and specified_type != FileType.CUSTOM.value - else detected_file_type + FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type ) return File( @@ -286,9 +282,7 @@ def _build_from_tool_file( if strict_type_validation and specified_type and detected_file_type.value != specified_type: raise ValueError("Detected file type does not match the specified type. Please verify the file.") - file_type = ( - FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM.value else detected_file_type - ) + file_type = FileType(specified_type) if specified_type and specified_type != FileType.CUSTOM else detected_file_type return File( id=mapping.get("id"), diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 250ee4695e..39ebd009d5 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -91,9 +91,13 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen result = StringVariable.model_validate(mapping) case SegmentType.SECRET: result = SecretVariable.model_validate(mapping) - case SegmentType.NUMBER if isinstance(value, int): + case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int): + mapping = dict(mapping) + mapping["value_type"] = SegmentType.INTEGER result = IntegerVariable.model_validate(mapping) - case SegmentType.NUMBER if isinstance(value, float): + case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float): + mapping = dict(mapping) + mapping["value_type"] = SegmentType.FLOAT result = FloatVariable.model_validate(mapping) case SegmentType.NUMBER if not isinstance(value, float | int): raise VariableError(f"invalid number value {value}") @@ -119,6 +123,8 @@ def infer_segment_type_from_value(value: Any, /) -> SegmentType: def build_segment(value: Any, /) -> Segment: + # NOTE: If you have runtime type information available, consider using the `build_segment_with_type` + # below if value is None: return NoneSegment() if isinstance(value, str): @@ -134,12 +140,17 @@ def build_segment(value: Any, /) -> Segment: if isinstance(value, list): items = [build_segment(item) for item in value] types = {item.value_type for item in items} - if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items): + if all(isinstance(item, ArraySegment) for item in items): return ArrayAnySegment(value=value) + elif len(types) != 1: + if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): + return ArrayNumberSegment(value=value) + return ArrayAnySegment(value=value) + match types.pop(): case SegmentType.STRING: return ArrayStringSegment(value=value) - case SegmentType.NUMBER: + case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: return ArrayNumberSegment(value=value) case SegmentType.OBJECT: return ArrayObjectSegment(value=value) @@ -153,6 +164,22 @@ def build_segment(value: Any, /) -> Segment: raise ValueError(f"not supported value {value}") +_segment_factory: Mapping[SegmentType, type[Segment]] = { + SegmentType.NONE: NoneSegment, + SegmentType.STRING: StringSegment, + SegmentType.INTEGER: IntegerSegment, + SegmentType.FLOAT: FloatSegment, + SegmentType.FILE: FileSegment, + SegmentType.OBJECT: ObjectSegment, + # Array types + SegmentType.ARRAY_ANY: ArrayAnySegment, + SegmentType.ARRAY_STRING: ArrayStringSegment, + SegmentType.ARRAY_NUMBER: ArrayNumberSegment, + SegmentType.ARRAY_OBJECT: ArrayObjectSegment, + SegmentType.ARRAY_FILE: ArrayFileSegment, +} + + def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: """ Build a segment with explicit type checking. @@ -190,7 +217,7 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: if segment_type == SegmentType.NONE: return NoneSegment() else: - raise TypeMismatchError(f"Expected {segment_type}, but got None") + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") # Handle empty list special case for array types if isinstance(value, list) and len(value) == 0: @@ -205,21 +232,25 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: elif segment_type == SegmentType.ARRAY_FILE: return ArrayFileSegment(value=value) else: - raise TypeMismatchError(f"Expected {segment_type}, but got empty list") - - # Build segment using existing logic to infer actual type - inferred_segment = build_segment(value) - inferred_type = inferred_segment.value_type + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") + inferred_type = SegmentType.infer_segment_type(value) # Type compatibility checking + if inferred_type is None: + raise TypeMismatchError( + f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" + ) if inferred_type == segment_type: - return inferred_segment - - # Type mismatch - raise error with descriptive message - raise TypeMismatchError( - f"Type mismatch: expected {segment_type}, but value '{value}' " - f"(type: {type(value).__name__}) corresponds to {inferred_type}" - ) + segment_class = _segment_factory[segment_type] + return segment_class(value_type=segment_type, value=value) + elif segment_type == SegmentType.NUMBER and inferred_type in ( + SegmentType.INTEGER, + SegmentType.FLOAT, + ): + segment_class = _segment_factory[inferred_type] + return segment_class(value_type=inferred_type, value=value) + else: + raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") def segment_to_variable( @@ -247,6 +278,6 @@ def segment_to_variable( name=name, description=description, value=segment.value, - selector=selector, + selector=list(selector), ), ) diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py new file mode 100644 index 0000000000..8288bd54a3 --- /dev/null +++ b/api/fields/_value_type_serializer.py @@ -0,0 +1,15 @@ +from typing import TypedDict + +from core.variables.segments import Segment +from core.variables.types import SegmentType + + +class _VarTypedDict(TypedDict, total=False): + value_type: SegmentType + + +def serialize_value_type(v: _VarTypedDict | Segment) -> str: + if isinstance(v, Segment): + return v.value_type.exposed_type().value + else: + return v["value_type"].exposed_type().value diff --git a/api/fields/app_fields.py b/api/fields/app_fields.py index 73c224542a..b6d85e0e24 100644 --- a/api/fields/app_fields.py +++ b/api/fields/app_fields.py @@ -188,6 +188,7 @@ app_detail_fields_with_site = { "site": fields.Nested(site_fields), "api_base_url": fields.String, "use_icon_as_answer_icon": fields.Boolean, + "max_active_requests": fields.Integer, "created_by": fields.String, "created_at": TimestampField, "updated_by": fields.String, diff --git a/api/fields/conversation_variable_fields.py b/api/fields/conversation_variable_fields.py index 71785e7d67..c5a0c9a49d 100644 --- a/api/fields/conversation_variable_fields.py +++ b/api/fields/conversation_variable_fields.py @@ -2,10 +2,12 @@ from flask_restful import fields from libs.helper import TimestampField +from ._value_type_serializer import serialize_value_type + conversation_variable_fields = { "id": fields.String, "name": fields.String, - "value_type": fields.String(attribute="value_type.value"), + "value_type": fields.String(attribute=serialize_value_type), "value": fields.String, "description": fields.String, "created_at": TimestampField, diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index f00ea71c54..930e59cc1c 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -5,6 +5,8 @@ from core.variables import SecretVariable, SegmentType, Variable from fields.member_fields import simple_account_fields from libs.helper import TimestampField +from ._value_type_serializer import serialize_value_type + ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET) @@ -24,11 +26,16 @@ class EnvironmentVariableField(fields.Raw): "id": value.id, "name": value.name, "value": value.value, - "value_type": value.value_type.value, + "value_type": value.value_type.exposed_type().value, "description": value.description, } if isinstance(value, dict): - value_type = value.get("value_type") + value_type_str = value.get("value_type") + if not isinstance(value_type_str, str): + raise TypeError( + f"unexpected type for value_type field, value={value_type_str}, type={type(value_type_str)}" + ) + value_type = SegmentType(value_type_str).exposed_type() if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES: raise ValueError(f"Unsupported environment variable value type: {value_type}") return value @@ -37,7 +44,7 @@ class EnvironmentVariableField(fields.Raw): conversation_variable_fields = { "id": fields.String, "name": fields.String, - "value_type": fields.String(attribute="value_type.value"), + "value_type": fields.String(attribute=serialize_value_type), "value": fields.Raw, "description": fields.String, } diff --git a/api/libs/helper.py b/api/libs/helper.py index 48126461a3..00772d530a 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -148,25 +148,6 @@ class StrLen: return value -class FloatRange: - """Restrict input to an float in a range (inclusive)""" - - def __init__(self, low, high, argument="argument"): - self.low = low - self.high = high - self.argument = argument - - def __call__(self, value): - value = _get_float(value) - if value < self.low or value > self.high: - error = "Invalid {arg}: {val}. {arg} must be within the range {lo} - {hi}".format( - arg=self.argument, val=value, lo=self.low, hi=self.high - ) - raise ValueError(error) - - return value - - class DatetimeString: def __init__(self, format, argument="argument"): self.format = format diff --git a/api/libs/jsonutil.py b/api/libs/jsonutil.py deleted file mode 100644 index fa29671034..0000000000 --- a/api/libs/jsonutil.py +++ /dev/null @@ -1,11 +0,0 @@ -import json - -from pydantic import BaseModel - - -class PydanticModelEncoder(json.JSONEncoder): - def default(self, o): - if isinstance(o, BaseModel): - return o.model_dump() - else: - super().default(o) diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index 218109522d..78f827584c 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -1,4 +1,3 @@ -import datetime import urllib.parse from typing import Any @@ -6,6 +5,7 @@ import requests from flask_login import current_user from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.source import DataSourceOauthBinding @@ -75,7 +75,7 @@ class NotionOAuth(OAuthDataSource): if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + data_source_binding.updated_at = naive_utc_now() db.session.commit() else: new_data_source_binding = DataSourceOauthBinding( @@ -115,7 +115,7 @@ class NotionOAuth(OAuthDataSource): if data_source_binding: data_source_binding.source_info = source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + data_source_binding.updated_at = naive_utc_now() db.session.commit() else: new_data_source_binding = DataSourceOauthBinding( @@ -154,7 +154,7 @@ class NotionOAuth(OAuthDataSource): } data_source_binding.source_info = new_source_info data_source_binding.disabled = False - data_source_binding.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + data_source_binding.updated_at = naive_utc_now() db.session.commit() else: raise ValueError("Data source binding not found") diff --git a/api/libs/rsa.py b/api/libs/rsa.py index 637bcc4a1d..da279eb32b 100644 --- a/api/libs/rsa.py +++ b/api/libs/rsa.py @@ -1,4 +1,5 @@ import hashlib +from typing import Union from Crypto.Cipher import AES from Crypto.PublicKey import RSA @@ -9,7 +10,7 @@ from extensions.ext_storage import storage from libs import gmpy2_pkcs10aep_cipher -def generate_key_pair(tenant_id): +def generate_key_pair(tenant_id: str) -> str: private_key = RSA.generate(2048) public_key = private_key.publickey() @@ -26,7 +27,7 @@ def generate_key_pair(tenant_id): prefix_hybrid = b"HYBRID:" -def encrypt(text, public_key): +def encrypt(text: str, public_key: Union[str, bytes]) -> bytes: if isinstance(public_key, str): public_key = public_key.encode() @@ -38,14 +39,14 @@ def encrypt(text, public_key): rsa_key = RSA.import_key(public_key) cipher_rsa = gmpy2_pkcs10aep_cipher.new(rsa_key) - enc_aes_key = cipher_rsa.encrypt(aes_key) + enc_aes_key: bytes = cipher_rsa.encrypt(aes_key) encrypted_data = enc_aes_key + cipher_aes.nonce + tag + ciphertext return prefix_hybrid + encrypted_data -def get_decrypt_decoding(tenant_id): +def get_decrypt_decoding(tenant_id: str) -> tuple[RSA.RsaKey, object]: filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) @@ -64,7 +65,7 @@ def get_decrypt_decoding(tenant_id): return rsa_key, cipher_rsa -def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa): +def decrypt_token_with_decoding(encrypted_text: bytes, rsa_key: RSA.RsaKey, cipher_rsa) -> str: if encrypted_text.startswith(prefix_hybrid): encrypted_text = encrypted_text[len(prefix_hybrid) :] @@ -83,10 +84,10 @@ def decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa): return decrypted_text.decode() -def decrypt(encrypted_text, tenant_id): +def decrypt(encrypted_text: bytes, tenant_id: str) -> str: rsa_key, cipher_rsa = get_decrypt_decoding(tenant_id) - return decrypt_token_with_decoding(encrypted_text, rsa_key, cipher_rsa) + return decrypt_token_with_decoding(encrypted_text=encrypted_text, rsa_key=rsa_key, cipher_rsa=cipher_rsa) class PrivkeyNotFoundError(Exception): diff --git a/api/libs/uuid_utils.py b/api/libs/uuid_utils.py new file mode 100644 index 0000000000..a8190011ed --- /dev/null +++ b/api/libs/uuid_utils.py @@ -0,0 +1,164 @@ +import secrets +import struct +import time +import uuid + +# Reference for UUIDv7 specification: +# RFC 9562, Section 5.7 - https://www.rfc-editor.org/rfc/rfc9562.html#section-5.7 + +# Define the format for packing the timestamp as an unsigned 64-bit integer (big-endian). +# +# For details on the `struct.pack` format, refer to: +# https://docs.python.org/3/library/struct.html#byte-order-size-and-alignment +_PACK_TIMESTAMP = ">Q" + +# Define the format for packing the 12-bit random data A (as specified in RFC 9562 Section 5.7) +# into an unsigned 16-bit integer (big-endian). +_PACK_RAND_A = ">H" + + +def _create_uuidv7_bytes(timestamp_ms: int, random_bytes: bytes) -> bytes: + """Create UUIDv7 byte structure with given timestamp and random bytes. + + This is a private helper function that handles the common logic for creating + UUIDv7 byte structure according to RFC 9562 specification. + + UUIDv7 Structure: + - 48 bits: timestamp (milliseconds since Unix epoch) + - 12 bits: random data A (with version bits) + - 62 bits: random data B (with variant bits) + + The function performs the following operations: + 1. Creates a 128-bit (16-byte) UUID structure + 2. Packs the timestamp into the first 48 bits (6 bytes) + 3. Sets the version bits to 7 (0111) in the correct position + 4. Sets the variant bits to 10 (binary) in the correct position + 5. Fills the remaining bits with the provided random bytes + + Args: + timestamp_ms: The timestamp in milliseconds since Unix epoch (48 bits). + random_bytes: Random bytes to use for the random portions (must be 10 bytes). + First 2 bytes are used for random data A (12 bits after version). + Last 8 bytes are used for random data B (62 bits after variant). + + Returns: + A 16-byte bytes object representing the complete UUIDv7 structure. + + Note: + This function assumes the random_bytes parameter is exactly 10 bytes. + The caller is responsible for providing appropriate random data. + """ + # Create the 128-bit UUID structure + uuid_bytes = bytearray(16) + + # Pack timestamp (48 bits) into first 6 bytes + uuid_bytes[0:6] = struct.pack(_PACK_TIMESTAMP, timestamp_ms)[2:8] # Take last 6 bytes of 8-byte big-endian + + # Next 16 bits: random data A (12 bits) + version (4 bits) + # Take first 2 random bytes and set version to 7 + rand_a = struct.unpack(_PACK_RAND_A, random_bytes[0:2])[0] + # Clear the highest 4 bits to make room for the version field + # by performing a bitwise AND with 0x0FFF (binary: 0b0000_1111_1111_1111). + rand_a = rand_a & 0x0FFF + # Set the version field to 7 (binary: 0111) by performing a bitwise OR with 0x7000 (binary: 0b0111_0000_0000_0000). + rand_a = rand_a | 0x7000 + uuid_bytes[6:8] = struct.pack(_PACK_RAND_A, rand_a) + + # Last 64 bits: random data B (62 bits) + variant (2 bits) + # Use remaining 8 random bytes and set variant to 10 (binary) + uuid_bytes[8:16] = random_bytes[2:10] + + # Set variant bits (first 2 bits of byte 8 should be '10') + uuid_bytes[8] = (uuid_bytes[8] & 0x3F) | 0x80 # Set variant to 10xxxxxx + + return bytes(uuid_bytes) + + +def uuidv7(timestamp_ms: int | None = None) -> uuid.UUID: + """Generate a UUID version 7 according to RFC 9562 specification. + + UUIDv7 features a time-ordered value field derived from the widely + implemented and well known Unix Epoch timestamp source, the number of + milliseconds since midnight 1 Jan 1970 UTC, leap seconds excluded. + + Structure: + - 48 bits: timestamp (milliseconds since Unix epoch) + - 12 bits: random data A (with version bits) + - 62 bits: random data B (with variant bits) + + Args: + timestamp_ms: The timestamp used when generating UUID, use the current time if unspecified. + Should be an integer representing milliseconds since Unix epoch. + + Returns: + A UUID object representing a UUIDv7. + + Example: + >>> import time + >>> # Generate UUIDv7 with current time + >>> uuid_current = uuidv7() + >>> # Generate UUIDv7 with specific timestamp + >>> uuid_specific = uuidv7(int(time.time() * 1000)) + """ + if timestamp_ms is None: + timestamp_ms = int(time.time() * 1000) + + # Generate 10 random bytes for the random portions + random_bytes = secrets.token_bytes(10) + + # Create UUIDv7 bytes using the helper function + uuid_bytes = _create_uuidv7_bytes(timestamp_ms, random_bytes) + + return uuid.UUID(bytes=uuid_bytes) + + +def uuidv7_timestamp(id_: uuid.UUID) -> int: + """Extract the timestamp from a UUIDv7. + + UUIDv7 contains a 48-bit timestamp field representing milliseconds since + the Unix epoch (1970-01-01 00:00:00 UTC). This function extracts and + returns that timestamp as an integer representing milliseconds since the epoch. + + Args: + id_: A UUID object that should be a UUIDv7 (version 7). + + Returns: + The timestamp as an integer representing milliseconds since Unix epoch. + + Raises: + ValueError: If the provided UUID is not version 7. + + Example: + >>> uuid_v7 = uuidv7() + >>> timestamp = uuidv7_timestamp(uuid_v7) + >>> print(f"UUID was created at: {timestamp} ms") + """ + # Verify this is a UUIDv7 + if id_.version != 7: + raise ValueError(f"Expected UUIDv7 (version 7), got version {id_.version}") + + # Extract the UUID bytes + uuid_bytes = id_.bytes + + # Extract the first 48 bits (6 bytes) as the timestamp in milliseconds + # Pad with 2 zero bytes at the beginning to make it 8 bytes for unpacking as Q (unsigned long long) + timestamp_bytes = b"\x00\x00" + uuid_bytes[0:6] + ts_in_ms = struct.unpack(_PACK_TIMESTAMP, timestamp_bytes)[0] + + # Return timestamp directly in milliseconds as integer + assert isinstance(ts_in_ms, int) + return ts_in_ms + + +def uuidv7_boundary(timestamp_ms: int) -> uuid.UUID: + """Generate a non-random uuidv7 with the given timestamp (first 48 bits) and + all random bits to 0. As the smallest possible uuidv7 for that timestamp, + it may be used as a boundary for partitions. + """ + # Use zero bytes for all random portions + zero_random_bytes = b"\x00" * 10 + + # Create UUIDv7 bytes using the helper function + uuid_bytes = _create_uuidv7_bytes(timestamp_ms, zero_random_bytes) + + return uuid.UUID(bytes=uuid_bytes) diff --git a/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py b/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py new file mode 100644 index 0000000000..2bbbb3d28e --- /dev/null +++ b/api/migrations/versions/2025_07_02_2332-1c9ba48be8e4_add_uuidv7_function_in_sql.py @@ -0,0 +1,86 @@ +"""add uuidv7 function in SQL + +Revision ID: 1c9ba48be8e4 +Revises: 58eb7bdb93fe +Create Date: 2025-07-02 23:32:38.484499 + +""" + +""" +The functions in this files comes from https://github.com/dverite/postgres-uuidv7-sql/, with minor modifications. + +LICENSE: + +# Copyright and License + +Copyright (c) 2024, Daniel Vérité + +Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement is hereby granted, provided that the above copyright notice and this paragraph and the following two paragraphs appear in all copies. + +In no event shall Daniel Vérité be liable to any party for direct, indirect, special, incidental, or consequential damages, including lost profits, arising out of the use of this software and its documentation, even if Daniel Vérité has been advised of the possibility of such damage. + +Daniel Vérité specifically disclaims any warranties, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose. The software provided hereunder is on an "AS IS" basis, and Daniel Vérité has no obligations to provide maintenance, support, updates, enhancements, or modifications. +""" + +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '1c9ba48be8e4' +down_revision = '58eb7bdb93fe' +branch_labels: None = None +depends_on: None = None + + +def upgrade(): + # This implementation differs slightly from the original uuidv7 function in + # https://github.com/dverite/postgres-uuidv7-sql/. + # The ability to specify source timestamp has been removed because its type signature is incompatible with + # PostgreSQL 18's `uuidv7` function. This capability is rarely needed in practice, as IDs can be + # generated and controlled within the application layer. + op.execute(sa.text(r""" +/* Main function to generate a uuidv7 value with millisecond precision */ +CREATE FUNCTION uuidv7() RETURNS uuid +AS +$$ + -- Replace the first 48 bits of a uuidv4 with the current + -- number of milliseconds since 1970-01-01 UTC + -- and set the "ver" field to 7 by setting additional bits +SELECT encode( + set_bit( + set_bit( + overlay(uuid_send(gen_random_uuid()) placing + substring(int8send((extract(epoch from clock_timestamp()) * 1000)::bigint) from + 3) + from 1 for 6), + 52, 1), + 53, 1), 'hex')::uuid; +$$ LANGUAGE SQL VOLATILE PARALLEL SAFE; + +COMMENT ON FUNCTION uuidv7 IS + 'Generate a uuid-v7 value with a 48-bit timestamp (millisecond precision) and 74 bits of randomness'; +""")) + + op.execute(sa.text(r""" +CREATE FUNCTION uuidv7_boundary(timestamptz) RETURNS uuid +AS +$$ + /* uuid fields: version=0b0111, variant=0b10 */ +SELECT encode( + overlay('\x00000000000070008000000000000000'::bytea + placing substring(int8send(floor(extract(epoch from $1) * 1000)::bigint) from 3) + from 1 for 6), + 'hex')::uuid; +$$ LANGUAGE SQL STABLE STRICT PARALLEL SAFE; + +COMMENT ON FUNCTION uuidv7_boundary(timestamptz) IS + 'Generate a non-random uuidv7 with the given timestamp (first 48 bits) and all random bits to 0. As the smallest possible uuidv7 for that timestamp, it may be used as a boundary for partitions.'; +""" +)) + + +def downgrade(): + op.execute(sa.text("DROP FUNCTION uuidv7")) + op.execute(sa.text("DROP FUNCTION uuidv7_boundary")) diff --git a/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py b/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py new file mode 100644 index 0000000000..df4fbf0a0e --- /dev/null +++ b/api/migrations/versions/2025_07_04_1705-71f5020c6470_tool_oauth.py @@ -0,0 +1,62 @@ +"""tool oauth + +Revision ID: 71f5020c6470 +Revises: 4474872b0ee6 +Create Date: 2025-06-24 17:05:43.118647 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '71f5020c6470' +down_revision = '1c9ba48be8e4' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('tool_oauth_system_clients', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('plugin_id', sa.String(length=512), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('encrypted_oauth_params', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_oauth_system_client_pkey'), + sa.UniqueConstraint('plugin_id', 'provider', name='tool_oauth_system_client_plugin_id_provider_idx') + ) + op.create_table('tool_oauth_tenant_clients', + sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False), + sa.Column('tenant_id', models.types.StringUUID(), nullable=False), + sa.Column('plugin_id', sa.String(length=512), nullable=False), + sa.Column('provider', sa.String(length=255), nullable=False), + sa.Column('enabled', sa.Boolean(), server_default=sa.text('true'), nullable=False), + sa.Column('encrypted_oauth_params', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('id', name='tool_oauth_tenant_client_pkey'), + sa.UniqueConstraint('tenant_id', 'plugin_id', 'provider', name='unique_tool_oauth_tenant_client') + ) + + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.add_column(sa.Column('name', sa.String(length=256), server_default=sa.text("'API KEY 1'::character varying"), nullable=False)) + batch_op.add_column(sa.Column('is_default', sa.Boolean(), server_default=sa.text('false'), nullable=False)) + batch_op.add_column(sa.Column('credential_type', sa.String(length=32), server_default=sa.text("'api-key'::character varying"), nullable=False)) + batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') + batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider', 'name']) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op: + batch_op.drop_constraint(batch_op.f('unique_builtin_tool_provider'), type_='unique') + batch_op.create_unique_constraint(batch_op.f('unique_builtin_tool_provider'), ['tenant_id', 'provider']) + batch_op.drop_column('credential_type') + batch_op.drop_column('is_default') + batch_op.drop_column('name') + + op.drop_table('tool_oauth_tenant_clients') + op.drop_table('tool_oauth_system_clients') + # ### end Alembic commands ### diff --git a/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py b/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py new file mode 100644 index 0000000000..3bdbafda7c --- /dev/null +++ b/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py @@ -0,0 +1,51 @@ +"""update models + +Revision ID: 1a83934ad6d1 +Revises: 71f5020c6470 +Create Date: 2025-07-21 09:35:48.774794 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '1a83934ad6d1' +down_revision = '71f5020c6470' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op: + batch_op.alter_column('server_identifier', + existing_type=sa.VARCHAR(length=24), + type_=sa.String(length=64), + existing_nullable=False) + + with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: + batch_op.alter_column('tool_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=128), + existing_nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: + batch_op.alter_column('tool_name', + existing_type=sa.String(length=128), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op: + batch_op.alter_column('server_identifier', + existing_type=sa.String(length=64), + type_=sa.VARCHAR(length=24), + existing_nullable=False) + + # ### end Alembic commands ### diff --git a/api/models/account.py b/api/models/account.py index 7ffeefa980..1af571bc01 100644 --- a/api/models/account.py +++ b/api/models/account.py @@ -196,7 +196,7 @@ class Tenant(Base): __tablename__ = "tenants" __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) - id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) + id: Mapped[str] = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) name = db.Column(db.String(255), nullable=False) encrypt_public_key = db.Column(db.Text) plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) diff --git a/api/models/dataset.py b/api/models/dataset.py index 1ec27203a0..57e54b72a7 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -255,7 +255,7 @@ class Dataset(Base): @staticmethod def gen_collection_name_by_id(dataset_id: str) -> str: normalized_dataset_id = dataset_id.replace("-", "_") - return f"Vector_index_{normalized_dataset_id}_Node" + return f"{dify_config.VECTOR_INDEX_NAME_PREFIX}_{normalized_dataset_id}_Node" class DatasetProcessRule(Base): diff --git a/api/models/model.py b/api/models/model.py index 7e9e91727d..2377aeed8a 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -610,14 +610,6 @@ class InstalledApp(Base): return tenant -class ConversationSource(StrEnum): - """This enumeration is designed for use with `Conversation.from_source`.""" - - # NOTE(QuantumGhost): The enumeration members may not cover all possible cases. - API = "api" - CONSOLE = "console" - - class Conversation(Base): __tablename__ = "conversations" __table_args__ = ( diff --git a/api/models/task.py b/api/models/task.py index d853c1dd9a..1a4b606ff5 100644 --- a/api/models/task.py +++ b/api/models/task.py @@ -1,7 +1,6 @@ -from datetime import UTC, datetime - from celery import states # type: ignore +from libs.datetime_utils import naive_utc_now from models.base import Base from .engine import db @@ -18,8 +17,8 @@ class CeleryTask(Base): result = db.Column(db.PickleType, nullable=True) date_done = db.Column( db.DateTime, - default=lambda: datetime.now(UTC).replace(tzinfo=None), - onupdate=lambda: datetime.now(UTC).replace(tzinfo=None), + default=lambda: naive_utc_now(), + onupdate=lambda: naive_utc_now(), nullable=True, ) traceback = db.Column(db.Text, nullable=True) @@ -39,4 +38,4 @@ class CeleryTaskSet(Base): id = db.Column(db.Integer, db.Sequence("taskset_id_sequence"), autoincrement=True, primary_key=True) taskset_id = db.Column(db.String(155), unique=True) result = db.Column(db.PickleType, nullable=True) - date_done = db.Column(db.DateTime, default=lambda: datetime.now(UTC).replace(tzinfo=None), nullable=True) + date_done = db.Column(db.DateTime, default=lambda: naive_utc_now(), nullable=True) diff --git a/api/models/tools.py b/api/models/tools.py index 9d2c3baea5..f5fae8b796 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -21,6 +21,43 @@ from .model import Account, App, Tenant from .types import StringUUID +# system level tool oauth client params (client_id, client_secret, etc.) +class ToolOAuthSystemClient(Base): + __tablename__ = "tool_oauth_system_clients" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_oauth_system_client_pkey"), + db.UniqueConstraint("plugin_id", "provider", name="tool_oauth_system_client_plugin_id_provider_idx"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) + provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + # oauth params of the tool provider + encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) + + +# tenant level tool oauth client params (client_id, client_secret, etc.) +class ToolOAuthTenantClient(Base): + __tablename__ = "tool_oauth_tenant_clients" + __table_args__ = ( + db.PrimaryKeyConstraint("id", name="tool_oauth_tenant_client_pkey"), + db.UniqueConstraint("tenant_id", "plugin_id", "provider", name="unique_tool_oauth_tenant_client"), + ) + + id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + # tenant id + tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False) + plugin_id: Mapped[str] = mapped_column(db.String(512), nullable=False) + provider: Mapped[str] = mapped_column(db.String(255), nullable=False) + enabled: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("true")) + # oauth params of the tool provider + encrypted_oauth_params: Mapped[str] = mapped_column(db.Text, nullable=False) + + @property + def oauth_params(self) -> dict: + return cast(dict, json.loads(self.encrypted_oauth_params or "{}")) + + class BuiltinToolProvider(Base): """ This table stores the tool provider information for built-in tools for each tenant. @@ -29,12 +66,14 @@ class BuiltinToolProvider(Base): __tablename__ = "tool_builtin_providers" __table_args__ = ( db.PrimaryKeyConstraint("id", name="tool_builtin_provider_pkey"), - # one tenant can only have one tool provider with the same name - db.UniqueConstraint("tenant_id", "provider", name="unique_builtin_tool_provider"), + db.UniqueConstraint("tenant_id", "provider", "name", name="unique_builtin_tool_provider"), ) # id of the tool provider id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) + name: Mapped[str] = mapped_column( + db.String(256), nullable=False, server_default=db.text("'API KEY 1'::character varying") + ) # id of the tenant tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=True) # who created this tool provider @@ -49,6 +88,11 @@ class BuiltinToolProvider(Base): updated_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)") ) + is_default: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false")) + # credential type, e.g., "api-key", "oauth2" + credential_type: Mapped[str] = mapped_column( + db.String(32), nullable=False, server_default=db.text("'api-key'::character varying") + ) @property def credentials(self) -> dict: @@ -68,7 +112,7 @@ class ApiToolProvider(Base): id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) # name of the api provider - name = db.Column(db.String(255), nullable=False) + name = db.Column(db.String(255), nullable=False, server_default=db.text("'API KEY 1'::character varying")) # icon icon = db.Column(db.String(255), nullable=False) # original schema @@ -210,7 +254,7 @@ class MCPToolProvider(Base): # name of the mcp provider name: Mapped[str] = mapped_column(db.String(40), nullable=False) # server identifier of the mcp provider - server_identifier: Mapped[str] = mapped_column(db.String(24), nullable=False) + server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False) # encrypted url of the mcp provider server_url: Mapped[str] = mapped_column(db.Text, nullable=False) # hash of server_url for uniqueness check @@ -281,18 +325,19 @@ class MCPToolProvider(Base): @property def decrypted_credentials(self) -> dict: + from core.helper.provider_cache import NoOpProviderCredentialCache from core.tools.mcp_tool.provider import MCPToolProviderController - from core.tools.utils.configuration import ProviderConfigEncrypter + from core.tools.utils.encryption import create_provider_encrypter provider_controller = MCPToolProviderController._from_db(self) - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_provider_encrypter( tenant_id=self.tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.provider_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], + cache=NoOpProviderCredentialCache(), ) - return tool_configuration.decrypt(self.credentials, use_cache=False) + + return encrypter.decrypt(self.credentials) # type: ignore class ToolModelInvoke(Base): @@ -313,7 +358,7 @@ class ToolModelInvoke(Base): # type tool_type = db.Column(db.String(40), nullable=False) # tool name - tool_name = db.Column(db.String(40), nullable=False) + tool_name = db.Column(db.String(128), nullable=False) # invoke parameters model_parameters = db.Column(db.Text, nullable=False) # prompt messages diff --git a/api/models/workflow.py b/api/models/workflow.py index 77d48bec4f..124fb3bb4c 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -1,7 +1,7 @@ import json import logging from collections.abc import Mapping, Sequence -from datetime import UTC, datetime +from datetime import datetime from enum import Enum, StrEnum from typing import TYPE_CHECKING, Any, Optional, Union from uuid import uuid4 @@ -12,9 +12,11 @@ from sqlalchemy import orm from core.file.constants import maybe_file_object from core.file.models import File from core.variables import utils as variable_utils +from core.variables.variables import FloatVariable, IntegerVariable, StringVariable from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.nodes.enums import NodeType from factories.variable_factory import TypeMismatchError, build_segment_with_type +from libs.datetime_utils import naive_utc_now from libs.helper import extract_tenant_id from ._workflow_exc import NodeNotFoundError, WorkflowDataError @@ -137,7 +139,7 @@ class Workflow(Base): updated_at: Mapped[datetime] = mapped_column( db.DateTime, nullable=False, - default=datetime.now(UTC).replace(tzinfo=None), + default=naive_utc_now(), server_onupdate=func.current_timestamp(), ) _environment_variables: Mapped[str] = mapped_column( @@ -178,7 +180,7 @@ class Workflow(Base): workflow.conversation_variables = conversation_variables or [] workflow.marked_name = marked_name workflow.marked_comment = marked_comment - workflow.created_at = datetime.now(UTC).replace(tzinfo=None) + workflow.created_at = naive_utc_now() workflow.updated_at = workflow.created_at return workflow @@ -347,7 +349,7 @@ class Workflow(Base): ) @property - def environment_variables(self) -> Sequence[Variable]: + def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]: # TODO: find some way to init `self._environment_variables` when instance created. if self._environment_variables is None: self._environment_variables = "{}" @@ -367,11 +369,15 @@ class Workflow(Base): def decrypt_func(var): if isinstance(var, SecretVariable): return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) - else: + elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)): return var + else: + raise AssertionError("this statement should be unreachable.") - results = list(map(decrypt_func, results)) - return results + decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list( + map(decrypt_func, results) + ) + return decrypted_results @environment_variables.setter def environment_variables(self, value: Sequence[Variable]): @@ -902,7 +908,7 @@ _EDITABLE_SYSTEM_VARIABLE = frozenset(["query", "files"]) def _naive_utc_datetime(): - return datetime.now(UTC).replace(tzinfo=None) + return naive_utc_now() class WorkflowDraftVariable(Base): diff --git a/api/services/account_service.py b/api/services/account_service.py index 2ba6f4345b..352efb2f0c 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -17,6 +17,7 @@ from constants.languages import language_timezone_mapping, languages from events.tenant_event import tenant_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client, redis_fallback +from libs.datetime_utils import naive_utc_now from libs.helper import RateLimiter, TokenManager from libs.passport import PassportService from libs.password import compare_password, hash_password, valid_password @@ -52,8 +53,14 @@ from services.errors.workspace import WorkSpaceNotAllowedCreateError, Workspaces from services.feature_service import FeatureService from tasks.delete_account_task import delete_account_task from tasks.mail_account_deletion_task import send_account_deletion_verification_code +from tasks.mail_change_mail_task import send_change_mail_task from tasks.mail_email_code_login import send_email_code_login_mail_task from tasks.mail_invite_member_task import send_invite_member_mail_task +from tasks.mail_owner_transfer_task import ( + send_new_owner_transfer_notify_email_task, + send_old_owner_transfer_notify_email_task, + send_owner_transfer_confirm_task, +) from tasks.mail_reset_password_task import send_reset_password_mail_task @@ -75,8 +82,13 @@ class AccountService: email_code_account_deletion_rate_limiter = RateLimiter( prefix="email_code_account_deletion_rate_limit", max_attempts=1, time_window=60 * 1 ) + change_email_rate_limiter = RateLimiter(prefix="change_email_rate_limit", max_attempts=1, time_window=60 * 1) + owner_transfer_rate_limiter = RateLimiter(prefix="owner_transfer_rate_limit", max_attempts=1, time_window=60 * 1) + LOGIN_MAX_ERROR_LIMITS = 5 FORGOT_PASSWORD_MAX_ERROR_LIMITS = 5 + CHANGE_EMAIL_MAX_ERROR_LIMITS = 5 + OWNER_TRANSFER_MAX_ERROR_LIMITS = 5 @staticmethod def _get_refresh_token_key(refresh_token: str) -> str: @@ -124,8 +136,8 @@ class AccountService: available_ta.current = True db.session.commit() - if datetime.now(UTC).replace(tzinfo=None) - account.last_active_at > timedelta(minutes=10): - account.last_active_at = datetime.now(UTC).replace(tzinfo=None) + if naive_utc_now() - account.last_active_at > timedelta(minutes=10): + account.last_active_at = naive_utc_now() db.session.commit() return cast(Account, account) @@ -169,7 +181,7 @@ class AccountService: if account.status == AccountStatus.PENDING.value: account.status = AccountStatus.ACTIVE.value - account.initialized_at = datetime.now(UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() db.session.commit() @@ -307,7 +319,7 @@ class AccountService: # If it exists, update the record account_integrate.open_id = open_id account_integrate.encrypted_token = "" # todo - account_integrate.updated_at = datetime.now(UTC).replace(tzinfo=None) + account_integrate.updated_at = naive_utc_now() else: # If it does not exist, create a new record account_integrate = AccountIntegrate( @@ -342,7 +354,7 @@ class AccountService: @staticmethod def update_login_info(account: Account, *, ip_address: str) -> None: """Update last login time and ip""" - account.last_login_at = datetime.now(UTC).replace(tzinfo=None) + account.last_login_at = naive_utc_now() account.last_login_ip = ip_address db.session.add(account) db.session.commit() @@ -419,6 +431,101 @@ class AccountService: cls.reset_password_rate_limiter.increment_rate_limit(account_email) return token + @classmethod + def send_change_email_email( + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + old_email: Optional[str] = None, + language: Optional[str] = "en-US", + phase: Optional[str] = None, + ): + account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") + + if cls.change_email_rate_limiter.is_rate_limited(account_email): + from controllers.console.auth.error import EmailChangeRateLimitExceededError + + raise EmailChangeRateLimitExceededError() + + code, token = cls.generate_change_email_token(account_email, account, old_email=old_email) + + send_change_mail_task.delay( + language=language, + to=account_email, + code=code, + phase=phase, + ) + cls.change_email_rate_limiter.increment_rate_limit(account_email) + return token + + @classmethod + def send_owner_transfer_email( + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + language: Optional[str] = "en-US", + workspace_name: Optional[str] = "", + ): + account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") + + if cls.owner_transfer_rate_limiter.is_rate_limited(account_email): + from controllers.console.auth.error import OwnerTransferRateLimitExceededError + + raise OwnerTransferRateLimitExceededError() + + code, token = cls.generate_owner_transfer_token(account_email, account) + + send_owner_transfer_confirm_task.delay( + language=language, + to=account_email, + code=code, + workspace=workspace_name, + ) + cls.owner_transfer_rate_limiter.increment_rate_limit(account_email) + return token + + @classmethod + def send_old_owner_transfer_notify_email( + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + language: Optional[str] = "en-US", + workspace_name: Optional[str] = "", + new_owner_email: Optional[str] = "", + ): + account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") + + send_old_owner_transfer_notify_email_task.delay( + language=language, + to=account_email, + workspace=workspace_name, + new_owner_email=new_owner_email, + ) + + @classmethod + def send_new_owner_transfer_notify_email( + cls, + account: Optional[Account] = None, + email: Optional[str] = None, + language: Optional[str] = "en-US", + workspace_name: Optional[str] = "", + ): + account_email = account.email if account else email + if account_email is None: + raise ValueError("Email must be provided.") + + send_new_owner_transfer_notify_email_task.delay( + language=language, + to=account_email, + workspace=workspace_name, + ) + @classmethod def generate_reset_password_token( cls, @@ -435,14 +542,64 @@ class AccountService: ) return code, token + @classmethod + def generate_change_email_token( + cls, + email: str, + account: Optional[Account] = None, + code: Optional[str] = None, + old_email: Optional[str] = None, + additional_data: dict[str, Any] = {}, + ): + if not code: + code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) + additional_data["code"] = code + additional_data["old_email"] = old_email + token = TokenManager.generate_token( + account=account, email=email, token_type="change_email", additional_data=additional_data + ) + return code, token + + @classmethod + def generate_owner_transfer_token( + cls, + email: str, + account: Optional[Account] = None, + code: Optional[str] = None, + additional_data: dict[str, Any] = {}, + ): + if not code: + code = "".join([str(secrets.randbelow(exclusive_upper_bound=10)) for _ in range(6)]) + additional_data["code"] = code + token = TokenManager.generate_token( + account=account, email=email, token_type="owner_transfer", additional_data=additional_data + ) + return code, token + @classmethod def revoke_reset_password_token(cls, token: str): TokenManager.revoke_token(token, "reset_password") + @classmethod + def revoke_change_email_token(cls, token: str): + TokenManager.revoke_token(token, "change_email") + + @classmethod + def revoke_owner_transfer_token(cls, token: str): + TokenManager.revoke_token(token, "owner_transfer") + @classmethod def get_reset_password_data(cls, token: str) -> Optional[dict[str, Any]]: return TokenManager.get_token_data(token, "reset_password") + @classmethod + def get_change_email_data(cls, token: str) -> Optional[dict[str, Any]]: + return TokenManager.get_token_data(token, "change_email") + + @classmethod + def get_owner_transfer_data(cls, token: str) -> Optional[dict[str, Any]]: + return TokenManager.get_token_data(token, "owner_transfer") + @classmethod def send_email_code_login_email( cls, account: Optional[Account] = None, email: Optional[str] = None, language: Optional[str] = "en-US" @@ -552,6 +709,62 @@ class AccountService: key = f"forgot_password_error_rate_limit:{email}" redis_client.delete(key) + @staticmethod + @redis_fallback(default_return=None) + def add_change_email_error_rate_limit(email: str) -> None: + key = f"change_email_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + count = 0 + count = int(count) + 1 + redis_client.setex(key, dify_config.CHANGE_EMAIL_LOCKOUT_DURATION, count) + + @staticmethod + @redis_fallback(default_return=False) + def is_change_email_error_rate_limit(email: str) -> bool: + key = f"change_email_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + return False + count = int(count) + if count > AccountService.CHANGE_EMAIL_MAX_ERROR_LIMITS: + return True + return False + + @staticmethod + @redis_fallback(default_return=None) + def reset_change_email_error_rate_limit(email: str): + key = f"change_email_error_rate_limit:{email}" + redis_client.delete(key) + + @staticmethod + @redis_fallback(default_return=None) + def add_owner_transfer_error_rate_limit(email: str) -> None: + key = f"owner_transfer_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + count = 0 + count = int(count) + 1 + redis_client.setex(key, dify_config.OWNER_TRANSFER_LOCKOUT_DURATION, count) + + @staticmethod + @redis_fallback(default_return=False) + def is_owner_transfer_error_rate_limit(email: str) -> bool: + key = f"owner_transfer_error_rate_limit:{email}" + count = redis_client.get(key) + if count is None: + return False + count = int(count) + if count > AccountService.OWNER_TRANSFER_MAX_ERROR_LIMITS: + return True + return False + + @staticmethod + @redis_fallback(default_return=None) + def reset_owner_transfer_error_rate_limit(email: str): + key = f"owner_transfer_error_rate_limit:{email}" + redis_client.delete(key) + @staticmethod @redis_fallback(default_return=False) def is_email_send_ip_limit(ip_address: str): @@ -593,6 +806,10 @@ class AccountService: return False + @staticmethod + def check_email_unique(email: str) -> bool: + return db.session.query(Account).filter_by(email=email).first() is None + class TenantService: @staticmethod @@ -850,21 +1067,21 @@ class TenantService: target_member_join.role = new_role db.session.commit() - @staticmethod - def dissolve_tenant(tenant: Tenant, operator: Account) -> None: - """Dissolve tenant""" - if not TenantService.check_member_permission(tenant, operator, operator, "remove"): - raise NoPermissionError("No permission to dissolve tenant.") - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() - db.session.delete(tenant) - db.session.commit() - @staticmethod def get_custom_config(tenant_id: str) -> dict: tenant = db.get_or_404(Tenant, tenant_id) return cast(dict, tenant.custom_config_dict) + @staticmethod + def is_owner(account: Account, tenant: Tenant) -> bool: + return TenantService.get_user_role(account, tenant) == TenantAccountRole.OWNER + + @staticmethod + def is_member(account: Account, tenant: Tenant) -> bool: + """Check if the account is a member of the tenant""" + return TenantService.get_user_role(account, tenant) is not None + class RegisterService: @classmethod @@ -892,7 +1109,7 @@ class RegisterService: ) account.last_login_ip = ip_address - account.initialized_at = datetime.now(UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() TenantService.create_owner_tenant_if_not_exist(account=account, is_setup=True) @@ -933,7 +1150,7 @@ class RegisterService: is_setup=is_setup, ) account.status = AccountStatus.ACTIVE.value if not status else status.value - account.initialized_at = datetime.now(UTC).replace(tzinfo=None) + account.initialized_at = naive_utc_now() if open_id is not None and provider is not None: AccountService.link_account_integrate(provider, open_id, account) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 20257fa345..08e13c588e 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -575,13 +575,26 @@ class AppDslService: raise ValueError("Missing draft workflow configuration, please check.") workflow_dict = workflow.to_dict(include_secret=include_secret) + # TODO: refactor: we need a better way to filter workspace related data from nodes for node in workflow_dict.get("graph", {}).get("nodes", []): - if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value: - dataset_ids = node["data"].get("dataset_ids", []) - node["data"]["dataset_ids"] = [ + node_data = node.get("data", {}) + if not node_data: + continue + data_type = node_data.get("type", "") + if data_type == NodeType.KNOWLEDGE_RETRIEVAL.value: + dataset_ids = node_data.get("dataset_ids", []) + node_data["dataset_ids"] = [ cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=app_model.tenant_id) for dataset_id in dataset_ids ] + # filter credential id from tool node + if not include_secret and data_type == NodeType.TOOL.value: + node_data.pop("credential_id", None) + # filter credential id from agent node + if not include_secret and data_type == NodeType.AGENT.value: + for tool in node_data.get("agent_parameters", {}).get("tools", {}).get("value", []): + tool.pop("credential_id", None) + export_data["workflow"] = workflow_dict dependencies = cls._extract_dependencies_from_workflow(workflow) export_data["dependencies"] = [ @@ -602,7 +615,15 @@ class AppDslService: if not app_model_config: raise ValueError("Missing app configuration, please check.") - export_data["model_config"] = app_model_config.to_dict() + model_config = app_model_config.to_dict() + + # TODO: refactor: we need a better way to filter workspace related data from model config + # filter credential id from model config + for tool in model_config.get("agent_mode", {}).get("tools", []): + tool.pop("credential_id", None) + + export_data["model_config"] = model_config + dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict()) export_data["dependencies"] = [ jsonable_encoder(d.model_dump()) diff --git a/api/services/app_service.py b/api/services/app_service.py index db0f8cd414..3494b2796b 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -1,6 +1,5 @@ import json import logging -from datetime import UTC, datetime from typing import Optional, cast from flask_login import current_user @@ -17,6 +16,7 @@ from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager from events.app_event import app_was_created from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.account import Account from models.model import App, AppMode, AppModelConfig, Site from models.tools import ApiToolProvider @@ -233,8 +233,9 @@ class AppService: app.icon = args.get("icon") app.icon_background = args.get("icon_background") app.use_icon_as_answer_icon = args.get("use_icon_as_answer_icon", False) + app.max_active_requests = args.get("max_active_requests") app.updated_by = current_user.id - app.updated_at = datetime.now(UTC).replace(tzinfo=None) + app.updated_at = naive_utc_now() db.session.commit() return app @@ -248,7 +249,7 @@ class AppService: """ app.name = name app.updated_by = current_user.id - app.updated_at = datetime.now(UTC).replace(tzinfo=None) + app.updated_at = naive_utc_now() db.session.commit() return app @@ -264,7 +265,7 @@ class AppService: app.icon = icon app.icon_background = icon_background app.updated_by = current_user.id - app.updated_at = datetime.now(UTC).replace(tzinfo=None) + app.updated_at = naive_utc_now() db.session.commit() return app @@ -281,7 +282,7 @@ class AppService: app.enable_site = enable_site app.updated_by = current_user.id - app.updated_at = datetime.now(UTC).replace(tzinfo=None) + app.updated_at = naive_utc_now() db.session.commit() return app @@ -298,7 +299,7 @@ class AppService: app.enable_api = enable_api app.updated_by = current_user.id - app.updated_at = datetime.now(UTC).replace(tzinfo=None) + app.updated_at = naive_utc_now() db.session.commit() return app diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index afdaa49465..40097d5ed5 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -1,5 +1,4 @@ from collections.abc import Callable, Sequence -from datetime import UTC, datetime from typing import Optional, Union from sqlalchemy import asc, desc, func, or_, select @@ -8,6 +7,7 @@ from sqlalchemy.orm import Session from core.app.entities.app_invoke_entities import InvokeFrom from core.llm_generator.llm_generator import LLMGenerator from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import ConversationVariable from models.account import Account @@ -113,7 +113,7 @@ class ConversationService: return cls.auto_generate_name(app_model, conversation) else: conversation.name = name - conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) + conversation.updated_at = naive_utc_now() db.session.commit() return conversation @@ -169,7 +169,7 @@ class ConversationService: conversation = cls.get_conversation(app_model, conversation_id, user) conversation.is_deleted = True - conversation.updated_at = datetime.now(UTC).replace(tzinfo=None) + conversation.updated_at = naive_utc_now() db.session.commit() @classmethod diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index e42b5ace75..09cdd66e04 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -26,6 +26,7 @@ from events.document_event import document_was_deleted from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper +from libs.datetime_utils import naive_utc_now from models.account import Account, TenantAccountRole from models.dataset import ( AppDatasetJoin, @@ -428,7 +429,7 @@ class DatasetService: # Add metadata fields filtered_data["updated_by"] = user.id - filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + filtered_data["updated_at"] = naive_utc_now() # update Retrieval model filtered_data["retrieval_model"] = data["retrieval_model"] @@ -994,7 +995,7 @@ class DocumentService: # update document to be paused document.is_paused = True document.paused_by = current_user.id - document.paused_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.paused_at = naive_utc_now() db.session.add(document) db.session.commit() diff --git a/api/services/entities/knowledge_entities/knowledge_entities.py b/api/services/entities/knowledge_entities/knowledge_entities.py index 88d4224e97..344c67885e 100644 --- a/api/services/entities/knowledge_entities/knowledge_entities.py +++ b/api/services/entities/knowledge_entities/knowledge_entities.py @@ -4,13 +4,6 @@ from typing import Literal, Optional from pydantic import BaseModel -class SegmentUpdateEntity(BaseModel): - content: str - answer: Optional[str] = None - keywords: Optional[list[str]] = None - enabled: Optional[bool] = None - - class ParentMode(StrEnum): FULL_DOC = "full-doc" PARAGRAPH = "paragraph" @@ -153,10 +146,6 @@ class MetadataUpdateArgs(BaseModel): value: Optional[str | int | float] = None -class MetadataValueUpdateArgs(BaseModel): - fields: list[MetadataUpdateArgs] - - class MetadataDetail(BaseModel): id: str name: str diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index eb50d79494..06a4c22117 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -1,6 +1,5 @@ import json from copy import deepcopy -from datetime import UTC, datetime from typing import Any, Optional, Union, cast from urllib.parse import urlparse @@ -11,6 +10,7 @@ from constants import HIDDEN_VALUE from core.helper import ssrf_proxy from core.rag.entities.metadata_entities import MetadataCondition from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.dataset import ( Dataset, ExternalKnowledgeApis, @@ -120,7 +120,7 @@ class ExternalDatasetService: external_knowledge_api.description = args.get("description", "") external_knowledge_api.settings = json.dumps(args.get("settings"), ensure_ascii=False) external_knowledge_api.updated_by = user_id - external_knowledge_api.updated_at = datetime.now(UTC).replace(tzinfo=None) + external_knowledge_api.updated_at = naive_utc_now() db.session.commit() return external_knowledge_api diff --git a/api/services/feature_service.py b/api/services/feature_service.py index 188caf3505..1441e6ce16 100644 --- a/api/services/feature_service.py +++ b/api/services/feature_service.py @@ -123,7 +123,7 @@ class FeatureModel(BaseModel): dataset_operator_enabled: bool = False webapp_copyright_enabled: bool = False workspace_members: LicenseLimitationModel = LicenseLimitationModel(enabled=False, size=0, limit=0) - + is_allow_transfer_workspace: bool = True # pydantic configs model_config = ConfigDict(protected_namespaces=()) @@ -149,6 +149,7 @@ class SystemFeatureModel(BaseModel): branding: BrandingModel = BrandingModel() webapp_auth: WebAppAuthModel = WebAppAuthModel() plugin_installation_permission: PluginInstallationPermissionModel = PluginInstallationPermissionModel() + enable_change_email: bool = True class FeatureService: @@ -186,6 +187,7 @@ class FeatureService: if dify_config.ENTERPRISE_ENABLED: system_features.branding.enabled = True system_features.webapp_auth.enabled = True + system_features.enable_change_email = False cls._fulfill_params_from_enterprise(system_features) if dify_config.MARKETPLACE_ENABLED: @@ -228,6 +230,8 @@ class FeatureService: if features.billing.subscription.plan != "sandbox": features.webapp_copyright_enabled = True + else: + features.is_allow_transfer_workspace = False if "members" in billing_info: features.members.size = billing_info["members"]["size"] diff --git a/api/services/plugin/plugin_parameter_service.py b/api/services/plugin/plugin_parameter_service.py index 393213c0e2..a1c5639e00 100644 --- a/api/services/plugin/plugin_parameter_service.py +++ b/api/services/plugin/plugin_parameter_service.py @@ -6,7 +6,7 @@ from sqlalchemy.orm import Session from core.plugin.entities.parameters import PluginParameterOption from core.plugin.impl.dynamic_select import DynamicSelectClient from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import create_tool_provider_encrypter from extensions.ext_database import db from models.tools import BuiltinToolProvider @@ -38,11 +38,9 @@ class PluginParameterService: case "tool": provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) # check if credentials are required @@ -63,7 +61,7 @@ class PluginParameterService: if db_record is None: raise ValueError(f"Builtin provider {provider} not found when fetching credentials") - credentials = tool_configuration.decrypt(db_record.credentials) + credentials = encrypter.decrypt(db_record.credentials) case _: raise ValueError(f"Invalid provider type: {provider_type}") diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index 0f22afd8dd..0a5bc44b64 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -196,6 +196,17 @@ class PluginService: manager = PluginInstaller() return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier) + @staticmethod + def is_plugin_verified(tenant_id: str, plugin_unique_identifier: str) -> bool: + """ + Check if the plugin is verified + """ + manager = PluginInstaller() + try: + return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier).verified + except Exception: + return False + @staticmethod def fetch_install_tasks(tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]: """ diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 6f848d49c4..80badf2335 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -18,7 +18,7 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser from extensions.ext_database import db from models.tools import ApiToolProvider @@ -164,15 +164,11 @@ class ApiToolManageService: provider_controller.load_bundled_tools(tool_bundles) # encrypt credentials - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) - - encrypted_credentials = tool_configuration.encrypt(credentials) - db_provider.credentials_str = json.dumps(encrypted_credentials) + db_provider.credentials_str = json.dumps(encrypter.encrypt(credentials)) db.session.add(db_provider) db.session.commit() @@ -297,28 +293,26 @@ class ApiToolManageService: provider_controller.load_bundled_tools(tool_bundles) # get original credentials if exists - tool_configuration = ProviderConfigEncrypter( + encrypter, cache = create_tool_provider_encrypter( tenant_id=tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) - original_credentials = tool_configuration.decrypt(provider.credentials) - masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) + original_credentials = encrypter.decrypt(provider.credentials) + masked_credentials = encrypter.mask_tool_credentials(original_credentials) # check if the credential has changed, save the original credential for name, value in credentials.items(): if name in masked_credentials and value == masked_credentials[name]: credentials[name] = original_credentials[name] - credentials = tool_configuration.encrypt(credentials) + credentials = encrypter.encrypt(credentials) provider.credentials_str = json.dumps(credentials) db.session.add(provider) db.session.commit() # delete cache - tool_configuration.delete_tool_credentials_cache() + cache.delete() # update labels ToolLabelManager.update_tool_labels(provider_controller, labels) @@ -416,15 +410,13 @@ class ApiToolManageService: # decrypt credentials if db_provider.id: - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=tenant_id, - config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) - decrypted_credentials = tool_configuration.decrypt(credentials) + decrypted_credentials = encrypter.decrypt(credentials) # check if the credential has changed, save the original credential - masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) + masked_credentials = encrypter.mask_tool_credentials(decrypted_credentials) for name, value in credentials.items(): if name in masked_credentials and value == masked_credentials[name]: credentials[name] = decrypted_credentials[name] @@ -446,7 +438,7 @@ class ApiToolManageService: return {"result": result or "empty response"} @staticmethod - def list_api_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]: + def list_api_tools(tenant_id: str) -> list[ToolProviderApiEntity]: """ list api tools """ @@ -474,7 +466,7 @@ class ApiToolManageService: for tool in tools or []: user_provider.tools.append( ToolTransformService.convert_tool_entity_to_api_entity( - tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels + tenant_id=tenant_id, tool=tool, labels=labels ) ) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 58a4b2f179..430575b532 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -1,28 +1,84 @@ import json import logging +import re +from collections.abc import Mapping from pathlib import Path +from typing import Any, Optional from sqlalchemy.orm import Session from configs import dify_config +from constants import HIDDEN_VALUE, UNKNOWN_VALUE from core.helper.position_helper import is_filtered -from core.model_runtime.utils.encoders import jsonable_encoder +from core.helper.provider_cache import NoOpProviderCredentialCache, ToolProviderCredentialsCache from core.plugin.entities.plugin import ToolProviderID -from core.plugin.impl.exc import PluginDaemonClientSideError +from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort -from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity -from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError +from core.tools.entities.api_entities import ( + ToolApiEntity, + ToolProviderApiEntity, + ToolProviderCredentialApiEntity, + ToolProviderCredentialInfoApiEntity, +) +from core.tools.entities.tool_entities import CredentialType +from core.tools.errors import ToolProviderNotFoundError +from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import create_provider_encrypter +from core.tools.utils.system_oauth_encryption import decrypt_system_oauth_params from extensions.ext_database import db -from models.tools import BuiltinToolProvider +from extensions.ext_redis import redis_client +from models.tools import BuiltinToolProvider, ToolOAuthSystemClient, ToolOAuthTenantClient +from services.plugin.plugin_service import PluginService from services.tools.tools_transform_service import ToolTransformService logger = logging.getLogger(__name__) class BuiltinToolManageService: + __MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100 + + @staticmethod + def delete_custom_oauth_client_params(tenant_id: str, provider: str): + """ + delete custom oauth client params + """ + tool_provider = ToolProviderID(provider) + with Session(db.engine) as session: + session.query(ToolOAuthTenantClient).filter_by( + tenant_id=tenant_id, + provider=tool_provider.provider_name, + plugin_id=tool_provider.plugin_id, + ).delete() + session.commit() + return {"result": "success"} + + @staticmethod + def get_builtin_tool_provider_oauth_client_schema(tenant_id: str, provider_name: str): + """ + get builtin tool provider oauth client schema + """ + provider = ToolManager.get_builtin_provider(provider_name, tenant_id) + verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified( + tenant_id, provider.plugin_unique_identifier + ) + + is_oauth_custom_client_enabled = BuiltinToolManageService.is_oauth_custom_client_enabled( + tenant_id, provider_name + ) + is_system_oauth_params_exists = verified and BuiltinToolManageService.is_oauth_system_client_exists( + provider_name + ) + result = { + "schema": provider.get_oauth_client_schema(), + "is_oauth_custom_client_enabled": is_oauth_custom_client_enabled, + "is_system_oauth_params_exists": is_system_oauth_params_exists, + "client_params": BuiltinToolManageService.get_custom_oauth_client_params(tenant_id, provider_name), + "redirect_uri": f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_name}/tool/callback", + } + return result + @staticmethod def list_builtin_tool_provider_tools(tenant_id: str, provider: str) -> list[ToolApiEntity]: """ @@ -36,27 +92,11 @@ class BuiltinToolManageService: provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) tools = provider_controller.get_tools() - tool_provider_configurations = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) - # check if user has added the provider - builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) - - credentials = {} - if builtin_provider is not None: - # get credentials - credentials = builtin_provider.credentials - credentials = tool_provider_configurations.decrypt(credentials) - result: list[ToolApiEntity] = [] for tool in tools or []: result.append( ToolTransformService.convert_tool_entity_to_api_entity( tool=tool, - credentials=credentials, tenant_id=tenant_id, labels=ToolLabelManager.get_tool_labels(provider_controller), ) @@ -65,25 +105,15 @@ class BuiltinToolManageService: return result @staticmethod - def get_builtin_tool_provider_info(user_id: str, tenant_id: str, provider: str): + def get_builtin_tool_provider_info(tenant_id: str, provider: str): """ get builtin tool provider info """ provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) - tool_provider_configurations = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) # check if user has added the provider - builtin_provider = BuiltinToolManageService._fetch_builtin_provider(provider, tenant_id) - - credentials = {} - if builtin_provider is not None: - # get credentials - credentials = builtin_provider.credentials - credentials = tool_provider_configurations.decrypt(credentials) + builtin_provider = BuiltinToolManageService.get_builtin_provider(provider, tenant_id) + if builtin_provider is None: + raise ValueError(f"you have not added provider {provider}") entity = ToolTransformService.builtin_provider_to_user_provider( provider_controller=provider_controller, @@ -92,127 +122,406 @@ class BuiltinToolManageService: ) entity.original_credentials = {} - return entity @staticmethod - def list_builtin_provider_credentials_schema(provider_name: str, tenant_id: str): + def list_builtin_provider_credentials_schema(provider_name: str, credential_type: CredentialType, tenant_id: str): """ list builtin provider credentials schema + :param credential_type: credential type :param provider_name: the name of the provider :param tenant_id: the id of the tenant :return: the list of tool providers """ provider = ToolManager.get_builtin_provider(provider_name, tenant_id) - return jsonable_encoder(provider.get_credentials_schema()) + return provider.get_credentials_schema_by_type(credential_type) @staticmethod def update_builtin_tool_provider( - session: Session, user_id: str, tenant_id: str, provider_name: str, credentials: dict + user_id: str, + tenant_id: str, + provider: str, + credential_id: str, + credentials: dict | None = None, + name: str | None = None, ): """ update builtin tool provider """ - # get if the provider exists - provider = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) + with Session(db.engine) as session: + # get if the provider exists + db_provider = ( + session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first() + ) + if db_provider is None: + raise ValueError(f"you have not added provider {provider}") + + try: + if CredentialType.of(db_provider.credential_type).is_editable() and credentials: + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller.need_credentials: + raise ValueError(f"provider {provider} does not need credentials") + + encrypter, cache = BuiltinToolManageService.create_tool_encrypter( + tenant_id, db_provider, provider, provider_controller + ) + + original_credentials = encrypter.decrypt(db_provider.credentials) + new_credentials: dict = { + key: value if value != HIDDEN_VALUE else original_credentials.get(key, UNKNOWN_VALUE) + for key, value in credentials.items() + } + + if CredentialType.of(db_provider.credential_type).is_validate_allowed(): + provider_controller.validate_credentials(user_id, new_credentials) + # encrypt credentials + db_provider.encrypted_credentials = json.dumps(encrypter.encrypt(new_credentials)) + + cache.delete() + + # update name if provided + if name and name != db_provider.name: + # check if the name is already used + if ( + session.query(BuiltinToolProvider) + .filter_by(tenant_id=tenant_id, provider=provider, name=name) + .count() + > 0 + ): + raise ValueError(f"the credential name '{name}' is already used") + + db_provider.name = name + + session.commit() + except Exception as e: + session.rollback() + raise ValueError(str(e)) + return {"result": "success"} + + @staticmethod + def add_builtin_tool_provider( + user_id: str, + api_type: CredentialType, + tenant_id: str, + provider: str, + credentials: dict, + name: str | None = None, + ): + """ + add builtin tool provider + """ try: - # get provider - provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) - if not provider_controller.need_credentials: - raise ValueError(f"provider {provider_name} does not need credentials") - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, - ) + with Session(db.engine) as session: + lock = f"builtin_tool_provider_create_lock:{tenant_id}_{provider}" + with redis_client.lock(lock, timeout=20): + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller.need_credentials: + raise ValueError(f"provider {provider} does not need credentials") + + provider_count = ( + session.query(BuiltinToolProvider).filter_by(tenant_id=tenant_id, provider=provider).count() + ) + + # check if the provider count is reached the limit + if provider_count >= BuiltinToolManageService.__MAX_BUILTIN_TOOL_PROVIDER_COUNT__: + raise ValueError(f"you have reached the maximum number of providers for {provider}") + + # validate credentials if allowed + if CredentialType.of(api_type).is_validate_allowed(): + provider_controller.validate_credentials(user_id, credentials) + + # generate name if not provided + if name is None or name == "": + name = BuiltinToolManageService.generate_builtin_tool_provider_name( + session=session, tenant_id=tenant_id, provider=provider, credential_type=api_type + ) + else: + # check if the name is already used + if ( + session.query(BuiltinToolProvider) + .filter_by(tenant_id=tenant_id, provider=provider, name=name) + .count() + > 0 + ): + raise ValueError(f"the credential name '{name}' is already used") + + # create encrypter + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(api_type) + ], + cache=NoOpProviderCredentialCache(), + ) + + db_provider = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=user_id, + provider=provider, + encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), + credential_type=api_type.value, + name=name, + ) - # get original credentials if exists - if provider is not None: - original_credentials = tool_configuration.decrypt(provider.credentials) - masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) - # check if the credential has changed, save the original credential - for name, value in credentials.items(): - if name in masked_credentials and value == masked_credentials[name]: - credentials[name] = original_credentials[name] - # validate credentials - provider_controller.validate_credentials(user_id, credentials) - # encrypt credentials - credentials = tool_configuration.encrypt(credentials) - except ( - PluginDaemonClientSideError, - ToolProviderNotFoundError, - ToolNotFoundError, - ToolProviderCredentialValidationError, - ) as e: + session.add(db_provider) + session.commit() + except Exception as e: + session.rollback() raise ValueError(str(e)) + return {"result": "success"} - if provider is None: - # create provider - provider = BuiltinToolProvider( - tenant_id=tenant_id, - user_id=user_id, - provider=provider_name, - encrypted_credentials=json.dumps(credentials), + @staticmethod + def create_tool_encrypter( + tenant_id: str, + db_provider: BuiltinToolProvider, + provider: str, + provider_controller: BuiltinToolProviderController, + ): + encrypter, cache = create_provider_encrypter( + tenant_id=tenant_id, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type(db_provider.credential_type) + ], + cache=ToolProviderCredentialsCache(tenant_id=tenant_id, provider=provider, credential_id=db_provider.id), + ) + return encrypter, cache + + @staticmethod + def generate_builtin_tool_provider_name( + session: Session, tenant_id: str, provider: str, credential_type: CredentialType + ) -> str: + try: + db_providers = ( + session.query(BuiltinToolProvider) + .filter_by( + tenant_id=tenant_id, + provider=provider, + credential_type=credential_type.value, + ) + .order_by(BuiltinToolProvider.created_at.desc()) + .all() ) - db.session.add(provider) - else: - provider.encrypted_credentials = json.dumps(credentials) + # Get the default name pattern + default_pattern = f"{credential_type.get_name()}" - # delete cache - tool_configuration.delete_tool_credentials_cache() + # Find all names that match the default pattern: "{default_pattern} {number}" + pattern = rf"^{re.escape(default_pattern)}\s+(\d+)$" + numbers = [] - db.session.commit() - return {"result": "success"} + for db_provider in db_providers: + if db_provider.name: + match = re.match(pattern, db_provider.name.strip()) + if match: + numbers.append(int(match.group(1))) + + # If no default pattern names found, start with 1 + if not numbers: + return f"{default_pattern} 1" + + # Find the next number + max_number = max(numbers) + return f"{default_pattern} {max_number + 1}" + except Exception as e: + logger.warning(f"Error generating next provider name for {provider}: {str(e)}") + # fallback + return f"{credential_type.get_name()} 1" @staticmethod - def get_builtin_tool_provider_credentials(tenant_id: str, provider_name: str): + def get_builtin_tool_provider_credentials( + tenant_id: str, provider_name: str + ) -> list[ToolProviderCredentialApiEntity]: """ get builtin tool provider credentials """ - provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) + with db.session.no_autoflush: + providers = ( + db.session.query(BuiltinToolProvider) + .filter_by(tenant_id=tenant_id, provider=provider_name) + .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) + .all() + ) - if provider_obj is None: - return {} + if len(providers) == 0: + return [] + + default_provider = providers[0] + default_provider.is_default = True + provider_controller = ToolManager.get_builtin_provider(default_provider.provider, tenant_id) + + credentials: list[ToolProviderCredentialApiEntity] = [] + encrypters = {} + for provider in providers: + credential_type = provider.credential_type + if credential_type not in encrypters: + encrypters[credential_type] = BuiltinToolManageService.create_tool_encrypter( + tenant_id, provider, provider.provider, provider_controller + )[0] + encrypter = encrypters[credential_type] + decrypt_credential = encrypter.mask_tool_credentials(encrypter.decrypt(provider.credentials)) + credential_entity = ToolTransformService.convert_builtin_provider_to_credential_entity( + provider=provider, + credentials=decrypt_credential, + ) + credentials.append(credential_entity) + return credentials - provider_controller = ToolManager.get_builtin_provider(provider_obj.provider, tenant_id) - tool_configuration = ProviderConfigEncrypter( - tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + @staticmethod + def get_builtin_tool_provider_credential_info(tenant_id: str, provider: str) -> ToolProviderCredentialInfoApiEntity: + """ + get builtin tool provider credential info + """ + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + supported_credential_types = provider_controller.get_supported_credential_types() + credentials = BuiltinToolManageService.get_builtin_tool_provider_credentials(tenant_id, provider) + credential_info = ToolProviderCredentialInfoApiEntity( + supported_credential_types=supported_credential_types, + is_oauth_custom_client_enabled=BuiltinToolManageService.is_oauth_custom_client_enabled(tenant_id, provider), + credentials=credentials, ) - credentials = tool_configuration.decrypt(provider_obj.credentials) - credentials = tool_configuration.mask_tool_credentials(credentials) - return credentials + + return credential_info @staticmethod - def delete_builtin_tool_provider(user_id: str, tenant_id: str, provider_name: str): + def delete_builtin_tool_provider(tenant_id: str, provider: str, credential_id: str): """ delete tool provider """ - provider_obj = BuiltinToolManageService._fetch_builtin_provider(provider_name, tenant_id) + with Session(db.engine) as session: + db_provider = ( + session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.id == credential_id, + ) + .first() + ) - if provider_obj is None: - raise ValueError(f"you have not added provider {provider_name}") + if db_provider is None: + raise ValueError(f"you have not added provider {provider}") - db.session.delete(provider_obj) - db.session.commit() + session.delete(db_provider) + session.commit() - # delete cache - provider_controller = ToolManager.get_builtin_provider(provider_name, tenant_id) - tool_configuration = ProviderConfigEncrypter( + # delete cache + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + _, cache = BuiltinToolManageService.create_tool_encrypter( + tenant_id, db_provider, provider, provider_controller + ) + cache.delete() + + return {"result": "success"} + + @staticmethod + def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str): + """ + set default provider + """ + with Session(db.engine) as session: + # get provider + target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first() + if target_provider is None: + raise ValueError("provider not found") + + # clear default provider + session.query(BuiltinToolProvider).filter_by( + tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True + ).update({"is_default": False}) + + # set new default provider + target_provider.is_default = True + session.commit() + return {"result": "success"} + + @staticmethod + def is_oauth_system_client_exists(provider_name: str) -> bool: + """ + check if oauth system client exists + """ + tool_provider = ToolProviderID(provider_name) + with Session(db.engine).no_autoflush as session: + system_client: ToolOAuthSystemClient | None = ( + session.query(ToolOAuthSystemClient) + .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) + .first() + ) + return system_client is not None + + @staticmethod + def is_oauth_custom_client_enabled(tenant_id: str, provider: str) -> bool: + """ + check if oauth custom client is enabled + """ + tool_provider = ToolProviderID(provider) + with Session(db.engine).no_autoflush as session: + user_client: ToolOAuthTenantClient | None = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + provider=tool_provider.provider_name, + plugin_id=tool_provider.plugin_id, + enabled=True, + ) + .first() + ) + return user_client is not None and user_client.enabled + + @staticmethod + def get_oauth_client(tenant_id: str, provider: str) -> Mapping[str, Any] | None: + """ + get builtin tool provider + """ + tool_provider = ToolProviderID(provider) + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + encrypter, _ = create_provider_encrypter( tenant_id=tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), ) - tool_configuration.delete_tool_credentials_cache() + with Session(db.engine).no_autoflush as session: + user_client: ToolOAuthTenantClient | None = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + provider=tool_provider.provider_name, + plugin_id=tool_provider.plugin_id, + enabled=True, + ) + .first() + ) + oauth_params: Mapping[str, Any] | None = None + if user_client: + oauth_params = encrypter.decrypt(user_client.oauth_params) + return oauth_params + + # only verified provider can use custom oauth client + is_verified = not isinstance(provider, PluginToolProviderController) or PluginService.is_plugin_verified( + tenant_id, provider.plugin_unique_identifier + ) + if not is_verified: + return oauth_params - return {"result": "success"} + system_client: ToolOAuthSystemClient | None = ( + session.query(ToolOAuthSystemClient) + .filter_by(plugin_id=tool_provider.plugin_id, provider=tool_provider.provider_name) + .first() + ) + if system_client: + try: + oauth_params = decrypt_system_oauth_params(system_client.encrypted_oauth_params) + except Exception as e: + raise ValueError(f"Error decrypting system oauth params: {e}") + + return oauth_params @staticmethod def get_builtin_tool_provider_icon(provider: str): @@ -234,9 +543,7 @@ class BuiltinToolManageService: with db.session.no_autoflush: # get all user added providers - db_providers: list[BuiltinToolProvider] = ( - db.session.query(BuiltinToolProvider).filter(BuiltinToolProvider.tenant_id == tenant_id).all() or [] - ) + db_providers: list[BuiltinToolProvider] = ToolManager.list_default_builtin_providers(tenant_id) # rewrite db_providers for db_provider in db_providers: @@ -275,7 +582,6 @@ class BuiltinToolManageService: ToolTransformService.convert_tool_entity_to_api_entity( tenant_id=tenant_id, tool=tool, - credentials=user_builtin_provider.original_credentials, labels=ToolLabelManager.get_tool_labels(provider_controller), ) ) @@ -287,43 +593,153 @@ class BuiltinToolManageService: return BuiltinToolProviderSort.sort(result) @staticmethod - def _fetch_builtin_provider(provider_name: str, tenant_id: str) -> BuiltinToolProvider | None: - try: - full_provider_name = provider_name - provider_id_entity = ToolProviderID(provider_name) - provider_name = provider_id_entity.provider_name - if provider_id_entity.organization != "langgenius": - provider_obj = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - BuiltinToolProvider.provider == full_provider_name, + def get_builtin_provider(provider_name: str, tenant_id: str) -> Optional[BuiltinToolProvider]: + """ + This method is used to fetch the builtin provider from the database + 1.if the default provider exists, return the default provider + 2.if the default provider does not exist, return the oldest provider + """ + with Session(db.engine) as session: + try: + full_provider_name = provider_name + provider_id_entity = ToolProviderID(provider_name) + provider_name = provider_id_entity.provider_name + + if provider_id_entity.organization != "langgenius": + provider = ( + session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + BuiltinToolProvider.provider == full_provider_name, + ) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first + ) + .first() ) - .first() - ) - else: - provider_obj = ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == provider_name) - | (BuiltinToolProvider.provider == full_provider_name), + else: + provider = ( + session.query(BuiltinToolProvider) + .filter( + BuiltinToolProvider.tenant_id == tenant_id, + (BuiltinToolProvider.provider == provider_name) + | (BuiltinToolProvider.provider == full_provider_name), + ) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first + ) + .first() + ) + + if provider is None: + return None + + provider.provider = ToolProviderID(provider.provider).to_string() + return provider + except Exception: + # it's an old provider without organization + return ( + session.query(BuiltinToolProvider) + .filter(BuiltinToolProvider.tenant_id == tenant_id, BuiltinToolProvider.provider == provider_name) + .order_by( + BuiltinToolProvider.is_default.desc(), # default=True first + BuiltinToolProvider.created_at.asc(), # oldest first ) .first() ) - if provider_obj is None: - return None + @staticmethod + def save_custom_oauth_client_params( + tenant_id: str, + provider: str, + client_params: Optional[dict] = None, + enable_oauth_custom_client: Optional[bool] = None, + ): + """ + setup oauth custom client + """ + if client_params is None and enable_oauth_custom_client is None: + return {"result": "success"} + + tool_provider = ToolProviderID(provider) + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller: + raise ToolProviderNotFoundError(f"Provider {provider} not found") - provider_obj.provider = ToolProviderID(provider_obj.provider).to_string() - return provider_obj - except Exception: - # it's an old provider without organization - return ( - db.session.query(BuiltinToolProvider) - .filter( - BuiltinToolProvider.tenant_id == tenant_id, - (BuiltinToolProvider.provider == provider_name), + if not isinstance(provider_controller, (BuiltinToolProviderController, PluginToolProviderController)): + raise ValueError(f"Provider {provider} is not a builtin or plugin provider") + + with Session(db.engine) as session: + custom_client_params = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, ) .first() ) + + # if the record does not exist, create a basic record + if custom_client_params is None: + custom_client_params = ToolOAuthTenantClient( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, + ) + session.add(custom_client_params) + + if client_params is not None: + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + original_params = encrypter.decrypt(custom_client_params.oauth_params) + new_params: dict = { + key: value if value != HIDDEN_VALUE else original_params.get(key, UNKNOWN_VALUE) + for key, value in client_params.items() + } + custom_client_params.encrypted_oauth_params = json.dumps(encrypter.encrypt(new_params)) + + if enable_oauth_custom_client is not None: + custom_client_params.enabled = enable_oauth_custom_client + + session.commit() + return {"result": "success"} + + @staticmethod + def get_custom_oauth_client_params(tenant_id: str, provider: str): + """ + get custom oauth client params + """ + with Session(db.engine) as session: + tool_provider = ToolProviderID(provider) + custom_oauth_client_params: ToolOAuthTenantClient | None = ( + session.query(ToolOAuthTenantClient) + .filter_by( + tenant_id=tenant_id, + plugin_id=tool_provider.plugin_id, + provider=tool_provider.provider_name, + ) + .first() + ) + if custom_oauth_client_params is None: + return {} + + provider_controller = ToolManager.get_builtin_provider(provider, tenant_id) + if not provider_controller: + raise ToolProviderNotFoundError(f"Provider {provider} not found") + + if not isinstance(provider_controller, BuiltinToolProviderController): + raise ValueError(f"Provider {provider} is not a builtin or plugin provider") + + encrypter, _ = create_provider_encrypter( + tenant_id=tenant_id, + config=[x.to_basic_provider_config() for x in provider_controller.get_oauth_client_schema()], + cache=NoOpProviderCredentialCache(), + ) + + return encrypter.mask_tool_credentials(encrypter.decrypt(custom_oauth_client_params.oauth_params)) diff --git a/api/services/tools/mcp_tools_mange_service.py b/api/services/tools/mcp_tools_manage_service.py similarity index 93% rename from api/services/tools/mcp_tools_mange_service.py rename to api/services/tools/mcp_tools_manage_service.py index 7c23abda4b..e0e256912e 100644 --- a/api/services/tools/mcp_tools_mange_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -7,13 +7,14 @@ from sqlalchemy import or_ from sqlalchemy.exc import IntegrityError from core.helper import encrypter +from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.error import MCPAuthError, MCPError from core.mcp.mcp_client import MCPClient from core.tools.entities.api_entities import ToolProviderApiEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType from core.tools.mcp_tool.provider import MCPToolProviderController -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import ProviderConfigEncrypter from extensions.ext_database import db from models.tools import MCPToolProvider from services.tools.tools_transform_service import ToolTransformService @@ -75,9 +76,9 @@ class MCPToolManageService: if existing_provider: if existing_provider.name == name: raise ValueError(f"MCP tool {name} already exists") - elif existing_provider.server_url_hash == server_url_hash: + if existing_provider.server_url_hash == server_url_hash: raise ValueError(f"MCP tool {server_url} already exists") - elif existing_provider.server_identifier == server_identifier: + if existing_provider.server_identifier == server_identifier: raise ValueError(f"MCP tool {server_identifier} already exists") encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) mcp_tool = MCPToolProvider( @@ -109,15 +110,14 @@ class MCPToolManageService: ] @classmethod - def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str): + def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity: mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - try: with MCPClient( mcp_provider.decrypted_server_url, provider_id, tenant_id, authed=mcp_provider.authed, for_list=True ) as mcp_client: tools = mcp_client.list_tools() - except MCPAuthError as e: + except MCPAuthError: raise ValueError("Please auth the tool first") except MCPError as e: raise ValueError(f"Failed to connect to MCP server: {e}") @@ -182,12 +182,11 @@ class MCPToolManageService: error_msg = str(e.orig) if "unique_mcp_provider_name" in error_msg: raise ValueError(f"MCP tool {name} already exists") - elif "unique_mcp_provider_server_url" in error_msg: + if "unique_mcp_provider_server_url" in error_msg: raise ValueError(f"MCP tool {server_url} already exists") - elif "unique_mcp_provider_server_identifier" in error_msg: + if "unique_mcp_provider_server_identifier" in error_msg: raise ValueError(f"MCP tool {server_identifier} already exists") - else: - raise + raise @classmethod def update_mcp_provider_credentials( @@ -197,8 +196,7 @@ class MCPToolManageService: tool_configuration = ProviderConfigEncrypter( tenant_id=mcp_provider.tenant_id, config=list(provider_controller.get_credentials_schema()), - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.provider_id, + provider_config_cache=NoOpProviderCredentialCache(), ) credentials = tool_configuration.encrypt(credentials) mcp_provider.updated_at = datetime.now() diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index 3d0c35cd9b..2d192e6f7f 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -5,21 +5,23 @@ from typing import Any, Optional, Union, cast from yarl import URL from configs import dify_config +from core.helper.provider_cache import ToolProviderCredentialsCache from core.mcp.types import Tool as MCPTool from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.custom_tool.provider import ApiToolProviderController -from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity +from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity, ToolProviderCredentialApiEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( ApiProviderAuthType, + CredentialType, ToolParameter, ToolProviderType, ) from core.tools.plugin_tool.provider import PluginToolProviderController -from core.tools.utils.configuration import ProviderConfigEncrypter +from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider @@ -119,7 +121,12 @@ class ToolTransformService: result.plugin_unique_identifier = provider_controller.plugin_unique_identifier # get credentials schema - schema = {x.to_basic_provider_config().name: x for x in provider_controller.get_credentials_schema()} + schema = { + x.to_basic_provider_config().name: x + for x in provider_controller.get_credentials_schema_by_type( + CredentialType.of(db_provider.credential_type) if db_provider else CredentialType.API_KEY + ) + } for name, value in schema.items(): if result.masked_credentials: @@ -136,15 +143,23 @@ class ToolTransformService: credentials = db_provider.credentials # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_provider_encrypter( tenant_id=db_provider.tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + config=[ + x.to_basic_provider_config() + for x in provider_controller.get_credentials_schema_by_type( + CredentialType.of(db_provider.credential_type) + ) + ], + cache=ToolProviderCredentialsCache( + tenant_id=db_provider.tenant_id, + provider=db_provider.provider, + credential_id=db_provider.id, + ), ) # decrypt the credentials and mask the credentials - decrypted_credentials = tool_configuration.decrypt(data=credentials) - masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) + decrypted_credentials = encrypter.decrypt(data=credentials) + masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials) result.masked_credentials = masked_credentials result.original_credentials = decrypted_credentials @@ -287,16 +302,14 @@ class ToolTransformService: if decrypt_credentials: # init tool configuration - tool_configuration = ProviderConfigEncrypter( + encrypter, _ = create_tool_provider_encrypter( tenant_id=db_provider.tenant_id, - config=[x.to_basic_provider_config() for x in provider_controller.get_credentials_schema()], - provider_type=provider_controller.provider_type.value, - provider_identity=provider_controller.entity.identity.name, + controller=provider_controller, ) # decrypt the credentials and mask the credentials - decrypted_credentials = tool_configuration.decrypt(data=credentials) - masked_credentials = tool_configuration.mask_tool_credentials(data=decrypted_credentials) + decrypted_credentials = encrypter.decrypt(data=credentials) + masked_credentials = encrypter.mask_tool_credentials(data=decrypted_credentials) result.masked_credentials = masked_credentials @@ -306,7 +319,6 @@ class ToolTransformService: def convert_tool_entity_to_api_entity( tool: Union[ApiToolBundle, WorkflowTool, Tool], tenant_id: str, - credentials: dict | None = None, labels: list[str] | None = None, ) -> ToolApiEntity: """ @@ -316,27 +328,39 @@ class ToolTransformService: # fork tool runtime tool = tool.fork_tool_runtime( runtime=ToolRuntime( - credentials=credentials or {}, + credentials={}, tenant_id=tenant_id, ) ) # get tool parameters - parameters = tool.entity.parameters or [] + base_parameters = tool.entity.parameters or [] # get tool runtime parameters runtime_parameters = tool.get_runtime_parameters() - # override parameters - current_parameters = parameters.copy() - for runtime_parameter in runtime_parameters: - found = False - for index, parameter in enumerate(current_parameters): - if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form: - current_parameters[index] = runtime_parameter - found = True - break - if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: - current_parameters.append(runtime_parameter) + # merge parameters using a functional approach to avoid type issues + merged_parameters: list[ToolParameter] = [] + + # create a mapping of runtime parameters for quick lookup + runtime_param_map = {(rp.name, rp.form): rp for rp in runtime_parameters} + + # process base parameters, replacing with runtime versions if they exist + for base_param in base_parameters: + key = (base_param.name, base_param.form) + if key in runtime_param_map: + merged_parameters.append(runtime_param_map[key]) + else: + merged_parameters.append(base_param) + + # add any runtime parameters that weren't in base parameters + for runtime_parameter in runtime_parameters: + if runtime_parameter.form == ToolParameter.ToolParameterForm.FORM: + # check if this parameter is already in merged_parameters + already_exists = any( + p.name == runtime_parameter.name and p.form == runtime_parameter.form for p in merged_parameters + ) + if not already_exists: + merged_parameters.append(runtime_parameter) return ToolApiEntity( author=tool.entity.identity.author, @@ -344,10 +368,10 @@ class ToolTransformService: label=tool.entity.identity.label, description=tool.entity.description.human if tool.entity.description else I18nObject(en_US=""), output_schema=tool.entity.output_schema, - parameters=current_parameters, + parameters=merged_parameters, labels=labels or [], ) - if isinstance(tool, ApiToolBundle): + elif isinstance(tool, ApiToolBundle): return ToolApiEntity( author=tool.author, name=tool.operation_id or "", @@ -356,6 +380,22 @@ class ToolTransformService: parameters=tool.parameters, labels=labels or [], ) + else: + # Handle WorkflowTool case + raise ValueError(f"Unsupported tool type: {type(tool)}") + + @staticmethod + def convert_builtin_provider_to_credential_entity( + provider: BuiltinToolProvider, credentials: dict + ) -> ToolProviderCredentialApiEntity: + return ToolProviderCredentialApiEntity( + id=provider.id, + name=provider.name, + provider=provider.provider, + credential_type=CredentialType.of(provider.credential_type), + is_default=provider.is_default, + credentials=credentials, + ) @staticmethod def convert_mcp_schema_to_parameter(schema: dict) -> list["ToolParameter"]: diff --git a/api/services/website_service.py b/api/services/website_service.py index 6720932a3a..991b669737 100644 --- a/api/services/website_service.py +++ b/api/services/website_service.py @@ -1,6 +1,7 @@ import datetime import json -from typing import Any +from dataclasses import dataclass +from typing import Any, Optional import requests from flask_login import current_user @@ -13,241 +14,392 @@ from extensions.ext_storage import storage from services.auth.api_key_auth_service import ApiKeyAuthService +@dataclass +class CrawlOptions: + """Options for crawling operations.""" + + limit: int = 1 + crawl_sub_pages: bool = False + only_main_content: bool = False + includes: Optional[str] = None + excludes: Optional[str] = None + max_depth: Optional[int] = None + use_sitemap: bool = True + + def get_include_paths(self) -> list[str]: + """Get list of include paths from comma-separated string.""" + return self.includes.split(",") if self.includes else [] + + def get_exclude_paths(self) -> list[str]: + """Get list of exclude paths from comma-separated string.""" + return self.excludes.split(",") if self.excludes else [] + + +@dataclass +class CrawlRequest: + """Request container for crawling operations.""" + + url: str + provider: str + options: CrawlOptions + + +@dataclass +class ScrapeRequest: + """Request container for scraping operations.""" + + provider: str + url: str + tenant_id: str + only_main_content: bool + + +@dataclass +class WebsiteCrawlApiRequest: + """Request container for website crawl API arguments.""" + + provider: str + url: str + options: dict[str, Any] + + def to_crawl_request(self) -> CrawlRequest: + """Convert API request to internal CrawlRequest.""" + options = CrawlOptions( + limit=self.options.get("limit", 1), + crawl_sub_pages=self.options.get("crawl_sub_pages", False), + only_main_content=self.options.get("only_main_content", False), + includes=self.options.get("includes"), + excludes=self.options.get("excludes"), + max_depth=self.options.get("max_depth"), + use_sitemap=self.options.get("use_sitemap", True), + ) + return CrawlRequest(url=self.url, provider=self.provider, options=options) + + @classmethod + def from_args(cls, args: dict) -> "WebsiteCrawlApiRequest": + """Create from Flask-RESTful parsed arguments.""" + provider = args.get("provider") + url = args.get("url") + options = args.get("options", {}) + + if not provider: + raise ValueError("Provider is required") + if not url: + raise ValueError("URL is required") + if not options: + raise ValueError("Options are required") + + return cls(provider=provider, url=url, options=options) + + +@dataclass +class WebsiteCrawlStatusApiRequest: + """Request container for website crawl status API arguments.""" + + provider: str + job_id: str + + @classmethod + def from_args(cls, args: dict, job_id: str) -> "WebsiteCrawlStatusApiRequest": + """Create from Flask-RESTful parsed arguments.""" + provider = args.get("provider") + + if not provider: + raise ValueError("Provider is required") + if not job_id: + raise ValueError("Job ID is required") + + return cls(provider=provider, job_id=job_id) + + class WebsiteService: + """Service class for website crawling operations using different providers.""" + @classmethod - def document_create_args_validate(cls, args: dict): - if "url" not in args or not args["url"]: - raise ValueError("url is required") - if "options" not in args or not args["options"]: - raise ValueError("options is required") - if "limit" not in args["options"] or not args["options"]["limit"]: - raise ValueError("limit is required") + def _get_credentials_and_config(cls, tenant_id: str, provider: str) -> tuple[dict, dict]: + """Get and validate credentials for a provider.""" + credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) + if not credentials or "config" not in credentials: + raise ValueError("No valid credentials found for the provider") + return credentials, credentials["config"] @classmethod - def crawl_url(cls, args: dict) -> dict: - provider = args.get("provider", "") - url = args.get("url") - options = args.get("options", "") - credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) - if provider == "firecrawl": - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") - ) - firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) - crawl_sub_pages = options.get("crawl_sub_pages", False) - only_main_content = options.get("only_main_content", False) - if not crawl_sub_pages: - params = { - "includePaths": [], - "excludePaths": [], - "limit": 1, - "scrapeOptions": {"onlyMainContent": only_main_content}, - } - else: - includes = options.get("includes").split(",") if options.get("includes") else [] - excludes = options.get("excludes").split(",") if options.get("excludes") else [] - params = { - "includePaths": includes, - "excludePaths": excludes, - "limit": options.get("limit", 1), - "scrapeOptions": {"onlyMainContent": only_main_content}, - } - if options.get("max_depth"): - params["maxDepth"] = options.get("max_depth") - job_id = firecrawl_app.crawl_url(url, params) - website_crawl_time_cache_key = f"website_crawl_{job_id}" - time = str(datetime.datetime.now().timestamp()) - redis_client.setex(website_crawl_time_cache_key, 3600, time) - return {"status": "active", "job_id": job_id} - elif provider == "watercrawl": - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") - ) - return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).crawl_url(url, options) + def _get_decrypted_api_key(cls, tenant_id: str, config: dict) -> str: + """Decrypt and return the API key from config.""" + api_key = config.get("api_key") + if not api_key: + raise ValueError("API key not found in configuration") + return encrypter.decrypt_token(tenant_id=tenant_id, token=api_key) - elif provider == "jinareader": - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") - ) - crawl_sub_pages = options.get("crawl_sub_pages", False) - if not crawl_sub_pages: - response = requests.get( - f"https://r.jina.ai/{url}", - headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, - ) - if response.json().get("code") != 200: - raise ValueError("Failed to crawl") - return {"status": "active", "data": response.json().get("data")} - else: - response = requests.post( - "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app", - json={ - "url": url, - "maxPages": options.get("limit", 1), - "useSitemap": options.get("use_sitemap", True), - }, - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {api_key}", - }, - ) - if response.json().get("code") != 200: - raise ValueError("Failed to crawl") - return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")} + @classmethod + def document_create_args_validate(cls, args: dict) -> None: + """Validate arguments for document creation.""" + try: + WebsiteCrawlApiRequest.from_args(args) + except ValueError as e: + raise ValueError(f"Invalid arguments: {e}") + + @classmethod + def crawl_url(cls, api_request: WebsiteCrawlApiRequest) -> dict[str, Any]: + """Crawl a URL using the specified provider with typed request.""" + request = api_request.to_crawl_request() + + _, config = cls._get_credentials_and_config(current_user.current_tenant_id, request.provider) + api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config) + + if request.provider == "firecrawl": + return cls._crawl_with_firecrawl(request=request, api_key=api_key, config=config) + elif request.provider == "watercrawl": + return cls._crawl_with_watercrawl(request=request, api_key=api_key, config=config) + elif request.provider == "jinareader": + return cls._crawl_with_jinareader(request=request, api_key=api_key) else: raise ValueError("Invalid provider") @classmethod - def get_crawl_status(cls, job_id: str, provider: str) -> dict: - credentials = ApiKeyAuthService.get_auth_credentials(current_user.current_tenant_id, "website", provider) - if provider == "firecrawl": - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") - ) - firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) - result = firecrawl_app.check_crawl_status(job_id) - crawl_status_data = { - "status": result.get("status", "active"), - "job_id": job_id, - "total": result.get("total", 0), - "current": result.get("current", 0), - "data": result.get("data", []), + def _crawl_with_firecrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]: + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) + + if not request.options.crawl_sub_pages: + params = { + "includePaths": [], + "excludePaths": [], + "limit": 1, + "scrapeOptions": {"onlyMainContent": request.options.only_main_content}, } - if crawl_status_data["status"] == "completed": - website_crawl_time_cache_key = f"website_crawl_{job_id}" - start_time = redis_client.get(website_crawl_time_cache_key) - if start_time: - end_time = datetime.datetime.now().timestamp() - time_consuming = abs(end_time - float(start_time)) - crawl_status_data["time_consuming"] = f"{time_consuming:.2f}" - redis_client.delete(website_crawl_time_cache_key) - elif provider == "watercrawl": - # decrypt api_key - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") + else: + params = { + "includePaths": request.options.get_include_paths(), + "excludePaths": request.options.get_exclude_paths(), + "limit": request.options.limit, + "scrapeOptions": {"onlyMainContent": request.options.only_main_content}, + } + if request.options.max_depth: + params["maxDepth"] = request.options.max_depth + + job_id = firecrawl_app.crawl_url(request.url, params) + website_crawl_time_cache_key = f"website_crawl_{job_id}" + time = str(datetime.datetime.now().timestamp()) + redis_client.setex(website_crawl_time_cache_key, 3600, time) + return {"status": "active", "job_id": job_id} + + @classmethod + def _crawl_with_watercrawl(cls, request: CrawlRequest, api_key: str, config: dict) -> dict[str, Any]: + # Convert CrawlOptions back to dict format for WaterCrawlProvider + options = { + "limit": request.options.limit, + "crawl_sub_pages": request.options.crawl_sub_pages, + "only_main_content": request.options.only_main_content, + "includes": request.options.includes, + "excludes": request.options.excludes, + "max_depth": request.options.max_depth, + "use_sitemap": request.options.use_sitemap, + } + return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).crawl_url( + url=request.url, options=options + ) + + @classmethod + def _crawl_with_jinareader(cls, request: CrawlRequest, api_key: str) -> dict[str, Any]: + if not request.options.crawl_sub_pages: + response = requests.get( + f"https://r.jina.ai/{request.url}", + headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, ) - crawl_status_data = WaterCrawlProvider( - api_key, credentials.get("config").get("base_url", None) - ).get_crawl_status(job_id) - elif provider == "jinareader": - api_key = encrypter.decrypt_token( - tenant_id=current_user.current_tenant_id, token=credentials.get("config").get("api_key") + if response.json().get("code") != 200: + raise ValueError("Failed to crawl") + return {"status": "active", "data": response.json().get("data")} + else: + response = requests.post( + "https://adaptivecrawl-kir3wx7b3a-uc.a.run.app", + json={ + "url": request.url, + "maxPages": request.options.limit, + "useSitemap": request.options.use_sitemap, + }, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {api_key}", + }, ) + if response.json().get("code") != 200: + raise ValueError("Failed to crawl") + return {"status": "active", "job_id": response.json().get("data", {}).get("taskId")} + + @classmethod + def get_crawl_status(cls, job_id: str, provider: str) -> dict[str, Any]: + """Get crawl status using string parameters.""" + api_request = WebsiteCrawlStatusApiRequest(provider=provider, job_id=job_id) + return cls.get_crawl_status_typed(api_request) + + @classmethod + def get_crawl_status_typed(cls, api_request: WebsiteCrawlStatusApiRequest) -> dict[str, Any]: + """Get crawl status using typed request.""" + _, config = cls._get_credentials_and_config(current_user.current_tenant_id, api_request.provider) + api_key = cls._get_decrypted_api_key(current_user.current_tenant_id, config) + + if api_request.provider == "firecrawl": + return cls._get_firecrawl_status(api_request.job_id, api_key, config) + elif api_request.provider == "watercrawl": + return cls._get_watercrawl_status(api_request.job_id, api_key, config) + elif api_request.provider == "jinareader": + return cls._get_jinareader_status(api_request.job_id, api_key) + else: + raise ValueError("Invalid provider") + + @classmethod + def _get_firecrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]: + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) + result = firecrawl_app.check_crawl_status(job_id) + crawl_status_data = { + "status": result.get("status", "active"), + "job_id": job_id, + "total": result.get("total", 0), + "current": result.get("current", 0), + "data": result.get("data", []), + } + if crawl_status_data["status"] == "completed": + website_crawl_time_cache_key = f"website_crawl_{job_id}" + start_time = redis_client.get(website_crawl_time_cache_key) + if start_time: + end_time = datetime.datetime.now().timestamp() + time_consuming = abs(end_time - float(start_time)) + crawl_status_data["time_consuming"] = f"{time_consuming:.2f}" + redis_client.delete(website_crawl_time_cache_key) + return crawl_status_data + + @classmethod + def _get_watercrawl_status(cls, job_id: str, api_key: str, config: dict) -> dict[str, Any]: + return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_status(job_id) + + @classmethod + def _get_jinareader_status(cls, job_id: str, api_key: str) -> dict[str, Any]: + response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id}, + ) + data = response.json().get("data", {}) + crawl_status_data = { + "status": data.get("status", "active"), + "job_id": job_id, + "total": len(data.get("urls", [])), + "current": len(data.get("processed", [])) + len(data.get("failed", [])), + "data": [], + "time_consuming": data.get("duration", 0) / 1000, + } + + if crawl_status_data["status"] == "completed": response = requests.post( "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, - json={"taskId": job_id}, + json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, ) data = response.json().get("data", {}) - crawl_status_data = { - "status": data.get("status", "active"), - "job_id": job_id, - "total": len(data.get("urls", [])), - "current": len(data.get("processed", [])) + len(data.get("failed", [])), - "data": [], - "time_consuming": data.get("duration", 0) / 1000, - } - - if crawl_status_data["status"] == "completed": - response = requests.post( - "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", - headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, - json={"taskId": job_id, "urls": list(data.get("processed", {}).keys())}, - ) - data = response.json().get("data", {}) - formatted_data = [ - { - "title": item.get("data", {}).get("title"), - "source_url": item.get("data", {}).get("url"), - "description": item.get("data", {}).get("description"), - "markdown": item.get("data", {}).get("content"), - } - for item in data.get("processed", {}).values() - ] - crawl_status_data["data"] = formatted_data - else: - raise ValueError("Invalid provider") + formatted_data = [ + { + "title": item.get("data", {}).get("title"), + "source_url": item.get("data", {}).get("url"), + "description": item.get("data", {}).get("description"), + "markdown": item.get("data", {}).get("content"), + } + for item in data.get("processed", {}).values() + ] + crawl_status_data["data"] = formatted_data return crawl_status_data @classmethod def get_crawl_url_data(cls, job_id: str, provider: str, url: str, tenant_id: str) -> dict[str, Any] | None: - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) - # decrypt api_key - api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) + _, config = cls._get_credentials_and_config(tenant_id, provider) + api_key = cls._get_decrypted_api_key(tenant_id, config) if provider == "firecrawl": - crawl_data: list[dict[str, Any]] | None = None - file_key = "website_files/" + job_id + ".txt" - if storage.exists(file_key): - stored_data = storage.load_once(file_key) - if stored_data: - crawl_data = json.loads(stored_data.decode("utf-8")) - else: - firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) - result = firecrawl_app.check_crawl_status(job_id) - if result.get("status") != "completed": - raise ValueError("Crawl job is not completed") - crawl_data = result.get("data") - - if crawl_data: - for item in crawl_data: - if item.get("source_url") == url: - return dict(item) - return None + return cls._get_firecrawl_url_data(job_id, url, api_key, config) elif provider == "watercrawl": - api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) - return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).get_crawl_url_data( - job_id, url - ) + return cls._get_watercrawl_url_data(job_id, url, api_key, config) elif provider == "jinareader": - if not job_id: - response = requests.get( - f"https://r.jina.ai/{url}", - headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, - ) - if response.json().get("code") != 200: - raise ValueError("Failed to crawl") - return dict(response.json().get("data", {})) - else: - # Get crawl status first - status_response = requests.post( - "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", - headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, - json={"taskId": job_id}, - ) - status_data = status_response.json().get("data", {}) - if status_data.get("status") != "completed": - raise ValueError("Crawl job is not completed") - - # Get processed data - data_response = requests.post( - "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", - headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, - json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())}, - ) - processed_data = data_response.json().get("data", {}) - for item in processed_data.get("processed", {}).values(): - if item.get("data", {}).get("url") == url: - return dict(item.get("data", {})) - return None + return cls._get_jinareader_url_data(job_id, url, api_key) else: raise ValueError("Invalid provider") @classmethod - def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict: - credentials = ApiKeyAuthService.get_auth_credentials(tenant_id, "website", provider) - if provider == "firecrawl": - # decrypt api_key - api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) - firecrawl_app = FirecrawlApp(api_key=api_key, base_url=credentials.get("config").get("base_url", None)) - params = {"onlyMainContent": only_main_content} - result = firecrawl_app.scrape_url(url, params) - return result - elif provider == "watercrawl": - api_key = encrypter.decrypt_token(tenant_id=tenant_id, token=credentials.get("config").get("api_key")) - return WaterCrawlProvider(api_key, credentials.get("config").get("base_url", None)).scrape_url(url) + def _get_firecrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None: + crawl_data: list[dict[str, Any]] | None = None + file_key = "website_files/" + job_id + ".txt" + if storage.exists(file_key): + stored_data = storage.load_once(file_key) + if stored_data: + crawl_data = json.loads(stored_data.decode("utf-8")) + else: + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) + result = firecrawl_app.check_crawl_status(job_id) + if result.get("status") != "completed": + raise ValueError("Crawl job is not completed") + crawl_data = result.get("data") + + if crawl_data: + for item in crawl_data: + if item.get("source_url") == url: + return dict(item) + return None + + @classmethod + def _get_watercrawl_url_data(cls, job_id: str, url: str, api_key: str, config: dict) -> dict[str, Any] | None: + return WaterCrawlProvider(api_key, config.get("base_url")).get_crawl_url_data(job_id, url) + + @classmethod + def _get_jinareader_url_data(cls, job_id: str, url: str, api_key: str) -> dict[str, Any] | None: + if not job_id: + response = requests.get( + f"https://r.jina.ai/{url}", + headers={"Accept": "application/json", "Authorization": f"Bearer {api_key}"}, + ) + if response.json().get("code") != 200: + raise ValueError("Failed to crawl") + return dict(response.json().get("data", {})) + else: + # Get crawl status first + status_response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id}, + ) + status_data = status_response.json().get("data", {}) + if status_data.get("status") != "completed": + raise ValueError("Crawl job is not completed") + + # Get processed data + data_response = requests.post( + "https://adaptivecrawlstatus-kir3wx7b3a-uc.a.run.app", + headers={"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"}, + json={"taskId": job_id, "urls": list(status_data.get("processed", {}).keys())}, + ) + processed_data = data_response.json().get("data", {}) + for item in processed_data.get("processed", {}).values(): + if item.get("data", {}).get("url") == url: + return dict(item.get("data", {})) + return None + + @classmethod + def get_scrape_url_data(cls, provider: str, url: str, tenant_id: str, only_main_content: bool) -> dict[str, Any]: + request = ScrapeRequest(provider=provider, url=url, tenant_id=tenant_id, only_main_content=only_main_content) + + _, config = cls._get_credentials_and_config(tenant_id=request.tenant_id, provider=request.provider) + api_key = cls._get_decrypted_api_key(tenant_id=request.tenant_id, config=config) + + if request.provider == "firecrawl": + return cls._scrape_with_firecrawl(request=request, api_key=api_key, config=config) + elif request.provider == "watercrawl": + return cls._scrape_with_watercrawl(request=request, api_key=api_key, config=config) else: raise ValueError("Invalid provider") + + @classmethod + def _scrape_with_firecrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]: + firecrawl_app = FirecrawlApp(api_key=api_key, base_url=config.get("base_url")) + params = {"onlyMainContent": request.only_main_content} + return firecrawl_app.scrape_url(url=request.url, params=params) + + @classmethod + def _scrape_with_watercrawl(cls, request: ScrapeRequest, api_key: str, config: dict) -> dict[str, Any]: + return WaterCrawlProvider(api_key=api_key, base_url=config.get("base_url")).scrape_url(request.url) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0149d50346..403e559743 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -2,8 +2,7 @@ import json import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence -from datetime import UTC, datetime -from typing import Any, Optional +from typing import Any, Optional, cast from uuid import uuid4 from sqlalchemy import select @@ -15,10 +14,10 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File from core.repositories import DifyCoreRepositoryFactory from core.variables import Variable +from core.variables.variables import VariableUnion from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.graph_engine.entities.event import InNodeEvent from core.workflow.nodes import NodeType @@ -28,10 +27,12 @@ from core.workflow.nodes.event import RunCompletedEvent from core.workflow.nodes.event.types import NodeEvent from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.start.entities import StartNodeData +from core.workflow.system_variable import SystemVariable from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings +from libs.datetime_utils import naive_utc_now from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider @@ -231,7 +232,7 @@ class WorkflowService: workflow.graph = json.dumps(graph) workflow.features = json.dumps(features) workflow.updated_by = account.id - workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow.updated_at = naive_utc_now() workflow.environment_variables = environment_variables workflow.conversation_variables = conversation_variables @@ -267,7 +268,7 @@ class WorkflowService: tenant_id=app_model.tenant_id, app_id=app_model.id, type=draft_workflow.type, - version=Workflow.version_from_datetime(datetime.now(UTC).replace(tzinfo=None)), + version=Workflow.version_from_datetime(naive_utc_now()), graph=draft_workflow.graph, features=draft_workflow.features, created_by=account.id, @@ -369,7 +370,7 @@ class WorkflowService: else: variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable.empty(), user_inputs=user_inputs, environment_variables=draft_workflow.environment_variables, conversation_variables=[], @@ -464,10 +465,10 @@ class WorkflowService: node_id: str, ) -> WorkflowNodeExecution: try: - node_instance, generator = invoke_node_fn() + node, node_events = invoke_node_fn() node_run_result: NodeRunResult | None = None - for event in generator: + for event in node_events: if isinstance(event, RunCompletedEvent): node_run_result = event.run_result @@ -478,18 +479,18 @@ class WorkflowService: if not node_run_result: raise ValueError("Node run failed with no run result") # single step debug mode error handling return - if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node_instance.should_continue_on_error: + if node_run_result.status == WorkflowNodeExecutionStatus.FAILED and node.continue_on_error: node_error_args: dict[str, Any] = { "status": WorkflowNodeExecutionStatus.EXCEPTION, "error": node_run_result.error, "inputs": node_run_result.inputs, - "metadata": {"error_strategy": node_instance.node_data.error_strategy}, + "metadata": {"error_strategy": node.error_strategy}, } - if node_instance.node_data.error_strategy is ErrorStrategy.DEFAULT_VALUE: + if node.error_strategy is ErrorStrategy.DEFAULT_VALUE: node_run_result = NodeRunResult( **node_error_args, outputs={ - **node_instance.node_data.default_value_dict, + **node.default_value_dict, "error_message": node_run_result.error, "error_type": node_run_result.error_type, }, @@ -508,10 +509,10 @@ class WorkflowService: ) error = node_run_result.error if not run_succeeded else None except WorkflowNodeRunFailedError as e: - node_instance = e.node_instance + node = e._node run_succeeded = False node_run_result = None - error = e.error + error = e._error # Create a NodeExecution domain model node_execution = WorkflowNodeExecution( @@ -519,11 +520,11 @@ class WorkflowService: workflow_id="", # This is a single-step execution, so no workflow ID index=1, node_id=node_id, - node_type=node_instance.node_type, - title=node_instance.node_data.title, + node_type=node.type_, + title=node.title, elapsed_time=time.perf_counter() - start_at, - created_at=datetime.now(UTC).replace(tzinfo=None), - finished_at=datetime.now(UTC).replace(tzinfo=None), + created_at=naive_utc_now(), + finished_at=naive_utc_now(), ) if run_succeeded and node_run_result: @@ -620,7 +621,7 @@ class WorkflowService: setattr(workflow, field, value) workflow.updated_by = account_id - workflow.updated_at = datetime.now(UTC).replace(tzinfo=None) + workflow.updated_at = naive_utc_now() return workflow @@ -685,36 +686,30 @@ def _setup_variable_pool( ): # Only inject system variables for START node type. if node_type == NodeType.START: - # Create a variable pool. - system_inputs: dict[SystemVariableKey, Any] = { - # From inputs: - SystemVariableKey.FILES: files, - SystemVariableKey.USER_ID: user_id, - # From workflow model - SystemVariableKey.APP_ID: workflow.app_id, - SystemVariableKey.WORKFLOW_ID: workflow.id, - # Randomly generated. - SystemVariableKey.WORKFLOW_EXECUTION_ID: str(uuid.uuid4()), - } + system_variable = SystemVariable( + user_id=user_id, + app_id=workflow.app_id, + workflow_id=workflow.id, + files=files or [], + workflow_execution_id=str(uuid.uuid4()), + ) # Only add chatflow-specific variables for non-workflow types if workflow.type != WorkflowType.WORKFLOW.value: - system_inputs.update( - { - SystemVariableKey.QUERY: query, - SystemVariableKey.CONVERSATION_ID: conversation_id, - SystemVariableKey.DIALOGUE_COUNT: 0, - } - ) + system_variable.query = query + system_variable.conversation_id = conversation_id + system_variable.dialogue_count = 0 else: - system_inputs = {} + system_variable = SystemVariable.empty() # init variable pool variable_pool = VariablePool( - system_variables=system_inputs, + system_variables=system_variable, user_inputs=user_inputs, environment_variables=workflow.environment_variables, - conversation_variables=conversation_variables, + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + conversation_variables=cast(list[VariableUnion], conversation_variables), # ) return variable_pool diff --git a/api/services/workspace_service.py b/api/services/workspace_service.py index 125e0c1b1e..bb35645c50 100644 --- a/api/services/workspace_service.py +++ b/api/services/workspace_service.py @@ -31,7 +31,7 @@ class WorkspaceService: assert tenant_account_join is not None, "TenantAccountJoin not found" tenant_info["role"] = tenant_account_join.role - can_replace_logo = FeatureService.get_features(tenant_info["id"]).can_replace_logo + can_replace_logo = FeatureService.get_features(tenant.id).can_replace_logo if can_replace_logo and TenantService.has_roles(tenant, [TenantAccountRole.OWNER, TenantAccountRole.ADMIN]): base_url = dify_config.FILES_URL diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 55cac6a9af..a85aab0bb7 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -1,4 +1,3 @@ -import datetime import logging import time @@ -8,6 +7,7 @@ from celery import shared_task # type: ignore from configs import dify_config from core.indexing_runner import DocumentIsPausedError, IndexingRunner from extensions.ext_database import db +from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document from services.feature_service import FeatureService @@ -53,7 +53,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = "error" document.error = str(e) - document.stopped_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.stopped_at = naive_utc_now() db.session.add(document) db.session.commit() db.session.close() @@ -68,7 +68,7 @@ def document_indexing_task(dataset_id: str, document_ids: list): if document: document.indexing_status = "parsing" - document.processing_started_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + document.processing_started_at = naive_utc_now() documents.append(document) db.session.add(document) db.session.commit() diff --git a/api/tasks/mail_change_mail_task.py b/api/tasks/mail_change_mail_task.py new file mode 100644 index 0000000000..da44040b7d --- /dev/null +++ b/api/tasks/mail_change_mail_task.py @@ -0,0 +1,78 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore +from flask import render_template + +from extensions.ext_mail import mail +from services.feature_service import FeatureService + + +@shared_task(queue="mail") +def send_change_mail_task(language: str, to: str, code: str, phase: str): + """ + Async Send change email mail + :param language: Language in which the email should be sent (e.g., 'en', 'zh') + :param to: Recipient email address + :param code: Change email code + :param phase: Change email phase (new_email, old_email) + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start change email mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + + email_config = { + "zh-Hans": { + "old_email": { + "subject": "检测您现在的邮箱", + "template_with_brand": "change_mail_confirm_old_template_zh-CN.html", + "template_without_brand": "without-brand/change_mail_confirm_old_template_zh-CN.html", + }, + "new_email": { + "subject": "确认您的邮箱地址变更", + "template_with_brand": "change_mail_confirm_new_template_zh-CN.html", + "template_without_brand": "without-brand/change_mail_confirm_new_template_zh-CN.html", + }, + }, + "en": { + "old_email": { + "subject": "Check your current email", + "template_with_brand": "change_mail_confirm_old_template_en-US.html", + "template_without_brand": "without-brand/change_mail_confirm_old_template_en-US.html", + }, + "new_email": { + "subject": "Confirm your new email address", + "template_with_brand": "change_mail_confirm_new_template_en-US.html", + "template_without_brand": "without-brand/change_mail_confirm_new_template_en-US.html", + }, + }, + } + + # send change email mail using different languages + try: + system_features = FeatureService.get_system_features() + lang_key = "zh-Hans" if language == "zh-Hans" else "en" + + if phase not in ["old_email", "new_email"]: + raise ValueError("Invalid phase") + + config = email_config[lang_key][phase] + subject = config["subject"] + + if system_features.branding.enabled: + template = config["template_without_brand"] + else: + template = config["template_with_brand"] + + html_content = render_template(template, to=to, code=code) + mail.send(to=to, subject=subject, html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style("Send change email mail to {} succeeded: latency: {}".format(to, end_at - start_at), fg="green") + ) + except Exception: + logging.exception("Send change email mail to {} failed".format(to)) diff --git a/api/tasks/mail_owner_transfer_task.py b/api/tasks/mail_owner_transfer_task.py new file mode 100644 index 0000000000..8d05c6dc0f --- /dev/null +++ b/api/tasks/mail_owner_transfer_task.py @@ -0,0 +1,152 @@ +import logging +import time + +import click +from celery import shared_task # type: ignore +from flask import render_template + +from extensions.ext_mail import mail +from services.feature_service import FeatureService + + +@shared_task(queue="mail") +def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspace: str): + """ + Async Send owner transfer confirm mail + :param language: Language in which the email should be sent (e.g., 'en', 'zh') + :param to: Recipient email address + :param workspace: Workspace name + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start change email mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + # send change email mail using different languages + try: + if language == "zh-Hans": + template = "transfer_workspace_owner_confirm_template_zh-CN.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + template = "without-brand/transfer_workspace_owner_confirm_template_zh-CN.html" + html_content = render_template(template, to=to, code=code, WorkspaceName=workspace) + mail.send(to=to, subject="验证您转移工作空间所有权的请求", html=html_content) + else: + html_content = render_template(template, to=to, code=code, WorkspaceName=workspace) + mail.send(to=to, subject="验证您转移工作空间所有权的请求", html=html_content) + else: + template = "transfer_workspace_owner_confirm_template_en-US.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + template = "without-brand/transfer_workspace_owner_confirm_template_en-US.html" + html_content = render_template(template, to=to, code=code, WorkspaceName=workspace) + mail.send(to=to, subject="Verify Your Request to Transfer Workspace Ownership", html=html_content) + else: + html_content = render_template(template, to=to, code=code, WorkspaceName=workspace) + mail.send(to=to, subject="Verify Your Request to Transfer Workspace Ownership", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send owner transfer confirm mail to {} succeeded: latency: {}".format(to, end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("owner transfer confirm email mail to {} failed".format(to)) + + +@shared_task(queue="mail") +def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: str, new_owner_email: str): + """ + Async Send owner transfer confirm mail + :param language: Language in which the email should be sent (e.g., 'en', 'zh') + :param to: Recipient email address + :param workspace: Workspace name + :param new_owner_email: New owner email + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start change email mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + # send change email mail using different languages + try: + if language == "zh-Hans": + template = "transfer_workspace_old_owner_notify_template_zh-CN.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + template = "without-brand/transfer_workspace_old_owner_notify_template_zh-CN.html" + html_content = render_template(template, to=to, WorkspaceName=workspace, NewOwnerEmail=new_owner_email) + mail.send(to=to, subject="工作区所有权已转移", html=html_content) + else: + html_content = render_template(template, to=to, WorkspaceName=workspace, NewOwnerEmail=new_owner_email) + mail.send(to=to, subject="工作区所有权已转移", html=html_content) + else: + template = "transfer_workspace_old_owner_notify_template_en-US.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + template = "without-brand/transfer_workspace_old_owner_notify_template_en-US.html" + html_content = render_template(template, to=to, WorkspaceName=workspace, NewOwnerEmail=new_owner_email) + mail.send(to=to, subject="Workspace ownership has been transferred", html=html_content) + else: + html_content = render_template(template, to=to, WorkspaceName=workspace, NewOwnerEmail=new_owner_email) + mail.send(to=to, subject="Workspace ownership has been transferred", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send owner transfer confirm mail to {} succeeded: latency: {}".format(to, end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("owner transfer confirm email mail to {} failed".format(to)) + + +@shared_task(queue="mail") +def send_new_owner_transfer_notify_email_task(language: str, to: str, workspace: str): + """ + Async Send owner transfer confirm mail + :param language: Language in which the email should be sent (e.g., 'en', 'zh') + :param to: Recipient email address + :param code: Change email code + :param workspace: Workspace name + """ + if not mail.is_inited(): + return + + logging.info(click.style("Start change email mail to {}".format(to), fg="green")) + start_at = time.perf_counter() + # send change email mail using different languages + try: + if language == "zh-Hans": + template = "transfer_workspace_new_owner_notify_template_zh-CN.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + template = "without-brand/transfer_workspace_new_owner_notify_template_zh-CN.html" + html_content = render_template(template, to=to, WorkspaceName=workspace) + mail.send(to=to, subject=f"您现在是 {workspace} 的所有者", html=html_content) + else: + html_content = render_template(template, to=to, WorkspaceName=workspace) + mail.send(to=to, subject=f"您现在是 {workspace} 的所有者", html=html_content) + else: + template = "transfer_workspace_new_owner_notify_template_en-US.html" + system_features = FeatureService.get_system_features() + if system_features.branding.enabled: + template = "without-brand/transfer_workspace_new_owner_notify_template_en-US.html" + html_content = render_template(template, to=to, WorkspaceName=workspace) + mail.send(to=to, subject=f"You are now the owner of {workspace}", html=html_content) + else: + html_content = render_template(template, to=to, WorkspaceName=workspace) + mail.send(to=to, subject=f"You are now the owner of {workspace}", html=html_content) + + end_at = time.perf_counter() + logging.info( + click.style( + "Send owner transfer confirm mail to {} succeeded: latency: {}".format(to, end_at - start_at), + fg="green", + ) + ) + except Exception: + logging.exception("owner transfer confirm email mail to {} failed".format(to)) diff --git a/api/templates/change_mail_confirm_new_template_en-US.html b/api/templates/change_mail_confirm_new_template_en-US.html new file mode 100644 index 0000000000..88721e787c --- /dev/null +++ b/api/templates/change_mail_confirm_new_template_en-US.html @@ -0,0 +1,125 @@ + + + + + + + + +
+
+ + Dify Logo +
+

Confirm Your New Email Address

+
+

You’re updating the email address linked to your Dify account.

+

To confirm this action, please use the verification code below.

+

This code will only be valid for the next 5 minutes:

+
+
+ {{code}} +
+

If you didn’t make this request, please ignore this email or contact support immediately.

+
+ + + + diff --git a/api/templates/change_mail_confirm_new_template_zh-CN.html b/api/templates/change_mail_confirm_new_template_zh-CN.html new file mode 100644 index 0000000000..25336ea1a1 --- /dev/null +++ b/api/templates/change_mail_confirm_new_template_zh-CN.html @@ -0,0 +1,125 @@ + + + + + + + + +
+
+ + Dify Logo +
+

确认您的邮箱地址变更

+
+

您正在更新与您的 Dify 账户关联的邮箱地址。

+

为了确认此操作,请使用以下验证码。

+

此验证码仅在接下来的5分钟内有效:

+
+
+ {{code}} +
+

如果您没有请求变更邮箱地址,请忽略此邮件或立即联系支持。

+
+ + + + diff --git a/api/templates/change_mail_confirm_old_template_en-US.html b/api/templates/change_mail_confirm_old_template_en-US.html new file mode 100644 index 0000000000..b20306aa87 --- /dev/null +++ b/api/templates/change_mail_confirm_old_template_en-US.html @@ -0,0 +1,125 @@ + + + + + + + + +
+
+ + Dify Logo +
+

Verify Your Request to Change Email

+
+

We received a request to change the email address associated with your Dify account.

+

To confirm this action, please use the verification code below.

+

This code will only be valid for the next 5 minutes:

+
+
+ {{code}} +
+

If you didn’t make this request, please ignore this email or contact support immediately.

+
+ + + + diff --git a/api/templates/change_mail_confirm_old_template_zh-CN.html b/api/templates/change_mail_confirm_old_template_zh-CN.html new file mode 100644 index 0000000000..23c9e46652 --- /dev/null +++ b/api/templates/change_mail_confirm_old_template_zh-CN.html @@ -0,0 +1,124 @@ + + + + + + + + +
+
+ + Dify Logo +
+

验证您的邮箱变更请求

+
+

我们收到了一个变更您 Dify 账户关联邮箱地址的请求。

+

此验证码仅在接下来的5分钟内有效:

+
+
+ {{code}} +
+

如果您没有请求变更邮箱地址,请忽略此邮件或立即联系支持。

+
+ + + + diff --git a/api/templates/clean_document_job_mail_template-US.html b/api/templates/clean_document_job_mail_template-US.html index 2d8f78b46a..b26e494f80 100644 --- a/api/templates/clean_document_job_mail_template-US.html +++ b/api/templates/clean_document_job_mail_template-US.html @@ -6,94 +6,136 @@ Documents Disabled Notification -