diff --git a/api/events/event_handlers/__init__.py b/api/events/event_handlers/__init__.py index 1d6ad35333..ebc55d5ef8 100644 --- a/api/events/event_handlers/__init__.py +++ b/api/events/event_handlers/__init__.py @@ -3,8 +3,10 @@ from .clean_when_document_deleted import handle from .create_document_index import handle from .create_installed_app_when_app_created import handle from .create_site_record_when_app_created import handle -from .deduct_quota_when_message_created import handle from .delete_tool_parameters_cache_when_sync_draft_workflow import handle from .update_app_dataset_join_when_app_model_config_updated import handle from .update_app_dataset_join_when_app_published_workflow_updated import handle -from .update_provider_last_used_at_when_message_created import handle + +# Consolidated handler replaces both deduct_quota_when_message_created and +# update_provider_last_used_at_when_message_created +from .update_provider_when_message_created import handle diff --git a/api/events/event_handlers/deduct_quota_when_message_created.py b/api/events/event_handlers/deduct_quota_when_message_created.py index b8e7019446..666e21a113 100644 --- a/api/events/event_handlers/deduct_quota_when_message_created.py +++ b/api/events/event_handlers/deduct_quota_when_message_created.py @@ -4,13 +4,14 @@ from configs import dify_config from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity from core.entities.provider_entities import QuotaUnit from core.plugin.entities.plugin import ModelProviderID -from events.message_event import message_was_created from extensions.ext_database import db from models.provider import Provider, ProviderType -@message_was_created.connect -def handle(sender, **kwargs): +# DEPRECATED: This handler has been replaced by update_provider_when_message_created.py +# to prevent deadlocks. This file is kept for reference but the handler is disabled. +# @message_was_created.connect # DISABLED +def handle_deprecated(sender, **kwargs): message = sender application_generate_entity = kwargs.get("application_generate_entity") diff --git a/api/events/event_handlers/update_provider_last_used_at_when_message_created.py b/api/events/event_handlers/update_provider_last_used_at_when_message_created.py index 59412cf87c..6319ce2d39 100644 --- a/api/events/event_handlers/update_provider_last_used_at_when_message_created.py +++ b/api/events/event_handlers/update_provider_last_used_at_when_message_created.py @@ -1,13 +1,14 @@ from datetime import UTC, datetime from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity -from events.message_event import message_was_created from extensions.ext_database import db from models.provider import Provider -@message_was_created.connect -def handle(sender, **kwargs): +# DEPRECATED: This handler has been replaced by update_provider_when_message_created.py +# to prevent deadlocks. This file is kept for reference but the handler is disabled. +# @message_was_created.connect # DISABLED +def handle_deprecated(sender, **kwargs): application_generate_entity = kwargs.get("application_generate_entity") if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): diff --git a/api/events/event_handlers/update_provider_when_message_created.py b/api/events/event_handlers/update_provider_when_message_created.py new file mode 100644 index 0000000000..53ec972e0b --- /dev/null +++ b/api/events/event_handlers/update_provider_when_message_created.py @@ -0,0 +1,296 @@ +import logging +import secrets +import time +import time as time_module +from datetime import UTC, datetime +from typing import Any, Optional + +from sqlalchemy import text +from sqlalchemy.exc import DatabaseError, IntegrityError, OperationalError + +from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity +from core.entities.provider_entities import QuotaUnit +from core.plugin.entities.plugin import ModelProviderID +from events.message_event import message_was_created +from extensions.ext_database import db +from models.provider import Provider, ProviderType + +logger = logging.getLogger(__name__) + +# Performance monitoring counters +_update_stats = { + "total_updates": 0, + "successful_updates": 0, + "failed_updates": 0, + "deadlock_retries": 0, + "total_duration": 0.0, +} + + +@message_was_created.connect +def handle(sender, **kwargs): + """ + Consolidated handler for Provider updates when a message is created. + + This handler replaces both: + - update_provider_last_used_at_when_message_created + - deduct_quota_when_message_created + + By performing all Provider updates in a single transaction, we eliminate + the deadlock that occurred when both handlers tried to update the same + Provider records concurrently. + """ + message = sender + application_generate_entity = kwargs.get("application_generate_entity") + + if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity): + return + + tenant_id = application_generate_entity.app_config.tenant_id + provider_name = application_generate_entity.model_conf.provider + current_time = datetime.now(UTC).replace(tzinfo=None) + + # Prepare updates for both scenarios + updates_to_perform: list[dict[str, Any]] = [] + + # 1. Always update last_used for the provider + basic_update = { + "filters": { + "tenant_id": tenant_id, + "provider_name": provider_name, + }, + "values": {"last_used": current_time}, + "description": "basic_last_used_update", + } + updates_to_perform.append(basic_update) + + # 2. Check if we need to deduct quota (system provider only) + model_config = application_generate_entity.model_conf + provider_model_bundle = model_config.provider_model_bundle + provider_configuration = provider_model_bundle.configuration + + if ( + provider_configuration.using_provider_type == ProviderType.SYSTEM + and hasattr(provider_configuration, "system_configuration") + and provider_configuration.system_configuration + and provider_configuration.system_configuration.current_quota_type is not None + ): + system_configuration = provider_configuration.system_configuration + + # Calculate quota usage + used_quota = _calculate_quota_usage(message, system_configuration) + + if used_quota is not None and used_quota > 0: + quota_update = { + "filters": { + "tenant_id": tenant_id, + "provider_name": ModelProviderID(model_config.provider).provider_name, + "provider_type": ProviderType.SYSTEM.value, + "quota_type": system_configuration.current_quota_type.value + if system_configuration.current_quota_type + else None, + }, + "values": {"quota_used": Provider.quota_used + used_quota, "last_used": current_time}, + "additional_filters": { + "quota_limit_check": True # Provider.quota_limit > Provider.quota_used + }, + "description": "quota_deduction_update", + } + updates_to_perform.append(quota_update) + + # Execute all updates with retry logic for deadlock prevention + start_time = time_module.perf_counter() + try: + _execute_provider_updates_with_retry(updates_to_perform) + + # Log successful completion with timing and update stats + duration = time_module.perf_counter() - start_time + _update_stats["total_updates"] += 1 + _update_stats["successful_updates"] += 1 + _update_stats["total_duration"] += duration + + logger.info( + f"Provider updates completed successfully. " + f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, " + f"Tenant: {tenant_id}, Provider: {provider_name}" + ) + + # Log performance stats periodically + if _update_stats["total_updates"] % 100 == 0: + _log_performance_stats() + + except Exception as e: + # Log failure with timing and context, update stats + duration = time_module.perf_counter() - start_time + _update_stats["total_updates"] += 1 + _update_stats["failed_updates"] += 1 + _update_stats["total_duration"] += duration + + logger.exception( + f"Provider updates failed after {duration:.3f}s. " + f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, " + f"Provider: {provider_name}" + ) + raise + + +def _calculate_quota_usage(message: Any, system_configuration: Any) -> Optional[int]: + """Calculate quota usage based on message tokens and quota type.""" + try: + if system_configuration.current_quota_type == QuotaUnit.TOKENS: + tokens = getattr(message, "answer_tokens", None) + return int(tokens) if tokens is not None else None + elif system_configuration.current_quota_type == QuotaUnit.TIMES: + return 1 + return None + except Exception as e: + logger.warning(f"Failed to calculate quota usage: {e}") + return None + + +def _execute_provider_updates_with_retry(updates_to_perform: list[dict[str, Any]], max_retries: int = 3): + """ + Execute Provider updates with deadlock retry logic. + + Args: + updates_to_perform: List of update operations to perform + max_retries: Maximum number of retry attempts for deadlock recovery + """ + for attempt in range(max_retries + 1): + try: + _execute_provider_updates(updates_to_perform) + return # Success, exit retry loop + + except IntegrityError as e: + # Don't retry integrity constraint violations + logger.exception("Integrity constraint violation in Provider update") + raise + + except (OperationalError, DatabaseError) as e: + error_msg = str(e).lower() + + # Check for various deadlock/lock timeout patterns across different databases + is_deadlock = any( + pattern in error_msg + for pattern in [ + "deadlock detected", + "deadlock found", + "lock wait timeout", + "lock timeout", + "serialization failure", + "could not serialize access", + ] + ) + + if is_deadlock and attempt < max_retries: + # Track deadlock retry statistics + _update_stats["deadlock_retries"] += 1 + + # Exponential backoff with jitter for deadlock recovery + base_delay = (2**attempt) * 0.1 + jitter = secrets.randbelow(100) / 1000.0 # 0.0 to 0.099 + delay = base_delay + jitter + + logger.warning( + f"Database lock conflict detected in Provider update " + f"(attempt {attempt + 1}/{max_retries + 1}). " + f"Retrying in {delay:.2f}s. Error: {e}" + ) + time.sleep(delay) + continue + else: + # Not a deadlock or max retries exceeded + logger.exception(f"Failed to update Provider after {attempt + 1} attempts") + raise + + except Exception as e: + logger.exception("Unexpected error updating Provider") + raise + + +def _execute_provider_updates(updates_to_perform: list[dict[str, Any]]): + """Execute all Provider updates in a single transaction with proper isolation.""" + if not updates_to_perform: + return + + # Start a new transaction with explicit isolation level + try: + # Set transaction isolation level to prevent phantom reads + # This helps reduce deadlock probability + db.session.execute(text("SET TRANSACTION ISOLATION LEVEL READ COMMITTED")) + + # Use a single transaction for all updates + for update_info in updates_to_perform: + filters = update_info["filters"] + values = update_info["values"] + additional_filters = update_info.get("additional_filters", {}) + description = update_info.get("description", "unknown") + + # Build the query with consistent ordering to prevent deadlocks + # Order by tenant_id, provider_name to ensure consistent lock acquisition + query = ( + db.session.query(Provider) + .filter(Provider.tenant_id == filters["tenant_id"], Provider.provider_name == filters["provider_name"]) + .order_by(Provider.tenant_id, Provider.provider_name) + ) + + # Add additional filters if specified + if "provider_type" in filters: + query = query.filter(Provider.provider_type == filters["provider_type"]) + if "quota_type" in filters: + query = query.filter(Provider.quota_type == filters["quota_type"]) + if additional_filters.get("quota_limit_check"): + query = query.filter(Provider.quota_limit > Provider.quota_used) + + # Execute the update + rows_affected = query.update(values, synchronize_session=False) + + logger.debug( + f"Provider update ({description}): {rows_affected} rows affected. Filters: {filters}, Values: {values}" + ) + + # If no rows were affected for quota updates, log a warning + if rows_affected == 0 and description == "quota_deduction_update": + logger.warning( + f"No Provider rows updated for quota deduction. " + f"This may indicate quota limit exceeded or provider not found. " + f"Filters: {filters}" + ) + + # Commit all updates in a single transaction + db.session.commit() + logger.debug(f"Successfully committed {len(updates_to_perform)} Provider updates") + + except Exception as e: + # Rollback on any error + try: + db.session.rollback() + logger.debug("Transaction rolled back successfully") + except Exception as rollback_error: + logger.exception("Error during rollback") + + logger.exception("Failed to update Provider") + raise + + +def _log_performance_stats(): + """Log performance statistics for Provider updates.""" + stats = _update_stats.copy() + + if stats["total_updates"] > 0: + success_rate = (stats["successful_updates"] / stats["total_updates"]) * 100 + avg_duration = stats["total_duration"] / stats["total_updates"] + + logger.info( + f"Provider Update Performance Stats: " + f"Total: {stats['total_updates']}, " + f"Success Rate: {success_rate:.1f}%, " + f"Avg Duration: {avg_duration:.3f}s, " + f"Deadlock Retries: {stats['deadlock_retries']}, " + f"Failed: {stats['failed_updates']}" + ) + + +def get_update_stats(): + """Get current update statistics for monitoring.""" + return _update_stats.copy() diff --git a/tests/unit_tests/events/test_provider_update_deadlock_prevention.py b/tests/unit_tests/events/test_provider_update_deadlock_prevention.py new file mode 100644 index 0000000000..47c175acd7 --- /dev/null +++ b/tests/unit_tests/events/test_provider_update_deadlock_prevention.py @@ -0,0 +1,248 @@ +import threading +from unittest.mock import Mock, patch + +from core.app.entities.app_invoke_entities import ChatAppGenerateEntity +from core.entities.provider_entities import QuotaUnit +from events.event_handlers.update_provider_when_message_created import ( + handle, + get_update_stats, +) +from models.provider import ProviderType +from sqlalchemy.exc import OperationalError + + +class TestProviderUpdateDeadlockPrevention: + """Test suite for deadlock prevention in Provider updates.""" + + def setup_method(self): + """Setup test fixtures.""" + self.mock_message = Mock() + self.mock_message.answer_tokens = 100 + + self.mock_app_config = Mock() + self.mock_app_config.tenant_id = "test-tenant-123" + + self.mock_model_conf = Mock() + self.mock_model_conf.provider = "openai" + + self.mock_system_config = Mock() + self.mock_system_config.current_quota_type = QuotaUnit.TOKENS + + self.mock_provider_config = Mock() + self.mock_provider_config.using_provider_type = ProviderType.SYSTEM + self.mock_provider_config.system_configuration = self.mock_system_config + + self.mock_provider_bundle = Mock() + self.mock_provider_bundle.configuration = self.mock_provider_config + + self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle + + self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity) + self.mock_generate_entity.app_config = self.mock_app_config + self.mock_generate_entity.model_conf = self.mock_model_conf + + @patch("events.event_handlers.update_provider_when_message_created.db") + def test_consolidated_handler_basic_functionality(self, mock_db): + """Test that the consolidated handler performs both updates correctly.""" + # Setup mock query chain + mock_query = Mock() + mock_db.session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.update.return_value = 1 # 1 row affected + + # Call the handler + handle(self.mock_message, application_generate_entity=self.mock_generate_entity) + + # Verify db.session.query was called + assert mock_db.session.query.called + + # Verify commit was called + mock_db.session.commit.assert_called_once() + + # Verify no rollback was called + assert not mock_db.session.rollback.called + + @patch("events.event_handlers.update_provider_when_message_created.db") + def test_deadlock_retry_mechanism(self, mock_db): + """Test that deadlock errors trigger retry logic.""" + # Setup mock to raise deadlock error on first attempt, succeed on second + mock_query = Mock() + mock_db.session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.update.return_value = 1 + + # First call raises deadlock, second succeeds + mock_db.session.commit.side_effect = [ + OperationalError("deadlock detected", None, None), + None, # Success on retry + ] + + # Call the handler + handle(self.mock_message, application_generate_entity=self.mock_generate_entity) + + # Verify commit was called twice (original + retry) + assert mock_db.session.commit.call_count == 2 + + # Verify rollback was called once (after first failure) + mock_db.session.rollback.assert_called_once() + + @patch("events.event_handlers.update_provider_when_message_created.db") + @patch("events.event_handlers.update_provider_when_message_created.time.sleep") + def test_exponential_backoff_timing(self, mock_sleep, mock_db): + """Test that retry delays follow exponential backoff pattern.""" + # Setup mock to fail twice, succeed on third attempt + mock_query = Mock() + mock_db.session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.update.return_value = 1 + + mock_db.session.commit.side_effect = [ + OperationalError("deadlock detected", None, None), + OperationalError("deadlock detected", None, None), + None, # Success on third attempt + ] + + # Call the handler + handle(self.mock_message, application_generate_entity=self.mock_generate_entity) + + # Verify sleep was called twice with increasing delays + assert mock_sleep.call_count == 2 + + # First delay should be around 0.1s + jitter + first_delay = mock_sleep.call_args_list[0][0][0] + assert 0.1 <= first_delay <= 0.3 + + # Second delay should be around 0.2s + jitter + second_delay = mock_sleep.call_args_list[1][0][0] + assert 0.2 <= second_delay <= 0.4 + + def test_concurrent_handler_execution(self): + """Test that multiple handlers can run concurrently without deadlock.""" + results = [] + errors = [] + + def run_handler(): + try: + with patch( + "events.event_handlers.update_provider_when_message_created.db" + ) as mock_db: + mock_query = Mock() + mock_db.session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.update.return_value = 1 + + handle( + self.mock_message, + application_generate_entity=self.mock_generate_entity, + ) + results.append("success") + except Exception as e: + errors.append(str(e)) + + # Run multiple handlers concurrently + threads = [] + for _ in range(5): + thread = threading.Thread(target=run_handler) + threads.append(thread) + thread.start() + + # Wait for all threads to complete + for thread in threads: + thread.join(timeout=5) + + # Verify all handlers completed successfully + assert len(results) == 5 + assert len(errors) == 0 + + def test_performance_stats_tracking(self): + """Test that performance statistics are tracked correctly.""" + # Reset stats + stats = get_update_stats() + initial_total = stats["total_updates"] + + with patch( + "events.event_handlers.update_provider_when_message_created.db" + ) as mock_db: + mock_query = Mock() + mock_db.session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.update.return_value = 1 + + # Call handler + handle( + self.mock_message, application_generate_entity=self.mock_generate_entity + ) + + # Check that stats were updated + updated_stats = get_update_stats() + assert updated_stats["total_updates"] == initial_total + 1 + assert updated_stats["successful_updates"] >= initial_total + 1 + + def test_non_chat_entity_ignored(self): + """Test that non-chat entities are ignored by the handler.""" + # Create a non-chat entity + mock_non_chat_entity = Mock() + mock_non_chat_entity.__class__.__name__ = "NonChatEntity" + + with patch( + "events.event_handlers.update_provider_when_message_created.db" + ) as mock_db: + # Call handler with non-chat entity + handle(self.mock_message, application_generate_entity=mock_non_chat_entity) + + # Verify no database operations were performed + assert not mock_db.session.query.called + assert not mock_db.session.commit.called + + @patch("events.event_handlers.update_provider_when_message_created.db") + def test_quota_calculation_tokens(self, mock_db): + """Test quota calculation for token-based quotas.""" + # Setup token-based quota + self.mock_system_config.current_quota_type = QuotaUnit.TOKENS + self.mock_message.answer_tokens = 150 + + mock_query = Mock() + mock_db.session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.update.return_value = 1 + + # Call handler + handle(self.mock_message, application_generate_entity=self.mock_generate_entity) + + # Verify update was called with token count + update_calls = mock_query.update.call_args_list + + # Should have at least one call with quota_used update + quota_update_found = False + for call in update_calls: + values = call[0][0] # First argument to update() + if "quota_used" in values: + quota_update_found = True + break + + assert quota_update_found + + @patch("events.event_handlers.update_provider_when_message_created.db") + def test_quota_calculation_times(self, mock_db): + """Test quota calculation for times-based quotas.""" + # Setup times-based quota + self.mock_system_config.current_quota_type = QuotaUnit.TIMES + + mock_query = Mock() + mock_db.session.query.return_value = mock_query + mock_query.filter.return_value = mock_query + mock_query.order_by.return_value = mock_query + mock_query.update.return_value = 1 + + # Call handler + handle(self.mock_message, application_generate_entity=self.mock_generate_entity) + + # Verify update was called + assert mock_query.update.called + assert mock_db.session.commit.called