From 83cd796b4d4338145413330eaa803af553cd0b1c Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Thu, 5 Jun 2025 00:59:31 +0800 Subject: [PATCH] feat(api): regenerate the url signature when serializing File object. --- .../console/app/workflow_draft_variable.py | 32 +++++++++++++++++-- api/services/workflow_service.py | 8 ++--- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index c2652fcc04..16de3dbe42 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -1,5 +1,5 @@ import logging -from typing import NoReturn +from typing import Any, NoReturn from flask import Response from flask_restful import Resource, fields, inputs, marshal_with, reqparse @@ -13,6 +13,8 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError +from core.variables.segment_group import SegmentGroup +from core.variables.segments import ArrayFileSegment, FileSegment, Segment from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment from libs.login import current_user, login_required @@ -24,6 +26,32 @@ from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) +def _convert_values_to_json_serializable_object(value: Segment) -> Any: + if isinstance(value, FileSegment): + return value.value.model_dump() + elif isinstance(value, ArrayFileSegment): + return [i.model_dump() for i in value.value] + elif isinstance(value, SegmentGroup): + return [_convert_values_to_json_serializable_object(i) for i in value.value] + else: + return value.value + + +def _serialize_var_value(variable: WorkflowDraftVariable) -> Any: + value = variable.get_value() + # create a copy of the value to avoid affecting the model cache. + value = value.model_copy(deep=True) + # Refresh the url signature before returning it to client. + if isinstance(value, FileSegment): + file = value.value + file.remote_url = file.generate_url() + elif isinstance(value, ArrayFileSegment): + files = value.value + for file in files: + file.remote_url = file.generate_url() + return _convert_values_to_json_serializable_object(value) + + def _create_pagination_parser(): parser = reqparse.RequestParser() parser.add_argument( @@ -51,7 +79,7 @@ _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { _WORKFLOW_DRAFT_VARIABLE_FIELDS = dict( _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, - value=fields.Raw(attribute=lambda variable: variable.get_value().value), + value=fields.Raw(attribute=_serialize_var_value), ) _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = { diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 0d3aa85c30..3841a2ec56 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -309,19 +309,17 @@ class WorkflowService: def run_draft_workflow_node( self, app_model: App, + draft_workflow: Workflow, node_id: str, user_inputs: dict, account: Account, query: str = "", - files: list[File] | None = None, + files: Sequence[File] | None = None, ) -> WorkflowNodeExecutionModel: """ Run draft workflow node """ - # fetch draft workflow by app_model - draft_workflow = self.get_draft_workflow(app_model=app_model) - if not draft_workflow: - raise ValueError("Workflow not initialized") + files = files or [] with Session(bind=db.engine, expire_on_commit=False) as session, session.begin(): draft_var_srv = WorkflowDraftVariableService(session)