diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py index 53ec972e0b..654ab56a2e 100644 --- a/api/events/event_handlers/update_provider_when_message_created.py +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -1,34 +1,56 @@ import logging -import secrets -import time import time as time_module from datetime import UTC, datetime from typing import Any, Optional -from sqlalchemy import text -from sqlalchemy.exc import DatabaseError, IntegrityError, OperationalError +from pydantic import BaseModel +from sqlalchemy import update +from configs import dify_config from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity -from core.entities.provider_entities import QuotaUnit +from core.entities.provider_entities import QuotaUnit, SystemConfiguration from core.plugin.entities.plugin import ModelProviderID from events.message_event import message_was_created from extensions.ext_database import db +from models.model import Message from models.provider import Provider, ProviderType logger = logging.getLogger(__name__) -# Performance monitoring counters -_update_stats = { - "total_updates": 0, - "successful_updates": 0, - "failed_updates": 0, - "deadlock_retries": 0, - "total_duration": 0.0, -} + +class _ProviderUpdateFilters(BaseModel): + """Filters for identifying Provider records to update.""" + + tenant_id: str + provider_name: str + provider_type: Optional[str] = None + quota_type: Optional[str] = None + + +class _ProviderUpdateAdditionalFilters(BaseModel): + """Additional filters for Provider updates.""" + + quota_limit_check: bool = False + + +class _ProviderUpdateValues(BaseModel): + """Values to update in Provider records.""" + + last_used: Optional[datetime] = None + quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression + + +class _ProviderUpdateOperation(BaseModel): + """A single Provider update operation.""" + + filters: _ProviderUpdateFilters + values: _ProviderUpdateValues + additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters() + description: str = "unknown" @message_was_created.connect -def handle(sender, **kwargs): +def handle(sender: Message, **kwargs): """ Consolidated handler for Provider updates when a message is created. @@ -36,9 +58,8 @@ def handle(sender, **kwargs): - update_provider_last_used_at_when_message_created - deduct_quota_when_message_created - By performing all Provider updates in a single transaction, we eliminate - the deadlock that occurred when both handlers tried to update the same - Provider records concurrently. + By performing all Provider updates in a single transaction, we ensure + consistency and efficiency when updating Provider records. """ message = sender application_generate_entity = kwargs.get("application_generate_entity") @@ -51,17 +72,17 @@ def handle(sender, **kwargs): current_time = datetime.now(UTC).replace(tzinfo=None) # Prepare updates for both scenarios - updates_to_perform: list[dict[str, Any]] = [] + updates_to_perform: list[_ProviderUpdateOperation] = [] # 1. Always update last_used for the provider - basic_update = { - "filters": { - "tenant_id": tenant_id, - "provider_name": provider_name, - }, - "values": {"last_used": current_time}, - "description": "basic_last_used_update", - } + basic_update = _ProviderUpdateOperation( + filters=_ProviderUpdateFilters( + tenant_id=tenant_id, + provider_name=provider_name, + ), + values=_ProviderUpdateValues(last_used=current_time), + description="basic_last_used_update", + ) updates_to_perform.append(basic_update) # 2. Check if we need to deduct quota (system provider only) @@ -71,43 +92,41 @@ def handle(sender, **kwargs): if ( provider_configuration.using_provider_type == ProviderType.SYSTEM - and hasattr(provider_configuration, "system_configuration") and provider_configuration.system_configuration and provider_configuration.system_configuration.current_quota_type is not None ): system_configuration = provider_configuration.system_configuration # Calculate quota usage - used_quota = _calculate_quota_usage(message, system_configuration) - - if used_quota is not None and used_quota > 0: - quota_update = { - "filters": { - "tenant_id": tenant_id, - "provider_name": ModelProviderID(model_config.provider).provider_name, - "provider_type": ProviderType.SYSTEM.value, - "quota_type": system_configuration.current_quota_type.value - if system_configuration.current_quota_type - else None, - }, - "values": {"quota_used": Provider.quota_used + used_quota, "last_used": current_time}, - "additional_filters": { - "quota_limit_check": True # Provider.quota_limit > Provider.quota_used - }, - "description": "quota_deduction_update", - } + used_quota = _calculate_quota_usage( + message=message, + system_configuration=system_configuration, + model_name=model_config.model, + ) + + if used_quota is not None: + quota_update = _ProviderUpdateOperation( + filters=_ProviderUpdateFilters( + tenant_id=tenant_id, + provider_name=ModelProviderID(model_config.provider).provider_name, + provider_type=ProviderType.SYSTEM.value, + quota_type=provider_configuration.system_configuration.current_quota_type.value, + ), + values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time), + additional_filters=_ProviderUpdateAdditionalFilters( + quota_limit_check=True # Provider.quota_limit > Provider.quota_used + ), + description="quota_deduction_update", + ) updates_to_perform.append(quota_update) - # Execute all updates with retry logic for deadlock prevention + # Execute all updates start_time = time_module.perf_counter() try: - _execute_provider_updates_with_retry(updates_to_perform) + _execute_provider_updates(updates_to_perform) - # Log successful completion with timing and update stats + # Log successful completion with timing duration = time_module.perf_counter() - start_time - _update_stats["total_updates"] += 1 - _update_stats["successful_updates"] += 1 - _update_stats["total_duration"] += duration logger.info( f"Provider updates completed successfully. " @@ -115,16 +134,9 @@ def handle(sender, **kwargs): f"Tenant: {tenant_id}, Provider: {provider_name}" ) - # Log performance stats periodically - if _update_stats["total_updates"] % 100 == 0: - _log_performance_stats() - except Exception as e: - # Log failure with timing and context, update stats + # Log failure with timing and context duration = time_module.perf_counter() - start_time - _update_stats["total_updates"] += 1 - _update_stats["failed_updates"] += 1 - _update_stats["total_duration"] += duration logger.exception( f"Provider updates failed after {duration:.3f}s. " @@ -134,13 +146,28 @@ def handle(sender, **kwargs): raise -def _calculate_quota_usage(message: Any, system_configuration: Any) -> Optional[int]: +def _calculate_quota_usage( + *, message: Message, system_configuration: SystemConfiguration, model_name: str +) -> Optional[int]: """Calculate quota usage based on message tokens and quota type.""" + quota_unit = None + for quota_configuration in system_configuration.quota_configurations: + if quota_configuration.quota_type == system_configuration.current_quota_type: + quota_unit = quota_configuration.quota_unit + if quota_configuration.quota_limit == -1: + return None + break + if quota_unit is None: + return None + try: - if system_configuration.current_quota_type == QuotaUnit.TOKENS: - tokens = getattr(message, "answer_tokens", None) - return int(tokens) if tokens is not None else None - elif system_configuration.current_quota_type == QuotaUnit.TIMES: + if quota_unit == QuotaUnit.TOKENS: + tokens = message.message_tokens + message.answer_tokens + return tokens + if quota_unit == QuotaUnit.CREDITS: + tokens = dify_config.get_model_credits(model_name) + return tokens + elif quota_unit == QuotaUnit.TIMES: return 1 return None except Exception as e: @@ -148,149 +175,64 @@ def _calculate_quota_usage(message: Any, system_configuration: Any) -> Optional[ return None -def _execute_provider_updates_with_retry(updates_to_perform: list[dict[str, Any]], max_retries: int = 3): - """ - Execute Provider updates with deadlock retry logic. - - Args: - updates_to_perform: List of update operations to perform - max_retries: Maximum number of retry attempts for deadlock recovery - """ - for attempt in range(max_retries + 1): - try: - _execute_provider_updates(updates_to_perform) - return # Success, exit retry loop - - except IntegrityError as e: - # Don't retry integrity constraint violations - logger.exception("Integrity constraint violation in Provider update") - raise - - except (OperationalError, DatabaseError) as e: - error_msg = str(e).lower() - - # Check for various deadlock/lock timeout patterns across different databases - is_deadlock = any( - pattern in error_msg - for pattern in [ - "deadlock detected", - "deadlock found", - "lock wait timeout", - "lock timeout", - "serialization failure", - "could not serialize access", - ] - ) - - if is_deadlock and attempt < max_retries: - # Track deadlock retry statistics - _update_stats["deadlock_retries"] += 1 - - # Exponential backoff with jitter for deadlock recovery - base_delay = (2**attempt) * 0.1 - jitter = secrets.randbelow(100) / 1000.0 # 0.0 to 0.099 - delay = base_delay + jitter - - logger.warning( - f"Database lock conflict detected in Provider update " - f"(attempt {attempt + 1}/{max_retries + 1}). " - f"Retrying in {delay:.2f}s. Error: {e}" - ) - time.sleep(delay) - continue - else: - # Not a deadlock or max retries exceeded - logger.exception(f"Failed to update Provider after {attempt + 1} attempts") - raise - - except Exception as e: - logger.exception("Unexpected error updating Provider") - raise - - -def _execute_provider_updates(updates_to_perform: list[dict[str, Any]]): - """Execute all Provider updates in a single transaction with proper isolation.""" +def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]): + """Execute all Provider updates in a single transaction.""" if not updates_to_perform: return - # Start a new transaction with explicit isolation level + # Use SQLAlchemy's context manager for transaction management + # This automatically handles commit/rollback try: - # Set transaction isolation level to prevent phantom reads - # This helps reduce deadlock probability - db.session.execute(text("SET TRANSACTION ISOLATION LEVEL READ COMMITTED")) - - # Use a single transaction for all updates - for update_info in updates_to_perform: - filters = update_info["filters"] - values = update_info["values"] - additional_filters = update_info.get("additional_filters", {}) - description = update_info.get("description", "unknown") - - # Build the query with consistent ordering to prevent deadlocks - # Order by tenant_id, provider_name to ensure consistent lock acquisition - query = ( - db.session.query(Provider) - .filter(Provider.tenant_id == filters["tenant_id"], Provider.provider_name == filters["provider_name"]) - .order_by(Provider.tenant_id, Provider.provider_name) - ) - - # Add additional filters if specified - if "provider_type" in filters: - query = query.filter(Provider.provider_type == filters["provider_type"]) - if "quota_type" in filters: - query = query.filter(Provider.quota_type == filters["quota_type"]) - if additional_filters.get("quota_limit_check"): - query = query.filter(Provider.quota_limit > Provider.quota_used) - - # Execute the update - rows_affected = query.update(values, synchronize_session=False) - - logger.debug( - f"Provider update ({description}): {rows_affected} rows affected. Filters: {filters}, Values: {values}" - ) + with db.session.begin(): + # Use a single transaction for all updates + for update_operation in updates_to_perform: + filters = update_operation.filters + values = update_operation.values + additional_filters = update_operation.additional_filters + description = update_operation.description + + # Build the where conditions + where_conditions = [ + Provider.tenant_id == filters.tenant_id, + Provider.provider_name == filters.provider_name, + ] - # If no rows were affected for quota updates, log a warning - if rows_affected == 0 and description == "quota_deduction_update": - logger.warning( - f"No Provider rows updated for quota deduction. " - f"This may indicate quota limit exceeded or provider not found. " - f"Filters: {filters}" + # Add additional filters if specified + if filters.provider_type is not None: + where_conditions.append(Provider.provider_type == filters.provider_type) + if filters.quota_type is not None: + where_conditions.append(Provider.quota_type == filters.quota_type) + if additional_filters.quota_limit_check: + where_conditions.append(Provider.quota_limit > Provider.quota_used) + + # Prepare values dict for SQLAlchemy update + update_values = {} + if values.last_used is not None: + update_values["last_used"] = values.last_used + if values.quota_used is not None: + update_values["quota_used"] = values.quota_used + + # Build and execute the update statement + stmt = update(Provider).where(*where_conditions).values(**update_values) + result = db.session.execute(stmt) + rows_affected = result.rowcount + + logger.debug( + f"Provider update ({description}): {rows_affected} rows affected. " + f"Filters: {filters.model_dump()}, Values: {update_values}" ) - # Commit all updates in a single transaction - db.session.commit() - logger.debug(f"Successfully committed {len(updates_to_perform)} Provider updates") + # If no rows were affected for quota updates, log a warning + if rows_affected == 0 and description == "quota_deduction_update": + logger.warning( + f"No Provider rows updated for quota deduction. " + f"This may indicate quota limit exceeded or provider not found. " + f"Filters: {filters.model_dump()}" + ) - except Exception as e: - # Rollback on any error - try: - db.session.rollback() - logger.debug("Transaction rolled back successfully") - except Exception as rollback_error: - logger.exception("Error during rollback") + logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates") + except Exception as e: + # The context manager automatically handles rollback logger.exception("Failed to update Provider") raise - - -def _log_performance_stats(): - """Log performance statistics for Provider updates.""" - stats = _update_stats.copy() - - if stats["total_updates"] > 0: - success_rate = (stats["successful_updates"] / stats["total_updates"]) * 100 - avg_duration = stats["total_duration"] / stats["total_updates"] - - logger.info( - f"Provider Update Performance Stats: " - f"Total: {stats['total_updates']}, " - f"Success Rate: {success_rate:.1f}%, " - f"Avg Duration: {avg_duration:.3f}s, " - f"Deadlock Retries: {stats['deadlock_retries']}, " - f"Failed: {stats['failed_updates']}" - ) - - -def get_update_stats(): - """Get current update statistics for monitoring.""" - return _update_stats.copy()