commit
3dc299aaa3
@ -1,5 +1,11 @@
|
||||
blank_issues_enabled: false
|
||||
contact_links:
|
||||
- name: "\U0001F4A1 Model Providers & Plugins"
|
||||
url: "https://github.com/langgenius/dify-official-plugins/issues/new/choose"
|
||||
about: Report issues with official plugins or model providers, you will need to provide the plugin version and other relevant details.
|
||||
- name: "\U0001F4AC Documentation Issues"
|
||||
url: "https://github.com/langgenius/dify-docs/issues/new"
|
||||
about: Report issues with the documentation, such as typos, outdated information, or missing content. Please provide the specific section and details of the issue.
|
||||
- name: "\U0001F4E7 Discussions"
|
||||
url: https://github.com/langgenius/dify/discussions/categories/general
|
||||
about: General discussions and request help from the community
|
||||
about: General discussions and seek help from the community
|
||||
|
||||
@ -0,0 +1,89 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from core.file.models import File
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
|
||||
|
||||
class SystemVariable(BaseModel):
|
||||
"""A model for managing system variables.
|
||||
|
||||
Fields with a value of `None` are treated as absent and will not be included
|
||||
in the variable pool.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
serialize_by_alias=True,
|
||||
validate_by_alias=True,
|
||||
)
|
||||
|
||||
user_id: str | None = None
|
||||
|
||||
# Ideally, `app_id` and `workflow_id` should be required and not `None`.
|
||||
# However, there are scenarios in the codebase where these fields are not set.
|
||||
# To maintain compatibility, they are marked as optional here.
|
||||
app_id: str | None = None
|
||||
workflow_id: str | None = None
|
||||
|
||||
files: Sequence[File] = Field(default_factory=list)
|
||||
|
||||
# NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`.
|
||||
# To maintain compatibility with existing workflows, it must be serialized
|
||||
# as `workflow_run_id` in dictionaries or JSON objects, and also referenced
|
||||
# as `workflow_run_id` in the variable pool.
|
||||
workflow_execution_id: str | None = Field(
|
||||
validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"),
|
||||
serialization_alias="workflow_run_id",
|
||||
default=None,
|
||||
)
|
||||
# Chatflow related fields.
|
||||
query: str | None = None
|
||||
conversation_id: str | None = None
|
||||
dialogue_count: int | None = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_json_fields(cls, data):
|
||||
if isinstance(data, dict):
|
||||
# For JSON validation, only allow workflow_run_id
|
||||
if "workflow_execution_id" in data and "workflow_run_id" not in data:
|
||||
# This is likely from direct instantiation, allow it
|
||||
return data
|
||||
elif "workflow_execution_id" in data and "workflow_run_id" in data:
|
||||
# Both present, remove workflow_execution_id
|
||||
data = data.copy()
|
||||
data.pop("workflow_execution_id")
|
||||
return data
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> "SystemVariable":
|
||||
return cls()
|
||||
|
||||
def to_dict(self) -> dict[SystemVariableKey, Any]:
|
||||
# NOTE: This method is provided for compatibility with legacy code.
|
||||
# New code should use the `SystemVariable` object directly instead of converting
|
||||
# it to a dictionary, as this conversion results in the loss of type information
|
||||
# for each key, making static analysis more difficult.
|
||||
|
||||
d: dict[SystemVariableKey, Any] = {
|
||||
SystemVariableKey.FILES: self.files,
|
||||
}
|
||||
if self.user_id is not None:
|
||||
d[SystemVariableKey.USER_ID] = self.user_id
|
||||
if self.app_id is not None:
|
||||
d[SystemVariableKey.APP_ID] = self.app_id
|
||||
if self.workflow_id is not None:
|
||||
d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id
|
||||
if self.workflow_execution_id is not None:
|
||||
d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id
|
||||
if self.query is not None:
|
||||
d[SystemVariableKey.QUERY] = self.query
|
||||
if self.conversation_id is not None:
|
||||
d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id
|
||||
if self.dialogue_count is not None:
|
||||
d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count
|
||||
return d
|
||||
@ -0,0 +1,15 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from core.variables.segments import Segment
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
|
||||
class _VarTypedDict(TypedDict, total=False):
|
||||
value_type: SegmentType
|
||||
|
||||
|
||||
def serialize_value_type(v: _VarTypedDict | Segment) -> str:
|
||||
if isinstance(v, Segment):
|
||||
return v.value_type.exposed_type().value
|
||||
else:
|
||||
return v["value_type"].exposed_type().value
|
||||
@ -0,0 +1,164 @@
|
||||
import secrets
|
||||
import struct
|
||||
import time
|
||||
import uuid
|
||||
|
||||
# Reference for UUIDv7 specification:
|
||||
# RFC 9562, Section 5.7 - https://www.rfc-editor.org/rfc/rfc9562.html#section-5.7
|
||||
|
||||
# Define the format for packing the timestamp as an unsigned 64-bit integer (big-endian).
|
||||
#
|
||||
# For details on the `struct.pack` format, refer to:
|
||||
# https://docs.python.org/3/library/struct.html#byte-order-size-and-alignment
|
||||
_PACK_TIMESTAMP = ">Q"
|
||||
|
||||
# Define the format for packing the 12-bit random data A (as specified in RFC 9562 Section 5.7)
|
||||
# into an unsigned 16-bit integer (big-endian).
|
||||
_PACK_RAND_A = ">H"
|
||||
|
||||
|
||||
def _create_uuidv7_bytes(timestamp_ms: int, random_bytes: bytes) -> bytes:
|
||||
"""Create UUIDv7 byte structure with given timestamp and random bytes.
|
||||
|
||||
This is a private helper function that handles the common logic for creating
|
||||
UUIDv7 byte structure according to RFC 9562 specification.
|
||||
|
||||
UUIDv7 Structure:
|
||||
- 48 bits: timestamp (milliseconds since Unix epoch)
|
||||
- 12 bits: random data A (with version bits)
|
||||
- 62 bits: random data B (with variant bits)
|
||||
|
||||
The function performs the following operations:
|
||||
1. Creates a 128-bit (16-byte) UUID structure
|
||||
2. Packs the timestamp into the first 48 bits (6 bytes)
|
||||
3. Sets the version bits to 7 (0111) in the correct position
|
||||
4. Sets the variant bits to 10 (binary) in the correct position
|
||||
5. Fills the remaining bits with the provided random bytes
|
||||
|
||||
Args:
|
||||
timestamp_ms: The timestamp in milliseconds since Unix epoch (48 bits).
|
||||
random_bytes: Random bytes to use for the random portions (must be 10 bytes).
|
||||
First 2 bytes are used for random data A (12 bits after version).
|
||||
Last 8 bytes are used for random data B (62 bits after variant).
|
||||
|
||||
Returns:
|
||||
A 16-byte bytes object representing the complete UUIDv7 structure.
|
||||
|
||||
Note:
|
||||
This function assumes the random_bytes parameter is exactly 10 bytes.
|
||||
The caller is responsible for providing appropriate random data.
|
||||
"""
|
||||
# Create the 128-bit UUID structure
|
||||
uuid_bytes = bytearray(16)
|
||||
|
||||
# Pack timestamp (48 bits) into first 6 bytes
|
||||
uuid_bytes[0:6] = struct.pack(_PACK_TIMESTAMP, timestamp_ms)[2:8] # Take last 6 bytes of 8-byte big-endian
|
||||
|
||||
# Next 16 bits: random data A (12 bits) + version (4 bits)
|
||||
# Take first 2 random bytes and set version to 7
|
||||
rand_a = struct.unpack(_PACK_RAND_A, random_bytes[0:2])[0]
|
||||
# Clear the highest 4 bits to make room for the version field
|
||||
# by performing a bitwise AND with 0x0FFF (binary: 0b0000_1111_1111_1111).
|
||||
rand_a = rand_a & 0x0FFF
|
||||
# Set the version field to 7 (binary: 0111) by performing a bitwise OR with 0x7000 (binary: 0b0111_0000_0000_0000).
|
||||
rand_a = rand_a | 0x7000
|
||||
uuid_bytes[6:8] = struct.pack(_PACK_RAND_A, rand_a)
|
||||
|
||||
# Last 64 bits: random data B (62 bits) + variant (2 bits)
|
||||
# Use remaining 8 random bytes and set variant to 10 (binary)
|
||||
uuid_bytes[8:16] = random_bytes[2:10]
|
||||
|
||||
# Set variant bits (first 2 bits of byte 8 should be '10')
|
||||
uuid_bytes[8] = (uuid_bytes[8] & 0x3F) | 0x80 # Set variant to 10xxxxxx
|
||||
|
||||
return bytes(uuid_bytes)
|
||||
|
||||
|
||||
def uuidv7(timestamp_ms: int | None = None) -> uuid.UUID:
|
||||
"""Generate a UUID version 7 according to RFC 9562 specification.
|
||||
|
||||
UUIDv7 features a time-ordered value field derived from the widely
|
||||
implemented and well known Unix Epoch timestamp source, the number of
|
||||
milliseconds since midnight 1 Jan 1970 UTC, leap seconds excluded.
|
||||
|
||||
Structure:
|
||||
- 48 bits: timestamp (milliseconds since Unix epoch)
|
||||
- 12 bits: random data A (with version bits)
|
||||
- 62 bits: random data B (with variant bits)
|
||||
|
||||
Args:
|
||||
timestamp_ms: The timestamp used when generating UUID, use the current time if unspecified.
|
||||
Should be an integer representing milliseconds since Unix epoch.
|
||||
|
||||
Returns:
|
||||
A UUID object representing a UUIDv7.
|
||||
|
||||
Example:
|
||||
>>> import time
|
||||
>>> # Generate UUIDv7 with current time
|
||||
>>> uuid_current = uuidv7()
|
||||
>>> # Generate UUIDv7 with specific timestamp
|
||||
>>> uuid_specific = uuidv7(int(time.time() * 1000))
|
||||
"""
|
||||
if timestamp_ms is None:
|
||||
timestamp_ms = int(time.time() * 1000)
|
||||
|
||||
# Generate 10 random bytes for the random portions
|
||||
random_bytes = secrets.token_bytes(10)
|
||||
|
||||
# Create UUIDv7 bytes using the helper function
|
||||
uuid_bytes = _create_uuidv7_bytes(timestamp_ms, random_bytes)
|
||||
|
||||
return uuid.UUID(bytes=uuid_bytes)
|
||||
|
||||
|
||||
def uuidv7_timestamp(id_: uuid.UUID) -> int:
|
||||
"""Extract the timestamp from a UUIDv7.
|
||||
|
||||
UUIDv7 contains a 48-bit timestamp field representing milliseconds since
|
||||
the Unix epoch (1970-01-01 00:00:00 UTC). This function extracts and
|
||||
returns that timestamp as an integer representing milliseconds since the epoch.
|
||||
|
||||
Args:
|
||||
id_: A UUID object that should be a UUIDv7 (version 7).
|
||||
|
||||
Returns:
|
||||
The timestamp as an integer representing milliseconds since Unix epoch.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided UUID is not version 7.
|
||||
|
||||
Example:
|
||||
>>> uuid_v7 = uuidv7()
|
||||
>>> timestamp = uuidv7_timestamp(uuid_v7)
|
||||
>>> print(f"UUID was created at: {timestamp} ms")
|
||||
"""
|
||||
# Verify this is a UUIDv7
|
||||
if id_.version != 7:
|
||||
raise ValueError(f"Expected UUIDv7 (version 7), got version {id_.version}")
|
||||
|
||||
# Extract the UUID bytes
|
||||
uuid_bytes = id_.bytes
|
||||
|
||||
# Extract the first 48 bits (6 bytes) as the timestamp in milliseconds
|
||||
# Pad with 2 zero bytes at the beginning to make it 8 bytes for unpacking as Q (unsigned long long)
|
||||
timestamp_bytes = b"\x00\x00" + uuid_bytes[0:6]
|
||||
ts_in_ms = struct.unpack(_PACK_TIMESTAMP, timestamp_bytes)[0]
|
||||
|
||||
# Return timestamp directly in milliseconds as integer
|
||||
assert isinstance(ts_in_ms, int)
|
||||
return ts_in_ms
|
||||
|
||||
|
||||
def uuidv7_boundary(timestamp_ms: int) -> uuid.UUID:
|
||||
"""Generate a non-random uuidv7 with the given timestamp (first 48 bits) and
|
||||
all random bits to 0. As the smallest possible uuidv7 for that timestamp,
|
||||
it may be used as a boundary for partitions.
|
||||
"""
|
||||
# Use zero bytes for all random portions
|
||||
zero_random_bytes = b"\x00" * 10
|
||||
|
||||
# Create UUIDv7 bytes using the helper function
|
||||
uuid_bytes = _create_uuidv7_bytes(timestamp_ms, zero_random_bytes)
|
||||
|
||||
return uuid.UUID(bytes=uuid_bytes)
|
||||
@ -0,0 +1,86 @@
|
||||
"""add uuidv7 function in SQL
|
||||
|
||||
Revision ID: 1c9ba48be8e4
|
||||
Revises: 58eb7bdb93fe
|
||||
Create Date: 2025-07-02 23:32:38.484499
|
||||
|
||||
"""
|
||||
|
||||
"""
|
||||
The functions in this files comes from https://github.com/dverite/postgres-uuidv7-sql/, with minor modifications.
|
||||
|
||||
LICENSE:
|
||||
|
||||
# Copyright and License
|
||||
|
||||
Copyright (c) 2024, Daniel Vérité
|
||||
|
||||
Permission to use, copy, modify, and distribute this software and its documentation for any purpose, without fee, and without a written agreement is hereby granted, provided that the above copyright notice and this paragraph and the following two paragraphs appear in all copies.
|
||||
|
||||
In no event shall Daniel Vérité be liable to any party for direct, indirect, special, incidental, or consequential damages, including lost profits, arising out of the use of this software and its documentation, even if Daniel Vérité has been advised of the possibility of such damage.
|
||||
|
||||
Daniel Vérité specifically disclaims any warranties, including, but not limited to, the implied warranties of merchantability and fitness for a particular purpose. The software provided hereunder is on an "AS IS" basis, and Daniel Vérité has no obligations to provide maintenance, support, updates, enhancements, or modifications.
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '1c9ba48be8e4'
|
||||
down_revision = '58eb7bdb93fe'
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# This implementation differs slightly from the original uuidv7 function in
|
||||
# https://github.com/dverite/postgres-uuidv7-sql/.
|
||||
# The ability to specify source timestamp has been removed because its type signature is incompatible with
|
||||
# PostgreSQL 18's `uuidv7` function. This capability is rarely needed in practice, as IDs can be
|
||||
# generated and controlled within the application layer.
|
||||
op.execute(sa.text(r"""
|
||||
/* Main function to generate a uuidv7 value with millisecond precision */
|
||||
CREATE FUNCTION uuidv7() RETURNS uuid
|
||||
AS
|
||||
$$
|
||||
-- Replace the first 48 bits of a uuidv4 with the current
|
||||
-- number of milliseconds since 1970-01-01 UTC
|
||||
-- and set the "ver" field to 7 by setting additional bits
|
||||
SELECT encode(
|
||||
set_bit(
|
||||
set_bit(
|
||||
overlay(uuid_send(gen_random_uuid()) placing
|
||||
substring(int8send((extract(epoch from clock_timestamp()) * 1000)::bigint) from
|
||||
3)
|
||||
from 1 for 6),
|
||||
52, 1),
|
||||
53, 1), 'hex')::uuid;
|
||||
$$ LANGUAGE SQL VOLATILE PARALLEL SAFE;
|
||||
|
||||
COMMENT ON FUNCTION uuidv7 IS
|
||||
'Generate a uuid-v7 value with a 48-bit timestamp (millisecond precision) and 74 bits of randomness';
|
||||
"""))
|
||||
|
||||
op.execute(sa.text(r"""
|
||||
CREATE FUNCTION uuidv7_boundary(timestamptz) RETURNS uuid
|
||||
AS
|
||||
$$
|
||||
/* uuid fields: version=0b0111, variant=0b10 */
|
||||
SELECT encode(
|
||||
overlay('\x00000000000070008000000000000000'::bytea
|
||||
placing substring(int8send(floor(extract(epoch from $1) * 1000)::bigint) from 3)
|
||||
from 1 for 6),
|
||||
'hex')::uuid;
|
||||
$$ LANGUAGE SQL STABLE STRICT PARALLEL SAFE;
|
||||
|
||||
COMMENT ON FUNCTION uuidv7_boundary(timestamptz) IS
|
||||
'Generate a non-random uuidv7 with the given timestamp (first 48 bits) and all random bits to 0. As the smallest possible uuidv7 for that timestamp, it may be used as a boundary for partitions.';
|
||||
"""
|
||||
))
|
||||
|
||||
|
||||
def downgrade():
|
||||
op.execute(sa.text("DROP FUNCTION uuidv7"))
|
||||
op.execute(sa.text("DROP FUNCTION uuidv7_boundary"))
|
||||
@ -0,0 +1,380 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_login import LoginManager, UserMixin
|
||||
|
||||
from controllers.console.error import NotInitValidateError, NotSetupError, UnauthorizedAndForceLogout
|
||||
from controllers.console.workspace.error import AccountNotInitializedError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
cloud_edition_billing_resource_check,
|
||||
enterprise_license_required,
|
||||
only_edition_cloud,
|
||||
only_edition_enterprise,
|
||||
only_edition_self_hosted,
|
||||
setup_required,
|
||||
)
|
||||
from models.account import AccountStatus
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
|
||||
class MockUser(UserMixin):
|
||||
"""Simple User class for testing."""
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.id = user_id
|
||||
self.current_tenant_id = "tenant123"
|
||||
|
||||
def get_id(self) -> str:
|
||||
return self.id
|
||||
|
||||
|
||||
def create_app_with_login():
|
||||
"""Create a Flask app with LoginManager configured."""
|
||||
app = Flask(__name__)
|
||||
app.config["SECRET_KEY"] = "test-secret-key"
|
||||
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
@login_manager.user_loader
|
||||
def load_user(user_id: str):
|
||||
return MockUser(user_id)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
class TestAccountInitialization:
|
||||
"""Test account initialization decorator"""
|
||||
|
||||
def test_should_allow_initialized_account(self):
|
||||
"""Test that initialized accounts can access protected views"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_user.status = AccountStatus.ACTIVE
|
||||
|
||||
@account_initialization_required
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.current_user", mock_user):
|
||||
result = protected_view()
|
||||
|
||||
# Assert
|
||||
assert result == "success"
|
||||
|
||||
def test_should_reject_uninitialized_account(self):
|
||||
"""Test that uninitialized accounts raise AccountNotInitializedError"""
|
||||
# Arrange
|
||||
mock_user = MagicMock()
|
||||
mock_user.status = AccountStatus.UNINITIALIZED
|
||||
|
||||
@account_initialization_required
|
||||
def protected_view():
|
||||
return "success"
|
||||
|
||||
# Act & Assert
|
||||
with patch("controllers.console.wraps.current_user", mock_user):
|
||||
with pytest.raises(AccountNotInitializedError):
|
||||
protected_view()
|
||||
|
||||
|
||||
class TestEditionChecks:
|
||||
"""Test edition-specific decorators"""
|
||||
|
||||
def test_only_edition_cloud_allows_cloud_edition(self):
|
||||
"""Test cloud edition decorator allows CLOUD edition"""
|
||||
|
||||
# Arrange
|
||||
@only_edition_cloud
|
||||
def cloud_view():
|
||||
return "cloud_success"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.dify_config.EDITION", "CLOUD"):
|
||||
result = cloud_view()
|
||||
|
||||
# Assert
|
||||
assert result == "cloud_success"
|
||||
|
||||
def test_only_edition_cloud_rejects_other_editions(self):
|
||||
"""Test cloud edition decorator rejects non-CLOUD editions"""
|
||||
# Arrange
|
||||
app = Flask(__name__)
|
||||
|
||||
@only_edition_cloud
|
||||
def cloud_view():
|
||||
return "cloud_success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
cloud_view()
|
||||
assert exc_info.value.code == 404
|
||||
|
||||
def test_only_edition_enterprise_allows_when_enabled(self):
|
||||
"""Test enterprise edition decorator allows when ENTERPRISE_ENABLED is True"""
|
||||
|
||||
# Arrange
|
||||
@only_edition_enterprise
|
||||
def enterprise_view():
|
||||
return "enterprise_success"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.dify_config.ENTERPRISE_ENABLED", True):
|
||||
result = enterprise_view()
|
||||
|
||||
# Assert
|
||||
assert result == "enterprise_success"
|
||||
|
||||
def test_only_edition_self_hosted_allows_self_hosted(self):
|
||||
"""Test self-hosted edition decorator allows SELF_HOSTED edition"""
|
||||
|
||||
# Arrange
|
||||
@only_edition_self_hosted
|
||||
def self_hosted_view():
|
||||
return "self_hosted_success"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
|
||||
result = self_hosted_view()
|
||||
|
||||
# Assert
|
||||
assert result == "self_hosted_success"
|
||||
|
||||
|
||||
class TestBillingResourceLimits:
|
||||
"""Test billing resource limit decorators"""
|
||||
|
||||
def test_should_allow_when_under_resource_limit(self):
|
||||
"""Test that requests are allowed when under resource limits"""
|
||||
# Arrange
|
||||
mock_features = MagicMock()
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.members.limit = 10
|
||||
mock_features.members.size = 5
|
||||
|
||||
@cloud_edition_billing_resource_check("members")
|
||||
def add_member():
|
||||
return "member_added"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.current_user"):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
result = add_member()
|
||||
|
||||
# Assert
|
||||
assert result == "member_added"
|
||||
|
||||
def test_should_reject_when_over_resource_limit(self):
|
||||
"""Test that requests are rejected when over resource limits"""
|
||||
# Arrange
|
||||
app = create_app_with_login()
|
||||
mock_features = MagicMock()
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.members.limit = 10
|
||||
mock_features.members.size = 10
|
||||
|
||||
@cloud_edition_billing_resource_check("members")
|
||||
def add_member():
|
||||
return "member_added"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
add_member()
|
||||
assert exc_info.value.code == 403
|
||||
assert "members has reached the limit" in str(exc_info.value.description)
|
||||
|
||||
def test_should_check_source_for_documents_limit(self):
|
||||
"""Test document limit checks request source"""
|
||||
# Arrange
|
||||
app = create_app_with_login()
|
||||
mock_features = MagicMock()
|
||||
mock_features.billing.enabled = True
|
||||
mock_features.documents_upload_quota.limit = 100
|
||||
mock_features.documents_upload_quota.size = 100
|
||||
|
||||
@cloud_edition_billing_resource_check("documents")
|
||||
def upload_document():
|
||||
return "document_uploaded"
|
||||
|
||||
# Test 1: Should reject when source is datasets
|
||||
with app.test_request_context("/?source=datasets"):
|
||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
upload_document()
|
||||
assert exc_info.value.code == 403
|
||||
|
||||
# Test 2: Should allow when source is not datasets
|
||||
with app.test_request_context("/?source=other"):
|
||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
||||
with patch("controllers.console.wraps.FeatureService.get_features", return_value=mock_features):
|
||||
result = upload_document()
|
||||
assert result == "document_uploaded"
|
||||
|
||||
|
||||
class TestRateLimiting:
|
||||
"""Test rate limiting decorator"""
|
||||
|
||||
@patch("controllers.console.wraps.redis_client")
|
||||
@patch("controllers.console.wraps.db")
|
||||
def test_should_allow_requests_within_rate_limit(self, mock_db, mock_redis):
|
||||
"""Test that requests within rate limit are allowed"""
|
||||
# Arrange
|
||||
mock_rate_limit = MagicMock()
|
||||
mock_rate_limit.enabled = True
|
||||
mock_rate_limit.limit = 10
|
||||
mock_redis.zcard.return_value = 5 # 5 requests in window
|
||||
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def knowledge_request():
|
||||
return "knowledge_success"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.current_user"):
|
||||
with patch(
|
||||
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
||||
):
|
||||
result = knowledge_request()
|
||||
|
||||
# Assert
|
||||
assert result == "knowledge_success"
|
||||
mock_redis.zadd.assert_called_once()
|
||||
mock_redis.zremrangebyscore.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.redis_client")
|
||||
@patch("controllers.console.wraps.db")
|
||||
def test_should_reject_requests_over_rate_limit(self, mock_db, mock_redis):
|
||||
"""Test that requests over rate limit are rejected and logged"""
|
||||
# Arrange
|
||||
app = create_app_with_login()
|
||||
mock_rate_limit = MagicMock()
|
||||
mock_rate_limit.enabled = True
|
||||
mock_rate_limit.limit = 10
|
||||
mock_rate_limit.subscription_plan = "pro"
|
||||
mock_redis.zcard.return_value = 11 # Over limit
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_db.session = mock_session
|
||||
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def knowledge_request():
|
||||
return "knowledge_success"
|
||||
|
||||
# Act & Assert
|
||||
with app.test_request_context():
|
||||
with patch("controllers.console.wraps.current_user", MockUser("test_user")):
|
||||
with patch(
|
||||
"controllers.console.wraps.FeatureService.get_knowledge_rate_limit", return_value=mock_rate_limit
|
||||
):
|
||||
with pytest.raises(Exception) as exc_info:
|
||||
knowledge_request()
|
||||
|
||||
# Verify error
|
||||
assert exc_info.value.code == 403
|
||||
assert "rate limit" in str(exc_info.value.description)
|
||||
|
||||
# Verify rate limit log was created
|
||||
mock_session.add.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestSystemSetup:
|
||||
"""Test system setup decorator"""
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
def test_should_allow_when_setup_complete(self, mock_db):
|
||||
"""Test that requests are allowed when setup is complete"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = MagicMock() # Setup exists
|
||||
|
||||
@setup_required
|
||||
def admin_view():
|
||||
return "admin_success"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
|
||||
result = admin_view()
|
||||
|
||||
# Assert
|
||||
assert result == "admin_success"
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.wraps.os.environ.get")
|
||||
def test_should_raise_not_init_validate_error_with_init_password(self, mock_environ_get, mock_db):
|
||||
"""Test NotInitValidateError when INIT_PASSWORD is set but setup not complete"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = None # No setup
|
||||
mock_environ_get.return_value = "some_password"
|
||||
|
||||
@setup_required
|
||||
def admin_view():
|
||||
return "admin_success"
|
||||
|
||||
# Act & Assert
|
||||
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
|
||||
with pytest.raises(NotInitValidateError):
|
||||
admin_view()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.wraps.os.environ.get")
|
||||
def test_should_raise_not_setup_error_without_init_password(self, mock_environ_get, mock_db):
|
||||
"""Test NotSetupError when no INIT_PASSWORD and setup not complete"""
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.first.return_value = None # No setup
|
||||
mock_environ_get.return_value = None # No INIT_PASSWORD
|
||||
|
||||
@setup_required
|
||||
def admin_view():
|
||||
return "admin_success"
|
||||
|
||||
# Act & Assert
|
||||
with patch("controllers.console.wraps.dify_config.EDITION", "SELF_HOSTED"):
|
||||
with pytest.raises(NotSetupError):
|
||||
admin_view()
|
||||
|
||||
|
||||
class TestEnterpriseLicense:
|
||||
"""Test enterprise license decorator"""
|
||||
|
||||
def test_should_allow_with_valid_license(self):
|
||||
"""Test that valid licenses allow access"""
|
||||
# Arrange
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.license.status = LicenseStatus.ACTIVE
|
||||
|
||||
@enterprise_license_required
|
||||
def enterprise_feature():
|
||||
return "enterprise_success"
|
||||
|
||||
# Act
|
||||
with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings):
|
||||
result = enterprise_feature()
|
||||
|
||||
# Assert
|
||||
assert result == "enterprise_success"
|
||||
|
||||
@pytest.mark.parametrize("invalid_status", [LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST])
|
||||
def test_should_reject_with_invalid_license(self, invalid_status):
|
||||
"""Test that invalid licenses raise UnauthorizedAndForceLogout"""
|
||||
# Arrange
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.license.status = invalid_status
|
||||
|
||||
@enterprise_license_required
|
||||
def enterprise_feature():
|
||||
return "enterprise_success"
|
||||
|
||||
# Act & Assert
|
||||
with patch("controllers.console.wraps.FeatureService.get_system_features", return_value=mock_settings):
|
||||
with pytest.raises(UnauthorizedAndForceLogout) as exc_info:
|
||||
enterprise_feature()
|
||||
assert "license is invalid" in str(exc_info.value)
|
||||
@ -0,0 +1,56 @@
|
||||
import pytest
|
||||
from flask import Flask
|
||||
|
||||
from core.tools.utils.parser import ApiBasedToolSchemaParser
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
app = Flask(__name__)
|
||||
return app
|
||||
|
||||
|
||||
def test_parse_openapi_to_tool_bundle_operation_id(app):
|
||||
openapi = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {"title": "Simple API", "version": "1.0.0"},
|
||||
"servers": [{"url": "http://localhost:3000"}],
|
||||
"paths": {
|
||||
"/": {
|
||||
"get": {
|
||||
"summary": "Root endpoint",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful response",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"/api/resources": {
|
||||
"get": {
|
||||
"summary": "Non-root endpoint without an operationId",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful response",
|
||||
}
|
||||
},
|
||||
},
|
||||
"post": {
|
||||
"summary": "Non-root endpoint with an operationId",
|
||||
"operationId": "createResource",
|
||||
"responses": {
|
||||
"201": {
|
||||
"description": "Resource created",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
with app.test_request_context():
|
||||
tool_bundles = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi)
|
||||
|
||||
assert len(tool_bundles) == 3
|
||||
assert tool_bundles[0].operation_id == "<root>_get"
|
||||
assert tool_bundles[1].operation_id == "apiresources_get"
|
||||
assert tool_bundles[2].operation_id == "createResource"
|
||||
@ -0,0 +1,60 @@
|
||||
from core.variables.types import SegmentType
|
||||
|
||||
|
||||
class TestSegmentTypeIsArrayType:
|
||||
"""
|
||||
Test class for SegmentType.is_array_type method.
|
||||
|
||||
Provides comprehensive coverage of all SegmentType values to ensure
|
||||
correct identification of array and non-array types.
|
||||
"""
|
||||
|
||||
def test_is_array_type(self):
|
||||
"""
|
||||
Test that all SegmentType enum values are covered in our test cases.
|
||||
|
||||
Ensures comprehensive coverage by verifying that every SegmentType
|
||||
value is tested for the is_array_type method.
|
||||
"""
|
||||
# Arrange
|
||||
all_segment_types = set(SegmentType)
|
||||
expected_array_types = [
|
||||
SegmentType.ARRAY_ANY,
|
||||
SegmentType.ARRAY_STRING,
|
||||
SegmentType.ARRAY_NUMBER,
|
||||
SegmentType.ARRAY_OBJECT,
|
||||
SegmentType.ARRAY_FILE,
|
||||
]
|
||||
expected_non_array_types = [
|
||||
SegmentType.INTEGER,
|
||||
SegmentType.FLOAT,
|
||||
SegmentType.NUMBER,
|
||||
SegmentType.STRING,
|
||||
SegmentType.OBJECT,
|
||||
SegmentType.SECRET,
|
||||
SegmentType.FILE,
|
||||
SegmentType.NONE,
|
||||
SegmentType.GROUP,
|
||||
]
|
||||
|
||||
for seg_type in expected_array_types:
|
||||
assert seg_type.is_array_type()
|
||||
|
||||
for seg_type in expected_non_array_types:
|
||||
assert not seg_type.is_array_type()
|
||||
|
||||
# Act & Assert
|
||||
covered_types = set(expected_array_types) | set(expected_non_array_types)
|
||||
assert covered_types == set(SegmentType), "All SegmentType values should be covered in tests"
|
||||
|
||||
def test_all_enum_values_are_supported(self):
|
||||
"""
|
||||
Test that all enum values are supported and return boolean values.
|
||||
|
||||
Validates that every SegmentType enum value can be processed by
|
||||
is_array_type method and returns a boolean value.
|
||||
"""
|
||||
enum_values: list[SegmentType] = list(SegmentType)
|
||||
for seg_type in enum_values:
|
||||
is_array = seg_type.is_array_type()
|
||||
assert isinstance(is_array, bool), f"is_array_type does not return a boolean for segment type {seg_type}"
|
||||
@ -0,0 +1,146 @@
|
||||
import time
|
||||
from decimal import Decimal
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RuntimeRouteState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
|
||||
|
||||
def create_test_graph_runtime_state() -> GraphRuntimeState:
|
||||
"""Factory function to create a GraphRuntimeState with non-empty values for testing."""
|
||||
# Create a variable pool with system variables
|
||||
system_vars = SystemVariable(
|
||||
user_id="test_user_123",
|
||||
app_id="test_app_456",
|
||||
workflow_id="test_workflow_789",
|
||||
workflow_execution_id="test_execution_001",
|
||||
query="test query",
|
||||
conversation_id="test_conv_123",
|
||||
dialogue_count=5,
|
||||
)
|
||||
variable_pool = VariablePool(system_variables=system_vars)
|
||||
|
||||
# Add some variables to the variable pool
|
||||
variable_pool.add(["test_node", "test_var"], "test_value")
|
||||
variable_pool.add(["another_node", "another_var"], 42)
|
||||
|
||||
# Create LLM usage with realistic values
|
||||
llm_usage = LLMUsage(
|
||||
prompt_tokens=150,
|
||||
prompt_unit_price=Decimal("0.001"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.15"),
|
||||
completion_tokens=75,
|
||||
completion_unit_price=Decimal("0.002"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.15"),
|
||||
total_tokens=225,
|
||||
total_price=Decimal("0.30"),
|
||||
currency="USD",
|
||||
latency=1.25,
|
||||
)
|
||||
|
||||
# Create runtime route state with some node states
|
||||
node_run_state = RuntimeRouteState()
|
||||
node_state = node_run_state.create_node_state("test_node_1")
|
||||
node_run_state.add_route(node_state.id, "target_node_id")
|
||||
|
||||
return GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
total_tokens=100,
|
||||
llm_usage=llm_usage,
|
||||
outputs={
|
||||
"string_output": "test result",
|
||||
"int_output": 42,
|
||||
"float_output": 3.14,
|
||||
"list_output": ["item1", "item2", "item3"],
|
||||
"dict_output": {"key1": "value1", "key2": 123},
|
||||
"nested_dict": {"level1": {"level2": ["nested", "list", 456]}},
|
||||
},
|
||||
node_run_steps=5,
|
||||
node_run_state=node_run_state,
|
||||
)
|
||||
|
||||
|
||||
def test_basic_round_trip_serialization():
|
||||
"""Test basic round-trip serialization ensures GraphRuntimeState values remain unchanged."""
|
||||
# Create a state with non-empty values
|
||||
original_state = create_test_graph_runtime_state()
|
||||
|
||||
# Serialize to JSON and deserialize back
|
||||
json_data = original_state.model_dump_json()
|
||||
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
|
||||
|
||||
# Core test: ensure the round-trip preserves all values
|
||||
assert deserialized_state == original_state
|
||||
|
||||
# Serialize to JSON and deserialize back
|
||||
dict_data = original_state.model_dump(mode="python")
|
||||
deserialized_state = GraphRuntimeState.model_validate(dict_data)
|
||||
assert deserialized_state == original_state
|
||||
|
||||
# Serialize to JSON and deserialize back
|
||||
dict_data = original_state.model_dump(mode="json")
|
||||
deserialized_state = GraphRuntimeState.model_validate(dict_data)
|
||||
assert deserialized_state == original_state
|
||||
|
||||
|
||||
def test_outputs_field_round_trip():
|
||||
"""Test the problematic outputs field maintains values through round-trip serialization."""
|
||||
original_state = create_test_graph_runtime_state()
|
||||
|
||||
# Serialize and deserialize
|
||||
json_data = original_state.model_dump_json()
|
||||
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
|
||||
|
||||
# Verify the outputs field specifically maintains its values
|
||||
assert deserialized_state.outputs == original_state.outputs
|
||||
assert deserialized_state == original_state
|
||||
|
||||
|
||||
def test_empty_outputs_round_trip():
|
||||
"""Test round-trip serialization with empty outputs field."""
|
||||
variable_pool = VariablePool.empty()
|
||||
original_state = GraphRuntimeState(
|
||||
variable_pool=variable_pool,
|
||||
start_at=time.perf_counter(),
|
||||
outputs={}, # Empty outputs
|
||||
)
|
||||
|
||||
json_data = original_state.model_dump_json()
|
||||
deserialized_state = GraphRuntimeState.model_validate_json(json_data)
|
||||
|
||||
assert deserialized_state == original_state
|
||||
|
||||
|
||||
def test_llm_usage_round_trip():
|
||||
# Create LLM usage with specific decimal values
|
||||
llm_usage = LLMUsage(
|
||||
prompt_tokens=100,
|
||||
prompt_unit_price=Decimal("0.0015"),
|
||||
prompt_price_unit=Decimal(1000),
|
||||
prompt_price=Decimal("0.15"),
|
||||
completion_tokens=50,
|
||||
completion_unit_price=Decimal("0.003"),
|
||||
completion_price_unit=Decimal(1000),
|
||||
completion_price=Decimal("0.15"),
|
||||
total_tokens=150,
|
||||
total_price=Decimal("0.30"),
|
||||
currency="USD",
|
||||
latency=2.5,
|
||||
)
|
||||
|
||||
json_data = llm_usage.model_dump_json()
|
||||
deserialized = LLMUsage.model_validate_json(json_data)
|
||||
assert deserialized == llm_usage
|
||||
|
||||
dict_data = llm_usage.model_dump(mode="python")
|
||||
deserialized = LLMUsage.model_validate(dict_data)
|
||||
assert deserialized == llm_usage
|
||||
|
||||
dict_data = llm_usage.model_dump(mode="json")
|
||||
deserialized = LLMUsage.model_validate(dict_data)
|
||||
assert deserialized == llm_usage
|
||||
@ -0,0 +1,401 @@
|
||||
import json
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph_engine.entities.runtime_route_state import RouteNodeState, RuntimeRouteState
|
||||
|
||||
_TEST_DATETIME = datetime(2024, 1, 15, 10, 30, 45)
|
||||
|
||||
|
||||
class TestRouteNodeStateSerialization:
|
||||
"""Test cases for RouteNodeState Pydantic serialization/deserialization."""
|
||||
|
||||
def _test_route_node_state(self):
|
||||
"""Test comprehensive RouteNodeState serialization with all core fields validation."""
|
||||
|
||||
node_run_result = NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={"input_key": "input_value"},
|
||||
outputs={"output_key": "output_value"},
|
||||
)
|
||||
|
||||
node_state = RouteNodeState(
|
||||
node_id="comprehensive_test_node",
|
||||
start_at=_TEST_DATETIME,
|
||||
finished_at=_TEST_DATETIME,
|
||||
status=RouteNodeState.Status.SUCCESS,
|
||||
node_run_result=node_run_result,
|
||||
index=5,
|
||||
paused_at=_TEST_DATETIME,
|
||||
paused_by="user_123",
|
||||
failed_reason="test_reason",
|
||||
)
|
||||
return node_state
|
||||
|
||||
def test_route_node_state_comprehensive_field_validation(self):
|
||||
"""Test comprehensive RouteNodeState serialization with all core fields validation."""
|
||||
node_state = self._test_route_node_state()
|
||||
serialized = node_state.model_dump()
|
||||
|
||||
# Comprehensive validation of all RouteNodeState fields
|
||||
assert serialized["node_id"] == "comprehensive_test_node"
|
||||
assert serialized["status"] == RouteNodeState.Status.SUCCESS
|
||||
assert serialized["start_at"] == _TEST_DATETIME
|
||||
assert serialized["finished_at"] == _TEST_DATETIME
|
||||
assert serialized["paused_at"] == _TEST_DATETIME
|
||||
assert serialized["paused_by"] == "user_123"
|
||||
assert serialized["failed_reason"] == "test_reason"
|
||||
assert serialized["index"] == 5
|
||||
assert "id" in serialized
|
||||
assert isinstance(serialized["id"], str)
|
||||
uuid.UUID(serialized["id"]) # Validate UUID format
|
||||
|
||||
# Validate nested NodeRunResult structure
|
||||
assert serialized["node_run_result"] is not None
|
||||
assert serialized["node_run_result"]["status"] == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert serialized["node_run_result"]["inputs"] == {"input_key": "input_value"}
|
||||
assert serialized["node_run_result"]["outputs"] == {"output_key": "output_value"}
|
||||
|
||||
def test_route_node_state_minimal_required_fields(self):
|
||||
"""Test RouteNodeState with only required fields, focusing on defaults."""
|
||||
node_state = RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME)
|
||||
|
||||
serialized = node_state.model_dump()
|
||||
|
||||
# Focus on required fields and default values (not re-testing all fields)
|
||||
assert serialized["node_id"] == "minimal_node"
|
||||
assert serialized["start_at"] == _TEST_DATETIME
|
||||
assert serialized["status"] == RouteNodeState.Status.RUNNING # Default status
|
||||
assert serialized["index"] == 1 # Default index
|
||||
assert serialized["node_run_result"] is None # Default None
|
||||
json = node_state.model_dump_json()
|
||||
deserialized = RouteNodeState.model_validate_json(json)
|
||||
assert deserialized == node_state
|
||||
|
||||
def test_route_node_state_deserialization_from_dict(self):
|
||||
"""Test RouteNodeState deserialization from dictionary data."""
|
||||
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
|
||||
test_id = str(uuid.uuid4())
|
||||
|
||||
dict_data = {
|
||||
"id": test_id,
|
||||
"node_id": "deserialized_node",
|
||||
"start_at": test_datetime,
|
||||
"status": "success",
|
||||
"finished_at": test_datetime,
|
||||
"index": 3,
|
||||
}
|
||||
|
||||
node_state = RouteNodeState.model_validate(dict_data)
|
||||
|
||||
# Focus on deserialization accuracy
|
||||
assert node_state.id == test_id
|
||||
assert node_state.node_id == "deserialized_node"
|
||||
assert node_state.start_at == test_datetime
|
||||
assert node_state.status == RouteNodeState.Status.SUCCESS
|
||||
assert node_state.finished_at == test_datetime
|
||||
assert node_state.index == 3
|
||||
|
||||
def test_route_node_state_round_trip_consistency(self):
|
||||
node_states = (
|
||||
self._test_route_node_state(),
|
||||
RouteNodeState(node_id="minimal_node", start_at=_TEST_DATETIME),
|
||||
)
|
||||
for node_state in node_states:
|
||||
json = node_state.model_dump_json()
|
||||
deserialized = RouteNodeState.model_validate_json(json)
|
||||
assert deserialized == node_state
|
||||
|
||||
dict_ = node_state.model_dump(mode="python")
|
||||
deserialized = RouteNodeState.model_validate(dict_)
|
||||
assert deserialized == node_state
|
||||
|
||||
dict_ = node_state.model_dump(mode="json")
|
||||
deserialized = RouteNodeState.model_validate(dict_)
|
||||
assert deserialized == node_state
|
||||
|
||||
|
||||
class TestRouteNodeStateEnumSerialization:
|
||||
"""Dedicated tests for RouteNodeState Status enum serialization behavior."""
|
||||
|
||||
def test_status_enum_model_dump_behavior(self):
|
||||
"""Test Status enum serialization in model_dump() returns enum objects."""
|
||||
|
||||
for status_enum in RouteNodeState.Status:
|
||||
node_state = RouteNodeState(node_id="enum_test", start_at=_TEST_DATETIME, status=status_enum)
|
||||
serialized = node_state.model_dump(mode="python")
|
||||
assert serialized["status"] == status_enum
|
||||
serialized = node_state.model_dump(mode="json")
|
||||
assert serialized["status"] == status_enum.value
|
||||
|
||||
def test_status_enum_json_serialization_behavior(self):
|
||||
"""Test Status enum serialization in JSON returns string values."""
|
||||
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
|
||||
|
||||
enum_to_string_mapping = {
|
||||
RouteNodeState.Status.RUNNING: "running",
|
||||
RouteNodeState.Status.SUCCESS: "success",
|
||||
RouteNodeState.Status.FAILED: "failed",
|
||||
RouteNodeState.Status.PAUSED: "paused",
|
||||
RouteNodeState.Status.EXCEPTION: "exception",
|
||||
}
|
||||
|
||||
for status_enum, expected_string in enum_to_string_mapping.items():
|
||||
node_state = RouteNodeState(node_id="json_enum_test", start_at=test_datetime, status=status_enum)
|
||||
|
||||
json_data = json.loads(node_state.model_dump_json())
|
||||
assert json_data["status"] == expected_string
|
||||
|
||||
def test_status_enum_deserialization_from_string(self):
|
||||
"""Test Status enum deserialization from string values."""
|
||||
test_datetime = datetime(2024, 1, 15, 10, 30, 45)
|
||||
|
||||
string_to_enum_mapping = {
|
||||
"running": RouteNodeState.Status.RUNNING,
|
||||
"success": RouteNodeState.Status.SUCCESS,
|
||||
"failed": RouteNodeState.Status.FAILED,
|
||||
"paused": RouteNodeState.Status.PAUSED,
|
||||
"exception": RouteNodeState.Status.EXCEPTION,
|
||||
}
|
||||
|
||||
for status_string, expected_enum in string_to_enum_mapping.items():
|
||||
dict_data = {
|
||||
"node_id": "enum_deserialize_test",
|
||||
"start_at": test_datetime,
|
||||
"status": status_string,
|
||||
}
|
||||
|
||||
node_state = RouteNodeState.model_validate(dict_data)
|
||||
assert node_state.status == expected_enum
|
||||
|
||||
|
||||
class TestRuntimeRouteStateSerialization:
|
||||
"""Test cases for RuntimeRouteState Pydantic serialization/deserialization."""
|
||||
|
||||
_NODE1_ID = "node_1"
|
||||
_ROUTE_STATE1_ID = str(uuid.uuid4())
|
||||
_NODE2_ID = "node_2"
|
||||
_ROUTE_STATE2_ID = str(uuid.uuid4())
|
||||
_NODE3_ID = "node_3"
|
||||
_ROUTE_STATE3_ID = str(uuid.uuid4())
|
||||
|
||||
def _get_runtime_route_state(self):
|
||||
# Create node states with different configurations
|
||||
node_state_1 = RouteNodeState(
|
||||
id=self._ROUTE_STATE1_ID,
|
||||
node_id=self._NODE1_ID,
|
||||
start_at=_TEST_DATETIME,
|
||||
index=1,
|
||||
)
|
||||
node_state_2 = RouteNodeState(
|
||||
id=self._ROUTE_STATE2_ID,
|
||||
node_id=self._NODE2_ID,
|
||||
start_at=_TEST_DATETIME,
|
||||
status=RouteNodeState.Status.SUCCESS,
|
||||
finished_at=_TEST_DATETIME,
|
||||
index=2,
|
||||
)
|
||||
node_state_3 = RouteNodeState(
|
||||
id=self._ROUTE_STATE3_ID,
|
||||
node_id=self._NODE3_ID,
|
||||
start_at=_TEST_DATETIME,
|
||||
status=RouteNodeState.Status.FAILED,
|
||||
failed_reason="Test failure",
|
||||
index=3,
|
||||
)
|
||||
|
||||
runtime_state = RuntimeRouteState(
|
||||
routes={node_state_1.id: [node_state_2.id, node_state_3.id], node_state_2.id: [node_state_3.id]},
|
||||
node_state_mapping={
|
||||
node_state_1.id: node_state_1,
|
||||
node_state_2.id: node_state_2,
|
||||
node_state_3.id: node_state_3,
|
||||
},
|
||||
)
|
||||
|
||||
return runtime_state
|
||||
|
||||
def test_runtime_route_state_comprehensive_structure_validation(self):
|
||||
"""Test comprehensive RuntimeRouteState serialization with full structure validation."""
|
||||
|
||||
runtime_state = self._get_runtime_route_state()
|
||||
serialized = runtime_state.model_dump()
|
||||
|
||||
# Comprehensive validation of RuntimeRouteState structure
|
||||
assert "routes" in serialized
|
||||
assert "node_state_mapping" in serialized
|
||||
assert isinstance(serialized["routes"], dict)
|
||||
assert isinstance(serialized["node_state_mapping"], dict)
|
||||
|
||||
# Validate routes dictionary structure and content
|
||||
assert len(serialized["routes"]) == 2
|
||||
assert self._ROUTE_STATE1_ID in serialized["routes"]
|
||||
assert self._ROUTE_STATE2_ID in serialized["routes"]
|
||||
assert serialized["routes"][self._ROUTE_STATE1_ID] == [self._ROUTE_STATE2_ID, self._ROUTE_STATE3_ID]
|
||||
assert serialized["routes"][self._ROUTE_STATE2_ID] == [self._ROUTE_STATE3_ID]
|
||||
|
||||
# Validate node_state_mapping dictionary structure and content
|
||||
assert len(serialized["node_state_mapping"]) == 3
|
||||
for state_id in [
|
||||
self._ROUTE_STATE1_ID,
|
||||
self._ROUTE_STATE2_ID,
|
||||
self._ROUTE_STATE3_ID,
|
||||
]:
|
||||
assert state_id in serialized["node_state_mapping"]
|
||||
node_data = serialized["node_state_mapping"][state_id]
|
||||
node_state = runtime_state.node_state_mapping[state_id]
|
||||
assert node_data["node_id"] == node_state.node_id
|
||||
assert node_data["status"] == node_state.status
|
||||
assert node_data["index"] == node_state.index
|
||||
|
||||
def test_runtime_route_state_empty_collections(self):
|
||||
"""Test RuntimeRouteState with empty collections, focusing on default behavior."""
|
||||
runtime_state = RuntimeRouteState()
|
||||
serialized = runtime_state.model_dump()
|
||||
|
||||
# Focus on default empty collection behavior
|
||||
assert serialized["routes"] == {}
|
||||
assert serialized["node_state_mapping"] == {}
|
||||
assert isinstance(serialized["routes"], dict)
|
||||
assert isinstance(serialized["node_state_mapping"], dict)
|
||||
|
||||
def test_runtime_route_state_json_serialization_structure(self):
|
||||
"""Test RuntimeRouteState JSON serialization structure."""
|
||||
node_state = RouteNodeState(node_id="json_node", start_at=_TEST_DATETIME)
|
||||
|
||||
runtime_state = RuntimeRouteState(
|
||||
routes={"source": ["target1", "target2"]}, node_state_mapping={node_state.id: node_state}
|
||||
)
|
||||
|
||||
json_str = runtime_state.model_dump_json()
|
||||
json_data = json.loads(json_str)
|
||||
|
||||
# Focus on JSON structure validation
|
||||
assert isinstance(json_str, str)
|
||||
assert isinstance(json_data, dict)
|
||||
assert "routes" in json_data
|
||||
assert "node_state_mapping" in json_data
|
||||
assert json_data["routes"]["source"] == ["target1", "target2"]
|
||||
assert node_state.id in json_data["node_state_mapping"]
|
||||
|
||||
def test_runtime_route_state_deserialization_from_dict(self):
|
||||
"""Test RuntimeRouteState deserialization from dictionary data."""
|
||||
node_id = str(uuid.uuid4())
|
||||
|
||||
dict_data = {
|
||||
"routes": {"source_node": ["target_node_1", "target_node_2"]},
|
||||
"node_state_mapping": {
|
||||
node_id: {
|
||||
"id": node_id,
|
||||
"node_id": "test_node",
|
||||
"start_at": _TEST_DATETIME,
|
||||
"status": "running",
|
||||
"index": 1,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
runtime_state = RuntimeRouteState.model_validate(dict_data)
|
||||
|
||||
# Focus on deserialization accuracy
|
||||
assert runtime_state.routes == {"source_node": ["target_node_1", "target_node_2"]}
|
||||
assert len(runtime_state.node_state_mapping) == 1
|
||||
assert node_id in runtime_state.node_state_mapping
|
||||
|
||||
deserialized_node = runtime_state.node_state_mapping[node_id]
|
||||
assert deserialized_node.node_id == "test_node"
|
||||
assert deserialized_node.status == RouteNodeState.Status.RUNNING
|
||||
assert deserialized_node.index == 1
|
||||
|
||||
def test_runtime_route_state_round_trip_consistency(self):
|
||||
"""Test RuntimeRouteState round-trip serialization consistency."""
|
||||
original = self._get_runtime_route_state()
|
||||
|
||||
# Dictionary round trip
|
||||
dict_data = original.model_dump(mode="python")
|
||||
reconstructed = RuntimeRouteState.model_validate(dict_data)
|
||||
assert reconstructed == original
|
||||
|
||||
dict_data = original.model_dump(mode="json")
|
||||
reconstructed = RuntimeRouteState.model_validate(dict_data)
|
||||
assert reconstructed == original
|
||||
|
||||
# JSON round trip
|
||||
json_str = original.model_dump_json()
|
||||
json_reconstructed = RuntimeRouteState.model_validate_json(json_str)
|
||||
assert json_reconstructed == original
|
||||
|
||||
|
||||
class TestSerializationEdgeCases:
|
||||
"""Test edge cases and error conditions for serialization/deserialization."""
|
||||
|
||||
def test_invalid_status_deserialization(self):
|
||||
"""Test deserialization with invalid status values."""
|
||||
test_datetime = _TEST_DATETIME
|
||||
invalid_data = {
|
||||
"node_id": "invalid_test",
|
||||
"start_at": test_datetime,
|
||||
"status": "invalid_status",
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
RouteNodeState.model_validate(invalid_data)
|
||||
assert "status" in str(exc_info.value)
|
||||
|
||||
def test_missing_required_fields_deserialization(self):
|
||||
"""Test deserialization with missing required fields."""
|
||||
incomplete_data = {"id": str(uuid.uuid4())}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
RouteNodeState.model_validate(incomplete_data)
|
||||
error_str = str(exc_info.value)
|
||||
assert "node_id" in error_str or "start_at" in error_str
|
||||
|
||||
def test_invalid_datetime_deserialization(self):
|
||||
"""Test deserialization with invalid datetime values."""
|
||||
invalid_data = {
|
||||
"node_id": "datetime_test",
|
||||
"start_at": "invalid_datetime",
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
RouteNodeState.model_validate(invalid_data)
|
||||
assert "start_at" in str(exc_info.value)
|
||||
|
||||
def test_invalid_routes_structure_deserialization(self):
|
||||
"""Test RuntimeRouteState deserialization with invalid routes structure."""
|
||||
invalid_data = {
|
||||
"routes": "invalid_routes_structure", # Should be dict
|
||||
"node_state_mapping": {},
|
||||
}
|
||||
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
RuntimeRouteState.model_validate(invalid_data)
|
||||
assert "routes" in str(exc_info.value)
|
||||
|
||||
def test_timezone_handling_in_datetime_fields(self):
|
||||
"""Test timezone handling in datetime field serialization."""
|
||||
utc_datetime = datetime.now(UTC)
|
||||
naive_datetime = utc_datetime.replace(tzinfo=None)
|
||||
|
||||
node_state = RouteNodeState(node_id="timezone_test", start_at=naive_datetime)
|
||||
dict_ = node_state.model_dump()
|
||||
|
||||
assert dict_["start_at"] == naive_datetime
|
||||
|
||||
# Test round trip
|
||||
reconstructed = RouteNodeState.model_validate(dict_)
|
||||
assert reconstructed.start_at == naive_datetime
|
||||
assert reconstructed.start_at.tzinfo is None
|
||||
|
||||
json = node_state.model_dump_json()
|
||||
|
||||
reconstructed = RouteNodeState.model_validate_json(json)
|
||||
assert reconstructed.start_at == naive_datetime
|
||||
assert reconstructed.start_at.tzinfo is None
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue