fix(api): fix the issue that nested objects are not added to variable pool properly

pull/21478/head
QuantumGhost 11 months ago
parent 8ea27bc341
commit 09605dec89

@ -53,6 +53,7 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.utils import variable_utils
from libs.flask_utils import preserve_flask_contexts from libs.flask_utils import preserve_flask_contexts
from models.enums import UserFrom from models.enums import UserFrom
from models.workflow import WorkflowType from models.workflow import WorkflowType
@ -856,16 +857,12 @@ class GraphEngine:
:param variable_value: variable value :param variable_value: variable value
:return: :return:
""" """
self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value) variable_utils.append_variables_recursively(
self.graph_runtime_state.variable_pool,
# if variable_value is a dict, then recursively append variables node_id,
if isinstance(variable_value, dict): variable_key_list,
for key, value in variable_value.items(): variable_value,
# construct new key list )
new_key_list = variable_key_list + [key]
self._append_variables_recursively(
node_id=node_id, variable_key_list=new_key_list, variable_value=value
)
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
""" """

@ -0,0 +1,28 @@
from core.variables.segments import ObjectSegment, Segment
from core.workflow.entities.variable_pool import VariablePool, VariableValue
def append_variables_recursively(
pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment
):
"""
Append variables recursively
:param node_id: node id
:param variable_key_list: variable key list
:param variable_value: variable value
:return:
"""
pool.add([node_id] + variable_key_list, variable_value)
# if variable_value is a dict, then recursively append variables
if isinstance(variable_value, ObjectSegment):
variable_dict = variable_value.value
elif isinstance(variable_value, dict):
variable_dict = variable_value
else:
return
for key, value in variable_dict.items():
# construct new key list
new_key_list = variable_key_list + [key]
append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value)

@ -3,7 +3,9 @@ from collections.abc import Mapping, Sequence
from typing import Any, Protocol from typing import Any, Protocol
from core.variables import Variable from core.variables import Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.utils import variable_utils
class VariableLoader(Protocol): class VariableLoader(Protocol):
@ -76,4 +78,7 @@ def load_into_variable_pool(
variables_to_load.append(list(selector)) variables_to_load.append(list(selector))
loaded = variable_loader.load_variables(variables_to_load) loaded = variable_loader.load_variables(variables_to_load)
for var in loaded: for var in loaded:
variable_pool.add(var.selector, var) assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}"
variable_utils.append_variables_recursively(
variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var
)

@ -129,7 +129,8 @@ class WorkflowDraftVariableService:
) -> list[WorkflowDraftVariable]: ) -> list[WorkflowDraftVariable]:
ors = [] ors = []
for selector in selectors: for selector in selectors:
node_id, name = selector assert len(selector) >= MIN_SELECTORS_LENGTH, f"Invalid selector to get: {selector}"
node_id, name = selector[:2]
ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name)) ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name))
# NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as # NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as

Loading…
Cancel
Save