feat(api): Making VariablePool and GraphRuntimeState serializable

pull/22025/head
QuantumGhost 11 months ago
parent 11bc5c0dec
commit d9bc894bb9

@ -68,13 +68,18 @@ def _create_pagination_parser():
return parser return parser
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type
return value_type.exposed_type().value
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = { _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"id": fields.String, "id": fields.String,
"type": fields.String(attribute=lambda model: model.get_variable_type()), "type": fields.String(attribute=lambda model: model.get_variable_type()),
"name": fields.String, "name": fields.String,
"description": fields.String, "description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String, "value_type": fields.String(attribute=_serialize_variable_type),
"edited": fields.Boolean(attribute=lambda model: model.edited), "edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean, "visible": fields.Boolean,
} }
@ -90,7 +95,7 @@ _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
"name": fields.String, "name": fields.String,
"description": fields.String, "description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()), "selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String, "value_type": fields.String(attribute=_serialize_variable_type),
"edited": fields.Boolean(attribute=lambda model: model.edited), "edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean, "visible": fields.Boolean,
} }
@ -396,7 +401,7 @@ class EnvironmentVariableCollectionApi(Resource):
"name": v.name, "name": v.name,
"description": v.description, "description": v.description,
"selector": v.selector, "selector": v.selector,
"value_type": v.value_type.value, "value_type": v.value_type.exposed_type().value,
"value": v.value, "value": v.value,
# Do not track edited for env vars. # Do not track edited for env vars.
"edited": False, "edited": False,

@ -16,6 +16,7 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent, QueueTextChunkEvent,
) )
from core.moderation.base import ModerationError from core.moderation.base import ModerationError
from core.variables.variables import VariableUnion
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.system_variable import SystemVariable from core.workflow.system_variable import SystemVariable
@ -152,7 +153,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
system_variables=system_inputs, system_variables=system_inputs,
user_inputs=inputs, user_inputs=inputs,
environment_variables=workflow.environment_variables, environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables, # TODO(QuantumGhost): find a better way to resolve typing issue.
conversation_variables=cast(list[VariableUnion], conversation_variables),
) )
# init graph # init graph

@ -1,9 +1,9 @@
import json import json
import sys import sys
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any from typing import Annotated, Any
from pydantic import BaseModel, ConfigDict, field_validator from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.file import File from core.file import File
@ -11,6 +11,11 @@ from .types import SegmentType
class Segment(BaseModel): class Segment(BaseModel):
"""Segment is runtime type used during the execution of workflow.
Note: this class is abstract, you should use subclasses of this class instead.
"""
model_config = ConfigDict(frozen=True) model_config = ConfigDict(frozen=True)
value_type: SegmentType value_type: SegmentType
@ -73,7 +78,7 @@ class StringSegment(Segment):
class FloatSegment(Segment): class FloatSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER value_type: SegmentType = SegmentType.FLOAT
value: float value: float
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass. # The following tests cannot pass.
@ -92,7 +97,7 @@ class FloatSegment(Segment):
class IntegerSegment(Segment): class IntegerSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER value_type: SegmentType = SegmentType.INTEGER
value: int value: int
@ -181,3 +186,38 @@ class ArrayFileSegment(ArraySegment):
@property @property
def text(self) -> str: def text(self) -> str:
return "" return ""
def get_segment_discriminator(v: Any) -> SegmentType | None:
if isinstance(v, Segment):
return v.value_type
elif isinstance(v, dict):
value_type = v.get("value_type")
if value_type is None:
return None
try:
seg_type = SegmentType(value_type)
except ValueError:
return None
return seg_type
else:
# return None if the discriminator value isn't found
return None
SegmentUnion = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]
| Annotated[StringSegment, Tag(SegmentType.STRING)]
| Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
| Annotated[FileSegment, Tag(SegmentType.FILE)]
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
),
Discriminator(get_segment_discriminator),
]

@ -1,8 +1,27 @@
from collections.abc import Mapping
from enum import StrEnum from enum import StrEnum
from typing import Any, Optional
from core.file.models import File
class ArrayValidation(StrEnum):
"""Strategy for validating array elements"""
# Skip element validation (only check array container)
NONE = "none"
# Validate the first element (if array is non-empty)
FIRST = "first"
# Validate all elements in the array.
ALL = "all"
class SegmentType(StrEnum): class SegmentType(StrEnum):
NUMBER = "number" NUMBER = "number"
INTEGER = "integer"
FLOAT = "float"
STRING = "string" STRING = "string"
OBJECT = "object" OBJECT = "object"
SECRET = "secret" SECRET = "secret"
@ -19,16 +38,138 @@ class SegmentType(StrEnum):
GROUP = "group" GROUP = "group"
def is_array_type(self): def is_array_type(self) -> bool:
return self in _ARRAY_TYPES return self in _ARRAY_TYPES
@classmethod
def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
"""Try to infer the SegmentType from the Python type of
the `value` parameter.
"""
if isinstance(value, list):
elem_types: set[SegmentType] = set()
for i in value:
segment_type = cls.infer_segment_type(i)
if segment_type is None:
return None
elem_types.add(segment_type)
if len(elem_types) != 1:
if elem_types.issubset(_NUMERICAL_TYPES):
return SegmentType.ARRAY_NUMBER
return SegmentType.ARRAY_ANY
elif all(i.is_array_type() for i in elem_types):
return SegmentType.ARRAY_ANY
match elem_types.pop():
case SegmentType.STRING:
return SegmentType.ARRAY_STRING
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return SegmentType.ARRAY_NUMBER
case SegmentType.OBJECT:
return SegmentType.ARRAY_OBJECT
case SegmentType.FILE:
return SegmentType.ARRAY_FILE
case SegmentType.NONE:
return SegmentType.ARRAY_ANY
case _:
# This should be unreachable.
raise ValueError(f"not supported value {value}")
if value is None:
return SegmentType.NONE
elif isinstance(value, int) and not isinstance(value, bool):
return SegmentType.INTEGER
elif isinstance(value, float):
return SegmentType.FLOAT
elif isinstance(value, str):
return SegmentType.STRING
elif isinstance(value, dict):
return SegmentType.OBJECT
elif isinstance(value, File):
return SegmentType.FILE
elif isinstance(value, str):
return SegmentType.STRING
else:
return None
def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool:
if not isinstance(value, list):
return False
# Skip element validation if array is empty
if len(value) == 0:
return True
if self == SegmentType.ARRAY_ANY:
return True
element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self]
if array_validation == ArrayValidation.NONE:
return True
elif array_validation == ArrayValidation.FIRST:
return element_type.is_valid(value[0])
else:
return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value)
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
"""
Check if a value matches the segment type.
Users of `SegmentType` should call this method, instead of using
`isinstance` manually.
Args:
value: The value to validate
array_validation: Validation strategy for array types (ignored for non-array types)
Returns:
True if the value matches the type under the given validation strategy
"""
if self.is_array_type():
return self._validate_array(value, array_validation)
elif self == SegmentType.NUMBER:
return isinstance(value, (int, float))
elif self == SegmentType.STRING:
return isinstance(value, str)
elif self == SegmentType.OBJECT:
return isinstance(value, dict)
elif self == SegmentType.SECRET:
return isinstance(value, str)
elif self == SegmentType.FILE:
return isinstance(value, File)
elif self == SegmentType.NONE:
return value is None
else:
raise AssertionError("this statement should be unreachable.")
def exposed_type(self) -> "SegmentType":
"""Returns the type exposed to the frontend.
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
"""
if self in (SegmentType.INTEGER, SegmentType.FLOAT):
return SegmentType.NUMBER
return self
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
# ARRAY_ANY does not have correpond element type.
SegmentType.ARRAY_STRING: SegmentType.STRING,
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
SegmentType.ARRAY_FILE: SegmentType.FILE,
}
_ARRAY_TYPES = frozenset( _ARRAY_TYPES = frozenset(
[ list(_ARRAY_ELEMENT_TYPES_MAPPING.keys())
+ [
SegmentType.ARRAY_ANY, SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING, ]
SegmentType.ARRAY_NUMBER, )
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
_NUMERICAL_TYPES = frozenset(
[
SegmentType.NUMBER,
SegmentType.INTEGER,
SegmentType.FLOAT,
] ]
) )

@ -1,10 +1,11 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import cast from typing import Annotated, TypeAlias, cast
from uuid import uuid4 from uuid import uuid4
from pydantic import Field from pydantic import Discriminator, Field, Tag
from core.helper import encrypter from core.helper import encrypter
from core.variables.segment_group import SegmentGroup
from .segments import ( from .segments import (
ArrayAnySegment, ArrayAnySegment,
@ -20,6 +21,7 @@ from .segments import (
ObjectSegment, ObjectSegment,
Segment, Segment,
StringSegment, StringSegment,
get_segment_discriminator,
) )
from .types import SegmentType from .types import SegmentType
@ -27,6 +29,10 @@ from .types import SegmentType
class Variable(Segment): class Variable(Segment):
""" """
A variable is a segment that has a name. A variable is a segment that has a name.
It is mainly used to store segments and their selector in VariablePool.
Note: this class is abstract, you should use subclasses of this class instead.
""" """
id: str = Field( id: str = Field(
@ -93,3 +99,22 @@ class FileVariable(FileSegment, Variable):
class ArrayFileVariable(ArrayFileSegment, ArrayVariable): class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass pass
VariableUnion: TypeAlias = Annotated[
(
Annotated[NoneVariable, Tag(SegmentType.NONE)]
| Annotated[StringVariable, Tag(SegmentType.STRING)]
| Annotated[FloatVariable, Tag(SegmentType.FLOAT)]
| Annotated[IntegerVariable, Tag(SegmentType.INTEGER)]
| Annotated[ObjectVariable, Tag(SegmentType.OBJECT)]
| Annotated[FileVariable, Tag(SegmentType.FILE)]
| Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
),
Discriminator(get_segment_discriminator),
]

@ -1,7 +1,7 @@
import re import re
from collections import defaultdict from collections import defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Union from typing import Any, Union, cast
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -9,8 +9,9 @@ from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment from core.variables.segments import FileSegment, NoneSegment
from core.variables.variables import VariableUnion
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import SystemVariableKey from core.workflow.system_variable import SystemVariable
from factories import variable_factory from factories import variable_factory
VariableValue = Union[str, int, float, dict, list, File] VariableValue = Union[str, int, float, dict, list, File]
@ -23,31 +24,35 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary. # The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one. # elements of the selector except the first one.
variable_dictionary: dict[str, dict[int, Segment]] = Field( variable_dictionary: dict[str, dict[int, VariableUnion]] = Field(
description="Variables mapping", description="Variables mapping",
default=defaultdict(dict), default=defaultdict(dict),
) )
# TODO: This user inputs is not used for pool.
# The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere.
user_inputs: Mapping[str, Any] = Field( user_inputs: Mapping[str, Any] = Field(
description="User inputs", description="User inputs",
default_factory=dict, default_factory=dict,
) )
system_variables: Mapping[SystemVariableKey, Any] = Field( # system_variables: Mapping[SystemVariableKey, Any] = Field(
# description="System variables",
# default_factory=dict,
# )
system_variables: SystemVariable = Field(
description="System variables", description="System variables",
default_factory=dict,
) )
environment_variables: Sequence[Variable] = Field( environment_variables: Sequence[VariableUnion] = Field(
description="Environment variables.", description="Environment variables.",
default_factory=list, default_factory=list,
) )
conversation_variables: Sequence[Variable] = Field( conversation_variables: Sequence[VariableUnion] = Field(
description="Conversation variables.", description="Conversation variables.",
default_factory=list, default_factory=list,
) )
def model_post_init(self, context: Any, /) -> None: def model_post_init(self, context: Any, /) -> None:
for key, value in self.system_variables.items(): # Create a mapping from field names to SystemVariableKey enum values
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value) self._add_system_variables(self.system_variables)
# Add environment variables to the variable pool # Add environment variables to the variable pool
for var in self.environment_variables: for var in self.environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
@ -83,8 +88,23 @@ class VariablePool(BaseModel):
segment = variable_factory.build_segment(value) segment = variable_factory.build_segment(value)
variable = variable_factory.segment_to_variable(segment=segment, selector=selector) variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
hash_key = hash(tuple(selector[1:])) key, hash_key = self._selector_to_keys(selector)
self.variable_dictionary[selector[0]][hash_key] = variable # Ensure the first-level key exists in the dictionary
if key not in self.variable_dictionary:
self.variable_dictionary[key] = {}
self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
@classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]:
return selector[0], hash(tuple(selector[1:]))
def _has(self, selector: Sequence[str]) -> bool:
key, hash_key = self._selector_to_keys(selector)
if key not in self.variable_dictionary:
return False
if hash_key not in self.variable_dictionary[key]:
return False
return True
def get(self, selector: Sequence[str], /) -> Segment | None: def get(self, selector: Sequence[str], /) -> Segment | None:
""" """
@ -102,8 +122,8 @@ class VariablePool(BaseModel):
if len(selector) < MIN_SELECTORS_LENGTH: if len(selector) < MIN_SELECTORS_LENGTH:
return None return None
hash_key = hash(tuple(selector[1:])) key, hash_key = self._selector_to_keys(selector)
value = self.variable_dictionary[selector[0]].get(hash_key) value: Segment | None = self.variable_dictionary[key].get(hash_key)
if value is None: if value is None:
selector, attr = selector[:-1], selector[-1] selector, attr = selector[:-1], selector[-1]
@ -136,8 +156,9 @@ class VariablePool(BaseModel):
if len(selector) == 1: if len(selector) == 1:
self.variable_dictionary[selector[0]] = {} self.variable_dictionary[selector[0]] = {}
return return
key, hash_key = self._selector_to_keys(selector)
hash_key = hash(tuple(selector[1:])) hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]].pop(hash_key, None) self.variable_dictionary[key].pop(hash_key, None)
def convert_template(self, template: str, /): def convert_template(self, template: str, /):
parts = VARIABLE_PATTERN.split(template) parts = VARIABLE_PATTERN.split(template)
@ -154,3 +175,31 @@ class VariablePool(BaseModel):
if isinstance(segment, FileSegment): if isinstance(segment, FileSegment):
return segment return segment
return None return None
def _add_system_variables(self, system_variable: SystemVariable):
sys_var_mapping = system_variable.to_dict()
for key, value in sys_var_mapping.items():
if value is None:
continue
selector = (SYSTEM_VARIABLE_NODE_ID, key)
# If the system variable already exists, do not add it again.
# This ensures that we can keep the id of the system variables intact.
if self._has(selector):
continue
self.add(selector, value) # type: ignore
@classmethod
def empty(cls) -> "VariablePool":
"""Create an empty variable pool."""
return cls(system_variables=SystemVariable.empty())
def dumps(self) -> str:
return self.model_dump_json()
@classmethod
def loads(cls, json_data: str) -> "VariablePool":
return VariablePool.model_validate_json(json_data)
def reload_storage_keys_for_file_types(self):
# TODO
pass

@ -17,8 +17,12 @@ class GraphRuntimeState(BaseModel):
"""total tokens""" """total tokens"""
llm_usage: LLMUsage = LLMUsage.empty_usage() llm_usage: LLMUsage = LLMUsage.empty_usage()
"""llm usage info""" """llm usage info"""
# The `outputs` field stores the final output values generated by executing workflows or chatflows.
#
# Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
# after a serialization and deserialization round trip.
outputs: dict[str, Any] = {} outputs: dict[str, Any] = {}
"""outputs"""
node_run_steps: int = 0 node_run_steps: int = 0
"""node run steps""" """node run steps"""

@ -1,11 +1,29 @@
from collections.abc import Mapping from collections.abc import Mapping
from typing import Any, Literal, Optional from typing import Annotated, Any, Literal, Optional
from pydantic import BaseModel, Field from pydantic import AfterValidator, BaseModel, Field
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
from core.workflow.utils.condition.entities import Condition from core.workflow.utils.condition.entities import Condition
_VALID_VAR_TYPE = frozenset(
[
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
]
)
def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
if seg_type not in _VALID_VAR_TYPE:
raise ValueError(...)
return seg_type
class LoopVariableData(BaseModel): class LoopVariableData(BaseModel):
""" """
@ -13,7 +31,7 @@ class LoopVariableData(BaseModel):
""" """
label: str label: str
var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"] var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"] value_type: Literal["variable", "constant"]
value: Optional[Any | list[str]] = None value: Optional[Any | list[str]] = None

@ -7,14 +7,9 @@ from typing import TYPE_CHECKING, Any, Literal, cast
from configs import dify_config from configs import dify_config
from core.variables import ( from core.variables import (
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
IntegerSegment, IntegerSegment,
ObjectSegment,
Segment, Segment,
SegmentType, SegmentType,
StringSegment,
) )
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -39,6 +34,7 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.loop.entities import LoopNodeData from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.utils.condition.processor import ConditionProcessor from core.workflow.utils.condition.processor import ConditionProcessor
from factories.variable_factory import TypeMismatchError, build_segment_with_type
if TYPE_CHECKING: if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
@ -505,23 +501,21 @@ class LoopNode(BaseNode[LoopNodeData]):
return variable_mapping return variable_mapping
@staticmethod @staticmethod
def _get_segment_for_constant(var_type: str, value: Any) -> Segment: def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
"""Get the appropriate segment type for a constant value.""" """Get the appropriate segment type for a constant value."""
segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = {
"string": (StringSegment, SegmentType.STRING),
"number": (IntegerSegment, SegmentType.NUMBER),
"object": (ObjectSegment, SegmentType.OBJECT),
"array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING),
"array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER),
"array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT),
}
if var_type in ["array[string]", "array[number]", "array[object]"]: if var_type in ["array[string]", "array[number]", "array[object]"]:
if value: if value and isinstance(value, str):
value = json.loads(value) value = json.loads(value)
else: else:
value = [] value = []
segment_info = segment_mapping.get(var_type) try:
if not segment_info: return build_segment_with_type(var_type, value)
raise ValueError(f"Invalid variable type: {var_type}") except TypeMismatchError as type_exc:
segment_class, value_type = segment_info # Attempt to parse the value as a JSON-encoded string, if applicable.
return segment_class(value=value, value_type=value_type) if not isinstance(value, str):
raise
try:
value = json.loads(value)
except ValueError:
raise type_exc
return build_segment_with_type(var_type, value)

@ -16,7 +16,7 @@ class StartNode(BaseNode[StartNodeData]):
def _run(self) -> NodeRunResult: def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling # TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs. # Set system variables as node outputs.

@ -130,6 +130,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
def get_zero_value(t: SegmentType): def get_zero_value(t: SegmentType):
# TODO(QuantumGhost): this should be a method of `SegmentType`.
match t: match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER: case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return variable_factory.build_segment([]) return variable_factory.build_segment([])
@ -137,6 +138,10 @@ def get_zero_value(t: SegmentType):
return variable_factory.build_segment({}) return variable_factory.build_segment({})
case SegmentType.STRING: case SegmentType.STRING:
return variable_factory.build_segment("") return variable_factory.build_segment("")
case SegmentType.INTEGER:
return variable_factory.build_segment(0)
case SegmentType.FLOAT:
return variable_factory.build_segment(0.0)
case SegmentType.NUMBER: case SegmentType.NUMBER:
return variable_factory.build_segment(0) return variable_factory.build_segment(0)
case _: case _:

@ -1,5 +1,6 @@
from core.variables import SegmentType from core.variables import SegmentType
# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy.
EMPTY_VALUE_MAPPING = { EMPTY_VALUE_MAPPING = {
SegmentType.STRING: "", SegmentType.STRING: "",
SegmentType.NUMBER: 0, SegmentType.NUMBER: 0,

@ -10,10 +10,16 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
case Operation.OVER_WRITE | Operation.CLEAR: case Operation.OVER_WRITE | Operation.CLEAR:
return True return True
case Operation.SET: case Operation.SET:
return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER} return variable_type in {
SegmentType.OBJECT,
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.INTEGER,
SegmentType.FLOAT,
}
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE: case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
# Only number variable can be added, subtracted, multiplied or divided # Only number variable can be added, subtracted, multiplied or divided
return variable_type == SegmentType.NUMBER return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
case Operation.APPEND | Operation.EXTEND: case Operation.APPEND | Operation.EXTEND:
# Only array variable can be appended or extended # Only array variable can be appended or extended
return variable_type in { return variable_type in {
@ -46,7 +52,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat
match variable_type: match variable_type:
case SegmentType.STRING | SegmentType.OBJECT: case SegmentType.STRING | SegmentType.OBJECT:
return operation in {Operation.OVER_WRITE, Operation.SET} return operation in {Operation.OVER_WRITE, Operation.SET}
case SegmentType.NUMBER: case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return operation in { return operation in {
Operation.OVER_WRITE, Operation.OVER_WRITE,
Operation.SET, Operation.SET,
@ -66,7 +72,7 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
case SegmentType.STRING: case SegmentType.STRING:
return isinstance(value, str) return isinstance(value, str)
case SegmentType.NUMBER: case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
if not isinstance(value, int | float): if not isinstance(value, int | float):
return False return False
if operation == Operation.DIVIDE and value == 0: if operation == Operation.DIVIDE and value == 0:

@ -60,7 +60,7 @@ class WorkflowCycleManager:
# Iterate over SystemVariable fields using Pydantic's model_fields # Iterate over SystemVariable fields using Pydantic's model_fields
if self._workflow_system_variables: if self._workflow_system_variables:
for field_name, value in self._workflow_system_variables.to_dict(): for field_name, value in self._workflow_system_variables.to_dict().items():
if field_name == SystemVariableKey.CONVERSATION_ID: if field_name == SystemVariableKey.CONVERSATION_ID:
continue continue
inputs[f"sys.{field_name}"] = value inputs[f"sys.{field_name}"] = value

@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event import NodeEvent from core.workflow.nodes.event import NodeEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from factories import file_factory from factories import file_factory
from models.enums import UserFrom from models.enums import UserFrom
@ -254,7 +255,7 @@ class WorkflowEntry:
# init variable pool # init variable pool
variable_pool = VariablePool( variable_pool = VariablePool(
system_variables={}, system_variables=SystemVariable.empty(),
user_inputs={}, user_inputs={},
environment_variables=[], environment_variables=[],
) )

@ -2,6 +2,8 @@ from collections.abc import Mapping, Sequence
from typing import Any, cast from typing import Any, cast
from uuid import uuid4 from uuid import uuid4
from shapely import is_valid
from configs import dify_config from configs import dify_config
from core.file import File from core.file import File
from core.variables.exc import VariableError from core.variables.exc import VariableError
@ -91,9 +93,13 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
result = StringVariable.model_validate(mapping) result = StringVariable.model_validate(mapping)
case SegmentType.SECRET: case SegmentType.SECRET:
result = SecretVariable.model_validate(mapping) result = SecretVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, int): case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int):
mapping = dict(mapping)
mapping["value_type"] = SegmentType.INTEGER
result = IntegerVariable.model_validate(mapping) result = IntegerVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, float): case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float):
mapping = dict(mapping)
mapping["value_type"] = SegmentType.FLOAT
result = FloatVariable.model_validate(mapping) result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int): case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f"invalid number value {value}") raise VariableError(f"invalid number value {value}")
@ -119,6 +125,8 @@ def infer_segment_type_from_value(value: Any, /) -> SegmentType:
def build_segment(value: Any, /) -> Segment: def build_segment(value: Any, /) -> Segment:
# NOTE: If you have runtime type information available, consider using the `build_segment_with_type`
# below
if value is None: if value is None:
return NoneSegment() return NoneSegment()
if isinstance(value, str): if isinstance(value, str):
@ -134,12 +142,17 @@ def build_segment(value: Any, /) -> Segment:
if isinstance(value, list): if isinstance(value, list):
items = [build_segment(item) for item in value] items = [build_segment(item) for item in value]
types = {item.value_type for item in items} types = {item.value_type for item in items}
if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items): if all(isinstance(item, ArraySegment) for item in items):
return ArrayAnySegment(value=value)
elif len(types) != 1:
if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}):
return ArrayNumberSegment(value=value)
return ArrayAnySegment(value=value) return ArrayAnySegment(value=value)
match types.pop(): match types.pop():
case SegmentType.STRING: case SegmentType.STRING:
return ArrayStringSegment(value=value) return ArrayStringSegment(value=value)
case SegmentType.NUMBER: case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return ArrayNumberSegment(value=value) return ArrayNumberSegment(value=value)
case SegmentType.OBJECT: case SegmentType.OBJECT:
return ArrayObjectSegment(value=value) return ArrayObjectSegment(value=value)
@ -153,6 +166,22 @@ def build_segment(value: Any, /) -> Segment:
raise ValueError(f"not supported value {value}") raise ValueError(f"not supported value {value}")
_segment_factory: Mapping[SegmentType, type[Segment]] = {
SegmentType.NONE: NoneSegment,
SegmentType.STRING: StringSegment,
SegmentType.INTEGER: IntegerSegment,
SegmentType.FLOAT: FloatSegment,
SegmentType.FILE: FileSegment,
SegmentType.OBJECT: ObjectSegment,
# Array types
SegmentType.ARRAY_ANY: ArrayAnySegment,
SegmentType.ARRAY_STRING: ArrayStringSegment,
SegmentType.ARRAY_NUMBER: ArrayNumberSegment,
SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
SegmentType.ARRAY_FILE: ArrayFileSegment,
}
def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
""" """
Build a segment with explicit type checking. Build a segment with explicit type checking.
@ -190,7 +219,7 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
if segment_type == SegmentType.NONE: if segment_type == SegmentType.NONE:
return NoneSegment() return NoneSegment()
else: else:
raise TypeMismatchError(f"Expected {segment_type}, but got None") raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None")
# Handle empty list special case for array types # Handle empty list special case for array types
if isinstance(value, list) and len(value) == 0: if isinstance(value, list) and len(value) == 0:
@ -205,21 +234,25 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
elif segment_type == SegmentType.ARRAY_FILE: elif segment_type == SegmentType.ARRAY_FILE:
return ArrayFileSegment(value=value) return ArrayFileSegment(value=value)
else: else:
raise TypeMismatchError(f"Expected {segment_type}, but got empty list") raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list")
# Build segment using existing logic to infer actual type
inferred_segment = build_segment(value)
inferred_type = inferred_segment.value_type
inferred_type = SegmentType.infer_segment_type(value)
# Type compatibility checking # Type compatibility checking
if inferred_type is None:
raise TypeMismatchError(
f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}"
)
if inferred_type == segment_type: if inferred_type == segment_type:
return inferred_segment segment_class = _segment_factory[segment_type]
return segment_class(value_type=segment_type, value=value)
# Type mismatch - raise error with descriptive message elif segment_type == SegmentType.NUMBER and inferred_type in (
raise TypeMismatchError( SegmentType.INTEGER,
f"Type mismatch: expected {segment_type}, but value '{value}' " SegmentType.FLOAT,
f"(type: {type(value).__name__}) corresponds to {inferred_type}" ):
) segment_class = _segment_factory[inferred_type]
return segment_class(value_type=inferred_type, value=value)
else:
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
def segment_to_variable( def segment_to_variable(
@ -247,6 +280,6 @@ def segment_to_variable(
name=name, name=name,
description=description, description=description,
value=segment.value, value=segment.value,
selector=selector, selector=list(selector),
), ),
) )

@ -0,0 +1,15 @@
from typing import TypedDict
from core.variables.segments import Segment
from core.variables.types import SegmentType
class _VarTypedDict(TypedDict, total=False):
value_type: SegmentType
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
if isinstance(v, Segment):
return v.value_type.exposed_type().value
else:
return v["value_type"].exposed_type().value

@ -2,10 +2,12 @@ from flask_restful import fields
from libs.helper import TimestampField from libs.helper import TimestampField
from ._value_type_serializer import serialize_value_type
conversation_variable_fields = { conversation_variable_fields = {
"id": fields.String, "id": fields.String,
"name": fields.String, "name": fields.String,
"value_type": fields.String(attribute="value_type.value"), "value_type": fields.String(attribute=serialize_value_type),
"value": fields.String, "value": fields.String,
"description": fields.String, "description": fields.String,
"created_at": TimestampField, "created_at": TimestampField,

@ -5,6 +5,8 @@ from core.variables import SecretVariable, SegmentType, Variable
from fields.member_fields import simple_account_fields from fields.member_fields import simple_account_fields
from libs.helper import TimestampField from libs.helper import TimestampField
from ._value_type_serializer import serialize_value_type
ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET) ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET)
@ -23,10 +25,15 @@ class EnvironmentVariableField(fields.Raw):
"id": value.id, "id": value.id,
"name": value.name, "name": value.name,
"value": value.value, "value": value.value,
"value_type": value.value_type.value, "value_type": value.value_type.exposed_type().value,
} }
if isinstance(value, dict): if isinstance(value, dict):
value_type = value.get("value_type") value_type_str = value.get("value_type")
if not isinstance(value_type_str, str):
raise TypeError(
f"unexpected type for value_type field, value={value_type_str}, type={type(value_type_str)}"
)
value_type = SegmentType(value_type_str).exposed_type()
if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES: if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES:
raise ValueError(f"Unsupported environment variable value type: {value_type}") raise ValueError(f"Unsupported environment variable value type: {value_type}")
return value return value
@ -35,7 +42,7 @@ class EnvironmentVariableField(fields.Raw):
conversation_variable_fields = { conversation_variable_fields = {
"id": fields.String, "id": fields.String,
"name": fields.String, "name": fields.String,
"value_type": fields.String(attribute="value_type.value"), "value_type": fields.String(attribute=serialize_value_type),
"value": fields.Raw, "value": fields.Raw,
"description": fields.String, "description": fields.String,
} }

@ -12,6 +12,7 @@ from sqlalchemy import orm
from core.file.constants import maybe_file_object from core.file.constants import maybe_file_object
from core.file.models import File from core.file.models import File
from core.variables import utils as variable_utils from core.variables import utils as variable_utils
from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from factories.variable_factory import TypeMismatchError, build_segment_with_type from factories.variable_factory import TypeMismatchError, build_segment_with_type
@ -346,7 +347,7 @@ class Workflow(Base):
) )
@property @property
def environment_variables(self) -> Sequence[Variable]: def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# TODO: find some way to init `self._environment_variables` when instance created. # TODO: find some way to init `self._environment_variables` when instance created.
if self._environment_variables is None: if self._environment_variables is None:
self._environment_variables = "{}" self._environment_variables = "{}"
@ -371,11 +372,15 @@ class Workflow(Base):
def decrypt_func(var): def decrypt_func(var):
if isinstance(var, SecretVariable): if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)}) return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
else: elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
return var return var
else:
raise AssertionError("this statement should be unreachable.")
results = list(map(decrypt_func, results)) decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list(
return results map(decrypt_func, results)
)
return decrypted_results
@environment_variables.setter @environment_variables.setter
def environment_variables(self, value: Sequence[Variable]): def environment_variables(self, value: Sequence[Variable]):

@ -3,7 +3,7 @@ import time
import uuid import uuid
from collections.abc import Callable, Generator, Mapping, Sequence from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any, Optional from typing import Any, Optional, cast
from uuid import uuid4 from uuid import uuid4
from sqlalchemy import select from sqlalchemy import select
@ -15,6 +15,7 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File from core.file import File
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables import Variable from core.variables import Variable
from core.variables.variables import VariableUnion
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
@ -692,7 +693,7 @@ def _setup_variable_pool(
system_variables=system_variable, system_variables=system_variable,
user_inputs=user_inputs, user_inputs=user_inputs,
environment_variables=workflow.environment_variables, environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables, conversation_variables=cast(list[VariableUnion], conversation_variables),
) )
return variable_pool return variable_pool

@ -11,6 +11,7 @@ from core.variables import (
SegmentType, SegmentType,
StringVariable, StringVariable,
) )
from core.variables.variables import Variable
def test_frozen_variables(): def test_frozen_variables():
@ -75,7 +76,7 @@ def test_object_variable_to_object():
def test_variable_to_object(): def test_variable_to_object():
var = StringVariable(name="text", value="text") var: Variable = StringVariable(name="text", value="text")
assert var.to_object() == "text" assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42) var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42 assert var.to_object() == 42

@ -509,8 +509,8 @@ def test_build_segment_type_for_scalar():
size=1000, size=1000,
) )
cases = [ cases = [
TestCase(0, SegmentType.NUMBER), TestCase(0, SegmentType.INTEGER),
TestCase(0.0, SegmentType.NUMBER), TestCase(0.0, SegmentType.FLOAT),
TestCase("", SegmentType.STRING), TestCase("", SegmentType.STRING),
TestCase(file, SegmentType.FILE), TestCase(file, SegmentType.FILE),
] ]
@ -535,14 +535,14 @@ class TestBuildSegmentWithType:
result = build_segment_with_type(SegmentType.NUMBER, 42) result = build_segment_with_type(SegmentType.NUMBER, 42)
assert isinstance(result, IntegerSegment) assert isinstance(result, IntegerSegment)
assert result.value == 42 assert result.value == 42
assert result.value_type == SegmentType.NUMBER assert result.value_type == SegmentType.INTEGER
def test_number_type_float(self): def test_number_type_float(self):
"""Test building a number segment with float value.""" """Test building a number segment with float value."""
result = build_segment_with_type(SegmentType.NUMBER, 3.14) result = build_segment_with_type(SegmentType.NUMBER, 3.14)
assert isinstance(result, FloatSegment) assert isinstance(result, FloatSegment)
assert result.value == 3.14 assert result.value == 3.14
assert result.value_type == SegmentType.NUMBER assert result.value_type == SegmentType.FLOAT
def test_object_type(self): def test_object_type(self):
"""Test building an object segment with correct type.""" """Test building an object segment with correct type."""
@ -656,14 +656,14 @@ class TestBuildSegmentWithType:
with pytest.raises(TypeMismatchError) as exc_info: with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.STRING, None) build_segment_with_type(SegmentType.STRING, None)
assert "Expected string, but got None" in str(exc_info.value) assert "expected string, but got None" in str(exc_info.value)
def test_type_mismatch_empty_list_to_non_array(self): def test_type_mismatch_empty_list_to_non_array(self):
"""Test type mismatch when expecting non-array type but getting empty list.""" """Test type mismatch when expecting non-array type but getting empty list."""
with pytest.raises(TypeMismatchError) as exc_info: with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.STRING, []) build_segment_with_type(SegmentType.STRING, [])
assert "Expected string, but got empty list" in str(exc_info.value) assert "expected string, but got empty list" in str(exc_info.value)
def test_type_mismatch_object_to_array(self): def test_type_mismatch_object_to_array(self):
"""Test type mismatch when expecting array but getting object.""" """Test type mismatch when expecting array but getting object."""
@ -678,19 +678,19 @@ class TestBuildSegmentWithType:
# Integer should work # Integer should work
result_int = build_segment_with_type(SegmentType.NUMBER, 42) result_int = build_segment_with_type(SegmentType.NUMBER, 42)
assert isinstance(result_int, IntegerSegment) assert isinstance(result_int, IntegerSegment)
assert result_int.value_type == SegmentType.NUMBER assert result_int.value_type == SegmentType.INTEGER
# Float should work # Float should work
result_float = build_segment_with_type(SegmentType.NUMBER, 3.14) result_float = build_segment_with_type(SegmentType.NUMBER, 3.14)
assert isinstance(result_float, FloatSegment) assert isinstance(result_float, FloatSegment)
assert result_float.value_type == SegmentType.NUMBER assert result_float.value_type == SegmentType.FLOAT
@pytest.mark.parametrize( @pytest.mark.parametrize(
("segment_type", "value", "expected_class"), ("segment_type", "value", "expected_class"),
[ [
(SegmentType.STRING, "test", StringSegment), (SegmentType.STRING, "test", StringSegment),
(SegmentType.NUMBER, 42, IntegerSegment), (SegmentType.INTEGER, 42, IntegerSegment),
(SegmentType.NUMBER, 3.14, FloatSegment), (SegmentType.FLOAT, 3.14, FloatSegment),
(SegmentType.OBJECT, {}, ObjectSegment), (SegmentType.OBJECT, {}, ObjectSegment),
(SegmentType.NONE, None, NoneSegment), (SegmentType.NONE, None, NoneSegment),
(SegmentType.ARRAY_STRING, [], ArrayStringSegment), (SegmentType.ARRAY_STRING, [], ArrayStringSegment),
@ -861,5 +861,5 @@ class TestBuildSegmentValueErrors:
# Verify they are processed as integers, not as errors # Verify they are processed as integers, not as errors
assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1" assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1"
assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0" assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0"
assert true_segment.value_type == SegmentType.NUMBER assert true_segment.value_type == SegmentType.INTEGER
assert false_segment.value_type == SegmentType.NUMBER assert false_segment.value_type == SegmentType.INTEGER

Loading…
Cancel
Save