Merge branch 'main' into fix-list-operator
commit
3042494562
@ -1,48 +0,0 @@
|
||||
## Guidelines for Database Connection Management in App Runner and Task Pipeline
|
||||
|
||||
Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks.
|
||||
|
||||
Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid detach errors.
|
||||
|
||||
Examples:
|
||||
|
||||
1. Creating a new record:
|
||||
|
||||
```python
|
||||
app = App(id=1)
|
||||
db.session.add(app)
|
||||
db.session.commit()
|
||||
db.session.refresh(app) # Retrieve table default values, like created_at, cached in the app object, won't affect after close
|
||||
|
||||
# Handle non-long-running tasks or store the content of the App instance in memory (via variable assignment).
|
||||
|
||||
db.session.close()
|
||||
|
||||
return app.id
|
||||
```
|
||||
|
||||
2. Fetching a record from the table:
|
||||
|
||||
```python
|
||||
app = db.session.query(App).filter(App.id == app_id).first()
|
||||
|
||||
created_at = app.created_at
|
||||
|
||||
db.session.close()
|
||||
|
||||
# Handle tasks (include long-running).
|
||||
|
||||
```
|
||||
|
||||
3. Updating a table field:
|
||||
|
||||
```python
|
||||
app = db.session.query(App).filter(App.id == app_id).first()
|
||||
|
||||
app.updated_at = time.utcnow()
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
return app_id
|
||||
```
|
||||
|
||||
@ -0,0 +1,34 @@
|
||||
"""oauth_refresh_token
|
||||
|
||||
Revision ID: 375fe79ead14
|
||||
Revises: 1a83934ad6d1
|
||||
Create Date: 2025-07-22 00:19:45.599636
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '375fe79ead14'
|
||||
down_revision = '1a83934ad6d1'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('expires_at', sa.BigInteger(), server_default=sa.text('-1'), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
|
||||
batch_op.drop_column('expires_at')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,42 @@
|
||||
"""add_tenant_plugin_autoupgrade_table
|
||||
|
||||
Revision ID: 8bcc02c9bd07
|
||||
Revises: 375fe79ead14
|
||||
Create Date: 2025-07-23 15:08:50.161441
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '8bcc02c9bd07'
|
||||
down_revision = '375fe79ead14'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tenant_plugin_auto_upgrade_strategies',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('strategy_setting', sa.String(length=16), server_default='fix_only', nullable=False),
|
||||
sa.Column('upgrade_time_of_day', sa.Integer(), nullable=False),
|
||||
sa.Column('upgrade_mode', sa.String(length=16), server_default='exclude', nullable=False),
|
||||
sa.Column('exclude_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
|
||||
sa.Column('include_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_plugin_auto_upgrade_strategy_pkey'),
|
||||
sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin_auto_upgrade_strategy')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
|
||||
op.drop_table('tenant_plugin_auto_upgrade_strategies')
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,49 @@
|
||||
import time
|
||||
|
||||
import click
|
||||
|
||||
import app
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
from tasks.process_tenant_plugin_autoupgrade_check_task import process_tenant_plugin_autoupgrade_check_task
|
||||
|
||||
AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL = 15 * 60 # 15 minutes
|
||||
|
||||
|
||||
@app.celery.task(queue="plugin")
|
||||
def check_upgradable_plugin_task():
|
||||
click.echo(click.style("Start check upgradable plugin.", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
now_seconds_of_day = time.time() % 86400 - 30 # we assume the tz is UTC
|
||||
click.echo(click.style("Now seconds of day: {}".format(now_seconds_of_day), fg="green"))
|
||||
|
||||
strategies = (
|
||||
db.session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.filter(
|
||||
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day >= now_seconds_of_day,
|
||||
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day
|
||||
< now_seconds_of_day + AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL,
|
||||
TenantPluginAutoUpgradeStrategy.strategy_setting
|
||||
!= TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for strategy in strategies:
|
||||
process_tenant_plugin_autoupgrade_check_task.delay(
|
||||
strategy.tenant_id,
|
||||
strategy.strategy_setting,
|
||||
strategy.upgrade_time_of_day,
|
||||
strategy.upgrade_mode,
|
||||
strategy.exclude_plugins,
|
||||
strategy.include_plugins,
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
click.echo(
|
||||
click.style(
|
||||
"Checked upgradable plugin success latency: {}".format(end_at - start_at),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
@ -0,0 +1,87 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
|
||||
|
||||
class PluginAutoUpgradeService:
|
||||
@staticmethod
|
||||
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
|
||||
with Session(db.engine) as session:
|
||||
return (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def change_strategy(
|
||||
tenant_id: str,
|
||||
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting,
|
||||
upgrade_time_of_day: int,
|
||||
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode,
|
||||
exclude_plugins: list[str],
|
||||
include_plugins: list[str],
|
||||
) -> bool:
|
||||
with Session(db.engine) as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not exist_strategy:
|
||||
strategy = TenantPluginAutoUpgradeStrategy(
|
||||
tenant_id=tenant_id,
|
||||
strategy_setting=strategy_setting,
|
||||
upgrade_time_of_day=upgrade_time_of_day,
|
||||
upgrade_mode=upgrade_mode,
|
||||
exclude_plugins=exclude_plugins,
|
||||
include_plugins=include_plugins,
|
||||
)
|
||||
session.add(strategy)
|
||||
else:
|
||||
exist_strategy.strategy_setting = strategy_setting
|
||||
exist_strategy.upgrade_time_of_day = upgrade_time_of_day
|
||||
exist_strategy.upgrade_mode = upgrade_mode
|
||||
exist_strategy.exclude_plugins = exclude_plugins
|
||||
exist_strategy.include_plugins = include_plugins
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
|
||||
with Session(db.engine) as session:
|
||||
exist_strategy = (
|
||||
session.query(TenantPluginAutoUpgradeStrategy)
|
||||
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
if not exist_strategy:
|
||||
# create for this tenant
|
||||
PluginAutoUpgradeService.change_strategy(
|
||||
tenant_id,
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
|
||||
0,
|
||||
TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
|
||||
[plugin_id],
|
||||
[],
|
||||
)
|
||||
return True
|
||||
else:
|
||||
if exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE:
|
||||
if plugin_id not in exist_strategy.exclude_plugins:
|
||||
new_exclude_plugins = exist_strategy.exclude_plugins.copy()
|
||||
new_exclude_plugins.append(plugin_id)
|
||||
exist_strategy.exclude_plugins = new_exclude_plugins
|
||||
elif exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL:
|
||||
if plugin_id in exist_strategy.include_plugins:
|
||||
new_include_plugins = exist_strategy.include_plugins.copy()
|
||||
new_include_plugins.remove(plugin_id)
|
||||
exist_strategy.include_plugins = new_include_plugins
|
||||
elif exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL:
|
||||
exist_strategy.upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
|
||||
exist_strategy.exclude_plugins = [plugin_id]
|
||||
|
||||
session.commit()
|
||||
return True
|
||||
@ -0,0 +1,166 @@
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
|
||||
from core.helper import marketplace
|
||||
from core.helper.marketplace import MarketplacePluginDeclaration
|
||||
from core.plugin.entities.plugin import PluginInstallationSource
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from models.account import TenantPluginAutoUpgradeStrategy
|
||||
|
||||
RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3
|
||||
|
||||
|
||||
cached_plugin_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {}
|
||||
|
||||
|
||||
def marketplace_batch_fetch_plugin_manifests(
|
||||
plugin_ids_plain_list: list[str],
|
||||
) -> list[MarketplacePluginDeclaration]:
|
||||
global cached_plugin_manifests
|
||||
# return marketplace.batch_fetch_plugin_manifests(plugin_ids_plain_list)
|
||||
not_included_plugin_ids = [
|
||||
plugin_id for plugin_id in plugin_ids_plain_list if plugin_id not in cached_plugin_manifests
|
||||
]
|
||||
if not_included_plugin_ids:
|
||||
manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_included_plugin_ids)
|
||||
for manifest in manifests:
|
||||
cached_plugin_manifests[manifest.plugin_id] = manifest
|
||||
|
||||
if (
|
||||
len(manifests) == 0
|
||||
): # this indicates that the plugin not found in marketplace, should set None in cache to prevent future check
|
||||
for plugin_id in not_included_plugin_ids:
|
||||
cached_plugin_manifests[plugin_id] = None
|
||||
|
||||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin_id in plugin_ids_plain_list:
|
||||
final_manifest = cached_plugin_manifests.get(plugin_id)
|
||||
if final_manifest is not None:
|
||||
result.append(final_manifest)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@shared_task(queue="plugin")
|
||||
def process_tenant_plugin_autoupgrade_check_task(
|
||||
tenant_id: str,
|
||||
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting,
|
||||
upgrade_time_of_day: int,
|
||||
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode,
|
||||
exclude_plugins: list[str],
|
||||
include_plugins: list[str],
|
||||
):
|
||||
try:
|
||||
manager = PluginInstaller()
|
||||
|
||||
click.echo(
|
||||
click.style(
|
||||
"Checking upgradable plugin for tenant: {}".format(tenant_id),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
|
||||
if strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED:
|
||||
return
|
||||
|
||||
# get plugin_ids to check
|
||||
plugin_ids: list[tuple[str, str, str]] = [] # plugin_id, version, unique_identifier
|
||||
click.echo(click.style("Upgrade mode: {}".format(upgrade_mode), fg="green"))
|
||||
|
||||
if upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL and include_plugins:
|
||||
all_plugins = manager.list_plugins(tenant_id)
|
||||
|
||||
for plugin in all_plugins:
|
||||
if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id in include_plugins:
|
||||
plugin_ids.append(
|
||||
(
|
||||
plugin.plugin_id,
|
||||
plugin.version,
|
||||
plugin.plugin_unique_identifier,
|
||||
)
|
||||
)
|
||||
|
||||
elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE:
|
||||
# get all plugins and remove excluded plugins
|
||||
all_plugins = manager.list_plugins(tenant_id)
|
||||
plugin_ids = [
|
||||
(plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
|
||||
for plugin in all_plugins
|
||||
if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id not in exclude_plugins
|
||||
]
|
||||
elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL:
|
||||
all_plugins = manager.list_plugins(tenant_id)
|
||||
plugin_ids = [
|
||||
(plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
|
||||
for plugin in all_plugins
|
||||
if plugin.source == PluginInstallationSource.Marketplace
|
||||
]
|
||||
|
||||
if not plugin_ids:
|
||||
return
|
||||
|
||||
plugin_ids_plain_list = [plugin_id for plugin_id, _, _ in plugin_ids]
|
||||
|
||||
manifests = marketplace_batch_fetch_plugin_manifests(plugin_ids_plain_list)
|
||||
|
||||
if not manifests:
|
||||
return
|
||||
|
||||
for manifest in manifests:
|
||||
for plugin_id, version, original_unique_identifier in plugin_ids:
|
||||
if manifest.plugin_id != plugin_id:
|
||||
continue
|
||||
|
||||
try:
|
||||
current_version = version
|
||||
latest_version = manifest.latest_version
|
||||
|
||||
def fix_only_checker(latest_version, current_version):
|
||||
latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
|
||||
current_version_tuple = tuple(int(val) for val in current_version.split("."))
|
||||
|
||||
if (
|
||||
latest_version_tuple[0] == current_version_tuple[0]
|
||||
and latest_version_tuple[1] == current_version_tuple[1]
|
||||
):
|
||||
return latest_version_tuple[2] != current_version_tuple[2]
|
||||
return False
|
||||
|
||||
version_checker = {
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version,
|
||||
current_version: latest_version != current_version,
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
|
||||
}
|
||||
|
||||
if version_checker[strategy_setting](latest_version, current_version):
|
||||
# execute upgrade
|
||||
new_unique_identifier = manifest.latest_package_identifier
|
||||
|
||||
marketplace.record_install_plugin_event(new_unique_identifier)
|
||||
click.echo(
|
||||
click.style(
|
||||
"Upgrade plugin: {} -> {}".format(original_unique_identifier, new_unique_identifier),
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
task_start_resp = manager.upgrade_plugin(
|
||||
tenant_id,
|
||||
original_unique_identifier,
|
||||
new_unique_identifier,
|
||||
PluginInstallationSource.Marketplace,
|
||||
{
|
||||
"plugin_unique_identifier": new_unique_identifier,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(click.style("Error when upgrading plugin: {}".format(e), fg="red"))
|
||||
traceback.print_exc()
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
click.echo(click.style("Error when checking upgradable plugin: {}".format(e), fg="red"))
|
||||
traceback.print_exc()
|
||||
return
|
||||
@ -1,248 +0,0 @@
|
||||
import threading
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from events.event_handlers.update_provider_when_message_created import (
|
||||
handle,
|
||||
get_update_stats,
|
||||
)
|
||||
from models.provider import ProviderType
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
|
||||
class TestProviderUpdateDeadlockPrevention:
|
||||
"""Test suite for deadlock prevention in Provider updates."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Setup test fixtures."""
|
||||
self.mock_message = Mock()
|
||||
self.mock_message.answer_tokens = 100
|
||||
|
||||
self.mock_app_config = Mock()
|
||||
self.mock_app_config.tenant_id = "test-tenant-123"
|
||||
|
||||
self.mock_model_conf = Mock()
|
||||
self.mock_model_conf.provider = "openai"
|
||||
|
||||
self.mock_system_config = Mock()
|
||||
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
|
||||
|
||||
self.mock_provider_config = Mock()
|
||||
self.mock_provider_config.using_provider_type = ProviderType.SYSTEM
|
||||
self.mock_provider_config.system_configuration = self.mock_system_config
|
||||
|
||||
self.mock_provider_bundle = Mock()
|
||||
self.mock_provider_bundle.configuration = self.mock_provider_config
|
||||
|
||||
self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle
|
||||
|
||||
self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity)
|
||||
self.mock_generate_entity.app_config = self.mock_app_config
|
||||
self.mock_generate_entity.model_conf = self.mock_model_conf
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
def test_consolidated_handler_basic_functionality(self, mock_db):
|
||||
"""Test that the consolidated handler performs both updates correctly."""
|
||||
# Setup mock query chain
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1 # 1 row affected
|
||||
|
||||
# Call the handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify db.session.query was called
|
||||
assert mock_db.session.query.called
|
||||
|
||||
# Verify commit was called
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
# Verify no rollback was called
|
||||
assert not mock_db.session.rollback.called
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
def test_deadlock_retry_mechanism(self, mock_db):
|
||||
"""Test that deadlock errors trigger retry logic."""
|
||||
# Setup mock to raise deadlock error on first attempt, succeed on second
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
# First call raises deadlock, second succeeds
|
||||
mock_db.session.commit.side_effect = [
|
||||
OperationalError("deadlock detected", None, None),
|
||||
None, # Success on retry
|
||||
]
|
||||
|
||||
# Call the handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify commit was called twice (original + retry)
|
||||
assert mock_db.session.commit.call_count == 2
|
||||
|
||||
# Verify rollback was called once (after first failure)
|
||||
mock_db.session.rollback.assert_called_once()
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
@patch("events.event_handlers.update_provider_when_message_created.time.sleep")
|
||||
def test_exponential_backoff_timing(self, mock_sleep, mock_db):
|
||||
"""Test that retry delays follow exponential backoff pattern."""
|
||||
# Setup mock to fail twice, succeed on third attempt
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
mock_db.session.commit.side_effect = [
|
||||
OperationalError("deadlock detected", None, None),
|
||||
OperationalError("deadlock detected", None, None),
|
||||
None, # Success on third attempt
|
||||
]
|
||||
|
||||
# Call the handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify sleep was called twice with increasing delays
|
||||
assert mock_sleep.call_count == 2
|
||||
|
||||
# First delay should be around 0.1s + jitter
|
||||
first_delay = mock_sleep.call_args_list[0][0][0]
|
||||
assert 0.1 <= first_delay <= 0.3
|
||||
|
||||
# Second delay should be around 0.2s + jitter
|
||||
second_delay = mock_sleep.call_args_list[1][0][0]
|
||||
assert 0.2 <= second_delay <= 0.4
|
||||
|
||||
def test_concurrent_handler_execution(self):
|
||||
"""Test that multiple handlers can run concurrently without deadlock."""
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def run_handler():
|
||||
try:
|
||||
with patch(
|
||||
"events.event_handlers.update_provider_when_message_created.db"
|
||||
) as mock_db:
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
handle(
|
||||
self.mock_message,
|
||||
application_generate_entity=self.mock_generate_entity,
|
||||
)
|
||||
results.append("success")
|
||||
except Exception as e:
|
||||
errors.append(str(e))
|
||||
|
||||
# Run multiple handlers concurrently
|
||||
threads = []
|
||||
for _ in range(5):
|
||||
thread = threading.Thread(target=run_handler)
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join(timeout=5)
|
||||
|
||||
# Verify all handlers completed successfully
|
||||
assert len(results) == 5
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_performance_stats_tracking(self):
|
||||
"""Test that performance statistics are tracked correctly."""
|
||||
# Reset stats
|
||||
stats = get_update_stats()
|
||||
initial_total = stats["total_updates"]
|
||||
|
||||
with patch(
|
||||
"events.event_handlers.update_provider_when_message_created.db"
|
||||
) as mock_db:
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
# Call handler
|
||||
handle(
|
||||
self.mock_message, application_generate_entity=self.mock_generate_entity
|
||||
)
|
||||
|
||||
# Check that stats were updated
|
||||
updated_stats = get_update_stats()
|
||||
assert updated_stats["total_updates"] == initial_total + 1
|
||||
assert updated_stats["successful_updates"] >= initial_total + 1
|
||||
|
||||
def test_non_chat_entity_ignored(self):
|
||||
"""Test that non-chat entities are ignored by the handler."""
|
||||
# Create a non-chat entity
|
||||
mock_non_chat_entity = Mock()
|
||||
mock_non_chat_entity.__class__.__name__ = "NonChatEntity"
|
||||
|
||||
with patch(
|
||||
"events.event_handlers.update_provider_when_message_created.db"
|
||||
) as mock_db:
|
||||
# Call handler with non-chat entity
|
||||
handle(self.mock_message, application_generate_entity=mock_non_chat_entity)
|
||||
|
||||
# Verify no database operations were performed
|
||||
assert not mock_db.session.query.called
|
||||
assert not mock_db.session.commit.called
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
def test_quota_calculation_tokens(self, mock_db):
|
||||
"""Test quota calculation for token-based quotas."""
|
||||
# Setup token-based quota
|
||||
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
|
||||
self.mock_message.answer_tokens = 150
|
||||
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
# Call handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify update was called with token count
|
||||
update_calls = mock_query.update.call_args_list
|
||||
|
||||
# Should have at least one call with quota_used update
|
||||
quota_update_found = False
|
||||
for call in update_calls:
|
||||
values = call[0][0] # First argument to update()
|
||||
if "quota_used" in values:
|
||||
quota_update_found = True
|
||||
break
|
||||
|
||||
assert quota_update_found
|
||||
|
||||
@patch("events.event_handlers.update_provider_when_message_created.db")
|
||||
def test_quota_calculation_times(self, mock_db):
|
||||
"""Test quota calculation for times-based quotas."""
|
||||
# Setup times-based quota
|
||||
self.mock_system_config.current_quota_type = QuotaUnit.TIMES
|
||||
|
||||
mock_query = Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.filter.return_value = mock_query
|
||||
mock_query.order_by.return_value = mock_query
|
||||
mock_query.update.return_value = 1
|
||||
|
||||
# Call handler
|
||||
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
|
||||
|
||||
# Verify update was called
|
||||
assert mock_query.update.called
|
||||
assert mock_db.session.commit.called
|
||||
@ -0,0 +1,7 @@
|
||||
<svg width="32" height="32" viewBox="0 0 32 32" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M28.0049 16C28.0049 20.4183 24.4231 24 20.0049 24C15.5866 24 12.0049 20.4183 12.0049 16C12.0049 11.5817 15.5866 8 20.0049 8C24.4231 8 28.0049 11.5817 28.0049 16Z" stroke="#676F83" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M4.00488 16H6.67155" stroke="#676F83" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M4.00488 9.33334H8.00488" stroke="#676F83" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M4.00488 22.6667H8.00488" stroke="#676F83" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
<path d="M26 22L29.3333 25.3333" stroke="#676F83" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 823 B |
@ -0,0 +1,4 @@
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M5.46257 4.43262C7.21556 2.91688 9.5007 2 12 2C17.5228 2 22 6.47715 22 12C22 14.1361 21.3302 16.1158 20.1892 17.7406L17 12H20C20 7.58172 16.4183 4 12 4C9.84982 4 7.89777 4.84827 6.46023 6.22842L5.46257 4.43262ZM18.5374 19.5674C16.7844 21.0831 14.4993 22 12 22C6.47715 22 2 17.5228 2 12C2 9.86386 2.66979 7.88416 3.8108 6.25944L7 12H4C4 16.4183 7.58172 20 12 20C14.1502 20 16.1022 19.1517 17.5398 17.7716L18.5374 19.5674Z" fill="black"/>
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M16.3308 16H14.2915L13.6249 13.9476H10.3761L9.70846 16H7.66918L10.7759 7H13.2281L16.3308 16ZM10.8595 12.4622H13.1435L12.0378 9.05639H11.9673L10.8595 12.4622Z" fill="black"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 772 B |
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue