feat(api): tracking execution time in GraphRuntimeState

use wall clock for time measurement (As the execution may be continued on another node)
pull/22621/head
QuantumGhost 7 months ago
parent 55c2c4a6b6
commit d99ad77837

@ -0,0 +1,15 @@
import time
def get_timestamp() -> float:
"""Retrieve a timestamp as a float point numer representing the number of seconds
since the Unix epoch.
This function is primarily used to measure the execution time of the workflow engine.
Since workflow execution may be paused and resumed on a different machine,
`time.perf_counter` cannot be used as it is inconsistent across machines.
To address this, the function uses the wall clock as the time source.
However, it assumes that the clocks of all servers are properly synchronized.
"""
return round(time.time())

@ -1,18 +1,32 @@
from typing import Any
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine._engine_utils import get_timestamp
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
_SECOND_TO_US = 1_000_000
class GraphRuntimeState(BaseModel):
"""`GraphRuntimeState` encapsulates the runtime state of workflow execution,
including scheduling details, variable values, and timing information.
Values that are initialized prior to workflow execution and remain constant
throughout the execution should be part of `GraphInitParams` instead.
"""
variable_pool: VariablePool = Field(..., description="variable pool")
"""variable pool"""
start_at: float = Field(..., description="start time")
"""start time"""
# The `start_at` field records the execution start time of the workflow.
#
# This field is automatically generated, and its value or interpretation may evolve.
# Avoid manually setting this field to ensure compatibility with future updates.
start_at: float = Field(description="start time", default_factory=get_timestamp)
total_tokens: int = 0
"""total tokens"""
llm_usage: LLMUsage = LLMUsage.empty_usage()
@ -29,3 +43,28 @@ class GraphRuntimeState(BaseModel):
node_run_state: RuntimeRouteState = RuntimeRouteState()
"""node run state"""
# `execution_time_us` tracks the total execution time of the workflow in microseconds.
# Time spent in suspension is excluded from this calculation.
#
# This field is used to persist the time already spent while suspending a workflow.
execution_time_us: int = 0
# `_last_execution_started_at` records the timestamp of the most recent resume start.
# It is updated when the workflow resumes from a suspended state.
_last_execution_started_at: float = PrivateAttr(default_factory=get_timestamp)
def is_timed_out(self, max_execution_time_seconds: int) -> bool:
"""Checks if the workflow execution has exceeded the specified `max_execution_time_seconds`."""
remaining_time_us = max_execution_time_seconds * _SECOND_TO_US - self.execution_time_us
if remaining_time_us <= 0:
return False
return int(get_timestamp() - self._last_execution_started_at) * _SECOND_TO_US > remaining_time_us
def record_suspend_state(self, next_node_id: str):
"""Record the time already spent in executing workflow.
This function should be called when suspending the workflow.
"""
self.execution_time_us = int(get_timestamp() - self._last_execution_started_at) * _SECOND_TO_US
self.node_run_state.next_node_id = next_node_id

@ -3,7 +3,7 @@ from datetime import UTC, datetime
from enum import Enum
from typing import Any, Optional
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, PrivateAttr
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
@ -92,9 +92,11 @@ class RuntimeRouteState(BaseModel):
# If `previous_node_id` is not `None`, then the correspond node has state in the dict
# `node_state_mapping`.
previous_node_state_id: Optional[str] = Field(None, description="The state of last executed node.")
previous_node_state_id: Optional[str] = Field(default=None, description="The state of last executed node.")
_state_by_id: dict[str, RouteNodeState]
# `_state_by_id` serves as a mapping from the unique identifier (`id`) of each `RouteNodeState`
# instance to the corresponding `RouteNodeState` object itself.
_state_by_id: dict[str, RouteNodeState] = PrivateAttr(default={})
def model_post_init(self, context: Any) -> None:
super().model_post_init(context)

@ -160,7 +160,7 @@ class IterationNode(BaseNode):
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine, GraphEngineThreadPool
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool)
graph_engine = GraphEngine(
tenant_id=self.tenant_id,

@ -124,7 +124,7 @@ class LoopNode(BaseNode):
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.graph_engine import GraphEngine
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool)
graph_engine = GraphEngine(
tenant_id=self.tenant_id,

@ -1,11 +1,10 @@
import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Optional, cast
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.workflow.callbacks import WorkflowCallback
@ -70,7 +69,8 @@ class WorkflowEntry:
raise ValueError("Max workflow call depth {} reached.".format(workflow_call_max_depth))
# init workflow run state
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool)
self.graph_engine = GraphEngine(
tenant_id=tenant_id,
app_id=app_id,
@ -146,7 +146,7 @@ class WorkflowEntry:
graph = Graph.init(graph_config=workflow.graph_dict)
# init workflow run state
node = node_cls(
node_instance = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=GraphInitParams(
@ -161,7 +161,7 @@ class WorkflowEntry:
call_depth=0,
),
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
)
try:
@ -190,11 +190,17 @@ class WorkflowEntry:
try:
# run node
generator = node.run()
generator = node_instance.run()
except Exception as e:
logger.exception(f"error while running node, {workflow.id=}, {node.id=}, {node.type_=}, {node.version()=}")
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
return node, generator
logger.exception(
"error while running node_instance, workflow_id=%s, node_id=%s, type=%s, version=%s",
workflow.id,
node_instance.id,
node_instance.node_type,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
return node_instance, generator
@classmethod
def run_free_node(
@ -256,7 +262,7 @@ class WorkflowEntry:
node_cls = cast(type[BaseNode], node_cls)
# init workflow run state
node: BaseNode = node_cls(
node_instance: BaseNode = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=GraphInitParams(
@ -271,7 +277,7 @@ class WorkflowEntry:
call_depth=0,
),
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
)
try:
@ -291,12 +297,17 @@ class WorkflowEntry:
)
# run node
generator = node.run()
generator = node_instance.run()
return node, generator
return node_instance, generator
except Exception as e:
logger.exception(f"error while running node, {node.id=}, {node.type_=}, {node.version()=}")
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
logger.exception(
"error while running node_instance, node_id=%s, type=%s, version=%s",
node_instance.id,
node_instance.node_type,
node_instance.version(),
)
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
@staticmethod
def handle_special_values(value: Optional[Mapping[str, Any]]) -> Mapping[str, Any] | None:

@ -1,4 +1,3 @@
import time
import uuid
from os import getenv
from typing import cast
@ -62,7 +61,9 @@ def init_code_node(code_config: dict):
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=code_config,
)

@ -1,4 +1,3 @@
import time
import uuid
from urllib.parse import urlencode
@ -56,7 +55,9 @@ def init_http_node(config: dict):
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=config,
)

@ -1,5 +1,4 @@
import json
import time
import uuid
from collections.abc import Generator
from unittest.mock import MagicMock, patch
@ -73,7 +72,9 @@ def init_llm_node(config: dict) -> LLMNode:
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=config,
)

@ -1,5 +1,4 @@
import os
import time
import uuid
from typing import Optional
from unittest.mock import MagicMock
@ -78,7 +77,9 @@ def init_parameter_extractor_node(config: dict):
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=config,
)
node.init_node_data(config.get("data", {}))

@ -1,4 +1,3 @@
import time
import uuid
import pytest
@ -73,7 +72,9 @@ def test_execute_code(setup_code_executor_mock):
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=config,
)
node.init_node_data(config.get("data", {}))

@ -1,4 +1,3 @@
import time
import uuid
from unittest.mock import MagicMock
@ -54,7 +53,9 @@ def init_tool_node(config: dict):
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=config,
)
node.init_node_data(config.get("data", {}))

@ -49,7 +49,6 @@ def create_test_graph_runtime_state() -> GraphRuntimeState:
return GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
total_tokens=100,
llm_usage=llm_usage,
outputs={
@ -106,7 +105,6 @@ def test_empty_outputs_round_trip():
variable_pool = VariablePool.empty()
original_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
outputs={}, # Empty outputs
)

@ -175,8 +175,9 @@ def test_run_parallel_in_workflow(mock_close, mock_remove):
user_inputs={"query": "hi"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.WORKFLOW,
@ -303,8 +304,9 @@ def test_run_parallel_in_chatflow(mock_close, mock_remove):
user_inputs={},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
@ -484,8 +486,9 @@ def test_run_branch(mock_close, mock_remove):
user_inputs={"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,
@ -823,8 +826,9 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
user_inputs={"query": "hi"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_engine = GraphEngine(
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
tenant_id="111",
app_id="222",
workflow_type=WorkflowType.CHAT,

@ -1,4 +1,3 @@
import time
import uuid
from unittest.mock import patch
@ -179,7 +178,9 @@ def test_run():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=node_config,
)
@ -401,7 +402,9 @@ def test_run_parallel():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=node_config,
)
@ -623,7 +626,9 @@ def test_iteration_run_in_parallel_mode():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=parallel_node_config,
)
@ -647,7 +652,9 @@ def test_iteration_run_in_parallel_mode():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=sequential_node_config,
)
@ -857,7 +864,9 @@ def test_iteration_run_error_handle():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=error_node_config,
)

@ -1,4 +1,3 @@
import time
import uuid
from unittest.mock import MagicMock
@ -74,7 +73,9 @@ def test_execute_answer():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=variable_pool,
),
config=node_config,
)

@ -1,4 +1,3 @@
import time
from unittest.mock import patch
from core.app.entities.app_invoke_entities import InvokeFrom
@ -175,7 +174,9 @@ class ContinueOnErrorTestHelper:
),
user_inputs=user_inputs or {"uid": "takato"},
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
graph_runtime_state = GraphRuntimeState(
variable_pool=variable_pool,
)
return GraphEngine(
tenant_id="111",

@ -1,4 +1,3 @@
import time
import uuid
from unittest.mock import MagicMock, Mock
@ -106,7 +105,9 @@ def test_execute_if_else_result_true():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=node_config,
)
@ -192,7 +193,9 @@ def test_execute_if_else_result_false():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(
variable_pool=pool,
),
config=node_config,
)

@ -1,4 +1,3 @@
import time
import uuid
from unittest import mock
from uuid import uuid4
@ -96,7 +95,7 @@ def test_overwrite_string_variable():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
@ -197,7 +196,7 @@ def test_append_variable_to_array():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)
@ -289,7 +288,7 @@ def test_clear_array():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config,
conv_var_updater_factory=mock_conv_var_updater_factory,
)

@ -1,4 +1,3 @@
import time
import uuid
from uuid import uuid4
@ -135,7 +134,7 @@ def test_remove_first_from_array():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config,
)
@ -227,7 +226,7 @@ def test_remove_last_from_array():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config,
)
@ -311,7 +310,7 @@ def test_remove_first_from_empty_array():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config,
)
@ -395,7 +394,7 @@ def test_remove_last_from_empty_array():
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool),
config=node_config,
)

Loading…
Cancel
Save