chore(api): ensure `variable_dictionary` is always a defaultdict

Even after serialization / deserialization roundtrip
pull/22025/head
QuantumGhost 10 months ago
parent 2e2f654dd3
commit 2dd98831a1

@ -1,7 +1,7 @@
import re
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Union, cast
from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
@ -24,7 +24,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: dict[str, dict[int, VariableUnion]] = Field(
variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
@ -85,9 +85,6 @@ class VariablePool(BaseModel):
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
key, hash_key = self._selector_to_keys(selector)
# Ensure the first-level key exists in the dictionary
if key not in self.variable_dictionary:
self.variable_dictionary[key] = {}
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
@ -190,10 +187,3 @@ class VariablePool(BaseModel):
def empty(cls) -> "VariablePool":
"""Create an empty variable pool."""
return cls(system_variables=SystemVariable.empty())
def dumps(self) -> str:
return self.model_dump_json()
@classmethod
def loads(cls, json_data: str) -> "VariablePool":
return VariablePool.model_validate_json(json_data)

Loading…
Cancel
Save