From 601fcbc21ae148d6a5b85822037d4c4c92d27c19 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Thu, 15 May 2025 12:34:58 +0800 Subject: [PATCH] feat(api): Add `EnumText` column type for SQLAlchemy Introduce `EnumText`, a custom column type for SQLAlchemy designed to work seamlessly with enumeration classes based on `StrEnum`. This type stores enum members as `VARCHAR` in the database and automatically handles conversion between the enumeration type and its string representation during reads and writes. Additionally, it validates that the stored and retrieved values are valid members of the associated enumeration class. --- api/models/types.py | 53 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 52 insertions(+), 1 deletion(-) diff --git a/api/models/types.py b/api/models/types.py index cb6773e70c..6235c2594e 100644 --- a/api/models/types.py +++ b/api/models/types.py @@ -1,4 +1,7 @@ -from sqlalchemy import CHAR, TypeDecorator +import enum +from typing import Generic, TypeVar + +from sqlalchemy import CHAR, VARCHAR, TypeDecorator from sqlalchemy.dialects.postgresql import UUID @@ -24,3 +27,51 @@ class StringUUID(TypeDecorator): if value is None: return value return str(value) + + +_E = TypeVar("_E", bound=enum.StrEnum) + + +class EnumText(TypeDecorator, Generic[_E]): + impl = VARCHAR + cache_ok = True + + _length: int + _enum_class: type[_E] + + def __init__(self, enum_class: type[_E], length: int | None = None): + self._enum_class = enum_class + max_enum_value_len = max(len(e.value) for e in enum_class) + if length is not None: + if length < max_enum_value_len: + raise ValueError("length should be greater than enum value length.") + self._length = length + else: + # leave some rooms for future longer enum values. + self._length = max(max_enum_value_len, 20) + + def process_bind_param(self, value: _E | str, dialect): + if value is None: + return value + if isinstance(value, self._enum_class): + return value.value + elif isinstance(value, str): + self._enum_class(value) + return value + else: + raise TypeError(f"expected str or {self._enum_class}, got {type(value)}") + + def load_dialect_impl(self, dialect): + return dialect.type_descriptor(VARCHAR(self._length)) + + def process_result_value(self, value, dialect) -> _E | None: + if value is None: + return value + if not isinstance(value, str): + raise TypeError(f"expected str, got {type(value)}") + return self._enum_class(value) + + def compare_values(self, x, y): + if x is None or y is None: + return x is y + return x == y