Merge main
commit
b0d53c0ac4
@ -1,203 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowStartedEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowSucceededEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowFailedEvent(
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeStartedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
node_run_index=node_run_index,
|
||||
predecessor_node_id=predecessor_node_id
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeSucceededEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeFailedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text,
|
||||
metadata={
|
||||
"node_id": node_id,
|
||||
**metadata
|
||||
}
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationStartEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
metadata=metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any]) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self._queue_manager._publish(
|
||||
QueueIterationNextEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
index=index,
|
||||
node_run_index=node_run_index,
|
||||
output=output
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self._queue_manager._publish(
|
||||
QueueIterationCompletedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
outputs=outputs
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
event,
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
@ -1,200 +0,0 @@
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowEventTriggerCallback(WorkflowCallback):
|
||||
|
||||
def __init__(self, queue_manager: AppQueueManager, workflow: Workflow):
|
||||
self._queue_manager = queue_manager
|
||||
|
||||
def on_workflow_run_started(self) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowStartedEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowSucceededEvent(),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueWorkflowFailedEvent(
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeStartedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
node_run_index=node_run_index,
|
||||
predecessor_node_id=predecessor_node_id
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeSucceededEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
execution_metadata=execution_metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueNodeFailedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
process_data=process_data,
|
||||
error=error
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueTextChunkEvent(
|
||||
text=text,
|
||||
metadata={
|
||||
"node_id": node_id,
|
||||
**metadata
|
||||
}
|
||||
), PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: dict = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationStartEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
node_data=node_data,
|
||||
inputs=inputs,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
metadata=metadata
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any]) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationNextEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
index=index,
|
||||
node_run_index=node_run_index,
|
||||
output=output
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
self._queue_manager.publish(
|
||||
QueueIterationCompletedEvent(
|
||||
node_id=node_id,
|
||||
node_type=node_type,
|
||||
node_run_index=node_run_index,
|
||||
outputs=outputs
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
"""
|
||||
pass
|
||||
@ -0,0 +1,379 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from extensions.ext_database import db
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner(AppRunner):
|
||||
def __init__(self, queue_manager: AppQueueManager):
|
||||
self.queue_manager = queue_manager
|
||||
|
||||
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
|
||||
"""
|
||||
Init graph
|
||||
"""
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
return graph
|
||||
|
||||
def _get_graph_and_variable_pool_of_single_iteration(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict,
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
"""
|
||||
Get variable pool of single iteration
|
||||
"""
|
||||
# fetch workflow graph
|
||||
graph_config = workflow.graph_dict
|
||||
if not graph_config:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if 'nodes' not in graph_config or 'edges' not in graph_config:
|
||||
raise ValueError('nodes or edges not found in workflow graph')
|
||||
|
||||
if not isinstance(graph_config.get('nodes'), list):
|
||||
raise ValueError('nodes in workflow graph must be a list')
|
||||
|
||||
if not isinstance(graph_config.get('edges'), list):
|
||||
raise ValueError('edges in workflow graph must be a list')
|
||||
|
||||
# filter nodes only in iteration
|
||||
node_configs = [
|
||||
node for node in graph_config.get('nodes', [])
|
||||
if node.get('id') == node_id or node.get('data', {}).get('iteration_id', '') == node_id
|
||||
]
|
||||
|
||||
graph_config['nodes'] = node_configs
|
||||
|
||||
node_ids = [node.get('id') for node in node_configs]
|
||||
|
||||
# filter edges only in iteration
|
||||
edge_configs = [
|
||||
edge for edge in graph_config.get('edges', [])
|
||||
if (edge.get('source') is None or edge.get('source') in node_ids)
|
||||
and (edge.get('target') is None or edge.get('target') in node_ids)
|
||||
]
|
||||
|
||||
graph_config['edges'] = edge_configs
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=node_id
|
||||
)
|
||||
|
||||
if not graph:
|
||||
raise ValueError('graph not found in workflow')
|
||||
|
||||
# fetch node config from node id
|
||||
iteration_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get('id') == node_id:
|
||||
iteration_node_config = node
|
||||
break
|
||||
|
||||
if not iteration_node_config:
|
||||
raise ValueError('iteration node id not found in workflow graph')
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType.value_of(iteration_node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict,
|
||||
config=iteration_node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
WorkflowEntry.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=IterationNodeData(**iteration_node_config.get('data', {}))
|
||||
)
|
||||
|
||||
return graph, variable_pool
|
||||
|
||||
def _handle_event(self, workflow_entry: WorkflowEntry, event: GraphEngineEvent) -> None:
|
||||
"""
|
||||
Handle event
|
||||
:param workflow_entry: workflow entry
|
||||
:param event: event
|
||||
"""
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowStartedEvent(
|
||||
graph_runtime_state=workflow_entry.graph_engine.graph_runtime_state
|
||||
)
|
||||
)
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowSucceededEvent(outputs=event.outputs)
|
||||
)
|
||||
elif isinstance(event, GraphRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueWorkflowFailedEvent(error=event.error)
|
||||
)
|
||||
elif isinstance(event, NodeRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeStartedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
node_run_index=event.route_node_state.index,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueNodeSucceededEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
execution_metadata=event.route_node_state.node_run_result.metadata
|
||||
if event.route_node_state.node_run_result else {},
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueNodeFailedEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.route_node_state.start_at,
|
||||
inputs=event.route_node_state.node_run_result.inputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
process_data=event.route_node_state.node_run_result.process_data
|
||||
if event.route_node_state.node_run_result else {},
|
||||
outputs=event.route_node_state.node_run_result.outputs
|
||||
if event.route_node_state.node_run_result else {},
|
||||
error=event.route_node_state.node_run_result.error
|
||||
if event.route_node_state.node_run_result
|
||||
and event.route_node_state.node_run_result.error
|
||||
else "Unknown error",
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
self._publish_event(
|
||||
QueueTextChunkEvent(
|
||||
text=event.chunk_content,
|
||||
from_variable_selector=event.from_variable_selector,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
self._publish_event(
|
||||
QueueRetrieverResourcesEvent(
|
||||
retriever_resources=event.retriever_resources,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunStartedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunSucceededEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id
|
||||
)
|
||||
)
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
self._publish_event(
|
||||
QueueParallelBranchRunFailedEvent(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
error=event.error
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunStartedEvent):
|
||||
self._publish_event(
|
||||
QueueIterationStartEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
metadata=event.metadata
|
||||
)
|
||||
)
|
||||
elif isinstance(event, IterationRunNextEvent):
|
||||
self._publish_event(
|
||||
QueueIterationNextEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
index=event.index,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
output=event.pre_iteration_output,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, (IterationRunSucceededEvent | IterationRunFailedEvent)):
|
||||
self._publish_event(
|
||||
QueueIterationCompletedEvent(
|
||||
node_execution_id=event.iteration_id,
|
||||
node_id=event.iteration_node_id,
|
||||
node_type=event.iteration_node_type,
|
||||
node_data=event.iteration_node_data,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
start_at=event.start_at,
|
||||
node_run_index=workflow_entry.graph_engine.graph_runtime_state.node_run_steps,
|
||||
inputs=event.inputs,
|
||||
outputs=event.outputs,
|
||||
metadata=event.metadata,
|
||||
steps=event.steps,
|
||||
error=event.error if isinstance(event, IterationRunFailedEvent) else None
|
||||
)
|
||||
)
|
||||
|
||||
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
|
||||
"""
|
||||
Get workflow
|
||||
"""
|
||||
# fetch workflow by workflow_id
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.filter(
|
||||
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# return workflow
|
||||
return workflow
|
||||
|
||||
def _publish_event(self, event: AppQueueEvent) -> None:
|
||||
self.queue_manager.publish(
|
||||
event,
|
||||
PublishFrom.APPLICATION_MANAGER
|
||||
)
|
||||
@ -1,16 +0,0 @@
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import AdvancedChatTaskState, WorkflowTaskState
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from models.account import Account
|
||||
from models.model import EndUser
|
||||
from models.workflow import Workflow
|
||||
|
||||
|
||||
class WorkflowCycleStateManager:
|
||||
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
|
||||
_workflow: Workflow
|
||||
_user: Union[Account, EndUser]
|
||||
_task_state: Union[AdvancedChatTaskState, WorkflowTaskState]
|
||||
_workflow_system_variables: dict[SystemVariableKey, Any]
|
||||
@ -1,290 +0,0 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Union
|
||||
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
NodeExecutionInfo,
|
||||
WorkflowIterationState,
|
||||
)
|
||||
from core.app.task_pipeline.workflow_cycle_state_manager import WorkflowCycleStateManager
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.workflow_engine_manager import WorkflowEngineManager
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
)
|
||||
|
||||
|
||||
class WorkflowIterationCycleManage(WorkflowCycleStateManager):
|
||||
_iteration_state: WorkflowIterationState = None
|
||||
|
||||
def _init_iteration_state(self) -> WorkflowIterationState:
|
||||
if not self._iteration_state:
|
||||
self._iteration_state = WorkflowIterationState(
|
||||
current_iterations={}
|
||||
)
|
||||
|
||||
def _handle_iteration_to_stream_response(self, task_id: str, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) \
|
||||
-> Union[IterationNodeStartStreamResponse, IterationNodeNextStreamResponse, IterationNodeCompletedStreamResponse]:
|
||||
"""
|
||||
Handle iteration to stream response
|
||||
:param task_id: task id
|
||||
:param event: iteration event
|
||||
:return:
|
||||
"""
|
||||
if isinstance(event, QueueIterationStartEvent):
|
||||
return IterationNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs,
|
||||
metadata=event.metadata
|
||||
)
|
||||
)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
|
||||
return IterationNodeNextStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=current_iteration.node_data.title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={}
|
||||
)
|
||||
)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
|
||||
return IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=current_iteration.node_data.title,
|
||||
outputs=event.outputs,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=current_iteration.inputs,
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
error=None,
|
||||
elapsed_time=time.perf_counter() - current_iteration.started_at,
|
||||
total_tokens=current_iteration.total_tokens,
|
||||
execution_metadata={
|
||||
'total_tokens': current_iteration.total_tokens,
|
||||
},
|
||||
finished_at=int(time.time()),
|
||||
steps=current_iteration.current_index
|
||||
)
|
||||
)
|
||||
|
||||
def _init_iteration_execution_from_workflow_run(self,
|
||||
workflow_run: WorkflowRun,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_title: str,
|
||||
node_run_index: int = 1,
|
||||
inputs: Optional[dict] = None,
|
||||
predecessor_node_id: Optional[str] = None
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_node_execution = WorkflowNodeExecution(
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
workflow_run_id=workflow_run.id,
|
||||
predecessor_node_id=predecessor_node_id,
|
||||
index=node_run_index,
|
||||
node_id=node_id,
|
||||
node_type=node_type.value,
|
||||
inputs=json.dumps(inputs) if inputs else None,
|
||||
title=node_title,
|
||||
status=WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
created_by_role=workflow_run.created_by_role,
|
||||
created_by=workflow_run.created_by,
|
||||
execution_metadata=json.dumps({
|
||||
'started_run_index': node_run_index + 1,
|
||||
'current_index': 0,
|
||||
'steps_boundary': [],
|
||||
}),
|
||||
created_at=datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
)
|
||||
|
||||
db.session.add(workflow_node_execution)
|
||||
db.session.commit()
|
||||
db.session.refresh(workflow_node_execution)
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_iteration_operation(self, event: QueueIterationStartEvent | QueueIterationNextEvent | QueueIterationCompletedEvent) -> WorkflowNodeExecution:
|
||||
if isinstance(event, QueueIterationStartEvent):
|
||||
return self._handle_iteration_started(event)
|
||||
elif isinstance(event, QueueIterationNextEvent):
|
||||
return self._handle_iteration_next(event)
|
||||
elif isinstance(event, QueueIterationCompletedEvent):
|
||||
return self._handle_iteration_completed(event)
|
||||
|
||||
def _handle_iteration_started(self, event: QueueIterationStartEvent) -> WorkflowNodeExecution:
|
||||
self._init_iteration_state()
|
||||
|
||||
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == self._task_state.workflow_run_id).first()
|
||||
workflow_node_execution = self._init_iteration_execution_from_workflow_run(
|
||||
workflow_run=workflow_run,
|
||||
node_id=event.node_id,
|
||||
node_type=NodeType.ITERATION,
|
||||
node_title=event.node_data.title,
|
||||
node_run_index=event.node_run_index,
|
||||
inputs=event.inputs,
|
||||
predecessor_node_id=event.predecessor_node_id
|
||||
)
|
||||
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=NodeType.ITERATION,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.ran_node_execution_infos[event.node_id] = latest_node_execution_info
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
self._iteration_state.current_iterations[event.node_id] = WorkflowIterationState.Data(
|
||||
parent_iteration_id=None,
|
||||
iteration_id=event.node_id,
|
||||
current_index=0,
|
||||
iteration_steps_boundary=[],
|
||||
node_execution_id=workflow_node_execution.id,
|
||||
started_at=time.perf_counter(),
|
||||
inputs=event.inputs,
|
||||
total_tokens=0,
|
||||
node_data=event.node_data
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_iteration_next(self, event: QueueIterationNextEvent) -> WorkflowNodeExecution:
|
||||
if event.node_id not in self._iteration_state.current_iterations:
|
||||
return
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
current_iteration.current_index = event.index
|
||||
current_iteration.iteration_steps_boundary.append(event.node_run_index)
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
|
||||
if original_node_execution_metadata:
|
||||
original_node_execution_metadata['current_index'] = event.index
|
||||
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
|
||||
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
|
||||
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _handle_iteration_completed(self, event: QueueIterationCompletedEvent):
|
||||
if event.node_id not in self._iteration_state.current_iterations:
|
||||
return
|
||||
|
||||
current_iteration = self._iteration_state.current_iterations[event.node_id]
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.outputs = json.dumps(WorkflowEngineManager.handle_special_values(event.outputs)) if event.outputs else None
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
|
||||
|
||||
original_node_execution_metadata = workflow_node_execution.execution_metadata_dict
|
||||
if original_node_execution_metadata:
|
||||
original_node_execution_metadata['steps_boundary'] = current_iteration.iteration_steps_boundary
|
||||
original_node_execution_metadata['total_tokens'] = current_iteration.total_tokens
|
||||
workflow_node_execution.execution_metadata = json.dumps(original_node_execution_metadata)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
# remove current iteration
|
||||
self._iteration_state.current_iterations.pop(event.node_id, None)
|
||||
|
||||
# set latest node execution info
|
||||
latest_node_execution_info = NodeExecutionInfo(
|
||||
workflow_node_execution_id=workflow_node_execution.id,
|
||||
node_type=NodeType.ITERATION,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self._task_state.latest_node_execution_info = latest_node_execution_info
|
||||
|
||||
db.session.close()
|
||||
|
||||
def _handle_iteration_exception(self, task_id: str, error: str) -> Generator[IterationNodeCompletedStreamResponse, None, None]:
|
||||
"""
|
||||
Handle iteration exception
|
||||
"""
|
||||
if not self._iteration_state or not self._iteration_state.current_iterations:
|
||||
return
|
||||
|
||||
for node_id, current_iteration in self._iteration_state.current_iterations.items():
|
||||
workflow_node_execution: WorkflowNodeExecution = db.session.query(WorkflowNodeExecution).filter(
|
||||
WorkflowNodeExecution.id == current_iteration.node_execution_id
|
||||
).first()
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.elapsed_time = time.perf_counter() - current_iteration.started_at
|
||||
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
yield IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=self._task_state.workflow_run_id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=node_id,
|
||||
node_id=node_id,
|
||||
node_type=NodeType.ITERATION.value,
|
||||
title=current_iteration.node_data.title,
|
||||
outputs={},
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=current_iteration.inputs,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error,
|
||||
elapsed_time=time.perf_counter() - current_iteration.started_at,
|
||||
total_tokens=current_iteration.total_tokens,
|
||||
execution_metadata={
|
||||
'total_tokens': current_iteration.total_tokens,
|
||||
},
|
||||
finished_at=int(time.time()),
|
||||
steps=current_iteration.current_index
|
||||
)
|
||||
)
|
||||
@ -1,116 +1,15 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.app.entities.queue_entities import AppQueueEvent
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent
|
||||
|
||||
|
||||
class WorkflowCallback(ABC):
|
||||
@abstractmethod
|
||||
def on_workflow_run_started(self) -> None:
|
||||
def on_event(
|
||||
self,
|
||||
event: GraphEngineEvent
|
||||
) -> None:
|
||||
"""
|
||||
Workflow run started
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_run_succeeded(self) -> None:
|
||||
"""
|
||||
Workflow run succeeded
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_run_failed(self, error: str) -> None:
|
||||
"""
|
||||
Workflow run failed
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_node_execute_started(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
node_run_index: int = 1,
|
||||
predecessor_node_id: Optional[str] = None) -> None:
|
||||
"""
|
||||
Workflow node execute started
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_node_execute_succeeded(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
inputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
execution_metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute succeeded
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_node_execute_failed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData,
|
||||
error: str,
|
||||
inputs: Optional[dict] = None,
|
||||
outputs: Optional[dict] = None,
|
||||
process_data: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Workflow node execute failed
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_node_text_chunk(self, node_id: str, text: str, metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish text chunk
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_iteration_started(self,
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int = 1,
|
||||
node_data: Optional[BaseNodeData] = None,
|
||||
inputs: Optional[dict] = None,
|
||||
predecessor_node_id: Optional[str] = None,
|
||||
metadata: Optional[dict] = None) -> None:
|
||||
"""
|
||||
Publish iteration started
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_iteration_next(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
index: int,
|
||||
node_run_index: int,
|
||||
output: Optional[Any],
|
||||
) -> None:
|
||||
"""
|
||||
Publish iteration next
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_workflow_iteration_completed(self, node_id: str,
|
||||
node_type: NodeType,
|
||||
node_run_index: int,
|
||||
outputs: dict) -> None:
|
||||
"""
|
||||
Publish iteration completed
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def on_event(self, event: AppQueueEvent) -> None:
|
||||
"""
|
||||
Publish event
|
||||
Published event
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -1,10 +1,8 @@
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
|
||||
|
||||
class WorkflowNodeRunFailedError(Exception):
|
||||
def __init__(self, node_id: str, node_type: NodeType, node_title: str, error: str):
|
||||
self.node_id = node_id
|
||||
self.node_type = node_type
|
||||
self.node_title = node_title
|
||||
def __init__(self, node_instance: BaseNode, error: str):
|
||||
self.node_instance = node_instance
|
||||
self.error = error
|
||||
super().__init__(f"Node {node_title} run failed: {error}")
|
||||
super().__init__(f"Node {node_instance.node_data.title} run failed: {error}")
|
||||
|
||||
@ -0,0 +1,31 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class RunConditionHandler(ABC):
|
||||
def __init__(self,
|
||||
init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
condition: RunCondition):
|
||||
self.init_params = init_params
|
||||
self.graph = graph
|
||||
self.condition = condition
|
||||
|
||||
@abstractmethod
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
raise NotImplementedError
|
||||
@ -0,0 +1,28 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class BranchIdentifyRunConditionHandler(RunConditionHandler):
|
||||
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.branch_identify:
|
||||
raise Exception("Branch identify is required")
|
||||
|
||||
run_result = previous_route_node_state.node_run_result
|
||||
if not run_result:
|
||||
return False
|
||||
|
||||
if not run_result.edge_source_handle:
|
||||
return False
|
||||
|
||||
return self.condition.branch_identify == run_result.edge_source_handle
|
||||
@ -0,0 +1,32 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.utils.condition.processor import ConditionProcessor
|
||||
|
||||
|
||||
class ConditionRunConditionHandlerHandler(RunConditionHandler):
|
||||
def check(self,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
previous_route_node_state: RouteNodeState
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the condition can be executed
|
||||
|
||||
:param graph_runtime_state: graph runtime state
|
||||
:param previous_route_node_state: previous route node state
|
||||
:return: bool
|
||||
"""
|
||||
if not self.condition.conditions:
|
||||
return True
|
||||
|
||||
# process condition
|
||||
condition_processor = ConditionProcessor()
|
||||
input_conditions, group_result = condition_processor.process_conditions(
|
||||
variable_pool=graph_runtime_state.variable_pool,
|
||||
conditions=self.condition.conditions
|
||||
)
|
||||
|
||||
# Apply the logical operator for the current case
|
||||
compare_result = all(group_result)
|
||||
|
||||
return compare_result
|
||||
@ -0,0 +1,35 @@
|
||||
from core.workflow.graph_engine.condition_handlers.base_handler import RunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.branch_identify_handler import BranchIdentifyRunConditionHandler
|
||||
from core.workflow.graph_engine.condition_handlers.condition_handler import ConditionRunConditionHandlerHandler
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
|
||||
|
||||
class ConditionManager:
|
||||
@staticmethod
|
||||
def get_condition_handler(
|
||||
init_params: GraphInitParams,
|
||||
graph: Graph,
|
||||
run_condition: RunCondition
|
||||
) -> RunConditionHandler:
|
||||
"""
|
||||
Get condition handler
|
||||
|
||||
:param init_params: init params
|
||||
:param graph: graph
|
||||
:param run_condition: run condition
|
||||
:return: condition handler
|
||||
"""
|
||||
if run_condition.type == "branch_identify":
|
||||
return BranchIdentifyRunConditionHandler(
|
||||
init_params=init_params,
|
||||
graph=graph,
|
||||
condition=run_condition
|
||||
)
|
||||
else:
|
||||
return ConditionRunConditionHandlerHandler(
|
||||
init_params=init_params,
|
||||
graph=graph,
|
||||
condition=run_condition
|
||||
)
|
||||
@ -0,0 +1,163 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
|
||||
|
||||
class GraphEngineEvent(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
###########################################
|
||||
# Graph Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseGraphEvent(GraphEngineEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunStartedEvent(BaseGraphEvent):
|
||||
pass
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
"""outputs"""
|
||||
|
||||
|
||||
class GraphRunFailedEvent(BaseGraphEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Node Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseNodeEvent(GraphEngineEvent):
|
||||
id: str = Field(..., description="node execution id")
|
||||
node_id: str = Field(..., description="node id")
|
||||
node_type: NodeType = Field(..., description="node type")
|
||||
node_data: BaseNodeData = Field(..., description="node data")
|
||||
route_node_state: RouteNodeState = Field(..., description="route node state")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class NodeRunStartedEvent(BaseNodeEvent):
|
||||
predecessor_node_id: Optional[str] = None
|
||||
"""predecessor node id"""
|
||||
|
||||
|
||||
class NodeRunStreamChunkEvent(BaseNodeEvent):
|
||||
chunk_content: str = Field(..., description="chunk content")
|
||||
from_variable_selector: Optional[list[str]] = None
|
||||
"""from variable selector"""
|
||||
|
||||
|
||||
class NodeRunRetrieverResourceEvent(BaseNodeEvent):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
class NodeRunSucceededEvent(BaseNodeEvent):
|
||||
pass
|
||||
|
||||
|
||||
class NodeRunFailedEvent(BaseNodeEvent):
|
||||
error: str = Field(..., description="error")
|
||||
|
||||
|
||||
###########################################
|
||||
# Parallel Branch Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseParallelBranchEvent(GraphEngineEvent):
|
||||
parallel_id: str = Field(..., description="parallel id")
|
||||
"""parallel id"""
|
||||
parallel_start_node_id: str = Field(..., description="parallel start node id")
|
||||
"""parallel start node id"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
in_iteration_id: Optional[str] = None
|
||||
"""iteration id if node is in iteration"""
|
||||
|
||||
|
||||
class ParallelBranchRunStartedEvent(BaseParallelBranchEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelBranchRunSucceededEvent(BaseParallelBranchEvent):
|
||||
pass
|
||||
|
||||
|
||||
class ParallelBranchRunFailedEvent(BaseParallelBranchEvent):
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
###########################################
|
||||
# Iteration Events
|
||||
###########################################
|
||||
|
||||
|
||||
class BaseIterationEvent(GraphEngineEvent):
|
||||
iteration_id: str = Field(..., description="iteration node execution id")
|
||||
iteration_node_id: str = Field(..., description="iteration node id")
|
||||
iteration_node_type: NodeType = Field(..., description="node type, iteration or loop")
|
||||
iteration_node_data: BaseNodeData = Field(..., description="node data")
|
||||
parallel_id: Optional[str] = None
|
||||
"""parallel id if node is in parallel"""
|
||||
parallel_start_node_id: Optional[str] = None
|
||||
"""parallel start node id if node is in parallel"""
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id if node is in parallel"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id if node is in parallel"""
|
||||
|
||||
|
||||
class IterationRunStartedEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
predecessor_node_id: Optional[str] = None
|
||||
|
||||
|
||||
class IterationRunNextEvent(BaseIterationEvent):
|
||||
index: int = Field(..., description="index")
|
||||
pre_iteration_output: Optional[Any] = Field(None, description="pre iteration output")
|
||||
|
||||
|
||||
class IterationRunSucceededEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
steps: int = 0
|
||||
|
||||
|
||||
class IterationRunFailedEvent(BaseIterationEvent):
|
||||
start_at: datetime = Field(..., description="start at")
|
||||
inputs: Optional[dict[str, Any]] = None
|
||||
outputs: Optional[dict[str, Any]] = None
|
||||
metadata: Optional[dict[str, Any]] = None
|
||||
steps: int = 0
|
||||
error: str = Field(..., description="failed reason")
|
||||
|
||||
|
||||
InNodeEvent = BaseNodeEvent | BaseParallelBranchEvent | BaseIterationEvent
|
||||
@ -0,0 +1,692 @@
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.nodes.answer.answer_stream_generate_router import AnswerStreamGeneratorRouter
|
||||
from core.workflow.nodes.answer.entities import AnswerStreamGenerateRoute
|
||||
from core.workflow.nodes.end.end_stream_generate_router import EndStreamGeneratorRouter
|
||||
from core.workflow.nodes.end.entities import EndStreamParam
|
||||
|
||||
|
||||
class GraphEdge(BaseModel):
|
||||
source_node_id: str = Field(..., description="source node id")
|
||||
target_node_id: str = Field(..., description="target node id")
|
||||
run_condition: Optional[RunCondition] = None
|
||||
"""run condition"""
|
||||
|
||||
|
||||
class GraphParallel(BaseModel):
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="random uuid parallel id")
|
||||
start_from_node_id: str = Field(..., description="start from node id")
|
||||
parent_parallel_id: Optional[str] = None
|
||||
"""parent parallel id"""
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
"""parent parallel start node id"""
|
||||
end_to_node_id: Optional[str] = None
|
||||
"""end to node id"""
|
||||
|
||||
|
||||
class Graph(BaseModel):
|
||||
root_node_id: str = Field(..., description="root node id of the graph")
|
||||
node_ids: list[str] = Field(default_factory=list, description="graph node ids")
|
||||
node_id_config_mapping: dict[str, dict] = Field(
|
||||
default_factory=list,
|
||||
description="node configs mapping (node id: node config)"
|
||||
)
|
||||
edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict,
|
||||
description="graph edge mapping (source node id: edges)"
|
||||
)
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = Field(
|
||||
default_factory=dict,
|
||||
description="reverse graph edge mapping (target node id: edges)"
|
||||
)
|
||||
parallel_mapping: dict[str, GraphParallel] = Field(
|
||||
default_factory=dict,
|
||||
description="graph parallel mapping (parallel id: parallel)"
|
||||
)
|
||||
node_parallel_mapping: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="graph node parallel mapping (node id: parallel id)"
|
||||
)
|
||||
answer_stream_generate_routes: AnswerStreamGenerateRoute = Field(
|
||||
...,
|
||||
description="answer stream generate routes"
|
||||
)
|
||||
end_stream_param: EndStreamParam = Field(
|
||||
...,
|
||||
description="end stream param"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
root_node_id: Optional[str] = None) -> "Graph":
|
||||
"""
|
||||
Init graph
|
||||
|
||||
:param graph_config: graph config
|
||||
:param root_node_id: root node id
|
||||
:return: graph
|
||||
"""
|
||||
# edge configs
|
||||
edge_configs = graph_config.get('edges')
|
||||
if edge_configs is None:
|
||||
edge_configs = []
|
||||
|
||||
edge_configs = cast(list, edge_configs)
|
||||
|
||||
# reorganize edges mapping
|
||||
edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]] = {}
|
||||
target_edge_ids = set()
|
||||
for edge_config in edge_configs:
|
||||
source_node_id = edge_config.get('source')
|
||||
if not source_node_id:
|
||||
continue
|
||||
|
||||
if source_node_id not in edge_mapping:
|
||||
edge_mapping[source_node_id] = []
|
||||
|
||||
target_node_id = edge_config.get('target')
|
||||
if not target_node_id:
|
||||
continue
|
||||
|
||||
if target_node_id not in reverse_edge_mapping:
|
||||
reverse_edge_mapping[target_node_id] = []
|
||||
|
||||
# is target node id in source node id edge mapping
|
||||
if any(graph_edge.target_node_id == target_node_id for graph_edge in edge_mapping[source_node_id]):
|
||||
continue
|
||||
|
||||
target_edge_ids.add(target_node_id)
|
||||
|
||||
# parse run condition
|
||||
run_condition = None
|
||||
if edge_config.get('sourceHandle') and edge_config.get('sourceHandle') != 'source':
|
||||
run_condition = RunCondition(
|
||||
type='branch_identify',
|
||||
branch_identify=edge_config.get('sourceHandle')
|
||||
)
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=target_node_id,
|
||||
run_condition=run_condition
|
||||
)
|
||||
|
||||
edge_mapping[source_node_id].append(graph_edge)
|
||||
reverse_edge_mapping[target_node_id].append(graph_edge)
|
||||
|
||||
# node configs
|
||||
node_configs = graph_config.get('nodes')
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
node_configs = cast(list, node_configs)
|
||||
|
||||
# fetch nodes that have no predecessor node
|
||||
root_node_configs = []
|
||||
all_node_id_config_mapping: dict[str, dict] = {}
|
||||
for node_config in node_configs:
|
||||
node_id = node_config.get('id')
|
||||
if not node_id:
|
||||
continue
|
||||
|
||||
if node_id not in target_edge_ids:
|
||||
root_node_configs.append(node_config)
|
||||
|
||||
all_node_id_config_mapping[node_id] = node_config
|
||||
|
||||
root_node_ids = [node_config.get('id') for node_config in root_node_configs]
|
||||
|
||||
# fetch root node
|
||||
if not root_node_id:
|
||||
# if no root node id, use the START type node as root node
|
||||
root_node_id = next((node_config.get("id") for node_config in root_node_configs
|
||||
if node_config.get('data', {}).get('type', '') == NodeType.START.value), None)
|
||||
|
||||
if not root_node_id or root_node_id not in root_node_ids:
|
||||
raise ValueError(f"Root node id {root_node_id} not found in the graph")
|
||||
|
||||
# Check whether it is connected to the previous node
|
||||
cls._check_connected_to_previous_node(
|
||||
route=[root_node_id],
|
||||
edge_mapping=edge_mapping
|
||||
)
|
||||
|
||||
# fetch all node ids from root node
|
||||
node_ids = [root_node_id]
|
||||
cls._recursively_add_node_ids(
|
||||
node_ids=node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
node_id=root_node_id
|
||||
)
|
||||
|
||||
node_id_config_mapping = {node_id: all_node_id_config_mapping[node_id] for node_id in node_ids}
|
||||
|
||||
# init parallel mapping
|
||||
parallel_mapping: dict[str, GraphParallel] = {}
|
||||
node_parallel_mapping: dict[str, str] = {}
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=root_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping
|
||||
)
|
||||
|
||||
# Check if it exceeds N layers of parallel
|
||||
for parallel in parallel_mapping.values():
|
||||
if parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=3,
|
||||
parent_parallel_id=parallel.parent_parallel_id
|
||||
)
|
||||
|
||||
# init answer stream generate routes
|
||||
answer_stream_generate_routes = AnswerStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping
|
||||
)
|
||||
|
||||
# init end stream param
|
||||
end_stream_param = EndStreamGeneratorRouter.init(
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = cls(
|
||||
root_node_id=root_node_id,
|
||||
node_ids=node_ids,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
answer_stream_generate_routes=answer_stream_generate_routes,
|
||||
end_stream_param=end_stream_param
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
def add_extra_edge(self, source_node_id: str,
|
||||
target_node_id: str,
|
||||
run_condition: Optional[RunCondition] = None) -> None:
|
||||
"""
|
||||
Add extra edge to the graph
|
||||
|
||||
:param source_node_id: source node id
|
||||
:param target_node_id: target node id
|
||||
:param run_condition: run condition
|
||||
"""
|
||||
if source_node_id not in self.node_ids or target_node_id not in self.node_ids:
|
||||
return
|
||||
|
||||
if source_node_id not in self.edge_mapping:
|
||||
self.edge_mapping[source_node_id] = []
|
||||
|
||||
if target_node_id in [graph_edge.target_node_id for graph_edge in self.edge_mapping[source_node_id]]:
|
||||
return
|
||||
|
||||
graph_edge = GraphEdge(
|
||||
source_node_id=source_node_id,
|
||||
target_node_id=target_node_id,
|
||||
run_condition=run_condition
|
||||
)
|
||||
|
||||
self.edge_mapping[source_node_id].append(graph_edge)
|
||||
|
||||
def get_leaf_node_ids(self) -> list[str]:
|
||||
"""
|
||||
Get leaf node ids of the graph
|
||||
|
||||
:return: leaf node ids
|
||||
"""
|
||||
leaf_node_ids = []
|
||||
for node_id in self.node_ids:
|
||||
if node_id not in self.edge_mapping:
|
||||
leaf_node_ids.append(node_id)
|
||||
elif (len(self.edge_mapping[node_id]) == 1
|
||||
and self.edge_mapping[node_id][0].target_node_id == self.root_node_id):
|
||||
leaf_node_ids.append(node_id)
|
||||
|
||||
return leaf_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_node_ids(cls,
|
||||
node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
node_id: str) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param node_ids: node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param node_id: node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(node_id, []):
|
||||
if graph_edge.target_node_id in node_ids:
|
||||
continue
|
||||
|
||||
node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_node_ids(
|
||||
node_ids=node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
node_id=graph_edge.target_node_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _check_connected_to_previous_node(
|
||||
cls,
|
||||
route: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]]
|
||||
) -> None:
|
||||
"""
|
||||
Check whether it is connected to the previous node
|
||||
"""
|
||||
last_node_id = route[-1]
|
||||
|
||||
for graph_edge in edge_mapping.get(last_node_id, []):
|
||||
if not graph_edge.target_node_id:
|
||||
continue
|
||||
|
||||
if graph_edge.target_node_id in route:
|
||||
raise ValueError(f"Node {graph_edge.source_node_id} is connected to the previous node, please check the graph.")
|
||||
|
||||
new_route = route[:]
|
||||
new_route.append(graph_edge.target_node_id)
|
||||
cls._check_connected_to_previous_node(
|
||||
route=new_route,
|
||||
edge_mapping=edge_mapping,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallels(
|
||||
cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
node_parallel_mapping: dict[str, str],
|
||||
parent_parallel: Optional[GraphParallel] = None
|
||||
) -> None:
|
||||
"""
|
||||
Recursively add parallel ids
|
||||
|
||||
:param edge_mapping: edge mapping
|
||||
:param start_node_id: start from node id
|
||||
:param parallel_mapping: parallel mapping
|
||||
:param node_parallel_mapping: node parallel mapping
|
||||
:param parent_parallel: parent parallel
|
||||
"""
|
||||
target_node_edges = edge_mapping.get(start_node_id, [])
|
||||
parallel = None
|
||||
if len(target_node_edges) > 1:
|
||||
# fetch all node ids in current parallels
|
||||
parallel_branch_node_ids = []
|
||||
condition_edge_mappings = {}
|
||||
for graph_edge in target_node_edges:
|
||||
if graph_edge.run_condition is None:
|
||||
parallel_branch_node_ids.append(graph_edge.target_node_id)
|
||||
else:
|
||||
condition_hash = graph_edge.run_condition.hash
|
||||
if not condition_hash in condition_edge_mappings:
|
||||
condition_edge_mappings[condition_hash] = []
|
||||
|
||||
condition_edge_mappings[condition_hash].append(graph_edge)
|
||||
|
||||
for _, graph_edges in condition_edge_mappings.items():
|
||||
if len(graph_edges) > 1:
|
||||
for graph_edge in graph_edges:
|
||||
parallel_branch_node_ids.append(graph_edge.target_node_id)
|
||||
|
||||
# any target node id in node_parallel_mapping
|
||||
if parallel_branch_node_ids:
|
||||
parent_parallel_id = parent_parallel.id if parent_parallel else None
|
||||
|
||||
parallel = GraphParallel(
|
||||
start_from_node_id=start_node_id,
|
||||
parent_parallel_id=parent_parallel.id if parent_parallel else None,
|
||||
parent_parallel_start_node_id=parent_parallel.start_from_node_id if parent_parallel else None
|
||||
)
|
||||
parallel_mapping[parallel.id] = parallel
|
||||
|
||||
in_branch_node_ids = cls._fetch_all_node_ids_in_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
parallel_branch_node_ids=parallel_branch_node_ids
|
||||
)
|
||||
|
||||
# collect all branches node ids
|
||||
parallel_node_ids = []
|
||||
for _, node_ids in in_branch_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
in_parent_parallel = True
|
||||
if parent_parallel_id:
|
||||
in_parent_parallel = False
|
||||
for parallel_node_id, parallel_id in node_parallel_mapping.items():
|
||||
if parallel_id == parent_parallel_id and parallel_node_id == node_id:
|
||||
in_parent_parallel = True
|
||||
break
|
||||
|
||||
if in_parent_parallel:
|
||||
parallel_node_ids.append(node_id)
|
||||
node_parallel_mapping[node_id] = parallel.id
|
||||
|
||||
outside_parallel_target_node_ids = set()
|
||||
for node_id in parallel_node_ids:
|
||||
if node_id == parallel.start_from_node_id:
|
||||
continue
|
||||
|
||||
node_edges = edge_mapping.get(node_id)
|
||||
if not node_edges:
|
||||
continue
|
||||
|
||||
if len(node_edges) > 1:
|
||||
continue
|
||||
|
||||
target_node_id = node_edges[0].target_node_id
|
||||
if target_node_id in parallel_node_ids:
|
||||
continue
|
||||
|
||||
if parent_parallel_id:
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
continue
|
||||
|
||||
if (
|
||||
(node_parallel_mapping.get(target_node_id) and node_parallel_mapping.get(target_node_id) == parent_parallel_id)
|
||||
or (parent_parallel and parent_parallel.end_to_node_id and target_node_id == parent_parallel.end_to_node_id)
|
||||
or (not node_parallel_mapping.get(target_node_id) and not parent_parallel)
|
||||
):
|
||||
outside_parallel_target_node_ids.add(target_node_id)
|
||||
|
||||
if len(outside_parallel_target_node_ids) == 1:
|
||||
if parent_parallel and parent_parallel.end_to_node_id and parallel.end_to_node_id == parent_parallel.end_to_node_id:
|
||||
parallel.end_to_node_id = None
|
||||
else:
|
||||
parallel.end_to_node_id = outside_parallel_target_node_ids.pop()
|
||||
|
||||
for graph_edge in target_node_edges:
|
||||
current_parallel = None
|
||||
if parallel:
|
||||
current_parallel = parallel
|
||||
elif parent_parallel:
|
||||
if not parent_parallel.end_to_node_id or (parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel.end_to_node_id):
|
||||
current_parallel = parent_parallel
|
||||
else:
|
||||
# fetch parent parallel's parent parallel
|
||||
parent_parallel_parent_parallel_id = parent_parallel.parent_parallel_id
|
||||
if parent_parallel_parent_parallel_id:
|
||||
parent_parallel_parent_parallel = parallel_mapping.get(parent_parallel_parent_parallel_id)
|
||||
if (
|
||||
parent_parallel_parent_parallel
|
||||
and (
|
||||
not parent_parallel_parent_parallel.end_to_node_id
|
||||
or (parent_parallel_parent_parallel.end_to_node_id and graph_edge.target_node_id != parent_parallel_parent_parallel.end_to_node_id)
|
||||
)
|
||||
):
|
||||
current_parallel = parent_parallel_parent_parallel
|
||||
|
||||
cls._recursively_add_parallels(
|
||||
edge_mapping=edge_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
parallel_mapping=parallel_mapping,
|
||||
node_parallel_mapping=node_parallel_mapping,
|
||||
parent_parallel=current_parallel
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _check_exceed_parallel_limit(
|
||||
cls,
|
||||
parallel_mapping: dict[str, GraphParallel],
|
||||
level_limit: int,
|
||||
parent_parallel_id: str,
|
||||
current_level: int = 1
|
||||
) -> None:
|
||||
"""
|
||||
Check if it exceeds N layers of parallel
|
||||
"""
|
||||
parent_parallel = parallel_mapping.get(parent_parallel_id)
|
||||
if not parent_parallel:
|
||||
return
|
||||
|
||||
current_level += 1
|
||||
if current_level > level_limit:
|
||||
raise ValueError(f"Exceeds {level_limit} layers of parallel")
|
||||
|
||||
if parent_parallel.parent_parallel_id:
|
||||
cls._check_exceed_parallel_limit(
|
||||
parallel_mapping=parallel_mapping,
|
||||
level_limit=level_limit,
|
||||
parent_parallel_id=parent_parallel.parent_parallel_id,
|
||||
current_level=current_level
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _recursively_add_parallel_node_ids(cls,
|
||||
branch_node_ids: list[str],
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
merge_node_id: str,
|
||||
start_node_id: str) -> None:
|
||||
"""
|
||||
Recursively add node ids
|
||||
|
||||
:param branch_node_ids: in branch node ids
|
||||
:param edge_mapping: edge mapping
|
||||
:param merge_node_id: merge node id
|
||||
:param start_node_id: start node id
|
||||
"""
|
||||
for graph_edge in edge_mapping.get(start_node_id, []):
|
||||
if (graph_edge.target_node_id != merge_node_id
|
||||
and graph_edge.target_node_id not in branch_node_ids):
|
||||
branch_node_ids.append(graph_edge.target_node_id)
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=branch_node_ids,
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=graph_edge.target_node_id
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _fetch_all_node_ids_in_parallels(cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
parallel_branch_node_ids: list[str]) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch all node ids in parallels
|
||||
"""
|
||||
routes_node_ids: dict[str, list[str]] = {}
|
||||
for parallel_branch_node_id in parallel_branch_node_ids:
|
||||
routes_node_ids[parallel_branch_node_id] = [parallel_branch_node_id]
|
||||
|
||||
# fetch routes node ids
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=parallel_branch_node_id,
|
||||
routes_node_ids=routes_node_ids[parallel_branch_node_id]
|
||||
)
|
||||
|
||||
# fetch leaf node ids from routes node ids
|
||||
leaf_node_ids: dict[str, list[str]] = {}
|
||||
merge_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
if node_id not in edge_mapping or len(edge_mapping[node_id]) == 0:
|
||||
if branch_node_id not in leaf_node_ids:
|
||||
leaf_node_ids[branch_node_id] = []
|
||||
|
||||
leaf_node_ids[branch_node_id].append(node_id)
|
||||
|
||||
for branch_node_id2, inner_route2 in routes_node_ids.items():
|
||||
if (
|
||||
branch_node_id != branch_node_id2
|
||||
and node_id in inner_route2
|
||||
and len(reverse_edge_mapping.get(node_id, [])) > 1
|
||||
and cls._is_node_in_routes(
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
start_node_id=node_id,
|
||||
routes_node_ids=routes_node_ids
|
||||
)
|
||||
):
|
||||
if node_id not in merge_branch_node_ids:
|
||||
merge_branch_node_ids[node_id] = []
|
||||
|
||||
if branch_node_id2 not in merge_branch_node_ids[node_id]:
|
||||
merge_branch_node_ids[node_id].append(branch_node_id2)
|
||||
|
||||
# sorted merge_branch_node_ids by branch_node_ids length desc
|
||||
merge_branch_node_ids = dict(sorted(merge_branch_node_ids.items(), key=lambda x: len(x[1]), reverse=True))
|
||||
|
||||
duplicate_end_node_ids = {}
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
for node_id2, branch_node_ids2 in merge_branch_node_ids.items():
|
||||
if node_id != node_id2 and set(branch_node_ids) == set(branch_node_ids2):
|
||||
if (node_id, node_id2) not in duplicate_end_node_ids and (node_id2, node_id) not in duplicate_end_node_ids:
|
||||
duplicate_end_node_ids[(node_id, node_id2)] = branch_node_ids
|
||||
|
||||
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
|
||||
# check which node is after
|
||||
if cls._is_node2_after_node1(
|
||||
node1_id=node_id,
|
||||
node2_id=node_id2,
|
||||
edge_mapping=edge_mapping
|
||||
):
|
||||
if node_id in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id2]
|
||||
elif cls._is_node2_after_node1(
|
||||
node1_id=node_id2,
|
||||
node2_id=node_id,
|
||||
edge_mapping=edge_mapping
|
||||
):
|
||||
if node_id2 in merge_branch_node_ids:
|
||||
del merge_branch_node_ids[node_id]
|
||||
|
||||
branches_merge_node_ids: dict[str, str] = {}
|
||||
for node_id, branch_node_ids in merge_branch_node_ids.items():
|
||||
if len(branch_node_ids) <= 1:
|
||||
continue
|
||||
|
||||
for branch_node_id in branch_node_ids:
|
||||
if branch_node_id in branches_merge_node_ids:
|
||||
continue
|
||||
|
||||
branches_merge_node_ids[branch_node_id] = node_id
|
||||
|
||||
in_branch_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
in_branch_node_ids[branch_node_id] = []
|
||||
if branch_node_id not in branches_merge_node_ids:
|
||||
# all node ids in current branch is in this thread
|
||||
in_branch_node_ids[branch_node_id].append(branch_node_id)
|
||||
in_branch_node_ids[branch_node_id].extend(node_ids)
|
||||
else:
|
||||
merge_node_id = branches_merge_node_ids[branch_node_id]
|
||||
if merge_node_id != branch_node_id:
|
||||
in_branch_node_ids[branch_node_id].append(branch_node_id)
|
||||
|
||||
# fetch all node ids from branch_node_id and merge_node_id
|
||||
cls._recursively_add_parallel_node_ids(
|
||||
branch_node_ids=in_branch_node_ids[branch_node_id],
|
||||
edge_mapping=edge_mapping,
|
||||
merge_node_id=merge_node_id,
|
||||
start_node_id=branch_node_id
|
||||
)
|
||||
|
||||
return in_branch_node_ids
|
||||
|
||||
@classmethod
|
||||
def _recursively_fetch_routes(cls,
|
||||
edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
routes_node_ids: list[str]) -> None:
|
||||
"""
|
||||
Recursively fetch route
|
||||
"""
|
||||
if start_node_id not in edge_mapping:
|
||||
return
|
||||
|
||||
for graph_edge in edge_mapping[start_node_id]:
|
||||
# find next node ids
|
||||
if graph_edge.target_node_id not in routes_node_ids:
|
||||
routes_node_ids.append(graph_edge.target_node_id)
|
||||
|
||||
cls._recursively_fetch_routes(
|
||||
edge_mapping=edge_mapping,
|
||||
start_node_id=graph_edge.target_node_id,
|
||||
routes_node_ids=routes_node_ids
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _is_node_in_routes(cls,
|
||||
reverse_edge_mapping: dict[str, list[GraphEdge]],
|
||||
start_node_id: str,
|
||||
routes_node_ids: dict[str, list[str]]) -> bool:
|
||||
"""
|
||||
Recursively check if the node is in the routes
|
||||
"""
|
||||
if start_node_id not in reverse_edge_mapping:
|
||||
return False
|
||||
|
||||
all_routes_node_ids = set()
|
||||
parallel_start_node_ids: dict[str, list[str]] = {}
|
||||
for branch_node_id, node_ids in routes_node_ids.items():
|
||||
for node_id in node_ids:
|
||||
all_routes_node_ids.add(node_id)
|
||||
|
||||
if branch_node_id in reverse_edge_mapping:
|
||||
for graph_edge in reverse_edge_mapping[branch_node_id]:
|
||||
if graph_edge.source_node_id not in parallel_start_node_ids:
|
||||
parallel_start_node_ids[graph_edge.source_node_id] = []
|
||||
|
||||
parallel_start_node_ids[graph_edge.source_node_id].append(branch_node_id)
|
||||
|
||||
parallel_start_node_id = None
|
||||
for p_start_node_id, branch_node_ids in parallel_start_node_ids.items():
|
||||
if set(branch_node_ids) == set(routes_node_ids.keys()):
|
||||
parallel_start_node_id = p_start_node_id
|
||||
return True
|
||||
|
||||
if not parallel_start_node_id:
|
||||
raise Exception("Parallel start node id not found")
|
||||
|
||||
for graph_edge in reverse_edge_mapping[start_node_id]:
|
||||
if graph_edge.source_node_id not in all_routes_node_ids or graph_edge.source_node_id != parallel_start_node_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def _is_node2_after_node1(
|
||||
cls,
|
||||
node1_id: str,
|
||||
node2_id: str,
|
||||
edge_mapping: dict[str, list[GraphEdge]]
|
||||
) -> bool:
|
||||
"""
|
||||
is node2 after node1
|
||||
"""
|
||||
if node1_id not in edge_mapping:
|
||||
return False
|
||||
|
||||
for graph_edge in edge_mapping[node1_id]:
|
||||
if graph_edge.target_node_id == node2_id:
|
||||
return True
|
||||
|
||||
if cls._is_node2_after_node1(
|
||||
node1_id=graph_edge.target_node_id,
|
||||
node2_id=node2_id,
|
||||
edge_mapping=edge_mapping
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
@ -0,0 +1,21 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from models.workflow import WorkflowType
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
workflow_type: WorkflowType = Field(..., description="workflow type")
|
||||
workflow_id: str = Field(..., description="workflow id")
|
||||
graph_config: Mapping[str, Any] = Field(..., description="graph config")
|
||||
user_id: str = Field(..., description="user id")
|
||||
user_from: UserFrom = Field(..., description="user from, account or end-user")
|
||||
invoke_from: InvokeFrom = Field(..., description="invoke from, service-api, web-app, explore or debugger")
|
||||
call_depth: int = Field(..., description="call depth")
|
||||
@ -0,0 +1,27 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||
|
||||
|
||||
class GraphRuntimeState(BaseModel):
|
||||
variable_pool: VariablePool = Field(..., description="variable pool")
|
||||
"""variable pool"""
|
||||
|
||||
start_at: float = Field(..., description="start time")
|
||||
"""start time"""
|
||||
total_tokens: int = 0
|
||||
"""total tokens"""
|
||||
llm_usage: LLMUsage = LLMUsage.empty_usage()
|
||||
"""llm usage info"""
|
||||
outputs: dict[str, Any] = {}
|
||||
"""outputs"""
|
||||
|
||||
node_run_steps: int = 0
|
||||
"""node run steps"""
|
||||
|
||||
node_run_state: RuntimeRouteState = RuntimeRouteState()
|
||||
"""node run state"""
|
||||
@ -0,0 +1,13 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.graph_engine.entities.graph import GraphParallel
|
||||
|
||||
|
||||
class NextGraphNode(BaseModel):
|
||||
node_id: str
|
||||
"""next node id"""
|
||||
|
||||
parallel: Optional[GraphParallel] = None
|
||||
"""parallel"""
|
||||
@ -0,0 +1,21 @@
|
||||
import hashlib
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class RunCondition(BaseModel):
|
||||
type: Literal["branch_identify", "condition"]
|
||||
"""condition type"""
|
||||
|
||||
branch_identify: Optional[str] = None
|
||||
"""branch identify like: sourceHandle, required when type is branch_identify"""
|
||||
|
||||
conditions: Optional[list[Condition]] = None
|
||||
"""conditions to run the node, required when type is condition"""
|
||||
|
||||
@property
|
||||
def hash(self) -> str:
|
||||
return hashlib.sha256(self.model_dump_json().encode()).hexdigest()
|
||||
@ -0,0 +1,111 @@
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class RouteNodeState(BaseModel):
|
||||
class Status(Enum):
|
||||
RUNNING = "running"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
PAUSED = "paused"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
"""node state id"""
|
||||
|
||||
node_id: str
|
||||
"""node id"""
|
||||
|
||||
node_run_result: Optional[NodeRunResult] = None
|
||||
"""node run result"""
|
||||
|
||||
status: Status = Status.RUNNING
|
||||
"""node status"""
|
||||
|
||||
start_at: datetime
|
||||
"""start time"""
|
||||
|
||||
paused_at: Optional[datetime] = None
|
||||
"""paused time"""
|
||||
|
||||
finished_at: Optional[datetime] = None
|
||||
"""finished time"""
|
||||
|
||||
failed_reason: Optional[str] = None
|
||||
"""failed reason"""
|
||||
|
||||
paused_by: Optional[str] = None
|
||||
"""paused by"""
|
||||
|
||||
index: int = 1
|
||||
|
||||
def set_finished(self, run_result: NodeRunResult) -> None:
|
||||
"""
|
||||
Node finished
|
||||
|
||||
:param run_result: run result
|
||||
"""
|
||||
if self.status in [RouteNodeState.Status.SUCCESS, RouteNodeState.Status.FAILED]:
|
||||
raise Exception(f"Route state {self.id} already finished")
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
self.status = RouteNodeState.Status.SUCCESS
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
self.status = RouteNodeState.Status.FAILED
|
||||
self.failed_reason = run_result.error
|
||||
else:
|
||||
raise Exception(f"Invalid route status {run_result.status}")
|
||||
|
||||
self.node_run_result = run_result
|
||||
self.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
|
||||
class RuntimeRouteState(BaseModel):
|
||||
routes: dict[str, list[str]] = Field(
|
||||
default_factory=dict,
|
||||
description="graph state routes (source_node_state_id: target_node_state_id)"
|
||||
)
|
||||
|
||||
node_state_mapping: dict[str, RouteNodeState] = Field(
|
||||
default_factory=dict,
|
||||
description="node state mapping (route_node_state_id: route_node_state)"
|
||||
)
|
||||
|
||||
def create_node_state(self, node_id: str) -> RouteNodeState:
|
||||
"""
|
||||
Create node state
|
||||
|
||||
:param node_id: node id
|
||||
"""
|
||||
state = RouteNodeState(node_id=node_id, start_at=datetime.now(timezone.utc).replace(tzinfo=None))
|
||||
self.node_state_mapping[state.id] = state
|
||||
return state
|
||||
|
||||
def add_route(self, source_node_state_id: str, target_node_state_id: str) -> None:
|
||||
"""
|
||||
Add route to the graph state
|
||||
|
||||
:param source_node_state_id: source node state id
|
||||
:param target_node_state_id: target node state id
|
||||
"""
|
||||
if source_node_state_id not in self.routes:
|
||||
self.routes[source_node_state_id] = []
|
||||
|
||||
self.routes[source_node_state_id].append(target_node_state_id)
|
||||
|
||||
def get_routes_with_node_state_by_source_node_state_id(self, source_node_state_id: str) \
|
||||
-> list[RouteNodeState]:
|
||||
"""
|
||||
Get routes with node state by source node id
|
||||
|
||||
:param source_node_state_id: source node state id
|
||||
:return: routes with node state
|
||||
"""
|
||||
return [self.node_state_mapping[target_state_id]
|
||||
for target_state_id in self.routes.get(source_node_state_id, [])]
|
||||
@ -0,0 +1,716 @@
|
||||
import logging
|
||||
import queue
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor, wait
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import (
|
||||
NodeRunMetadataKey,
|
||||
NodeType,
|
||||
UserFrom,
|
||||
)
|
||||
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||
from core.workflow.graph_engine.condition_handlers.condition_manager import ConditionManager
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseIterationEvent,
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
ParallelBranchRunFailedEvent,
|
||||
ParallelBranchRunStartedEvent,
|
||||
ParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph, GraphEdge
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.nodes.answer.answer_stream_processor import AnswerStreamProcessor
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from extensions.ext_database import db
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GraphEngineThreadPool(ThreadPoolExecutor):
|
||||
def __init__(self, max_workers=None, thread_name_prefix='',
|
||||
initializer=None, initargs=(), max_submit_count=100) -> None:
|
||||
super().__init__(max_workers, thread_name_prefix, initializer, initargs)
|
||||
self.max_submit_count = max_submit_count
|
||||
self.submit_count = 0
|
||||
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
self.submit_count += 1
|
||||
self.check_is_full()
|
||||
|
||||
return super().submit(fn, *args, **kwargs)
|
||||
|
||||
def check_is_full(self) -> None:
|
||||
print(f"submit_count: {self.submit_count}, max_submit_count: {self.max_submit_count}")
|
||||
if self.submit_count > self.max_submit_count:
|
||||
raise ValueError(f"Max submit count {self.max_submit_count} of workflow thread pool reached.")
|
||||
|
||||
|
||||
class GraphEngine:
|
||||
workflow_thread_pool_mapping: dict[str, GraphEngineThreadPool] = {}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_type: WorkflowType,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
call_depth: int,
|
||||
graph: Graph,
|
||||
graph_config: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
max_execution_steps: int,
|
||||
max_execution_time: int,
|
||||
thread_pool_id: Optional[str] = None
|
||||
) -> None:
|
||||
thread_pool_max_submit_count = 100
|
||||
thread_pool_max_workers = 10
|
||||
|
||||
## init thread pool
|
||||
if thread_pool_id:
|
||||
if not thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
|
||||
raise ValueError(f"Max submit count {thread_pool_max_submit_count} of workflow thread pool reached.")
|
||||
|
||||
self.thread_pool_id = thread_pool_id
|
||||
self.thread_pool = GraphEngine.workflow_thread_pool_mapping[thread_pool_id]
|
||||
self.is_main_thread_pool = False
|
||||
else:
|
||||
self.thread_pool = GraphEngineThreadPool(max_workers=thread_pool_max_workers, max_submit_count=thread_pool_max_submit_count)
|
||||
self.thread_pool_id = str(uuid.uuid4())
|
||||
self.is_main_thread_pool = True
|
||||
GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id] = self.thread_pool
|
||||
|
||||
self.graph = graph
|
||||
self.init_params = GraphInitParams(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_type=workflow_type,
|
||||
workflow_id=workflow_id,
|
||||
graph_config=graph_config,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth
|
||||
)
|
||||
|
||||
self.graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
|
||||
self.max_execution_steps = max_execution_steps
|
||||
self.max_execution_time = max_execution_time
|
||||
|
||||
def run(self) -> Generator[GraphEngineEvent, None, None]:
|
||||
# trigger graph run start event
|
||||
yield GraphRunStartedEvent()
|
||||
|
||||
try:
|
||||
stream_processor_cls: type[AnswerStreamProcessor | EndStreamProcessor]
|
||||
if self.init_params.workflow_type == WorkflowType.CHAT:
|
||||
stream_processor_cls = AnswerStreamProcessor
|
||||
else:
|
||||
stream_processor_cls = EndStreamProcessor
|
||||
|
||||
stream_processor = stream_processor_cls(
|
||||
graph=self.graph,
|
||||
variable_pool=self.graph_runtime_state.variable_pool
|
||||
)
|
||||
|
||||
# run graph
|
||||
generator = stream_processor.process(
|
||||
self._run(start_node_id=self.graph.root_node_id)
|
||||
)
|
||||
|
||||
for item in generator:
|
||||
try:
|
||||
yield item
|
||||
if isinstance(item, NodeRunFailedEvent):
|
||||
yield GraphRunFailedEvent(error=item.route_node_state.failed_reason or 'Unknown error.')
|
||||
return
|
||||
elif isinstance(item, NodeRunSucceededEvent):
|
||||
if item.node_type == NodeType.END:
|
||||
self.graph_runtime_state.outputs = (item.route_node_state.node_run_result.outputs
|
||||
if item.route_node_state.node_run_result
|
||||
and item.route_node_state.node_run_result.outputs
|
||||
else {})
|
||||
elif item.node_type == NodeType.ANSWER:
|
||||
if "answer" not in self.graph_runtime_state.outputs:
|
||||
self.graph_runtime_state.outputs["answer"] = ""
|
||||
|
||||
self.graph_runtime_state.outputs["answer"] += "\n" + (item.route_node_state.node_run_result.outputs.get("answer", "")
|
||||
if item.route_node_state.node_run_result
|
||||
and item.route_node_state.node_run_result.outputs
|
||||
else "")
|
||||
|
||||
self.graph_runtime_state.outputs["answer"] = self.graph_runtime_state.outputs["answer"].strip()
|
||||
except Exception as e:
|
||||
logger.exception(f"Graph run failed: {str(e)}")
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
return
|
||||
|
||||
# trigger graph run success event
|
||||
yield GraphRunSucceededEvent(outputs=self.graph_runtime_state.outputs)
|
||||
except GraphRunFailedError as e:
|
||||
yield GraphRunFailedEvent(error=e.error)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when graph running")
|
||||
yield GraphRunFailedEvent(error=str(e))
|
||||
raise e
|
||||
finally:
|
||||
if self.is_main_thread_pool and self.thread_pool_id in GraphEngine.workflow_thread_pool_mapping:
|
||||
del GraphEngine.workflow_thread_pool_mapping[self.thread_pool_id]
|
||||
|
||||
def _run(
|
||||
self,
|
||||
start_node_id: str,
|
||||
in_parallel_id: Optional[str] = None,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
parallel_start_node_id = None
|
||||
if in_parallel_id:
|
||||
parallel_start_node_id = start_node_id
|
||||
|
||||
next_node_id = start_node_id
|
||||
previous_route_node_state: Optional[RouteNodeState] = None
|
||||
while True:
|
||||
# max steps reached
|
||||
if self.graph_runtime_state.node_run_steps > self.max_execution_steps:
|
||||
raise GraphRunFailedError('Max steps {} reached.'.format(self.max_execution_steps))
|
||||
|
||||
# or max execution time reached
|
||||
if self._is_timed_out(
|
||||
start_at=self.graph_runtime_state.start_at,
|
||||
max_execution_time=self.max_execution_time
|
||||
):
|
||||
raise GraphRunFailedError('Max execution time {}s reached.'.format(self.max_execution_time))
|
||||
|
||||
# init route node state
|
||||
route_node_state = self.graph_runtime_state.node_run_state.create_node_state(
|
||||
node_id=next_node_id
|
||||
)
|
||||
|
||||
# get node config
|
||||
node_id = route_node_state.node_id
|
||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||
if not node_config:
|
||||
raise GraphRunFailedError(f'Node {node_id} config not found.')
|
||||
|
||||
# convert to specific node
|
||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
if not node_cls:
|
||||
raise GraphRunFailedError(f'Node {node_id} type {node_type} not found.')
|
||||
|
||||
previous_node_id = previous_route_node_state.node_id if previous_route_node_state else None
|
||||
|
||||
# init workflow run state
|
||||
node_instance = node_cls( # type: ignore
|
||||
id=route_node_state.id,
|
||||
config=node_config,
|
||||
graph_init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_node_id=previous_node_id,
|
||||
thread_pool_id=self.thread_pool_id
|
||||
)
|
||||
|
||||
try:
|
||||
# run node
|
||||
generator = self._run_node(
|
||||
node_instance=node_instance,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
for item in generator:
|
||||
if isinstance(item, NodeRunStartedEvent):
|
||||
self.graph_runtime_state.node_run_steps += 1
|
||||
item.route_node_state.index = self.graph_runtime_state.node_run_steps
|
||||
|
||||
yield item
|
||||
|
||||
self.graph_runtime_state.node_run_state.node_state_mapping[route_node_state.id] = route_node_state
|
||||
|
||||
# append route
|
||||
if previous_route_node_state:
|
||||
self.graph_runtime_state.node_run_state.add_route(
|
||||
source_node_state_id=previous_route_node_state.id,
|
||||
target_node_state_id=route_node_state.id
|
||||
)
|
||||
except Exception as e:
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = str(e)
|
||||
yield NodeRunFailedEvent(
|
||||
error=str(e),
|
||||
id=node_instance.id,
|
||||
node_id=next_node_id,
|
||||
node_type=node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
raise e
|
||||
|
||||
# It may not be necessary, but it is necessary. :)
|
||||
if (self.graph.node_id_config_mapping[next_node_id]
|
||||
.get("data", {}).get("type", "").lower() == NodeType.END.value):
|
||||
break
|
||||
|
||||
previous_route_node_state = route_node_state
|
||||
|
||||
# get next node ids
|
||||
edge_mappings = self.graph.edge_mapping.get(next_node_id)
|
||||
if not edge_mappings:
|
||||
break
|
||||
|
||||
if len(edge_mappings) == 1:
|
||||
edge = edge_mappings[0]
|
||||
|
||||
if edge.run_condition:
|
||||
result = ConditionManager.get_condition_handler(
|
||||
init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
run_condition=edge.run_condition,
|
||||
).check(
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_route_node_state=previous_route_node_state
|
||||
)
|
||||
|
||||
if not result:
|
||||
break
|
||||
|
||||
next_node_id = edge.target_node_id
|
||||
else:
|
||||
final_node_id = None
|
||||
|
||||
if any(edge.run_condition for edge in edge_mappings):
|
||||
# if nodes has run conditions, get node id which branch to take based on the run condition results
|
||||
condition_edge_mappings = {}
|
||||
for edge in edge_mappings:
|
||||
if edge.run_condition:
|
||||
run_condition_hash = edge.run_condition.hash
|
||||
if run_condition_hash not in condition_edge_mappings:
|
||||
condition_edge_mappings[run_condition_hash] = []
|
||||
|
||||
condition_edge_mappings[run_condition_hash].append(edge)
|
||||
|
||||
for _, sub_edge_mappings in condition_edge_mappings.items():
|
||||
if len(sub_edge_mappings) == 0:
|
||||
continue
|
||||
|
||||
edge = sub_edge_mappings[0]
|
||||
|
||||
result = ConditionManager.get_condition_handler(
|
||||
init_params=self.init_params,
|
||||
graph=self.graph,
|
||||
run_condition=edge.run_condition,
|
||||
).check(
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
previous_route_node_state=previous_route_node_state,
|
||||
)
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
if len(sub_edge_mappings) == 1:
|
||||
final_node_id = edge.target_node_id
|
||||
else:
|
||||
parallel_generator = self._run_parallel_branches(
|
||||
edge_mappings=sub_edge_mappings,
|
||||
in_parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
for item in parallel_generator:
|
||||
if isinstance(item, str):
|
||||
final_node_id = item
|
||||
else:
|
||||
yield item
|
||||
|
||||
break
|
||||
|
||||
if not final_node_id:
|
||||
break
|
||||
|
||||
next_node_id = final_node_id
|
||||
else:
|
||||
parallel_generator = self._run_parallel_branches(
|
||||
edge_mappings=edge_mappings,
|
||||
in_parallel_id=in_parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id
|
||||
)
|
||||
|
||||
for item in parallel_generator:
|
||||
if isinstance(item, str):
|
||||
final_node_id = item
|
||||
else:
|
||||
yield item
|
||||
|
||||
if not final_node_id:
|
||||
break
|
||||
|
||||
next_node_id = final_node_id
|
||||
|
||||
if in_parallel_id and self.graph.node_parallel_mapping.get(next_node_id, '') != in_parallel_id:
|
||||
break
|
||||
|
||||
def _run_parallel_branches(
|
||||
self,
|
||||
edge_mappings: list[GraphEdge],
|
||||
in_parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
) -> Generator[GraphEngineEvent | str, None, None]:
|
||||
# if nodes has no run conditions, parallel run all nodes
|
||||
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
|
||||
if not parallel_id:
|
||||
node_id = edge_mappings[0].target_node_id
|
||||
node_config = self.graph.node_id_config_mapping.get(node_id)
|
||||
if not node_config:
|
||||
raise GraphRunFailedError(f'Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches.')
|
||||
|
||||
node_title = node_config.get('data', {}).get('title')
|
||||
raise GraphRunFailedError(f'Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches.')
|
||||
|
||||
parallel = self.graph.parallel_mapping.get(parallel_id)
|
||||
if not parallel:
|
||||
raise GraphRunFailedError(f'Parallel {parallel_id} not found.')
|
||||
|
||||
# run parallel nodes, run in new thread and use queue to get results
|
||||
q: queue.Queue = queue.Queue()
|
||||
|
||||
# Create a list to store the threads
|
||||
futures = []
|
||||
|
||||
# new thread
|
||||
for edge in edge_mappings:
|
||||
if (
|
||||
edge.target_node_id not in self.graph.node_parallel_mapping
|
||||
or self.graph.node_parallel_mapping.get(edge.target_node_id, '') != parallel_id
|
||||
):
|
||||
continue
|
||||
|
||||
futures.append(
|
||||
self.thread_pool.submit(self._run_parallel_node, **{
|
||||
'flask_app': current_app._get_current_object(), # type: ignore[attr-defined]
|
||||
'q': q,
|
||||
'parallel_id': parallel_id,
|
||||
'parallel_start_node_id': edge.target_node_id,
|
||||
'parent_parallel_id': in_parallel_id,
|
||||
'parent_parallel_start_node_id': parallel_start_node_id,
|
||||
})
|
||||
)
|
||||
|
||||
succeeded_count = 0
|
||||
while True:
|
||||
try:
|
||||
event = q.get(timeout=1)
|
||||
if event is None:
|
||||
break
|
||||
|
||||
yield event
|
||||
if event.parallel_id == parallel_id:
|
||||
if isinstance(event, ParallelBranchRunSucceededEvent):
|
||||
succeeded_count += 1
|
||||
if succeeded_count == len(futures):
|
||||
q.put(None)
|
||||
|
||||
continue
|
||||
elif isinstance(event, ParallelBranchRunFailedEvent):
|
||||
raise GraphRunFailedError(event.error)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
# wait all threads
|
||||
wait(futures)
|
||||
|
||||
# get final node id
|
||||
final_node_id = parallel.end_to_node_id
|
||||
if final_node_id:
|
||||
yield final_node_id
|
||||
|
||||
def _run_parallel_node(
|
||||
self,
|
||||
flask_app: Flask,
|
||||
q: queue.Queue,
|
||||
parallel_id: str,
|
||||
parallel_start_node_id: str,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run parallel nodes
|
||||
"""
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
q.put(ParallelBranchRunStartedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
))
|
||||
|
||||
# run node
|
||||
generator = self._run(
|
||||
start_node_id=parallel_start_node_id,
|
||||
in_parallel_id=parallel_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
for item in generator:
|
||||
q.put(item)
|
||||
|
||||
# trigger graph run success event
|
||||
q.put(ParallelBranchRunSucceededEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
))
|
||||
except GraphRunFailedError as e:
|
||||
q.put(ParallelBranchRunFailedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
error=e.error
|
||||
))
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when generating in parallel")
|
||||
q.put(ParallelBranchRunFailedEvent(
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id,
|
||||
error=str(e)
|
||||
))
|
||||
finally:
|
||||
db.session.remove()
|
||||
|
||||
def _run_node(
|
||||
self,
|
||||
node_instance: BaseNode,
|
||||
route_node_state: RouteNodeState,
|
||||
parallel_id: Optional[str] = None,
|
||||
parallel_start_node_id: Optional[str] = None,
|
||||
parent_parallel_id: Optional[str] = None,
|
||||
parent_parallel_start_node_id: Optional[str] = None,
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Run node
|
||||
"""
|
||||
# trigger node run start event
|
||||
yield NodeRunStartedEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
predecessor_node_id=node_instance.previous_node_id,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
db.session.close()
|
||||
|
||||
try:
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
for item in generator:
|
||||
if isinstance(item, GraphEngineEvent):
|
||||
if isinstance(item, BaseIterationEvent):
|
||||
# add parallel info to iteration event
|
||||
item.parallel_id = parallel_id
|
||||
item.parallel_start_node_id = parallel_start_node_id
|
||||
item.parent_parallel_id = parent_parallel_id
|
||||
item.parent_parallel_start_node_id = parent_parallel_start_node_id
|
||||
|
||||
yield item
|
||||
else:
|
||||
if isinstance(item, RunCompletedEvent):
|
||||
run_result = item.run_result
|
||||
route_node_state.set_finished(run_result=run_result)
|
||||
|
||||
if run_result.status == WorkflowNodeExecutionStatus.FAILED:
|
||||
yield NodeRunFailedEvent(
|
||||
error=route_node_state.failed_reason or 'Unknown error.',
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
elif run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED:
|
||||
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
|
||||
# plus state total_tokens
|
||||
self.graph_runtime_state.total_tokens += int(
|
||||
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if run_result.llm_usage:
|
||||
# use the latest usage
|
||||
self.graph_runtime_state.llm_usage += run_result.llm_usage
|
||||
|
||||
# append node output variables to variable pool
|
||||
if run_result.outputs:
|
||||
for variable_key, variable_value in run_result.outputs.items():
|
||||
# append variables to variable pool recursively
|
||||
self._append_variables_recursively(
|
||||
node_id=node_instance.node_id,
|
||||
variable_key_list=[variable_key],
|
||||
variable_value=variable_value
|
||||
)
|
||||
|
||||
# add parallel info to run result metadata
|
||||
if parallel_id and parallel_start_node_id:
|
||||
if not run_result.metadata:
|
||||
run_result.metadata = {}
|
||||
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_ID] = parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARALLEL_START_NODE_ID] = parallel_start_node_id
|
||||
if parent_parallel_id and parent_parallel_start_node_id:
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_ID] = parent_parallel_id
|
||||
run_result.metadata[NodeRunMetadataKey.PARENT_PARALLEL_START_NODE_ID] = parent_parallel_start_node_id
|
||||
|
||||
yield NodeRunSucceededEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
|
||||
break
|
||||
elif isinstance(item, RunStreamChunkEvent):
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
chunk_content=item.chunk_content,
|
||||
from_variable_selector=item.from_variable_selector,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
elif isinstance(item, RunRetrieverResourceEvent):
|
||||
yield NodeRunRetrieverResourceEvent(
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
retriever_resources=item.retriever_resources,
|
||||
context=item.context,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
except GenerateTaskStoppedException:
|
||||
# trigger node run failed event
|
||||
route_node_state.status = RouteNodeState.Status.FAILED
|
||||
route_node_state.failed_reason = "Workflow stopped."
|
||||
yield NodeRunFailedEvent(
|
||||
error="Workflow stopped.",
|
||||
id=node_instance.id,
|
||||
node_id=node_instance.node_id,
|
||||
node_type=node_instance.node_type,
|
||||
node_data=node_instance.node_data,
|
||||
route_node_state=route_node_state,
|
||||
parallel_id=parallel_id,
|
||||
parallel_start_node_id=parallel_start_node_id,
|
||||
parent_parallel_id=parent_parallel_id,
|
||||
parent_parallel_start_node_id=parent_parallel_start_node_id
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception(f"Node {node_instance.node_data.title} run failed: {str(e)}")
|
||||
raise e
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _append_variables_recursively(self,
|
||||
node_id: str,
|
||||
variable_key_list: list[str],
|
||||
variable_value: VariableValue):
|
||||
"""
|
||||
Append variables recursively
|
||||
:param node_id: node id
|
||||
:param variable_key_list: variable key list
|
||||
:param variable_value: variable value
|
||||
:return:
|
||||
"""
|
||||
self.graph_runtime_state.variable_pool.add(
|
||||
[node_id] + variable_key_list,
|
||||
variable_value
|
||||
)
|
||||
|
||||
# if variable_value is a dict, then recursively append variables
|
||||
if isinstance(variable_value, dict):
|
||||
for key, value in variable_value.items():
|
||||
# construct new key list
|
||||
new_key_list = variable_key_list + [key]
|
||||
self._append_variables_recursively(
|
||||
node_id=node_id,
|
||||
variable_key_list=new_key_list,
|
||||
variable_value=value
|
||||
)
|
||||
|
||||
def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool:
|
||||
"""
|
||||
Check timeout
|
||||
:param start_at: start time
|
||||
:param max_execution_time: max execution time
|
||||
:return:
|
||||
"""
|
||||
return time.perf_counter() - start_at > max_execution_time
|
||||
|
||||
|
||||
class GraphRunFailedError(Exception):
|
||||
def __init__(self, error: str):
|
||||
self.error = error
|
||||
@ -0,0 +1,169 @@
|
||||
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.entities import (
|
||||
AnswerNodeData,
|
||||
AnswerStreamGenerateRoute,
|
||||
GenerateRouteChunk,
|
||||
TextGenerateRouteChunk,
|
||||
VarGenerateRouteChunk,
|
||||
)
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class AnswerStreamGeneratorRouter:
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]] # type: ignore[name-defined]
|
||||
) -> AnswerStreamGenerateRoute:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# parse stream output node value selectors of answer nodes
|
||||
answer_generate_route: dict[str, list[GenerateRouteChunk]] = {}
|
||||
for answer_node_id, node_config in node_id_config_mapping.items():
|
||||
if not node_config.get('data', {}).get('type') == NodeType.ANSWER.value:
|
||||
continue
|
||||
|
||||
# get generate route for stream output
|
||||
generate_route = cls._extract_generate_route_selectors(node_config)
|
||||
answer_generate_route[answer_node_id] = generate_route
|
||||
|
||||
# fetch answer dependencies
|
||||
answer_node_ids = list(answer_generate_route.keys())
|
||||
answer_dependencies = cls._fetch_answers_dependencies(
|
||||
answer_node_ids=answer_node_ids,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_id_config_mapping=node_id_config_mapping
|
||||
)
|
||||
|
||||
return AnswerStreamGenerateRoute(
|
||||
answer_generate_route=answer_generate_route,
|
||||
answer_dependencies=answer_dependencies
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_generate_route_from_node_data(cls, node_data: AnswerNodeData) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route from node data
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
variable_template_parser = VariableTemplateParser(template=node_data.answer)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
|
||||
value_selector_mapping = {
|
||||
variable_selector.variable: variable_selector.value_selector
|
||||
for variable_selector in variable_selectors
|
||||
}
|
||||
|
||||
variable_keys = list(value_selector_mapping.keys())
|
||||
|
||||
# format answer template
|
||||
template_parser = PromptTemplateParser(template=node_data.answer, with_variable_tmpl=True)
|
||||
template_variable_keys = template_parser.variable_keys
|
||||
|
||||
# Take the intersection of variable_keys and template_variable_keys
|
||||
variable_keys = list(set(variable_keys) & set(template_variable_keys))
|
||||
|
||||
template = node_data.answer
|
||||
for var in variable_keys:
|
||||
template = template.replace(f'{{{{{var}}}}}', f'Ω{{{{{var}}}}}Ω')
|
||||
|
||||
generate_routes: list[GenerateRouteChunk] = []
|
||||
for part in template.split('Ω'):
|
||||
if part:
|
||||
if cls._is_variable(part, variable_keys):
|
||||
var_key = part.replace('Ω', '').replace('{{', '').replace('}}', '')
|
||||
value_selector = value_selector_mapping[var_key]
|
||||
generate_routes.append(VarGenerateRouteChunk(
|
||||
value_selector=value_selector
|
||||
))
|
||||
else:
|
||||
generate_routes.append(TextGenerateRouteChunk(
|
||||
text=part
|
||||
))
|
||||
|
||||
return generate_routes
|
||||
|
||||
@classmethod
|
||||
def _extract_generate_route_selectors(cls, config: dict) -> list[GenerateRouteChunk]:
|
||||
"""
|
||||
Extract generate route selectors
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = AnswerNodeData(**config.get("data", {}))
|
||||
return cls.extract_generate_route_from_node_data(node_data)
|
||||
|
||||
@classmethod
|
||||
def _is_variable(cls, part, variable_keys):
|
||||
cleaned_part = part.replace('{{', '').replace('}}', '')
|
||||
return part.startswith('{{') and cleaned_part in variable_keys
|
||||
|
||||
@classmethod
|
||||
def _fetch_answers_dependencies(cls,
|
||||
answer_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict]
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch answer dependencies
|
||||
:param answer_node_ids: answer node ids
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:return:
|
||||
"""
|
||||
answer_dependencies: dict[str, list[str]] = {}
|
||||
for answer_node_id in answer_node_ids:
|
||||
if answer_dependencies.get(answer_node_id) is None:
|
||||
answer_dependencies[answer_node_id] = []
|
||||
|
||||
cls._recursive_fetch_answer_dependencies(
|
||||
current_node_id=answer_node_id,
|
||||
answer_node_id=answer_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
answer_dependencies=answer_dependencies
|
||||
)
|
||||
|
||||
return answer_dependencies
|
||||
|
||||
@classmethod
|
||||
def _recursive_fetch_answer_dependencies(cls,
|
||||
current_node_id: str,
|
||||
answer_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
answer_dependencies: dict[str, list[str]]
|
||||
) -> None:
|
||||
"""
|
||||
Recursive fetch answer dependencies
|
||||
:param current_node_id: current node id
|
||||
:param answer_node_id: answer node id
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param answer_dependencies: answer dependencies
|
||||
:return:
|
||||
"""
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
|
||||
if source_node_type in (
|
||||
NodeType.ANSWER.value,
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
):
|
||||
answer_dependencies[answer_node_id].append(source_node_id)
|
||||
else:
|
||||
cls._recursive_fetch_answer_dependencies(
|
||||
current_node_id=source_node_id,
|
||||
answer_node_id=answer_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
answer_dependencies=answer_dependencies
|
||||
)
|
||||
@ -0,0 +1,221 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, cast
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||
from core.workflow.nodes.answer.entities import GenerateRouteChunk, TextGenerateRouteChunk, VarGenerateRouteChunk
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AnswerStreamProcessor(StreamProcessor):
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
super().__init__(graph, variable_pool)
|
||||
self.generate_routes = graph.answer_stream_generate_routes
|
||||
self.route_position = {}
|
||||
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
||||
self.route_position[answer_node_id] = 0
|
||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||
|
||||
def process(self,
|
||||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
|
||||
self.reset()
|
||||
|
||||
yield event
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.in_iteration_id:
|
||||
yield event
|
||||
continue
|
||||
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
stream_out_answer_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
]
|
||||
else:
|
||||
stream_out_answer_node_ids = self._get_stream_out_answer_node_ids(event)
|
||||
self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
] = stream_out_answer_node_ids
|
||||
|
||||
for _ in stream_out_answer_node_ids:
|
||||
yield event
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
yield event
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
# update self.route_position after all stream event finished
|
||||
for answer_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
||||
|
||||
# remove unreachable nodes
|
||||
self._remove_unreachable_nodes(event)
|
||||
|
||||
# generate stream outputs
|
||||
yield from self._generate_stream_outputs_when_node_finished(event)
|
||||
else:
|
||||
yield event
|
||||
|
||||
def reset(self) -> None:
|
||||
self.route_position = {}
|
||||
for answer_node_id, route_chunks in self.generate_routes.answer_generate_route.items():
|
||||
self.route_position[answer_node_id] = 0
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:param event: node run succeeded event
|
||||
:return:
|
||||
"""
|
||||
for answer_node_id, position in self.route_position.items():
|
||||
# all depends on answer node id not in rest node ids
|
||||
if (event.route_node_state.node_id != answer_node_id
|
||||
and (answer_node_id not in self.rest_node_ids
|
||||
or not all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]))):
|
||||
continue
|
||||
|
||||
route_position = self.route_position[answer_node_id]
|
||||
route_chunks = self.generate_routes.answer_generate_route[answer_node_id][route_position:]
|
||||
|
||||
for route_chunk in route_chunks:
|
||||
if route_chunk.type == GenerateRouteChunk.ChunkType.TEXT:
|
||||
route_chunk = cast(TextGenerateRouteChunk, route_chunk)
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
chunk_content=route_chunk.text,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
else:
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
if not value_selector:
|
||||
break
|
||||
|
||||
value = self.variable_pool.get(
|
||||
value_selector
|
||||
)
|
||||
|
||||
if value is None:
|
||||
break
|
||||
|
||||
text = value.markdown
|
||||
|
||||
if text:
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
chunk_content=text,
|
||||
from_variable_selector=value_selector,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
|
||||
self.route_position[answer_node_id] += 1
|
||||
|
||||
def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.from_variable_selector:
|
||||
return []
|
||||
|
||||
stream_output_value_selector = event.from_variable_selector
|
||||
if not stream_output_value_selector:
|
||||
return []
|
||||
|
||||
stream_out_answer_node_ids = []
|
||||
for answer_node_id, route_position in self.route_position.items():
|
||||
if answer_node_id not in self.rest_node_ids:
|
||||
continue
|
||||
|
||||
# all depends on answer node id not in rest node ids
|
||||
if all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.generate_routes.answer_dependencies[answer_node_id]):
|
||||
if route_position >= len(self.generate_routes.answer_generate_route[answer_node_id]):
|
||||
continue
|
||||
|
||||
route_chunk = self.generate_routes.answer_generate_route[answer_node_id][route_position]
|
||||
|
||||
if route_chunk.type != GenerateRouteChunk.ChunkType.VAR:
|
||||
continue
|
||||
|
||||
route_chunk = cast(VarGenerateRouteChunk, route_chunk)
|
||||
value_selector = route_chunk.value_selector
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if value_selector != stream_output_value_selector:
|
||||
continue
|
||||
|
||||
stream_out_answer_node_ids.append(answer_node_id)
|
||||
|
||||
return stream_out_answer_node_ids
|
||||
|
||||
@classmethod
|
||||
def _fetch_files_from_variable_value(cls, value: dict | list) -> list[dict]:
|
||||
"""
|
||||
Fetch files from variable value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return []
|
||||
|
||||
files = []
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
file_var = cls._get_file_var_from_value(item)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
elif isinstance(value, dict):
|
||||
file_var = cls._get_file_var_from_value(value)
|
||||
if file_var:
|
||||
files.append(file_var)
|
||||
|
||||
return files
|
||||
|
||||
@classmethod
|
||||
def _get_file_var_from_value(cls, value: dict | list) -> Optional[dict]:
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict):
|
||||
if '__variant' in value and value['__variant'] == FileVar.__name__:
|
||||
return value
|
||||
elif isinstance(value, FileVar):
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
@ -0,0 +1,71 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, NodeRunSucceededEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
|
||||
|
||||
class StreamProcessor(ABC):
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
self.graph = graph
|
||||
self.variable_pool = variable_pool
|
||||
self.rest_node_ids = graph.node_ids.copy()
|
||||
|
||||
@abstractmethod
|
||||
def process(self,
|
||||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def _remove_unreachable_nodes(self, event: NodeRunSucceededEvent) -> None:
|
||||
finished_node_id = event.route_node_state.node_id
|
||||
if finished_node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
# remove finished node id
|
||||
self.rest_node_ids.remove(finished_node_id)
|
||||
|
||||
run_result = event.route_node_state.node_run_result
|
||||
if not run_result:
|
||||
return
|
||||
|
||||
if run_result.edge_source_handle:
|
||||
reachable_node_ids = []
|
||||
unreachable_first_node_ids = []
|
||||
for edge in self.graph.edge_mapping[finished_node_id]:
|
||||
if (edge.run_condition
|
||||
and edge.run_condition.branch_identify
|
||||
and run_result.edge_source_handle == edge.run_condition.branch_identify):
|
||||
reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
continue
|
||||
else:
|
||||
unreachable_first_node_ids.append(edge.target_node_id)
|
||||
|
||||
for node_id in unreachable_first_node_ids:
|
||||
self._remove_node_ids_in_unreachable_branch(node_id, reachable_node_ids)
|
||||
|
||||
def _fetch_node_ids_in_reachable_branch(self, node_id: str) -> list[str]:
|
||||
node_ids = []
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id == self.graph.root_node_id:
|
||||
continue
|
||||
|
||||
node_ids.append(edge.target_node_id)
|
||||
node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id))
|
||||
return node_ids
|
||||
|
||||
def _remove_node_ids_in_unreachable_branch(self, node_id: str, reachable_node_ids: list[str]) -> None:
|
||||
"""
|
||||
remove target node ids until merge
|
||||
"""
|
||||
if node_id not in self.rest_node_ids:
|
||||
return
|
||||
|
||||
self.rest_node_ids.remove(node_id)
|
||||
for edge in self.graph.edge_mapping.get(node_id, []):
|
||||
if edge.target_node_id in reachable_node_ids:
|
||||
continue
|
||||
|
||||
self._remove_node_ids_in_unreachable_branch(edge.target_node_id, reachable_node_ids)
|
||||
@ -0,0 +1,148 @@
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.end.entities import EndNodeData, EndStreamParam
|
||||
|
||||
|
||||
class EndStreamGeneratorRouter:
|
||||
|
||||
@classmethod
|
||||
def init(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_parallel_mapping: dict[str, str]
|
||||
) -> EndStreamParam:
|
||||
"""
|
||||
Get stream generate routes.
|
||||
:return:
|
||||
"""
|
||||
# parse stream output node value selector of end nodes
|
||||
end_stream_variable_selectors_mapping: dict[str, list[list[str]]] = {}
|
||||
for end_node_id, node_config in node_id_config_mapping.items():
|
||||
if not node_config.get('data', {}).get('type') == NodeType.END.value:
|
||||
continue
|
||||
|
||||
# skip end node in parallel
|
||||
if end_node_id in node_parallel_mapping:
|
||||
continue
|
||||
|
||||
# get generate route for stream output
|
||||
stream_variable_selectors = cls._extract_stream_variable_selector(node_id_config_mapping, node_config)
|
||||
end_stream_variable_selectors_mapping[end_node_id] = stream_variable_selectors
|
||||
|
||||
# fetch end dependencies
|
||||
end_node_ids = list(end_stream_variable_selectors_mapping.keys())
|
||||
end_dependencies = cls._fetch_ends_dependencies(
|
||||
end_node_ids=end_node_ids,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
node_id_config_mapping=node_id_config_mapping
|
||||
)
|
||||
|
||||
return EndStreamParam(
|
||||
end_stream_variable_selector_mapping=end_stream_variable_selectors_mapping,
|
||||
end_dependencies=end_dependencies
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_stream_variable_selector_from_node_data(cls,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
node_data: EndNodeData) -> list[list[str]]:
|
||||
"""
|
||||
Extract stream variable selector from node data
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param node_data: node data object
|
||||
:return:
|
||||
"""
|
||||
variable_selectors = node_data.outputs
|
||||
|
||||
value_selectors = []
|
||||
for variable_selector in variable_selectors:
|
||||
if not variable_selector.value_selector:
|
||||
continue
|
||||
|
||||
node_id = variable_selector.value_selector[0]
|
||||
if node_id != 'sys' and node_id in node_id_config_mapping:
|
||||
node = node_id_config_mapping[node_id]
|
||||
node_type = node.get('data', {}).get('type')
|
||||
if (
|
||||
variable_selector.value_selector not in value_selectors
|
||||
and node_type == NodeType.LLM.value
|
||||
and variable_selector.value_selector[1] == 'text'
|
||||
):
|
||||
value_selectors.append(variable_selector.value_selector)
|
||||
|
||||
return value_selectors
|
||||
|
||||
@classmethod
|
||||
def _extract_stream_variable_selector(cls, node_id_config_mapping: dict[str, dict], config: dict) \
|
||||
-> list[list[str]]:
|
||||
"""
|
||||
Extract stream variable selector from node config
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param config: node config
|
||||
:return:
|
||||
"""
|
||||
node_data = EndNodeData(**config.get("data", {}))
|
||||
return cls.extract_stream_variable_selector_from_node_data(node_id_config_mapping, node_data)
|
||||
|
||||
@classmethod
|
||||
def _fetch_ends_dependencies(cls,
|
||||
end_node_ids: list[str],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]], # type: ignore[name-defined]
|
||||
node_id_config_mapping: dict[str, dict]
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
Fetch end dependencies
|
||||
:param end_node_ids: end node ids
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:return:
|
||||
"""
|
||||
end_dependencies: dict[str, list[str]] = {}
|
||||
for end_node_id in end_node_ids:
|
||||
if end_dependencies.get(end_node_id) is None:
|
||||
end_dependencies[end_node_id] = []
|
||||
|
||||
cls._recursive_fetch_end_dependencies(
|
||||
current_node_id=end_node_id,
|
||||
end_node_id=end_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
end_dependencies=end_dependencies
|
||||
)
|
||||
|
||||
return end_dependencies
|
||||
|
||||
@classmethod
|
||||
def _recursive_fetch_end_dependencies(cls,
|
||||
current_node_id: str,
|
||||
end_node_id: str,
|
||||
node_id_config_mapping: dict[str, dict],
|
||||
reverse_edge_mapping: dict[str, list["GraphEdge"]],
|
||||
# type: ignore[name-defined]
|
||||
end_dependencies: dict[str, list[str]]
|
||||
) -> None:
|
||||
"""
|
||||
Recursive fetch end dependencies
|
||||
:param current_node_id: current node id
|
||||
:param end_node_id: end node id
|
||||
:param node_id_config_mapping: node id config mapping
|
||||
:param reverse_edge_mapping: reverse edge mapping
|
||||
:param end_dependencies: end dependencies
|
||||
:return:
|
||||
"""
|
||||
reverse_edges = reverse_edge_mapping.get(current_node_id, [])
|
||||
for edge in reverse_edges:
|
||||
source_node_id = edge.source_node_id
|
||||
source_node_type = node_id_config_mapping[source_node_id].get('data', {}).get('type')
|
||||
if source_node_type in (
|
||||
NodeType.IF_ELSE.value,
|
||||
NodeType.QUESTION_CLASSIFIER,
|
||||
):
|
||||
end_dependencies[end_node_id].append(source_node_id)
|
||||
else:
|
||||
cls._recursive_fetch_end_dependencies(
|
||||
current_node_id=source_node_id,
|
||||
end_node_id=end_node_id,
|
||||
node_id_config_mapping=node_id_config_mapping,
|
||||
reverse_edge_mapping=reverse_edge_mapping,
|
||||
end_dependencies=end_dependencies
|
||||
)
|
||||
@ -0,0 +1,191 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
GraphEngineEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.nodes.answer.base_stream_processor import StreamProcessor
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EndStreamProcessor(StreamProcessor):
|
||||
|
||||
def __init__(self, graph: Graph, variable_pool: VariablePool) -> None:
|
||||
super().__init__(graph, variable_pool)
|
||||
self.end_stream_param = graph.end_stream_param
|
||||
self.route_position = {}
|
||||
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
|
||||
self.route_position[end_node_id] = 0
|
||||
self.current_stream_chunk_generating_node_ids: dict[str, list[str]] = {}
|
||||
self.has_outputed = False
|
||||
self.outputed_node_ids = set()
|
||||
|
||||
def process(self,
|
||||
generator: Generator[GraphEngineEvent, None, None]
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
for event in generator:
|
||||
if isinstance(event, NodeRunStartedEvent):
|
||||
if event.route_node_state.node_id == self.graph.root_node_id and not self.rest_node_ids:
|
||||
self.reset()
|
||||
|
||||
yield event
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
if event.in_iteration_id:
|
||||
if self.has_outputed and event.node_id not in self.outputed_node_ids:
|
||||
event.chunk_content = '\n' + event.chunk_content
|
||||
|
||||
self.outputed_node_ids.add(event.node_id)
|
||||
self.has_outputed = True
|
||||
yield event
|
||||
continue
|
||||
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
stream_out_end_node_ids = self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
]
|
||||
else:
|
||||
stream_out_end_node_ids = self._get_stream_out_end_node_ids(event)
|
||||
self.current_stream_chunk_generating_node_ids[
|
||||
event.route_node_state.node_id
|
||||
] = stream_out_end_node_ids
|
||||
|
||||
if stream_out_end_node_ids:
|
||||
if self.has_outputed and event.node_id not in self.outputed_node_ids:
|
||||
event.chunk_content = '\n' + event.chunk_content
|
||||
|
||||
self.outputed_node_ids.add(event.node_id)
|
||||
self.has_outputed = True
|
||||
yield event
|
||||
elif isinstance(event, NodeRunSucceededEvent):
|
||||
yield event
|
||||
if event.route_node_state.node_id in self.current_stream_chunk_generating_node_ids:
|
||||
# update self.route_position after all stream event finished
|
||||
for end_node_id in self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]:
|
||||
self.route_position[end_node_id] += 1
|
||||
|
||||
del self.current_stream_chunk_generating_node_ids[event.route_node_state.node_id]
|
||||
|
||||
# remove unreachable nodes
|
||||
self._remove_unreachable_nodes(event)
|
||||
|
||||
# generate stream outputs
|
||||
yield from self._generate_stream_outputs_when_node_finished(event)
|
||||
else:
|
||||
yield event
|
||||
|
||||
def reset(self) -> None:
|
||||
self.route_position = {}
|
||||
for end_node_id, _ in self.end_stream_param.end_stream_variable_selector_mapping.items():
|
||||
self.route_position[end_node_id] = 0
|
||||
self.rest_node_ids = self.graph.node_ids.copy()
|
||||
self.current_stream_chunk_generating_node_ids = {}
|
||||
|
||||
def _generate_stream_outputs_when_node_finished(self,
|
||||
event: NodeRunSucceededEvent
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
Generate stream outputs.
|
||||
:param event: node run succeeded event
|
||||
:return:
|
||||
"""
|
||||
for end_node_id, position in self.route_position.items():
|
||||
# all depends on end node id not in rest node ids
|
||||
if (event.route_node_state.node_id != end_node_id
|
||||
and (end_node_id not in self.rest_node_ids
|
||||
or not all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.end_stream_param.end_dependencies[end_node_id]))):
|
||||
continue
|
||||
|
||||
route_position = self.route_position[end_node_id]
|
||||
|
||||
position = 0
|
||||
value_selectors = []
|
||||
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
|
||||
if position >= route_position:
|
||||
value_selectors.append(current_value_selectors)
|
||||
|
||||
position += 1
|
||||
|
||||
for value_selector in value_selectors:
|
||||
if not value_selector:
|
||||
continue
|
||||
|
||||
value = self.variable_pool.get(
|
||||
value_selector
|
||||
)
|
||||
|
||||
if value is None:
|
||||
break
|
||||
|
||||
text = value.markdown
|
||||
|
||||
if text:
|
||||
current_node_id = value_selector[0]
|
||||
if self.has_outputed and current_node_id not in self.outputed_node_ids:
|
||||
text = '\n' + text
|
||||
|
||||
self.outputed_node_ids.add(current_node_id)
|
||||
self.has_outputed = True
|
||||
yield NodeRunStreamChunkEvent(
|
||||
id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_data=event.node_data,
|
||||
chunk_content=text,
|
||||
from_variable_selector=value_selector,
|
||||
route_node_state=event.route_node_state,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
)
|
||||
|
||||
self.route_position[end_node_id] += 1
|
||||
|
||||
def _get_stream_out_end_node_ids(self, event: NodeRunStreamChunkEvent) -> list[str]:
|
||||
"""
|
||||
Is stream out support
|
||||
:param event: queue text chunk event
|
||||
:return:
|
||||
"""
|
||||
if not event.from_variable_selector:
|
||||
return []
|
||||
|
||||
stream_output_value_selector = event.from_variable_selector
|
||||
if not stream_output_value_selector:
|
||||
return []
|
||||
|
||||
stream_out_end_node_ids = []
|
||||
for end_node_id, route_position in self.route_position.items():
|
||||
if end_node_id not in self.rest_node_ids:
|
||||
continue
|
||||
|
||||
# all depends on end node id not in rest node ids
|
||||
if all(dep_id not in self.rest_node_ids
|
||||
for dep_id in self.end_stream_param.end_dependencies[end_node_id]):
|
||||
if route_position >= len(self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]):
|
||||
continue
|
||||
|
||||
position = 0
|
||||
value_selector = None
|
||||
for current_value_selectors in self.end_stream_param.end_stream_variable_selector_mapping[end_node_id]:
|
||||
if position == route_position:
|
||||
value_selector = current_value_selectors
|
||||
break
|
||||
|
||||
position += 1
|
||||
|
||||
if not value_selector:
|
||||
continue
|
||||
|
||||
# check chunk node id is before current node id or equal to current node id
|
||||
if value_selector != stream_output_value_selector:
|
||||
continue
|
||||
|
||||
stream_out_end_node_ids.append(end_node_id)
|
||||
|
||||
return stream_out_end_node_ids
|
||||
@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
|
||||
|
||||
class RunCompletedEvent(BaseModel):
|
||||
run_result: NodeRunResult = Field(..., description="run result")
|
||||
|
||||
|
||||
class RunStreamChunkEvent(BaseModel):
|
||||
chunk_content: str = Field(..., description="chunk content")
|
||||
from_variable_selector: list[str] = Field(..., description="from variable selector")
|
||||
|
||||
|
||||
class RunRetrieverResourceEvent(BaseModel):
|
||||
retriever_resources: list[dict] = Field(..., description="retriever resources")
|
||||
context: str = Field(..., description="context")
|
||||
|
||||
|
||||
RunEvent = RunCompletedEvent | RunStreamChunkEvent | RunRetrieverResourceEvent
|
||||
@ -1,124 +1,371 @@
|
||||
from typing import cast
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.workflow.entities.base_node_data_entities import BaseIterationState
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseIterationNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationState
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, NodeType
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseGraphEvent,
|
||||
BaseNodeEvent,
|
||||
BaseParallelBranchEvent,
|
||||
GraphRunFailedEvent,
|
||||
InNodeEvent,
|
||||
IterationRunFailedEvent,
|
||||
IterationRunNextEvent,
|
||||
IterationRunStartedEvent,
|
||||
IterationRunSucceededEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class IterationNode(BaseIterationNode):
|
||||
|
||||
class IterationNode(BaseNode):
|
||||
"""
|
||||
Iteration Node.
|
||||
"""
|
||||
_node_data_cls = IterationNodeData
|
||||
_node_type = NodeType.ITERATION
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> BaseIterationState:
|
||||
def _run(self) -> Generator[RunEvent | InNodeEvent, None, None]:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
self.node_data = cast(IterationNodeData, self.node_data)
|
||||
iterator = variable_pool.get_any(self.node_data.iterator_selector)
|
||||
|
||||
if not isinstance(iterator, list):
|
||||
raise ValueError(f"Invalid iterator value: {iterator}, please provide a list.")
|
||||
iterator_list_segment = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector)
|
||||
|
||||
state = IterationState(iteration_node_id=self.node_id, index=-1, inputs={
|
||||
'iterator_selector': iterator
|
||||
}, outputs=[], metadata=IterationState.MetaData(
|
||||
iterator_length=len(iterator) if iterator is not None else 0
|
||||
))
|
||||
if not iterator_list_segment:
|
||||
raise ValueError(f"Iterator variable {self.node_data.iterator_selector} not found")
|
||||
|
||||
self._set_current_iteration_variable(variable_pool, state)
|
||||
return state
|
||||
iterator_list_value = iterator_list_segment.to_object()
|
||||
|
||||
def _get_next_iteration(self, variable_pool: VariablePool, state: IterationState) -> NodeRunResult | str:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
:param graph: graph
|
||||
:return: next node id
|
||||
"""
|
||||
# resolve current output
|
||||
self._resolve_current_output(variable_pool, state)
|
||||
# move to next iteration
|
||||
self._next_iteration(variable_pool, state)
|
||||
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
if self._reached_iteration_limit(variable_pool, state):
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
if not isinstance(iterator_list_value, list):
|
||||
raise ValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.")
|
||||
|
||||
inputs = {
|
||||
"iterator_selector": iterator_list_value
|
||||
}
|
||||
|
||||
graph_config = self.graph_config
|
||||
|
||||
if not self.node_data.start_node_id:
|
||||
raise ValueError(f'field start_node_id in iteration {self.node_id} not found')
|
||||
|
||||
root_node_id = self.node_data.start_node_id
|
||||
|
||||
# init graph
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=root_node_id
|
||||
)
|
||||
|
||||
if not iteration_graph:
|
||||
raise ValueError('iteration graph not found')
|
||||
|
||||
leaf_node_ids = iteration_graph.get_leaf_node_ids()
|
||||
iteration_leaf_node_ids = []
|
||||
for leaf_node_id in leaf_node_ids:
|
||||
node_config = iteration_graph.node_id_config_mapping.get(leaf_node_id)
|
||||
if not node_config:
|
||||
continue
|
||||
|
||||
leaf_node_iteration_id = node_config.get("data", {}).get("iteration_id")
|
||||
if not leaf_node_iteration_id:
|
||||
continue
|
||||
|
||||
if leaf_node_iteration_id != self.node_id:
|
||||
continue
|
||||
|
||||
iteration_leaf_node_ids.append(leaf_node_id)
|
||||
|
||||
# add condition of end nodes to root node
|
||||
iteration_graph.add_extra_edge(
|
||||
source_node_id=leaf_node_id,
|
||||
target_node_id=root_node_id,
|
||||
run_condition=RunCondition(
|
||||
type="condition",
|
||||
conditions=[
|
||||
Condition(
|
||||
variable_selector=[self.node_id, "index"],
|
||||
comparison_operator="<",
|
||||
value=str(len(iterator_list_value))
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# append iteration variable (item, index) to variable pool
|
||||
variable_pool.add(
|
||||
[self.node_id, 'index'],
|
||||
0
|
||||
)
|
||||
variable_pool.add(
|
||||
[self.node_id, 'item'],
|
||||
iterator_list_value[0]
|
||||
)
|
||||
|
||||
# init graph engine
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=self.app_id,
|
||||
workflow_type=self.workflow_type,
|
||||
workflow_id=self.workflow_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from,
|
||||
invoke_from=self.invoke_from,
|
||||
call_depth=self.workflow_call_depth,
|
||||
graph=iteration_graph,
|
||||
graph_config=graph_config,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
|
||||
)
|
||||
|
||||
start_at = datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
yield IterationRunStartedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
metadata={
|
||||
"iterator_length": len(iterator_list_value)
|
||||
},
|
||||
predecessor_node_id=self.previous_node_id
|
||||
)
|
||||
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=0,
|
||||
pre_iteration_output=None
|
||||
)
|
||||
|
||||
outputs: list[Any] = []
|
||||
try:
|
||||
# run workflow
|
||||
rst = graph_engine.run()
|
||||
for event in rst:
|
||||
if isinstance(event, (BaseNodeEvent | BaseParallelBranchEvent)) and not event.in_iteration_id:
|
||||
event.in_iteration_id = self.node_id
|
||||
|
||||
if isinstance(event, BaseNodeEvent) and event.node_type == NodeType.ITERATION_START:
|
||||
continue
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
if event.route_node_state.node_run_result:
|
||||
metadata = event.route_node_state.node_run_result.metadata
|
||||
if not metadata:
|
||||
metadata = {}
|
||||
|
||||
if NodeRunMetadataKey.ITERATION_ID not in metadata:
|
||||
metadata[NodeRunMetadataKey.ITERATION_ID] = self.node_id
|
||||
metadata[NodeRunMetadataKey.ITERATION_INDEX] = variable_pool.get_any([self.node_id, 'index'])
|
||||
event.route_node_state.node_run_result.metadata = metadata
|
||||
|
||||
yield event
|
||||
|
||||
# handle iteration run result
|
||||
if event.route_node_state.node_id in iteration_leaf_node_ids:
|
||||
# append to iteration output variable list
|
||||
current_iteration_output = variable_pool.get_any(self.node_data.output_selector)
|
||||
outputs.append(current_iteration_output)
|
||||
|
||||
# remove all nodes outputs from variable pool
|
||||
for node_id in iteration_graph.node_ids:
|
||||
variable_pool.remove_node(node_id)
|
||||
|
||||
# move to next iteration
|
||||
current_index = variable_pool.get([self.node_id, 'index'])
|
||||
if current_index is None:
|
||||
raise ValueError(f'iteration {self.node_id} current index not found')
|
||||
|
||||
next_index = int(current_index.to_object()) + 1
|
||||
variable_pool.add(
|
||||
[self.node_id, 'index'],
|
||||
next_index
|
||||
)
|
||||
|
||||
if next_index < len(iterator_list_value):
|
||||
variable_pool.add(
|
||||
[self.node_id, 'item'],
|
||||
iterator_list_value[next_index]
|
||||
)
|
||||
|
||||
yield IterationRunNextEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
index=next_index,
|
||||
pre_iteration_output=jsonable_encoder(
|
||||
current_iteration_output) if current_iteration_output else None
|
||||
)
|
||||
elif isinstance(event, BaseGraphEvent):
|
||||
if isinstance(event, GraphRunFailedEvent):
|
||||
# iteration run failed
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={
|
||||
"output": jsonable_encoder(outputs)
|
||||
},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
error=event.error,
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=event.error,
|
||||
)
|
||||
)
|
||||
break
|
||||
else:
|
||||
event = cast(InNodeEvent, event)
|
||||
yield event
|
||||
|
||||
yield IterationRunSucceededEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={
|
||||
'output': jsonable_encoder(state.outputs)
|
||||
"output": jsonable_encoder(outputs)
|
||||
},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||
}
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
'output': jsonable_encoder(outputs)
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
# iteration run failed
|
||||
logger.exception("Iteration run failed")
|
||||
yield IterationRunFailedEvent(
|
||||
iteration_id=self.id,
|
||||
iteration_node_id=self.node_id,
|
||||
iteration_node_type=self.node_type,
|
||||
iteration_node_data=self.node_data,
|
||||
start_at=start_at,
|
||||
inputs=inputs,
|
||||
outputs={
|
||||
"output": jsonable_encoder(outputs)
|
||||
},
|
||||
steps=len(iterator_list_value),
|
||||
metadata={
|
||||
"total_tokens": graph_engine.graph_runtime_state.total_tokens
|
||||
},
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
return node_data.start_node_id
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
finally:
|
||||
# remove iteration variable (item, index) from variable pool after iteration run completed
|
||||
variable_pool.remove([self.node_id, 'index'])
|
||||
variable_pool.remove([self.node_id, 'item'])
|
||||
|
||||
def _set_current_iteration_variable(self, variable_pool: VariablePool, state: IterationState):
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IterationNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Set current iteration variable.
|
||||
:variable_pool: variable pool
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
variable_mapping = {
|
||||
f'{node_id}.input_selector': node_data.iterator_selector,
|
||||
}
|
||||
|
||||
variable_pool.add((self.node_id, 'index'), state.index)
|
||||
# get the iterator value
|
||||
iterator = variable_pool.get_any(node_data.iterator_selector)
|
||||
# init graph
|
||||
iteration_graph = Graph.init(
|
||||
graph_config=graph_config,
|
||||
root_node_id=node_data.start_node_id
|
||||
)
|
||||
|
||||
if iterator is None or not isinstance(iterator, list):
|
||||
return
|
||||
if not iteration_graph:
|
||||
raise ValueError('iteration graph not found')
|
||||
|
||||
if state.index < len(iterator):
|
||||
variable_pool.add((self.node_id, 'item'), iterator[state.index])
|
||||
for sub_node_id, sub_node_config in iteration_graph.node_id_config_mapping.items():
|
||||
if sub_node_config.get('data', {}).get('iteration_id') != node_id:
|
||||
continue
|
||||
|
||||
def _next_iteration(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Move to next iteration.
|
||||
:param variable_pool: variable pool
|
||||
"""
|
||||
state.index += 1
|
||||
self._set_current_iteration_variable(variable_pool, state)
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
# Get node class
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
node_type = NodeType.value_of(sub_node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
if not node_cls:
|
||||
continue
|
||||
|
||||
def _reached_iteration_limit(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Check if iteration limit is reached.
|
||||
:return: True if iteration limit is reached, False otherwise
|
||||
"""
|
||||
node_data = cast(IterationNodeData, self.node_data)
|
||||
iterator = variable_pool.get_any(node_data.iterator_selector)
|
||||
node_cls = cast(BaseNode, node_cls)
|
||||
|
||||
sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config,
|
||||
config=sub_node_config
|
||||
)
|
||||
sub_node_variable_mapping = cast(dict[str, list[str]], sub_node_variable_mapping)
|
||||
except NotImplementedError:
|
||||
sub_node_variable_mapping = {}
|
||||
|
||||
if iterator is None or not isinstance(iterator, list):
|
||||
return True
|
||||
# remove iteration variables
|
||||
sub_node_variable_mapping = {
|
||||
sub_node_id + '.' + key: value for key, value in sub_node_variable_mapping.items()
|
||||
if value[0] != node_id
|
||||
}
|
||||
|
||||
return state.index >= len(iterator)
|
||||
|
||||
def _resolve_current_output(self, variable_pool: VariablePool, state: IterationState):
|
||||
"""
|
||||
Resolve current output.
|
||||
:param variable_pool: variable pool
|
||||
"""
|
||||
output_selector = cast(IterationNodeData, self.node_data).output_selector
|
||||
output = variable_pool.get_any(output_selector)
|
||||
# clear the output for this iteration
|
||||
variable_pool.remove([self.node_id] + output_selector[1:])
|
||||
state.current_output = output
|
||||
if output is not None:
|
||||
# NOTE: This is a temporary patch to process double nested list (for example, DALL-E output in iteration).
|
||||
if isinstance(output, list):
|
||||
state.outputs.extend(output)
|
||||
else:
|
||||
state.outputs.append(output)
|
||||
variable_mapping.update(sub_node_variable_mapping)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(cls, node_data: IterationNodeData) -> dict[str, list[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {
|
||||
'input_selector': node_data.iterator_selector,
|
||||
}
|
||||
# remove variable out from iteration
|
||||
variable_mapping = {
|
||||
key: value for key, value in variable_mapping.items()
|
||||
if value[0] not in iteration_graph.node_ids
|
||||
}
|
||||
|
||||
return variable_mapping
|
||||
|
||||
@ -0,0 +1,39 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.iteration.entities import IterationNodeData, IterationStartNodeData
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
|
||||
class IterationStartNode(BaseNode):
|
||||
"""
|
||||
Iteration Start Node.
|
||||
"""
|
||||
_node_data_cls = IterationStartNodeData
|
||||
_node_type = NodeType.ITERATION_START
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
"""
|
||||
Run the node.
|
||||
"""
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: IterationNodeData
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selector to variable mapping
|
||||
:param graph_config: graph config
|
||||
:param node_id: node id
|
||||
:param node_data: node data
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
@ -1,20 +1,34 @@
|
||||
from core.workflow.entities.node_entities import NodeRunResult, NodeType
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import BaseIterationNode
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.loop.entities import LoopNodeData, LoopState
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
class LoopNode(BaseIterationNode):
|
||||
class LoopNode(BaseNode):
|
||||
"""
|
||||
Loop Node.
|
||||
"""
|
||||
_node_data_cls = LoopNodeData
|
||||
_node_type = NodeType.LOOP
|
||||
|
||||
def _run(self, variable_pool: VariablePool) -> LoopState:
|
||||
return super()._run(variable_pool)
|
||||
def _run(self) -> LoopState:
|
||||
return super()._run()
|
||||
|
||||
def _get_next_iteration(self, variable_loop: VariablePool) -> NodeRunResult | str:
|
||||
@classmethod
|
||||
def get_conditions(cls, node_config: dict[str, Any]) -> list[Condition]:
|
||||
"""
|
||||
Get next iteration start node id based on the graph.
|
||||
Get conditions.
|
||||
"""
|
||||
node_id = node_config.get('id')
|
||||
if not node_id:
|
||||
return []
|
||||
|
||||
# TODO waiting for implementation
|
||||
return [Condition(
|
||||
variable_selector=[node_id, 'index'],
|
||||
comparison_operator="≤",
|
||||
value_type="value_selector",
|
||||
value_selector=[]
|
||||
)]
|
||||
|
||||
@ -0,0 +1,37 @@
|
||||
from core.workflow.entities.node_entities import NodeType
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.end.end_node import EndNode
|
||||
from core.workflow.nodes.http_request.http_request_node import HttpRequestNode
|
||||
from core.workflow.nodes.if_else.if_else_node import IfElseNode
|
||||
from core.workflow.nodes.iteration.iteration_node import IterationNode
|
||||
from core.workflow.nodes.iteration.iteration_start_node import IterationStartNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from core.workflow.nodes.tool.tool_node import ToolNode
|
||||
from core.workflow.nodes.variable_aggregator.variable_aggregator_node import VariableAggregatorNode
|
||||
from core.workflow.nodes.variable_assigner import VariableAssignerNode
|
||||
|
||||
node_classes = {
|
||||
NodeType.START: StartNode,
|
||||
NodeType.END: EndNode,
|
||||
NodeType.ANSWER: AnswerNode,
|
||||
NodeType.LLM: LLMNode,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: KnowledgeRetrievalNode,
|
||||
NodeType.IF_ELSE: IfElseNode,
|
||||
NodeType.CODE: CodeNode,
|
||||
NodeType.TEMPLATE_TRANSFORM: TemplateTransformNode,
|
||||
NodeType.QUESTION_CLASSIFIER: QuestionClassifierNode,
|
||||
NodeType.HTTP_REQUEST: HttpRequestNode,
|
||||
NodeType.TOOL: ToolNode,
|
||||
NodeType.VARIABLE_AGGREGATOR: VariableAggregatorNode,
|
||||
NodeType.VARIABLE_ASSIGNER: VariableAggregatorNode, # original name of VARIABLE_AGGREGATOR
|
||||
NodeType.ITERATION: IterationNode,
|
||||
NodeType.ITERATION_START: IterationStartNode,
|
||||
NodeType.PARAMETER_EXTRACTOR: ParameterExtractorNode,
|
||||
NodeType.CONVERSATION_VARIABLE_ASSIGNER: VariableAssignerNode,
|
||||
}
|
||||
@ -0,0 +1,17 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Condition(BaseModel):
|
||||
"""
|
||||
Condition entity
|
||||
"""
|
||||
variable_selector: list[str]
|
||||
comparison_operator: Literal[
|
||||
# for string or array
|
||||
"contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty",
|
||||
# for number
|
||||
"=", "≠", ">", "<", "≥", "≤", "null", "not null"
|
||||
]
|
||||
value: Optional[str] = None
|
||||
@ -0,0 +1,383 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.file.file_obj import FileVar
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||
|
||||
|
||||
class ConditionProcessor:
|
||||
def process_conditions(self, variable_pool: VariablePool, conditions: Sequence[Condition]):
|
||||
input_conditions = []
|
||||
group_result = []
|
||||
|
||||
index = 0
|
||||
for condition in conditions:
|
||||
index += 1
|
||||
actual_value = variable_pool.get_any(
|
||||
condition.variable_selector
|
||||
)
|
||||
|
||||
expected_value = None
|
||||
if condition.value is not None:
|
||||
variable_template_parser = VariableTemplateParser(template=condition.value)
|
||||
variable_selectors = variable_template_parser.extract_variable_selectors()
|
||||
if variable_selectors:
|
||||
for variable_selector in variable_selectors:
|
||||
value = variable_pool.get_any(
|
||||
variable_selector.value_selector
|
||||
)
|
||||
expected_value = variable_template_parser.format({variable_selector.variable: value})
|
||||
|
||||
if expected_value is None:
|
||||
expected_value = condition.value
|
||||
else:
|
||||
expected_value = condition.value
|
||||
|
||||
comparison_operator = condition.comparison_operator
|
||||
input_conditions.append(
|
||||
{
|
||||
"actual_value": actual_value,
|
||||
"expected_value": expected_value,
|
||||
"comparison_operator": comparison_operator
|
||||
}
|
||||
)
|
||||
|
||||
result = self.evaluate_condition(actual_value, comparison_operator, expected_value)
|
||||
group_result.append(result)
|
||||
|
||||
return input_conditions, group_result
|
||||
|
||||
def evaluate_condition(
|
||||
self,
|
||||
actual_value: Optional[str | int | float | dict[Any, Any] | list[Any] | FileVar | None],
|
||||
comparison_operator: str,
|
||||
expected_value: Optional[str] = None
|
||||
) -> bool:
|
||||
"""
|
||||
Evaluate condition
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:param comparison_operator: comparison operator
|
||||
|
||||
:return: bool
|
||||
"""
|
||||
if comparison_operator == "contains":
|
||||
return self._assert_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "not contains":
|
||||
return self._assert_not_contains(actual_value, expected_value)
|
||||
elif comparison_operator == "start with":
|
||||
return self._assert_start_with(actual_value, expected_value)
|
||||
elif comparison_operator == "end with":
|
||||
return self._assert_end_with(actual_value, expected_value)
|
||||
elif comparison_operator == "is":
|
||||
return self._assert_is(actual_value, expected_value)
|
||||
elif comparison_operator == "is not":
|
||||
return self._assert_is_not(actual_value, expected_value)
|
||||
elif comparison_operator == "empty":
|
||||
return self._assert_empty(actual_value)
|
||||
elif comparison_operator == "not empty":
|
||||
return self._assert_not_empty(actual_value)
|
||||
elif comparison_operator == "=":
|
||||
return self._assert_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≠":
|
||||
return self._assert_not_equal(actual_value, expected_value)
|
||||
elif comparison_operator == ">":
|
||||
return self._assert_greater_than(actual_value, expected_value)
|
||||
elif comparison_operator == "<":
|
||||
return self._assert_less_than(actual_value, expected_value)
|
||||
elif comparison_operator == "≥":
|
||||
return self._assert_greater_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "≤":
|
||||
return self._assert_less_than_or_equal(actual_value, expected_value)
|
||||
elif comparison_operator == "null":
|
||||
return self._assert_null(actual_value)
|
||||
elif comparison_operator == "not null":
|
||||
return self._assert_not_null(actual_value)
|
||||
else:
|
||||
raise ValueError(f"Invalid comparison operator: {comparison_operator}")
|
||||
|
||||
def _assert_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert contains
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str | list):
|
||||
raise ValueError('Invalid actual value type: string or array')
|
||||
|
||||
if expected_value not in actual_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_not_contains(self, actual_value: Optional[str | list], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert not contains
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return True
|
||||
|
||||
if not isinstance(actual_value, str | list):
|
||||
raise ValueError('Invalid actual value type: string or array')
|
||||
|
||||
if expected_value in actual_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_start_with(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert start with
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if not actual_value.startswith(expected_value):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_end_with(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert end with
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if not actual_value.endswith(expected_value):
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_is(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert is
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if actual_value != expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_is_not(self, actual_value: Optional[str], expected_value: str) -> bool:
|
||||
"""
|
||||
Assert is not
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, str):
|
||||
raise ValueError('Invalid actual value type: string')
|
||||
|
||||
if actual_value == expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_empty(self, actual_value: Optional[str]) -> bool:
|
||||
"""
|
||||
Assert empty
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if not actual_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_not_empty(self, actual_value: Optional[str]) -> bool:
|
||||
"""
|
||||
Assert not empty
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value != expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_not_equal(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert not equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value == expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_greater_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert greater than
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value <= expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_less_than(self, actual_value: Optional[int | float], expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert less than
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value >= expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_greater_than_or_equal(self, actual_value: Optional[int | float],
|
||||
expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert greater than or equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value < expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_less_than_or_equal(self, actual_value: Optional[int | float],
|
||||
expected_value: str | int | float) -> bool:
|
||||
"""
|
||||
Assert less than or equal
|
||||
:param actual_value: actual value
|
||||
:param expected_value: expected value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return False
|
||||
|
||||
if not isinstance(actual_value, int | float):
|
||||
raise ValueError('Invalid actual value type: number')
|
||||
|
||||
if isinstance(actual_value, int):
|
||||
expected_value = int(expected_value)
|
||||
else:
|
||||
expected_value = float(expected_value)
|
||||
|
||||
if actual_value > expected_value:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _assert_null(self, actual_value: Optional[int | float]) -> bool:
|
||||
"""
|
||||
Assert null
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is None:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _assert_not_null(self, actual_value: Optional[int | float]) -> bool:
|
||||
"""
|
||||
Assert not null
|
||||
:param actual_value: actual value
|
||||
:return:
|
||||
"""
|
||||
if actual_value is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class ConditionAssertionError(Exception):
|
||||
def __init__(self, message: str, conditions: list[dict], sub_condition_compare_results: list[dict]) -> None:
|
||||
self.message = message
|
||||
self.conditions = conditions
|
||||
self.sub_condition_compare_results = sub_condition_compare_results
|
||||
super().__init__(self.message)
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,314 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.app_config.entities import FileExtraConfig
|
||||
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedException
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.file_obj import FileTransferMethod, FileType, FileVar
|
||||
from core.workflow.callbacks.base_workflow_callback import WorkflowCallback
|
||||
from core.workflow.entities.base_node_data_entities import BaseNodeData
|
||||
from core.workflow.entities.node_entities import NodeType, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.base_node import BaseNode
|
||||
from core.workflow.nodes.event import RunEvent
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from core.workflow.nodes.node_mapping import node_classes
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowType,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowEntry:
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
workflow_type: WorkflowType,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph: Graph,
|
||||
user_id: str,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
call_depth: int,
|
||||
variable_pool: VariablePool,
|
||||
thread_pool_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Init workflow entry
|
||||
:param tenant_id: tenant id
|
||||
:param app_id: app id
|
||||
:param workflow_id: workflow id
|
||||
:param workflow_type: workflow type
|
||||
:param graph_config: workflow graph config
|
||||
:param graph: workflow graph
|
||||
:param user_id: user id
|
||||
:param user_from: user from
|
||||
:param invoke_from: invoke from
|
||||
:param call_depth: call depth
|
||||
:param variable_pool: variable pool
|
||||
:param thread_pool_id: thread pool id
|
||||
"""
|
||||
# check call depth
|
||||
workflow_call_max_depth = dify_config.WORKFLOW_CALL_MAX_DEPTH
|
||||
if call_depth > workflow_call_max_depth:
|
||||
raise ValueError('Max workflow call depth {} reached.'.format(workflow_call_max_depth))
|
||||
|
||||
# init workflow run state
|
||||
self.graph_engine = GraphEngine(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
workflow_type=workflow_type,
|
||||
workflow_id=workflow_id,
|
||||
user_id=user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=call_depth,
|
||||
graph=graph,
|
||||
graph_config=graph_config,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS,
|
||||
max_execution_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME,
|
||||
thread_pool_id=thread_pool_id
|
||||
)
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
callbacks: Sequence[WorkflowCallback],
|
||||
) -> Generator[GraphEngineEvent, None, None]:
|
||||
"""
|
||||
:param callbacks: workflow callbacks
|
||||
"""
|
||||
graph_engine = self.graph_engine
|
||||
|
||||
try:
|
||||
# run workflow
|
||||
generator = graph_engine.run()
|
||||
for event in generator:
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_event(
|
||||
event=event
|
||||
)
|
||||
yield event
|
||||
except GenerateTaskStoppedException:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception("Unknown Error when workflow entry running")
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
callback.on_event(
|
||||
event=GraphRunFailedEvent(
|
||||
error=str(e)
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def single_step_run(
|
||||
cls,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_id: str,
|
||||
user_inputs: dict
|
||||
) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]:
|
||||
"""
|
||||
Single step run workflow node
|
||||
:param workflow: Workflow instance
|
||||
:param node_id: node id
|
||||
:param user_id: user id
|
||||
:param user_inputs: user inputs
|
||||
:return:
|
||||
"""
|
||||
# fetch node info from workflow graph
|
||||
graph = workflow.graph_dict
|
||||
if not graph:
|
||||
raise ValueError('workflow graph not found')
|
||||
|
||||
nodes = graph.get('nodes')
|
||||
if not nodes:
|
||||
raise ValueError('nodes not found in workflow graph')
|
||||
|
||||
# fetch node config from node id
|
||||
node_config = None
|
||||
for node in nodes:
|
||||
if node.get('id') == node_id:
|
||||
node_config = node
|
||||
break
|
||||
|
||||
if not node_config:
|
||||
raise ValueError('node id not found in workflow graph')
|
||||
|
||||
# Get node class
|
||||
node_type = NodeType.value_of(node_config.get('data', {}).get('type'))
|
||||
node_cls = node_classes.get(node_type)
|
||||
node_cls = cast(type[BaseNode], node_cls)
|
||||
|
||||
if not node_cls:
|
||||
raise ValueError(f'Node class not found for node type {node_type}')
|
||||
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables={},
|
||||
user_inputs={},
|
||||
environment_variables=workflow.environment_variables,
|
||||
)
|
||||
|
||||
# init graph
|
||||
graph = Graph.init(
|
||||
graph_config=workflow.graph_dict
|
||||
)
|
||||
|
||||
# init workflow run state
|
||||
node_instance: BaseNode = node_cls(
|
||||
id=str(uuid.uuid4()),
|
||||
config=node_config,
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id=workflow.tenant_id,
|
||||
app_id=workflow.app_id,
|
||||
workflow_type=WorkflowType.value_of(workflow.type),
|
||||
workflow_id=workflow.id,
|
||||
graph_config=workflow.graph_dict,
|
||||
user_id=user_id,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0
|
||||
),
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter()
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# variable selector to variable mapping
|
||||
try:
|
||||
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
|
||||
graph_config=workflow.graph_dict,
|
||||
config=node_config
|
||||
)
|
||||
except NotImplementedError:
|
||||
variable_mapping = {}
|
||||
|
||||
cls.mapping_user_inputs_to_variable_pool(
|
||||
variable_mapping=variable_mapping,
|
||||
user_inputs=user_inputs,
|
||||
variable_pool=variable_pool,
|
||||
tenant_id=workflow.tenant_id,
|
||||
node_type=node_type,
|
||||
node_data=node_instance.node_data
|
||||
)
|
||||
|
||||
# run node
|
||||
generator = node_instance.run()
|
||||
|
||||
return node_instance, generator
|
||||
except Exception as e:
|
||||
raise WorkflowNodeRunFailedError(
|
||||
node_instance=node_instance,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]:
|
||||
"""
|
||||
Handle special values
|
||||
:param value: value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
new_value = dict(value) if value else {}
|
||||
if isinstance(new_value, dict):
|
||||
for key, val in new_value.items():
|
||||
if isinstance(val, FileVar):
|
||||
new_value[key] = val.to_dict()
|
||||
elif isinstance(val, list):
|
||||
new_val = []
|
||||
for v in val:
|
||||
if isinstance(v, FileVar):
|
||||
new_val.append(v.to_dict())
|
||||
else:
|
||||
new_val.append(v)
|
||||
|
||||
new_value[key] = new_val
|
||||
|
||||
return new_value
|
||||
|
||||
@classmethod
|
||||
def mapping_user_inputs_to_variable_pool(
|
||||
cls,
|
||||
variable_mapping: Mapping[str, Sequence[str]],
|
||||
user_inputs: dict,
|
||||
variable_pool: VariablePool,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_data: BaseNodeData
|
||||
) -> None:
|
||||
for node_variable, variable_selector in variable_mapping.items():
|
||||
# fetch node id and variable key from node_variable
|
||||
node_variable_list = node_variable.split('.')
|
||||
if len(node_variable_list) < 1:
|
||||
raise ValueError(f'Invalid node variable {node_variable}')
|
||||
|
||||
node_variable_key = '.'.join(node_variable_list[1:])
|
||||
|
||||
if (
|
||||
node_variable_key not in user_inputs
|
||||
and node_variable not in user_inputs
|
||||
) and not variable_pool.get(variable_selector):
|
||||
raise ValueError(f'Variable key {node_variable} not found in user inputs.')
|
||||
|
||||
# fetch variable node id from variable selector
|
||||
variable_node_id = variable_selector[0]
|
||||
variable_key_list = variable_selector[1:]
|
||||
variable_key_list = cast(list[str], variable_key_list)
|
||||
|
||||
# get input value
|
||||
input_value = user_inputs.get(node_variable)
|
||||
if not input_value:
|
||||
input_value = user_inputs.get(node_variable_key)
|
||||
|
||||
# FIXME: temp fix for image type
|
||||
if node_type == NodeType.LLM:
|
||||
new_value = []
|
||||
if isinstance(input_value, list):
|
||||
node_data = cast(LLMNodeData, node_data)
|
||||
|
||||
detail = node_data.vision.configs.detail if node_data.vision.configs else None
|
||||
|
||||
for item in input_value:
|
||||
if isinstance(item, dict) and 'type' in item and item['type'] == 'image':
|
||||
transfer_method = FileTransferMethod.value_of(item.get('transfer_method'))
|
||||
file = FileVar(
|
||||
tenant_id=tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=transfer_method,
|
||||
url=item.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
related_id=item.get(
|
||||
'upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
extra_config=FileExtraConfig(image_config={'detail': detail} if detail else None),
|
||||
)
|
||||
new_value.append(file)
|
||||
|
||||
if new_value:
|
||||
value = new_value
|
||||
|
||||
# append variable and value to variable pool
|
||||
variable_pool.add([variable_node_id] + variable_key_list, input_value)
|
||||
@ -0,0 +1,35 @@
|
||||
"""add node_execution_id into node_executions
|
||||
|
||||
Revision ID: 675b5321501b
|
||||
Revises: 030f4915f36a
|
||||
Create Date: 2024-08-12 10:54:02.259331
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
import models as models
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '675b5321501b'
|
||||
down_revision = '030f4915f36a'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('node_execution_id', sa.String(length=255), nullable=True))
|
||||
batch_op.create_index('workflow_node_execution_id_idx', ['tenant_id', 'app_id', 'workflow_id', 'triggered_from', 'node_execution_id'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_node_executions', schema=None) as batch_op:
|
||||
batch_op.drop_index('workflow_node_execution_id_idx')
|
||||
batch_op.drop_column('node_execution_id')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -1,46 +1,84 @@
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.nodes.base_node import UserFrom
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock
|
||||
|
||||
|
||||
@pytest.mark.parametrize("setup_code_executor_mock", [["none"]], indirect=True)
|
||||
def test_execute_code(setup_code_executor_mock):
|
||||
code = """{{args2}}"""
|
||||
node = TemplateTransformNode(
|
||||
config = {
|
||||
"id": "1",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
],
|
||||
"template": code,
|
||||
},
|
||||
}
|
||||
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-next-target",
|
||||
"source": "start",
|
||||
"target": "1",
|
||||
},
|
||||
],
|
||||
"nodes": [{"data": {"type": "start"}, "id": "start"}, config],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
init_params = GraphInitParams(
|
||||
tenant_id="1",
|
||||
app_id="1",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="1",
|
||||
graph_config=graph_config,
|
||||
user_id="1",
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
user_from=UserFrom.END_USER,
|
||||
config={
|
||||
"id": "1",
|
||||
"data": {
|
||||
"title": "123",
|
||||
"variables": [
|
||||
{
|
||||
"variable": "args1",
|
||||
"value_selector": ["1", "123", "args1"],
|
||||
},
|
||||
{"variable": "args2", "value_selector": ["1", "123", "args2"]},
|
||||
],
|
||||
"template": code,
|
||||
},
|
||||
},
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# construct variable pool
|
||||
pool = VariablePool(system_variables={}, user_inputs={}, environment_variables=[])
|
||||
pool.add(["1", "123", "args1"], 1)
|
||||
pool.add(["1", "123", "args2"], 3)
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"},
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
variable_pool.add(["1", "123", "args1"], 1)
|
||||
variable_pool.add(["1", "123", "args2"], 3)
|
||||
|
||||
node = TemplateTransformNode(
|
||||
id=str(uuid.uuid4()),
|
||||
graph_init_params=init_params,
|
||||
graph=graph,
|
||||
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
|
||||
config=config,
|
||||
)
|
||||
|
||||
# execute node
|
||||
result = node.run(pool)
|
||||
result = node._run()
|
||||
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs is not None
|
||||
assert result.outputs["output"] == "3"
|
||||
|
||||
@ -1,7 +1,24 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
# Getting the absolute path of the current file's directory
|
||||
ABS_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# Getting the absolute path of the project's root directory
|
||||
PROJECT_DIR = os.path.abspath(os.path.join(ABS_PATH, os.pardir, os.pardir))
|
||||
|
||||
CACHED_APP = Flask(__name__)
|
||||
CACHED_APP.config.update({"TESTING": True})
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def app() -> Flask:
|
||||
return CACHED_APP
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _provide_app_context(app: Flask):
|
||||
with app.app_context():
|
||||
yield
|
||||
|
||||
@ -0,0 +1,791 @@
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.run_condition import RunCondition
|
||||
from core.workflow.utils.condition.entities import Condition
|
||||
|
||||
|
||||
def test_init():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "llm-source-answer-target",
|
||||
"source": "llm",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "start-source-qc-target",
|
||||
"source": "start",
|
||||
"target": "qc",
|
||||
},
|
||||
{
|
||||
"id": "qc-1-llm-target",
|
||||
"source": "qc",
|
||||
"sourceHandle": "1",
|
||||
"target": "llm",
|
||||
},
|
||||
{
|
||||
"id": "qc-2-http-target",
|
||||
"source": "qc",
|
||||
"sourceHandle": "2",
|
||||
"target": "http",
|
||||
},
|
||||
{
|
||||
"id": "http-source-answer2-target",
|
||||
"source": "http",
|
||||
"target": "answer2",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
{
|
||||
"data": {"type": "question-classifier"},
|
||||
"id": "qc",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "http-request",
|
||||
},
|
||||
"id": "http",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer2",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
start_node_id = "start"
|
||||
|
||||
assert graph.root_node_id == start_node_id
|
||||
assert graph.edge_mapping.get(start_node_id)[0].target_node_id == "qc"
|
||||
assert {"llm", "http"} == {node.target_node_id for node in graph.edge_mapping.get("qc")}
|
||||
|
||||
|
||||
def test__init_iteration_graph():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "llm-answer",
|
||||
"source": "llm",
|
||||
"sourceHandle": "source",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "iteration-source-llm-target",
|
||||
"source": "iteration",
|
||||
"sourceHandle": "source",
|
||||
"target": "llm",
|
||||
},
|
||||
{
|
||||
"id": "template-transform-in-iteration-source-llm-in-iteration-target",
|
||||
"source": "template-transform-in-iteration",
|
||||
"sourceHandle": "source",
|
||||
"target": "llm-in-iteration",
|
||||
},
|
||||
{
|
||||
"id": "llm-in-iteration-source-answer-in-iteration-target",
|
||||
"source": "llm-in-iteration",
|
||||
"sourceHandle": "source",
|
||||
"target": "answer-in-iteration",
|
||||
},
|
||||
{
|
||||
"id": "start-source-code-target",
|
||||
"source": "start",
|
||||
"sourceHandle": "source",
|
||||
"target": "code",
|
||||
},
|
||||
{
|
||||
"id": "code-source-iteration-target",
|
||||
"source": "code",
|
||||
"sourceHandle": "source",
|
||||
"target": "iteration",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"type": "start",
|
||||
},
|
||||
"id": "start",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
{
|
||||
"data": {"type": "iteration"},
|
||||
"id": "iteration",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "template-transform",
|
||||
},
|
||||
"id": "template-transform-in-iteration",
|
||||
"parentId": "iteration",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm-in-iteration",
|
||||
"parentId": "iteration",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer-in-iteration",
|
||||
"parentId": "iteration",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config, root_node_id="template-transform-in-iteration")
|
||||
graph.add_extra_edge(
|
||||
source_node_id="answer-in-iteration",
|
||||
target_node_id="template-transform-in-iteration",
|
||||
run_condition=RunCondition(
|
||||
type="condition",
|
||||
conditions=[Condition(variable_selector=["iteration", "index"], comparison_operator="≤", value="5")],
|
||||
),
|
||||
)
|
||||
|
||||
# iteration:
|
||||
# [template-transform-in-iteration -> llm-in-iteration -> answer-in-iteration]
|
||||
|
||||
assert graph.root_node_id == "template-transform-in-iteration"
|
||||
assert graph.edge_mapping.get("template-transform-in-iteration")[0].target_node_id == "llm-in-iteration"
|
||||
assert graph.edge_mapping.get("llm-in-iteration")[0].target_node_id == "answer-in-iteration"
|
||||
assert graph.edge_mapping.get("answer-in-iteration")[0].target_node_id == "template-transform-in-iteration"
|
||||
|
||||
|
||||
def test_parallels_graph():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-answer-target",
|
||||
"source": "llm1",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-answer-target",
|
||||
"source": "llm2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-answer-target",
|
||||
"source": "llm3",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(3):
|
||||
start_edges = graph.edge_mapping.get("start")
|
||||
assert start_edges is not None
|
||||
assert start_edges[i].target_node_id == f"llm{i+1}"
|
||||
|
||||
llm_edges = graph.edge_mapping.get(f"llm{i+1}")
|
||||
assert llm_edges is not None
|
||||
assert llm_edges[0].target_node_id == "answer"
|
||||
|
||||
assert len(graph.parallel_mapping) == 1
|
||||
assert len(graph.node_parallel_mapping) == 3
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
|
||||
def test_parallels_graph2():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-answer-target",
|
||||
"source": "llm1",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-answer-target",
|
||||
"source": "llm2",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(3):
|
||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||
|
||||
if i < 2:
|
||||
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
|
||||
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == "answer"
|
||||
|
||||
assert len(graph.parallel_mapping) == 1
|
||||
assert len(graph.node_parallel_mapping) == 3
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
|
||||
def test_parallels_graph3():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(3):
|
||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||
|
||||
assert len(graph.parallel_mapping) == 1
|
||||
assert len(graph.node_parallel_mapping) == 3
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
|
||||
def test_parallels_graph4():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-answer-target",
|
||||
"source": "llm1",
|
||||
"target": "code1",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-answer-target",
|
||||
"source": "llm2",
|
||||
"target": "code2",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-code3-target",
|
||||
"source": "llm3",
|
||||
"target": "code3",
|
||||
},
|
||||
{
|
||||
"id": "code1-source-answer-target",
|
||||
"source": "code1",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "code2-source-answer-target",
|
||||
"source": "code2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "code3-source-answer-target",
|
||||
"source": "code3",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(3):
|
||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||
assert graph.edge_mapping.get(f"llm{i + 1}") is not None
|
||||
assert graph.edge_mapping.get(f"llm{i + 1}")[0].target_node_id == f"code{i + 1}"
|
||||
assert graph.edge_mapping.get(f"code{i + 1}") is not None
|
||||
assert graph.edge_mapping.get(f"code{i + 1}")[0].target_node_id == "answer"
|
||||
|
||||
assert len(graph.parallel_mapping) == 1
|
||||
assert len(graph.node_parallel_mapping) == 6
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
|
||||
def test_parallels_graph5():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm4",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm5",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-code1-target",
|
||||
"source": "llm1",
|
||||
"target": "code1",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-code1-target",
|
||||
"source": "llm2",
|
||||
"target": "code1",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-code2-target",
|
||||
"source": "llm3",
|
||||
"target": "code2",
|
||||
},
|
||||
{
|
||||
"id": "llm4-source-code2-target",
|
||||
"source": "llm4",
|
||||
"target": "code2",
|
||||
},
|
||||
{
|
||||
"id": "llm5-source-code3-target",
|
||||
"source": "llm5",
|
||||
"target": "code3",
|
||||
},
|
||||
{
|
||||
"id": "code1-source-answer-target",
|
||||
"source": "code1",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "code2-source-answer-target",
|
||||
"source": "code2",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm4",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm5",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(5):
|
||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||
|
||||
assert graph.edge_mapping.get("llm1") is not None
|
||||
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
|
||||
assert graph.edge_mapping.get("llm2") is not None
|
||||
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code1"
|
||||
assert graph.edge_mapping.get("llm3") is not None
|
||||
assert graph.edge_mapping.get("llm3")[0].target_node_id == "code2"
|
||||
assert graph.edge_mapping.get("llm4") is not None
|
||||
assert graph.edge_mapping.get("llm4")[0].target_node_id == "code2"
|
||||
assert graph.edge_mapping.get("llm5") is not None
|
||||
assert graph.edge_mapping.get("llm5")[0].target_node_id == "code3"
|
||||
assert graph.edge_mapping.get("code1") is not None
|
||||
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
|
||||
assert graph.edge_mapping.get("code2") is not None
|
||||
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
|
||||
|
||||
assert len(graph.parallel_mapping) == 1
|
||||
assert len(graph.node_parallel_mapping) == 8
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3", "llm4", "llm5", "code1", "code2", "code3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
|
||||
def test_parallels_graph6():
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "start-source-llm1-target",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm2-target",
|
||||
"source": "start",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "start-source-llm3-target",
|
||||
"source": "start",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-code1-target",
|
||||
"source": "llm1",
|
||||
"target": "code1",
|
||||
},
|
||||
{
|
||||
"id": "llm1-source-code2-target",
|
||||
"source": "llm1",
|
||||
"target": "code2",
|
||||
},
|
||||
{
|
||||
"id": "llm2-source-code3-target",
|
||||
"source": "llm2",
|
||||
"target": "code3",
|
||||
},
|
||||
{
|
||||
"id": "code1-source-answer-target",
|
||||
"source": "code1",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "code2-source-answer-target",
|
||||
"source": "code2",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "code3-source-answer-target",
|
||||
"source": "code3",
|
||||
"target": "answer",
|
||||
},
|
||||
{
|
||||
"id": "llm3-source-answer-target",
|
||||
"source": "llm3",
|
||||
"target": "answer",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start"}, "id": "start"},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "code",
|
||||
},
|
||||
"id": "code3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer", "answer": "1"},
|
||||
"id": "answer",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
assert graph.root_node_id == "start"
|
||||
for i in range(3):
|
||||
assert graph.edge_mapping.get("start")[i].target_node_id == f"llm{i + 1}"
|
||||
|
||||
assert graph.edge_mapping.get("llm1") is not None
|
||||
assert graph.edge_mapping.get("llm1")[0].target_node_id == "code1"
|
||||
assert graph.edge_mapping.get("llm1") is not None
|
||||
assert graph.edge_mapping.get("llm1")[1].target_node_id == "code2"
|
||||
assert graph.edge_mapping.get("llm2") is not None
|
||||
assert graph.edge_mapping.get("llm2")[0].target_node_id == "code3"
|
||||
assert graph.edge_mapping.get("code1") is not None
|
||||
assert graph.edge_mapping.get("code1")[0].target_node_id == "answer"
|
||||
assert graph.edge_mapping.get("code2") is not None
|
||||
assert graph.edge_mapping.get("code2")[0].target_node_id == "answer"
|
||||
assert graph.edge_mapping.get("code3") is not None
|
||||
assert graph.edge_mapping.get("code3")[0].target_node_id == "answer"
|
||||
|
||||
assert len(graph.parallel_mapping) == 2
|
||||
assert len(graph.node_parallel_mapping) == 6
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3", "code1", "code2", "code3"]:
|
||||
assert node_id in graph.node_parallel_mapping
|
||||
|
||||
parent_parallel = None
|
||||
child_parallel = None
|
||||
for p_id, parallel in graph.parallel_mapping.items():
|
||||
if parallel.parent_parallel_id is None:
|
||||
parent_parallel = parallel
|
||||
else:
|
||||
child_parallel = parallel
|
||||
|
||||
for node_id in ["llm1", "llm2", "llm3", "code3"]:
|
||||
assert graph.node_parallel_mapping[node_id] == parent_parallel.id
|
||||
|
||||
for node_id in ["code1", "code2"]:
|
||||
assert graph.node_parallel_mapping[node_id] == child_parallel.id
|
||||
@ -0,0 +1,505 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult, UserFrom
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.graph_engine.entities.event import (
|
||||
BaseNodeEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunStartedEvent,
|
||||
NodeRunStreamChunkEvent,
|
||||
NodeRunSucceededEvent,
|
||||
)
|
||||
from core.workflow.graph_engine.entities.graph import Graph
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState
|
||||
from core.workflow.graph_engine.graph_engine import GraphEngine
|
||||
from core.workflow.nodes.event import RunCompletedEvent, RunStreamChunkEvent
|
||||
from core.workflow.nodes.llm.llm_node import LLMNode
|
||||
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
|
||||
|
||||
|
||||
@patch("extensions.ext_database.db.session.remove")
|
||||
@patch("extensions.ext_database.db.session.close")
|
||||
def test_run_parallel_in_workflow(mock_close, mock_remove):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "1",
|
||||
"source": "start",
|
||||
"target": "llm1",
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"source": "llm1",
|
||||
"target": "llm2",
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"source": "llm1",
|
||||
"target": "llm3",
|
||||
},
|
||||
{
|
||||
"id": "4",
|
||||
"source": "llm2",
|
||||
"target": "end1",
|
||||
},
|
||||
{
|
||||
"id": "5",
|
||||
"source": "llm3",
|
||||
"target": "end2",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"type": "start",
|
||||
"title": "start",
|
||||
"variables": [
|
||||
{
|
||||
"label": "query",
|
||||
"max_length": 48,
|
||||
"options": [],
|
||||
"required": True,
|
||||
"type": "text-input",
|
||||
"variable": "query",
|
||||
}
|
||||
],
|
||||
},
|
||||
"id": "start",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
"title": "llm1",
|
||||
"context": {"enabled": False, "variable_selector": []},
|
||||
"model": {
|
||||
"completion_params": {"temperature": 0.7},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai",
|
||||
},
|
||||
"prompt_template": [
|
||||
{"role": "system", "text": "say hi"},
|
||||
{"role": "user", "text": "{{#start.query#}}"},
|
||||
],
|
||||
"vision": {"configs": {"detail": "high"}, "enabled": False},
|
||||
},
|
||||
"id": "llm1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
"title": "llm2",
|
||||
"context": {"enabled": False, "variable_selector": []},
|
||||
"model": {
|
||||
"completion_params": {"temperature": 0.7},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai",
|
||||
},
|
||||
"prompt_template": [
|
||||
{"role": "system", "text": "say bye"},
|
||||
{"role": "user", "text": "{{#start.query#}}"},
|
||||
],
|
||||
"vision": {"configs": {"detail": "high"}, "enabled": False},
|
||||
},
|
||||
"id": "llm2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "llm",
|
||||
"title": "llm3",
|
||||
"context": {"enabled": False, "variable_selector": []},
|
||||
"model": {
|
||||
"completion_params": {"temperature": 0.7},
|
||||
"mode": "chat",
|
||||
"name": "gpt-4o",
|
||||
"provider": "openai",
|
||||
},
|
||||
"prompt_template": [
|
||||
{"role": "system", "text": "say good morning"},
|
||||
{"role": "user", "text": "{{#start.query#}}"},
|
||||
],
|
||||
"vision": {"configs": {"detail": "high"}, "enabled": False},
|
||||
},
|
||||
"id": "llm3",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "end",
|
||||
"title": "end1",
|
||||
"outputs": [
|
||||
{"value_selector": ["llm2", "text"], "variable": "result2"},
|
||||
{"value_selector": ["start", "query"], "variable": "query"},
|
||||
],
|
||||
},
|
||||
"id": "end1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"type": "end",
|
||||
"title": "end2",
|
||||
"outputs": [
|
||||
{"value_selector": ["llm1", "text"], "variable": "result1"},
|
||||
{"value_selector": ["llm3", "text"], "variable": "result3"},
|
||||
],
|
||||
},
|
||||
"id": "end2",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={SystemVariableKey.FILES: [], SystemVariableKey.USER_ID: "aaa"}, user_inputs={"query": "hi"}
|
||||
)
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
workflow_type=WorkflowType.WORKFLOW,
|
||||
workflow_id="333",
|
||||
graph_config=graph_config,
|
||||
user_id="444",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
|
||||
def llm_generator(self):
|
||||
contents = ["hi", "bye", "good morning"]
|
||||
|
||||
yield RunStreamChunkEvent(
|
||||
chunk_content=contents[int(self.node_id[-1]) - 1], from_variable_selector=[self.node_id, "text"]
|
||||
)
|
||||
|
||||
yield RunCompletedEvent(
|
||||
run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={},
|
||||
process_data={},
|
||||
outputs={},
|
||||
metadata={
|
||||
NodeRunMetadataKey.TOTAL_TOKENS: 1,
|
||||
NodeRunMetadataKey.TOTAL_PRICE: 1,
|
||||
NodeRunMetadataKey.CURRENCY: "USD",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# print("")
|
||||
|
||||
with patch.object(LLMNode, "_run", new=llm_generator):
|
||||
items = []
|
||||
generator = graph_engine.run()
|
||||
for item in generator:
|
||||
# print(type(item), item)
|
||||
items.append(item)
|
||||
if isinstance(item, NodeRunSucceededEvent):
|
||||
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
|
||||
|
||||
assert not isinstance(item, NodeRunFailedEvent)
|
||||
assert not isinstance(item, GraphRunFailedEvent)
|
||||
|
||||
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in ["llm2", "llm3", "end1", "end2"]:
|
||||
assert item.parallel_id is not None
|
||||
|
||||
assert len(items) == 18
|
||||
assert isinstance(items[0], GraphRunStartedEvent)
|
||||
assert isinstance(items[1], NodeRunStartedEvent)
|
||||
assert items[1].route_node_state.node_id == "start"
|
||||
assert isinstance(items[2], NodeRunSucceededEvent)
|
||||
assert items[2].route_node_state.node_id == "start"
|
||||
|
||||
|
||||
@patch("extensions.ext_database.db.session.remove")
|
||||
@patch("extensions.ext_database.db.session.close")
|
||||
def test_run_parallel_in_chatflow(mock_close, mock_remove):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "1",
|
||||
"source": "start",
|
||||
"target": "answer1",
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"source": "answer1",
|
||||
"target": "answer2",
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"source": "answer1",
|
||||
"target": "answer3",
|
||||
},
|
||||
{
|
||||
"id": "4",
|
||||
"source": "answer2",
|
||||
"target": "answer4",
|
||||
},
|
||||
{
|
||||
"id": "5",
|
||||
"source": "answer3",
|
||||
"target": "answer5",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{"data": {"type": "start", "title": "start"}, "id": "start"},
|
||||
{"data": {"type": "answer", "title": "answer1", "answer": "1"}, "id": "answer1"},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer2", "answer": "2"},
|
||||
"id": "answer2",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer3", "answer": "3"},
|
||||
"id": "answer3",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer4", "answer": "4"},
|
||||
"id": "answer4",
|
||||
},
|
||||
{
|
||||
"data": {"type": "answer", "title": "answer5", "answer": "5"},
|
||||
"id": "answer5",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "what's the weather in SF",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
user_inputs={},
|
||||
)
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_id="333",
|
||||
graph_config=graph_config,
|
||||
user_id="444",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
|
||||
# print("")
|
||||
|
||||
items = []
|
||||
generator = graph_engine.run()
|
||||
for item in generator:
|
||||
# print(type(item), item)
|
||||
items.append(item)
|
||||
if isinstance(item, NodeRunSucceededEvent):
|
||||
assert item.route_node_state.status == RouteNodeState.Status.SUCCESS
|
||||
|
||||
assert not isinstance(item, NodeRunFailedEvent)
|
||||
assert not isinstance(item, GraphRunFailedEvent)
|
||||
|
||||
if isinstance(item, BaseNodeEvent) and item.route_node_state.node_id in [
|
||||
"answer2",
|
||||
"answer3",
|
||||
"answer4",
|
||||
"answer5",
|
||||
]:
|
||||
assert item.parallel_id is not None
|
||||
|
||||
assert len(items) == 23
|
||||
assert isinstance(items[0], GraphRunStartedEvent)
|
||||
assert isinstance(items[1], NodeRunStartedEvent)
|
||||
assert items[1].route_node_state.node_id == "start"
|
||||
assert isinstance(items[2], NodeRunSucceededEvent)
|
||||
assert items[2].route_node_state.node_id == "start"
|
||||
|
||||
|
||||
@patch("extensions.ext_database.db.session.remove")
|
||||
@patch("extensions.ext_database.db.session.close")
|
||||
def test_run_branch(mock_close, mock_remove):
|
||||
graph_config = {
|
||||
"edges": [
|
||||
{
|
||||
"id": "1",
|
||||
"source": "start",
|
||||
"target": "if-else-1",
|
||||
},
|
||||
{
|
||||
"id": "2",
|
||||
"source": "if-else-1",
|
||||
"sourceHandle": "true",
|
||||
"target": "answer-1",
|
||||
},
|
||||
{
|
||||
"id": "3",
|
||||
"source": "if-else-1",
|
||||
"sourceHandle": "false",
|
||||
"target": "if-else-2",
|
||||
},
|
||||
{
|
||||
"id": "4",
|
||||
"source": "if-else-2",
|
||||
"sourceHandle": "true",
|
||||
"target": "answer-2",
|
||||
},
|
||||
{
|
||||
"id": "5",
|
||||
"source": "if-else-2",
|
||||
"sourceHandle": "false",
|
||||
"target": "answer-3",
|
||||
},
|
||||
],
|
||||
"nodes": [
|
||||
{
|
||||
"data": {
|
||||
"title": "Start",
|
||||
"type": "start",
|
||||
"variables": [
|
||||
{
|
||||
"label": "uid",
|
||||
"max_length": 48,
|
||||
"options": [],
|
||||
"required": True,
|
||||
"type": "text-input",
|
||||
"variable": "uid",
|
||||
}
|
||||
],
|
||||
},
|
||||
"id": "start",
|
||||
},
|
||||
{
|
||||
"data": {"answer": "1 {{#start.uid#}}", "title": "Answer", "type": "answer", "variables": []},
|
||||
"id": "answer-1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"cases": [
|
||||
{
|
||||
"case_id": "true",
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "contains",
|
||||
"id": "b0f02473-08b6-4a81-af91-15345dcb2ec8",
|
||||
"value": "hi",
|
||||
"varType": "string",
|
||||
"variable_selector": ["sys", "query"],
|
||||
}
|
||||
],
|
||||
"id": "true",
|
||||
"logical_operator": "and",
|
||||
}
|
||||
],
|
||||
"desc": "",
|
||||
"title": "IF/ELSE",
|
||||
"type": "if-else",
|
||||
},
|
||||
"id": "if-else-1",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"cases": [
|
||||
{
|
||||
"case_id": "true",
|
||||
"conditions": [
|
||||
{
|
||||
"comparison_operator": "contains",
|
||||
"id": "ae895199-5608-433b-b5f0-0997ae1431e4",
|
||||
"value": "takatost",
|
||||
"varType": "string",
|
||||
"variable_selector": ["sys", "query"],
|
||||
}
|
||||
],
|
||||
"id": "true",
|
||||
"logical_operator": "and",
|
||||
}
|
||||
],
|
||||
"title": "IF/ELSE 2",
|
||||
"type": "if-else",
|
||||
},
|
||||
"id": "if-else-2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "2",
|
||||
"title": "Answer 2",
|
||||
"type": "answer",
|
||||
},
|
||||
"id": "answer-2",
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"answer": "3",
|
||||
"title": "Answer 3",
|
||||
"type": "answer",
|
||||
},
|
||||
"id": "answer-3",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
graph = Graph.init(graph_config=graph_config)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables={
|
||||
SystemVariableKey.QUERY: "hi",
|
||||
SystemVariableKey.FILES: [],
|
||||
SystemVariableKey.CONVERSATION_ID: "abababa",
|
||||
SystemVariableKey.USER_ID: "aaa",
|
||||
},
|
||||
user_inputs={"uid": "takato"},
|
||||
)
|
||||
|
||||
graph_engine = GraphEngine(
|
||||
tenant_id="111",
|
||||
app_id="222",
|
||||
workflow_type=WorkflowType.CHAT,
|
||||
workflow_id="333",
|
||||
graph_config=graph_config,
|
||||
user_id="444",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.WEB_APP,
|
||||
call_depth=0,
|
||||
graph=graph,
|
||||
variable_pool=variable_pool,
|
||||
max_execution_steps=500,
|
||||
max_execution_time=1200,
|
||||
)
|
||||
|
||||
# print("")
|
||||
|
||||
items = []
|
||||
generator = graph_engine.run()
|
||||
for item in generator:
|
||||
# print(type(item), item)
|
||||
items.append(item)
|
||||
|
||||
assert len(items) == 10
|
||||
assert items[3].route_node_state.node_id == "if-else-1"
|
||||
assert items[4].route_node_state.node_id == "if-else-1"
|
||||
assert isinstance(items[5], NodeRunStreamChunkEvent)
|
||||
assert items[5].chunk_content == "1 "
|
||||
assert isinstance(items[6], NodeRunStreamChunkEvent)
|
||||
assert items[6].chunk_content == "takato"
|
||||
assert items[7].route_node_state.node_id == "answer-1"
|
||||
assert items[8].route_node_state.node_id == "answer-1"
|
||||
assert items[8].route_node_state.node_run_result.outputs["answer"] == "1 takato"
|
||||
assert isinstance(items[9], GraphRunSucceededEvent)
|
||||
|
||||
# print(graph_engine.graph_runtime_state.model_dump_json(indent=2))
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue