feat(api): load _storage_key for file types while loading draft variables

pull/20699/head
QuantumGhost 11 months ago
parent 2db7815098
commit 8f8465cd9f

@ -271,6 +271,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
var_loader = DraftVarLoader( var_loader = DraftVarLoader(
engine=db.engine, engine=db.engine,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
) )
draft_var_srv = WorkflowDraftVariableService(db.session()) draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow) draft_var_srv.prefill_conversation_variable_default_values(workflow)
@ -353,6 +354,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
var_loader = DraftVarLoader( var_loader = DraftVarLoader(
engine=db.engine, engine=db.engine,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
) )
draft_var_srv = WorkflowDraftVariableService(db.session()) draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow) draft_var_srv.prefill_conversation_variable_default_values(workflow)

@ -317,6 +317,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
var_loader = DraftVarLoader( var_loader = DraftVarLoader(
engine=db.engine, engine=db.engine,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
) )
return self._generate( return self._generate(
@ -400,6 +401,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
var_loader = DraftVarLoader( var_loader = DraftVarLoader(
engine=db.engine, engine=db.engine,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
) )
return self._generate( return self._generate(
app_model=app_model, app_model=app_model,

@ -14,13 +14,14 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File from core.file.models import File
from core.variables import Segment, StringSegment, Variable from core.variables import Segment, StringSegment, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import ArrayFileSegment from core.variables.segments import ArrayFileSegment, FileSegment
from core.variables.types import SegmentType from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.variable_assigner.common.helpers import get_updated_variables from core.workflow.nodes.variable_assigner.common.helpers import get_updated_variables
from core.workflow.variable_loader import VariableLoader from core.workflow.variable_loader import VariableLoader
from factories.file_factory import StorageKeyLoader
from factories.variable_factory import build_segment, segment_to_variable from factories.variable_factory import build_segment, segment_to_variable
from models import App, Conversation from models import App, Conversation
from models.enums import DraftVariableType from models.enums import DraftVariableType
@ -56,16 +57,19 @@ class DraftVarLoader(VariableLoader):
_engine: Engine _engine: Engine
# Application ID for which variables are being loaded. # Application ID for which variables are being loaded.
_app_id: str _app_id: str
_tenant_id: str
_fallback_variables: Sequence[Variable] _fallback_variables: Sequence[Variable]
def __init__( def __init__(
self, self,
engine: Engine, engine: Engine,
app_id: str, app_id: str,
tenant_id: str,
fallback_variables: Sequence[Variable] | None = None, fallback_variables: Sequence[Variable] | None = None,
) -> None: ) -> None:
self._engine = engine self._engine = engine
self._app_id = app_id self._app_id = app_id
self._tenant_id = tenant_id
self._fallback_variables = fallback_variables or [] self._fallback_variables = fallback_variables or []
def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]: def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]:
@ -94,20 +98,17 @@ class DraftVarLoader(VariableLoader):
selector_tuple = self._selector_to_tuple(variable.selector) selector_tuple = self._selector_to_tuple(variable.selector)
variable_by_selector[selector_tuple] = variable variable_by_selector[selector_tuple] = variable
# If a conversation variable is referenced but not present in the draft variables table, # Important:
# fall back to returning the variable with its default value. files: list[File] = []
for draft_var in draft_vars:
fallback_var_by_selector = {} value = draft_var.get_value()
for variable in self._fallback_variables: if isinstance(value, FileSegment):
selector_tuple = self._selector_to_tuple(variable.selector) files.append(value.value)
fallback_var_by_selector[selector_tuple] = variable elif isinstance(value, ArrayFileSegment):
files.extend(value.value)
for selector in selectors: with Session(bind=self._engine) as session:
selector_tuple = self._selector_to_tuple(selector) storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id)
if selector_tuple in variable_by_selector: storage_key_loader.load_storage_keys(files)
continue
if selector_tuple in fallback_var_by_selector:
variable_by_selector[selector_tuple] = fallback_var_by_selector[selector_tuple]
return list(variable_by_selector.values()) return list(variable_by_selector.values())

@ -367,6 +367,7 @@ class WorkflowService:
variable_loader = DraftVarLoader( variable_loader = DraftVarLoader(
engine=db.engine, engine=db.engine,
app_id=app_model.id, app_id=app_model.id,
tenant_id=app_model.tenant_id,
) )
eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config) eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)

@ -170,12 +170,14 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
@pytest.mark.usefixtures("flask_req_ctx") @pytest.mark.usefixtures("flask_req_ctx")
class TestDraftVariableLoader(unittest.TestCase): class TestDraftVariableLoader(unittest.TestCase):
_test_app_id: str _test_app_id: str
_test_tenant_id: str
_node1_id = "test_loader_node_1" _node1_id = "test_loader_node_1"
_node_exec_id = str(uuid.uuid4()) _node_exec_id = str(uuid.uuid4())
def setUp(self): def setUp(self):
self._test_app_id = str(uuid.uuid4()) self._test_app_id = str(uuid.uuid4())
self._test_tenant_id = str(uuid.uuid4())
sys_var = WorkflowDraftVariable.new_sys_variable( sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id, app_id=self._test_app_id,
name="sys_var", name="sys_var",
@ -218,12 +220,12 @@ class TestDraftVariableLoader(unittest.TestCase):
session.commit() session.commit()
def test_variable_loader_with_empty_selector(self): def test_variable_loader_with_empty_selector(self):
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id) var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
variables = var_loader.load_variables([]) variables = var_loader.load_variables([])
assert len(variables) == 0 assert len(variables) == 0
def test_variable_loader_with_non_empty_selector(self): def test_variable_loader_with_non_empty_selector(self):
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id) var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
variables = var_loader.load_variables( variables = var_loader.load_variables(
[ [
[SYSTEM_VARIABLE_NODE_ID, "sys_var"], [SYSTEM_VARIABLE_NODE_ID, "sys_var"],

Loading…
Cancel
Save