feat(api): regenerate the url signature when serializing File object.

pull/20699/head
QuantumGhost 12 months ago
parent 222087e3be
commit 83cd796b4d

@ -1,5 +1,5 @@
import logging
from typing import NoReturn
from typing import Any, NoReturn
from flask import Response
from flask_restful import Resource, fields, inputs, marshal_with, reqparse
@ -13,6 +13,8 @@ from controllers.console.app.error import (
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from factories.variable_factory import build_segment
from libs.login import current_user, login_required
@ -24,6 +26,32 @@ from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
def _convert_values_to_json_serializable_object(value: Segment) -> Any:
if isinstance(value, FileSegment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
return value.value
def _serialize_var_value(variable: WorkflowDraftVariable) -> Any:
value = variable.get_value()
# create a copy of the value to avoid affecting the model cache.
value = value.model_copy(deep=True)
# Refresh the url signature before returning it to client.
if isinstance(value, FileSegment):
file = value.value
file.remote_url = file.generate_url()
elif isinstance(value, ArrayFileSegment):
files = value.value
for file in files:
file.remote_url = file.generate_url()
return _convert_values_to_json_serializable_object(value)
def _create_pagination_parser():
parser = reqparse.RequestParser()
parser.add_argument(
@ -51,7 +79,7 @@ _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
value=fields.Raw(attribute=lambda variable: variable.get_value().value),
value=fields.Raw(attribute=_serialize_var_value),
)
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {

@ -309,19 +309,17 @@ class WorkflowService:
def run_draft_workflow_node(
self,
app_model: App,
draft_workflow: Workflow,
node_id: str,
user_inputs: dict,
account: Account,
query: str = "",
files: list[File] | None = None,
files: Sequence[File] | None = None,
) -> WorkflowNodeExecutionModel:
"""
Run draft workflow node
"""
# fetch draft workflow by app_model
draft_workflow = self.get_draft_workflow(app_model=app_model)
if not draft_workflow:
raise ValueError("Workflow not initialized")
files = files or []
with Session(bind=db.engine, expire_on_commit=False) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session)

Loading…
Cancel
Save