fix(event_handlers): DB dead lock

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/21468/head
-LAN- 11 months ago
parent 164e5481c5
commit 4ac2715c2a
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -3,8 +3,10 @@ from .clean_when_document_deleted import handle
from .create_document_index import handle
from .create_installed_app_when_app_created import handle
from .create_site_record_when_app_created import handle
from .deduct_quota_when_message_created import handle
from .delete_tool_parameters_cache_when_sync_draft_workflow import handle
from .update_app_dataset_join_when_app_model_config_updated import handle
from .update_app_dataset_join_when_app_published_workflow_updated import handle
from .update_provider_last_used_at_when_message_created import handle
# Consolidated handler replaces both deduct_quota_when_message_created and
# update_provider_last_used_at_when_message_created
from .update_provider_when_message_created import handle

@ -4,13 +4,14 @@ from configs import dify_config
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit
from core.plugin.entities.plugin import ModelProviderID
from events.message_event import message_was_created
from extensions.ext_database import db
from models.provider import Provider, ProviderType
@message_was_created.connect
def handle(sender, **kwargs):
# DEPRECATED: This handler has been replaced by update_provider_when_message_created.py
# to prevent deadlocks. This file is kept for reference but the handler is disabled.
# @message_was_created.connect # DISABLED
def handle_deprecated(sender, **kwargs):
message = sender
application_generate_entity = kwargs.get("application_generate_entity")

@ -1,13 +1,14 @@
from datetime import UTC, datetime
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from events.message_event import message_was_created
from extensions.ext_database import db
from models.provider import Provider
@message_was_created.connect
def handle(sender, **kwargs):
# DEPRECATED: This handler has been replaced by update_provider_when_message_created.py
# to prevent deadlocks. This file is kept for reference but the handler is disabled.
# @message_was_created.connect # DISABLED
def handle_deprecated(sender, **kwargs):
application_generate_entity = kwargs.get("application_generate_entity")
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):

@ -0,0 +1,296 @@
import logging
import secrets
import time
import time as time_module
from datetime import UTC, datetime
from typing import Any, Optional
from sqlalchemy import text
from sqlalchemy.exc import DatabaseError, IntegrityError, OperationalError
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit
from core.plugin.entities.plugin import ModelProviderID
from events.message_event import message_was_created
from extensions.ext_database import db
from models.provider import Provider, ProviderType
logger = logging.getLogger(__name__)
# Performance monitoring counters
_update_stats = {
"total_updates": 0,
"successful_updates": 0,
"failed_updates": 0,
"deadlock_retries": 0,
"total_duration": 0.0,
}
@message_was_created.connect
def handle(sender, **kwargs):
"""
Consolidated handler for Provider updates when a message is created.
This handler replaces both:
- update_provider_last_used_at_when_message_created
- deduct_quota_when_message_created
By performing all Provider updates in a single transaction, we eliminate
the deadlock that occurred when both handlers tried to update the same
Provider records concurrently.
"""
message = sender
application_generate_entity = kwargs.get("application_generate_entity")
if not isinstance(application_generate_entity, ChatAppGenerateEntity | AgentChatAppGenerateEntity):
return
tenant_id = application_generate_entity.app_config.tenant_id
provider_name = application_generate_entity.model_conf.provider
current_time = datetime.now(UTC).replace(tzinfo=None)
# Prepare updates for both scenarios
updates_to_perform: list[dict[str, Any]] = []
# 1. Always update last_used for the provider
basic_update = {
"filters": {
"tenant_id": tenant_id,
"provider_name": provider_name,
},
"values": {"last_used": current_time},
"description": "basic_last_used_update",
}
updates_to_perform.append(basic_update)
# 2. Check if we need to deduct quota (system provider only)
model_config = application_generate_entity.model_conf
provider_model_bundle = model_config.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if (
provider_configuration.using_provider_type == ProviderType.SYSTEM
and hasattr(provider_configuration, "system_configuration")
and provider_configuration.system_configuration
and provider_configuration.system_configuration.current_quota_type is not None
):
system_configuration = provider_configuration.system_configuration
# Calculate quota usage
used_quota = _calculate_quota_usage(message, system_configuration)
if used_quota is not None and used_quota > 0:
quota_update = {
"filters": {
"tenant_id": tenant_id,
"provider_name": ModelProviderID(model_config.provider).provider_name,
"provider_type": ProviderType.SYSTEM.value,
"quota_type": system_configuration.current_quota_type.value
if system_configuration.current_quota_type
else None,
},
"values": {"quota_used": Provider.quota_used + used_quota, "last_used": current_time},
"additional_filters": {
"quota_limit_check": True # Provider.quota_limit > Provider.quota_used
},
"description": "quota_deduction_update",
}
updates_to_perform.append(quota_update)
# Execute all updates with retry logic for deadlock prevention
start_time = time_module.perf_counter()
try:
_execute_provider_updates_with_retry(updates_to_perform)
# Log successful completion with timing and update stats
duration = time_module.perf_counter() - start_time
_update_stats["total_updates"] += 1
_update_stats["successful_updates"] += 1
_update_stats["total_duration"] += duration
logger.info(
f"Provider updates completed successfully. "
f"Updates: {len(updates_to_perform)}, Duration: {duration:.3f}s, "
f"Tenant: {tenant_id}, Provider: {provider_name}"
)
# Log performance stats periodically
if _update_stats["total_updates"] % 100 == 0:
_log_performance_stats()
except Exception as e:
# Log failure with timing and context, update stats
duration = time_module.perf_counter() - start_time
_update_stats["total_updates"] += 1
_update_stats["failed_updates"] += 1
_update_stats["total_duration"] += duration
logger.exception(
f"Provider updates failed after {duration:.3f}s. "
f"Updates: {len(updates_to_perform)}, Tenant: {tenant_id}, "
f"Provider: {provider_name}"
)
raise
def _calculate_quota_usage(message: Any, system_configuration: Any) -> Optional[int]:
"""Calculate quota usage based on message tokens and quota type."""
try:
if system_configuration.current_quota_type == QuotaUnit.TOKENS:
tokens = getattr(message, "answer_tokens", None)
return int(tokens) if tokens is not None else None
elif system_configuration.current_quota_type == QuotaUnit.TIMES:
return 1
return None
except Exception as e:
logger.warning(f"Failed to calculate quota usage: {e}")
return None
def _execute_provider_updates_with_retry(updates_to_perform: list[dict[str, Any]], max_retries: int = 3):
"""
Execute Provider updates with deadlock retry logic.
Args:
updates_to_perform: List of update operations to perform
max_retries: Maximum number of retry attempts for deadlock recovery
"""
for attempt in range(max_retries + 1):
try:
_execute_provider_updates(updates_to_perform)
return # Success, exit retry loop
except IntegrityError as e:
# Don't retry integrity constraint violations
logger.exception("Integrity constraint violation in Provider update")
raise
except (OperationalError, DatabaseError) as e:
error_msg = str(e).lower()
# Check for various deadlock/lock timeout patterns across different databases
is_deadlock = any(
pattern in error_msg
for pattern in [
"deadlock detected",
"deadlock found",
"lock wait timeout",
"lock timeout",
"serialization failure",
"could not serialize access",
]
)
if is_deadlock and attempt < max_retries:
# Track deadlock retry statistics
_update_stats["deadlock_retries"] += 1
# Exponential backoff with jitter for deadlock recovery
base_delay = (2**attempt) * 0.1
jitter = secrets.randbelow(100) / 1000.0 # 0.0 to 0.099
delay = base_delay + jitter
logger.warning(
f"Database lock conflict detected in Provider update "
f"(attempt {attempt + 1}/{max_retries + 1}). "
f"Retrying in {delay:.2f}s. Error: {e}"
)
time.sleep(delay)
continue
else:
# Not a deadlock or max retries exceeded
logger.exception(f"Failed to update Provider after {attempt + 1} attempts")
raise
except Exception as e:
logger.exception("Unexpected error updating Provider")
raise
def _execute_provider_updates(updates_to_perform: list[dict[str, Any]]):
"""Execute all Provider updates in a single transaction with proper isolation."""
if not updates_to_perform:
return
# Start a new transaction with explicit isolation level
try:
# Set transaction isolation level to prevent phantom reads
# This helps reduce deadlock probability
db.session.execute(text("SET TRANSACTION ISOLATION LEVEL READ COMMITTED"))
# Use a single transaction for all updates
for update_info in updates_to_perform:
filters = update_info["filters"]
values = update_info["values"]
additional_filters = update_info.get("additional_filters", {})
description = update_info.get("description", "unknown")
# Build the query with consistent ordering to prevent deadlocks
# Order by tenant_id, provider_name to ensure consistent lock acquisition
query = (
db.session.query(Provider)
.filter(Provider.tenant_id == filters["tenant_id"], Provider.provider_name == filters["provider_name"])
.order_by(Provider.tenant_id, Provider.provider_name)
)
# Add additional filters if specified
if "provider_type" in filters:
query = query.filter(Provider.provider_type == filters["provider_type"])
if "quota_type" in filters:
query = query.filter(Provider.quota_type == filters["quota_type"])
if additional_filters.get("quota_limit_check"):
query = query.filter(Provider.quota_limit > Provider.quota_used)
# Execute the update
rows_affected = query.update(values, synchronize_session=False)
logger.debug(
f"Provider update ({description}): {rows_affected} rows affected. Filters: {filters}, Values: {values}"
)
# If no rows were affected for quota updates, log a warning
if rows_affected == 0 and description == "quota_deduction_update":
logger.warning(
f"No Provider rows updated for quota deduction. "
f"This may indicate quota limit exceeded or provider not found. "
f"Filters: {filters}"
)
# Commit all updates in a single transaction
db.session.commit()
logger.debug(f"Successfully committed {len(updates_to_perform)} Provider updates")
except Exception as e:
# Rollback on any error
try:
db.session.rollback()
logger.debug("Transaction rolled back successfully")
except Exception as rollback_error:
logger.exception("Error during rollback")
logger.exception("Failed to update Provider")
raise
def _log_performance_stats():
"""Log performance statistics for Provider updates."""
stats = _update_stats.copy()
if stats["total_updates"] > 0:
success_rate = (stats["successful_updates"] / stats["total_updates"]) * 100
avg_duration = stats["total_duration"] / stats["total_updates"]
logger.info(
f"Provider Update Performance Stats: "
f"Total: {stats['total_updates']}, "
f"Success Rate: {success_rate:.1f}%, "
f"Avg Duration: {avg_duration:.3f}s, "
f"Deadlock Retries: {stats['deadlock_retries']}, "
f"Failed: {stats['failed_updates']}"
)
def get_update_stats():
"""Get current update statistics for monitoring."""
return _update_stats.copy()

@ -0,0 +1,248 @@
import threading
from unittest.mock import Mock, patch
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit
from events.event_handlers.update_provider_when_message_created import (
handle,
get_update_stats,
)
from models.provider import ProviderType
from sqlalchemy.exc import OperationalError
class TestProviderUpdateDeadlockPrevention:
"""Test suite for deadlock prevention in Provider updates."""
def setup_method(self):
"""Setup test fixtures."""
self.mock_message = Mock()
self.mock_message.answer_tokens = 100
self.mock_app_config = Mock()
self.mock_app_config.tenant_id = "test-tenant-123"
self.mock_model_conf = Mock()
self.mock_model_conf.provider = "openai"
self.mock_system_config = Mock()
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
self.mock_provider_config = Mock()
self.mock_provider_config.using_provider_type = ProviderType.SYSTEM
self.mock_provider_config.system_configuration = self.mock_system_config
self.mock_provider_bundle = Mock()
self.mock_provider_bundle.configuration = self.mock_provider_config
self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle
self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity)
self.mock_generate_entity.app_config = self.mock_app_config
self.mock_generate_entity.model_conf = self.mock_model_conf
@patch("events.event_handlers.update_provider_when_message_created.db")
def test_consolidated_handler_basic_functionality(self, mock_db):
"""Test that the consolidated handler performs both updates correctly."""
# Setup mock query chain
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1 # 1 row affected
# Call the handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify db.session.query was called
assert mock_db.session.query.called
# Verify commit was called
mock_db.session.commit.assert_called_once()
# Verify no rollback was called
assert not mock_db.session.rollback.called
@patch("events.event_handlers.update_provider_when_message_created.db")
def test_deadlock_retry_mechanism(self, mock_db):
"""Test that deadlock errors trigger retry logic."""
# Setup mock to raise deadlock error on first attempt, succeed on second
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
# First call raises deadlock, second succeeds
mock_db.session.commit.side_effect = [
OperationalError("deadlock detected", None, None),
None, # Success on retry
]
# Call the handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify commit was called twice (original + retry)
assert mock_db.session.commit.call_count == 2
# Verify rollback was called once (after first failure)
mock_db.session.rollback.assert_called_once()
@patch("events.event_handlers.update_provider_when_message_created.db")
@patch("events.event_handlers.update_provider_when_message_created.time.sleep")
def test_exponential_backoff_timing(self, mock_sleep, mock_db):
"""Test that retry delays follow exponential backoff pattern."""
# Setup mock to fail twice, succeed on third attempt
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
mock_db.session.commit.side_effect = [
OperationalError("deadlock detected", None, None),
OperationalError("deadlock detected", None, None),
None, # Success on third attempt
]
# Call the handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify sleep was called twice with increasing delays
assert mock_sleep.call_count == 2
# First delay should be around 0.1s + jitter
first_delay = mock_sleep.call_args_list[0][0][0]
assert 0.1 <= first_delay <= 0.3
# Second delay should be around 0.2s + jitter
second_delay = mock_sleep.call_args_list[1][0][0]
assert 0.2 <= second_delay <= 0.4
def test_concurrent_handler_execution(self):
"""Test that multiple handlers can run concurrently without deadlock."""
results = []
errors = []
def run_handler():
try:
with patch(
"events.event_handlers.update_provider_when_message_created.db"
) as mock_db:
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
handle(
self.mock_message,
application_generate_entity=self.mock_generate_entity,
)
results.append("success")
except Exception as e:
errors.append(str(e))
# Run multiple handlers concurrently
threads = []
for _ in range(5):
thread = threading.Thread(target=run_handler)
threads.append(thread)
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join(timeout=5)
# Verify all handlers completed successfully
assert len(results) == 5
assert len(errors) == 0
def test_performance_stats_tracking(self):
"""Test that performance statistics are tracked correctly."""
# Reset stats
stats = get_update_stats()
initial_total = stats["total_updates"]
with patch(
"events.event_handlers.update_provider_when_message_created.db"
) as mock_db:
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
# Call handler
handle(
self.mock_message, application_generate_entity=self.mock_generate_entity
)
# Check that stats were updated
updated_stats = get_update_stats()
assert updated_stats["total_updates"] == initial_total + 1
assert updated_stats["successful_updates"] >= initial_total + 1
def test_non_chat_entity_ignored(self):
"""Test that non-chat entities are ignored by the handler."""
# Create a non-chat entity
mock_non_chat_entity = Mock()
mock_non_chat_entity.__class__.__name__ = "NonChatEntity"
with patch(
"events.event_handlers.update_provider_when_message_created.db"
) as mock_db:
# Call handler with non-chat entity
handle(self.mock_message, application_generate_entity=mock_non_chat_entity)
# Verify no database operations were performed
assert not mock_db.session.query.called
assert not mock_db.session.commit.called
@patch("events.event_handlers.update_provider_when_message_created.db")
def test_quota_calculation_tokens(self, mock_db):
"""Test quota calculation for token-based quotas."""
# Setup token-based quota
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
self.mock_message.answer_tokens = 150
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
# Call handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify update was called with token count
update_calls = mock_query.update.call_args_list
# Should have at least one call with quota_used update
quota_update_found = False
for call in update_calls:
values = call[0][0] # First argument to update()
if "quota_used" in values:
quota_update_found = True
break
assert quota_update_found
@patch("events.event_handlers.update_provider_when_message_created.db")
def test_quota_calculation_times(self, mock_db):
"""Test quota calculation for times-based quotas."""
# Setup times-based quota
self.mock_system_config.current_quota_type = QuotaUnit.TIMES
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
# Call handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify update was called
assert mock_query.update.called
assert mock_db.session.commit.called
Loading…
Cancel
Save