Merge branch 'main' into feat/tool-plugin-oauth
commit
8a954c0b19
@ -1,65 +0,0 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
@message_was_created.connect
|
||||
def handle(sender, **kwargs):
|
||||
message = sender
|
||||
application_generate_entity = kwargs.get("application_generate_entity")
|
||||
|
||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||
return
|
||||
|
||||
model_config = application_generate_entity.model_conf
|
||||
provider_model_bundle = model_config.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
if not system_configuration.current_quota_type:
|
||||
return
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = message.message_tokens + message.answer_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_config.model)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_config.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
).update(
|
||||
{
|
||||
"quota_used": Provider.quota_used + used_quota,
|
||||
"last_used": datetime.now(tz=UTC).replace(tzinfo=None),
|
||||
}
|
||||
)
|
||||
db.session.commit()
|
||||
@ -1,20 +0,0 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from models.provider import Provider
|
||||
|
||||
|
||||
@message_was_created.connect
|
||||
def handle(sender, **kwargs):
|
||||
application_generate_entity = kwargs.get("application_generate_entity")
|
||||
|
||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||
return
|
||||
|
||||
db.session.query(Provider).filter(
|
||||
Provider.tenant_id == application_generate_entity.app_config.tenant_id,
|
||||
Provider.provider_name == application_generate_entity.model_conf.provider,
|
||||
).update({"last_used": datetime.now(UTC).replace(tzinfo=None)})
|
||||
db.session.commit()
|
||||
@ -0,0 +1,234 @@
|
||||
import logging
|
||||
import time as time_module
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
from libs import datetime_utils
|
||||
from models.model import Message
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _ProviderUpdateFilters(BaseModel):
|
||||
"""Filters for identifying Provider records to update."""
|
||||
|
||||
tenant_id: str
|
||||
provider_name: str
|
||||
provider_type: Optional[str] = None
|
||||
quota_type: Optional[str] = None
|
||||
|
||||
|
||||
class _ProviderUpdateAdditionalFilters(BaseModel):
|
||||
"""Additional filters for Provider updates."""
|
||||
|
||||
quota_limit_check: bool = False
|
||||
|
||||
|
||||
class _ProviderUpdateValues(BaseModel):
|
||||
"""Values to update in Provider records."""
|
||||
|
||||
last_used: Optional[datetime] = None
|
||||
quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression
|
||||
|
||||
|
||||
class _ProviderUpdateOperation(BaseModel):
|
||||
"""A single Provider update operation."""
|
||||
|
||||
filters: _ProviderUpdateFilters
|
||||
values: _ProviderUpdateValues
|
||||
additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
|
||||
description: str = "unknown"
|
||||
|
||||
|
||||
@message_was_created.connect
|
||||
def handle(sender: Message, **kwargs):
|
||||
"""
|
||||
Consolidated handler for Provider updates when a message is created.
|
||||
|
||||
This handler replaces both:
|
||||
- update_provider_last_used_at_when_message_created
|
||||
- deduct_quota_when_message_created
|
||||
|
||||
By performing all Provider updates in a single transaction, we ensure
|
||||
consistency and efficiency when updating Provider records.
|
||||
"""
|
||||
message = sender
|
||||
application_generate_entity = kwargs.get("application_generate_entity")
|
||||
|
||||
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
|
||||
return
|
||||
|
||||
tenant_id = application_generate_entity.app_config.tenant_id
|
||||
provider_name = application_generate_entity.model_conf.provider
|
||||
current_time = datetime_utils.naive_utc_now()
|
||||
|
||||
# Prepare updates for both scenarios
|
||||
updates_to_perform: list[_ProviderUpdateOperation] = []
|
||||
|
||||
# 1. Always update last_used for the provider
|
||||
basic_update = _ProviderUpdateOperation(
|
||||
filters=_ProviderUpdateFilters(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
),
|
||||
values=_ProviderUpdateValues(last_used=current_time),
|
||||
description="basic_last_used_update",
|
||||
)
|
||||
updates_to_perform.append(basic_update)
|
||||
|
||||
# 2. Check if we need to deduct quota (system provider only)
|
||||
model_config = application_generate_entity.model_conf
|
||||
provider_model_bundle = model_config.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if (
|
||||
provider_configuration.using_provider_type == ProviderType.SYSTEM
|
||||
and provider_configuration.system_configuration
|
||||
and provider_configuration.system_configuration.current_quota_type is not None
|
||||
):
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
# Calculate quota usage
|
||||
used_quota = _calculate_quota_usage(
|
||||
message=message,
|
||||
system_configuration=system_configuration,
|
||||
model_name=model_config.model,
|
||||
)
|
||||
|
||||
if used_quota is not None:
|
||||
quota_update = _ProviderUpdateOperation(
|
||||
filters=_ProviderUpdateFilters(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=ModelProviderID(model_config.provider).provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=provider_configuration.system_configuration.current_quota_type.value,
|
||||
),
|
||||
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
|
||||
additional_filters=_ProviderUpdateAdditionalFilters(
|
||||
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
|
||||
),
|
||||
description="quota_deduction_update",
|
||||
)
|
||||
updates_to_perform.append(quota_update)
|
||||
|
||||
# Execute all updates
|
||||
start_time = time_module.perf_counter()
|
||||
try:
|
||||
_execute_provider_updates(updates_to_perform)
|
||||
|
||||
# Log successful completion with timing
|
||||
duration = time_module.perf_counter() - start_time
|
||||
|
||||
logger.info(
|
||||
f"Provider updates completed successfully. "
|
||||
f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, "
|
||||
f"Tenant: {tenant_id}, Provider: {provider_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Log failure with timing and context
|
||||
duration = time_module.perf_counter() - start_time
|
||||
|
||||
logger.exception(
|
||||
f"Provider updates failed after {duration:.3f}s. "
|
||||
f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, "
|
||||
f"Provider: {provider_name}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def _calculate_quota_usage(
|
||||
*, message: Message, system_configuration: SystemConfiguration, model_name: str
|
||||
) -> Optional[int]:
|
||||
"""Calculate quota usage based on message tokens and quota type."""
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return None
|
||||
break
|
||||
if quota_unit is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
tokens = message.message_tokens + message.answer_tokens
|
||||
return tokens
|
||||
if quota_unit == QuotaUnit.CREDITS:
|
||||
tokens = dify_config.get_model_credits(model_name)
|
||||
return tokens
|
||||
elif quota_unit == QuotaUnit.TIMES:
|
||||
return 1
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception("Failed to calculate quota usage")
|
||||
return None
|
||||
|
||||
|
||||
def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
|
||||
"""Execute all Provider updates in a single transaction."""
|
||||
if not updates_to_perform:
|
||||
return
|
||||
|
||||
# Use SQLAlchemy's context manager for transaction management
|
||||
# This automatically handles commit/rollback
|
||||
with Session(db.engine) as session:
|
||||
# Use a single transaction for all updates
|
||||
for update_operation in updates_to_perform:
|
||||
filters = update_operation.filters
|
||||
values = update_operation.values
|
||||
additional_filters = update_operation.additional_filters
|
||||
description = update_operation.description
|
||||
|
||||
# Build the where conditions
|
||||
where_conditions = [
|
||||
Provider.tenant_id == filters.tenant_id,
|
||||
Provider.provider_name == filters.provider_name,
|
||||
]
|
||||
|
||||
# Add additional filters if specified
|
||||
if filters.provider_type is not None:
|
||||
where_conditions.append(Provider.provider_type == filters.provider_type)
|
||||
if filters.quota_type is not None:
|
||||
where_conditions.append(Provider.quota_type == filters.quota_type)
|
||||
if additional_filters.quota_limit_check:
|
||||
where_conditions.append(Provider.quota_limit > Provider.quota_used)
|
||||
|
||||
# Prepare values dict for SQLAlchemy update
|
||||
update_values = {}
|
||||
if values.last_used is not None:
|
||||
update_values["last_used"] = values.last_used
|
||||
if values.quota_used is not None:
|
||||
update_values["quota_used"] = values.quota_used
|
||||
|
||||
# Build and execute the update statement
|
||||
stmt = update(Provider).where(*where_conditions).values(**update_values)
|
||||
result = session.execute(stmt)
|
||||
rows_affected = result.rowcount
|
||||
|
||||
logger.debug(
|
||||
f"Provider update ({description}): {rows_affected} rows affected. "
|
||||
f"Filters: {filters.model_dump()}, Values: {update_values}"
|
||||
)
|
||||
|
||||
# If no rows were affected for quota updates, log a warning
|
||||
if rows_affected == 0 and description == "quota_deduction_update":
|
||||
logger.warning(
|
||||
f"No Provider rows updated for quota deduction. "
|
||||
f"This may indicate quota limit exceeded or provider not found. "
|
||||
f"Filters: {filters.model_dump()}"
|
||||
)
|
||||
|
||||
logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates")
|
||||
@ -1,23 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from core.moderation.factory import ModerationFactory, ModerationOutputsResult
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig
|
||||
|
||||
|
||||
class ModerationService:
|
||||
def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
|
||||
app_model_config: Optional[AppModelConfig] = None
|
||||
|
||||
app_model_config = (
|
||||
db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||
)
|
||||
|
||||
if not app_model_config:
|
||||
raise ValueError("app model config not found")
|
||||
|
||||
name = app_model_config.sensitive_word_avoidance_dict["type"]
|
||||
config = app_model_config.sensitive_word_avoidance_dict["config"]
|
||||
|
||||
moderation = ModerationFactory(name, app_id, app_model.tenant_id, config)
|
||||
return moderation.moderation_for_outputs(text)
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
Before Width: | Height: | Size: 60 KiB After Width: | Height: | Size: 187 KiB |
@ -0,0 +1,248 @@
|
||||
import threading
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from events.event_handlers.update_provider_when_message_created import (
|
||||
handle,
|
||||
get_update_stats,
|
||||
)
|
||||
from models.provider import ProviderType
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
|
||||
class TestProviderUpdateDeadlockPrevention:
|
||||
"""Test suite for deadlock prevention in Provider updates."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test fixtures."""
|
||||
self.mock_message = Mock()
|
||||
self.mock_message.answer_tokens = 100
|
||||
|
||||
self.mock_app_config = Mock()
|
||||
self.mock_app_config.tenant_id = "test-tenant-123"
|
||||
|
||||
self.mock_model_conf = Mock()
|
||||
self.mock_model_conf.provider = "openai"
|
||||
|
||||
self.mock_system_config = Mock()
|
||||
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
|
||||
|
||||
self.mock_provider_config = Mock()
|
||||
self.mock_provider_config.using_provider_type = ProviderType.SYSTEM
|
||||
self.mock_provider_config.system_configuration = self.mock_system_config
|
||||
|
||||
self.mock_provider_bundle = Mock()
|
||||
self.mock_provider_bundle.configuration = self.mock_provider_config
|
||||
|
||||
self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle
|
||||
|
||||
self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity)
|
||||
self.mock_generate_entity.app_config = self.mock_app_config
|
||||
self.mock_generate_entity.model_conf = self.mock_model_conf
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
def test_consolidated_handler_basic_functionality(self, mock_db):
|
||||
"""Test that the consolidated handler performs both updates correctly."""
|
||||
# Setup mock query chain
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1 # 1 row affected
|
||||
|
||||
# Call the handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify db.session.query was called
|
||||
assert mock_db.session.query.called
|
||||
|
||||
# Verify commit was called
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
# Verify no rollback was called
|
||||
assert not mock_db.session.rollback.called
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
def test_deadlock_retry_mechanism(self, mock_db):
|
||||
"""Test that deadlock errors trigger retry logic."""
|
||||
# Setup mock to raise deadlock error on first attempt, succeed on second
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
# First call raises deadlock, second succeeds
|
||||
mock_db.session.commit.side_effect = [
|
||||
OperationalError("deadlock detected", None, None),
|
||||
None, # Success on retry
|
||||
]
|
||||
|
||||
# Call the handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify commit was called twice (original + retry)
|
||||
assert mock_db.session.commit.call_count == 2
|
||||
|
||||
# Verify rollback was called once (after first failure)
|
||||
mock_db.session.rollback.assert_called_once()
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
@patch("events.event_handlers.update_provider_when_message_created.time.sleep")
|
||||
def test_exponential_backoff_timing(self, mock_sleep, mock_db):
|
||||
"""Test that retry delays follow exponential backoff pattern."""
|
||||
# Setup mock to fail twice, succeed on third attempt
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
mock_db.session.commit.side_effect = [
|
||||
OperationalError("deadlock detected", None, None),
|
||||
OperationalError("deadlock detected", None, None),
|
||||
None, # Success on third attempt
|
||||
]
|
||||
|
||||
# Call the handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify sleep was called twice with increasing delays
|
||||
assert mock_sleep.call_count == 2
|
||||
|
||||
# First delay should be around 0.1s + jitter
|
||||
first_delay = mock_sleep.call_args_list[0][0][0]
|
||||
assert 0.1 <= first_delay <= 0.3
|
||||
|
||||
# Second delay should be around 0.2s + jitter
|
||||
second_delay = mock_sleep.call_args_list[1][0][0]
|
||||
assert 0.2 <= second_delay <= 0.4
|
||||
|
||||
def test_concurrent_handler_execution(self):
|
||||
"""Test that multiple handlers can run concurrently without deadlock."""
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def run_handler():
|
||||
try:
|
||||
with patch(
|
||||
"events.event_handlers.update_provider_when_message_created.db"
|
||||
) as mock_db:
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
handle(
|
||||
self.mock_message,
|
||||
application_generate_entity=self.mock_generate_entity,
|
||||
)
|
||||
results.append("success")
|
||||
except Exception as e:
|
||||
errors.append(str(e))
|
||||
|
||||
# Run multiple handlers concurrently
|
||||
threads = []
|
||||
for _ in range(5):
|
||||
thread = threading.Thread(target=run_handler)
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join(timeout=5)
|
||||
|
||||
# Verify all handlers completed successfully
|
||||
assert len(results) == 5
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_performance_stats_tracking(self):
|
||||
"""Test that performance statistics are tracked correctly."""
|
||||
# Reset stats
|
||||
stats = get_update_stats()
|
||||
initial_total = stats["total_updates"]
|
||||
|
||||
with patch(
|
||||
"events.event_handlers.update_provider_when_message_created.db"
|
||||
) as mock_db:
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
# Call handler
|
||||
handle(
|
||||
self.mock_message, application_generate_entity=self.mock_generate_entity
|
||||
)
|
||||
|
||||
# Check that stats were updated
|
||||
updated_stats = get_update_stats()
|
||||
assert updated_stats["total_updates"] == initial_total + 1
|
||||
assert updated_stats["successful_updates"] >= initial_total + 1
|
||||
|
||||
def test_non_chat_entity_ignored(self):
|
||||
"""Test that non-chat entities are ignored by the handler."""
|
||||
# Create a non-chat entity
|
||||
mock_non_chat_entity = Mock()
|
||||
mock_non_chat_entity.__class__.__name__ = "NonChatEntity"
|
||||
|
||||
with patch(
|
||||
"events.event_handlers.update_provider_when_message_created.db"
|
||||
) as mock_db:
|
||||
# Call handler with non-chat entity
|
||||
handle(self.mock_message, application_generate_entity=mock_non_chat_entity)
|
||||
|
||||
# Verify no database operations were performed
|
||||
assert not mock_db.session.query.called
|
||||
assert not mock_db.session.commit.called
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
def test_quota_calculation_tokens(self, mock_db):
|
||||
"""Test quota calculation for token-based quotas."""
|
||||
# Setup token-based quota
|
||||
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
|
||||
self.mock_message.answer_tokens = 150
|
||||
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
# Call handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify update was called with token count
|
||||
update_calls = mock_query.update.call_args_list
|
||||
|
||||
# Should have at least one call with quota_used update
|
||||
quota_update_found = False
|
||||
for call in update_calls:
|
||||
values = call[0][0] # First argument to update()
|
||||
if "quota_used" in values:
|
||||
quota_update_found = True
|
||||
break
|
||||
|
||||
assert quota_update_found
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
def test_quota_calculation_times(self, mock_db):
|
||||
"""Test quota calculation for times-based quotas."""
|
||||
# Setup times-based quota
|
||||
self.mock_system_config.current_quota_type = QuotaUnit.TIMES
|
||||
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
# Call handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify update was called
|
||||
assert mock_query.update.called
|
||||
assert mock_db.session.commit.called
|
||||
Loading…
Reference in New Issue