refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025)
refactor(api): Separate SegmentType for Integer/Float to Enable Pydantic Serialization (#22025) This PR addresses serialization issues in the VariablePool model by separating the `value_type` tags for `IntegerSegment`/`FloatSegment` and `IntegerVariable`/`FloatVariable`. Previously, both Integer and Float types shared the same `SegmentType.NUMBER` tag, causing conflicts during serialization. Key changes: - Introduce distinct `value_type` tags for Integer and Float segments/variables - Add `VariableUnion` and `SegmentUnion` types for proper type discrimination - Leverage Pydantic's discriminated union feature for seamless serialization/deserialization - Enable accurate serialization of data structures containing these types Closes #22024.pull/22485/head
parent
229b4d621e
commit
2c1ab4879f
@ -0,0 +1,89 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from core.file.models import File
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
|
||||
|
||||
class SystemVariable(BaseModel):
|
||||
"""A model for managing system variables.
|
||||
|
||||
Fields with a value of `None` are treated as absent and will not be included
|
||||
in the variable pool.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
serialize_by_alias=True,
|
||||
validate_by_alias=True,
|
||||
)
|
||||
|
||||
user_id: str | None = None
|
||||
|
||||
# Ideally, `app_id` and `workflow_id` should be required and not `None`.
|
||||
# However, there are scenarios in the codebase where these fields are not set.
|
||||
# To maintain compatibility, they are marked as optional here.
|
||||
app_id: str | None = None
|
||||
workflow_id: str | None = None
|
||||
|
||||
files: Sequence[File] = Field(default_factory=list)
|
||||
|
||||
# NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`.
|
||||
# To maintain compatibility with existing workflows, it must be serialized
|
||||
# as `workflow_run_id` in dictionaries or JSON objects, and also referenced
|
||||
# as `workflow_run_id` in the variable pool.
|
||||
workflow_execution_id: str | None = Field(
|
||||
validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"),
|
||||
serialization_alias="workflow_run_id",
|
||||
default=None,
|
||||
)
|
||||
# Chatflow related fields.
|
||||
query: str | None = None
|
||||
conversation_id: str | None = None
|
||||
dialogue_count: int | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_json_fields(cls, data):
|
||||
if isinstance(data, dict):
|
||||
# For JSON validation, only allow workflow_run_id
|
||||
if "workflow_execution_id" in data and "workflow_run_id" not in data:
|
||||
# This is likely from direct instantiation, allow it
|
||||
return data
|
||||
elif "workflow_execution_id" in data and "workflow_run_id" in data:
|
||||
# Both present, remove workflow_execution_id
|
||||
data = data.copy()
|
||||
data.pop("workflow_execution_id")
|
||||
return data
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "SystemVariable":
|
||||
return cls()
|
||||
|
||||
def to_dict(self) -> dict[SystemVariableKey, Any]:
|
||||
# NOTE: This method is provided for compatibility with legacy code.
|
||||
# New code should use the `SystemVariable` object directly instead of converting
|
||||
# it to a dictionary, as this conversion results in the loss of type information
|
||||
# for each key, making static analysis more difficult.
|
||||
|
||||
d: dict[SystemVariableKey, Any] = {
|
||||
SystemVariableKey.FILES: self.files,
|
||||
}
|
||||
if self.user_id is not None:
|
||||
d[SystemVariableKey.USER_ID] = self.user_id
|
||||
if self.app_id is not None:
|
||||
d[SystemVariableKey.APP_ID] = self.app_id
|
||||
if self.workflow_id is not None:
|
||||
d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id
|
||||
if self.workflow_execution_id is not None:
|
||||
d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id
|
||||
if self.query is not None:
|
||||
d[SystemVariableKey.QUERY] = self.query
|
||||
if self.conversation_id is not None:
|
||||
d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id
|
||||
if self.dialogue_count is not None:
|
||||
d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count
|
||||
return d
|
||||
@ -0,0 +1,15 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from core.variables.segments import Segment
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
|
||||
class _VarTypedDict(TypedDict, total=False):
|
||||
value_type: SegmentType
|
||||
|
||||
|
||||
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
|
||||
if isinstance(v, Segment):
|
||||
return v.value_type.exposed_type().value
|
||||
else:
|
||||
return v["value_type"].exposed_type().value
|
||||
@ -0,0 +1,60 @@
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
|
||||
class TestSegmentTypeIsArrayType:
|
||||
"""
|
||||
Test class for SegmentType.is_array_type method.
|
||||
|
||||
Provides comprehensive coverage of all SegmentType values to ensure
|
||||
correct identification of array and non-array types.
|
||||
"""
|
||||
|
||||
def test_is_array_type(self):
|
||||
"""
|
||||
Test that all SegmentType enum values are covered in our test cases.
|
||||
|
||||
Ensures comprehensive coverage by verifying that every SegmentType
|
||||
value is tested for the is_array_type method.
|
||||
"""
|
||||
# Arrange
|
||||
all_segment_types = set(SegmentType)
|
||||
expected_array_types = [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
]
|
||||
expected_non_array_types = [
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.FLOAT,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.STRING,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.SECRET,
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
SegmentType.GROUP,
|
||||
]
|
||||
|
||||
for seg_type in expected_array_types:
|
||||
assert seg_type.is_array_type()
|
||||
|
||||
for seg_type in expected_non_array_types:
|
||||
assert not seg_type.is_array_type()
|
||||
|
||||
# Act & Assert
|
||||
covered_types = set(expected_array_types) | set(expected_non_array_types)
|
||||
assert covered_types == set(SegmentType), "All SegmentType values should be covered in tests"
|
||||
|
||||
def test_all_enum_values_are_supported(self):
|
||||
"""
|
||||
Test that all enum values are supported and return boolean values.
|
||||
|
||||
Validates that every SegmentType enum value can be processed by
|
||||
is_array_type method and returns a boolean value.
|
||||
"""
|
||||
enum_values: list[SegmentType] = list(SegmentType)
|
||||
for seg_type in enum_values:
|
||||
is_array = seg_type.is_array_type()
|
||||
assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"
|
||||
@ -0,0 +1,146 @@
|
||||
import time
|
||||
from decimal import Decimal
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def create_test_graph_runtime_state() -> GraphRuntimeState:
|
||||
"""Factory function to create a GraphRuntimeState with non-empty values for testing."""
|
||||
# Create a variable pool with system variables
|
||||
system_vars = SystemVariable(
|
||||
user_id="test_user_123",
|
||||
app_id="test_app_456",
|
||||
workflow_id="test_workflow_789",
|
||||
workflow_execution_id="test_execution_001",
|
||||
query="test query",
|
||||
conversation_id="test_conv_123",
|
||||
dialogue_count=5,
|
||||
)
|
||||
variable_pool = VariablePool(system_variables=system_vars)
|
||||
|
||||
# Add some variables to the variable pool
|
||||
variable_pool.add(["test_node", "test_var"], "test_value")
|
||||
variable_pool.add(["another_node", "another_var"], 42)
|
||||
|
||||
# Create LLM usage with realistic values
|
||||
llm_usage = LLMUsage(
|
||||
prompt_tokens=150,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.15"),
|
||||
completion_tokens=75,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.15"),
|
||||
total_tokens=225,
|
||||
total_price=Decimal("0.30"),
|
||||
currency="USD",
|
||||
latency=1.25,
|
||||
)
|
||||
|
||||
# Create runtime route state with some node states
|
||||
node_run_state = RuntimeRouteState()
|
||||
node_state = node_run_state.create_node_state("test_node_1")
|
||||
node_run_state.add_route(node_state.id, "target_node_id")
|
||||
|
||||
return GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
total_tokens=100,
|
||||
llm_usage=llm_usage,
|
||||
outputs={
|
||||
"string_output": "test result",
|
||||
"int_output": 42,
|
||||
"float_output": 3.14,
|
||||
"list_output": ["item1", "item2", "item3"],
|
||||
"dict_output": {"key1": "value1", "key2": 123},
|
||||
"nested_dict": {"level1": {"level2": ["nested", "list", 456]}},
|
||||
},
|
||||
node_run_steps=5,
|
||||
node_run_state=node_run_state,
|
||||
)
|
||||
|
||||
|
||||
def test_basic_round_trip_serialization():
|
||||
"""Test basic round-trip serialization ensures GraphRuntimeState values remain unchanged."""
|
||||
# Create a state with non-empty values
|
||||
original_state = create_test_graph_runtime_state()
|
||||
|
||||
# Serialize to JSON and deserialize back
|
||||
json_data = original_state.model_dump_json()
|
||||
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
|
||||
|
||||
# Core test: ensure the round-trip preserves all values
|
||||
assert deserialized_state == original_state
|
||||
|
||||
# Serialize to JSON and deserialize back
|
||||
dict_data = original_state.model_dump(mode="python")
|
||||
deserialized_state = GraphRuntimeState.model_validate(dict_data)
|
||||
assert deserialized_state == original_state
|
||||
|
||||
# Serialize to JSON and deserialize back
|
||||
dict_data = original_state.model_dump(mode="json")
|
||||
deserialized_state = GraphRuntimeState.model_validate(dict_data)
|
||||
assert deserialized_state == original_state
|
||||
|
||||
|
||||
def test_outputs_field_round_trip():
|
||||
"""Test the problematic outputs field maintains values through round-trip serialization."""
|
||||
original_state = create_test_graph_runtime_state()
|
||||
|
||||
# Serialize and deserialize
|
||||
json_data = original_state.model_dump_json()
|
||||
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
|
||||
|
||||
# Verify the outputs field specifically maintains its values
|
||||
assert deserialized_state.outputs == original_state.outputs
|
||||
assert deserialized_state == original_state
|
||||
|
||||
|
||||
def test_empty_outputs_round_trip():
|
||||
"""Test round-trip serialization with empty outputs field."""
|
||||
variable_pool = VariablePool.empty()
|
||||
original_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
outputs={}, # Empty outputs
|
||||
)
|
||||
|
||||
json_data = original_state.model_dump_json()
|
||||
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
|
||||
|
||||
assert deserialized_state == original_state
|
||||
|
||||
|
||||
def test_llm_usage_round_trip():
|
||||
# Create LLM usage with specific decimal values
|
||||
llm_usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.0015"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.15"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.003"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.15"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.30"),
|
||||
currency="USD",
|
||||
latency=2.5,
|
||||
)
|
||||
|
||||
json_data = llm_usage.model_dump_json()
|
||||
deserialized = LLMUsage.model_validate_json(json_data)
|
||||
assert deserialized == llm_usage
|
||||
|
||||
dict_data = llm_usage.model_dump(mode="python")
|
||||
deserialized = LLMUsage.model_validate(dict_data)
|
||||
assert deserialized == llm_usage
|
||||
|
||||
dict_data = llm_usage.model_dump(mode="json")
|
||||
deserialized = LLMUsage.model_validate(dict_data)
|
||||
assert deserialized == llm_usage
|
||||
@ -0,0 +1,401 @@
|
||||
import json
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState
|
||||
|
||||
_TEST_DATETIME = datetime(2024, 1, 15, 10, 30, 45)
|
||||
|
||||
|
||||
class TestRouteNodeStateSerialization:
|
||||
"""Test cases for RouteNodeState Pydantic serialization/deserialization."""
|
||||
|
||||
def _test_route_node_state(self):
|
||||
"""Test comprehensive RouteNodeState serialization with all core fields validation."""
|
||||
|
||||
node_run_result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"input_key": "input_value"},
|
||||
outputs={"output_key": "output_value"},
|
||||
)
|
||||
|
||||
node_state = RouteNodeState(
|
||||
node_id="comprehensive_test_node",
|
||||
start_at=_TEST_DATETIME,
|
||||
finished_at=_TEST_DATETIME,
|
||||
status=RouteNodeState.Status.SUCCESS,
|
||||
node_run_result=node_run_result,
|
||||
index=5,
|
||||
paused_at=_TEST_DATETIME,
|
||||
paused_by="user_123",
|
||||
failed_reason="test_reason",
|
||||
)
|
||||
return node_state
|
||||
|
||||
def test_route_node_state_comprehensive_field_validation(self):
|
||||
"""Test comprehensive RouteNodeState serialization with all core fields validation."""
|
||||
node_state = self._test_route_node_state()
|
||||
serialized = node_state.model_dump()
|
||||
|
||||
# Comprehensive validation of all RouteNodeState fields
|
||||
assert serialized["node_id"] == "comprehensive_test_node"
|
||||
assert serialized["status"] == RouteNodeState.Status.SUCCESS
|
||||
assert serialized["start_at"] == _TEST_DATETIME
|
||||
assert serialized["finished_at"] == _TEST_DATETIME
|
||||
assert serialized["paused_at"] == _TEST_DATETIME
|
||||
assert serialized["paused_by"] == "user_123"
|
||||
assert serialized["failed_reason"] == "test_reason"
|
||||
assert serialized["index"] == 5
|
||||
assert "id" in serialized
|
||||
assert isinstance(serialized["id"], str)
|
||||
uuid.UUID(serialized["id"]) # Validate UUID format
|
||||
|
||||
# Validate nested NodeRunResult structure
|
||||
assert serialized["node_run_result"] is not None
|
||||
assert serialized["node_run_result"]["status"] == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert serialized["node_run_result"]["inputs"] == {"input_key": "input_value"}
|
||||
assert serialized["node_run_result"]["outputs"] == {"output_key": "output_value"}
|
||||
|
||||
def test_route_node_state_minimal_required_fields(self):
|
||||
"""Test RouteNodeState with only required fields, focusing on defaults."""
|
||||
node_state = RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME)
|
||||
|
||||
serialized = node_state.model_dump()
|
||||
|
||||
# Focus on required fields and default values (not re-testing all fields)
|
||||
assert serialized["node_id"] == "minimal_node"
|
||||
assert serialized["start_at"] == _TEST_DATETIME
|
||||
assert serialized["status"] == RouteNodeState.Status.RUNNING # Default status
|
||||
assert serialized["index"] == 1 # Default index
|
||||
assert serialized["node_run_result"] is None # Default None
|
||||
json = node_state.model_dump_json()
|
||||
deserialized = RouteNodeState.model_validate_json(json)
|
||||
assert deserialized == node_state
|
||||
|
||||
def test_route_node_state_deserialization_from_dict(self):
|
||||
"""Test RouteNodeState deserialization from dictionary data."""
|
||||
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
|
||||
test_id = str(uuid.uuid4())
|
||||
|
||||
dict_data = {
|
||||
"id": test_id,
|
||||
"node_id": "deserialized_node",
|
||||
"start_at": test_datetime,
|
||||
"status": "success",
|
||||
"finished_at": test_datetime,
|
||||
"index": 3,
|
||||
}
|
||||
|
||||
node_state = RouteNodeState.model_validate(dict_data)
|
||||
|
||||
# Focus on deserialization accuracy
|
||||
assert node_state.id == test_id
|
||||
assert node_state.node_id == "deserialized_node"
|
||||
assert node_state.start_at == test_datetime
|
||||
assert node_state.status == RouteNodeState.Status.SUCCESS
|
||||
assert node_state.finished_at == test_datetime
|
||||
assert node_state.index == 3
|
||||
|
||||
def test_route_node_state_round_trip_consistency(self):
|
||||
node_states = (
|
||||
self._test_route_node_state(),
|
||||
RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME),
|
||||
)
|
||||
for node_state in node_states:
|
||||
json = node_state.model_dump_json()
|
||||
deserialized = RouteNodeState.model_validate_json(json)
|
||||
assert deserialized == node_state
|
||||
|
||||
dict_ = node_state.model_dump(mode="python")
|
||||
deserialized = RouteNodeState.model_validate(dict_)
|
||||
assert deserialized == node_state
|
||||
|
||||
dict_ = node_state.model_dump(mode="json")
|
||||
deserialized = RouteNodeState.model_validate(dict_)
|
||||
assert deserialized == node_state
|
||||
|
||||
|
||||
class TestRouteNodeStateEnumSerialization:
|
||||
"""Dedicated tests for RouteNodeState Status enum serialization behavior."""
|
||||
|
||||
def test_status_enum_model_dump_behavior(self):
|
||||
"""Test Status enum serialization in model_dump() returns enum objects."""
|
||||
|
||||
for status_enum in RouteNodeState.Status:
|
||||
node_state = RouteNodeState(node_id="enum_test", start_at=_TEST_DATETIME, status=status_enum)
|
||||
serialized = node_state.model_dump(mode="python")
|
||||
assert serialized["status"] == status_enum
|
||||
serialized = node_state.model_dump(mode="json")
|
||||
assert serialized["status"] == status_enum.value
|
||||
|
||||
def test_status_enum_json_serialization_behavior(self):
|
||||
"""Test Status enum serialization in JSON returns string values."""
|
||||
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
|
||||
|
||||
enum_to_string_mapping = {
|
||||
RouteNodeState.Status.RUNNING: "running",
|
||||
RouteNodeState.Status.SUCCESS: "success",
|
||||
RouteNodeState.Status.FAILED: "failed",
|
||||
RouteNodeState.Status.PAUSED: "paused",
|
||||
RouteNodeState.Status.EXCEPTION: "exception",
|
||||
}
|
||||
|
||||
for status_enum, expected_string in enum_to_string_mapping.items():
|
||||
node_state = RouteNodeState(node_id="json_enum_test", start_at=test_datetime, status=status_enum)
|
||||
|
||||
json_data = json.loads(node_state.model_dump_json())
|
||||
assert json_data["status"] == expected_string
|
||||
|
||||
def test_status_enum_deserialization_from_string(self):
|
||||
"""Test Status enum deserialization from string values."""
|
||||
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
|
||||
|
||||
string_to_enum_mapping = {
|
||||
"running": RouteNodeState.Status.RUNNING,
|
||||
"success": RouteNodeState.Status.SUCCESS,
|
||||
"failed": RouteNodeState.Status.FAILED,
|
||||
"paused": RouteNodeState.Status.PAUSED,
|
||||
"exception": RouteNodeState.Status.EXCEPTION,
|
||||
}
|
||||
|
||||
for status_string, expected_enum in string_to_enum_mapping.items():
|
||||
dict_data = {
|
||||
"node_id": "enum_deserialize_test",
|
||||
"start_at": test_datetime,
|
||||
"status": status_string,
|
||||
}
|
||||
|
||||
node_state = RouteNodeState.model_validate(dict_data)
|
||||
assert node_state.status == expected_enum
|
||||
|
||||
|
||||
class TestRuntimeRouteStateSerialization:
|
||||
"""Test cases for RuntimeRouteState Pydantic serialization/deserialization."""
|
||||
|
||||
_NODE1_ID = "node_1"
|
||||
_ROUTE_STATE1_ID = str(uuid.uuid4())
|
||||
_NODE2_ID = "node_2"
|
||||
_ROUTE_STATE2_ID = str(uuid.uuid4())
|
||||
_NODE3_ID = "node_3"
|
||||
_ROUTE_STATE3_ID = str(uuid.uuid4())
|
||||
|
||||
def _get_runtime_route_state(self):
|
||||
# Create node states with different configurations
|
||||
node_state_1 = RouteNodeState(
|
||||
id=self._ROUTE_STATE1_ID,
|
||||
node_id=self._NODE1_ID,
|
||||
start_at=_TEST_DATETIME,
|
||||
index=1,
|
||||
)
|
||||
node_state_2 = RouteNodeState(
|
||||
id=self._ROUTE_STATE2_ID,
|
||||
node_id=self._NODE2_ID,
|
||||
start_at=_TEST_DATETIME,
|
||||
status=RouteNodeState.Status.SUCCESS,
|
||||
finished_at=_TEST_DATETIME,
|
||||
index=2,
|
||||
)
|
||||
node_state_3 = RouteNodeState(
|
||||
id=self._ROUTE_STATE3_ID,
|
||||
node_id=self._NODE3_ID,
|
||||
start_at=_TEST_DATETIME,
|
||||
status=RouteNodeState.Status.FAILED,
|
||||
failed_reason="Test failure",
|
||||
index=3,
|
||||
)
|
||||
|
||||
runtime_state = RuntimeRouteState(
|
||||
routes={node_state_1.id: [node_state_2.id, node_state_3.id], node_state_2.id: [node_state_3.id]},
|
||||
node_state_mapping={
|
||||
node_state_1.id: node_state_1,
|
||||
node_state_2.id: node_state_2,
|
||||
node_state_3.id: node_state_3,
|
||||
},
|
||||
)
|
||||
|
||||
return runtime_state
|
||||
|
||||
def test_runtime_route_state_comprehensive_structure_validation(self):
|
||||
"""Test comprehensive RuntimeRouteState serialization with full structure validation."""
|
||||
|
||||
runtime_state = self._get_runtime_route_state()
|
||||
serialized = runtime_state.model_dump()
|
||||
|
||||
# Comprehensive validation of RuntimeRouteState structure
|
||||
assert "routes" in serialized
|
||||
assert "node_state_mapping" in serialized
|
||||
assert isinstance(serialized["routes"], dict)
|
||||
assert isinstance(serialized["node_state_mapping"], dict)
|
||||
|
||||
# Validate routes dictionary structure and content
|
||||
assert len(serialized["routes"]) == 2
|
||||
assert self._ROUTE_STATE1_ID in serialized["routes"]
|
||||
assert self._ROUTE_STATE2_ID in serialized["routes"]
|
||||
assert serialized["routes"][self._ROUTE_STATE1_ID] == [self._ROUTE_STATE2_ID, self._ROUTE_STATE3_ID]
|
||||
assert serialized["routes"][self._ROUTE_STATE2_ID] == [self._ROUTE_STATE3_ID]
|
||||
|
||||
# Validate node_state_mapping dictionary structure and content
|
||||
assert len(serialized["node_state_mapping"]) == 3
|
||||
for state_id in [
|
||||
self._ROUTE_STATE1_ID,
|
||||
self._ROUTE_STATE2_ID,
|
||||
self._ROUTE_STATE3_ID,
|
||||
]:
|
||||
assert state_id in serialized["node_state_mapping"]
|
||||
node_data = serialized["node_state_mapping"][state_id]
|
||||
node_state = runtime_state.node_state_mapping[state_id]
|
||||
assert node_data["node_id"] == node_state.node_id
|
||||
assert node_data["status"] == node_state.status
|
||||
assert node_data["index"] == node_state.index
|
||||
|
||||
def test_runtime_route_state_empty_collections(self):
|
||||
"""Test RuntimeRouteState with empty collections, focusing on default behavior."""
|
||||
runtime_state = RuntimeRouteState()
|
||||
serialized = runtime_state.model_dump()
|
||||
|
||||
# Focus on default empty collection behavior
|
||||
assert serialized["routes"] == {}
|
||||
assert serialized["node_state_mapping"] == {}
|
||||
assert isinstance(serialized["routes"], dict)
|
||||
assert isinstance(serialized["node_state_mapping"], dict)
|
||||
|
||||
def test_runtime_route_state_json_serialization_structure(self):
|
||||
"""Test RuntimeRouteState JSON serialization structure."""
|
||||
node_state = RouteNodeState(node_id="json_node", start_at=_TEST_DATETIME)
|
||||
|
||||
runtime_state = RuntimeRouteState(
|
||||
routes={"source": ["target1", "target2"]}, node_state_mapping={node_state.id: node_state}
|
||||
)
|
||||
|
||||
json_str = runtime_state.model_dump_json()
|
||||
json_data = json.loads(json_str)
|
||||
|
||||
# Focus on JSON structure validation
|
||||
assert isinstance(json_str, str)
|
||||
assert isinstance(json_data, dict)
|
||||
assert "routes" in json_data
|
||||
assert "node_state_mapping" in json_data
|
||||
assert json_data["routes"]["source"] == ["target1", "target2"]
|
||||
assert node_state.id in json_data["node_state_mapping"]
|
||||
|
||||
def test_runtime_route_state_deserialization_from_dict(self):
|
||||
"""Test RuntimeRouteState deserialization from dictionary data."""
|
||||
node_id = str(uuid.uuid4())
|
||||
|
||||
dict_data = {
|
||||
"routes": {"source_node": ["target_node_1", "target_node_2"]},
|
||||
"node_state_mapping": {
|
||||
node_id: {
|
||||
"id": node_id,
|
||||
"node_id": "test_node",
|
||||
"start_at": _TEST_DATETIME,
|
||||
"status": "running",
|
||||
"index": 1,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
runtime_state = RuntimeRouteState.model_validate(dict_data)
|
||||
|
||||
# Focus on deserialization accuracy
|
||||
assert runtime_state.routes == {"source_node": ["target_node_1", "target_node_2"]}
|
||||
assert len(runtime_state.node_state_mapping) == 1
|
||||
assert node_id in runtime_state.node_state_mapping
|
||||
|
||||
deserialized_node = runtime_state.node_state_mapping[node_id]
|
||||
assert deserialized_node.node_id == "test_node"
|
||||
assert deserialized_node.status == RouteNodeState.Status.RUNNING
|
||||
assert deserialized_node.index == 1
|
||||
|
||||
def test_runtime_route_state_round_trip_consistency(self):
|
||||
"""Test RuntimeRouteState round-trip serialization consistency."""
|
||||
original = self._get_runtime_route_state()
|
||||
|
||||
# Dictionary round trip
|
||||
dict_data = original.model_dump(mode="python")
|
||||
reconstructed = RuntimeRouteState.model_validate(dict_data)
|
||||
assert reconstructed == original
|
||||
|
||||
dict_data = original.model_dump(mode="json")
|
||||
reconstructed = RuntimeRouteState.model_validate(dict_data)
|
||||
assert reconstructed == original
|
||||
|
||||
# JSON round trip
|
||||
json_str = original.model_dump_json()
|
||||
json_reconstructed = RuntimeRouteState.model_validate_json(json_str)
|
||||
assert json_reconstructed == original
|
||||
|
||||
|
||||
class TestSerializationEdgeCases:
|
||||
"""Test edge cases and error conditions for serialization/deserialization."""
|
||||
|
||||
def test_invalid_status_deserialization(self):
|
||||
"""Test deserialization with invalid status values."""
|
||||
test_datetime = _TEST_DATETIME
|
||||
invalid_data = {
|
||||
"node_id": "invalid_test",
|
||||
"start_at": test_datetime,
|
||||
"status": "invalid_status",
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
RouteNodeState.model_validate(invalid_data)
|
||||
assert "status" in str(exc_info.value)
|
||||
|
||||
def test_missing_required_fields_deserialization(self):
|
||||
"""Test deserialization with missing required fields."""
|
||||
incomplete_data = {"id": str(uuid.uuid4())}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
RouteNodeState.model_validate(incomplete_data)
|
||||
error_str = str(exc_info.value)
|
||||
assert "node_id" in error_str or "start_at" in error_str
|
||||
|
||||
def test_invalid_datetime_deserialization(self):
|
||||
"""Test deserialization with invalid datetime values."""
|
||||
invalid_data = {
|
||||
"node_id": "datetime_test",
|
||||
"start_at": "invalid_datetime",
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
RouteNodeState.model_validate(invalid_data)
|
||||
assert "start_at" in str(exc_info.value)
|
||||
|
||||
def test_invalid_routes_structure_deserialization(self):
|
||||
"""Test RuntimeRouteState deserialization with invalid routes structure."""
|
||||
invalid_data = {
|
||||
"routes": "invalid_routes_structure", # Should be dict
|
||||
"node_state_mapping": {},
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
RuntimeRouteState.model_validate(invalid_data)
|
||||
assert "routes" in str(exc_info.value)
|
||||
|
||||
def test_timezone_handling_in_datetime_fields(self):
|
||||
"""Test timezone handling in datetime field serialization."""
|
||||
utc_datetime = datetime.now(UTC)
|
||||
naive_datetime = utc_datetime.replace(tzinfo=None)
|
||||
|
||||
node_state = RouteNodeState(node_id="timezone_test", start_at=naive_datetime)
|
||||
dict_ = node_state.model_dump()
|
||||
|
||||
assert dict_["start_at"] == naive_datetime
|
||||
|
||||
# Test round trip
|
||||
reconstructed = RouteNodeState.model_validate(dict_)
|
||||
assert reconstructed.start_at == naive_datetime
|
||||
assert reconstructed.start_at.tzinfo is None
|
||||
|
||||
json = node_state.model_dump_json()
|
||||
|
||||
reconstructed = RouteNodeState.model_validate_json(json)
|
||||
assert reconstructed.start_at == naive_datetime
|
||||
assert reconstructed.start_at.tzinfo is None
|
||||
@ -0,0 +1,251 @@
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.file.enums import FileTransferMethod, FileType
|
||||
from core.file.models import File
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
# Test data constants for SystemVariable serialization tests
|
||||
VALID_BASE_DATA: dict[str, Any] = {
|
||||
"user_id": "a20f06b1-8703-45ab-937c-860a60072113",
|
||||
"app_id": "661bed75-458d-49c9-b487-fda0762677b9",
|
||||
"workflow_id": "d31f2136-b292-4ae0-96d4-1e77894a4f43",
|
||||
}
|
||||
|
||||
COMPLETE_VALID_DATA: dict[str, Any] = {
|
||||
**VALID_BASE_DATA,
|
||||
"query": "test query",
|
||||
"files": [],
|
||||
"conversation_id": "91f1eb7d-69f4-4d7b-b82f-4003d51744b9",
|
||||
"dialogue_count": 5,
|
||||
"workflow_run_id": "eb4704b5-2274-47f2-bfcd-0452daa82cb5",
|
||||
}
|
||||
|
||||
|
||||
def create_test_file() -> File:
|
||||
"""Create a test File object for serialization tests."""
|
||||
return File(
|
||||
tenant_id="test-tenant-id",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="test-file-id",
|
||||
filename="test.txt",
|
||||
extension=".txt",
|
||||
mime_type="text/plain",
|
||||
size=1024,
|
||||
storage_key="test-storage-key",
|
||||
)
|
||||
|
||||
|
||||
class TestSystemVariableSerialization:
|
||||
"""Focused tests for SystemVariable serialization/deserialization logic."""
|
||||
|
||||
def test_basic_deserialization(self):
|
||||
"""Test successful deserialization from JSON structure with all fields correctly mapped."""
|
||||
# Test with complete data
|
||||
system_var = SystemVariable(**COMPLETE_VALID_DATA)
|
||||
|
||||
# Verify all fields are correctly mapped
|
||||
assert system_var.user_id == COMPLETE_VALID_DATA["user_id"]
|
||||
assert system_var.app_id == COMPLETE_VALID_DATA["app_id"]
|
||||
assert system_var.workflow_id == COMPLETE_VALID_DATA["workflow_id"]
|
||||
assert system_var.query == COMPLETE_VALID_DATA["query"]
|
||||
assert system_var.conversation_id == COMPLETE_VALID_DATA["conversation_id"]
|
||||
assert system_var.dialogue_count == COMPLETE_VALID_DATA["dialogue_count"]
|
||||
assert system_var.workflow_execution_id == COMPLETE_VALID_DATA["workflow_run_id"]
|
||||
assert system_var.files == []
|
||||
|
||||
# Test with minimal data (only required fields)
|
||||
minimal_var = SystemVariable(**VALID_BASE_DATA)
|
||||
assert minimal_var.user_id == VALID_BASE_DATA["user_id"]
|
||||
assert minimal_var.app_id == VALID_BASE_DATA["app_id"]
|
||||
assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"]
|
||||
assert minimal_var.query is None
|
||||
assert minimal_var.conversation_id is None
|
||||
assert minimal_var.dialogue_count is None
|
||||
assert minimal_var.workflow_execution_id is None
|
||||
assert minimal_var.files == []
|
||||
|
||||
def test_alias_handling(self):
|
||||
"""Test workflow_execution_id vs workflow_run_id alias resolution - core deserialization logic."""
|
||||
workflow_id = "eb4704b5-2274-47f2-bfcd-0452daa82cb5"
|
||||
|
||||
# Test workflow_run_id only (preferred alias)
|
||||
data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
|
||||
system_var1 = SystemVariable(**data_run_id)
|
||||
assert system_var1.workflow_execution_id == workflow_id
|
||||
|
||||
# Test workflow_execution_id only (direct field name)
|
||||
data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
|
||||
system_var2 = SystemVariable(**data_execution_id)
|
||||
assert system_var2.workflow_execution_id == workflow_id
|
||||
|
||||
# Test both present - workflow_run_id should take precedence
|
||||
data_both = {
|
||||
**VALID_BASE_DATA,
|
||||
"workflow_execution_id": "should-be-ignored",
|
||||
"workflow_run_id": workflow_id,
|
||||
}
|
||||
system_var3 = SystemVariable(**data_both)
|
||||
assert system_var3.workflow_execution_id == workflow_id
|
||||
|
||||
# Test neither present - should be None
|
||||
system_var4 = SystemVariable(**VALID_BASE_DATA)
|
||||
assert system_var4.workflow_execution_id is None
|
||||
|
||||
def test_serialization_round_trip(self):
|
||||
"""Test that serialize → deserialize produces the same result with alias handling."""
|
||||
# Create original SystemVariable
|
||||
original = SystemVariable(**COMPLETE_VALID_DATA)
|
||||
|
||||
# Serialize to dict
|
||||
serialized = original.model_dump(mode="json")
|
||||
|
||||
# Verify alias is used in serialization (workflow_run_id, not workflow_execution_id)
|
||||
assert "workflow_run_id" in serialized
|
||||
assert "workflow_execution_id" not in serialized
|
||||
assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
|
||||
|
||||
# Deserialize back
|
||||
deserialized = SystemVariable(**serialized)
|
||||
|
||||
# Verify all fields match after round-trip
|
||||
assert deserialized.user_id == original.user_id
|
||||
assert deserialized.app_id == original.app_id
|
||||
assert deserialized.workflow_id == original.workflow_id
|
||||
assert deserialized.query == original.query
|
||||
assert deserialized.conversation_id == original.conversation_id
|
||||
assert deserialized.dialogue_count == original.dialogue_count
|
||||
assert deserialized.workflow_execution_id == original.workflow_execution_id
|
||||
assert list(deserialized.files) == list(original.files)
|
||||
|
||||
def test_json_round_trip(self):
|
||||
"""Test JSON serialization/deserialization consistency with proper structure."""
|
||||
# Create original SystemVariable
|
||||
original = SystemVariable(**COMPLETE_VALID_DATA)
|
||||
|
||||
# Serialize to JSON string
|
||||
json_str = original.model_dump_json()
|
||||
|
||||
# Parse JSON and verify structure
|
||||
json_data = json.loads(json_str)
|
||||
assert "workflow_run_id" in json_data
|
||||
assert "workflow_execution_id" not in json_data
|
||||
assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
|
||||
|
||||
# Deserialize from JSON data
|
||||
deserialized = SystemVariable(**json_data)
|
||||
|
||||
# Verify key fields match after JSON round-trip
|
||||
assert deserialized.workflow_execution_id == original.workflow_execution_id
|
||||
assert deserialized.user_id == original.user_id
|
||||
assert deserialized.app_id == original.app_id
|
||||
assert deserialized.workflow_id == original.workflow_id
|
||||
|
||||
def test_files_field_deserialization(self):
|
||||
"""Test deserialization with File objects in the files field - SystemVariable specific logic."""
|
||||
# Test with empty files list
|
||||
data_empty = {**VALID_BASE_DATA, "files": []}
|
||||
system_var_empty = SystemVariable(**data_empty)
|
||||
assert system_var_empty.files == []
|
||||
|
||||
# Test with single File object
|
||||
test_file = create_test_file()
|
||||
data_single = {**VALID_BASE_DATA, "files": [test_file]}
|
||||
system_var_single = SystemVariable(**data_single)
|
||||
assert len(system_var_single.files) == 1
|
||||
assert system_var_single.files[0].filename == "test.txt"
|
||||
assert system_var_single.files[0].tenant_id == "test-tenant-id"
|
||||
|
||||
# Test with multiple File objects
|
||||
file1 = File(
|
||||
tenant_id="tenant1",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="file1",
|
||||
filename="doc1.txt",
|
||||
storage_key="key1",
|
||||
)
|
||||
file2 = File(
|
||||
tenant_id="tenant2",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.jpg",
|
||||
filename="image.jpg",
|
||||
storage_key="key2",
|
||||
)
|
||||
|
||||
data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]}
|
||||
system_var_multiple = SystemVariable(**data_multiple)
|
||||
assert len(system_var_multiple.files) == 2
|
||||
assert system_var_multiple.files[0].filename == "doc1.txt"
|
||||
assert system_var_multiple.files[1].filename == "image.jpg"
|
||||
|
||||
# Verify files field serialization/deserialization
|
||||
serialized = system_var_multiple.model_dump(mode="json")
|
||||
deserialized = SystemVariable(**serialized)
|
||||
assert len(deserialized.files) == 2
|
||||
assert deserialized.files[0].filename == "doc1.txt"
|
||||
assert deserialized.files[1].filename == "image.jpg"
|
||||
|
||||
def test_alias_serialization_consistency(self):
|
||||
"""Test that alias handling works consistently in both serialization directions."""
|
||||
workflow_id = "test-workflow-id"
|
||||
|
||||
# Create with workflow_run_id (alias)
|
||||
data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
|
||||
system_var = SystemVariable(**data_with_alias)
|
||||
|
||||
# Serialize and verify alias is used
|
||||
serialized = system_var.model_dump()
|
||||
assert serialized["workflow_run_id"] == workflow_id
|
||||
assert "workflow_execution_id" not in serialized
|
||||
|
||||
# Deserialize and verify field mapping
|
||||
deserialized = SystemVariable(**serialized)
|
||||
assert deserialized.workflow_execution_id == workflow_id
|
||||
|
||||
# Test JSON serialization path
|
||||
json_serialized = json.loads(system_var.model_dump_json())
|
||||
assert json_serialized["workflow_run_id"] == workflow_id
|
||||
assert "workflow_execution_id" not in json_serialized
|
||||
|
||||
json_deserialized = SystemVariable(**json_serialized)
|
||||
assert json_deserialized.workflow_execution_id == workflow_id
|
||||
|
||||
def test_model_validator_serialization_logic(self):
|
||||
"""Test the custom model validator behavior for serialization scenarios."""
|
||||
workflow_id = "test-workflow-execution-id"
|
||||
|
||||
# Test direct instantiation with workflow_execution_id (should work)
|
||||
data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
|
||||
system_var1 = SystemVariable(**data1)
|
||||
assert system_var1.workflow_execution_id == workflow_id
|
||||
|
||||
# Test serialization of the above (should use alias)
|
||||
serialized1 = system_var1.model_dump()
|
||||
assert "workflow_run_id" in serialized1
|
||||
assert serialized1["workflow_run_id"] == workflow_id
|
||||
|
||||
# Test both present - workflow_run_id takes precedence (validator logic)
|
||||
data2 = {
|
||||
**VALID_BASE_DATA,
|
||||
"workflow_execution_id": "should-be-removed",
|
||||
"workflow_run_id": workflow_id,
|
||||
}
|
||||
system_var2 = SystemVariable(**data2)
|
||||
assert system_var2.workflow_execution_id == workflow_id
|
||||
|
||||
# Verify serialization consistency
|
||||
serialized2 = system_var2.model_dump()
|
||||
assert serialized2["workflow_run_id"] == workflow_id
|
||||
|
||||
|
||||
def test_constructor_with_extra_key():
|
||||
# Test that SystemVariable should forbid extra keys
|
||||
with pytest.raises(ValidationError):
|
||||
# This should fail because there is an unexpected key.
|
||||
SystemVariable(invalid_key=1) # type: ignore
|
||||
Loading…
Reference in New Issue