From 8f8465cd9ffbca471e9a5d5006b498153f3128de Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Fri, 20 Jun 2025 22:48:13 +0800 Subject: [PATCH] feat(api): load _storage_key for file types while loading draft variables --- .../app/apps/advanced_chat/app_generator.py | 2 ++ api/core/app/apps/workflow/app_generator.py | 2 ++ .../workflow_draft_variable_service.py | 31 ++++++++++--------- api/services/workflow_service.py | 1 + .../test_workflow_draft_variable_service.py | 6 ++-- 5 files changed, 25 insertions(+), 17 deletions(-) diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index dde0698c42..afecd99978 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -271,6 +271,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): var_loader = DraftVarLoader( engine=db.engine, app_id=application_generate_entity.app_config.app_id, + tenant_id=application_generate_entity.app_config.tenant_id, ) draft_var_srv = WorkflowDraftVariableService(db.session()) draft_var_srv.prefill_conversation_variable_default_values(workflow) @@ -353,6 +354,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): var_loader = DraftVarLoader( engine=db.engine, app_id=application_generate_entity.app_config.app_id, + tenant_id=application_generate_entity.app_config.tenant_id, ) draft_var_srv = WorkflowDraftVariableService(db.session()) draft_var_srv.prefill_conversation_variable_default_values(workflow) diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 70cc144075..369fa0e48c 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -317,6 +317,7 @@ class WorkflowAppGenerator(BaseAppGenerator): var_loader = DraftVarLoader( engine=db.engine, app_id=application_generate_entity.app_config.app_id, + tenant_id=application_generate_entity.app_config.tenant_id, ) return self._generate( @@ -400,6 +401,7 @@ class WorkflowAppGenerator(BaseAppGenerator): var_loader = DraftVarLoader( engine=db.engine, app_id=application_generate_entity.app_config.app_id, + tenant_id=application_generate_entity.app_config.tenant_id, ) return self._generate( app_model=app_model, diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 4f4e749d41..095321251f 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -14,13 +14,14 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.file.models import File from core.variables import Segment, StringSegment, Variable from core.variables.consts import MIN_SELECTORS_LENGTH -from core.variables.segments import ArrayFileSegment +from core.variables.segments import ArrayFileSegment, FileSegment from core.variables.types import SegmentType from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.enums import SystemVariableKey from core.workflow.nodes import NodeType from core.workflow.nodes.variable_assigner.common.helpers import get_updated_variables from core.workflow.variable_loader import VariableLoader +from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable from models import App, Conversation from models.enums import DraftVariableType @@ -56,16 +57,19 @@ class DraftVarLoader(VariableLoader): _engine: Engine # Application ID for which variables are being loaded. _app_id: str + _tenant_id: str _fallback_variables: Sequence[Variable] def __init__( self, engine: Engine, app_id: str, + tenant_id: str, fallback_variables: Sequence[Variable] | None = None, ) -> None: self._engine = engine self._app_id = app_id + self._tenant_id = tenant_id self._fallback_variables = fallback_variables or [] def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]: @@ -94,20 +98,17 @@ class DraftVarLoader(VariableLoader): selector_tuple = self._selector_to_tuple(variable.selector) variable_by_selector[selector_tuple] = variable - # If a conversation variable is referenced but not present in the draft variables table, - # fall back to returning the variable with its default value. - - fallback_var_by_selector = {} - for variable in self._fallback_variables: - selector_tuple = self._selector_to_tuple(variable.selector) - fallback_var_by_selector[selector_tuple] = variable - - for selector in selectors: - selector_tuple = self._selector_to_tuple(selector) - if selector_tuple in variable_by_selector: - continue - if selector_tuple in fallback_var_by_selector: - variable_by_selector[selector_tuple] = fallback_var_by_selector[selector_tuple] + # Important: + files: list[File] = [] + for draft_var in draft_vars: + value = draft_var.get_value() + if isinstance(value, FileSegment): + files.append(value.value) + elif isinstance(value, ArrayFileSegment): + files.extend(value.value) + with Session(bind=self._engine) as session: + storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id) + storage_key_loader.load_storage_keys(files) return list(variable_by_selector.values()) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index ac9db869cc..53a22c8e76 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -367,6 +367,7 @@ class WorkflowService: variable_loader = DraftVarLoader( engine=db.engine, app_id=app_model.id, + tenant_id=app_model.tenant_id, ) eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index c8ffeaf09c..30cd2e60cb 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -170,12 +170,14 @@ class TestWorkflowDraftVariableService(unittest.TestCase): @pytest.mark.usefixtures("flask_req_ctx") class TestDraftVariableLoader(unittest.TestCase): _test_app_id: str + _test_tenant_id: str _node1_id = "test_loader_node_1" _node_exec_id = str(uuid.uuid4()) def setUp(self): self._test_app_id = str(uuid.uuid4()) + self._test_tenant_id = str(uuid.uuid4()) sys_var = WorkflowDraftVariable.new_sys_variable( app_id=self._test_app_id, name="sys_var", @@ -218,12 +220,12 @@ class TestDraftVariableLoader(unittest.TestCase): session.commit() def test_variable_loader_with_empty_selector(self): - var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id) + var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id) variables = var_loader.load_variables([]) assert len(variables) == 0 def test_variable_loader_with_non_empty_selector(self): - var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id) + var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id) variables = var_loader.load_variables( [ [SYSTEM_VARIABLE_NODE_ID, "sys_var"],