pull/22771/merge
sayThQ199 10 months ago committed by GitHub
commit b61bd08992
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -152,7 +152,9 @@ class GraphEngine:
try: try:
if self.init_params.workflow_type == WorkflowType.CHAT: if self.init_params.workflow_type == WorkflowType.CHAT:
stream_processor = AnswerStreamProcessor( stream_processor = AnswerStreamProcessor(
graph=self.graph, variable_pool=self.graph_runtime_state.variable_pool graph=self.graph,
variable_pool=self.graph_runtime_state.variable_pool,
node_run_state=self.graph_runtime_state.node_run_state,
) )
else: else:
stream_processor = EndStreamProcessor( stream_processor = EndStreamProcessor(

@ -6,7 +6,7 @@ from core.workflow.nodes.answer.entities import (
TextGenerateRouteChunk, TextGenerateRouteChunk,
VarGenerateRouteChunk, VarGenerateRouteChunk,
) )
from core.workflow.nodes.enums import ErrorStrategy, NodeType from core.workflow.nodes.enums import NodeType
from core.workflow.utils.variable_template_parser import VariableTemplateParser from core.workflow.utils.variable_template_parser import VariableTemplateParser
@ -18,8 +18,10 @@ class AnswerStreamGeneratorRouter:
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined] reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
) -> AnswerStreamGenerateRoute: ) -> AnswerStreamGenerateRoute:
""" """
Get stream generate routes. Initializes the stream generation routes for all Answer nodes in the workflow.
:return: This method performs a static analysis of the graph to parse the answer templates.
The old logic for pre-calculating static dependencies has been deprecated and removed,
as the decision logic is now handled dynamically at runtime by the AnswerStreamProcessor.
""" """
# parse stream output node value selectors of answer nodes # parse stream output node value selectors of answer nodes
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {} answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
@ -31,16 +33,8 @@ class AnswerStreamGeneratorRouter:
generate_route = cls._extract_generate_route_selectors(node_config) generate_route = cls._extract_generate_route_selectors(node_config)
answer_generate_route[answer_node_id] = generate_route answer_generate_route[answer_node_id] = generate_route
# fetch answer dependencies
answer_node_ids = list(answer_generate_route.keys())
answer_dependencies = cls._fetch_answers_dependencies(
answer_node_ids=answer_node_ids,
reverse_edge_mapping=reverse_edge_mapping,
node_id_config_mapping=node_id_config_mapping,
)
return AnswerStreamGenerateRoute( return AnswerStreamGenerateRoute(
answer_generate_route=answer_generate_route, answer_dependencies=answer_dependencies answer_generate_route=answer_generate_route
) )
@classmethod @classmethod
@ -90,85 +84,12 @@ class AnswerStreamGeneratorRouter:
:return: :return:
""" """
node_data = AnswerNodeData(**config.get("data", {})) node_data = AnswerNodeData(**config.get("data", {}))
# Trim whitespace from the answer template to prevent parsing issues with leading/trailing spaces.
if node_data.answer:
node_data.answer = node_data.answer.strip()
return cls.extract_generate_route_from_node_data(node_data) return cls.extract_generate_route_from_node_data(node_data)
@classmethod @classmethod
def _is_variable(cls, part, variable_keys): def _is_variable(cls, part, variable_keys):
cleaned_part = part.replace("{{", "").replace("}}", "") cleaned_part = part.replace("{{", "").replace("}}", "")
return part.startswith("{{") and cleaned_part in variable_keys return part.startswith("{{") and cleaned_part in variable_keys
@classmethod
def _fetch_answers_dependencies(
cls,
answer_node_ids: list[str],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
node_id_config_mapping: dict[str, dict],
) -> dict[str, list[str]]:
"""
Fetch answer dependencies
:param answer_node_ids: answer node ids
:param reverse_edge_mapping: reverse edge mapping
:param node_id_config_mapping: node id config mapping
:return:
"""
answer_dependencies: dict[str, list[str]] = {}
for answer_node_id in answer_node_ids:
if answer_dependencies.get(answer_node_id) is None:
answer_dependencies[answer_node_id] = []
cls._recursive_fetch_answer_dependencies(
current_node_id=answer_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies,
)
return answer_dependencies
@classmethod
def _recursive_fetch_answer_dependencies(
cls,
current_node_id: str,
answer_node_id: str,
node_id_config_mapping: dict[str, dict],
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
answer_dependencies: dict[str, list[str]],
) -> None:
"""
Recursive fetch answer dependencies
:param current_node_id: current node id
:param answer_node_id: answer node id
:param node_id_config_mapping: node id config mapping
:param reverse_edge_mapping: reverse edge mapping
:param answer_dependencies: answer dependencies
:return:
"""
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
for edge in reverse_edges:
source_node_id = edge.source_node_id
if source_node_id not in node_id_config_mapping:
continue
source_node_type = node_id_config_mapping[source_node_id].get("data", {}).get("type")
source_node_data = node_id_config_mapping[source_node_id].get("data", {})
if (
source_node_type
in {
NodeType.ANSWER,
NodeType.IF_ELSE,
NodeType.QUESTION_CLASSIFIER,
NodeType.ITERATION,
NodeType.LOOP,
NodeType.VARIABLE_ASSIGNER,
}
or source_node_data.get("error_strategy") == ErrorStrategy.FAIL_BRANCH
):
answer_dependencies[answer_node_id].append(source_node_id)
else:
cls._recursive_fetch_answer_dependencies(
current_node_id=source_node_id,
answer_node_id=answer_node_id,
node_id_config_mapping=node_id_config_mapping,
reverse_edge_mapping=reverse_edge_mapping,
answer_dependencies=answer_dependencies,
)

@ -11,6 +11,7 @@ from core.workflow.graph_engine.entities.event import (
NodeRunSucceededEvent, NodeRunSucceededEvent,
) )
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
@ -18,8 +19,8 @@ logger = logging.getLogger(__name__)
class AnswerStreamProcessor(StreamProcessor): class AnswerStreamProcessor(StreamProcessor):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: def __init__(self, graph: Graph, variable_pool: VariablePool, node_run_state: RuntimeRouteState) -> None:
super().__init__(graph, variable_pool) super().__init__(graph, variable_pool, node_run_state)
self.generate_routes = graph.answer_stream_generate_routes self.generate_routes = graph.answer_stream_generate_routes
self.route_position = {} self.route_position = {}
for answer_node_id in self.generate_routes.answer_generate_route: for answer_node_id in self.generate_routes.answer_generate_route:
@ -73,6 +74,80 @@ class AnswerStreamProcessor(StreamProcessor):
self.rest_node_ids = self.graph.node_ids.copy() self.rest_node_ids = self.graph.node_ids.copy()
self.current_stream_chunk_generating_node_ids = {} self.current_stream_chunk_generating_node_ids = {}
def _is_dynamic_dependencies_met(self, start_node_id: str) -> bool:
"""
Performs a dynamic, runtime dependency check by traversing backwards from a given start_node_id.
This method is the core of the new streaming architecture. Instead of relying on a pre-calculated,
static dependency map, it validates the actual execution path at the moment a stream event is received.
It queries the runtime state of the graph ('the logbook') to ensure that a valid, uninterrupted,
and logically sound path exists from the start_node_id all the way back to the graph's entry point.
The traversal logic handles:
- Basic node completion states (SUCCEEDED, FAILED, RUNNING).
- Complex branch nodes (If/Else), by checking which branch was actually taken during the run.
Paths from branches that were not taken are considered irrelevant ("parallel universes") and ignored.
This approach correctly handles complex topologies with join points (nodes with multiple inputs),
ensuring that streaming is only permitted when the true, logical dependency chain for the *current run*
has been successfully completed.
:param start_node_id: The node ID from which to begin the backward traversal (e.g., the LLM node).
:return: True if all dependencies on the active path are met, False otherwise.
"""
# Use a queue for BFS and a set to track visited nodes to prevent cycles
queue = [start_node_id]
visited = {start_node_id}
while queue:
current_node_id = queue.pop(0)
# Get the edges leading to the current node
parent_edges = self.graph.reverse_edge_mapping.get(current_node_id, [])
if not parent_edges:
continue
for edge in parent_edges:
parent_node_id = edge.source_node_id
if parent_node_id in visited:
continue
visited.add(parent_node_id)
# Find the latest execution state of the parent node in the current run
parent_node_run_state = None
for state in self.node_run_state.node_state_mapping.values():
if state.node_id == parent_node_id:
parent_node_run_state = state
break # Assume the last found state is the latest for simplicity
if not parent_node_run_state or parent_node_run_state.status == RouteNodeState.Status.RUNNING:
return False
if parent_node_run_state.status in [RouteNodeState.Status.FAILED, RouteNodeState.Status.EXCEPTION]:
return False
# If the parent is a branch node, check if the executed branch leads to the current node
parent_node_config = self.graph.node_id_config_mapping.get(parent_node_id, {})
parent_node_type = parent_node_config.get('data', {}).get('type')
is_branch_node = parent_node_type in ['if-else', 'question-classifier'] # Example branch types
if is_branch_node:
run_result = parent_node_run_state.node_run_result
chosen_handle = run_result.edge_source_handle if run_result else None
required_handle = edge.run_condition.branch_identify if edge.run_condition else None
# If the chosen branch does not match the path we are traversing, this dependency path is irrelevant
if chosen_handle and required_handle and chosen_handle != required_handle:
continue # This path was not taken, so it's not a dependency
# If all checks pass, add the parent to the queue to continue traversing up
queue.append(parent_node_id)
return True
def _generate_stream_outputs_when_node_finished( def _generate_stream_outputs_when_node_finished(
self, event: NodeRunSucceededEvent self, event: NodeRunSucceededEvent
) -> Generator[GraphEngineEvent, None, None]: ) -> Generator[GraphEngineEvent, None, None]:
@ -85,10 +160,7 @@ class AnswerStreamProcessor(StreamProcessor):
# all depends on answer node id not in rest node ids # all depends on answer node id not in rest node ids
if event.route_node_state.node_id != answer_node_id and ( if event.route_node_state.node_id != answer_node_id and (
answer_node_id not in self.rest_node_ids answer_node_id not in self.rest_node_ids
or not all( or not self._is_dynamic_dependencies_met(answer_node_id) # Using dynamic check for final output as well
dep_id not in self.rest_node_ids
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]
)
): ):
continue continue
@ -156,33 +228,13 @@ class AnswerStreamProcessor(StreamProcessor):
for answer_node_id, route_position in self.route_position.items(): for answer_node_id, route_position in self.route_position.items():
if answer_node_id not in self.rest_node_ids: if answer_node_id not in self.rest_node_ids:
continue continue
# Remove current node id from answer dependencies to support stream output if it is a success branch
answer_dependencies = self.generate_routes.answer_dependencies # New dynamic dependency check, replacing the old static dependency list.
edge_mapping = self.graph.edge_mapping.get(event.node_id) source_node_id_for_check = event.from_variable_selector[0]
success_edge = ( all_deps_finished = self._is_dynamic_dependencies_met(start_node_id=source_node_id_for_check)
next(
( if all_deps_finished:
edge if route_position >= len(self.generate_routes.answer_generate_route.get(answer_node_id, [])):
for edge in edge_mapping
if edge.run_condition
and edge.run_condition.type == "branch_identify"
and edge.run_condition.branch_identify == "success-branch"
),
None,
)
if edge_mapping
else None
)
if (
event.node_id in answer_dependencies[answer_node_id]
and success_edge
and success_edge.target_node_id == answer_node_id
):
answer_dependencies[answer_node_id].remove(event.node_id)
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
# all depends on answer node id not in rest node ids
if all(dep_id not in self.rest_node_ids for dep_id in answer_dependencies_ids):
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
continue continue
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position] route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]

@ -6,14 +6,16 @@ from typing import Optional
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunExceptionEvent, NodeRunSucceededEvent
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class StreamProcessor(ABC): class StreamProcessor(ABC):
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None: def __init__(self, graph: Graph, variable_pool: VariablePool, node_run_state: RuntimeRouteState) -> None:
self.graph = graph self.graph = graph
self.variable_pool = variable_pool self.variable_pool = variable_pool
self.node_run_state = node_run_state
self.rest_node_ids = graph.node_ids.copy() self.rest_node_ids = graph.node_ids.copy()
@abstractmethod @abstractmethod
@ -21,6 +23,16 @@ class StreamProcessor(ABC):
raise NotImplementedError raise NotImplementedError
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None: def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent | NodeRunExceptionEvent) -> None:
"""
Prunes unreachable branches from the `rest_node_ids` list after a branch node has executed.
This method implements a conservative, non-recursive pruning strategy to prevent a critical bug
where the pruning process would incorrectly "spread" across join points (nodes with multiple inputs)
and erroneously remove shared downstream nodes that should have been preserved.
By only removing the immediate first node of each determined unreachable branch, we ensure that
the integrity of shared paths in complex graph topologies is maintained.
"""
finished_node_id = event.route_node_state.node_id finished_node_id = event.route_node_state.node_id
if finished_node_id not in self.rest_node_ids: if finished_node_id not in self.rest_node_ids:
return return
@ -68,9 +80,14 @@ class StreamProcessor(ABC):
): ):
continue continue
unreachable_first_node_ids.append(edge.target_node_id) unreachable_first_node_ids.append(edge.target_node_id)
unreachable_first_node_ids = list(set(unreachable_first_node_ids) - set(reachable_node_ids))
for node_id in unreachable_first_node_ids: # Instead of recursively removing the entire unreachable branch,
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids) # which can cause issues with complex join points,
# we will only remove the immediate first node of the unreachable branch.
# This prevents the removal logic from incorrectly pruning shared paths downstream.
for node_id in list(set(unreachable_first_node_ids) - set(reachable_node_ids)):
if node_id in self.rest_node_ids:
self.rest_node_ids.remove(node_id)
def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]: def _fetch_node_ids_in_reachable_branch(self, node_id: str, branch_identify: Optional[str] = None) -> list[str]:
if node_id not in self.rest_node_ids: if node_id not in self.rest_node_ids:

@ -57,9 +57,6 @@ class AnswerStreamGenerateRoute(BaseModel):
AnswerStreamGenerateRoute entity AnswerStreamGenerateRoute entity
""" """
answer_dependencies: dict[str, list[str]] = Field(
..., description="answer dependencies (answer node id -> dependent answer node ids)"
)
answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field(
..., description="answer generate route (answer node id -> generate route chunks)" ..., description="answer generate route (answer node id -> generate route chunks)"
) )

Loading…
Cancel
Save