diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index 8f3b91f26e..7989f3b032 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -5,7 +5,6 @@ from uuid import uuid4 from pydantic import Discriminator, Field, Tag from core.helper import encrypter -from core.variables.segment_group import SegmentGroup from .segments import ( ArrayAnySegment, diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 5f362237b4..39ebd009d5 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -2,8 +2,6 @@ from collections.abc import Mapping, Sequence from typing import Any, cast from uuid import uuid4 -from shapely import is_valid - from configs import dify_config from core.file import File from core.variables.exc import VariableError diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 1b035d01a7..cdc261fd42 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -1,14 +1,49 @@ +import dataclasses + +from pydantic import BaseModel + +from core.file import File, FileTransferMethod, FileType from core.helper import encrypter -from core.variables import SecretVariable, StringVariable +from core.variables.segments import ( + ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, + ArrayStringSegment, + FileSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, + Segment, + SegmentUnion, + StringSegment, + get_segment_discriminator, +) +from core.variables.types import SegmentType +from core.variables.variables import ( + ArrayAnyVariable, + ArrayFileVariable, + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FileVariable, + FloatVariable, + IntegerVariable, + NoneVariable, + ObjectVariable, + SecretVariable, + StringVariable, + Variable, + VariableUnion, +) from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable def test_segment_group_to_text(): variable_pool = VariablePool( - system_variables={ - SystemVariableKey("user_id"): "fake-user-id", - }, + system_variables=SystemVariable(user_id="fake-user-id"), user_inputs={}, environment_variables=[ SecretVariable(name="secret_key", value="fake-secret-key"), @@ -30,7 +65,7 @@ def test_segment_group_to_text(): def test_convert_constant_to_segment_group(): variable_pool = VariablePool( - system_variables={}, + system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -43,9 +78,7 @@ def test_convert_constant_to_segment_group(): def test_convert_variable_to_segment_group(): variable_pool = VariablePool( - system_variables={ - SystemVariableKey("user_id"): "fake-user-id", - }, + system_variables=SystemVariable(user_id="fake-user-id"), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -56,3 +89,297 @@ def test_convert_variable_to_segment_group(): assert segments_group.log == "fake-user-id" assert isinstance(segments_group.value[0], StringVariable) assert segments_group.value[0].value == "fake-user-id" + + +class _Segments(BaseModel): + segments: list[SegmentUnion] + + +class _Variables(BaseModel): + variables: list[VariableUnion] + + +def create_test_file( + file_type: FileType = FileType.DOCUMENT, + transfer_method: FileTransferMethod = FileTransferMethod.LOCAL_FILE, + filename: str = "test.txt", + extension: str = ".txt", + mime_type: str = "text/plain", + size: int = 1024, +) -> File: + """Factory function to create File objects for testing""" + return File( + tenant_id="test-tenant", + type=file_type, + transfer_method=transfer_method, + filename=filename, + extension=extension, + mime_type=mime_type, + size=size, + related_id="test-file-id" if transfer_method != FileTransferMethod.REMOTE_URL else None, + remote_url="https://example.com/file.txt" if transfer_method == FileTransferMethod.REMOTE_URL else None, + storage_key="test-storage-key", + ) + + +class TestSegmentDumpAndLoad: + """Test suite for segment and variable serialization/deserialization""" + + def test_segments(self): + """Test basic segment serialization compatibility""" + model = _Segments(segments=[IntegerSegment(value=1), StringSegment(value="a")]) + json = model.model_dump_json() + print("Json: ", json) + loaded = _Segments.model_validate_json(json) + assert loaded == model + + def test_segment_number(self): + """Test number segment serialization compatibility""" + model = _Segments(segments=[IntegerSegment(value=1), FloatSegment(value=1.0)]) + json = model.model_dump_json() + print("Json: ", json) + loaded = _Segments.model_validate_json(json) + assert loaded == model + + def test_variables(self): + """Test variable serialization compatibility""" + model = _Variables(variables=[IntegerVariable(value=1, name="int"), StringVariable(value="a", name="str")]) + json = model.model_dump_json() + print("Json: ", json) + restored = _Variables.model_validate_json(json) + assert restored == model + + def test_all_segments_serialization(self): + """Test serialization/deserialization of all segment types""" + # Create one instance of each segment type + test_file = create_test_file() + + all_segments: list[SegmentUnion] = [ + NoneSegment(), + StringSegment(value="test string"), + IntegerSegment(value=42), + FloatSegment(value=3.14), + ObjectSegment(value={"key": "value", "number": 123}), + FileSegment(value=test_file), + ArrayAnySegment(value=[1, "string", 3.14, {"key": "value"}]), + ArrayStringSegment(value=["hello", "world"]), + ArrayNumberSegment(value=[1, 2.5, 3]), + ArrayObjectSegment(value=[{"id": 1}, {"id": 2}]), + ArrayFileSegment(value=[]), # Empty array to avoid file complexity + ] + + # Test serialization and deserialization + model = _Segments(segments=all_segments) + json_str = model.model_dump_json() + loaded = _Segments.model_validate_json(json_str) + + # Verify all segments are preserved + assert len(loaded.segments) == len(all_segments) + + for original, loaded_segment in zip(all_segments, loaded.segments): + assert type(loaded_segment) == type(original) + assert loaded_segment.value_type == original.value_type + + # For file segments, compare key properties instead of exact equality + if isinstance(original, FileSegment) and isinstance(loaded_segment, FileSegment): + orig_file = original.value + loaded_file = loaded_segment.value + assert isinstance(orig_file, File) + assert isinstance(loaded_file, File) + assert loaded_file.tenant_id == orig_file.tenant_id + assert loaded_file.type == orig_file.type + assert loaded_file.filename == orig_file.filename + else: + assert loaded_segment.value == original.value + + def test_all_variables_serialization(self): + """Test serialization/deserialization of all variable types""" + # Create one instance of each variable type + test_file = create_test_file() + + all_variables: list[VariableUnion] = [ + NoneVariable(name="none_var"), + StringVariable(value="test string", name="string_var"), + IntegerVariable(value=42, name="int_var"), + FloatVariable(value=3.14, name="float_var"), + ObjectVariable(value={"key": "value", "number": 123}, name="object_var"), + FileVariable(value=test_file, name="file_var"), + ArrayAnyVariable(value=[1, "string", 3.14, {"key": "value"}], name="array_any_var"), + ArrayStringVariable(value=["hello", "world"], name="array_string_var"), + ArrayNumberVariable(value=[1, 2.5, 3], name="array_number_var"), + ArrayObjectVariable(value=[{"id": 1}, {"id": 2}], name="array_object_var"), + ArrayFileVariable(value=[], name="array_file_var"), # Empty array to avoid file complexity + ] + + # Test serialization and deserialization + model = _Variables(variables=all_variables) + json_str = model.model_dump_json() + loaded = _Variables.model_validate_json(json_str) + + # Verify all variables are preserved + assert len(loaded.variables) == len(all_variables) + + for original, loaded_variable in zip(all_variables, loaded.variables): + assert type(loaded_variable) == type(original) + assert loaded_variable.value_type == original.value_type + assert loaded_variable.name == original.name + + # For file variables, compare key properties instead of exact equality + if isinstance(original, FileVariable) and isinstance(loaded_variable, FileVariable): + orig_file = original.value + loaded_file = loaded_variable.value + assert isinstance(orig_file, File) + assert isinstance(loaded_file, File) + assert loaded_file.tenant_id == orig_file.tenant_id + assert loaded_file.type == orig_file.type + assert loaded_file.filename == orig_file.filename + else: + assert loaded_variable.value == original.value + + def test_segment_discriminator_function_for_segment_types(self): + """Test the segment discriminator function""" + + @dataclasses.dataclass + class TestCase: + segment: Segment + expected_segment_type: SegmentType + + file1 = create_test_file() + file2 = create_test_file(filename="test2.txt") + + cases = [ + TestCase( + NoneSegment(), + SegmentType.NONE, + ), + TestCase( + StringSegment(value=""), + SegmentType.STRING, + ), + TestCase( + FloatSegment(value=0.0), + SegmentType.FLOAT, + ), + TestCase( + IntegerSegment(value=0), + SegmentType.INTEGER, + ), + TestCase( + ObjectSegment(value={}), + SegmentType.OBJECT, + ), + TestCase( + FileSegment(value=file1), + SegmentType.FILE, + ), + TestCase( + ArrayAnySegment(value=[0, 0.0, ""]), + SegmentType.ARRAY_ANY, + ), + TestCase( + ArrayStringSegment(value=[""]), + SegmentType.ARRAY_STRING, + ), + TestCase( + ArrayNumberSegment(value=[0, 0.0]), + SegmentType.ARRAY_NUMBER, + ), + TestCase( + ArrayObjectSegment(value=[{}]), + SegmentType.ARRAY_OBJECT, + ), + TestCase( + ArrayFileSegment(value=[file1, file2]), + SegmentType.ARRAY_FILE, + ), + ] + + for test_case in cases: + segment = test_case.segment + assert get_segment_discriminator(segment) == test_case.expected_segment_type, ( + f"get_segment_discriminator failed for type {type(segment)}" + ) + model_dict = segment.model_dump(mode="json") + assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, ( + f"get_segment_discriminator failed for serialized form of type {type(segment)}" + ) + + def test_variable_discriminator_function_for_variable_types(self): + """Test the variable discriminator function""" + + @dataclasses.dataclass + class TestCase: + variable: Variable + expected_segment_type: SegmentType + + file1 = create_test_file() + file2 = create_test_file(filename="test2.txt") + + cases = [ + TestCase( + NoneVariable(name="none_var"), + SegmentType.NONE, + ), + TestCase( + StringVariable(value="test", name="string_var"), + SegmentType.STRING, + ), + TestCase( + FloatVariable(value=0.0, name="float_var"), + SegmentType.FLOAT, + ), + TestCase( + IntegerVariable(value=0, name="int_var"), + SegmentType.INTEGER, + ), + TestCase( + ObjectVariable(value={}, name="object_var"), + SegmentType.OBJECT, + ), + TestCase( + FileVariable(value=file1, name="file_var"), + SegmentType.FILE, + ), + TestCase( + SecretVariable(value="secret", name="secret_var"), + SegmentType.SECRET, + ), + TestCase( + ArrayAnyVariable(value=[0, 0.0, ""], name="array_any_var"), + SegmentType.ARRAY_ANY, + ), + TestCase( + ArrayStringVariable(value=[""], name="array_string_var"), + SegmentType.ARRAY_STRING, + ), + TestCase( + ArrayNumberVariable(value=[0, 0.0], name="array_number_var"), + SegmentType.ARRAY_NUMBER, + ), + TestCase( + ArrayObjectVariable(value=[{}], name="array_object_var"), + SegmentType.ARRAY_OBJECT, + ), + TestCase( + ArrayFileVariable(value=[file1, file2], name="array_file_var"), + SegmentType.ARRAY_FILE, + ), + ] + + for test_case in cases: + variable = test_case.variable + assert get_segment_discriminator(variable) == test_case.expected_segment_type, ( + f"get_segment_discriminator failed for type {type(variable)}" + ) + model_dict = variable.model_dump(mode="json") + assert get_segment_discriminator(model_dict) == test_case.expected_segment_type, ( + f"get_segment_discriminator failed for serialized form of type {type(variable)}" + ) + + def test_invlaid_value_for_discriminator(self): + # Test invalid cases + assert get_segment_discriminator({"value_type": "invalid"}) is None + assert get_segment_discriminator({}) is None + assert get_segment_discriminator("not_a_dict") is None + assert get_segment_discriminator(42) is None + assert get_segment_discriminator(object) is None diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py new file mode 100644 index 0000000000..64d0d8c7e7 --- /dev/null +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -0,0 +1,60 @@ +from core.variables.types import SegmentType + + +class TestSegmentTypeIsArrayType: + """ + Test class for SegmentType.is_array_type method. + + Provides comprehensive coverage of all SegmentType values to ensure + correct identification of array and non-array types. + """ + + def test_is_array_type(self): + """ + Test that all SegmentType enum values are covered in our test cases. + + Ensures comprehensive coverage by verifying that every SegmentType + value is tested for the is_array_type method. + """ + # Arrange + all_segment_types = set(SegmentType) + expected_array_types = [ + SegmentType.ARRAY_ANY, + SegmentType.ARRAY_STRING, + SegmentType.ARRAY_NUMBER, + SegmentType.ARRAY_OBJECT, + SegmentType.ARRAY_FILE, + ] + expected_non_array_types = [ + SegmentType.INTEGER, + SegmentType.FLOAT, + SegmentType.NUMBER, + SegmentType.STRING, + SegmentType.OBJECT, + SegmentType.SECRET, + SegmentType.FILE, + SegmentType.NONE, + SegmentType.GROUP, + ] + + for seg_type in expected_array_types: + assert seg_type.is_array_type() + + for seg_type in expected_non_array_types: + assert not seg_type.is_array_type() + + # Act & Assert + covered_types = set(expected_array_types) | set(expected_non_array_types) + assert covered_types == set(SegmentType), "All SegmentType values should be covered in tests" + + def test_all_enum_values_are_supported(self): + """ + Test that all enum values are supported and return boolean values. + + Validates that every SegmentType enum value can be processed by + is_array_type method and returns a boolean value. + """ + enum_values: list[SegmentType] = list(SegmentType) + for seg_type in enum_values: + is_array = seg_type.is_array_type() + assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py new file mode 100644 index 0000000000..8e9fcf196b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_graph_runtime_state.py @@ -0,0 +1,146 @@ +import time +from decimal import Decimal + +from core.model_runtime.entities.llm_entities import LLMUsage +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState +from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState +from core.workflow.system_variable import SystemVariable + + +def create_test_graph_runtime_state() -> GraphRuntimeState: + """Factory function to create a GraphRuntimeState with non-empty values for testing.""" + # Create a variable pool with system variables + system_vars = SystemVariable( + user_id="test_user_123", + app_id="test_app_456", + workflow_id="test_workflow_789", + workflow_execution_id="test_execution_001", + query="test query", + conversation_id="test_conv_123", + dialogue_count=5, + ) + variable_pool = VariablePool(system_variables=system_vars) + + # Add some variables to the variable pool + variable_pool.add(["test_node", "test_var"], "test_value") + variable_pool.add(["another_node", "another_var"], 42) + + # Create LLM usage with realistic values + llm_usage = LLMUsage( + prompt_tokens=150, + prompt_unit_price=Decimal("0.001"), + prompt_price_unit=Decimal("1000"), + prompt_price=Decimal("0.15"), + completion_tokens=75, + completion_unit_price=Decimal("0.002"), + completion_price_unit=Decimal("1000"), + completion_price=Decimal("0.15"), + total_tokens=225, + total_price=Decimal("0.30"), + currency="USD", + latency=1.25, + ) + + # Create runtime route state with some node states + node_run_state = RuntimeRouteState() + node_state = node_run_state.create_node_state("test_node_1") + node_run_state.add_route(node_state.id, "target_node_id") + + return GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + total_tokens=100, + llm_usage=llm_usage, + outputs={ + "string_output": "test result", + "int_output": 42, + "float_output": 3.14, + "list_output": ["item1", "item2", "item3"], + "dict_output": {"key1": "value1", "key2": 123}, + "nested_dict": {"level1": {"level2": ["nested", "list", 456]}}, + }, + node_run_steps=5, + node_run_state=node_run_state, + ) + + +def test_basic_round_trip_serialization(): + """Test basic round-trip serialization ensures GraphRuntimeState values remain unchanged.""" + # Create a state with non-empty values + original_state = create_test_graph_runtime_state() + + # Serialize to JSON and deserialize back + json_data = original_state.model_dump_json() + deserialized_state = GraphRuntimeState.model_validate_json(json_data) + + # Core test: ensure the round-trip preserves all values + assert deserialized_state == original_state + + # Serialize to JSON and deserialize back + dict_data = original_state.model_dump(mode="python") + deserialized_state = GraphRuntimeState.model_validate(dict_data) + assert deserialized_state == original_state + + # Serialize to JSON and deserialize back + dict_data = original_state.model_dump(mode="json") + deserialized_state = GraphRuntimeState.model_validate(dict_data) + assert deserialized_state == original_state + + +def test_outputs_field_round_trip(): + """Test the problematic outputs field maintains values through round-trip serialization.""" + original_state = create_test_graph_runtime_state() + + # Serialize and deserialize + json_data = original_state.model_dump_json() + deserialized_state = GraphRuntimeState.model_validate_json(json_data) + + # Verify the outputs field specifically maintains its values + assert deserialized_state.outputs == original_state.outputs + assert deserialized_state == original_state + + +def test_empty_outputs_round_trip(): + """Test round-trip serialization with empty outputs field.""" + variable_pool = VariablePool.empty() + original_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + outputs={}, # Empty outputs + ) + + json_data = original_state.model_dump_json() + deserialized_state = GraphRuntimeState.model_validate_json(json_data) + + assert deserialized_state == original_state + + +def test_llm_usage_round_trip(): + # Create LLM usage with specific decimal values + llm_usage = LLMUsage( + prompt_tokens=100, + prompt_unit_price=Decimal("0.0015"), + prompt_price_unit=Decimal("1000"), + prompt_price=Decimal("0.15"), + completion_tokens=50, + completion_unit_price=Decimal("0.003"), + completion_price_unit=Decimal("1000"), + completion_price=Decimal("0.15"), + total_tokens=150, + total_price=Decimal("0.30"), + currency="USD", + latency=2.5, + ) + + json_data = llm_usage.model_dump_json() + deserialized = LLMUsage.model_validate_json(json_data) + assert deserialized == llm_usage + + dict_data = llm_usage.model_dump(mode="python") + deserialized = LLMUsage.model_validate(dict_data) + assert deserialized == llm_usage + + dict_data = llm_usage.model_dump(mode="json") + deserialized = LLMUsage.model_validate(dict_data) + assert deserialized == llm_usage diff --git a/api/tests/unit_tests/core/workflow/graph_engine/entities/test_node_run_state.py b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_node_run_state.py new file mode 100644 index 0000000000..f3de42479a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/graph_engine/entities/test_node_run_state.py @@ -0,0 +1,401 @@ +import json +import uuid +from datetime import UTC, datetime + +import pytest +from pydantic import ValidationError + +from core.workflow.entities.node_entities import NodeRunResult +from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus +from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState + +_TEST_DATETIME = datetime(2024, 1, 15, 10, 30, 45) + + +class TestRouteNodeStateSerialization: + """Test cases for RouteNodeState Pydantic serialization/deserialization.""" + + def _test_route_node_state(self): + """Test comprehensive RouteNodeState serialization with all core fields validation.""" + + node_run_result = NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"input_key": "input_value"}, + outputs={"output_key": "output_value"}, + ) + + node_state = RouteNodeState( + node_id="comprehensive_test_node", + start_at=_TEST_DATETIME, + finished_at=_TEST_DATETIME, + status=RouteNodeState.Status.SUCCESS, + node_run_result=node_run_result, + index=5, + paused_at=_TEST_DATETIME, + paused_by="user_123", + failed_reason="test_reason", + ) + return node_state + + def test_route_node_state_comprehensive_field_validation(self): + """Test comprehensive RouteNodeState serialization with all core fields validation.""" + node_state = self._test_route_node_state() + serialized = node_state.model_dump() + + # Comprehensive validation of all RouteNodeState fields + assert serialized["node_id"] == "comprehensive_test_node" + assert serialized["status"] == RouteNodeState.Status.SUCCESS + assert serialized["start_at"] == _TEST_DATETIME + assert serialized["finished_at"] == _TEST_DATETIME + assert serialized["paused_at"] == _TEST_DATETIME + assert serialized["paused_by"] == "user_123" + assert serialized["failed_reason"] == "test_reason" + assert serialized["index"] == 5 + assert "id" in serialized + assert isinstance(serialized["id"], str) + uuid.UUID(serialized["id"]) # Validate UUID format + + # Validate nested NodeRunResult structure + assert serialized["node_run_result"] is not None + assert serialized["node_run_result"]["status"] == WorkflowNodeExecutionStatus.SUCCEEDED + assert serialized["node_run_result"]["inputs"] == {"input_key": "input_value"} + assert serialized["node_run_result"]["outputs"] == {"output_key": "output_value"} + + def test_route_node_state_minimal_required_fields(self): + """Test RouteNodeState with only required fields, focusing on defaults.""" + node_state = RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME) + + serialized = node_state.model_dump() + + # Focus on required fields and default values (not re-testing all fields) + assert serialized["node_id"] == "minimal_node" + assert serialized["start_at"] == _TEST_DATETIME + assert serialized["status"] == RouteNodeState.Status.RUNNING # Default status + assert serialized["index"] == 1 # Default index + assert serialized["node_run_result"] is None # Default None + json = node_state.model_dump_json() + deserialized = RouteNodeState.model_validate_json(json) + assert deserialized == node_state + + def test_route_node_state_deserialization_from_dict(self): + """Test RouteNodeState deserialization from dictionary data.""" + test_datetime = datetime(2024, 1, 15, 10, 30, 45) + test_id = str(uuid.uuid4()) + + dict_data = { + "id": test_id, + "node_id": "deserialized_node", + "start_at": test_datetime, + "status": "success", + "finished_at": test_datetime, + "index": 3, + } + + node_state = RouteNodeState.model_validate(dict_data) + + # Focus on deserialization accuracy + assert node_state.id == test_id + assert node_state.node_id == "deserialized_node" + assert node_state.start_at == test_datetime + assert node_state.status == RouteNodeState.Status.SUCCESS + assert node_state.finished_at == test_datetime + assert node_state.index == 3 + + def test_route_node_state_round_trip_consistency(self): + node_states = ( + self._test_route_node_state(), + RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME), + ) + for node_state in node_states: + json = node_state.model_dump_json() + deserialized = RouteNodeState.model_validate_json(json) + assert deserialized == node_state + + dict_ = node_state.model_dump(mode="python") + deserialized = RouteNodeState.model_validate(dict_) + assert deserialized == node_state + + dict_ = node_state.model_dump(mode="json") + deserialized = RouteNodeState.model_validate(dict_) + assert deserialized == node_state + + +class TestRouteNodeStateEnumSerialization: + """Dedicated tests for RouteNodeState Status enum serialization behavior.""" + + def test_status_enum_model_dump_behavior(self): + """Test Status enum serialization in model_dump() returns enum objects.""" + + for status_enum in RouteNodeState.Status: + node_state = RouteNodeState(node_id="enum_test", start_at=_TEST_DATETIME, status=status_enum) + serialized = node_state.model_dump(mode="python") + assert serialized["status"] == status_enum + serialized = node_state.model_dump(mode="json") + assert serialized["status"] == status_enum.value + + def test_status_enum_json_serialization_behavior(self): + """Test Status enum serialization in JSON returns string values.""" + test_datetime = datetime(2024, 1, 15, 10, 30, 45) + + enum_to_string_mapping = { + RouteNodeState.Status.RUNNING: "running", + RouteNodeState.Status.SUCCESS: "success", + RouteNodeState.Status.FAILED: "failed", + RouteNodeState.Status.PAUSED: "paused", + RouteNodeState.Status.EXCEPTION: "exception", + } + + for status_enum, expected_string in enum_to_string_mapping.items(): + node_state = RouteNodeState(node_id="json_enum_test", start_at=test_datetime, status=status_enum) + + json_data = json.loads(node_state.model_dump_json()) + assert json_data["status"] == expected_string + + def test_status_enum_deserialization_from_string(self): + """Test Status enum deserialization from string values.""" + test_datetime = datetime(2024, 1, 15, 10, 30, 45) + + string_to_enum_mapping = { + "running": RouteNodeState.Status.RUNNING, + "success": RouteNodeState.Status.SUCCESS, + "failed": RouteNodeState.Status.FAILED, + "paused": RouteNodeState.Status.PAUSED, + "exception": RouteNodeState.Status.EXCEPTION, + } + + for status_string, expected_enum in string_to_enum_mapping.items(): + dict_data = { + "node_id": "enum_deserialize_test", + "start_at": test_datetime, + "status": status_string, + } + + node_state = RouteNodeState.model_validate(dict_data) + assert node_state.status == expected_enum + + +class TestRuntimeRouteStateSerialization: + """Test cases for RuntimeRouteState Pydantic serialization/deserialization.""" + + _NODE1_ID = "node_1" + _ROUTE_STATE1_ID = str(uuid.uuid4()) + _NODE2_ID = "node_2" + _ROUTE_STATE2_ID = str(uuid.uuid4()) + _NODE3_ID = "node_3" + _ROUTE_STATE3_ID = str(uuid.uuid4()) + + def _get_runtime_route_state(self): + # Create node states with different configurations + node_state_1 = RouteNodeState( + id=self._ROUTE_STATE1_ID, + node_id=self._NODE1_ID, + start_at=_TEST_DATETIME, + index=1, + ) + node_state_2 = RouteNodeState( + id=self._ROUTE_STATE2_ID, + node_id=self._NODE2_ID, + start_at=_TEST_DATETIME, + status=RouteNodeState.Status.SUCCESS, + finished_at=_TEST_DATETIME, + index=2, + ) + node_state_3 = RouteNodeState( + id=self._ROUTE_STATE3_ID, + node_id=self._NODE3_ID, + start_at=_TEST_DATETIME, + status=RouteNodeState.Status.FAILED, + failed_reason="Test failure", + index=3, + ) + + runtime_state = RuntimeRouteState( + routes={node_state_1.id: [node_state_2.id, node_state_3.id], node_state_2.id: [node_state_3.id]}, + node_state_mapping={ + node_state_1.id: node_state_1, + node_state_2.id: node_state_2, + node_state_3.id: node_state_3, + }, + ) + + return runtime_state + + def test_runtime_route_state_comprehensive_structure_validation(self): + """Test comprehensive RuntimeRouteState serialization with full structure validation.""" + + runtime_state = self._get_runtime_route_state() + serialized = runtime_state.model_dump() + + # Comprehensive validation of RuntimeRouteState structure + assert "routes" in serialized + assert "node_state_mapping" in serialized + assert isinstance(serialized["routes"], dict) + assert isinstance(serialized["node_state_mapping"], dict) + + # Validate routes dictionary structure and content + assert len(serialized["routes"]) == 2 + assert self._ROUTE_STATE1_ID in serialized["routes"] + assert self._ROUTE_STATE2_ID in serialized["routes"] + assert serialized["routes"][self._ROUTE_STATE1_ID] == [self._ROUTE_STATE2_ID, self._ROUTE_STATE3_ID] + assert serialized["routes"][self._ROUTE_STATE2_ID] == [self._ROUTE_STATE3_ID] + + # Validate node_state_mapping dictionary structure and content + assert len(serialized["node_state_mapping"]) == 3 + for state_id in [ + self._ROUTE_STATE1_ID, + self._ROUTE_STATE2_ID, + self._ROUTE_STATE3_ID, + ]: + assert state_id in serialized["node_state_mapping"] + node_data = serialized["node_state_mapping"][state_id] + node_state = runtime_state.node_state_mapping[state_id] + assert node_data["node_id"] == node_state.node_id + assert node_data["status"] == node_state.status + assert node_data["index"] == node_state.index + + def test_runtime_route_state_empty_collections(self): + """Test RuntimeRouteState with empty collections, focusing on default behavior.""" + runtime_state = RuntimeRouteState() + serialized = runtime_state.model_dump() + + # Focus on default empty collection behavior + assert serialized["routes"] == {} + assert serialized["node_state_mapping"] == {} + assert isinstance(serialized["routes"], dict) + assert isinstance(serialized["node_state_mapping"], dict) + + def test_runtime_route_state_json_serialization_structure(self): + """Test RuntimeRouteState JSON serialization structure.""" + node_state = RouteNodeState(node_id="json_node", start_at=_TEST_DATETIME) + + runtime_state = RuntimeRouteState( + routes={"source": ["target1", "target2"]}, node_state_mapping={node_state.id: node_state} + ) + + json_str = runtime_state.model_dump_json() + json_data = json.loads(json_str) + + # Focus on JSON structure validation + assert isinstance(json_str, str) + assert isinstance(json_data, dict) + assert "routes" in json_data + assert "node_state_mapping" in json_data + assert json_data["routes"]["source"] == ["target1", "target2"] + assert node_state.id in json_data["node_state_mapping"] + + def test_runtime_route_state_deserialization_from_dict(self): + """Test RuntimeRouteState deserialization from dictionary data.""" + node_id = str(uuid.uuid4()) + + dict_data = { + "routes": {"source_node": ["target_node_1", "target_node_2"]}, + "node_state_mapping": { + node_id: { + "id": node_id, + "node_id": "test_node", + "start_at": _TEST_DATETIME, + "status": "running", + "index": 1, + } + }, + } + + runtime_state = RuntimeRouteState.model_validate(dict_data) + + # Focus on deserialization accuracy + assert runtime_state.routes == {"source_node": ["target_node_1", "target_node_2"]} + assert len(runtime_state.node_state_mapping) == 1 + assert node_id in runtime_state.node_state_mapping + + deserialized_node = runtime_state.node_state_mapping[node_id] + assert deserialized_node.node_id == "test_node" + assert deserialized_node.status == RouteNodeState.Status.RUNNING + assert deserialized_node.index == 1 + + def test_runtime_route_state_round_trip_consistency(self): + """Test RuntimeRouteState round-trip serialization consistency.""" + original = self._get_runtime_route_state() + + # Dictionary round trip + dict_data = original.model_dump(mode="python") + reconstructed = RuntimeRouteState.model_validate(dict_data) + assert reconstructed == original + + dict_data = original.model_dump(mode="json") + reconstructed = RuntimeRouteState.model_validate(dict_data) + assert reconstructed == original + + # JSON round trip + json_str = original.model_dump_json() + json_reconstructed = RuntimeRouteState.model_validate_json(json_str) + assert json_reconstructed == original + + +class TestSerializationEdgeCases: + """Test edge cases and error conditions for serialization/deserialization.""" + + def test_invalid_status_deserialization(self): + """Test deserialization with invalid status values.""" + test_datetime = _TEST_DATETIME + invalid_data = { + "node_id": "invalid_test", + "start_at": test_datetime, + "status": "invalid_status", + } + + with pytest.raises(ValidationError) as exc_info: + RouteNodeState.model_validate(invalid_data) + assert "status" in str(exc_info.value) + + def test_missing_required_fields_deserialization(self): + """Test deserialization with missing required fields.""" + incomplete_data = {"id": str(uuid.uuid4())} + + with pytest.raises(ValidationError) as exc_info: + RouteNodeState.model_validate(incomplete_data) + error_str = str(exc_info.value) + assert "node_id" in error_str or "start_at" in error_str + + def test_invalid_datetime_deserialization(self): + """Test deserialization with invalid datetime values.""" + invalid_data = { + "node_id": "datetime_test", + "start_at": "invalid_datetime", + "status": "running", + } + + with pytest.raises(ValidationError) as exc_info: + RouteNodeState.model_validate(invalid_data) + assert "start_at" in str(exc_info.value) + + def test_invalid_routes_structure_deserialization(self): + """Test RuntimeRouteState deserialization with invalid routes structure.""" + invalid_data = { + "routes": "invalid_routes_structure", # Should be dict + "node_state_mapping": {}, + } + + with pytest.raises(ValidationError) as exc_info: + RuntimeRouteState.model_validate(invalid_data) + assert "routes" in str(exc_info.value) + + def test_timezone_handling_in_datetime_fields(self): + """Test timezone handling in datetime field serialization.""" + utc_datetime = datetime.now(UTC) + naive_datetime = utc_datetime.replace(tzinfo=None) + + node_state = RouteNodeState(node_id="timezone_test", start_at=naive_datetime) + dict_ = node_state.model_dump() + + assert dict_["start_at"] == naive_datetime + + # Test round trip + reconstructed = RouteNodeState.model_validate(dict_) + assert reconstructed.start_at == naive_datetime + assert reconstructed.start_at.tzinfo is None + + json = node_state.model_dump_json() + + reconstructed = RouteNodeState.model_validate_json(json) + assert reconstructed.start_at == naive_datetime + assert reconstructed.start_at.tzinfo is None diff --git a/api/tests/unit_tests/core/workflow/test_system_variable.py b/api/tests/unit_tests/core/workflow/test_system_variable.py new file mode 100644 index 0000000000..11d788ed79 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_system_variable.py @@ -0,0 +1,251 @@ +import json +from typing import Any + +import pytest +from pydantic import ValidationError + +from core.file.enums import FileTransferMethod, FileType +from core.file.models import File +from core.workflow.system_variable import SystemVariable + +# Test data constants for SystemVariable serialization tests +VALID_BASE_DATA: dict[str, Any] = { + "user_id": "a20f06b1-8703-45ab-937c-860a60072113", + "app_id": "661bed75-458d-49c9-b487-fda0762677b9", + "workflow_id": "d31f2136-b292-4ae0-96d4-1e77894a4f43", +} + +COMPLETE_VALID_DATA: dict[str, Any] = { + **VALID_BASE_DATA, + "query": "test query", + "files": [], + "conversation_id": "91f1eb7d-69f4-4d7b-b82f-4003d51744b9", + "dialogue_count": 5, + "workflow_run_id": "eb4704b5-2274-47f2-bfcd-0452daa82cb5", +} + + +def create_test_file() -> File: + """Create a test File object for serialization tests.""" + return File( + tenant_id="test-tenant-id", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="test-file-id", + filename="test.txt", + extension=".txt", + mime_type="text/plain", + size=1024, + storage_key="test-storage-key", + ) + + +class TestSystemVariableSerialization: + """Focused tests for SystemVariable serialization/deserialization logic.""" + + def test_basic_deserialization(self): + """Test successful deserialization from JSON structure with all fields correctly mapped.""" + # Test with complete data + system_var = SystemVariable(**COMPLETE_VALID_DATA) + + # Verify all fields are correctly mapped + assert system_var.user_id == COMPLETE_VALID_DATA["user_id"] + assert system_var.app_id == COMPLETE_VALID_DATA["app_id"] + assert system_var.workflow_id == COMPLETE_VALID_DATA["workflow_id"] + assert system_var.query == COMPLETE_VALID_DATA["query"] + assert system_var.conversation_id == COMPLETE_VALID_DATA["conversation_id"] + assert system_var.dialogue_count == COMPLETE_VALID_DATA["dialogue_count"] + assert system_var.workflow_execution_id == COMPLETE_VALID_DATA["workflow_run_id"] + assert system_var.files == [] + + # Test with minimal data (only required fields) + minimal_var = SystemVariable(**VALID_BASE_DATA) + assert minimal_var.user_id == VALID_BASE_DATA["user_id"] + assert minimal_var.app_id == VALID_BASE_DATA["app_id"] + assert minimal_var.workflow_id == VALID_BASE_DATA["workflow_id"] + assert minimal_var.query is None + assert minimal_var.conversation_id is None + assert minimal_var.dialogue_count is None + assert minimal_var.workflow_execution_id is None + assert minimal_var.files == [] + + def test_alias_handling(self): + """Test workflow_execution_id vs workflow_run_id alias resolution - core deserialization logic.""" + workflow_id = "eb4704b5-2274-47f2-bfcd-0452daa82cb5" + + # Test workflow_run_id only (preferred alias) + data_run_id = {**VALID_BASE_DATA, "workflow_run_id": workflow_id} + system_var1 = SystemVariable(**data_run_id) + assert system_var1.workflow_execution_id == workflow_id + + # Test workflow_execution_id only (direct field name) + data_execution_id = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id} + system_var2 = SystemVariable(**data_execution_id) + assert system_var2.workflow_execution_id == workflow_id + + # Test both present - workflow_run_id should take precedence + data_both = { + **VALID_BASE_DATA, + "workflow_execution_id": "should-be-ignored", + "workflow_run_id": workflow_id, + } + system_var3 = SystemVariable(**data_both) + assert system_var3.workflow_execution_id == workflow_id + + # Test neither present - should be None + system_var4 = SystemVariable(**VALID_BASE_DATA) + assert system_var4.workflow_execution_id is None + + def test_serialization_round_trip(self): + """Test that serialize → deserialize produces the same result with alias handling.""" + # Create original SystemVariable + original = SystemVariable(**COMPLETE_VALID_DATA) + + # Serialize to dict + serialized = original.model_dump(mode="json") + + # Verify alias is used in serialization (workflow_run_id, not workflow_execution_id) + assert "workflow_run_id" in serialized + assert "workflow_execution_id" not in serialized + assert serialized["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"] + + # Deserialize back + deserialized = SystemVariable(**serialized) + + # Verify all fields match after round-trip + assert deserialized.user_id == original.user_id + assert deserialized.app_id == original.app_id + assert deserialized.workflow_id == original.workflow_id + assert deserialized.query == original.query + assert deserialized.conversation_id == original.conversation_id + assert deserialized.dialogue_count == original.dialogue_count + assert deserialized.workflow_execution_id == original.workflow_execution_id + assert list(deserialized.files) == list(original.files) + + def test_json_round_trip(self): + """Test JSON serialization/deserialization consistency with proper structure.""" + # Create original SystemVariable + original = SystemVariable(**COMPLETE_VALID_DATA) + + # Serialize to JSON string + json_str = original.model_dump_json() + + # Parse JSON and verify structure + json_data = json.loads(json_str) + assert "workflow_run_id" in json_data + assert "workflow_execution_id" not in json_data + assert json_data["workflow_run_id"] == COMPLETE_VALID_DATA["workflow_run_id"] + + # Deserialize from JSON data + deserialized = SystemVariable(**json_data) + + # Verify key fields match after JSON round-trip + assert deserialized.workflow_execution_id == original.workflow_execution_id + assert deserialized.user_id == original.user_id + assert deserialized.app_id == original.app_id + assert deserialized.workflow_id == original.workflow_id + + def test_files_field_deserialization(self): + """Test deserialization with File objects in the files field - SystemVariable specific logic.""" + # Test with empty files list + data_empty = {**VALID_BASE_DATA, "files": []} + system_var_empty = SystemVariable(**data_empty) + assert system_var_empty.files == [] + + # Test with single File object + test_file = create_test_file() + data_single = {**VALID_BASE_DATA, "files": [test_file]} + system_var_single = SystemVariable(**data_single) + assert len(system_var_single.files) == 1 + assert system_var_single.files[0].filename == "test.txt" + assert system_var_single.files[0].tenant_id == "test-tenant-id" + + # Test with multiple File objects + file1 = File( + tenant_id="tenant1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="file1", + filename="doc1.txt", + storage_key="key1", + ) + file2 = File( + tenant_id="tenant2", + type=FileType.IMAGE, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/image.jpg", + filename="image.jpg", + storage_key="key2", + ) + + data_multiple = {**VALID_BASE_DATA, "files": [file1, file2]} + system_var_multiple = SystemVariable(**data_multiple) + assert len(system_var_multiple.files) == 2 + assert system_var_multiple.files[0].filename == "doc1.txt" + assert system_var_multiple.files[1].filename == "image.jpg" + + # Verify files field serialization/deserialization + serialized = system_var_multiple.model_dump(mode="json") + deserialized = SystemVariable(**serialized) + assert len(deserialized.files) == 2 + assert deserialized.files[0].filename == "doc1.txt" + assert deserialized.files[1].filename == "image.jpg" + + def test_alias_serialization_consistency(self): + """Test that alias handling works consistently in both serialization directions.""" + workflow_id = "test-workflow-id" + + # Create with workflow_run_id (alias) + data_with_alias = {**VALID_BASE_DATA, "workflow_run_id": workflow_id} + system_var = SystemVariable(**data_with_alias) + + # Serialize and verify alias is used + serialized = system_var.model_dump() + assert serialized["workflow_run_id"] == workflow_id + assert "workflow_execution_id" not in serialized + + # Deserialize and verify field mapping + deserialized = SystemVariable(**serialized) + assert deserialized.workflow_execution_id == workflow_id + + # Test JSON serialization path + json_serialized = json.loads(system_var.model_dump_json()) + assert json_serialized["workflow_run_id"] == workflow_id + assert "workflow_execution_id" not in json_serialized + + json_deserialized = SystemVariable(**json_serialized) + assert json_deserialized.workflow_execution_id == workflow_id + + def test_model_validator_serialization_logic(self): + """Test the custom model validator behavior for serialization scenarios.""" + workflow_id = "test-workflow-execution-id" + + # Test direct instantiation with workflow_execution_id (should work) + data1 = {**VALID_BASE_DATA, "workflow_execution_id": workflow_id} + system_var1 = SystemVariable(**data1) + assert system_var1.workflow_execution_id == workflow_id + + # Test serialization of the above (should use alias) + serialized1 = system_var1.model_dump() + assert "workflow_run_id" in serialized1 + assert serialized1["workflow_run_id"] == workflow_id + + # Test both present - workflow_run_id takes precedence (validator logic) + data2 = { + **VALID_BASE_DATA, + "workflow_execution_id": "should-be-removed", + "workflow_run_id": workflow_id, + } + system_var2 = SystemVariable(**data2) + assert system_var2.workflow_execution_id == workflow_id + + # Verify serialization consistency + serialized2 = system_var2.model_dump() + assert serialized2["workflow_run_id"] == workflow_id + + +def test_constructor_with_extra_key(): + # Test that SystemVariable should forbid extra keys + with pytest.raises(ValidationError): + # This should fail because there is an unexpected key. + SystemVariable(invalid_key=1) # type: ignore diff --git a/api/tests/unit_tests/core/workflow/test_variable_pool.py b/api/tests/unit_tests/core/workflow/test_variable_pool.py index bb8d34fad5..cb9a4f3c44 100644 --- a/api/tests/unit_tests/core/workflow/test_variable_pool.py +++ b/api/tests/unit_tests/core/workflow/test_variable_pool.py @@ -1,17 +1,40 @@ import pytest -from pydantic import ValidationError from core.file import File, FileTransferMethod, FileType from core.variables import FileSegment, StringSegment -from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID +from core.variables.segments import ( + ArrayAnySegment, + ArrayFileSegment, + ArrayNumberSegment, + ArrayObjectSegment, + ArrayStringSegment, + FloatSegment, + IntegerSegment, + NoneSegment, + ObjectSegment, +) +from core.variables.variables import ( + ArrayNumberVariable, + ArrayObjectVariable, + ArrayStringVariable, + FloatVariable, + IntegerVariable, + ObjectVariable, + StringVariable, + VariableUnion, +) +from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.entities.variable_pool import VariablePool -from core.workflow.enums import SystemVariableKey +from core.workflow.system_variable import SystemVariable from factories.variable_factory import build_segment, segment_to_variable @pytest.fixture def pool(): - return VariablePool(system_variables={}, user_inputs={}) + return VariablePool( + system_variables=SystemVariable(user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id"), + user_inputs={}, + ) @pytest.fixture @@ -52,18 +75,28 @@ def test_use_long_selector(pool): class TestVariablePool: def test_constructor(self): - pool = VariablePool() + # Test with minimal required SystemVariable + minimal_system_vars = SystemVariable( + user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id" + ) + pool = VariablePool(system_variables=minimal_system_vars) + + # Test with all parameters pool = VariablePool( variable_dictionary={}, user_inputs={}, - system_variables={}, + system_variables=minimal_system_vars, environment_variables=[], conversation_variables=[], ) + # Test with more complex SystemVariable + complex_system_vars = SystemVariable( + user_id="test_user_id", app_id="test_app_id", workflow_id="test_workflow_id" + ) pool = VariablePool( user_inputs={"key": "value"}, - system_variables={SystemVariableKey.WORKFLOW_ID: "test_workflow_id"}, + system_variables=complex_system_vars, environment_variables=[ segment_to_variable( segment=build_segment(1), @@ -80,6 +113,302 @@ class TestVariablePool: ], ) - def test_constructor_with_invalid_system_variable_key(self): - with pytest.raises(ValidationError): - VariablePool(system_variables={"invalid_key": "value"}) # type: ignore + def test_get_system_variables(self): + sys_var = SystemVariable( + user_id="test_user_id", + app_id="test_app_id", + workflow_id="test_workflow_id", + workflow_execution_id="test_execution_123", + query="test query", + conversation_id="test_conv_id", + dialogue_count=5, + ) + pool = VariablePool(system_variables=sys_var) + + kv = [ + ("user_id", sys_var.user_id), + ("app_id", sys_var.app_id), + ("workflow_id", sys_var.workflow_id), + ("workflow_run_id", sys_var.workflow_execution_id), + ("query", sys_var.query), + ("conversation_id", sys_var.conversation_id), + ("dialogue_count", sys_var.dialogue_count), + ] + for key, expected_value in kv: + segment = pool.get([SYSTEM_VARIABLE_NODE_ID, key]) + assert segment is not None + assert segment.value == expected_value + + +class TestVariablePoolSerialization: + """Test cases for VariablePool serialization and deserialization using Pydantic's built-in methods. + + These tests focus exclusively on serialization/deserialization logic to ensure that + VariablePool data can be properly serialized to dictionaries/JSON and reconstructed + while preserving all data integrity. + """ + + _NODE1_ID = "node_1" + _NODE2_ID = "node_2" + _NODE3_ID = "node_3" + + def _create_pool_without_file(self): + # Create comprehensive system variables + system_vars = SystemVariable( + user_id="test_user_id", + app_id="test_app_id", + workflow_id="test_workflow_id", + workflow_execution_id="test_execution_123", + query="test query", + conversation_id="test_conv_id", + dialogue_count=5, + ) + + # Create environment variables with all types including ArrayFileVariable + env_vars: list[VariableUnion] = [ + StringVariable( + id="env_string_id", + name="env_string", + value="env_string_value", + selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_string"], + ), + IntegerVariable( + id="env_integer_id", + name="env_integer", + value=1, + selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_integer"], + ), + FloatVariable( + id="env_float_id", + name="env_float", + value=1.0, + selector=[ENVIRONMENT_VARIABLE_NODE_ID, "env_float"], + ), + ] + + # Create conversation variables with complex data + conv_vars: list[VariableUnion] = [ + StringVariable( + id="conv_string_id", + name="conv_string", + value="conv_string_value", + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_string"], + ), + IntegerVariable( + id="conv_integer_id", + name="conv_integer", + value=1, + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_integer"], + ), + FloatVariable( + id="conv_float_id", + name="conv_float", + value=1.0, + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_float"], + ), + ObjectVariable( + id="conv_object_id", + name="conv_object", + value={"key": "value", "nested": {"data": 123}}, + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_object"], + ), + ArrayStringVariable( + id="conv_array_string_id", + name="conv_array_string", + value=["conv_array_string_value"], + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_string"], + ), + ArrayNumberVariable( + id="conv_array_number_id", + name="conv_array_number", + value=[1, 1.0], + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_number"], + ), + ArrayObjectVariable( + id="conv_array_object_id", + name="conv_array_object", + value=[{"a": 1}, {"b": "2"}], + selector=[CONVERSATION_VARIABLE_NODE_ID, "conv_array_object"], + ), + ] + + # Create comprehensive user inputs + user_inputs = { + "string_input": "test_value", + "number_input": 42, + "object_input": {"nested": {"key": "value"}}, + "array_input": ["item1", "item2", "item3"], + } + + # Create VariablePool + pool = VariablePool( + system_variables=system_vars, + user_inputs=user_inputs, + environment_variables=env_vars, + conversation_variables=conv_vars, + ) + return pool + + def _add_node_data_to_pool(self, pool: VariablePool, with_file=False): + test_file = File( + tenant_id="test_tenant_id", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + related_id="test_related_id", + remote_url="test_url", + filename="test_file.txt", + storage_key="test_storage_key", + ) + + # Add various segment types to variable dictionary + pool.add((self._NODE1_ID, "string_var"), StringSegment(value="test_string")) + pool.add((self._NODE1_ID, "int_var"), IntegerSegment(value=123)) + pool.add((self._NODE1_ID, "float_var"), FloatSegment(value=45.67)) + pool.add((self._NODE1_ID, "object_var"), ObjectSegment(value={"test": "data"})) + if with_file: + pool.add((self._NODE1_ID, "file_var"), FileSegment(value=test_file)) + pool.add((self._NODE1_ID, "none_var"), NoneSegment()) + + # Add array segments including ArrayFileVariable + pool.add((self._NODE2_ID, "array_string"), ArrayStringSegment(value=["a", "b", "c"])) + pool.add((self._NODE2_ID, "array_number"), ArrayNumberSegment(value=[1, 2, 3])) + pool.add((self._NODE2_ID, "array_object"), ArrayObjectSegment(value=[{"a": 1}, {"b": 2}])) + if with_file: + pool.add((self._NODE2_ID, "array_file"), ArrayFileSegment(value=[test_file])) + pool.add((self._NODE2_ID, "array_any"), ArrayAnySegment(value=["mixed", 123, {"key": "value"}])) + + # Add nested variables + pool.add((self._NODE3_ID, "nested", "deep", "var"), StringSegment(value="deep_value")) + + def test_system_variables(self): + sys_vars = SystemVariable( + user_id="test_user_id", + app_id="test_app_id", + workflow_id="test_workflow_id", + workflow_execution_id="test_execution_123", + query="test query", + conversation_id="test_conv_id", + dialogue_count=5, + ) + pool = VariablePool(system_variables=sys_vars) + json = pool.model_dump_json() + pool2 = VariablePool.model_validate_json(json) + assert pool2.system_variables == sys_vars + + for mode in ["json", "python"]: + dict_ = pool.model_dump(mode=mode) + pool2 = VariablePool.model_validate(dict_) + assert pool2.system_variables == sys_vars + + def test_pool_without_file_vars(self): + pool = self._create_pool_without_file() + json = pool.model_dump_json() + pool2 = pool.model_validate_json(json) + assert pool2.system_variables == pool.system_variables + assert pool2.conversation_variables == pool.conversation_variables + assert pool2.environment_variables == pool.environment_variables + assert pool2.user_inputs == pool.user_inputs + assert pool2.variable_dictionary == pool.variable_dictionary + assert pool2 == pool + + def test_basic_dictionary_round_trip(self): + """Test basic round-trip serialization: model_dump() → model_validate()""" + # Create a comprehensive VariablePool with all data types + original_pool = self._create_pool_without_file() + self._add_node_data_to_pool(original_pool) + + # Serialize to dictionary using Pydantic's model_dump() + serialized_data = original_pool.model_dump() + + # Verify serialized data structure + assert isinstance(serialized_data, dict) + assert "system_variables" in serialized_data + assert "user_inputs" in serialized_data + assert "environment_variables" in serialized_data + assert "conversation_variables" in serialized_data + assert "variable_dictionary" in serialized_data + + # Deserialize back using Pydantic's model_validate() + reconstructed_pool = VariablePool.model_validate(serialized_data) + + # Verify data integrity is preserved + self._assert_pools_equal(original_pool, reconstructed_pool) + + def test_json_round_trip(self): + """Test JSON round-trip serialization: model_dump_json() → model_validate_json()""" + # Create a comprehensive VariablePool with all data types + original_pool = self._create_pool_without_file() + self._add_node_data_to_pool(original_pool) + + # Serialize to JSON string using Pydantic's model_dump_json() + json_data = original_pool.model_dump_json() + + # Verify JSON is valid string + assert isinstance(json_data, str) + assert len(json_data) > 0 + + # Deserialize back using Pydantic's model_validate_json() + reconstructed_pool = VariablePool.model_validate_json(json_data) + + # Verify data integrity is preserved + self._assert_pools_equal(original_pool, reconstructed_pool) + + def test_complex_data_serialization(self): + """Test serialization of complex data structures including ArrayFileVariable""" + original_pool = self._create_pool_without_file() + self._add_node_data_to_pool(original_pool, with_file=True) + + # Test dictionary round-trip + dict_data = original_pool.model_dump() + reconstructed_dict = VariablePool.model_validate(dict_data) + + # Test JSON round-trip + json_data = original_pool.model_dump_json() + reconstructed_json = VariablePool.model_validate_json(json_data) + + # Verify both reconstructed pools are equivalent + self._assert_pools_equal(reconstructed_dict, reconstructed_json) + # TODO: assert the data for file object... + + def _assert_pools_equal(self, pool1: VariablePool, pool2: VariablePool) -> None: + """Assert that two VariablePools contain equivalent data""" + + # Compare system variables + assert pool1.system_variables == pool2.system_variables + + # Compare user inputs + assert dict(pool1.user_inputs) == dict(pool2.user_inputs) + + # Compare environment variables count + assert pool1.environment_variables == pool2.environment_variables + + # Compare conversation variables count + assert pool1.conversation_variables == pool2.conversation_variables + + # Test key variable retrievals to ensure functionality is preserved + test_selectors = [ + (SYSTEM_VARIABLE_NODE_ID, "user_id"), + (SYSTEM_VARIABLE_NODE_ID, "app_id"), + (ENVIRONMENT_VARIABLE_NODE_ID, "env_string"), + (ENVIRONMENT_VARIABLE_NODE_ID, "env_number"), + (CONVERSATION_VARIABLE_NODE_ID, "conv_string"), + (self._NODE1_ID, "string_var"), + (self._NODE1_ID, "int_var"), + (self._NODE1_ID, "float_var"), + (self._NODE2_ID, "array_string"), + (self._NODE2_ID, "array_number"), + (self._NODE3_ID, "nested", "deep", "var"), + ] + + for selector in test_selectors: + val1 = pool1.get(selector) + val2 = pool2.get(selector) + + # Both should exist or both should be None + assert (val1 is None) == (val2 is None) + + if val1 is not None and val2 is not None: + # Values should be equal + assert val1.value == val2.value + # Value types should be the same (more important than exact class type) + assert val1.value_type == val2.value_type