fix(event_handlers): DB dead lock
Signed-off-by: -LAN- <laipz8200@outlook.com>pull/21468/head
parent
164e5481c5
commit
4ac2715c2a
@ -0,0 +1,296 @@
|
|||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
import time as time_module
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.exc import DatabaseError, IntegrityError, OperationalError
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||||
|
from core.entities.provider_entities import QuotaUnit
|
||||||
|
from core.plugin.entities.plugin import ModelProviderID
|
||||||
|
from events.message_event import message_was_created
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.provider import Provider, ProviderType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Performance monitoring counters
|
||||||
|
_update_stats = {
|
||||||
|
"total_updates": 0,
|
||||||
|
"successful_updates": 0,
|
||||||
|
"failed_updates": 0,
|
||||||
|
"deadlock_retries": 0,
|
||||||
|
"total_duration": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@message_was_created.connect
|
||||||
|
def handle(sender, **kwargs):
|
||||||
|
"""
|
||||||
|
Consolidated handler for Provider updates when a message is created.
|
||||||
|
|
||||||
|
This handler replaces both:
|
||||||
|
- update_provider_last_used_at_when_message_created
|
||||||
|
- deduct_quota_when_message_created
|
||||||
|
|
||||||
|
By performing all Provider updates in a single transaction, we eliminate
|
||||||
|
the deadlock that occurred when both handlers tried to update the same
|
||||||
|
Provider records concurrently.
|
||||||
|
"""
|
||||||
|
message = sender
|
||||||
|
application_generate_entity = kwargs.get("application_generate_entity")
|
||||||
|
|
||||||
|
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||||
|
return
|
||||||
|
|
||||||
|
tenant_id = application_generate_entity.app_config.tenant_id
|
||||||
|
provider_name = application_generate_entity.model_conf.provider
|
||||||
|
current_time = datetime.now(UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
|
# Prepare updates for both scenarios
|
||||||
|
updates_to_perform: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
# 1. Always update last_used for the provider
|
||||||
|
basic_update = {
|
||||||
|
"filters": {
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"provider_name": provider_name,
|
||||||
|
},
|
||||||
|
"values": {"last_used": current_time},
|
||||||
|
"description": "basic_last_used_update",
|
||||||
|
}
|
||||||
|
updates_to_perform.append(basic_update)
|
||||||
|
|
||||||
|
# 2. Check if we need to deduct quota (system provider only)
|
||||||
|
model_config = application_generate_entity.model_conf
|
||||||
|
provider_model_bundle = model_config.provider_model_bundle
|
||||||
|
provider_configuration = provider_model_bundle.configuration
|
||||||
|
|
||||||
|
if (
|
||||||
|
provider_configuration.using_provider_type == ProviderType.SYSTEM
|
||||||
|
and hasattr(provider_configuration, "system_configuration")
|
||||||
|
and provider_configuration.system_configuration
|
||||||
|
and provider_configuration.system_configuration.current_quota_type is not None
|
||||||
|
):
|
||||||
|
system_configuration = provider_configuration.system_configuration
|
||||||
|
|
||||||
|
# Calculate quota usage
|
||||||
|
used_quota = _calculate_quota_usage(message, system_configuration)
|
||||||
|
|
||||||
|
if used_quota is not None and used_quota > 0:
|
||||||
|
quota_update = {
|
||||||
|
"filters": {
|
||||||
|
"tenant_id": tenant_id,
|
||||||
|
"provider_name": ModelProviderID(model_config.provider).provider_name,
|
||||||
|
"provider_type": ProviderType.SYSTEM.value,
|
||||||
|
"quota_type": system_configuration.current_quota_type.value
|
||||||
|
if system_configuration.current_quota_type
|
||||||
|
else None,
|
||||||
|
},
|
||||||
|
"values": {"quota_used": Provider.quota_used + used_quota, "last_used": current_time},
|
||||||
|
"additional_filters": {
|
||||||
|
"quota_limit_check": True # Provider.quota_limit > Provider.quota_used
|
||||||
|
},
|
||||||
|
"description": "quota_deduction_update",
|
||||||
|
}
|
||||||
|
updates_to_perform.append(quota_update)
|
||||||
|
|
||||||
|
# Execute all updates with retry logic for deadlock prevention
|
||||||
|
start_time = time_module.perf_counter()
|
||||||
|
try:
|
||||||
|
_execute_provider_updates_with_retry(updates_to_perform)
|
||||||
|
|
||||||
|
# Log successful completion with timing and update stats
|
||||||
|
duration = time_module.perf_counter() - start_time
|
||||||
|
_update_stats["total_updates"] += 1
|
||||||
|
_update_stats["successful_updates"] += 1
|
||||||
|
_update_stats["total_duration"] += duration
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Provider updates completed successfully. "
|
||||||
|
f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, "
|
||||||
|
f"Tenant: {tenant_id}, Provider: {provider_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log performance stats periodically
|
||||||
|
if _update_stats["total_updates"] % 100 == 0:
|
||||||
|
_log_performance_stats()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Log failure with timing and context, update stats
|
||||||
|
duration = time_module.perf_counter() - start_time
|
||||||
|
_update_stats["total_updates"] += 1
|
||||||
|
_update_stats["failed_updates"] += 1
|
||||||
|
_update_stats["total_duration"] += duration
|
||||||
|
|
||||||
|
logger.exception(
|
||||||
|
f"Provider updates failed after {duration:.3f}s. "
|
||||||
|
f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, "
|
||||||
|
f"Provider: {provider_name}"
|
||||||
|
)
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_quota_usage(message: Any, system_configuration: Any) -> Optional[int]:
|
||||||
|
"""Calculate quota usage based on message tokens and quota type."""
|
||||||
|
try:
|
||||||
|
if system_configuration.current_quota_type == QuotaUnit.TOKENS:
|
||||||
|
tokens = getattr(message, "answer_tokens", None)
|
||||||
|
return int(tokens) if tokens is not None else None
|
||||||
|
elif system_configuration.current_quota_type == QuotaUnit.TIMES:
|
||||||
|
return 1
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to calculate quota usage: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _execute_provider_updates_with_retry(updates_to_perform: list[dict[str, Any]], max_retries: int = 3):
|
||||||
|
"""
|
||||||
|
Execute Provider updates with deadlock retry logic.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
updates_to_perform: List of update operations to perform
|
||||||
|
max_retries: Maximum number of retry attempts for deadlock recovery
|
||||||
|
"""
|
||||||
|
for attempt in range(max_retries + 1):
|
||||||
|
try:
|
||||||
|
_execute_provider_updates(updates_to_perform)
|
||||||
|
return # Success, exit retry loop
|
||||||
|
|
||||||
|
except IntegrityError as e:
|
||||||
|
# Don't retry integrity constraint violations
|
||||||
|
logger.exception("Integrity constraint violation in Provider update")
|
||||||
|
raise
|
||||||
|
|
||||||
|
except (OperationalError, DatabaseError) as e:
|
||||||
|
error_msg = str(e).lower()
|
||||||
|
|
||||||
|
# Check for various deadlock/lock timeout patterns across different databases
|
||||||
|
is_deadlock = any(
|
||||||
|
pattern in error_msg
|
||||||
|
for pattern in [
|
||||||
|
"deadlock detected",
|
||||||
|
"deadlock found",
|
||||||
|
"lock wait timeout",
|
||||||
|
"lock timeout",
|
||||||
|
"serialization failure",
|
||||||
|
"could not serialize access",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_deadlock and attempt < max_retries:
|
||||||
|
# Track deadlock retry statistics
|
||||||
|
_update_stats["deadlock_retries"] += 1
|
||||||
|
|
||||||
|
# Exponential backoff with jitter for deadlock recovery
|
||||||
|
base_delay = (2**attempt) * 0.1
|
||||||
|
jitter = secrets.randbelow(100) / 1000.0 # 0.0 to 0.099
|
||||||
|
delay = base_delay + jitter
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f"Database lock conflict detected in Provider update "
|
||||||
|
f"(attempt {attempt + 1}/{max_retries + 1}). "
|
||||||
|
f"Retrying in {delay:.2f}s. Error: {e}"
|
||||||
|
)
|
||||||
|
time.sleep(delay)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Not a deadlock or max retries exceeded
|
||||||
|
logger.exception(f"Failed to update Provider after {attempt + 1} attempts")
|
||||||
|
raise
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Unexpected error updating Provider")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _execute_provider_updates(updates_to_perform: list[dict[str, Any]]):
|
||||||
|
"""Execute all Provider updates in a single transaction with proper isolation."""
|
||||||
|
if not updates_to_perform:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Start a new transaction with explicit isolation level
|
||||||
|
try:
|
||||||
|
# Set transaction isolation level to prevent phantom reads
|
||||||
|
# This helps reduce deadlock probability
|
||||||
|
db.session.execute(text("SET TRANSACTION ISOLATION LEVEL READ COMMITTED"))
|
||||||
|
|
||||||
|
# Use a single transaction for all updates
|
||||||
|
for update_info in updates_to_perform:
|
||||||
|
filters = update_info["filters"]
|
||||||
|
values = update_info["values"]
|
||||||
|
additional_filters = update_info.get("additional_filters", {})
|
||||||
|
description = update_info.get("description", "unknown")
|
||||||
|
|
||||||
|
# Build the query with consistent ordering to prevent deadlocks
|
||||||
|
# Order by tenant_id, provider_name to ensure consistent lock acquisition
|
||||||
|
query = (
|
||||||
|
db.session.query(Provider)
|
||||||
|
.filter(Provider.tenant_id == filters["tenant_id"], Provider.provider_name == filters["provider_name"])
|
||||||
|
.order_by(Provider.tenant_id, Provider.provider_name)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add additional filters if specified
|
||||||
|
if "provider_type" in filters:
|
||||||
|
query = query.filter(Provider.provider_type == filters["provider_type"])
|
||||||
|
if "quota_type" in filters:
|
||||||
|
query = query.filter(Provider.quota_type == filters["quota_type"])
|
||||||
|
if additional_filters.get("quota_limit_check"):
|
||||||
|
query = query.filter(Provider.quota_limit > Provider.quota_used)
|
||||||
|
|
||||||
|
# Execute the update
|
||||||
|
rows_affected = query.update(values, synchronize_session=False)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Provider update ({description}): {rows_affected} rows affected. Filters: {filters}, Values: {values}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# If no rows were affected for quota updates, log a warning
|
||||||
|
if rows_affected == 0 and description == "quota_deduction_update":
|
||||||
|
logger.warning(
|
||||||
|
f"No Provider rows updated for quota deduction. "
|
||||||
|
f"This may indicate quota limit exceeded or provider not found. "
|
||||||
|
f"Filters: {filters}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Commit all updates in a single transaction
|
||||||
|
db.session.commit()
|
||||||
|
logger.debug(f"Successfully committed {len(updates_to_perform)} Provider updates")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
# Rollback on any error
|
||||||
|
try:
|
||||||
|
db.session.rollback()
|
||||||
|
logger.debug("Transaction rolled back successfully")
|
||||||
|
except Exception as rollback_error:
|
||||||
|
logger.exception("Error during rollback")
|
||||||
|
|
||||||
|
logger.exception("Failed to update Provider")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
def _log_performance_stats():
|
||||||
|
"""Log performance statistics for Provider updates."""
|
||||||
|
stats = _update_stats.copy()
|
||||||
|
|
||||||
|
if stats["total_updates"] > 0:
|
||||||
|
success_rate = (stats["successful_updates"] / stats["total_updates"]) * 100
|
||||||
|
avg_duration = stats["total_duration"] / stats["total_updates"]
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Provider Update Performance Stats: "
|
||||||
|
f"Total: {stats['total_updates']}, "
|
||||||
|
f"Success Rate: {success_rate:.1f}%, "
|
||||||
|
f"Avg Duration: {avg_duration:.3f}s, "
|
||||||
|
f"Deadlock Retries: {stats['deadlock_retries']}, "
|
||||||
|
f"Failed: {stats['failed_updates']}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_update_stats():
|
||||||
|
"""Get current update statistics for monitoring."""
|
||||||
|
return _update_stats.copy()
|
||||||
@ -0,0 +1,248 @@
|
|||||||
|
import threading
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
|
||||||
|
from core.entities.provider_entities import QuotaUnit
|
||||||
|
from events.event_handlers.update_provider_when_message_created import (
|
||||||
|
handle,
|
||||||
|
get_update_stats,
|
||||||
|
)
|
||||||
|
from models.provider import ProviderType
|
||||||
|
from sqlalchemy.exc import OperationalError
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderUpdateDeadlockPrevention:
|
||||||
|
"""Test suite for deadlock prevention in Provider updates."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Setup test fixtures."""
|
||||||
|
self.mock_message = Mock()
|
||||||
|
self.mock_message.answer_tokens = 100
|
||||||
|
|
||||||
|
self.mock_app_config = Mock()
|
||||||
|
self.mock_app_config.tenant_id = "test-tenant-123"
|
||||||
|
|
||||||
|
self.mock_model_conf = Mock()
|
||||||
|
self.mock_model_conf.provider = "openai"
|
||||||
|
|
||||||
|
self.mock_system_config = Mock()
|
||||||
|
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
|
||||||
|
|
||||||
|
self.mock_provider_config = Mock()
|
||||||
|
self.mock_provider_config.using_provider_type = ProviderType.SYSTEM
|
||||||
|
self.mock_provider_config.system_configuration = self.mock_system_config
|
||||||
|
|
||||||
|
self.mock_provider_bundle = Mock()
|
||||||
|
self.mock_provider_bundle.configuration = self.mock_provider_config
|
||||||
|
|
||||||
|
self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle
|
||||||
|
|
||||||
|
self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity)
|
||||||
|
self.mock_generate_entity.app_config = self.mock_app_config
|
||||||
|
self.mock_generate_entity.model_conf = self.mock_model_conf
|
||||||
|
|
||||||
|
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||||
|
def test_consolidated_handler_basic_functionality(self, mock_db):
|
||||||
|
"""Test that the consolidated handler performs both updates correctly."""
|
||||||
|
# Setup mock query chain
|
||||||
|
mock_query = Mock()
|
||||||
|
mock_db.session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
mock_query.order_by.return_value = mock_query
|
||||||
|
mock_query.update.return_value = 1 # 1 row affected
|
||||||
|
|
||||||
|
# Call the handler
|
||||||
|
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||||
|
|
||||||
|
# Verify db.session.query was called
|
||||||
|
assert mock_db.session.query.called
|
||||||
|
|
||||||
|
# Verify commit was called
|
||||||
|
mock_db.session.commit.assert_called_once()
|
||||||
|
|
||||||
|
# Verify no rollback was called
|
||||||
|
assert not mock_db.session.rollback.called
|
||||||
|
|
||||||
|
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||||
|
def test_deadlock_retry_mechanism(self, mock_db):
|
||||||
|
"""Test that deadlock errors trigger retry logic."""
|
||||||
|
# Setup mock to raise deadlock error on first attempt, succeed on second
|
||||||
|
mock_query = Mock()
|
||||||
|
mock_db.session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
mock_query.order_by.return_value = mock_query
|
||||||
|
mock_query.update.return_value = 1
|
||||||
|
|
||||||
|
# First call raises deadlock, second succeeds
|
||||||
|
mock_db.session.commit.side_effect = [
|
||||||
|
OperationalError("deadlock detected", None, None),
|
||||||
|
None, # Success on retry
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call the handler
|
||||||
|
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||||
|
|
||||||
|
# Verify commit was called twice (original + retry)
|
||||||
|
assert mock_db.session.commit.call_count == 2
|
||||||
|
|
||||||
|
# Verify rollback was called once (after first failure)
|
||||||
|
mock_db.session.rollback.assert_called_once()
|
||||||
|
|
||||||
|
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||||
|
@patch("events.event_handlers.update_provider_when_message_created.time.sleep")
|
||||||
|
def test_exponential_backoff_timing(self, mock_sleep, mock_db):
|
||||||
|
"""Test that retry delays follow exponential backoff pattern."""
|
||||||
|
# Setup mock to fail twice, succeed on third attempt
|
||||||
|
mock_query = Mock()
|
||||||
|
mock_db.session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
mock_query.order_by.return_value = mock_query
|
||||||
|
mock_query.update.return_value = 1
|
||||||
|
|
||||||
|
mock_db.session.commit.side_effect = [
|
||||||
|
OperationalError("deadlock detected", None, None),
|
||||||
|
OperationalError("deadlock detected", None, None),
|
||||||
|
None, # Success on third attempt
|
||||||
|
]
|
||||||
|
|
||||||
|
# Call the handler
|
||||||
|
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||||
|
|
||||||
|
# Verify sleep was called twice with increasing delays
|
||||||
|
assert mock_sleep.call_count == 2
|
||||||
|
|
||||||
|
# First delay should be around 0.1s + jitter
|
||||||
|
first_delay = mock_sleep.call_args_list[0][0][0]
|
||||||
|
assert 0.1 <= first_delay <= 0.3
|
||||||
|
|
||||||
|
# Second delay should be around 0.2s + jitter
|
||||||
|
second_delay = mock_sleep.call_args_list[1][0][0]
|
||||||
|
assert 0.2 <= second_delay <= 0.4
|
||||||
|
|
||||||
|
def test_concurrent_handler_execution(self):
|
||||||
|
"""Test that multiple handlers can run concurrently without deadlock."""
|
||||||
|
results = []
|
||||||
|
errors = []
|
||||||
|
|
||||||
|
def run_handler():
|
||||||
|
try:
|
||||||
|
with patch(
|
||||||
|
"events.event_handlers.update_provider_when_message_created.db"
|
||||||
|
) as mock_db:
|
||||||
|
mock_query = Mock()
|
||||||
|
mock_db.session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
mock_query.order_by.return_value = mock_query
|
||||||
|
mock_query.update.return_value = 1
|
||||||
|
|
||||||
|
handle(
|
||||||
|
self.mock_message,
|
||||||
|
application_generate_entity=self.mock_generate_entity,
|
||||||
|
)
|
||||||
|
results.append("success")
|
||||||
|
except Exception as e:
|
||||||
|
errors.append(str(e))
|
||||||
|
|
||||||
|
# Run multiple handlers concurrently
|
||||||
|
threads = []
|
||||||
|
for _ in range(5):
|
||||||
|
thread = threading.Thread(target=run_handler)
|
||||||
|
threads.append(thread)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# Wait for all threads to complete
|
||||||
|
for thread in threads:
|
||||||
|
thread.join(timeout=5)
|
||||||
|
|
||||||
|
# Verify all handlers completed successfully
|
||||||
|
assert len(results) == 5
|
||||||
|
assert len(errors) == 0
|
||||||
|
|
||||||
|
def test_performance_stats_tracking(self):
|
||||||
|
"""Test that performance statistics are tracked correctly."""
|
||||||
|
# Reset stats
|
||||||
|
stats = get_update_stats()
|
||||||
|
initial_total = stats["total_updates"]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"events.event_handlers.update_provider_when_message_created.db"
|
||||||
|
) as mock_db:
|
||||||
|
mock_query = Mock()
|
||||||
|
mock_db.session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
mock_query.order_by.return_value = mock_query
|
||||||
|
mock_query.update.return_value = 1
|
||||||
|
|
||||||
|
# Call handler
|
||||||
|
handle(
|
||||||
|
self.mock_message, application_generate_entity=self.mock_generate_entity
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that stats were updated
|
||||||
|
updated_stats = get_update_stats()
|
||||||
|
assert updated_stats["total_updates"] == initial_total + 1
|
||||||
|
assert updated_stats["successful_updates"] >= initial_total + 1
|
||||||
|
|
||||||
|
def test_non_chat_entity_ignored(self):
|
||||||
|
"""Test that non-chat entities are ignored by the handler."""
|
||||||
|
# Create a non-chat entity
|
||||||
|
mock_non_chat_entity = Mock()
|
||||||
|
mock_non_chat_entity.__class__.__name__ = "NonChatEntity"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"events.event_handlers.update_provider_when_message_created.db"
|
||||||
|
) as mock_db:
|
||||||
|
# Call handler with non-chat entity
|
||||||
|
handle(self.mock_message, application_generate_entity=mock_non_chat_entity)
|
||||||
|
|
||||||
|
# Verify no database operations were performed
|
||||||
|
assert not mock_db.session.query.called
|
||||||
|
assert not mock_db.session.commit.called
|
||||||
|
|
||||||
|
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||||
|
def test_quota_calculation_tokens(self, mock_db):
|
||||||
|
"""Test quota calculation for token-based quotas."""
|
||||||
|
# Setup token-based quota
|
||||||
|
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
|
||||||
|
self.mock_message.answer_tokens = 150
|
||||||
|
|
||||||
|
mock_query = Mock()
|
||||||
|
mock_db.session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
mock_query.order_by.return_value = mock_query
|
||||||
|
mock_query.update.return_value = 1
|
||||||
|
|
||||||
|
# Call handler
|
||||||
|
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||||
|
|
||||||
|
# Verify update was called with token count
|
||||||
|
update_calls = mock_query.update.call_args_list
|
||||||
|
|
||||||
|
# Should have at least one call with quota_used update
|
||||||
|
quota_update_found = False
|
||||||
|
for call in update_calls:
|
||||||
|
values = call[0][0] # First argument to update()
|
||||||
|
if "quota_used" in values:
|
||||||
|
quota_update_found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
assert quota_update_found
|
||||||
|
|
||||||
|
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||||
|
def test_quota_calculation_times(self, mock_db):
|
||||||
|
"""Test quota calculation for times-based quotas."""
|
||||||
|
# Setup times-based quota
|
||||||
|
self.mock_system_config.current_quota_type = QuotaUnit.TIMES
|
||||||
|
|
||||||
|
mock_query = Mock()
|
||||||
|
mock_db.session.query.return_value = mock_query
|
||||||
|
mock_query.filter.return_value = mock_query
|
||||||
|
mock_query.order_by.return_value = mock_query
|
||||||
|
mock_query.update.return_value = 1
|
||||||
|
|
||||||
|
# Call handler
|
||||||
|
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||||
|
|
||||||
|
# Verify update was called
|
||||||
|
assert mock_query.update.called
|
||||||
|
assert mock_db.session.commit.called
|
||||||
Loading…
Reference in New Issue