diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 89bd3dd2a6..af15324f46 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -153,7 +153,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): system_variables=system_inputs, user_inputs=inputs, environment_variables=workflow.environment_variables, - # TODO(QuantumGhost): find a better way to resolve typing issue. + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. conversation_variables=cast(list[VariableUnion], conversation_variables), ) diff --git a/api/core/variables/segments.py b/api/core/variables/segments.py index c31a0f7747..13274f4e0e 100644 --- a/api/core/variables/segments.py +++ b/api/core/variables/segments.py @@ -1,7 +1,7 @@ import json import sys from collections.abc import Mapping, Sequence -from typing import Annotated, Any +from typing import Annotated, Any, TypeAlias from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator @@ -205,7 +205,15 @@ def get_segment_discriminator(v: Any) -> SegmentType | None: return None -SegmentUnion = Annotated[ +# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic. +# Use `Segment` for type hinting when serialization is not required. +# +# Note: +# - All variants in `SegmentUnion` must inherit from the `Segment` class. +# - The union must include all non-abstract subclasses of `Segment`, except: +# - `SegmentGroup`, which is not added to the variable pool. +# - `Variable` and its subclasses, which are handled by `VariableUnion`. +SegmentUnion: TypeAlias = Annotated[ ( Annotated[NoneSegment, Tag(SegmentType.NONE)] | Annotated[StringSegment, Tag(SegmentType.STRING)] diff --git a/api/core/variables/types.py b/api/core/variables/types.py index 08227fce18..e39237dba5 100644 --- a/api/core/variables/types.py +++ b/api/core/variables/types.py @@ -43,8 +43,11 @@ class SegmentType(StrEnum): @classmethod def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]: - """Try to infer the SegmentType from the Python type of - the `value` parameter. + """ + Attempt to infer the `SegmentType` based on the Python type of the `value` parameter. + + Returns `None` if no appropriate `SegmentType` can be determined for the given `value`. + For example, this may occur if the input is a generic Python object of type `object`. """ if isinstance(value, list): diff --git a/api/core/variables/variables.py b/api/core/variables/variables.py index 7989f3b032..a31ebc848e 100644 --- a/api/core/variables/variables.py +++ b/api/core/variables/variables.py @@ -100,6 +100,12 @@ class ArrayFileVariable(ArrayFileSegment, ArrayVariable): pass +# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic. +# Use `Variable` for type hinting when serialization is not required. +# +# Note: +# - All variants in `VariableUnion` must inherit from the `Variable` class. +# - The union must include all non-abstract subclasses of `Segment`, except: VariableUnion: TypeAlias = Annotated[ ( Annotated[NoneVariable, Tag(SegmentType.NONE)] diff --git a/api/core/workflow/entities/variable_pool.py b/api/core/workflow/entities/variable_pool.py index 6222d435db..d96237741d 100644 --- a/api/core/workflow/entities/variable_pool.py +++ b/api/core/workflow/entities/variable_pool.py @@ -34,10 +34,6 @@ class VariablePool(BaseModel): description="User inputs", default_factory=dict, ) - # system_variables: Mapping[SystemVariableKey, Any] = Field( - # description="System variables", - # default_factory=dict, - # ) system_variables: SystemVariable = Field( description="System variables", ) @@ -92,6 +88,8 @@ class VariablePool(BaseModel): # Ensure the first-level key exists in the dictionary if key not in self.variable_dictionary: self.variable_dictionary[key] = {} + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable) @classmethod @@ -199,7 +197,3 @@ class VariablePool(BaseModel): @classmethod def loads(cls, json_data: str) -> "VariablePool": return VariablePool.model_validate_json(json_data) - - def reload_storage_keys_for_file_types(self): - # TODO - pass diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 27a1ded63e..13f4fe329e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -693,7 +693,9 @@ def _setup_variable_pool( system_variables=system_variable, user_inputs=user_inputs, environment_variables=workflow.environment_variables, - conversation_variables=cast(list[VariableUnion], conversation_variables), + # Based on the definition of `VariableUnion`, + # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. + conversation_variables=cast(list[VariableUnion], conversation_variables), # ) return variable_pool diff --git a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py index 27eec08cae..3f83428834 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py @@ -5,7 +5,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom from core.workflow.entities.node_entities import NodeRunResult, WorkflowNodeExecutionMetadataKey from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from core.workflow.enums import SystemVariableKey from core.workflow.graph_engine.entities.event import ( GraphRunPartialSucceededEvent, NodeRunExceptionEvent,