feat(api): load _storage_key for file types while loading draft variables

pull/20699/head
QuantumGhost 11 months ago
parent 2db7815098
commit 8f8465cd9f

@ -271,6 +271,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)
@ -353,6 +354,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
draft_var_srv = WorkflowDraftVariableService(db.session())
draft_var_srv.prefill_conversation_variable_default_values(workflow)

@ -317,6 +317,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
@ -400,6 +401,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
tenant_id=application_generate_entity.app_config.tenant_id,
)
return self._generate(
app_model=app_model,

@ -14,13 +14,14 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.variables import Segment, StringSegment, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import ArrayFileSegment
from core.variables.segments import ArrayFileSegment, FileSegment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.variable_assigner.common.helpers import get_updated_variables
from core.workflow.variable_loader import VariableLoader
from factories.file_factory import StorageKeyLoader
from factories.variable_factory import build_segment, segment_to_variable
from models import App, Conversation
from models.enums import DraftVariableType
@ -56,16 +57,19 @@ class DraftVarLoader(VariableLoader):
_engine: Engine
# Application ID for which variables are being loaded.
_app_id: str
_tenant_id: str
_fallback_variables: Sequence[Variable]
def __init__(
self,
engine: Engine,
app_id: str,
tenant_id: str,
fallback_variables: Sequence[Variable] | None = None,
) -> None:
self._engine = engine
self._app_id = app_id
self._tenant_id = tenant_id
self._fallback_variables = fallback_variables or []
def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]:
@ -94,20 +98,17 @@ class DraftVarLoader(VariableLoader):
selector_tuple = self._selector_to_tuple(variable.selector)
variable_by_selector[selector_tuple] = variable
# If a conversation variable is referenced but not present in the draft variables table,
# fall back to returning the variable with its default value.
fallback_var_by_selector = {}
for variable in self._fallback_variables:
selector_tuple = self._selector_to_tuple(variable.selector)
fallback_var_by_selector[selector_tuple] = variable
for selector in selectors:
selector_tuple = self._selector_to_tuple(selector)
if selector_tuple in variable_by_selector:
continue
if selector_tuple in fallback_var_by_selector:
variable_by_selector[selector_tuple] = fallback_var_by_selector[selector_tuple]
# Important:
files: list[File] = []
for draft_var in draft_vars:
value = draft_var.get_value()
if isinstance(value, FileSegment):
files.append(value.value)
elif isinstance(value, ArrayFileSegment):
files.extend(value.value)
with Session(bind=self._engine) as session:
storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id)
storage_key_loader.load_storage_keys(files)
return list(variable_by_selector.values())

@ -367,6 +367,7 @@ class WorkflowService:
variable_loader = DraftVarLoader(
engine=db.engine,
app_id=app_model.id,
tenant_id=app_model.tenant_id,
)
eclosing_node_type_and_id = draft_workflow.get_enclosing_node_type_and_id(node_config)

@ -170,12 +170,14 @@ class TestWorkflowDraftVariableService(unittest.TestCase):
@pytest.mark.usefixtures("flask_req_ctx")
class TestDraftVariableLoader(unittest.TestCase):
_test_app_id: str
_test_tenant_id: str
_node1_id = "test_loader_node_1"
_node_exec_id = str(uuid.uuid4())
def setUp(self):
self._test_app_id = str(uuid.uuid4())
self._test_tenant_id = str(uuid.uuid4())
sys_var = WorkflowDraftVariable.new_sys_variable(
app_id=self._test_app_id,
name="sys_var",
@ -218,12 +220,12 @@ class TestDraftVariableLoader(unittest.TestCase):
session.commit()
def test_variable_loader_with_empty_selector(self):
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id)
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
variables = var_loader.load_variables([])
assert len(variables) == 0
def test_variable_loader_with_non_empty_selector(self):
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id)
var_loader = DraftVarLoader(engine=db.engine, app_id=self._test_app_id, tenant_id=self._test_tenant_id)
variables = var_loader.load_variables(
[
[SYSTEM_VARIABLE_NODE_ID, "sys_var"],

Loading…
Cancel
Save