|
|
|
@ -1,7 +1,7 @@
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
from collections import defaultdict
|
|
|
|
from collections import defaultdict
|
|
|
|
from collections.abc import Mapping, Sequence
|
|
|
|
from collections.abc import Mapping, Sequence
|
|
|
|
from typing import Any, Union
|
|
|
|
from typing import Any, Union, cast
|
|
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
|
|
|
|
|
|
|
|
@ -9,8 +9,9 @@ from core.file import File, FileAttribute, file_manager
|
|
|
|
from core.variables import Segment, SegmentGroup, Variable
|
|
|
|
from core.variables import Segment, SegmentGroup, Variable
|
|
|
|
from core.variables.consts import MIN_SELECTORS_LENGTH
|
|
|
|
from core.variables.consts import MIN_SELECTORS_LENGTH
|
|
|
|
from core.variables.segments import FileSegment, NoneSegment
|
|
|
|
from core.variables.segments import FileSegment, NoneSegment
|
|
|
|
|
|
|
|
from core.variables.variables import VariableUnion
|
|
|
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
|
|
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
|
|
|
from core.workflow.enums import SystemVariableKey
|
|
|
|
from core.workflow.system_variable import SystemVariable
|
|
|
|
from factories import variable_factory
|
|
|
|
from factories import variable_factory
|
|
|
|
|
|
|
|
|
|
|
|
VariableValue = Union[str, int, float, dict, list, File]
|
|
|
|
VariableValue = Union[str, int, float, dict, list, File]
|
|
|
|
@ -23,31 +24,35 @@ class VariablePool(BaseModel):
|
|
|
|
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
|
|
|
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
|
|
|
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
|
|
|
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
|
|
|
# elements of the selector except the first one.
|
|
|
|
# elements of the selector except the first one.
|
|
|
|
variable_dictionary: dict[str, dict[int, Segment]] = Field(
|
|
|
|
variable_dictionary: dict[str, dict[int, VariableUnion]] = Field(
|
|
|
|
description="Variables mapping",
|
|
|
|
description="Variables mapping",
|
|
|
|
default=defaultdict(dict),
|
|
|
|
default=defaultdict(dict),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
# TODO: This user inputs is not used for pool.
|
|
|
|
|
|
|
|
|
|
|
|
# The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere.
|
|
|
|
user_inputs: Mapping[str, Any] = Field(
|
|
|
|
user_inputs: Mapping[str, Any] = Field(
|
|
|
|
description="User inputs",
|
|
|
|
description="User inputs",
|
|
|
|
default_factory=dict,
|
|
|
|
default_factory=dict,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
system_variables: Mapping[SystemVariableKey, Any] = Field(
|
|
|
|
# system_variables: Mapping[SystemVariableKey, Any] = Field(
|
|
|
|
|
|
|
|
# description="System variables",
|
|
|
|
|
|
|
|
# default_factory=dict,
|
|
|
|
|
|
|
|
# )
|
|
|
|
|
|
|
|
system_variables: SystemVariable = Field(
|
|
|
|
description="System variables",
|
|
|
|
description="System variables",
|
|
|
|
default_factory=dict,
|
|
|
|
|
|
|
|
)
|
|
|
|
)
|
|
|
|
environment_variables: Sequence[Variable] = Field(
|
|
|
|
environment_variables: Sequence[VariableUnion] = Field(
|
|
|
|
description="Environment variables.",
|
|
|
|
description="Environment variables.",
|
|
|
|
default_factory=list,
|
|
|
|
default_factory=list,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
conversation_variables: Sequence[Variable] = Field(
|
|
|
|
conversation_variables: Sequence[VariableUnion] = Field(
|
|
|
|
description="Conversation variables.",
|
|
|
|
description="Conversation variables.",
|
|
|
|
default_factory=list,
|
|
|
|
default_factory=list,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def model_post_init(self, context: Any, /) -> None:
|
|
|
|
def model_post_init(self, context: Any, /) -> None:
|
|
|
|
for key, value in self.system_variables.items():
|
|
|
|
# Create a mapping from field names to SystemVariableKey enum values
|
|
|
|
self.add((SYSTEM_VARIABLE_NODE_ID, key.value), value)
|
|
|
|
self._add_system_variables(self.system_variables)
|
|
|
|
# Add environment variables to the variable pool
|
|
|
|
# Add environment variables to the variable pool
|
|
|
|
for var in self.environment_variables:
|
|
|
|
for var in self.environment_variables:
|
|
|
|
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
|
|
|
self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var)
|
|
|
|
@ -83,8 +88,23 @@ class VariablePool(BaseModel):
|
|
|
|
segment = variable_factory.build_segment(value)
|
|
|
|
segment = variable_factory.build_segment(value)
|
|
|
|
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
|
|
|
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
|
|
|
|
|
|
|
|
|
|
|
hash_key = hash(tuple(selector[1:]))
|
|
|
|
key, hash_key = self._selector_to_keys(selector)
|
|
|
|
self.variable_dictionary[selector[0]][hash_key] = variable
|
|
|
|
# Ensure the first-level key exists in the dictionary
|
|
|
|
|
|
|
|
if key not in self.variable_dictionary:
|
|
|
|
|
|
|
|
self.variable_dictionary[key] = {}
|
|
|
|
|
|
|
|
self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, int]:
|
|
|
|
|
|
|
|
return selector[0], hash(tuple(selector[1:]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _has(self, selector: Sequence[str]) -> bool:
|
|
|
|
|
|
|
|
key, hash_key = self._selector_to_keys(selector)
|
|
|
|
|
|
|
|
if key not in self.variable_dictionary:
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
if hash_key not in self.variable_dictionary[key]:
|
|
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
def get(self, selector: Sequence[str], /) -> Segment | None:
|
|
|
|
def get(self, selector: Sequence[str], /) -> Segment | None:
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
@ -102,8 +122,8 @@ class VariablePool(BaseModel):
|
|
|
|
if len(selector) < MIN_SELECTORS_LENGTH:
|
|
|
|
if len(selector) < MIN_SELECTORS_LENGTH:
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
hash_key = hash(tuple(selector[1:]))
|
|
|
|
key, hash_key = self._selector_to_keys(selector)
|
|
|
|
value = self.variable_dictionary[selector[0]].get(hash_key)
|
|
|
|
value: Segment | None = self.variable_dictionary[key].get(hash_key)
|
|
|
|
|
|
|
|
|
|
|
|
if value is None:
|
|
|
|
if value is None:
|
|
|
|
selector, attr = selector[:-1], selector[-1]
|
|
|
|
selector, attr = selector[:-1], selector[-1]
|
|
|
|
@ -136,8 +156,9 @@ class VariablePool(BaseModel):
|
|
|
|
if len(selector) == 1:
|
|
|
|
if len(selector) == 1:
|
|
|
|
self.variable_dictionary[selector[0]] = {}
|
|
|
|
self.variable_dictionary[selector[0]] = {}
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
key, hash_key = self._selector_to_keys(selector)
|
|
|
|
hash_key = hash(tuple(selector[1:]))
|
|
|
|
hash_key = hash(tuple(selector[1:]))
|
|
|
|
self.variable_dictionary[selector[0]].pop(hash_key, None)
|
|
|
|
self.variable_dictionary[key].pop(hash_key, None)
|
|
|
|
|
|
|
|
|
|
|
|
def convert_template(self, template: str, /):
|
|
|
|
def convert_template(self, template: str, /):
|
|
|
|
parts = VARIABLE_PATTERN.split(template)
|
|
|
|
parts = VARIABLE_PATTERN.split(template)
|
|
|
|
@ -154,3 +175,31 @@ class VariablePool(BaseModel):
|
|
|
|
if isinstance(segment, FileSegment):
|
|
|
|
if isinstance(segment, FileSegment):
|
|
|
|
return segment
|
|
|
|
return segment
|
|
|
|
return None
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _add_system_variables(self, system_variable: SystemVariable):
|
|
|
|
|
|
|
|
sys_var_mapping = system_variable.to_dict()
|
|
|
|
|
|
|
|
for key, value in sys_var_mapping.items():
|
|
|
|
|
|
|
|
if value is None:
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
selector = (SYSTEM_VARIABLE_NODE_ID, key)
|
|
|
|
|
|
|
|
# If the system variable already exists, do not add it again.
|
|
|
|
|
|
|
|
# This ensures that we can keep the id of the system variables intact.
|
|
|
|
|
|
|
|
if self._has(selector):
|
|
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
self.add(selector, value) # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
def empty(cls) -> "VariablePool":
|
|
|
|
|
|
|
|
"""Create an empty variable pool."""
|
|
|
|
|
|
|
|
return cls(system_variables=SystemVariable.empty())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def dumps(self) -> str:
|
|
|
|
|
|
|
|
return self.model_dump_json()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
|
|
|
def loads(cls, json_data: str) -> "VariablePool":
|
|
|
|
|
|
|
|
return VariablePool.model_validate_json(json_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reload_storage_keys_for_file_types(self):
|
|
|
|
|
|
|
|
# TODO
|
|
|
|
|
|
|
|
pass
|
|
|
|
|