|
|
|
|
@ -1,34 +1,56 @@
|
|
|
|
|
import logging
|
|
|
|
|
import secrets
|
|
|
|
|
import time
|
|
|
|
|
import time as time_module
|
|
|
|
|
from datetime import UTC, datetime
|
|
|
|
|
from typing import Any, Optional
|
|
|
|
|
|
|
|
|
|
from sqlalchemy import text
|
|
|
|
|
from sqlalchemy.exc import DatabaseError, IntegrityError, OperationalError
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
from sqlalchemy import update
|
|
|
|
|
|
|
|
|
|
from configs import dify_config
|
|
|
|
|
from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, ChatAppGenerateEntity
|
|
|
|
|
from core.entities.provider_entities import QuotaUnit
|
|
|
|
|
from core.entities.provider_entities import QuotaUnit, SystemConfiguration
|
|
|
|
|
from core.plugin.entities.plugin import ModelProviderID
|
|
|
|
|
from events.message_event import message_was_created
|
|
|
|
|
from extensions.ext_database import db
|
|
|
|
|
from models.model import Message
|
|
|
|
|
from models.provider import Provider, ProviderType
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
# Performance monitoring counters
|
|
|
|
|
_update_stats = {
|
|
|
|
|
"total_updates": 0,
|
|
|
|
|
"successful_updates": 0,
|
|
|
|
|
"failed_updates": 0,
|
|
|
|
|
"deadlock_retries": 0,
|
|
|
|
|
"total_duration": 0.0,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class _ProviderUpdateFilters(BaseModel):
|
|
|
|
|
"""Filters for identifying Provider records to update."""
|
|
|
|
|
|
|
|
|
|
tenant_id: str
|
|
|
|
|
provider_name: str
|
|
|
|
|
provider_type: Optional[str] = None
|
|
|
|
|
quota_type: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ProviderUpdateAdditionalFilters(BaseModel):
|
|
|
|
|
"""Additional filters for Provider updates."""
|
|
|
|
|
|
|
|
|
|
quota_limit_check: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ProviderUpdateValues(BaseModel):
|
|
|
|
|
"""Values to update in Provider records."""
|
|
|
|
|
|
|
|
|
|
last_used: Optional[datetime] = None
|
|
|
|
|
quota_used: Optional[Any] = None # Can be Provider.quota_used + int expression
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _ProviderUpdateOperation(BaseModel):
|
|
|
|
|
"""A single Provider update operation."""
|
|
|
|
|
|
|
|
|
|
filters: _ProviderUpdateFilters
|
|
|
|
|
values: _ProviderUpdateValues
|
|
|
|
|
additional_filters: _ProviderUpdateAdditionalFilters = _ProviderUpdateAdditionalFilters()
|
|
|
|
|
description: str = "unknown"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@message_was_created.connect
|
|
|
|
|
def handle(sender, **kwargs):
|
|
|
|
|
def handle(sender: Message, **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Consolidated handler for Provider updates when a message is created.
|
|
|
|
|
|
|
|
|
|
@ -36,9 +58,8 @@ def handle(sender, **kwargs):
|
|
|
|
|
- update_provider_last_used_at_when_message_created
|
|
|
|
|
- deduct_quota_when_message_created
|
|
|
|
|
|
|
|
|
|
By performing all Provider updates in a single transaction, we eliminate
|
|
|
|
|
the deadlock that occurred when both handlers tried to update the same
|
|
|
|
|
Provider records concurrently.
|
|
|
|
|
By performing all Provider updates in a single transaction, we ensure
|
|
|
|
|
consistency and efficiency when updating Provider records.
|
|
|
|
|
"""
|
|
|
|
|
message = sender
|
|
|
|
|
application_generate_entity = kwargs.get("application_generate_entity")
|
|
|
|
|
@ -51,17 +72,17 @@ def handle(sender, **kwargs):
|
|
|
|
|
current_time = datetime.now(UTC).replace(tzinfo=None)
|
|
|
|
|
|
|
|
|
|
# Prepare updates for both scenarios
|
|
|
|
|
updates_to_perform: list[dict[str, Any]] = []
|
|
|
|
|
updates_to_perform: list[_ProviderUpdateOperation] = []
|
|
|
|
|
|
|
|
|
|
# 1. Always update last_used for the provider
|
|
|
|
|
basic_update = {
|
|
|
|
|
"filters": {
|
|
|
|
|
"tenant_id": tenant_id,
|
|
|
|
|
"provider_name": provider_name,
|
|
|
|
|
},
|
|
|
|
|
"values": {"last_used": current_time},
|
|
|
|
|
"description": "basic_last_used_update",
|
|
|
|
|
}
|
|
|
|
|
basic_update = _ProviderUpdateOperation(
|
|
|
|
|
filters=_ProviderUpdateFilters(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
provider_name=provider_name,
|
|
|
|
|
),
|
|
|
|
|
values=_ProviderUpdateValues(last_used=current_time),
|
|
|
|
|
description="basic_last_used_update",
|
|
|
|
|
)
|
|
|
|
|
updates_to_perform.append(basic_update)
|
|
|
|
|
|
|
|
|
|
# 2. Check if we need to deduct quota (system provider only)
|
|
|
|
|
@ -71,43 +92,41 @@ def handle(sender, **kwargs):
|
|
|
|
|
|
|
|
|
|
if (
|
|
|
|
|
provider_configuration.using_provider_type == ProviderType.SYSTEM
|
|
|
|
|
and hasattr(provider_configuration, "system_configuration")
|
|
|
|
|
and provider_configuration.system_configuration
|
|
|
|
|
and provider_configuration.system_configuration.current_quota_type is not None
|
|
|
|
|
):
|
|
|
|
|
system_configuration = provider_configuration.system_configuration
|
|
|
|
|
|
|
|
|
|
# Calculate quota usage
|
|
|
|
|
used_quota = _calculate_quota_usage(message, system_configuration)
|
|
|
|
|
|
|
|
|
|
if used_quota is not None and used_quota > 0:
|
|
|
|
|
quota_update = {
|
|
|
|
|
"filters": {
|
|
|
|
|
"tenant_id": tenant_id,
|
|
|
|
|
"provider_name": ModelProviderID(model_config.provider).provider_name,
|
|
|
|
|
"provider_type": ProviderType.SYSTEM.value,
|
|
|
|
|
"quota_type": system_configuration.current_quota_type.value
|
|
|
|
|
if system_configuration.current_quota_type
|
|
|
|
|
else None,
|
|
|
|
|
},
|
|
|
|
|
"values": {"quota_used": Provider.quota_used + used_quota, "last_used": current_time},
|
|
|
|
|
"additional_filters": {
|
|
|
|
|
"quota_limit_check": True # Provider.quota_limit > Provider.quota_used
|
|
|
|
|
},
|
|
|
|
|
"description": "quota_deduction_update",
|
|
|
|
|
}
|
|
|
|
|
used_quota = _calculate_quota_usage(
|
|
|
|
|
message=message,
|
|
|
|
|
system_configuration=system_configuration,
|
|
|
|
|
model_name=model_config.model,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if used_quota is not None:
|
|
|
|
|
quota_update = _ProviderUpdateOperation(
|
|
|
|
|
filters=_ProviderUpdateFilters(
|
|
|
|
|
tenant_id=tenant_id,
|
|
|
|
|
provider_name=ModelProviderID(model_config.provider).provider_name,
|
|
|
|
|
provider_type=ProviderType.SYSTEM.value,
|
|
|
|
|
quota_type=provider_configuration.system_configuration.current_quota_type.value,
|
|
|
|
|
),
|
|
|
|
|
values=_ProviderUpdateValues(quota_used=Provider.quota_used + used_quota, last_used=current_time),
|
|
|
|
|
additional_filters=_ProviderUpdateAdditionalFilters(
|
|
|
|
|
quota_limit_check=True # Provider.quota_limit > Provider.quota_used
|
|
|
|
|
),
|
|
|
|
|
description="quota_deduction_update",
|
|
|
|
|
)
|
|
|
|
|
updates_to_perform.append(quota_update)
|
|
|
|
|
|
|
|
|
|
# Execute all updates with retry logic for deadlock prevention
|
|
|
|
|
# Execute all updates
|
|
|
|
|
start_time = time_module.perf_counter()
|
|
|
|
|
try:
|
|
|
|
|
_execute_provider_updates_with_retry(updates_to_perform)
|
|
|
|
|
_execute_provider_updates(updates_to_perform)
|
|
|
|
|
|
|
|
|
|
# Log successful completion with timing and update stats
|
|
|
|
|
# Log successful completion with timing
|
|
|
|
|
duration = time_module.perf_counter() - start_time
|
|
|
|
|
_update_stats["total_updates"] += 1
|
|
|
|
|
_update_stats["successful_updates"] += 1
|
|
|
|
|
_update_stats["total_duration"] += duration
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Provider updates completed successfully. "
|
|
|
|
|
@ -115,16 +134,9 @@ def handle(sender, **kwargs):
|
|
|
|
|
f"Tenant: {tenant_id}, Provider: {provider_name}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Log performance stats periodically
|
|
|
|
|
if _update_stats["total_updates"] % 100 == 0:
|
|
|
|
|
_log_performance_stats()
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
# Log failure with timing and context, update stats
|
|
|
|
|
# Log failure with timing and context
|
|
|
|
|
duration = time_module.perf_counter() - start_time
|
|
|
|
|
_update_stats["total_updates"] += 1
|
|
|
|
|
_update_stats["failed_updates"] += 1
|
|
|
|
|
_update_stats["total_duration"] += duration
|
|
|
|
|
|
|
|
|
|
logger.exception(
|
|
|
|
|
f"Provider updates failed after {duration:.3f}s. "
|
|
|
|
|
@ -134,13 +146,28 @@ def handle(sender, **kwargs):
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _calculate_quota_usage(message: Any, system_configuration: Any) -> Optional[int]:
|
|
|
|
|
def _calculate_quota_usage(
|
|
|
|
|
*, message: Message, system_configuration: SystemConfiguration, model_name: str
|
|
|
|
|
) -> Optional[int]:
|
|
|
|
|
"""Calculate quota usage based on message tokens and quota type."""
|
|
|
|
|
quota_unit = None
|
|
|
|
|
for quota_configuration in system_configuration.quota_configurations:
|
|
|
|
|
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
|
|
|
|
quota_unit = quota_configuration.quota_unit
|
|
|
|
|
if quota_configuration.quota_limit == -1:
|
|
|
|
|
return None
|
|
|
|
|
break
|
|
|
|
|
if quota_unit is None:
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
if system_configuration.current_quota_type == QuotaUnit.TOKENS:
|
|
|
|
|
tokens = getattr(message, "answer_tokens", None)
|
|
|
|
|
return int(tokens) if tokens is not None else None
|
|
|
|
|
elif system_configuration.current_quota_type == QuotaUnit.TIMES:
|
|
|
|
|
if quota_unit == QuotaUnit.TOKENS:
|
|
|
|
|
tokens = message.message_tokens + message.answer_tokens
|
|
|
|
|
return tokens
|
|
|
|
|
if quota_unit == QuotaUnit.CREDITS:
|
|
|
|
|
tokens = dify_config.get_model_credits(model_name)
|
|
|
|
|
return tokens
|
|
|
|
|
elif quota_unit == QuotaUnit.TIMES:
|
|
|
|
|
return 1
|
|
|
|
|
return None
|
|
|
|
|
except Exception as e:
|
|
|
|
|
@ -148,105 +175,51 @@ def _calculate_quota_usage(message: Any, system_configuration: Any) -> Optional[
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _execute_provider_updates_with_retry(updates_to_perform: list[dict[str, Any]], max_retries: int = 3):
|
|
|
|
|
"""
|
|
|
|
|
Execute Provider updates with deadlock retry logic.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
updates_to_perform: List of update operations to perform
|
|
|
|
|
max_retries: Maximum number of retry attempts for deadlock recovery
|
|
|
|
|
"""
|
|
|
|
|
for attempt in range(max_retries + 1):
|
|
|
|
|
try:
|
|
|
|
|
_execute_provider_updates(updates_to_perform)
|
|
|
|
|
return # Success, exit retry loop
|
|
|
|
|
|
|
|
|
|
except IntegrityError as e:
|
|
|
|
|
# Don't retry integrity constraint violations
|
|
|
|
|
logger.exception("Integrity constraint violation in Provider update")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
except (OperationalError, DatabaseError) as e:
|
|
|
|
|
error_msg = str(e).lower()
|
|
|
|
|
|
|
|
|
|
# Check for various deadlock/lock timeout patterns across different databases
|
|
|
|
|
is_deadlock = any(
|
|
|
|
|
pattern in error_msg
|
|
|
|
|
for pattern in [
|
|
|
|
|
"deadlock detected",
|
|
|
|
|
"deadlock found",
|
|
|
|
|
"lock wait timeout",
|
|
|
|
|
"lock timeout",
|
|
|
|
|
"serialization failure",
|
|
|
|
|
"could not serialize access",
|
|
|
|
|
]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if is_deadlock and attempt < max_retries:
|
|
|
|
|
# Track deadlock retry statistics
|
|
|
|
|
_update_stats["deadlock_retries"] += 1
|
|
|
|
|
|
|
|
|
|
# Exponential backoff with jitter for deadlock recovery
|
|
|
|
|
base_delay = (2**attempt) * 0.1
|
|
|
|
|
jitter = secrets.randbelow(100) / 1000.0 # 0.0 to 0.099
|
|
|
|
|
delay = base_delay + jitter
|
|
|
|
|
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"Database lock conflict detected in Provider update "
|
|
|
|
|
f"(attempt {attempt + 1}/{max_retries + 1}). "
|
|
|
|
|
f"Retrying in {delay:.2f}s. Error: {e}"
|
|
|
|
|
)
|
|
|
|
|
time.sleep(delay)
|
|
|
|
|
continue
|
|
|
|
|
else:
|
|
|
|
|
# Not a deadlock or max retries exceeded
|
|
|
|
|
logger.exception(f"Failed to update Provider after {attempt + 1} attempts")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.exception("Unexpected error updating Provider")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _execute_provider_updates(updates_to_perform: list[dict[str, Any]]):
|
|
|
|
|
"""Execute all Provider updates in a single transaction with proper isolation."""
|
|
|
|
|
def _execute_provider_updates(updates_to_perform: list[_ProviderUpdateOperation]):
|
|
|
|
|
"""Execute all Provider updates in a single transaction."""
|
|
|
|
|
if not updates_to_perform:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# Start a new transaction with explicit isolation level
|
|
|
|
|
# Use SQLAlchemy's context manager for transaction management
|
|
|
|
|
# This automatically handles commit/rollback
|
|
|
|
|
try:
|
|
|
|
|
# Set transaction isolation level to prevent phantom reads
|
|
|
|
|
# This helps reduce deadlock probability
|
|
|
|
|
db.session.execute(text("SET TRANSACTION ISOLATION LEVEL READ COMMITTED"))
|
|
|
|
|
|
|
|
|
|
with db.session.begin():
|
|
|
|
|
# Use a single transaction for all updates
|
|
|
|
|
for update_info in updates_to_perform:
|
|
|
|
|
filters = update_info["filters"]
|
|
|
|
|
values = update_info["values"]
|
|
|
|
|
additional_filters = update_info.get("additional_filters", {})
|
|
|
|
|
description = update_info.get("description", "unknown")
|
|
|
|
|
|
|
|
|
|
# Build the query with consistent ordering to prevent deadlocks
|
|
|
|
|
# Order by tenant_id, provider_name to ensure consistent lock acquisition
|
|
|
|
|
query = (
|
|
|
|
|
db.session.query(Provider)
|
|
|
|
|
.filter(Provider.tenant_id == filters["tenant_id"], Provider.provider_name == filters["provider_name"])
|
|
|
|
|
.order_by(Provider.tenant_id, Provider.provider_name)
|
|
|
|
|
)
|
|
|
|
|
for update_operation in updates_to_perform:
|
|
|
|
|
filters = update_operation.filters
|
|
|
|
|
values = update_operation.values
|
|
|
|
|
additional_filters = update_operation.additional_filters
|
|
|
|
|
description = update_operation.description
|
|
|
|
|
|
|
|
|
|
# Build the where conditions
|
|
|
|
|
where_conditions = [
|
|
|
|
|
Provider.tenant_id == filters.tenant_id,
|
|
|
|
|
Provider.provider_name == filters.provider_name,
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
# Add additional filters if specified
|
|
|
|
|
if "provider_type" in filters:
|
|
|
|
|
query = query.filter(Provider.provider_type == filters["provider_type"])
|
|
|
|
|
if "quota_type" in filters:
|
|
|
|
|
query = query.filter(Provider.quota_type == filters["quota_type"])
|
|
|
|
|
if additional_filters.get("quota_limit_check"):
|
|
|
|
|
query = query.filter(Provider.quota_limit > Provider.quota_used)
|
|
|
|
|
|
|
|
|
|
# Execute the update
|
|
|
|
|
rows_affected = query.update(values, synchronize_session=False)
|
|
|
|
|
if filters.provider_type is not None:
|
|
|
|
|
where_conditions.append(Provider.provider_type == filters.provider_type)
|
|
|
|
|
if filters.quota_type is not None:
|
|
|
|
|
where_conditions.append(Provider.quota_type == filters.quota_type)
|
|
|
|
|
if additional_filters.quota_limit_check:
|
|
|
|
|
where_conditions.append(Provider.quota_limit > Provider.quota_used)
|
|
|
|
|
|
|
|
|
|
# Prepare values dict for SQLAlchemy update
|
|
|
|
|
update_values = {}
|
|
|
|
|
if values.last_used is not None:
|
|
|
|
|
update_values["last_used"] = values.last_used
|
|
|
|
|
if values.quota_used is not None:
|
|
|
|
|
update_values["quota_used"] = values.quota_used
|
|
|
|
|
|
|
|
|
|
# Build and execute the update statement
|
|
|
|
|
stmt = update(Provider).where(*where_conditions).values(**update_values)
|
|
|
|
|
result = db.session.execute(stmt)
|
|
|
|
|
rows_affected = result.rowcount
|
|
|
|
|
|
|
|
|
|
logger.debug(
|
|
|
|
|
f"Provider update ({description}): {rows_affected} rows affected. Filters: {filters}, Values: {values}"
|
|
|
|
|
f"Provider update ({description}): {rows_affected} rows affected. "
|
|
|
|
|
f"Filters: {filters.model_dump()}, Values: {update_values}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# If no rows were affected for quota updates, log a warning
|
|
|
|
|
@ -254,43 +227,12 @@ def _execute_provider_updates(updates_to_perform: list[dict[str, Any]]):
|
|
|
|
|
logger.warning(
|
|
|
|
|
f"No Provider rows updated for quota deduction. "
|
|
|
|
|
f"This may indicate quota limit exceeded or provider not found. "
|
|
|
|
|
f"Filters: {filters}"
|
|
|
|
|
f"Filters: {filters.model_dump()}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Commit all updates in a single transaction
|
|
|
|
|
db.session.commit()
|
|
|
|
|
logger.debug(f"Successfully committed {len(updates_to_perform)} Provider updates")
|
|
|
|
|
logger.debug(f"Successfully processed {len(updates_to_perform)} Provider updates")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
# Rollback on any error
|
|
|
|
|
try:
|
|
|
|
|
db.session.rollback()
|
|
|
|
|
logger.debug("Transaction rolled back successfully")
|
|
|
|
|
except Exception as rollback_error:
|
|
|
|
|
logger.exception("Error during rollback")
|
|
|
|
|
|
|
|
|
|
# The context manager automatically handles rollback
|
|
|
|
|
logger.exception("Failed to update Provider")
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _log_performance_stats():
|
|
|
|
|
"""Log performance statistics for Provider updates."""
|
|
|
|
|
stats = _update_stats.copy()
|
|
|
|
|
|
|
|
|
|
if stats["total_updates"] > 0:
|
|
|
|
|
success_rate = (stats["successful_updates"] / stats["total_updates"]) * 100
|
|
|
|
|
avg_duration = stats["total_duration"] / stats["total_updates"]
|
|
|
|
|
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Provider Update Performance Stats: "
|
|
|
|
|
f"Total: {stats['total_updates']}, "
|
|
|
|
|
f"Success Rate: {success_rate:.1f}%, "
|
|
|
|
|
f"Avg Duration: {avg_duration:.3f}s, "
|
|
|
|
|
f"Deadlock Retries: {stats['deadlock_retries']}, "
|
|
|
|
|
f"Failed: {stats['failed_updates']}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_update_stats():
|
|
|
|
|
"""Get current update statistics for monitoring."""
|
|
|
|
|
return _update_stats.copy()
|
|
|
|
|
|