feat(api): Making VariablePool and GraphRuntimeState serializable

pull/22025/head
QuantumGhost 11 months ago
parent 11bc5c0dec
commit d9bc894bb9

@ -68,13 +68,18 @@ def _create_pagination_parser():
return parser
def _serialize_variable_type(workflow_draft_var: WorkflowDraftVariable) -> str:
value_type = workflow_draft_var.value_type
return value_type.exposed_type().value
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
"id": fields.String,
"type": fields.String(attribute=lambda model: model.get_variable_type()),
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String,
"value_type": fields.String(attribute=_serialize_variable_type),
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}
@ -90,7 +95,7 @@ _WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
"name": fields.String,
"description": fields.String,
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
"value_type": fields.String,
"value_type": fields.String(attribute=_serialize_variable_type),
"edited": fields.Boolean(attribute=lambda model: model.edited),
"visible": fields.Boolean,
}
@ -396,7 +401,7 @@ class EnvironmentVariableCollectionApi(Resource):
"name": v.name,
"description": v.description,
"selector": v.selector,
"value_type": v.value_type.value,
"value_type": v.value_type.exposed_type().value,
"value": v.value,
# Do not track edited for env vars.
"edited": False,

@ -16,6 +16,7 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
)
from core.moderation.base import ModerationError
from core.variables.variables import VariableUnion
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.system_variable import SystemVariable
@ -152,7 +153,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
# TODO(QuantumGhost): find a better way to resolve typing issue.
conversation_variables=cast(list[VariableUnion], conversation_variables),
)
# init graph

@ -1,9 +1,9 @@
import json
import sys
from collections.abc import Mapping, Sequence
from typing import Any
from typing import Annotated, Any
from pydantic import BaseModel, ConfigDict, field_validator
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
from core.file import File
@ -11,6 +11,11 @@ from .types import SegmentType
class Segment(BaseModel):
"""Segment is runtime type used during the execution of workflow.
Note: this class is abstract, you should use subclasses of this class instead.
"""
model_config = ConfigDict(frozen=True)
value_type: SegmentType
@ -73,7 +78,7 @@ class StringSegment(Segment):
class FloatSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER
value_type: SegmentType = SegmentType.FLOAT
value: float
# NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems.
# The following tests cannot pass.
@ -92,7 +97,7 @@ class FloatSegment(Segment):
class IntegerSegment(Segment):
value_type: SegmentType = SegmentType.NUMBER
value_type: SegmentType = SegmentType.INTEGER
value: int
@ -181,3 +186,38 @@ class ArrayFileSegment(ArraySegment):
@property
def text(self) -> str:
return ""
def get_segment_discriminator(v: Any) -> SegmentType | None:
if isinstance(v, Segment):
return v.value_type
elif isinstance(v, dict):
value_type = v.get("value_type")
if value_type is None:
return None
try:
seg_type = SegmentType(value_type)
except ValueError:
return None
return seg_type
else:
# return None if the discriminator value isn't found
return None
SegmentUnion = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]
| Annotated[StringSegment, Tag(SegmentType.STRING)]
| Annotated[FloatSegment, Tag(SegmentType.FLOAT)]
| Annotated[IntegerSegment, Tag(SegmentType.INTEGER)]
| Annotated[ObjectSegment, Tag(SegmentType.OBJECT)]
| Annotated[FileSegment, Tag(SegmentType.FILE)]
| Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
),
Discriminator(get_segment_discriminator),
]

@ -1,8 +1,27 @@
from collections.abc import Mapping
from enum import StrEnum
from typing import Any, Optional
from core.file.models import File
class ArrayValidation(StrEnum):
"""Strategy for validating array elements"""
# Skip element validation (only check array container)
NONE = "none"
# Validate the first element (if array is non-empty)
FIRST = "first"
# Validate all elements in the array.
ALL = "all"
class SegmentType(StrEnum):
NUMBER = "number"
INTEGER = "integer"
FLOAT = "float"
STRING = "string"
OBJECT = "object"
SECRET = "secret"
@ -19,16 +38,138 @@ class SegmentType(StrEnum):
GROUP = "group"
def is_array_type(self):
def is_array_type(self) -> bool:
return self in _ARRAY_TYPES
@classmethod
def infer_segment_type(cls, value: Any) -> Optional["SegmentType"]:
"""Try to infer the SegmentType from the Python type of
the `value` parameter.
"""
if isinstance(value, list):
elem_types: set[SegmentType] = set()
for i in value:
segment_type = cls.infer_segment_type(i)
if segment_type is None:
return None
elem_types.add(segment_type)
if len(elem_types) != 1:
if elem_types.issubset(_NUMERICAL_TYPES):
return SegmentType.ARRAY_NUMBER
return SegmentType.ARRAY_ANY
elif all(i.is_array_type() for i in elem_types):
return SegmentType.ARRAY_ANY
match elem_types.pop():
case SegmentType.STRING:
return SegmentType.ARRAY_STRING
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return SegmentType.ARRAY_NUMBER
case SegmentType.OBJECT:
return SegmentType.ARRAY_OBJECT
case SegmentType.FILE:
return SegmentType.ARRAY_FILE
case SegmentType.NONE:
return SegmentType.ARRAY_ANY
case _:
# This should be unreachable.
raise ValueError(f"not supported value {value}")
if value is None:
return SegmentType.NONE
elif isinstance(value, int) and not isinstance(value, bool):
return SegmentType.INTEGER
elif isinstance(value, float):
return SegmentType.FLOAT
elif isinstance(value, str):
return SegmentType.STRING
elif isinstance(value, dict):
return SegmentType.OBJECT
elif isinstance(value, File):
return SegmentType.FILE
elif isinstance(value, str):
return SegmentType.STRING
else:
return None
def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool:
if not isinstance(value, list):
return False
# Skip element validation if array is empty
if len(value) == 0:
return True
if self == SegmentType.ARRAY_ANY:
return True
element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self]
if array_validation == ArrayValidation.NONE:
return True
elif array_validation == ArrayValidation.FIRST:
return element_type.is_valid(value[0])
else:
return all([element_type.is_valid(i, array_validation=ArrayValidation.NONE)] for i in value)
def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.FIRST) -> bool:
"""
Check if a value matches the segment type.
Users of `SegmentType` should call this method, instead of using
`isinstance` manually.
Args:
value: The value to validate
array_validation: Validation strategy for array types (ignored for non-array types)
Returns:
True if the value matches the type under the given validation strategy
"""
if self.is_array_type():
return self._validate_array(value, array_validation)
elif self == SegmentType.NUMBER:
return isinstance(value, (int, float))
elif self == SegmentType.STRING:
return isinstance(value, str)
elif self == SegmentType.OBJECT:
return isinstance(value, dict)
elif self == SegmentType.SECRET:
return isinstance(value, str)
elif self == SegmentType.FILE:
return isinstance(value, File)
elif self == SegmentType.NONE:
return value is None
else:
raise AssertionError("this statement should be unreachable.")
def exposed_type(self) -> "SegmentType":
"""Returns the type exposed to the frontend.
The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here.
"""
if self in (SegmentType.INTEGER, SegmentType.FLOAT):
return SegmentType.NUMBER
return self
_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = {
# ARRAY_ANY does not have correpond element type.
SegmentType.ARRAY_STRING: SegmentType.STRING,
SegmentType.ARRAY_NUMBER: SegmentType.NUMBER,
SegmentType.ARRAY_OBJECT: SegmentType.OBJECT,
SegmentType.ARRAY_FILE: SegmentType.FILE,
}
_ARRAY_TYPES = frozenset(
[
list(_ARRAY_ELEMENT_TYPES_MAPPING.keys())
+ [
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_FILE,
]
)
_NUMERICAL_TYPES = frozenset(
[
SegmentType.NUMBER,
SegmentType.INTEGER,
SegmentType.FLOAT,
]
)

@ -1,10 +1,11 @@
from collections.abc import Sequence
from typing import cast
from typing import Annotated, TypeAlias, cast
from uuid import uuid4
from pydantic import Field
from pydantic import Discriminator, Field, Tag
from core.helper import encrypter
from core.variables.segment_group import SegmentGroup
from .segments import (
ArrayAnySegment,
@ -20,6 +21,7 @@ from .segments import (
ObjectSegment,
Segment,
StringSegment,
get_segment_discriminator,
)
from .types import SegmentType
@ -27,6 +29,10 @@ from .types import SegmentType
class Variable(Segment):
"""
A variable is a segment that has a name.
It is mainly used to store segments and their selector in VariablePool.
Note: this class is abstract, you should use subclasses of this class instead.
"""
id: str = Field(
@ -93,3 +99,22 @@ class FileVariable(FileSegment, Variable):
class ArrayFileVariable(ArrayFileSegment, ArrayVariable):
pass
VariableUnion: TypeAlias = Annotated[
(
Annotated[NoneVariable, Tag(SegmentType.NONE)]
| Annotated[StringVariable, Tag(SegmentType.STRING)]
| Annotated[FloatVariable, Tag(SegmentType.FLOAT)]
| Annotated[IntegerVariable, Tag(SegmentType.INTEGER)]
| Annotated[ObjectVariable, Tag(SegmentType.OBJECT)]
| Annotated[FileVariable, Tag(SegmentType.FILE)]
| Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)]
| Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)]
| Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)]
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
),
Discriminator(get_segment_discriminator),
]

@ -1,7 +1,7 @@
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union
from typing import Any, Union, cast
from pydantic import BaseModel, Field
@ -9,8 +9,9 @@ from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables.consts import MIN_SELECTORS_LENGTH
from core.variables.segments import FileSegment, NoneSegment
from core.variables.variables import VariableUnion
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.enums import SystemVariableKey
from core.workflow.system_variable import SystemVariable
from factories import variable_factory
VariableValue = Union[str, int, float, dict, list, File]
@ -23,31 +24,35 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: dict[str, dict[int, Segment]] = Field(
variable_dictionary: dict[str, dict[int, VariableUnion]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
# TODO: This user inputs is not used for pool.
# The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere.
user_inputs: Mapping[str, Any] = Field(
description="User inputs",
default_factory=dict,
)
system_variables: Mapping[SystemVariableKey, Any] = Field(
# system_variables: Mapping[SystemVariableKey, Any] = Field(
# description="System variables",
# default_factory=dict,
# )
system_variables: SystemVariable = Field(
description="System variables",
default_factory=dict,
)
environment_variables: Sequence[Variable] = Field(
environment_variables: Sequence[VariableUnion] = Field(
description="Environment variables.",
default_factory=list,
)
conversation_variables: Sequence[Variable] = Field(
conversation_variables: Sequence[VariableUnion] = Field(
description="Conversation variables.",
default_factory=list,
)
def model_post_init(self, context: Any, /) -> None:
for key, value in self.system_variables.items():
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
# Create a mapping from field names to SystemVariableKey enum values
self._add_system_variables(self.system_variables)
# Add environment variables to the variable pool
for var in self.environment_variables:
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
@ -83,8 +88,23 @@ class VariablePool(BaseModel):
segment = variable_factory.build_segment(value)
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]][hash_key] = variable
key, hash_key = self._selector_to_keys(selector)
# Ensure the first-level key exists in the dictionary
if key not in self.variable_dictionary:
self.variable_dictionary[key] = {}
self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
@classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]:
return selector[0], hash(tuple(selector[1:]))
def _has(self, selector: Sequence[str]) -> bool:
key, hash_key = self._selector_to_keys(selector)
if key not in self.variable_dictionary:
return False
if hash_key not in self.variable_dictionary[key]:
return False
return True
def get(self, selector: Sequence[str], /) -> Segment | None:
"""
@ -102,8 +122,8 @@ class VariablePool(BaseModel):
if len(selector) < MIN_SELECTORS_LENGTH:
return None
hash_key = hash(tuple(selector[1:]))
value = self.variable_dictionary[selector[0]].get(hash_key)
key, hash_key = self._selector_to_keys(selector)
value: Segment | None = self.variable_dictionary[key].get(hash_key)
if value is None:
selector, attr = selector[:-1], selector[-1]
@ -136,8 +156,9 @@ class VariablePool(BaseModel):
if len(selector) == 1:
self.variable_dictionary[selector[0]] = {}
return
key, hash_key = self._selector_to_keys(selector)
hash_key = hash(tuple(selector[1:]))
self.variable_dictionary[selector[0]].pop(hash_key, None)
self.variable_dictionary[key].pop(hash_key, None)
def convert_template(self, template: str, /):
parts = VARIABLE_PATTERN.split(template)
@ -154,3 +175,31 @@ class VariablePool(BaseModel):
if isinstance(segment, FileSegment):
return segment
return None
def _add_system_variables(self, system_variable: SystemVariable):
sys_var_mapping = system_variable.to_dict()
for key, value in sys_var_mapping.items():
if value is None:
continue
selector = (SYSTEM_VARIABLE_NODE_ID, key)
# If the system variable already exists, do not add it again.
# This ensures that we can keep the id of the system variables intact.
if self._has(selector):
continue
self.add(selector, value) # type: ignore
@classmethod
def empty(cls) -> "VariablePool":
"""Create an empty variable pool."""
return cls(system_variables=SystemVariable.empty())
def dumps(self) -> str:
return self.model_dump_json()
@classmethod
def loads(cls, json_data: str) -> "VariablePool":
return VariablePool.model_validate_json(json_data)
def reload_storage_keys_for_file_types(self):
# TODO
pass

@ -17,8 +17,12 @@ class GraphRuntimeState(BaseModel):
"""total tokens"""
llm_usage: LLMUsage = LLMUsage.empty_usage()
"""llm usage info"""
# The `outputs` field stores the final output values generated by executing workflows or chatflows.
#
# Note: Since the type of this field is `dict[str, Any]`, its values may not remain consistent
# after a serialization and deserialization round trip.
outputs: dict[str, Any] = {}
"""outputs"""
node_run_steps: int = 0
"""node run steps"""

@ -1,11 +1,29 @@
from collections.abc import Mapping
from typing import Any, Literal, Optional
from typing import Annotated, Any, Literal, Optional
from pydantic import BaseModel, Field
from pydantic import AfterValidator, BaseModel, Field
from core.variables.types import SegmentType
from core.workflow.nodes.base import BaseLoopNodeData, BaseLoopState, BaseNodeData
from core.workflow.utils.condition.entities import Condition
_VALID_VAR_TYPE = frozenset(
[
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_OBJECT,
]
)
def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
if seg_type not in _VALID_VAR_TYPE:
raise ValueError(...)
return seg_type
class LoopVariableData(BaseModel):
"""
@ -13,7 +31,7 @@ class LoopVariableData(BaseModel):
"""
label: str
var_type: Literal["string", "number", "object", "array[string]", "array[number]", "array[object]"]
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"]
value: Optional[Any | list[str]] = None

@ -7,14 +7,9 @@ from typing import TYPE_CHECKING, Any, Literal, cast
from configs import dify_config
from core.variables import (
ArrayNumberSegment,
ArrayObjectSegment,
ArrayStringSegment,
IntegerSegment,
ObjectSegment,
Segment,
SegmentType,
StringSegment,
)
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@ -39,6 +34,7 @@ from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.event import NodeEvent, RunCompletedEvent
from core.workflow.nodes.loop.entities import LoopNodeData
from core.workflow.utils.condition.processor import ConditionProcessor
from factories.variable_factory import TypeMismatchError, build_segment_with_type
if TYPE_CHECKING:
from core.workflow.entities.variable_pool import VariablePool
@ -505,23 +501,21 @@ class LoopNode(BaseNode[LoopNodeData]):
return variable_mapping
@staticmethod
def _get_segment_for_constant(var_type: str, value: Any) -> Segment:
def _get_segment_for_constant(var_type: SegmentType, value: Any) -> Segment:
"""Get the appropriate segment type for a constant value."""
segment_mapping: dict[str, tuple[type[Segment], SegmentType]] = {
"string": (StringSegment, SegmentType.STRING),
"number": (IntegerSegment, SegmentType.NUMBER),
"object": (ObjectSegment, SegmentType.OBJECT),
"array[string]": (ArrayStringSegment, SegmentType.ARRAY_STRING),
"array[number]": (ArrayNumberSegment, SegmentType.ARRAY_NUMBER),
"array[object]": (ArrayObjectSegment, SegmentType.ARRAY_OBJECT),
}
if var_type in ["array[string]", "array[number]", "array[object]"]:
if value:
if value and isinstance(value, str):
value = json.loads(value)
else:
value = []
segment_info = segment_mapping.get(var_type)
if not segment_info:
raise ValueError(f"Invalid variable type: {var_type}")
segment_class, value_type = segment_info
return segment_class(value=value, value_type=value_type)
try:
return build_segment_with_type(var_type, value)
except TypeMismatchError as type_exc:
# Attempt to parse the value as a JSON-encoded string, if applicable.
if not isinstance(value, str):
raise
try:
value = json.loads(value)
except ValueError:
raise type_exc
return build_segment_with_type(var_type, value)

@ -16,7 +16,7 @@ class StartNode(BaseNode[StartNodeData]):
def _run(self) -> NodeRunResult:
node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs)
system_inputs = self.graph_runtime_state.variable_pool.system_variables
system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict()
# TODO: System variables should be directly accessible, no need for special handling
# Set system variables as node outputs.

@ -130,6 +130,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerData]):
def get_zero_value(t: SegmentType):
# TODO(QuantumGhost): this should be a method of `SegmentType`.
match t:
case SegmentType.ARRAY_OBJECT | SegmentType.ARRAY_STRING | SegmentType.ARRAY_NUMBER:
return variable_factory.build_segment([])
@ -137,6 +138,10 @@ def get_zero_value(t: SegmentType):
return variable_factory.build_segment({})
case SegmentType.STRING:
return variable_factory.build_segment("")
case SegmentType.INTEGER:
return variable_factory.build_segment(0)
case SegmentType.FLOAT:
return variable_factory.build_segment(0.0)
case SegmentType.NUMBER:
return variable_factory.build_segment(0)
case _:

@ -1,5 +1,6 @@
from core.variables import SegmentType
# Note: This mapping is duplicated with `get_zero_value`. Consider refactoring to avoid redundancy.
EMPTY_VALUE_MAPPING = {
SegmentType.STRING: "",
SegmentType.NUMBER: 0,

@ -10,10 +10,16 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
case Operation.OVER_WRITE | Operation.CLEAR:
return True
case Operation.SET:
return variable_type in {SegmentType.OBJECT, SegmentType.STRING, SegmentType.NUMBER}
return variable_type in {
SegmentType.OBJECT,
SegmentType.STRING,
SegmentType.NUMBER,
SegmentType.INTEGER,
SegmentType.FLOAT,
}
case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE:
# Only number variable can be added, subtracted, multiplied or divided
return variable_type == SegmentType.NUMBER
return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}
case Operation.APPEND | Operation.EXTEND:
# Only array variable can be appended or extended
return variable_type in {
@ -46,7 +52,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat
match variable_type:
case SegmentType.STRING | SegmentType.OBJECT:
return operation in {Operation.OVER_WRITE, Operation.SET}
case SegmentType.NUMBER:
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return operation in {
Operation.OVER_WRITE,
Operation.SET,
@ -66,7 +72,7 @@ def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, va
case SegmentType.STRING:
return isinstance(value, str)
case SegmentType.NUMBER:
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
if not isinstance(value, int | float):
return False
if operation == Operation.DIVIDE and value == 0:

@ -60,7 +60,7 @@ class WorkflowCycleManager:
# Iterate over SystemVariable fields using Pydantic's model_fields
if self._workflow_system_variables:
for field_name, value in self._workflow_system_variables.to_dict():
for field_name, value in self._workflow_system_variables.to_dict().items():
if field_name == SystemVariableKey.CONVERSATION_ID:
continue
inputs[f"sys.{field_name}"] = value

@ -21,6 +21,7 @@ from core.workflow.nodes import NodeType
from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.event import NodeEvent
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from factories import file_factory
from models.enums import UserFrom
@ -254,7 +255,7 @@ class WorkflowEntry:
# init variable pool
variable_pool = VariablePool(
system_variables={},
system_variables=SystemVariable.empty(),
user_inputs={},
environment_variables=[],
)

@ -2,6 +2,8 @@ from collections.abc import Mapping, Sequence
from typing import Any, cast
from uuid import uuid4
from shapely import is_valid
from configs import dify_config
from core.file import File
from core.variables.exc import VariableError
@ -91,9 +93,13 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
result = StringVariable.model_validate(mapping)
case SegmentType.SECRET:
result = SecretVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, int):
case SegmentType.NUMBER | SegmentType.INTEGER if isinstance(value, int):
mapping = dict(mapping)
mapping["value_type"] = SegmentType.INTEGER
result = IntegerVariable.model_validate(mapping)
case SegmentType.NUMBER if isinstance(value, float):
case SegmentType.NUMBER | SegmentType.FLOAT if isinstance(value, float):
mapping = dict(mapping)
mapping["value_type"] = SegmentType.FLOAT
result = FloatVariable.model_validate(mapping)
case SegmentType.NUMBER if not isinstance(value, float | int):
raise VariableError(f"invalid number value {value}")
@ -119,6 +125,8 @@ def infer_segment_type_from_value(value: Any, /) -> SegmentType:
def build_segment(value: Any, /) -> Segment:
# NOTE: If you have runtime type information available, consider using the `build_segment_with_type`
# below
if value is None:
return NoneSegment()
if isinstance(value, str):
@ -134,12 +142,17 @@ def build_segment(value: Any, /) -> Segment:
if isinstance(value, list):
items = [build_segment(item) for item in value]
types = {item.value_type for item in items}
if len(types) != 1 or all(isinstance(item, ArraySegment) for item in items):
if all(isinstance(item, ArraySegment) for item in items):
return ArrayAnySegment(value=value)
elif len(types) != 1:
if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}):
return ArrayNumberSegment(value=value)
return ArrayAnySegment(value=value)
match types.pop():
case SegmentType.STRING:
return ArrayStringSegment(value=value)
case SegmentType.NUMBER:
case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT:
return ArrayNumberSegment(value=value)
case SegmentType.OBJECT:
return ArrayObjectSegment(value=value)
@ -153,6 +166,22 @@ def build_segment(value: Any, /) -> Segment:
raise ValueError(f"not supported value {value}")
_segment_factory: Mapping[SegmentType, type[Segment]] = {
SegmentType.NONE: NoneSegment,
SegmentType.STRING: StringSegment,
SegmentType.INTEGER: IntegerSegment,
SegmentType.FLOAT: FloatSegment,
SegmentType.FILE: FileSegment,
SegmentType.OBJECT: ObjectSegment,
# Array types
SegmentType.ARRAY_ANY: ArrayAnySegment,
SegmentType.ARRAY_STRING: ArrayStringSegment,
SegmentType.ARRAY_NUMBER: ArrayNumberSegment,
SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
SegmentType.ARRAY_FILE: ArrayFileSegment,
}
def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
"""
Build a segment with explicit type checking.
@ -190,7 +219,7 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
if segment_type == SegmentType.NONE:
return NoneSegment()
else:
raise TypeMismatchError(f"Expected {segment_type}, but got None")
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None")
# Handle empty list special case for array types
if isinstance(value, list) and len(value) == 0:
@ -205,21 +234,25 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
elif segment_type == SegmentType.ARRAY_FILE:
return ArrayFileSegment(value=value)
else:
raise TypeMismatchError(f"Expected {segment_type}, but got empty list")
# Build segment using existing logic to infer actual type
inferred_segment = build_segment(value)
inferred_type = inferred_segment.value_type
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list")
inferred_type = SegmentType.infer_segment_type(value)
# Type compatibility checking
if inferred_type is None:
raise TypeMismatchError(
f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}"
)
if inferred_type == segment_type:
return inferred_segment
# Type mismatch - raise error with descriptive message
raise TypeMismatchError(
f"Type mismatch: expected {segment_type}, but value '{value}' "
f"(type: {type(value).__name__}) corresponds to {inferred_type}"
)
segment_class = _segment_factory[segment_type]
return segment_class(value_type=segment_type, value=value)
elif segment_type == SegmentType.NUMBER and inferred_type in (
SegmentType.INTEGER,
SegmentType.FLOAT,
):
segment_class = _segment_factory[inferred_type]
return segment_class(value_type=inferred_type, value=value)
else:
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
def segment_to_variable(
@ -247,6 +280,6 @@ def segment_to_variable(
name=name,
description=description,
value=segment.value,
selector=selector,
selector=list(selector),
),
)

@ -0,0 +1,15 @@
from typing import TypedDict
from core.variables.segments import Segment
from core.variables.types import SegmentType
class _VarTypedDict(TypedDict, total=False):
value_type: SegmentType
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
if isinstance(v, Segment):
return v.value_type.exposed_type().value
else:
return v["value_type"].exposed_type().value

@ -2,10 +2,12 @@ from flask_restful import fields
from libs.helper import TimestampField
from ._value_type_serializer import serialize_value_type
conversation_variable_fields = {
"id": fields.String,
"name": fields.String,
"value_type": fields.String(attribute="value_type.value"),
"value_type": fields.String(attribute=serialize_value_type),
"value": fields.String,
"description": fields.String,
"created_at": TimestampField,

@ -5,6 +5,8 @@ from core.variables import SecretVariable, SegmentType, Variable
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
from ._value_type_serializer import serialize_value_type
ENVIRONMENT_VARIABLE_SUPPORTED_TYPES = (SegmentType.STRING, SegmentType.NUMBER, SegmentType.SECRET)
@ -23,10 +25,15 @@ class EnvironmentVariableField(fields.Raw):
"id": value.id,
"name": value.name,
"value": value.value,
"value_type": value.value_type.value,
"value_type": value.value_type.exposed_type().value,
}
if isinstance(value, dict):
value_type = value.get("value_type")
value_type_str = value.get("value_type")
if not isinstance(value_type_str, str):
raise TypeError(
f"unexpected type for value_type field, value={value_type_str}, type={type(value_type_str)}"
)
value_type = SegmentType(value_type_str).exposed_type()
if value_type not in ENVIRONMENT_VARIABLE_SUPPORTED_TYPES:
raise ValueError(f"Unsupported environment variable value type: {value_type}")
return value
@ -35,7 +42,7 @@ class EnvironmentVariableField(fields.Raw):
conversation_variable_fields = {
"id": fields.String,
"name": fields.String,
"value_type": fields.String(attribute="value_type.value"),
"value_type": fields.String(attribute=serialize_value_type),
"value": fields.Raw,
"description": fields.String,
}

@ -12,6 +12,7 @@ from sqlalchemy import orm
from core.file.constants import maybe_file_object
from core.file.models import File
from core.variables import utils as variable_utils
from core.variables.variables import FloatVariable, IntegerVariable, StringVariable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes.enums import NodeType
from factories.variable_factory import TypeMismatchError, build_segment_with_type
@ -346,7 +347,7 @@ class Workflow(Base):
)
@property
def environment_variables(self) -> Sequence[Variable]:
def environment_variables(self) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# TODO: find some way to init `self._environment_variables` when instance created.
if self._environment_variables is None:
self._environment_variables = "{}"
@ -371,11 +372,15 @@ class Workflow(Base):
def decrypt_func(var):
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
else:
elif isinstance(var, (StringVariable, IntegerVariable, FloatVariable)):
return var
else:
raise AssertionError("this statement should be unreachable.")
results = list(map(decrypt_func, results))
return results
decrypted_results: list[SecretVariable | StringVariable | IntegerVariable | FloatVariable] = list(
map(decrypt_func, results)
)
return decrypted_results
@environment_variables.setter
def environment_variables(self, value: Sequence[Variable]):

@ -3,7 +3,7 @@ import time
import uuid
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional
from typing import Any, Optional, cast
from uuid import uuid4
from sqlalchemy import select
@ -15,6 +15,7 @@ from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables import Variable
from core.variables.variables import VariableUnion
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecution, WorkflowNodeExecutionStatus
@ -692,7 +693,7 @@ def _setup_variable_pool(
system_variables=system_variable,
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
conversation_variables=conversation_variables,
conversation_variables=cast(list[VariableUnion], conversation_variables),
)
return variable_pool

@ -11,6 +11,7 @@ from core.variables import (
SegmentType,
StringVariable,
)
from core.variables.variables import Variable
def test_frozen_variables():
@ -75,7 +76,7 @@ def test_object_variable_to_object():
def test_variable_to_object():
var = StringVariable(name="text", value="text")
var: Variable = StringVariable(name="text", value="text")
assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42

@ -509,8 +509,8 @@ def test_build_segment_type_for_scalar():
size=1000,
)
cases = [
TestCase(0, SegmentType.NUMBER),
TestCase(0.0, SegmentType.NUMBER),
TestCase(0, SegmentType.INTEGER),
TestCase(0.0, SegmentType.FLOAT),
TestCase("", SegmentType.STRING),
TestCase(file, SegmentType.FILE),
]
@ -535,14 +535,14 @@ class TestBuildSegmentWithType:
result = build_segment_with_type(SegmentType.NUMBER, 42)
assert isinstance(result, IntegerSegment)
assert result.value == 42
assert result.value_type == SegmentType.NUMBER
assert result.value_type == SegmentType.INTEGER
def test_number_type_float(self):
"""Test building a number segment with float value."""
result = build_segment_with_type(SegmentType.NUMBER, 3.14)
assert isinstance(result, FloatSegment)
assert result.value == 3.14
assert result.value_type == SegmentType.NUMBER
assert result.value_type == SegmentType.FLOAT
def test_object_type(self):
"""Test building an object segment with correct type."""
@ -656,14 +656,14 @@ class TestBuildSegmentWithType:
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.STRING, None)
assert "Expected string, but got None" in str(exc_info.value)
assert "expected string, but got None" in str(exc_info.value)
def test_type_mismatch_empty_list_to_non_array(self):
"""Test type mismatch when expecting non-array type but getting empty list."""
with pytest.raises(TypeMismatchError) as exc_info:
build_segment_with_type(SegmentType.STRING, [])
assert "Expected string, but got empty list" in str(exc_info.value)
assert "expected string, but got empty list" in str(exc_info.value)
def test_type_mismatch_object_to_array(self):
"""Test type mismatch when expecting array but getting object."""
@ -678,19 +678,19 @@ class TestBuildSegmentWithType:
# Integer should work
result_int = build_segment_with_type(SegmentType.NUMBER, 42)
assert isinstance(result_int, IntegerSegment)
assert result_int.value_type == SegmentType.NUMBER
assert result_int.value_type == SegmentType.INTEGER
# Float should work
result_float = build_segment_with_type(SegmentType.NUMBER, 3.14)
assert isinstance(result_float, FloatSegment)
assert result_float.value_type == SegmentType.NUMBER
assert result_float.value_type == SegmentType.FLOAT
@pytest.mark.parametrize(
("segment_type", "value", "expected_class"),
[
(SegmentType.STRING, "test", StringSegment),
(SegmentType.NUMBER, 42, IntegerSegment),
(SegmentType.NUMBER, 3.14, FloatSegment),
(SegmentType.INTEGER, 42, IntegerSegment),
(SegmentType.FLOAT, 3.14, FloatSegment),
(SegmentType.OBJECT, {}, ObjectSegment),
(SegmentType.NONE, None, NoneSegment),
(SegmentType.ARRAY_STRING, [], ArrayStringSegment),
@ -861,5 +861,5 @@ class TestBuildSegmentValueErrors:
# Verify they are processed as integers, not as errors
assert true_segment.value == 1, "Test case 1 (boolean_true): Expected True to be processed as integer 1"
assert false_segment.value == 0, "Test case 2 (boolean_false): Expected False to be processed as integer 0"
assert true_segment.value_type == SegmentType.NUMBER
assert false_segment.value_type == SegmentType.NUMBER
assert true_segment.value_type == SegmentType.INTEGER
assert false_segment.value_type == SegmentType.INTEGER

Loading…
Cancel
Save