chore(api): ensure `variable_dictionary` is always a defaultdict

Even after serialization / deserialization roundtrip
pull/22025/head
QuantumGhost 10 months ago
parent 2e2f654dd3
commit 2dd98831a1

@ -1,7 +1,7 @@
import re import re
from collections import defaultdict from collections import defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Union, cast from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -24,7 +24,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary. # The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one. # elements of the selector except the first one.
variable_dictionary: dict[str, dict[int, VariableUnion]] = Field( variable_dictionary: defaultdict[str, Annotated[dict[int, VariableUnion], Field(default_factory=dict)]] = Field(
description="Variables mapping", description="Variables mapping",
default=defaultdict(dict), default=defaultdict(dict),
) )
@ -85,9 +85,6 @@ class VariablePool(BaseModel):
variable = variable_factory.segment_to_variable(segment=segment, selector=selector) variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
key, hash_key = self._selector_to_keys(selector) key, hash_key = self._selector_to_keys(selector)
# Ensure the first-level key exists in the dictionary
if key not in self.variable_dictionary:
self.variable_dictionary[key] = {}
# Based on the definition of `VariableUnion`, # Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible. # `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable) self.variable_dictionary[key][hash_key] = cast(VariableUnion, variable)
@ -190,10 +187,3 @@ class VariablePool(BaseModel):
def empty(cls) -> "VariablePool": def empty(cls) -> "VariablePool":
"""Create an empty variable pool.""" """Create an empty variable pool."""
return cls(system_variables=SystemVariable.empty()) return cls(system_variables=SystemVariable.empty())
def dumps(self) -> str:
return self.model_dump_json()
@classmethod
def loads(cls, json_data: str) -> "VariablePool":
return VariablePool.model_validate_json(json_data)

Loading…
Cancel
Save