test(api): Add tests for GraphRuntimeState serialization / deserialzation

pull/22025/head
QuantumGhost 11 months ago
parent d9bc894bb9
commit 5a51e58548

@ -5,7 +5,6 @@ from uuid import uuid4
from pydantic import Discriminator, Field, Tag from pydantic import Discriminator, Field, Tag
from core.helper import encrypter from core.helper import encrypter
from core.variables.segment_group import SegmentGroup
from .segments import ( from .segments import (
ArrayAnySegment, ArrayAnySegment,

@ -2,8 +2,6 @@ from collections.abc import Mapping, Sequence
from typing import Any, cast from typing import Any, cast
from uuid import uuid4 from uuid import uuid4
from shapely import is_valid
from configs import dify_config from configs import dify_config
from core.file import File from core.file import File
from core.variables.exc import VariableError from core.variables.exc import VariableError

@ -1,14 +1,49 @@
import dataclasses
from pydantic import BaseModel
from core.file import File, FileTransferMethod, FileType
from core.helper import encrypter from core.helper import encrypter
from core.variables import SecretVariable, StringVariable from core.variables.segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FileSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
Segment,
SegmentUnion,
StringSegment,
get_segment_discriminator,
)
from core.variables.types import SegmentType
from core.variables.variables import (
ArrayAnyVariable,
ArrayFileVariable,
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FileVariable,
FloatVariable,
IntegerVariable,
NoneVariable,
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
VariableUnion,
)
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.system_variable import SystemVariable
def test_segment_group_to_text(): def test_segment_group_to_text():
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables={ system_variables=SystemVariable(user_id="fake-user-id"),
SystemVariableKey("user_id"): "fake-user-id",
},
user_inputs={}, user_inputs={},
environment_variables=[ environment_variables=[
SecretVariable(name="secret_key", value="fake-secret-key"), SecretVariable(name="secret_key", value="fake-secret-key"),
@ -30,7 +65,7 @@ def test_segment_group_to_text():
def test_convert_constant_to_segment_group(): def test_convert_constant_to_segment_group():
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables={}, system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"),
user_inputs={}, user_inputs={},
environment_variables=[], environment_variables=[],
conversation_variables=[], conversation_variables=[],
@ -43,9 +78,7 @@ def test_convert_constant_to_segment_group():
def test_convert_variable_to_segment_group(): def test_convert_variable_to_segment_group():
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables={ system_variables=SystemVariable(user_id="fake-user-id"),
SystemVariableKey("user_id"): "fake-user-id",
},
user_inputs={}, user_inputs={},
environment_variables=[], environment_variables=[],
conversation_variables=[], conversation_variables=[],
@ -56,3 +89,297 @@ def test_convert_variable_to_segment_group():
assert segments_group.log == "fake-user-id" assert segments_group.log == "fake-user-id"
assert isinstance(segments_group.value[0], StringVariable) assert isinstance(segments_group.value[0], StringVariable)
assert segments_group.value[0].value == "fake-user-id" assert segments_group.value[0].value == "fake-user-id"
class _Segments(BaseModel):
segments: list[SegmentUnion]
class _Variables(BaseModel):
variables: list[VariableUnion]
def create_test_file(
file_type: FileType = FileType.DOCUMENT,
transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE,
filename: str = "test.txt",
extension: str = ".txt",
mime_type: str = "text/plain",
size: int = 1024,
) -> File:
"""Factory function to create File objects for testing"""
return File(
tenant_id="test-tenant",
type=file_type,
transfer_method=transfer_method,
filename=filename,
extension=extension,
mime_type=mime_type,
size=size,
related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None,
remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None,
storage_key="test-storage-key",
)
class TestSegmentDumpAndLoad:
"""Test suite for segment and variable serialization/deserialization"""
def test_segments(self):
"""Test basic segment serialization compatibility"""
model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")])
json = model.model_dump_json()
print("Json: ", json)
loaded = _Segments.model_validate_json(json)
assert loaded == model
def test_segment_number(self):
"""Test number segment serialization compatibility"""
model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)])
json = model.model_dump_json()
print("Json: ", json)
loaded = _Segments.model_validate_json(json)
assert loaded == model
def test_variables(self):
"""Test variable serialization compatibility"""
model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")])
json = model.model_dump_json()
print("Json: ", json)
restored = _Variables.model_validate_json(json)
assert restored == model
def test_all_segments_serialization(self):
"""Test serialization/deserialization of all segment types"""
# Create one instance of each segment type
test_file = create_test_file()
all_segments: list[SegmentUnion] = [
NoneSegment(),
StringSegment(value="test string"),
IntegerSegment(value=42),
FloatSegment(value=3.14),
ObjectSegment(value={"key": "value", "number": 123}),
FileSegment(value=test_file),
ArrayAnySegment(value=[1, "string", 3.14, {"key": "value"}]),
ArrayStringSegment(value=["hello", "world"]),
ArrayNumberSegment(value=[1, 2.5, 3]),
ArrayObjectSegment(value=[{"id": 1}, {"id": 2}]),
ArrayFileSegment(value=[]), # Empty array to avoid file complexity
]
# Test serialization and deserialization
model = _Segments(segments=all_segments)
json_str = model.model_dump_json()
loaded = _Segments.model_validate_json(json_str)
# Verify all segments are preserved
assert len(loaded.segments) == len(all_segments)
for original, loaded_segment in zip(all_segments, loaded.segments):
assert type(loaded_segment) == type(original)
assert loaded_segment.value_type == original.value_type
# For file segments, compare key properties instead of exact equality
if isinstance(original, FileSegment) and isinstance(loaded_segment, FileSegment):
orig_file = original.value
loaded_file = loaded_segment.value
assert isinstance(orig_file, File)
assert isinstance(loaded_file, File)
assert loaded_file.tenant_id == orig_file.tenant_id
assert loaded_file.type == orig_file.type
assert loaded_file.filename == orig_file.filename
else:
assert loaded_segment.value == original.value
def test_all_variables_serialization(self):
"""Test serialization/deserialization of all variable types"""
# Create one instance of each variable type
test_file = create_test_file()
all_variables: list[VariableUnion] = [
NoneVariable(name="none_var"),
StringVariable(value="test string", name="string_var"),
IntegerVariable(value=42, name="int_var"),
FloatVariable(value=3.14, name="float_var"),
ObjectVariable(value={"key": "value", "number": 123}, name="object_var"),
FileVariable(value=test_file, name="file_var"),
ArrayAnyVariable(value=[1, "string", 3.14, {"key": "value"}], name="array_any_var"),
ArrayStringVariable(value=["hello", "world"], name="array_string_var"),
ArrayNumberVariable(value=[1, 2.5, 3], name="array_number_var"),
ArrayObjectVariable(value=[{"id": 1}, {"id": 2}], name="array_object_var"),
ArrayFileVariable(value=[], name="array_file_var"), # Empty array to avoid file complexity
]
# Test serialization and deserialization
model = _Variables(variables=all_variables)
json_str = model.model_dump_json()
loaded = _Variables.model_validate_json(json_str)
# Verify all variables are preserved
assert len(loaded.variables) == len(all_variables)
for original, loaded_variable in zip(all_variables, loaded.variables):
assert type(loaded_variable) == type(original)
assert loaded_variable.value_type == original.value_type
assert loaded_variable.name == original.name
# For file variables, compare key properties instead of exact equality
if isinstance(original, FileVariable) and isinstance(loaded_variable, FileVariable):
orig_file = original.value
loaded_file = loaded_variable.value
assert isinstance(orig_file, File)
assert isinstance(loaded_file, File)
assert loaded_file.tenant_id == orig_file.tenant_id
assert loaded_file.type == orig_file.type
assert loaded_file.filename == orig_file.filename
else:
assert loaded_variable.value == original.value
def test_segment_discriminator_function_for_segment_types(self):
"""Test the segment discriminator function"""
@dataclasses.dataclass
class TestCase:
segment: Segment
expected_segment_type: SegmentType
file1 = create_test_file()
file2 = create_test_file(filename="test2.txt")
cases = [
TestCase(
NoneSegment(),
SegmentType.NONE,
),
TestCase(
StringSegment(value=""),
SegmentType.STRING,
),
TestCase(
FloatSegment(value=0.0),
SegmentType.FLOAT,
),
TestCase(
IntegerSegment(value=0),
SegmentType.INTEGER,
),
TestCase(
ObjectSegment(value={}),
SegmentType.OBJECT,
),
TestCase(
FileSegment(value=file1),
SegmentType.FILE,
),
TestCase(
ArrayAnySegment(value=[0, 0.0, ""]),
SegmentType.ARRAY_ANY,
),
TestCase(
ArrayStringSegment(value=[""]),
SegmentType.ARRAY_STRING,
),
TestCase(
ArrayNumberSegment(value=[0, 0.0]),
SegmentType.ARRAY_NUMBER,
),
TestCase(
ArrayObjectSegment(value=[{}]),
SegmentType.ARRAY_OBJECT,
),
TestCase(
ArrayFileSegment(value=[file1, file2]),
SegmentType.ARRAY_FILE,
),
]
for test_case in cases:
segment = test_case.segment
assert get_segment_discriminator(segment) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for type {type(segment)}"
)
model_dict = segment.model_dump(mode="json")
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for serialized form of type {type(segment)}"
)
def test_variable_discriminator_function_for_variable_types(self):
"""Test the variable discriminator function"""
@dataclasses.dataclass
class TestCase:
variable: Variable
expected_segment_type: SegmentType
file1 = create_test_file()
file2 = create_test_file(filename="test2.txt")
cases = [
TestCase(
NoneVariable(name="none_var"),
SegmentType.NONE,
),
TestCase(
StringVariable(value="test", name="string_var"),
SegmentType.STRING,
),
TestCase(
FloatVariable(value=0.0, name="float_var"),
SegmentType.FLOAT,
),
TestCase(
IntegerVariable(value=0, name="int_var"),
SegmentType.INTEGER,
),
TestCase(
ObjectVariable(value={}, name="object_var"),
SegmentType.OBJECT,
),
TestCase(
FileVariable(value=file1, name="file_var"),
SegmentType.FILE,
),
TestCase(
SecretVariable(value="secret", name="secret_var"),
SegmentType.SECRET,
),
TestCase(
ArrayAnyVariable(value=[0, 0.0, ""], name="array_any_var"),
SegmentType.ARRAY_ANY,
),
TestCase(
ArrayStringVariable(value=[""], name="array_string_var"),
SegmentType.ARRAY_STRING,
),
TestCase(
ArrayNumberVariable(value=[0, 0.0], name="array_number_var"),
SegmentType.ARRAY_NUMBER,
),
TestCase(
ArrayObjectVariable(value=[{}], name="array_object_var"),
SegmentType.ARRAY_OBJECT,
),
TestCase(
ArrayFileVariable(value=[file1, file2], name="array_file_var"),
SegmentType.ARRAY_FILE,
),
]
for test_case in cases:
variable = test_case.variable
assert get_segment_discriminator(variable) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for type {type(variable)}"
)
model_dict = variable.model_dump(mode="json")
assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, (
f"get_segment_discriminator failed for serialized form of type {type(variable)}"
)
def test_invlaid_value_for_discriminator(self):
# Test invalid cases
assert get_segment_discriminator({"value_type": "invalid"}) is None
assert get_segment_discriminator({}) is None
assert get_segment_discriminator("not_a_dict") is None
assert get_segment_discriminator(42) is None
assert get_segment_discriminator(object) is None

@ -0,0 +1,60 @@
from core.variables.types import SegmentType
class TestSegmentTypeIsArrayType:
"""
Test class for SegmentType.is_array_type method.
Provides comprehensive coverage of all SegmentType values to ensure
correct identification of array and non-array types.
"""
def test_is_array_type(self):
"""
Test that all SegmentType enum values are covered in our test cases.
Ensures comprehensive coverage by verifying that every SegmentType
value is tested for the is_array_type method.
"""
# Arrange
all_segment_types = set(SegmentType)
expected_array_types = [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
]
expected_non_array_types = [
SegmentType.INTEGER,
SegmentType.FLOAT,
SegmentType.NUMBER,
SegmentType.STRING,
SegmentType.OBJECT,
SegmentType.SECRET,
SegmentType.FILE,
SegmentType.NONE,
SegmentType.GROUP,
]
for seg_type in expected_array_types:
assert seg_type.is_array_type()
for seg_type in expected_non_array_types:
assert not seg_type.is_array_type()
# Act & Assert
covered_types = set(expected_array_types) | set(expected_non_array_types)
assert covered_types == set(SegmentType), "All SegmentType values should be covered in tests"
def test_all_enum_values_are_supported(self):
"""
Test that all enum values are supported and return boolean values.
Validates that every SegmentType enum value can be processed by
is_array_type method and returns a boolean value.
"""
enum_values: list[SegmentType] = list(SegmentType)
for seg_type in enum_values:
is_array = seg_type.is_array_type()
assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"

@ -0,0 +1,146 @@
import time
from decimal import Decimal
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
from core.workflow.system_variable import SystemVariable
def create_test_graph_runtime_state() -> GraphRuntimeState:
"""Factory function to create a GraphRuntimeState with non-empty values for testing."""
# Create a variable pool with system variables
system_vars = SystemVariable(
user_id="test_user_123",
app_id="test_app_456",
workflow_id="test_workflow_789",
workflow_execution_id="test_execution_001",
query="test query",
conversation_id="test_conv_123",
dialogue_count=5,
)
variable_pool = VariablePool(system_variables=system_vars)
# Add some variables to the variable pool
variable_pool.add(["test_node", "test_var"], "test_value")
variable_pool.add(["another_node", "another_var"], 42)
# Create LLM usage with realistic values
llm_usage = LLMUsage(
prompt_tokens=150,
prompt_unit_price=Decimal("0.001"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.15"),
completion_tokens=75,
completion_unit_price=Decimal("0.002"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.15"),
total_tokens=225,
total_price=Decimal("0.30"),
currency="USD",
latency=1.25,
)
# Create runtime route state with some node states
node_run_state = RuntimeRouteState()
node_state = node_run_state.create_node_state("test_node_1")
node_run_state.add_route(node_state.id, "target_node_id")
return GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
total_tokens=100,
llm_usage=llm_usage,
outputs={
"string_output": "test result",
"int_output": 42,
"float_output": 3.14,
"list_output": ["item1", "item2", "item3"],
"dict_output": {"key1": "value1", "key2": 123},
"nested_dict": {"level1": {"level2": ["nested", "list", 456]}},
},
node_run_steps=5,
node_run_state=node_run_state,
)
def test_basic_round_trip_serialization():
"""Test basic round-trip serialization ensures GraphRuntimeState values remain unchanged."""
# Create a state with non-empty values
original_state = create_test_graph_runtime_state()
# Serialize to JSON and deserialize back
json_data = original_state.model_dump_json()
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
# Core test: ensure the round-trip preserves all values
assert deserialized_state == original_state
# Serialize to JSON and deserialize back
dict_data = original_state.model_dump(mode="python")
deserialized_state = GraphRuntimeState.model_validate(dict_data)
assert deserialized_state == original_state
# Serialize to JSON and deserialize back
dict_data = original_state.model_dump(mode="json")
deserialized_state = GraphRuntimeState.model_validate(dict_data)
assert deserialized_state == original_state
def test_outputs_field_round_trip():
"""Test the problematic outputs field maintains values through round-trip serialization."""
original_state = create_test_graph_runtime_state()
# Serialize and deserialize
json_data = original_state.model_dump_json()
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
# Verify the outputs field specifically maintains its values
assert deserialized_state.outputs == original_state.outputs
assert deserialized_state == original_state
def test_empty_outputs_round_trip():
"""Test round-trip serialization with empty outputs field."""
variable_pool = VariablePool.empty()
original_state = GraphRuntimeState(
variable_pool=variable_pool,
start_at=time.perf_counter(),
outputs={}, # Empty outputs
)
json_data = original_state.model_dump_json()
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
assert deserialized_state == original_state
def test_llm_usage_round_trip():
# Create LLM usage with specific decimal values
llm_usage = LLMUsage(
prompt_tokens=100,
prompt_unit_price=Decimal("0.0015"),
prompt_price_unit=Decimal("1000"),
prompt_price=Decimal("0.15"),
completion_tokens=50,
completion_unit_price=Decimal("0.003"),
completion_price_unit=Decimal("1000"),
completion_price=Decimal("0.15"),
total_tokens=150,
total_price=Decimal("0.30"),
currency="USD",
latency=2.5,
)
json_data = llm_usage.model_dump_json()
deserialized = LLMUsage.model_validate_json(json_data)
assert deserialized == llm_usage
dict_data = llm_usage.model_dump(mode="python")
deserialized = LLMUsage.model_validate(dict_data)
assert deserialized == llm_usage
dict_data = llm_usage.model_dump(mode="json")
deserialized = LLMUsage.model_validate(dict_data)
assert deserialized == llm_usage

@ -0,0 +1,401 @@
import json
import uuid
from datetime import UTC, datetime
import pytest
from pydantic import ValidationError
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState
_TEST_DATETIME = datetime(2024, 1, 15, 10, 30, 45)
class TestRouteNodeStateSerialization:
"""Test cases for RouteNodeState Pydantic serialization/deserialization."""
def _test_route_node_state(self):
"""Test comprehensive RouteNodeState serialization with all core fields validation."""
node_run_result = NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"input_key": "input_value"},
outputs={"output_key": "output_value"},
)
node_state = RouteNodeState(
node_id="comprehensive_test_node",
start_at=_TEST_DATETIME,
finished_at=_TEST_DATETIME,
status=RouteNodeState.Status.SUCCESS,
node_run_result=node_run_result,
index=5,
paused_at=_TEST_DATETIME,
paused_by="user_123",
failed_reason="test_reason",
)
return node_state
def test_route_node_state_comprehensive_field_validation(self):
"""Test comprehensive RouteNodeState serialization with all core fields validation."""
node_state = self._test_route_node_state()
serialized = node_state.model_dump()
# Comprehensive validation of all RouteNodeState fields
assert serialized["node_id"] == "comprehensive_test_node"
assert serialized["status"] == RouteNodeState.Status.SUCCESS
assert serialized["start_at"] == _TEST_DATETIME
assert serialized["finished_at"] == _TEST_DATETIME
assert serialized["paused_at"] == _TEST_DATETIME
assert serialized["paused_by"] == "user_123"
assert serialized["failed_reason"] == "test_reason"
assert serialized["index"] == 5
assert "id" in serialized
assert isinstance(serialized["id"], str)
uuid.UUID(serialized["id"]) # Validate UUID format
# Validate nested NodeRunResult structure
assert serialized["node_run_result"] is not None
assert serialized["node_run_result"]["status"] == WorkflowNodeExecutionStatus.SUCCEEDED
assert serialized["node_run_result"]["inputs"] == {"input_key": "input_value"}
assert serialized["node_run_result"]["outputs"] == {"output_key": "output_value"}
def test_route_node_state_minimal_required_fields(self):
"""Test RouteNodeState with only required fields, focusing on defaults."""
node_state = RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME)
serialized = node_state.model_dump()
# Focus on required fields and default values (not re-testing all fields)
assert serialized["node_id"] == "minimal_node"
assert serialized["start_at"] == _TEST_DATETIME
assert serialized["status"] == RouteNodeState.Status.RUNNING # Default status
assert serialized["index"] == 1 # Default index
assert serialized["node_run_result"] is None # Default None
json = node_state.model_dump_json()
deserialized = RouteNodeState.model_validate_json(json)
assert deserialized == node_state
def test_route_node_state_deserialization_from_dict(self):
"""Test RouteNodeState deserialization from dictionary data."""
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
test_id = str(uuid.uuid4())
dict_data = {
"id": test_id,
"node_id": "deserialized_node",
"start_at": test_datetime,
"status": "success",
"finished_at": test_datetime,
"index": 3,
}
node_state = RouteNodeState.model_validate(dict_data)
# Focus on deserialization accuracy
assert node_state.id == test_id
assert node_state.node_id == "deserialized_node"
assert node_state.start_at == test_datetime
assert node_state.status == RouteNodeState.Status.SUCCESS
assert node_state.finished_at == test_datetime
assert node_state.index == 3
def test_route_node_state_round_trip_consistency(self):
node_states = (
self._test_route_node_state(),
RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME),
)
for node_state in node_states:
json = node_state.model_dump_json()
deserialized = RouteNodeState.model_validate_json(json)
assert deserialized == node_state
dict_ = node_state.model_dump(mode="python")
deserialized = RouteNodeState.model_validate(dict_)
assert deserialized == node_state
dict_ = node_state.model_dump(mode="json")
deserialized = RouteNodeState.model_validate(dict_)
assert deserialized == node_state
class TestRouteNodeStateEnumSerialization:
"""Dedicated tests for RouteNodeState Status enum serialization behavior."""
def test_status_enum_model_dump_behavior(self):
"""Test Status enum serialization in model_dump() returns enum objects."""
for status_enum in RouteNodeState.Status:
node_state = RouteNodeState(node_id="enum_test", start_at=_TEST_DATETIME, status=status_enum)
serialized = node_state.model_dump(mode="python")
assert serialized["status"] == status_enum
serialized = node_state.model_dump(mode="json")
assert serialized["status"] == status_enum.value
def test_status_enum_json_serialization_behavior(self):
"""Test Status enum serialization in JSON returns string values."""
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
enum_to_string_mapping = {
RouteNodeState.Status.RUNNING: "running",
RouteNodeState.Status.SUCCESS: "success",
RouteNodeState.Status.FAILED: "failed",
RouteNodeState.Status.PAUSED: "paused",
RouteNodeState.Status.EXCEPTION: "exception",
}
for status_enum, expected_string in enum_to_string_mapping.items():
node_state = RouteNodeState(node_id="json_enum_test", start_at=test_datetime, status=status_enum)
json_data = json.loads(node_state.model_dump_json())
assert json_data["status"] == expected_string
def test_status_enum_deserialization_from_string(self):
"""Test Status enum deserialization from string values."""
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
string_to_enum_mapping = {
"running": RouteNodeState.Status.RUNNING,
"success": RouteNodeState.Status.SUCCESS,
"failed": RouteNodeState.Status.FAILED,
"paused": RouteNodeState.Status.PAUSED,
"exception": RouteNodeState.Status.EXCEPTION,
}
for status_string, expected_enum in string_to_enum_mapping.items():
dict_data = {
"node_id": "enum_deserialize_test",
"start_at": test_datetime,
"status": status_string,
}
node_state = RouteNodeState.model_validate(dict_data)
assert node_state.status == expected_enum
class TestRuntimeRouteStateSerialization:
"""Test cases for RuntimeRouteState Pydantic serialization/deserialization."""
_NODE1_ID = "node_1"
_ROUTE_STATE1_ID = str(uuid.uuid4())
_NODE2_ID = "node_2"
_ROUTE_STATE2_ID = str(uuid.uuid4())
_NODE3_ID = "node_3"
_ROUTE_STATE3_ID = str(uuid.uuid4())
def _get_runtime_route_state(self):
# Create node states with different configurations
node_state_1 = RouteNodeState(
id=self._ROUTE_STATE1_ID,
node_id=self._NODE1_ID,
start_at=_TEST_DATETIME,
index=1,
)
node_state_2 = RouteNodeState(
id=self._ROUTE_STATE2_ID,
node_id=self._NODE2_ID,
start_at=_TEST_DATETIME,
status=RouteNodeState.Status.SUCCESS,
finished_at=_TEST_DATETIME,
index=2,
)
node_state_3 = RouteNodeState(
id=self._ROUTE_STATE3_ID,
node_id=self._NODE3_ID,
start_at=_TEST_DATETIME,
status=RouteNodeState.Status.FAILED,
failed_reason="Test failure",
index=3,
)
runtime_state = RuntimeRouteState(
routes={node_state_1.id: [node_state_2.id, node_state_3.id], node_state_2.id: [node_state_3.id]},
node_state_mapping={
node_state_1.id: node_state_1,
node_state_2.id: node_state_2,
node_state_3.id: node_state_3,
},
)
return runtime_state
def test_runtime_route_state_comprehensive_structure_validation(self):
"""Test comprehensive RuntimeRouteState serialization with full structure validation."""
runtime_state = self._get_runtime_route_state()
serialized = runtime_state.model_dump()
# Comprehensive validation of RuntimeRouteState structure
assert "routes" in serialized
assert "node_state_mapping" in serialized
assert isinstance(serialized["routes"], dict)
assert isinstance(serialized["node_state_mapping"], dict)
# Validate routes dictionary structure and content
assert len(serialized["routes"]) == 2
assert self._ROUTE_STATE1_ID in serialized["routes"]
assert self._ROUTE_STATE2_ID in serialized["routes"]
assert serialized["routes"][self._ROUTE_STATE1_ID] == [self._ROUTE_STATE2_ID, self._ROUTE_STATE3_ID]
assert serialized["routes"][self._ROUTE_STATE2_ID] == [self._ROUTE_STATE3_ID]
# Validate node_state_mapping dictionary structure and content
assert len(serialized["node_state_mapping"]) == 3
for state_id in [
self._ROUTE_STATE1_ID,
self._ROUTE_STATE2_ID,
self._ROUTE_STATE3_ID,
]:
assert state_id in serialized["node_state_mapping"]
node_data = serialized["node_state_mapping"][state_id]
node_state = runtime_state.node_state_mapping[state_id]
assert node_data["node_id"] == node_state.node_id
assert node_data["status"] == node_state.status
assert node_data["index"] == node_state.index
def test_runtime_route_state_empty_collections(self):
"""Test RuntimeRouteState with empty collections, focusing on default behavior."""
runtime_state = RuntimeRouteState()
serialized = runtime_state.model_dump()
# Focus on default empty collection behavior
assert serialized["routes"] == {}
assert serialized["node_state_mapping"] == {}
assert isinstance(serialized["routes"], dict)
assert isinstance(serialized["node_state_mapping"], dict)
def test_runtime_route_state_json_serialization_structure(self):
"""Test RuntimeRouteState JSON serialization structure."""
node_state = RouteNodeState(node_id="json_node", start_at=_TEST_DATETIME)
runtime_state = RuntimeRouteState(
routes={"source": ["target1", "target2"]}, node_state_mapping={node_state.id: node_state}
)
json_str = runtime_state.model_dump_json()
json_data = json.loads(json_str)
# Focus on JSON structure validation
assert isinstance(json_str, str)
assert isinstance(json_data, dict)
assert "routes" in json_data
assert "node_state_mapping" in json_data
assert json_data["routes"]["source"] == ["target1", "target2"]
assert node_state.id in json_data["node_state_mapping"]
def test_runtime_route_state_deserialization_from_dict(self):
"""Test RuntimeRouteState deserialization from dictionary data."""
node_id = str(uuid.uuid4())
dict_data = {
"routes": {"source_node": ["target_node_1", "target_node_2"]},
"node_state_mapping": {
node_id: {
"id": node_id,
"node_id": "test_node",
"start_at": _TEST_DATETIME,
"status": "running",
"index": 1,
}
},
}
runtime_state = RuntimeRouteState.model_validate(dict_data)
# Focus on deserialization accuracy
assert runtime_state.routes == {"source_node": ["target_node_1", "target_node_2"]}
assert len(runtime_state.node_state_mapping) == 1
assert node_id in runtime_state.node_state_mapping
deserialized_node = runtime_state.node_state_mapping[node_id]
assert deserialized_node.node_id == "test_node"
assert deserialized_node.status == RouteNodeState.Status.RUNNING
assert deserialized_node.index == 1
def test_runtime_route_state_round_trip_consistency(self):
"""Test RuntimeRouteState round-trip serialization consistency."""
original = self._get_runtime_route_state()
# Dictionary round trip
dict_data = original.model_dump(mode="python")
reconstructed = RuntimeRouteState.model_validate(dict_data)
assert reconstructed == original
dict_data = original.model_dump(mode="json")
reconstructed = RuntimeRouteState.model_validate(dict_data)
assert reconstructed == original
# JSON round trip
json_str = original.model_dump_json()
json_reconstructed = RuntimeRouteState.model_validate_json(json_str)
assert json_reconstructed == original
class TestSerializationEdgeCases:
"""Test edge cases and error conditions for serialization/deserialization."""
def test_invalid_status_deserialization(self):
"""Test deserialization with invalid status values."""
test_datetime = _TEST_DATETIME
invalid_data = {
"node_id": "invalid_test",
"start_at": test_datetime,
"status": "invalid_status",
}
with pytest.raises(ValidationError) as exc_info:
RouteNodeState.model_validate(invalid_data)
assert "status" in str(exc_info.value)
def test_missing_required_fields_deserialization(self):
"""Test deserialization with missing required fields."""
incomplete_data = {"id": str(uuid.uuid4())}
with pytest.raises(ValidationError) as exc_info:
RouteNodeState.model_validate(incomplete_data)
error_str = str(exc_info.value)
assert "node_id" in error_str or "start_at" in error_str
def test_invalid_datetime_deserialization(self):
"""Test deserialization with invalid datetime values."""
invalid_data = {
"node_id": "datetime_test",
"start_at": "invalid_datetime",
"status": "running",
}
with pytest.raises(ValidationError) as exc_info:
RouteNodeState.model_validate(invalid_data)
assert "start_at" in str(exc_info.value)
def test_invalid_routes_structure_deserialization(self):
"""Test RuntimeRouteState deserialization with invalid routes structure."""
invalid_data = {
"routes": "invalid_routes_structure", # Should be dict
"node_state_mapping": {},
}
with pytest.raises(ValidationError) as exc_info:
RuntimeRouteState.model_validate(invalid_data)
assert "routes" in str(exc_info.value)
def test_timezone_handling_in_datetime_fields(self):
"""Test timezone handling in datetime field serialization."""
utc_datetime = datetime.now(UTC)
naive_datetime = utc_datetime.replace(tzinfo=None)
node_state = RouteNodeState(node_id="timezone_test", start_at=naive_datetime)
dict_ = node_state.model_dump()
assert dict_["start_at"] == naive_datetime
# Test round trip
reconstructed = RouteNodeState.model_validate(dict_)
assert reconstructed.start_at == naive_datetime
assert reconstructed.start_at.tzinfo is None
json = node_state.model_dump_json()
reconstructed = RouteNodeState.model_validate_json(json)
assert reconstructed.start_at == naive_datetime
assert reconstructed.start_at.tzinfo is None

@ -0,0 +1,251 @@
import json
from typing import Any
import pytest
from pydantic import ValidationError
from core.file.enums import FileTransferMethod, FileType
from core.file.models import File
from core.workflow.system_variable import SystemVariable
# Test data constants for SystemVariable serialization tests
VALID_BASE_DATA: dict[str, Any] = {
"user_id": "a20f06b1-8703-45ab-937c-860a60072113",
"app_id": "661bed75-458d-49c9-b487-fda0762677b9",
"workflow_id": "d31f2136-b292-4ae0-96d4-1e77894a4f43",
}
COMPLETE_VALID_DATA: dict[str, Any] = {
**VALID_BASE_DATA,
"query": "test query",
"files": [],
"conversation_id": "91f1eb7d-69f4-4d7b-b82f-4003d51744b9",
"dialogue_count": 5,
"workflow_run_id": "eb4704b5-2274-47f2-bfcd-0452daa82cb5",
}
def create_test_file() -> File:
"""Create a test File object for serialization tests."""
return File(
tenant_id="test-tenant-id",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test-file-id",
filename="test.txt",
extension=".txt",
mime_type="text/plain",
size=1024,
storage_key="test-storage-key",
)
class TestSystemVariableSerialization:
"""Focused tests for SystemVariable serialization/deserialization logic."""
def test_basic_deserialization(self):
"""Test successful deserialization from JSON structure with all fields correctly mapped."""
# Test with complete data
system_var = SystemVariable(**COMPLETE_VALID_DATA)
# Verify all fields are correctly mapped
assert system_var.user_id == COMPLETE_VALID_DATA["user_id"]
assert system_var.app_id == COMPLETE_VALID_DATA["app_id"]
assert system_var.workflow_id == COMPLETE_VALID_DATA["workflow_id"]
assert system_var.query == COMPLETE_VALID_DATA["query"]
assert system_var.conversation_id == COMPLETE_VALID_DATA["conversation_id"]
assert system_var.dialogue_count == COMPLETE_VALID_DATA["dialogue_count"]
assert system_var.workflow_execution_id == COMPLETE_VALID_DATA["workflow_run_id"]
assert system_var.files == []
# Test with minimal data (only required fields)
minimal_var = SystemVariable(**VALID_BASE_DATA)
assert minimal_var.user_id == VALID_BASE_DATA["user_id"]
assert minimal_var.app_id == VALID_BASE_DATA["app_id"]
assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"]
assert minimal_var.query is None
assert minimal_var.conversation_id is None
assert minimal_var.dialogue_count is None
assert minimal_var.workflow_execution_id is None
assert minimal_var.files == []
def test_alias_handling(self):
"""Test workflow_execution_id vs workflow_run_id alias resolution - core deserialization logic."""
workflow_id = "eb4704b5-2274-47f2-bfcd-0452daa82cb5"
# Test workflow_run_id only (preferred alias)
data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
system_var1 = SystemVariable(**data_run_id)
assert system_var1.workflow_execution_id == workflow_id
# Test workflow_execution_id only (direct field name)
data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
system_var2 = SystemVariable(**data_execution_id)
assert system_var2.workflow_execution_id == workflow_id
# Test both present - workflow_run_id should take precedence
data_both = {
**VALID_BASE_DATA,
"workflow_execution_id": "should-be-ignored",
"workflow_run_id": workflow_id,
}
system_var3 = SystemVariable(**data_both)
assert system_var3.workflow_execution_id == workflow_id
# Test neither present - should be None
system_var4 = SystemVariable(**VALID_BASE_DATA)
assert system_var4.workflow_execution_id is None
def test_serialization_round_trip(self):
"""Test that serialize → deserialize produces the same result with alias handling."""
# Create original SystemVariable
original = SystemVariable(**COMPLETE_VALID_DATA)
# Serialize to dict
serialized = original.model_dump(mode="json")
# Verify alias is used in serialization (workflow_run_id, not workflow_execution_id)
assert "workflow_run_id" in serialized
assert "workflow_execution_id" not in serialized
assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
# Deserialize back
deserialized = SystemVariable(**serialized)
# Verify all fields match after round-trip
assert deserialized.user_id == original.user_id
assert deserialized.app_id == original.app_id
assert deserialized.workflow_id == original.workflow_id
assert deserialized.query == original.query
assert deserialized.conversation_id == original.conversation_id
assert deserialized.dialogue_count == original.dialogue_count
assert deserialized.workflow_execution_id == original.workflow_execution_id
assert list(deserialized.files) == list(original.files)
def test_json_round_trip(self):
"""Test JSON serialization/deserialization consistency with proper structure."""
# Create original SystemVariable
original = SystemVariable(**COMPLETE_VALID_DATA)
# Serialize to JSON string
json_str = original.model_dump_json()
# Parse JSON and verify structure
json_data = json.loads(json_str)
assert "workflow_run_id" in json_data
assert "workflow_execution_id" not in json_data
assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"]
# Deserialize from JSON data
deserialized = SystemVariable(**json_data)
# Verify key fields match after JSON round-trip
assert deserialized.workflow_execution_id == original.workflow_execution_id
assert deserialized.user_id == original.user_id
assert deserialized.app_id == original.app_id
assert deserialized.workflow_id == original.workflow_id
def test_files_field_deserialization(self):
"""Test deserialization with File objects in the files field - SystemVariable specific logic."""
# Test with empty files list
data_empty = {**VALID_BASE_DATA, "files": []}
system_var_empty = SystemVariable(**data_empty)
assert system_var_empty.files == []
# Test with single File object
test_file = create_test_file()
data_single = {**VALID_BASE_DATA, "files": [test_file]}
system_var_single = SystemVariable(**data_single)
assert len(system_var_single.files) == 1
assert system_var_single.files[0].filename == "test.txt"
assert system_var_single.files[0].tenant_id == "test-tenant-id"
# Test with multiple File objects
file1 = File(
tenant_id="tenant1",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="file1",
filename="doc1.txt",
storage_key="key1",
)
file2 = File(
tenant_id="tenant2",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
remote_url="https://example.com/image.jpg",
filename="image.jpg",
storage_key="key2",
)
data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]}
system_var_multiple = SystemVariable(**data_multiple)
assert len(system_var_multiple.files) == 2
assert system_var_multiple.files[0].filename == "doc1.txt"
assert system_var_multiple.files[1].filename == "image.jpg"
# Verify files field serialization/deserialization
serialized = system_var_multiple.model_dump(mode="json")
deserialized = SystemVariable(**serialized)
assert len(deserialized.files) == 2
assert deserialized.files[0].filename == "doc1.txt"
assert deserialized.files[1].filename == "image.jpg"
def test_alias_serialization_consistency(self):
"""Test that alias handling works consistently in both serialization directions."""
workflow_id = "test-workflow-id"
# Create with workflow_run_id (alias)
data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id}
system_var = SystemVariable(**data_with_alias)
# Serialize and verify alias is used
serialized = system_var.model_dump()
assert serialized["workflow_run_id"] == workflow_id
assert "workflow_execution_id" not in serialized
# Deserialize and verify field mapping
deserialized = SystemVariable(**serialized)
assert deserialized.workflow_execution_id == workflow_id
# Test JSON serialization path
json_serialized = json.loads(system_var.model_dump_json())
assert json_serialized["workflow_run_id"] == workflow_id
assert "workflow_execution_id" not in json_serialized
json_deserialized = SystemVariable(**json_serialized)
assert json_deserialized.workflow_execution_id == workflow_id
def test_model_validator_serialization_logic(self):
"""Test the custom model validator behavior for serialization scenarios."""
workflow_id = "test-workflow-execution-id"
# Test direct instantiation with workflow_execution_id (should work)
data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id}
system_var1 = SystemVariable(**data1)
assert system_var1.workflow_execution_id == workflow_id
# Test serialization of the above (should use alias)
serialized1 = system_var1.model_dump()
assert "workflow_run_id" in serialized1
assert serialized1["workflow_run_id"] == workflow_id
# Test both present - workflow_run_id takes precedence (validator logic)
data2 = {
**VALID_BASE_DATA,
"workflow_execution_id": "should-be-removed",
"workflow_run_id": workflow_id,
}
system_var2 = SystemVariable(**data2)
assert system_var2.workflow_execution_id == workflow_id
# Verify serialization consistency
serialized2 = system_var2.model_dump()
assert serialized2["workflow_run_id"] == workflow_id
def test_constructor_with_extra_key():
# Test that SystemVariable should forbid extra keys
with pytest.raises(ValidationError):
# This should fail because there is an unexpected key.
SystemVariable(invalid_key=1) # type: ignore

@ -1,17 +1,40 @@
import pytest import pytest
from pydantic import ValidationError
from core.file import File, FileTransferMethod, FileType from core.file import File, FileTransferMethod, FileType
from core.variables import FileSegment, StringSegment from core.variables import FileSegment, StringSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from core.variables.segments import (
ArrayAnySegment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
FloatSegment,
IntegerSegment,
NoneSegment,
ObjectSegment,
)
from core.variables.variables import (
ArrayNumberVariable,
ArrayObjectVariable,
ArrayStringVariable,
FloatVariable,
IntegerVariable,
ObjectVariable,
StringVariable,
VariableUnion,
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.system_variable import SystemVariable
from factories.variable_factory import build_segment, segment_to_variable from factories.variable_factory import build_segment, segment_to_variable
@pytest.fixture @pytest.fixture
def pool(): def pool():
return VariablePool(system_variables={}, user_inputs={}) return VariablePool(
system_variables=SystemVariable(user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"),
user_inputs={},
)
@pytest.fixture @pytest.fixture
@ -52,18 +75,28 @@ def test_use_long_selector(pool):
class TestVariablePool: class TestVariablePool:
def test_constructor(self): def test_constructor(self):
pool = VariablePool() # Test with minimal required SystemVariable
minimal_system_vars = SystemVariable(
user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"
)
pool = VariablePool(system_variables=minimal_system_vars)
# Test with all parameters
pool = VariablePool( pool = VariablePool(
variable_dictionary={}, variable_dictionary={},
user_inputs={}, user_inputs={},
system_variables={}, system_variables=minimal_system_vars,
environment_variables=[], environment_variables=[],
conversation_variables=[], conversation_variables=[],
) )
# Test with more complex SystemVariable
complex_system_vars = SystemVariable(
user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"
)
pool = VariablePool( pool = VariablePool(
user_inputs={"key": "value"}, user_inputs={"key": "value"},
system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"}, system_variables=complex_system_vars,
environment_variables=[ environment_variables=[
segment_to_variable( segment_to_variable(
segment=build_segment(1), segment=build_segment(1),
@ -80,6 +113,302 @@ class TestVariablePool:
], ],
) )
def test_constructor_with_invalid_system_variable_key(self): def test_get_system_variables(self):
with pytest.raises(ValidationError): sys_var = SystemVariable(
VariablePool(system_variables={"invalid_key": "value"}) # type: ignore user_id="test_user_id",
app_id="test_app_id",
workflow_id="test_workflow_id",
workflow_execution_id="test_execution_123",
query="test query",
conversation_id="test_conv_id",
dialogue_count=5,
)
pool = VariablePool(system_variables=sys_var)
kv = [
("user_id", sys_var.user_id),
("app_id", sys_var.app_id),
("workflow_id", sys_var.workflow_id),
("workflow_run_id", sys_var.workflow_execution_id),
("query", sys_var.query),
("conversation_id", sys_var.conversation_id),
("dialogue_count", sys_var.dialogue_count),
]
for key, expected_value in kv:
segment = pool.get([SYSTEM_VARIABLE_NODE_ID, key])
assert segment is not None
assert segment.value == expected_value
class TestVariablePoolSerialization:
"""Test cases for VariablePool serialization and deserialization using Pydantic's built-in methods.
These tests focus exclusively on serialization/deserialization logic to ensure that
VariablePool data can be properly serialized to dictionaries/JSON and reconstructed
while preserving all data integrity.
"""
_NODE1_ID = "node_1"
_NODE2_ID = "node_2"
_NODE3_ID = "node_3"
def _create_pool_without_file(self):
# Create comprehensive system variables
system_vars = SystemVariable(
user_id="test_user_id",
app_id="test_app_id",
workflow_id="test_workflow_id",
workflow_execution_id="test_execution_123",
query="test query",
conversation_id="test_conv_id",
dialogue_count=5,
)
# Create environment variables with all types including ArrayFileVariable
env_vars: list[VariableUnion] = [
StringVariable(
id="env_string_id",
name="env_string",
value="env_string_value",
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_string"],
),
IntegerVariable(
id="env_integer_id",
name="env_integer",
value=1,
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_integer"],
),
FloatVariable(
id="env_float_id",
name="env_float",
value=1.0,
selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_float"],
),
]
# Create conversation variables with complex data
conv_vars: list[VariableUnion] = [
StringVariable(
id="conv_string_id",
name="conv_string",
value="conv_string_value",
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_string"],
),
IntegerVariable(
id="conv_integer_id",
name="conv_integer",
value=1,
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_integer"],
),
FloatVariable(
id="conv_float_id",
name="conv_float",
value=1.0,
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_float"],
),
ObjectVariable(
id="conv_object_id",
name="conv_object",
value={"key": "value", "nested": {"data": 123}},
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_object"],
),
ArrayStringVariable(
id="conv_array_string_id",
name="conv_array_string",
value=["conv_array_string_value"],
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_string"],
),
ArrayNumberVariable(
id="conv_array_number_id",
name="conv_array_number",
value=[1, 1.0],
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_number"],
),
ArrayObjectVariable(
id="conv_array_object_id",
name="conv_array_object",
value=[{"a": 1}, {"b": "2"}],
selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_object"],
),
]
# Create comprehensive user inputs
user_inputs = {
"string_input": "test_value",
"number_input": 42,
"object_input": {"nested": {"key": "value"}},
"array_input": ["item1", "item2", "item3"],
}
# Create VariablePool
pool = VariablePool(
system_variables=system_vars,
user_inputs=user_inputs,
environment_variables=env_vars,
conversation_variables=conv_vars,
)
return pool
def _add_node_data_to_pool(self, pool: VariablePool, with_file=False):
test_file = File(
tenant_id="test_tenant_id",
type=FileType.DOCUMENT,
transfer_method=FileTransferMethod.LOCAL_FILE,
related_id="test_related_id",
remote_url="test_url",
filename="test_file.txt",
storage_key="test_storage_key",
)
# Add various segment types to variable dictionary
pool.add((self._NODE1_ID, "string_var"), StringSegment(value="test_string"))
pool.add((self._NODE1_ID, "int_var"), IntegerSegment(value=123))
pool.add((self._NODE1_ID, "float_var"), FloatSegment(value=45.67))
pool.add((self._NODE1_ID, "object_var"), ObjectSegment(value={"test": "data"}))
if with_file:
pool.add((self._NODE1_ID, "file_var"), FileSegment(value=test_file))
pool.add((self._NODE1_ID, "none_var"), NoneSegment())
# Add array segments including ArrayFileVariable
pool.add((self._NODE2_ID, "array_string"), ArrayStringSegment(value=["a", "b", "c"]))
pool.add((self._NODE2_ID, "array_number"), ArrayNumberSegment(value=[1, 2, 3]))
pool.add((self._NODE2_ID, "array_object"), ArrayObjectSegment(value=[{"a": 1}, {"b": 2}]))
if with_file:
pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file]))
pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}]))
# Add nested variables
pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value"))
def test_system_variables(self):
sys_vars = SystemVariable(
user_id="test_user_id",
app_id="test_app_id",
workflow_id="test_workflow_id",
workflow_execution_id="test_execution_123",
query="test query",
conversation_id="test_conv_id",
dialogue_count=5,
)
pool = VariablePool(system_variables=sys_vars)
json = pool.model_dump_json()
pool2 = VariablePool.model_validate_json(json)
assert pool2.system_variables == sys_vars
for mode in ["json", "python"]:
dict_ = pool.model_dump(mode=mode)
pool2 = VariablePool.model_validate(dict_)
assert pool2.system_variables == sys_vars
def test_pool_without_file_vars(self):
pool = self._create_pool_without_file()
json = pool.model_dump_json()
pool2 = pool.model_validate_json(json)
assert pool2.system_variables == pool.system_variables
assert pool2.conversation_variables == pool.conversation_variables
assert pool2.environment_variables == pool.environment_variables
assert pool2.user_inputs == pool.user_inputs
assert pool2.variable_dictionary == pool.variable_dictionary
assert pool2 == pool
def test_basic_dictionary_round_trip(self):
"""Test basic round-trip serialization: model_dump() → model_validate()"""
# Create a comprehensive VariablePool with all data types
original_pool = self._create_pool_without_file()
self._add_node_data_to_pool(original_pool)
# Serialize to dictionary using Pydantic's model_dump()
serialized_data = original_pool.model_dump()
# Verify serialized data structure
assert isinstance(serialized_data, dict)
assert "system_variables" in serialized_data
assert "user_inputs" in serialized_data
assert "environment_variables" in serialized_data
assert "conversation_variables" in serialized_data
assert "variable_dictionary" in serialized_data
# Deserialize back using Pydantic's model_validate()
reconstructed_pool = VariablePool.model_validate(serialized_data)
# Verify data integrity is preserved
self._assert_pools_equal(original_pool, reconstructed_pool)
def test_json_round_trip(self):
"""Test JSON round-trip serialization: model_dump_json() → model_validate_json()"""
# Create a comprehensive VariablePool with all data types
original_pool = self._create_pool_without_file()
self._add_node_data_to_pool(original_pool)
# Serialize to JSON string using Pydantic's model_dump_json()
json_data = original_pool.model_dump_json()
# Verify JSON is valid string
assert isinstance(json_data, str)
assert len(json_data) > 0
# Deserialize back using Pydantic's model_validate_json()
reconstructed_pool = VariablePool.model_validate_json(json_data)
# Verify data integrity is preserved
self._assert_pools_equal(original_pool, reconstructed_pool)
def test_complex_data_serialization(self):
"""Test serialization of complex data structures including ArrayFileVariable"""
original_pool = self._create_pool_without_file()
self._add_node_data_to_pool(original_pool, with_file=True)
# Test dictionary round-trip
dict_data = original_pool.model_dump()
reconstructed_dict = VariablePool.model_validate(dict_data)
# Test JSON round-trip
json_data = original_pool.model_dump_json()
reconstructed_json = VariablePool.model_validate_json(json_data)
# Verify both reconstructed pools are equivalent
self._assert_pools_equal(reconstructed_dict, reconstructed_json)
# TODO: assert the data for file object...
def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool) -> None:
"""Assert that two VariablePools contain equivalent data"""
# Compare system variables
assert pool1.system_variables == pool2.system_variables
# Compare user inputs
assert dict(pool1.user_inputs) == dict(pool2.user_inputs)
# Compare environment variables count
assert pool1.environment_variables == pool2.environment_variables
# Compare conversation variables count
assert pool1.conversation_variables == pool2.conversation_variables
# Test key variable retrievals to ensure functionality is preserved
test_selectors = [
(SYSTEM_VARIABLE_NODE_ID, "user_id"),
(SYSTEM_VARIABLE_NODE_ID, "app_id"),
(ENVIRONMENT_VARIABLE_NODE_ID, "env_string"),
(ENVIRONMENT_VARIABLE_NODE_ID, "env_number"),
(CONVERSATION_VARIABLE_NODE_ID, "conv_string"),
(self._NODE1_ID, "string_var"),
(self._NODE1_ID, "int_var"),
(self._NODE1_ID, "float_var"),
(self._NODE2_ID, "array_string"),
(self._NODE2_ID, "array_number"),
(self._NODE3_ID, "nested", "deep", "var"),
]
for selector in test_selectors:
val1 = pool1.get(selector)
val2 = pool2.get(selector)
# Both should exist or both should be None
assert (val1 is None) == (val2 is None)
if val1 is not None and val2 is not None:
# Values should be equal
assert val1.value == val2.value
# Value types should be the same (more important than exact class type)
assert val1.value_type == val2.value_type

Loading…
Cancel
Save