Merge branch 'main' into fix-list-operator

pull/22803/head
leslie2046 11 months ago
commit 3042494562

@ -471,6 +471,16 @@ APP_MAX_ACTIVE_REQUESTS=0
# Celery beat configuration
CELERY_BEAT_SCHEDULER_TIME=1
# Celery schedule tasks configuration
ENABLE_CLEAN_EMBEDDING_CACHE_TASK=false
ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
ENABLE_CLEAN_MESSAGES=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
# Position configuration
POSITION_TOOL_PINS=
POSITION_TOOL_INCLUDES=

@ -74,7 +74,12 @@
10. If you need to handle and debug the async tasks (e.g. dataset importing and documents indexing), please start the worker service.
```bash
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion
uv run celery -A app.celery worker -P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion,plugin
```
Addition, if you want to debug the celery scheduled tasks, you can use the following command in another terminal:
```bash
uv run celery -A app.celery beat
```
## Testing

@ -832,6 +832,41 @@ class CeleryBeatConfig(BaseSettings):
)
class CeleryScheduleTasksConfig(BaseSettings):
ENABLE_CLEAN_EMBEDDING_CACHE_TASK: bool = Field(
description="Enable clean embedding cache task",
default=False,
)
ENABLE_CLEAN_UNUSED_DATASETS_TASK: bool = Field(
description="Enable clean unused datasets task",
default=False,
)
ENABLE_CREATE_TIDB_SERVERLESS_TASK: bool = Field(
description="Enable create tidb service job task",
default=False,
)
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK: bool = Field(
description="Enable update tidb service job status task",
default=False,
)
ENABLE_CLEAN_MESSAGES: bool = Field(
description="Enable clean messages task",
default=False,
)
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,
)
ENABLE_DATASETS_QUEUE_MONITOR: bool = Field(
description="Enable queue monitor task",
default=False,
)
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field(
description="Enable check upgradable plugin task",
default=True,
)
class PositionConfig(BaseSettings):
POSITION_PROVIDER_PINS: str = Field(
description="Comma-separated list of pinned model providers",
@ -961,5 +996,6 @@ class FeatureConfig(
# hosted services config
HostedServiceConfig,
CeleryBeatConfig,
CeleryScheduleTasksConfig,
):
pass

@ -12,7 +12,8 @@ from controllers.console.wraps import account_initialization_required, setup_req
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.impl.exc import PluginDaemonClientSideError
from libs.login import login_required
from models.account import TenantPluginPermission
from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission
from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService
from services.plugin.plugin_parameter_service import PluginParameterService
from services.plugin.plugin_permission_service import PluginPermissionService
from services.plugin.plugin_service import PluginService
@ -534,6 +535,114 @@ class PluginFetchDynamicSelectOptionsApi(Resource):
return jsonable_encoder({"options": options})
class PluginChangePreferencesApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
user = current_user
if not user.is_admin_or_owner:
raise Forbidden()
req = reqparse.RequestParser()
req.add_argument("permission", type=dict, required=True, location="json")
req.add_argument("auto_upgrade", type=dict, required=True, location="json")
args = req.parse_args()
tenant_id = user.current_tenant_id
permission = args["permission"]
install_permission = TenantPluginPermission.InstallPermission(permission.get("install_permission", "everyone"))
debug_permission = TenantPluginPermission.DebugPermission(permission.get("debug_permission", "everyone"))
auto_upgrade = args["auto_upgrade"]
strategy_setting = TenantPluginAutoUpgradeStrategy.StrategySetting(
auto_upgrade.get("strategy_setting", "fix_only")
)
upgrade_time_of_day = auto_upgrade.get("upgrade_time_of_day", 0)
upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode(auto_upgrade.get("upgrade_mode", "exclude"))
exclude_plugins = auto_upgrade.get("exclude_plugins", [])
include_plugins = auto_upgrade.get("include_plugins", [])
# set permission
set_permission_result = PluginPermissionService.change_permission(
tenant_id,
install_permission,
debug_permission,
)
if not set_permission_result:
return jsonable_encoder({"success": False, "message": "Failed to set permission"})
# set auto upgrade strategy
set_auto_upgrade_strategy_result = PluginAutoUpgradeService.change_strategy(
tenant_id,
strategy_setting,
upgrade_time_of_day,
upgrade_mode,
exclude_plugins,
include_plugins,
)
if not set_auto_upgrade_strategy_result:
return jsonable_encoder({"success": False, "message": "Failed to set auto upgrade strategy"})
return jsonable_encoder({"success": True})
class PluginFetchPreferencesApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
tenant_id = current_user.current_tenant_id
permission = PluginPermissionService.get_permission(tenant_id)
permission_dict = {
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
"debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
}
if permission:
permission_dict["install_permission"] = permission.install_permission
permission_dict["debug_permission"] = permission.debug_permission
auto_upgrade = PluginAutoUpgradeService.get_strategy(tenant_id)
auto_upgrade_dict = {
"strategy_setting": TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
"upgrade_time_of_day": 0,
"upgrade_mode": TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
"exclude_plugins": [],
"include_plugins": [],
}
if auto_upgrade:
auto_upgrade_dict = {
"strategy_setting": auto_upgrade.strategy_setting,
"upgrade_time_of_day": auto_upgrade.upgrade_time_of_day,
"upgrade_mode": auto_upgrade.upgrade_mode,
"exclude_plugins": auto_upgrade.exclude_plugins,
"include_plugins": auto_upgrade.include_plugins,
}
return jsonable_encoder({"permission": permission_dict, "auto_upgrade": auto_upgrade_dict})
class PluginAutoUpgradeExcludePluginApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
# exclude one single plugin
tenant_id = current_user.current_tenant_id
req = reqparse.RequestParser()
req.add_argument("plugin_id", type=str, required=True, location="json")
args = req.parse_args()
return jsonable_encoder({"success": PluginAutoUpgradeService.exclude_plugin(tenant_id, args["plugin_id"])})
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
api.add_resource(PluginListLatestVersionsApi, "/workspaces/current/plugin/list/latest-versions")
@ -560,3 +669,7 @@ api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permissi
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
api.add_resource(PluginFetchDynamicSelectOptionsApi, "/workspaces/current/plugin/parameters/dynamic-options")
api.add_resource(PluginFetchPreferencesApi, "/workspaces/current/plugin/preferences/fetch")
api.add_resource(PluginChangePreferencesApi, "/workspaces/current/plugin/preferences/change")
api.add_resource(PluginAutoUpgradeExcludePluginApi, "/workspaces/current/plugin/preferences/autoupgrade/exclude")

@ -739,7 +739,7 @@ class ToolOAuthCallback(Resource):
raise Forbidden("no oauth available client config found for this tool provider")
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider}/tool/callback"
credentials = oauth_handler.get_credentials(
credentials_response = oauth_handler.get_credentials(
tenant_id=tenant_id,
user_id=user_id,
plugin_id=plugin_id,
@ -747,7 +747,10 @@ class ToolOAuthCallback(Resource):
redirect_uri=redirect_uri,
system_credentials=oauth_client_params,
request=request,
).credentials
)
credentials = credentials_response.credentials
expires_at = credentials_response.expires_at
if not credentials:
raise Exception("the plugin credentials failed")
@ -758,6 +761,7 @@ class ToolOAuthCallback(Resource):
tenant_id=tenant_id,
provider=provider,
credentials=dict(credentials),
expires_at=expires_at,
api_type=CredentialType.OAUTH2,
)
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")

@ -1,48 +0,0 @@
## Guidelines for Database Connection Management in App Runner and Task Pipeline
Due to the presence of tasks in App Runner that require long execution times, such as LLM generation and external requests, Flask-Sqlalchemy's strategy for database connection pooling is to allocate one connection (transaction) per request. This approach keeps a connection occupied even during non-DB tasks, leading to the inability to acquire new connections during high concurrency requests due to multiple long-running tasks.
Therefore, the database operations in App Runner and Task Pipeline must ensure connections are closed immediately after use, and it's better to pass IDs rather than Model objects to avoid detach errors.
Examples:
1. Creating a new record:
```python
app = App(id=1)
db.session.add(app)
db.session.commit()
db.session.refresh(app) # Retrieve table default values, like created_at, cached in the app object, won't affect after close
# Handle non-long-running tasks or store the content of the App instance in memory (via variable assignment).
db.session.close()
return app.id
```
2. Fetching a record from the table:
```python
app = db.session.query(App).filter(App.id == app_id).first()
created_at = app.created_at
db.session.close()
# Handle tasks (include long-running).
```
3. Updating a table field:
```python
app = db.session.query(App).filter(App.id == app_id).first()
app.updated_at = time.utcnow()
db.session.commit()
db.session.close()
return app_id
```

@ -7,7 +7,8 @@ from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
@ -486,21 +487,52 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"""
with preserve_flask_contexts(flask_app, context_vars=context):
try:
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
# chatbot app
runner = AdvancedChatAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
dialogue_count=self._dialogue_count,
variable_loader=variable_loader,
# get conversation and message
conversation = self._get_conversation(conversation_id)
message = self._get_message(message_id)
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
Workflow.app_id == application_generate_entity.app_config.app_id,
Workflow.id == application_generate_entity.app_config.workflow_id,
)
)
if workflow is None:
raise ValueError("Workflow not found")
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,
InvokeFrom.SERVICE_API,
}
if is_external_api_call:
# For external API calls, use end user's session ID
end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id))
system_user_id = end_user.session_id if end_user else ""
else:
# For internal calls, use the original user ID
system_user_id = application_generate_entity.user_id
app = session.scalar(select(App).where(App.id == application_generate_entity.app_config.app_id))
if app is None:
raise ValueError("App not found")
runner = AdvancedChatAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation=conversation,
message=message,
dialogue_count=self._dialogue_count,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
app=app,
)
try:
runner.run()
except GenerateTaskStoppedError:
pass

@ -1,6 +1,6 @@
import logging
from collections.abc import Mapping
from typing import Any, cast
from typing import Any, Optional, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
@ -9,13 +9,19 @@ from configs import dify_config
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AppGenerateEntity,
InvokeFrom,
)
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
QueueStopEvent,
QueueTextChunkEvent,
)
from core.app.features.annotation_reply.annotation_reply import AnnotationReplyFeature
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool
@ -23,8 +29,9 @@ from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, EndUser, Message
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable, WorkflowType
logger = logging.getLogger(__name__)
@ -37,21 +44,29 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
def __init__(
self,
*,
application_generate_entity: AdvancedChatAppGenerateEntity,
queue_manager: AppQueueManager,
conversation: Conversation,
message: Message,
dialogue_count: int,
variable_loader: VariableLoader,
workflow: Workflow,
system_user_id: str,
app: App,
) -> None:
super().__init__(queue_manager, variable_loader)
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
)
self.application_generate_entity = application_generate_entity
self.conversation = conversation
self.message = message
self._dialogue_count = dialogue_count
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
self._workflow = workflow
self.system_user_id = system_user_id
self._app = app
def run(self) -> None:
app_config = self.application_generate_entity.app_config
@ -61,18 +76,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not app_record:
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
user_id: str | None = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
@ -80,14 +83,14 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=dict(self.application_generate_entity.single_iteration_run.inputs),
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=dict(self.application_generate_entity.single_loop_run.inputs),
)
@ -98,7 +101,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# moderation
if self.handle_input_moderation(
app_record=app_record,
app_record=self._app,
app_generate_entity=self.application_generate_entity,
inputs=inputs,
query=query,
@ -108,7 +111,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
# annotation reply
if self.handle_annotation_reply(
app_record=app_record,
app_record=self._app,
message=self.message,
query=query,
app_generate_entity=self.application_generate_entity,
@ -128,7 +131,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
ConversationVariable.from_variable(
app_id=self.conversation.app_id, conversation_id=self.conversation.id, variable=variable
)
for variable in workflow.conversation_variables
for variable in self._workflow.conversation_variables
]
session.add_all(db_conversation_variables)
# Convert database entities to variables.
@ -141,7 +144,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
query=query,
files=files,
conversation_id=self.conversation.id,
user_id=user_id,
user_id=self.system_user_id,
dialogue_count=self._dialogue_count,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
@ -152,25 +155,25 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables),
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
graph = self._init_graph(graph_config=self._workflow.graph_dict)
db.session.close()
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
workflow_id=self._workflow.id,
workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
@ -241,3 +244,51 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._publish_event(QueueTextChunkEvent(text=text))
self._publish_event(QueueStopEvent(stopped_by=stopped_by))
def query_app_annotations_to_reply(
self, app_record: App, message: Message, query: str, user_id: str, invoke_from: InvokeFrom
) -> Optional[MessageAnnotation]:
"""
Query app annotations to reply
:param app_record: app record
:param message: message
:param query: query
:param user_id: user id
:param invoke_from: invoke from
:return:
"""
annotation_reply_feature = AnnotationReplyFeature()
return annotation_reply_feature.query(
app_record=app_record, message=message, query=query, user_id=user_id, invoke_from=invoke_from
)
def moderation_for_inputs(
self,
*,
app_id: str,
tenant_id: str,
app_generate_entity: AppGenerateEntity,
inputs: Mapping[str, Any],
query: str | None = None,
message_id: str,
) -> tuple[bool, Mapping[str, Any], str]:
"""
Process sensitive_word_avoidance.
:param app_id: app id
:param tenant_id: tenant id
:param app_generate_entity: app generate entity
:param inputs: inputs
:param query: query
:param message_id: message id
:return:
"""
moderation_feature = InputModeration()
return moderation_feature.check(
app_id=app_id,
tenant_id=tenant_id,
app_config=app_generate_entity.app_config,
inputs=dict(inputs),
query=query or "",
message_id=message_id,
trace_manager=app_generate_entity.trace_manager,
)

@ -7,7 +7,8 @@ from typing import Any, Literal, Optional, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
from sqlalchemy.orm import sessionmaker
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
import contexts
from configs import dify_config
@ -445,17 +446,44 @@ class WorkflowAppGenerator(BaseAppGenerator):
"""
with preserve_flask_contexts(flask_app, context_vars=context):
try:
# workflow app
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
with Session(db.engine, expire_on_commit=False) as session:
workflow = session.scalar(
select(Workflow).where(
Workflow.tenant_id == application_generate_entity.app_config.tenant_id,
Workflow.app_id == application_generate_entity.app_config.app_id,
Workflow.id == application_generate_entity.app_config.workflow_id,
)
)
if workflow is None:
raise ValueError("Workflow not found")
# Determine system_user_id based on invocation source
is_external_api_call = application_generate_entity.invoke_from in {
InvokeFrom.WEB_APP,
InvokeFrom.SERVICE_API,
}
if is_external_api_call:
# For external API calls, use end user's session ID
end_user = session.scalar(select(EndUser).where(EndUser.id == application_generate_entity.user_id))
system_user_id = end_user.session_id if end_user else ""
else:
# For internal calls, use the original user ID
system_user_id = application_generate_entity.user_id
runner = WorkflowAppRunner(
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
workflow=workflow,
system_user_id=system_user_id,
)
try:
runner.run()
except GenerateTaskStoppedError:
except GenerateTaskStoppedError as e:
logger.warning(f"Task stopped: {str(e)}")
pass
except InvokeAuthorizationError:
queue_manager.publish_error(
@ -471,8 +499,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
except Exception as e:
logger.exception("Unknown Error when generating")
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
finally:
db.session.close()
def _handle_response(
self,

@ -14,10 +14,8 @@ from core.workflow.entities.variable_pool import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.enums import UserFrom
from models.model import App, EndUser
from models.workflow import WorkflowType
from models.workflow import Workflow, WorkflowType
logger = logging.getLogger(__name__)
@ -29,22 +27,23 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
def __init__(
self,
*,
application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None,
workflow: Workflow,
system_user_id: str,
) -> None:
"""
:param application_generate_entity: application generate entity
:param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id
"""
super().__init__(queue_manager, variable_loader)
super().__init__(
queue_manager=queue_manager,
variable_loader=variable_loader,
app_id=application_generate_entity.app_config.app_id,
)
self.application_generate_entity = application_generate_entity
self.workflow_thread_pool_id = workflow_thread_pool_id
def _get_app_id(self) -> str:
return self.application_generate_entity.app_config.app_id
self._workflow = workflow
self._sys_user_id = system_user_id
def run(self) -> None:
"""
@ -53,24 +52,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
else:
user_id = self.application_generate_entity.user_id
app_record = db.session.query(App).filter(App.id == app_config.app_id).first()
if not app_record:
raise ValueError("App not found")
workflow = self.get_workflow(app_model=app_record, workflow_id=app_config.workflow_id)
if not workflow:
raise ValueError("Workflow not initialized")
db.session.close()
workflow_callbacks: list[WorkflowCallback] = []
if dify_config.DEBUG:
workflow_callbacks.append(WorkflowLoggingCallback())
@ -79,14 +60,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
if self.application_generate_entity.single_iteration_run:
# if only single iteration run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
workflow=workflow,
workflow=self._workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
)
elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
workflow=workflow,
workflow=self._workflow,
node_id=self.application_generate_entity.single_loop_run.node_id,
user_inputs=self.application_generate_entity.single_loop_run.inputs,
)
@ -98,7 +79,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
system_inputs = SystemVariable(
files=files,
user_id=user_id,
user_id=self._sys_user_id,
app_id=app_config.app_id,
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
@ -107,21 +88,21 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=workflow.environment_variables,
environment_variables=self._workflow.environment_variables,
conversation_variables=[],
)
# init graph
graph = self._init_graph(graph_config=workflow.graph_dict)
graph = self._init_graph(graph_config=self._workflow.graph_dict)
# RUN WORKFLOW
workflow_entry = WorkflowEntry(
tenant_id=workflow.tenant_id,
app_id=workflow.app_id,
workflow_id=workflow.id,
workflow_type=WorkflowType.value_of(workflow.type),
tenant_id=self._workflow.tenant_id,
app_id=self._workflow.app_id,
workflow_id=self._workflow.id,
workflow_type=WorkflowType.value_of(self._workflow.type),
graph=graph,
graph_config=workflow.graph_dict,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT

@ -1,8 +1,7 @@
from collections.abc import Mapping
from typing import Any, Optional, cast
from typing import Any, cast
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
@ -65,18 +64,20 @@ from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.system_variable import SystemVariable
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.model import App
from models.workflow import Workflow
class WorkflowBasedAppRunner(AppRunner):
def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None:
self.queue_manager = queue_manager
class WorkflowBasedAppRunner:
def __init__(
self,
*,
queue_manager: AppQueueManager,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
app_id: str,
) -> None:
self._queue_manager = queue_manager
self._variable_loader = variable_loader
def _get_app_id(self) -> str:
raise NotImplementedError("not implemented")
self._app_id = app_id
def _init_graph(self, graph_config: Mapping[str, Any]) -> Graph:
"""
@ -693,21 +694,5 @@ class WorkflowBasedAppRunner(AppRunner):
)
)
def get_workflow(self, app_model: App, workflow_id: str) -> Optional[Workflow]:
"""
Get workflow
"""
# fetch workflow by workflow_id
workflow = (
db.session.query(Workflow)
.filter(
Workflow.tenant_id == app_model.tenant_id, Workflow.app_id == app_model.id, Workflow.id == workflow_id
)
.first()
)
# return workflow
return workflow
def _publish_event(self, event: AppQueueEvent) -> None:
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)

@ -25,9 +25,29 @@ def batch_fetch_plugin_manifests(plugin_ids: list[str]) -> Sequence[MarketplaceP
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
return [MarketplacePluginDeclaration(**plugin) for plugin in response.json()["data"]["plugins"]]
def batch_fetch_plugin_manifests_ignore_deserialization_error(
plugin_ids: list[str],
) -> Sequence[MarketplacePluginDeclaration]:
if len(plugin_ids) == 0:
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = requests.post(url, json={"plugin_ids": plugin_ids})
response.raise_for_status()
result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]:
try:
result.append(MarketplacePluginDeclaration(**plugin))
except Exception as e:
pass
return result
def record_install_plugin_event(plugin_unique_identifier: str):
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
response = requests.post(url, json={"unique_identifier": plugin_unique_identifier})

@ -114,7 +114,8 @@ class LLMGenerator:
),
)
questions = output_parser.parse(cast(str, response.message.content))
text_content = response.message.get_text_content()
questions = output_parser.parse(text_content) if text_content else []
except InvokeError:
questions = []
except Exception:

@ -15,5 +15,4 @@ class SuggestedQuestionsAfterAnswerOutputParser:
json_obj = json.loads(action_match.group(0).strip())
else:
json_obj = []
return json_obj

@ -156,6 +156,23 @@ class PromptMessage(ABC, BaseModel):
"""
return not self.content
def get_text_content(self) -> str:
"""
Get text content from prompt message.
:return: Text content as string, empty string if no text content
"""
if isinstance(self.content, str):
return self.content
elif isinstance(self.content, list):
text_parts = []
for item in self.content:
if isinstance(item, TextPromptMessageContent):
text_parts.append(item.data)
return "".join(text_parts)
else:
return ""
@field_validator("content", mode="before")
@classmethod
def validate_content(cls, v):

@ -182,6 +182,10 @@ class PluginOAuthAuthorizationUrlResponse(BaseModel):
class PluginOAuthCredentialsResponse(BaseModel):
metadata: Mapping[str, Any] = Field(
default_factory=dict, description="The metadata of the OAuth, like avatar url, name, etc."
)
expires_at: int = Field(default=-1, description="The expires at time of the credentials. UTC timestamp.")
credentials: Mapping[str, Any] = Field(description="The credentials of the OAuth.")

@ -84,6 +84,41 @@ class OAuthHandler(BasePluginClient):
except Exception as e:
raise ValueError(f"Error getting credentials: {e}")
def refresh_credentials(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
redirect_uri: str,
system_credentials: Mapping[str, Any],
credentials: Mapping[str, Any],
) -> PluginOAuthCredentialsResponse:
try:
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/oauth/refresh_credentials",
PluginOAuthCredentialsResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"redirect_uri": redirect_uri,
"system_credentials": system_credentials,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("No response received from plugin daemon for refresh credentials request.")
except Exception as e:
raise ValueError(f"Error refreshing credentials: {e}")
def _convert_request_to_raw_data(self, request: Request) -> bytes:
"""
Convert a Request object to raw HTTP data.

@ -284,7 +284,8 @@ class TencentVector(BaseVector):
# Compatible with version 1.1.3 and below.
meta = json.loads(meta)
score = 1 - result.get("score", 0.0)
score = result.get("score", 0.0)
else:
score = result.get("score", 0.0)
if score > score_threshold:
meta["score"] = score
doc = Document(page_content=result.get(self.field_text), metadata=meta)

@ -1,16 +1,19 @@
import json
import logging
import mimetypes
from collections.abc import Generator
import time
from collections.abc import Generator, Mapping
from os import listdir, path
from threading import Lock
from typing import TYPE_CHECKING, Any, Literal, Optional, Union, cast
from pydantic import TypeAdapter
from yarl import URL
import contexts
from core.helper.provider_cache import ToolProviderCredentialsCache
from core.plugin.entities.plugin import ToolProviderID
from core.plugin.impl.oauth import OAuthHandler
from core.plugin.impl.tool import PluginToolManager
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.__base.tool_runtime import ToolRuntime
@ -244,12 +247,47 @@ class ToolManager:
tenant_id=tenant_id, provider=provider_id, credential_id=builtin_provider.id
),
)
# decrypt the credentials
decrypted_credentials: Mapping[str, Any] = encrypter.decrypt(builtin_provider.credentials)
# check if the credentials is expired
if builtin_provider.expires_at != -1 and (builtin_provider.expires_at - 60) < int(time.time()):
# TODO: circular import
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
# refresh the credentials
tool_provider = ToolProviderID(provider_id)
provider_name = tool_provider.provider_name
redirect_uri = f"{dify_config.CONSOLE_API_URL}/console/api/oauth/plugin/{provider_id}/tool/callback"
system_credentials = BuiltinToolManageService.get_oauth_client(tenant_id, provider_id)
oauth_handler = OAuthHandler()
# refresh the credentials
refreshed_credentials = oauth_handler.refresh_credentials(
tenant_id=tenant_id,
user_id=builtin_provider.user_id,
plugin_id=tool_provider.plugin_id,
provider=provider_name,
redirect_uri=redirect_uri,
system_credentials=system_credentials or {},
credentials=decrypted_credentials,
)
# update the credentials
builtin_provider.encrypted_credentials = (
TypeAdapter(dict[str, Any])
.dump_json(encrypter.encrypt(dict(refreshed_credentials.credentials)))
.decode("utf-8")
)
builtin_provider.expires_at = refreshed_credentials.expires_at
db.session.commit()
decrypted_credentials = refreshed_credentials.credentials
return cast(
BuiltinTool,
builtin_tool.fork_tool_runtime(
runtime=ToolRuntime(
tenant_id=tenant_id,
credentials=encrypter.decrypt(builtin_provider.credentials),
credentials=dict(decrypted_credentials),
credential_type=CredentialType.of(builtin_provider.credential_type),
runtime_parameters={},
invoke_from=invoke_from,

@ -317,8 +317,13 @@ class ToolNode(BaseNode):
elif message.type == ToolInvokeMessage.MessageType.FILE:
assert message.meta is not None
assert isinstance(message.meta, dict)
assert "file" in message.meta
assert isinstance(message.meta["file"], File)
# Validate that meta contains a 'file' key
if "file" not in message.meta:
raise ToolNodeError("File message is missing 'file' key in meta")
# Validate that the file is an instance of File
if not isinstance(message.meta["file"], File):
raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
files.append(message.meta["file"])
elif message.type == ToolInvokeMessage.MessageType.LOG:
assert isinstance(message.message, ToolInvokeMessage.LogMessage)

@ -22,7 +22,7 @@ if [[ "${MODE}" == "worker" ]]; then
exec celery -A app.celery worker -P ${CELERY_WORKER_CLASS:-gevent} $CONCURRENCY_OPTION \
--max-tasks-per-child ${MAX_TASK_PRE_CHILD:-50} --loglevel ${LOG_LEVEL:-INFO} \
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion}
-Q ${CELERY_QUEUES:-dataset,mail,ops_trace,app_deletion,plugin}
elif [[ "${MODE}" == "beat" ]]; then
exec celery -A app.celery beat --loglevel ${LOG_LEVEL:-INFO}

@ -64,49 +64,62 @@ def init_app(app: DifyApp) -> Celery:
celery_app.set_default()
app.extensions["celery"] = celery_app
imports = [
"schedule.clean_embedding_cache_task",
"schedule.clean_unused_datasets_task",
"schedule.create_tidb_serverless_task",
"schedule.update_tidb_serverless_status_task",
"schedule.clean_messages",
"schedule.mail_clean_document_notify_task",
"schedule.queue_monitor_task",
]
imports = []
day = dify_config.CELERY_BEAT_SCHEDULER_TIME
beat_schedule = {
"clean_embedding_cache_task": {
# if you add a new task, please add the switch to CeleryScheduleTasksConfig
beat_schedule = {}
if dify_config.ENABLE_CLEAN_EMBEDDING_CACHE_TASK:
imports.append("schedule.clean_embedding_cache_task")
beat_schedule["clean_embedding_cache_task"] = {
"task": "schedule.clean_embedding_cache_task.clean_embedding_cache_task",
"schedule": timedelta(days=day),
},
"clean_unused_datasets_task": {
}
if dify_config.ENABLE_CLEAN_UNUSED_DATASETS_TASK:
imports.append("schedule.clean_unused_datasets_task")
beat_schedule["clean_unused_datasets_task"] = {
"task": "schedule.clean_unused_datasets_task.clean_unused_datasets_task",
"schedule": timedelta(days=day),
},
"create_tidb_serverless_task": {
}
if dify_config.ENABLE_CREATE_TIDB_SERVERLESS_TASK:
imports.append("schedule.create_tidb_serverless_task")
beat_schedule["create_tidb_serverless_task"] = {
"task": "schedule.create_tidb_serverless_task.create_tidb_serverless_task",
"schedule": crontab(minute="0", hour="*"),
},
"update_tidb_serverless_status_task": {
}
if dify_config.ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK:
imports.append("schedule.update_tidb_serverless_status_task")
beat_schedule["update_tidb_serverless_status_task"] = {
"task": "schedule.update_tidb_serverless_status_task.update_tidb_serverless_status_task",
"schedule": timedelta(minutes=10),
},
"clean_messages": {
}
if dify_config.ENABLE_CLEAN_MESSAGES:
imports.append("schedule.clean_messages")
beat_schedule["clean_messages"] = {
"task": "schedule.clean_messages.clean_messages",
"schedule": timedelta(days=day),
},
# every Monday
"mail_clean_document_notify_task": {
}
if dify_config.ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK:
imports.append("schedule.mail_clean_document_notify_task")
beat_schedule["mail_clean_document_notify_task"] = {
"task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task",
"schedule": crontab(minute="0", hour="10", day_of_week="1"),
},
"datasets-queue-monitor": {
}
if dify_config.ENABLE_DATASETS_QUEUE_MONITOR:
imports.append("schedule.queue_monitor_task")
beat_schedule["datasets-queue-monitor"] = {
"task": "schedule.queue_monitor_task.queue_monitor_task",
"schedule": timedelta(
minutes=dify_config.QUEUE_MONITOR_INTERVAL if dify_config.QUEUE_MONITOR_INTERVAL else 30
),
},
}
}
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:
imports.append("schedule.check_upgradable_plugin_task")
beat_schedule["check_upgradable_plugin_task"] = {
"task": "schedule.check_upgradable_plugin_task.check_upgradable_plugin_task",
"schedule": crontab(minute="*/15"),
}
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
return celery_app

@ -0,0 +1,34 @@
"""oauth_refresh_token
Revision ID: 375fe79ead14
Revises: 1a83934ad6d1
Create Date: 2025-07-22 00:19:45.599636
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '375fe79ead14'
down_revision = '1a83934ad6d1'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.add_column(sa.Column('expires_at', sa.BigInteger(), server_default=sa.text('-1'), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tool_builtin_providers', schema=None) as batch_op:
batch_op.drop_column('expires_at')
# ### end Alembic commands ###

@ -0,0 +1,42 @@
"""add_tenant_plugin_autoupgrade_table
Revision ID: 8bcc02c9bd07
Revises: 375fe79ead14
Create Date: 2025-07-23 15:08:50.161441
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '8bcc02c9bd07'
down_revision = '375fe79ead14'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tenant_plugin_auto_upgrade_strategies',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('strategy_setting', sa.String(length=16), server_default='fix_only', nullable=False),
sa.Column('upgrade_time_of_day', sa.Integer(), nullable=False),
sa.Column('upgrade_mode', sa.String(length=16), server_default='exclude', nullable=False),
sa.Column('exclude_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
sa.Column('include_plugins', sa.ARRAY(sa.String(length=255)), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tenant_plugin_auto_upgrade_strategy_pkey'),
sa.UniqueConstraint('tenant_id', name='unique_tenant_plugin_auto_upgrade_strategy')
)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('tenant_plugin_auto_upgrade_strategies')
# ### end Alembic commands ###

@ -297,6 +297,40 @@ class TenantPluginPermission(Base):
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID)
install_permission: Mapped[InstallPermission] = mapped_column(db.String(16), server_default="everyone")
debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), server_default="noone")
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
install_permission: Mapped[InstallPermission] = mapped_column(
db.String(16), nullable=False, server_default="everyone"
)
debug_permission: Mapped[DebugPermission] = mapped_column(db.String(16), nullable=False, server_default="noone")
class TenantPluginAutoUpgradeStrategy(Base):
class StrategySetting(enum.StrEnum):
DISABLED = "disabled"
FIX_ONLY = "fix_only"
LATEST = "latest"
class UpgradeMode(enum.StrEnum):
ALL = "all"
PARTIAL = "partial"
EXCLUDE = "exclude"
__tablename__ = "tenant_plugin_auto_upgrade_strategies"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_plugin_auto_upgrade_strategy_pkey"),
db.UniqueConstraint("tenant_id", name="unique_tenant_plugin_auto_upgrade_strategy"),
)
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
strategy_setting: Mapped[StrategySetting] = mapped_column(db.String(16), nullable=False, server_default="fix_only")
upgrade_time_of_day: Mapped[int] = mapped_column(db.Integer, nullable=False, default=0) # seconds of the day
upgrade_mode: Mapped[UpgradeMode] = mapped_column(db.String(16), nullable=False, server_default="exclude")
exclude_plugins: Mapped[list[str]] = mapped_column(
db.ARRAY(db.String(255)), nullable=False
) # plugin_id (author/name)
include_plugins: Mapped[list[str]] = mapped_column(
db.ARRAY(db.String(255)), nullable=False
) # plugin_id (author/name)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

@ -93,6 +93,7 @@ class BuiltinToolProvider(Base):
credential_type: Mapped[str] = mapped_column(
db.String(32), nullable=False, server_default=db.text("'api-key'::character varying")
)
expires_at: Mapped[int] = mapped_column(db.BigInteger, nullable=False, server_default=db.text("-1"))
@property
def credentials(self) -> dict:

@ -0,0 +1,49 @@
import time
import click
import app
from extensions.ext_database import db
from models.account import TenantPluginAutoUpgradeStrategy
from tasks.process_tenant_plugin_autoupgrade_check_task import process_tenant_plugin_autoupgrade_check_task
AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL = 15 * 60 # 15 minutes
@app.celery.task(queue="plugin")
def check_upgradable_plugin_task():
click.echo(click.style("Start check upgradable plugin.", fg="green"))
start_at = time.perf_counter()
now_seconds_of_day = time.time() % 86400 - 30 # we assume the tz is UTC
click.echo(click.style("Now seconds of day: {}".format(now_seconds_of_day), fg="green"))
strategies = (
db.session.query(TenantPluginAutoUpgradeStrategy)
.filter(
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day >= now_seconds_of_day,
TenantPluginAutoUpgradeStrategy.upgrade_time_of_day
< now_seconds_of_day + AUTO_UPGRADE_MINIMAL_CHECKING_INTERVAL,
TenantPluginAutoUpgradeStrategy.strategy_setting
!= TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED,
)
.all()
)
for strategy in strategies:
process_tenant_plugin_autoupgrade_check_task.delay(
strategy.tenant_id,
strategy.strategy_setting,
strategy.upgrade_time_of_day,
strategy.upgrade_mode,
strategy.exclude_plugins,
strategy.include_plugins,
)
end_at = time.perf_counter()
click.echo(
click.style(
"Checked upgradable plugin success latency: {}".format(end_at - start_at),
fg="green",
)
)

@ -29,6 +29,7 @@ from models.account import (
Tenant,
TenantAccountJoin,
TenantAccountRole,
TenantPluginAutoUpgradeStrategy,
TenantStatus,
)
from models.model import DifySetup
@ -828,6 +829,17 @@ class TenantService:
db.session.add(tenant)
db.session.commit()
plugin_upgrade_strategy = TenantPluginAutoUpgradeStrategy(
tenant_id=tenant.id,
strategy_setting=TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
upgrade_time_of_day=0,
upgrade_mode=TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
exclude_plugins=[],
include_plugins=[],
)
db.session.add(plugin_upgrade_strategy)
db.session.commit()
tenant.encrypt_public_key = generate_key_pair(tenant.id)
db.session.commit()
return tenant

@ -0,0 +1,87 @@
from sqlalchemy.orm import Session
from extensions.ext_database import db
from models.account import TenantPluginAutoUpgradeStrategy
class PluginAutoUpgradeService:
@staticmethod
def get_strategy(tenant_id: str) -> TenantPluginAutoUpgradeStrategy | None:
with Session(db.engine) as session:
return (
session.query(TenantPluginAutoUpgradeStrategy)
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
)
@staticmethod
def change_strategy(
tenant_id: str,
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting,
upgrade_time_of_day: int,
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode,
exclude_plugins: list[str],
include_plugins: list[str],
) -> bool:
with Session(db.engine) as session:
exist_strategy = (
session.query(TenantPluginAutoUpgradeStrategy)
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
)
if not exist_strategy:
strategy = TenantPluginAutoUpgradeStrategy(
tenant_id=tenant_id,
strategy_setting=strategy_setting,
upgrade_time_of_day=upgrade_time_of_day,
upgrade_mode=upgrade_mode,
exclude_plugins=exclude_plugins,
include_plugins=include_plugins,
)
session.add(strategy)
else:
exist_strategy.strategy_setting = strategy_setting
exist_strategy.upgrade_time_of_day = upgrade_time_of_day
exist_strategy.upgrade_mode = upgrade_mode
exist_strategy.exclude_plugins = exclude_plugins
exist_strategy.include_plugins = include_plugins
session.commit()
return True
@staticmethod
def exclude_plugin(tenant_id: str, plugin_id: str) -> bool:
with Session(db.engine) as session:
exist_strategy = (
session.query(TenantPluginAutoUpgradeStrategy)
.filter(TenantPluginAutoUpgradeStrategy.tenant_id == tenant_id)
.first()
)
if not exist_strategy:
# create for this tenant
PluginAutoUpgradeService.change_strategy(
tenant_id,
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY,
0,
TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE,
[plugin_id],
[],
)
return True
else:
if exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE:
if plugin_id not in exist_strategy.exclude_plugins:
new_exclude_plugins = exist_strategy.exclude_plugins.copy()
new_exclude_plugins.append(plugin_id)
exist_strategy.exclude_plugins = new_exclude_plugins
elif exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL:
if plugin_id in exist_strategy.include_plugins:
new_include_plugins = exist_strategy.include_plugins.copy()
new_include_plugins.remove(plugin_id)
exist_strategy.include_plugins = new_include_plugins
elif exist_strategy.upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL:
exist_strategy.upgrade_mode = TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE
exist_strategy.exclude_plugins = [plugin_id]
session.commit()
return True

@ -38,6 +38,7 @@ logger = logging.getLogger(__name__)
class BuiltinToolManageService:
__MAX_BUILTIN_TOOL_PROVIDER_COUNT__ = 100
__DEFAULT_EXPIRES_AT__ = 2147483647
@staticmethod
def delete_custom_oauth_client_params(tenant_id: str, provider: str):
@ -212,6 +213,7 @@ class BuiltinToolManageService:
tenant_id: str,
provider: str,
credentials: dict,
expires_at: int = -1,
name: str | None = None,
):
"""
@ -269,6 +271,9 @@ class BuiltinToolManageService:
encrypted_credentials=json.dumps(encrypter.encrypt(credentials)),
credential_type=api_type.value,
name=name,
expires_at=expires_at
if expires_at is not None
else BuiltinToolManageService.__DEFAULT_EXPIRES_AT__,
)
session.add(db_provider)

@ -112,19 +112,27 @@ class MCPToolManageService:
@classmethod
def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
server_url = mcp_provider.decrypted_server_url
authed = mcp_provider.authed
try:
with MCPClient(
mcp_provider.decrypted_server_url, provider_id, tenant_id, authed=mcp_provider.authed, for_list=True
) as mcp_client:
with MCPClient(server_url, provider_id, tenant_id, authed=authed, for_list=True) as mcp_client:
tools = mcp_client.list_tools()
except MCPAuthError:
raise ValueError("Please auth the tool first")
except MCPError as e:
raise ValueError(f"Failed to connect to MCP server: {e}")
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
mcp_provider.authed = True
mcp_provider.updated_at = datetime.now()
db.session.commit()
try:
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
mcp_provider.authed = True
mcp_provider.updated_at = datetime.now()
db.session.commit()
except Exception:
db.session.rollback()
raise
user = mcp_provider.load_user()
return ToolProviderApiEntity(
id=mcp_provider.id,
@ -160,22 +168,35 @@ class MCPToolManageService:
server_identifier: str,
):
mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id)
mcp_provider.updated_at = datetime.now()
mcp_provider.name = name
mcp_provider.icon = (
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
)
mcp_provider.server_identifier = server_identifier
reconnect_result = None
encrypted_server_url = None
server_url_hash = None
if UNCHANGED_SERVER_URL_PLACEHOLDER not in server_url:
encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url)
mcp_provider.server_url = encrypted_server_url
server_url_hash = hashlib.sha256(server_url.encode()).hexdigest()
if server_url_hash != mcp_provider.server_url_hash:
cls._re_connect_mcp_provider(mcp_provider, provider_id, tenant_id)
mcp_provider.server_url_hash = server_url_hash
reconnect_result = cls._re_connect_mcp_provider(server_url, provider_id, tenant_id)
try:
mcp_provider.updated_at = datetime.now()
mcp_provider.name = name
mcp_provider.icon = (
json.dumps({"content": icon, "background": icon_background}) if icon_type == "emoji" else icon
)
mcp_provider.server_identifier = server_identifier
if encrypted_server_url is not None and server_url_hash is not None:
mcp_provider.server_url = encrypted_server_url
mcp_provider.server_url_hash = server_url_hash
if reconnect_result:
mcp_provider.authed = reconnect_result["authed"]
mcp_provider.tools = reconnect_result["tools"]
mcp_provider.encrypted_credentials = reconnect_result["encrypted_credentials"]
db.session.commit()
except IntegrityError as e:
db.session.rollback()
@ -187,6 +208,9 @@ class MCPToolManageService:
if "unique_mcp_provider_server_identifier" in error_msg:
raise ValueError(f"MCP tool {server_identifier} already exists")
raise
except Exception:
db.session.rollback()
raise
@classmethod
def update_mcp_provider_credentials(
@ -207,23 +231,22 @@ class MCPToolManageService:
db.session.commit()
@classmethod
def _re_connect_mcp_provider(cls, mcp_provider: MCPToolProvider, provider_id: str, tenant_id: str):
"""re-connect mcp provider"""
def _re_connect_mcp_provider(cls, server_url: str, provider_id: str, tenant_id: str):
try:
with MCPClient(
mcp_provider.decrypted_server_url,
server_url,
provider_id,
tenant_id,
authed=False,
for_list=True,
) as mcp_client:
tools = mcp_client.list_tools()
mcp_provider.authed = True
mcp_provider.tools = json.dumps([tool.model_dump() for tool in tools])
return {
"authed": True,
"tools": json.dumps([tool.model_dump() for tool in tools]),
"encrypted_credentials": "{}",
}
except MCPAuthError:
mcp_provider.authed = False
mcp_provider.tools = "[]"
return {"authed": False, "tools": "[]", "encrypted_credentials": "{}"}
except MCPError as e:
raise ValueError(f"Failed to re-connect MCP server: {e}") from e
# reset credentials
mcp_provider.encrypted_credentials = "{}"

@ -0,0 +1,166 @@
import traceback
import typing
import click
from celery import shared_task # type: ignore
from core.helper import marketplace
from core.helper.marketplace import MarketplacePluginDeclaration
from core.plugin.entities.plugin import PluginInstallationSource
from core.plugin.impl.plugin import PluginInstaller
from models.account import TenantPluginAutoUpgradeStrategy
RETRY_TIMES_OF_ONE_PLUGIN_IN_ONE_TENANT = 3
cached_plugin_manifests: dict[str, typing.Union[MarketplacePluginDeclaration, None]] = {}
def marketplace_batch_fetch_plugin_manifests(
plugin_ids_plain_list: list[str],
) -> list[MarketplacePluginDeclaration]:
global cached_plugin_manifests
# return marketplace.batch_fetch_plugin_manifests(plugin_ids_plain_list)
not_included_plugin_ids = [
plugin_id for plugin_id in plugin_ids_plain_list if plugin_id not in cached_plugin_manifests
]
if not_included_plugin_ids:
manifests = marketplace.batch_fetch_plugin_manifests_ignore_deserialization_error(not_included_plugin_ids)
for manifest in manifests:
cached_plugin_manifests[manifest.plugin_id] = manifest
if (
len(manifests) == 0
): # this indicates that the plugin not found in marketplace, should set None in cache to prevent future check
for plugin_id in not_included_plugin_ids:
cached_plugin_manifests[plugin_id] = None
result: list[MarketplacePluginDeclaration] = []
for plugin_id in plugin_ids_plain_list:
final_manifest = cached_plugin_manifests.get(plugin_id)
if final_manifest is not None:
result.append(final_manifest)
return result
@shared_task(queue="plugin")
def process_tenant_plugin_autoupgrade_check_task(
tenant_id: str,
strategy_setting: TenantPluginAutoUpgradeStrategy.StrategySetting,
upgrade_time_of_day: int,
upgrade_mode: TenantPluginAutoUpgradeStrategy.UpgradeMode,
exclude_plugins: list[str],
include_plugins: list[str],
):
try:
manager = PluginInstaller()
click.echo(
click.style(
"Checking upgradable plugin for tenant: {}".format(tenant_id),
fg="green",
)
)
if strategy_setting == TenantPluginAutoUpgradeStrategy.StrategySetting.DISABLED:
return
# get plugin_ids to check
plugin_ids: list[tuple[str, str, str]] = [] # plugin_id, version, unique_identifier
click.echo(click.style("Upgrade mode: {}".format(upgrade_mode), fg="green"))
if upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.PARTIAL and include_plugins:
all_plugins = manager.list_plugins(tenant_id)
for plugin in all_plugins:
if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id in include_plugins:
plugin_ids.append(
(
plugin.plugin_id,
plugin.version,
plugin.plugin_unique_identifier,
)
)
elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.EXCLUDE:
# get all plugins and remove excluded plugins
all_plugins = manager.list_plugins(tenant_id)
plugin_ids = [
(plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
for plugin in all_plugins
if plugin.source == PluginInstallationSource.Marketplace and plugin.plugin_id not in exclude_plugins
]
elif upgrade_mode == TenantPluginAutoUpgradeStrategy.UpgradeMode.ALL:
all_plugins = manager.list_plugins(tenant_id)
plugin_ids = [
(plugin.plugin_id, plugin.version, plugin.plugin_unique_identifier)
for plugin in all_plugins
if plugin.source == PluginInstallationSource.Marketplace
]
if not plugin_ids:
return
plugin_ids_plain_list = [plugin_id for plugin_id, _, _ in plugin_ids]
manifests = marketplace_batch_fetch_plugin_manifests(plugin_ids_plain_list)
if not manifests:
return
for manifest in manifests:
for plugin_id, version, original_unique_identifier in plugin_ids:
if manifest.plugin_id != plugin_id:
continue
try:
current_version = version
latest_version = manifest.latest_version
def fix_only_checker(latest_version, current_version):
latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
current_version_tuple = tuple(int(val) for val in current_version.split("."))
if (
latest_version_tuple[0] == current_version_tuple[0]
and latest_version_tuple[1] == current_version_tuple[1]
):
return latest_version_tuple[2] != current_version_tuple[2]
return False
version_checker = {
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version,
current_version: latest_version != current_version,
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
}
if version_checker[strategy_setting](latest_version, current_version):
# execute upgrade
new_unique_identifier = manifest.latest_package_identifier
marketplace.record_install_plugin_event(new_unique_identifier)
click.echo(
click.style(
"Upgrade plugin: {} -> {}".format(original_unique_identifier, new_unique_identifier),
fg="green",
)
)
task_start_resp = manager.upgrade_plugin(
tenant_id,
original_unique_identifier,
new_unique_identifier,
PluginInstallationSource.Marketplace,
{
"plugin_unique_identifier": new_unique_identifier,
},
)
except Exception as e:
click.echo(click.style("Error when upgrading plugin: {}".format(e), fg="red"))
traceback.print_exc()
break
except Exception as e:
click.echo(click.style("Error when checking upgradable plugin: {}".format(e), fg="red"))
traceback.print_exc()
return

@ -1168,3 +1168,13 @@ QUEUE_MONITOR_THRESHOLD=200
QUEUE_MONITOR_ALERT_EMAILS=
# Monitor interval in minutes, default is 30 minutes
QUEUE_MONITOR_INTERVAL=30
# Celery schedule tasks configuration
ENABLE_CLEAN_EMBEDDING_CACHE_TASK=false
ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
ENABLE_CLEAN_MESSAGES=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true

@ -55,6 +55,25 @@ services:
- ssrf_proxy_network
- default
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.5.0
restart: always
environment:
# Use the shared environment variables.
<<: *shared-api-worker-env
# Startup mode, 'worker_beat' starts the Celery beat for scheduling periodic tasks.
MODE: beat
depends_on:
db:
condition: service_healthy
redis:
condition: service_started
networks:
- ssrf_proxy_network
- default
# Frontend web application.
web:
image: langgenius/dify-web:1.6.0

@ -527,6 +527,14 @@ x-shared-env: &shared-api-worker-env
QUEUE_MONITOR_THRESHOLD: ${QUEUE_MONITOR_THRESHOLD:-200}
QUEUE_MONITOR_ALERT_EMAILS: ${QUEUE_MONITOR_ALERT_EMAILS:-}
QUEUE_MONITOR_INTERVAL: ${QUEUE_MONITOR_INTERVAL:-30}
ENABLE_CLEAN_EMBEDDING_CACHE_TASK: ${ENABLE_CLEAN_EMBEDDING_CACHE_TASK:-false}
ENABLE_CLEAN_UNUSED_DATASETS_TASK: ${ENABLE_CLEAN_UNUSED_DATASETS_TASK:-false}
ENABLE_CREATE_TIDB_SERVERLESS_TASK: ${ENABLE_CREATE_TIDB_SERVERLESS_TASK:-false}
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK: ${ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK:-false}
ENABLE_CLEAN_MESSAGES: ${ENABLE_CLEAN_MESSAGES:-false}
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: ${ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK:-false}
ENABLE_DATASETS_QUEUE_MONITOR: ${ENABLE_DATASETS_QUEUE_MONITOR:-false}
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: ${ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK:-true}
services:
# API service
@ -584,6 +592,25 @@ services:
- ssrf_proxy_network
- default
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.5.0
restart: always
environment:
# Use the shared environment variables.
<<: *shared-api-worker-env
# Startup mode, 'worker_beat' starts the Celery beat for scheduling periodic tasks.
MODE: beat
depends_on:
db:
condition: service_healthy
redis:
condition: service_started
networks:
- ssrf_proxy_network
- default
# Frontend web application.
web:
image: langgenius/dify-web:1.6.0

@ -1,248 +0,0 @@
import threading
from unittest.mock import Mock, patch
from core.app.entities.app_invoke_entities import ChatAppGenerateEntity
from core.entities.provider_entities import QuotaUnit
from events.event_handlers.update_provider_when_message_created import (
handle,
get_update_stats,
)
from models.provider import ProviderType
from sqlalchemy.exc import OperationalError
class TestProviderUpdateDeadlockPrevention:
"""Test suite for deadlock prevention in Provider updates."""
def setup_method(self):
"""Setup test fixtures."""
self.mock_message = Mock()
self.mock_message.answer_tokens = 100
self.mock_app_config = Mock()
self.mock_app_config.tenant_id = "test-tenant-123"
self.mock_model_conf = Mock()
self.mock_model_conf.provider = "openai"
self.mock_system_config = Mock()
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
self.mock_provider_config = Mock()
self.mock_provider_config.using_provider_type = ProviderType.SYSTEM
self.mock_provider_config.system_configuration = self.mock_system_config
self.mock_provider_bundle = Mock()
self.mock_provider_bundle.configuration = self.mock_provider_config
self.mock_model_conf.provider_model_bundle = self.mock_provider_bundle
self.mock_generate_entity = Mock(spec=ChatAppGenerateEntity)
self.mock_generate_entity.app_config = self.mock_app_config
self.mock_generate_entity.model_conf = self.mock_model_conf
@patch("events.event_handlers.update_provider_when_message_created.db")
def test_consolidated_handler_basic_functionality(self, mock_db):
"""Test that the consolidated handler performs both updates correctly."""
# Setup mock query chain
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1 # 1 row affected
# Call the handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify db.session.query was called
assert mock_db.session.query.called
# Verify commit was called
mock_db.session.commit.assert_called_once()
# Verify no rollback was called
assert not mock_db.session.rollback.called
@patch("events.event_handlers.update_provider_when_message_created.db")
def test_deadlock_retry_mechanism(self, mock_db):
"""Test that deadlock errors trigger retry logic."""
# Setup mock to raise deadlock error on first attempt, succeed on second
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
# First call raises deadlock, second succeeds
mock_db.session.commit.side_effect = [
OperationalError("deadlock detected", None, None),
None, # Success on retry
]
# Call the handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify commit was called twice (original + retry)
assert mock_db.session.commit.call_count == 2
# Verify rollback was called once (after first failure)
mock_db.session.rollback.assert_called_once()
@patch("events.event_handlers.update_provider_when_message_created.db")
@patch("events.event_handlers.update_provider_when_message_created.time.sleep")
def test_exponential_backoff_timing(self, mock_sleep, mock_db):
"""Test that retry delays follow exponential backoff pattern."""
# Setup mock to fail twice, succeed on third attempt
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
mock_db.session.commit.side_effect = [
OperationalError("deadlock detected", None, None),
OperationalError("deadlock detected", None, None),
None, # Success on third attempt
]
# Call the handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify sleep was called twice with increasing delays
assert mock_sleep.call_count == 2
# First delay should be around 0.1s + jitter
first_delay = mock_sleep.call_args_list[0][0][0]
assert 0.1 <= first_delay <= 0.3
# Second delay should be around 0.2s + jitter
second_delay = mock_sleep.call_args_list[1][0][0]
assert 0.2 <= second_delay <= 0.4
def test_concurrent_handler_execution(self):
"""Test that multiple handlers can run concurrently without deadlock."""
results = []
errors = []
def run_handler():
try:
with patch(
"events.event_handlers.update_provider_when_message_created.db"
) as mock_db:
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
handle(
self.mock_message,
application_generate_entity=self.mock_generate_entity,
)
results.append("success")
except Exception as e:
errors.append(str(e))
# Run multiple handlers concurrently
threads = []
for _ in range(5):
thread = threading.Thread(target=run_handler)
threads.append(thread)
thread.start()
# Wait for all threads to complete
for thread in threads:
thread.join(timeout=5)
# Verify all handlers completed successfully
assert len(results) == 5
assert len(errors) == 0
def test_performance_stats_tracking(self):
"""Test that performance statistics are tracked correctly."""
# Reset stats
stats = get_update_stats()
initial_total = stats["total_updates"]
with patch(
"events.event_handlers.update_provider_when_message_created.db"
) as mock_db:
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
# Call handler
handle(
self.mock_message, application_generate_entity=self.mock_generate_entity
)
# Check that stats were updated
updated_stats = get_update_stats()
assert updated_stats["total_updates"] == initial_total + 1
assert updated_stats["successful_updates"] >= initial_total + 1
def test_non_chat_entity_ignored(self):
"""Test that non-chat entities are ignored by the handler."""
# Create a non-chat entity
mock_non_chat_entity = Mock()
mock_non_chat_entity.__class__.__name__ = "NonChatEntity"
with patch(
"events.event_handlers.update_provider_when_message_created.db"
) as mock_db:
# Call handler with non-chat entity
handle(self.mock_message, application_generate_entity=mock_non_chat_entity)
# Verify no database operations were performed
assert not mock_db.session.query.called
assert not mock_db.session.commit.called
@patch("events.event_handlers.update_provider_when_message_created.db")
def test_quota_calculation_tokens(self, mock_db):
"""Test quota calculation for token-based quotas."""
# Setup token-based quota
self.mock_system_config.current_quota_type = QuotaUnit.TOKENS
self.mock_message.answer_tokens = 150
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
# Call handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify update was called with token count
update_calls = mock_query.update.call_args_list
# Should have at least one call with quota_used update
quota_update_found = False
for call in update_calls:
values = call[0][0] # First argument to update()
if "quota_used" in values:
quota_update_found = True
break
assert quota_update_found
@patch("events.event_handlers.update_provider_when_message_created.db")
def test_quota_calculation_times(self, mock_db):
"""Test quota calculation for times-based quotas."""
# Setup times-based quota
self.mock_system_config.current_quota_type = QuotaUnit.TIMES
mock_query = Mock()
mock_db.session.query.return_value = mock_query
mock_query.filter.return_value = mock_query
mock_query.order_by.return_value = mock_query
mock_query.update.return_value = 1
# Call handler
handle(self.mock_message, application_generate_entity=self.mock_generate_entity)
# Verify update was called
assert mock_query.update.called
assert mock_db.session.commit.called

@ -83,27 +83,50 @@ const Panel: FC = () => {
const hasConfiguredTracing = !!(langSmithConfig || langFuseConfig || opikConfig || weaveConfig || arizeConfig || phoenixConfig || aliyunConfig)
const fetchTracingConfig = async () => {
const { tracing_config: arizeConfig, has_not_configured: arizeHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.arize })
if (!arizeHasNotConfig)
setArizeConfig(arizeConfig as ArizeConfig)
const { tracing_config: phoenixConfig, has_not_configured: phoenixHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.phoenix })
if (!phoenixHasNotConfig)
setPhoenixConfig(phoenixConfig as PhoenixConfig)
const { tracing_config: langSmithConfig, has_not_configured: langSmithHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.langSmith })
if (!langSmithHasNotConfig)
setLangSmithConfig(langSmithConfig as LangSmithConfig)
const { tracing_config: langFuseConfig, has_not_configured: langFuseHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.langfuse })
if (!langFuseHasNotConfig)
setLangFuseConfig(langFuseConfig as LangFuseConfig)
const { tracing_config: opikConfig, has_not_configured: OpikHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.opik })
if (!OpikHasNotConfig)
setOpikConfig(opikConfig as OpikConfig)
const { tracing_config: weaveConfig, has_not_configured: weaveHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.weave })
if (!weaveHasNotConfig)
setWeaveConfig(weaveConfig as WeaveConfig)
const { tracing_config: aliyunConfig, has_not_configured: aliyunHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.aliyun })
if (!aliyunHasNotConfig)
setAliyunConfig(aliyunConfig as AliyunConfig)
const getArizeConfig = async () => {
const { tracing_config: arizeConfig, has_not_configured: arizeHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.arize })
if (!arizeHasNotConfig)
setArizeConfig(arizeConfig as ArizeConfig)
}
const getPhoenixConfig = async () => {
const { tracing_config: phoenixConfig, has_not_configured: phoenixHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.phoenix })
if (!phoenixHasNotConfig)
setPhoenixConfig(phoenixConfig as PhoenixConfig)
}
const getLangSmithConfig = async () => {
const { tracing_config: langSmithConfig, has_not_configured: langSmithHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.langSmith })
if (!langSmithHasNotConfig)
setLangSmithConfig(langSmithConfig as LangSmithConfig)
}
const getLangFuseConfig = async () => {
const { tracing_config: langFuseConfig, has_not_configured: langFuseHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.langfuse })
if (!langFuseHasNotConfig)
setLangFuseConfig(langFuseConfig as LangFuseConfig)
}
const getOpikConfig = async () => {
const { tracing_config: opikConfig, has_not_configured: OpikHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.opik })
if (!OpikHasNotConfig)
setOpikConfig(opikConfig as OpikConfig)
}
const getWeaveConfig = async () => {
const { tracing_config: weaveConfig, has_not_configured: weaveHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.weave })
if (!weaveHasNotConfig)
setWeaveConfig(weaveConfig as WeaveConfig)
}
const getAliyunConfig = async () => {
const { tracing_config: aliyunConfig, has_not_configured: aliyunHasNotConfig } = await doFetchTracingConfig({ appId, provider: TracingProvider.aliyun })
if (!aliyunHasNotConfig)
setAliyunConfig(aliyunConfig as AliyunConfig)
}
Promise.all([
getArizeConfig(),
getPhoenixConfig(),
getLangSmithConfig(),
getLangFuseConfig(),
getOpikConfig(),
getWeaveConfig(),
getAliyunConfig(),
])
}
const handleTracingConfigUpdated = async (provider: TracingProvider) => {
@ -155,7 +178,6 @@ const Panel: FC = () => {
await fetchTracingConfig()
setLoaded()
})()
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
const [controlShowPopup, setControlShowPopup] = useState<number>(0)

@ -4,17 +4,19 @@ import cn from '@/utils/classnames'
type OptionListItemProps = {
isSelected: boolean
onClick: () => void
noAutoScroll?: boolean
} & React.LiHTMLAttributes<HTMLLIElement>
const OptionListItem: FC<OptionListItemProps> = ({
isSelected,
onClick,
noAutoScroll,
children,
}) => {
const listItemRef = useRef<HTMLLIElement>(null)
useEffect(() => {
if (isSelected)
if (isSelected && !noAutoScroll)
listItemRef.current?.scrollIntoView({ behavior: 'instant' })
}, [])

@ -1,13 +1,18 @@
import React from 'react'
import { useTranslation } from 'react-i18next'
const Header = () => {
type Props = {
title?: string
}
const Header = ({
title,
}: Props) => {
const { t } = useTranslation()
return (
<div className='flex flex-col border-b-[0.5px] border-divider-regular'>
<div className='system-md-semibold flex items-center px-2 py-1.5 text-text-primary'>
{t('time.title.pickTime')}
{title || t('time.title.pickTime')}
</div>
</div>
)

@ -20,6 +20,9 @@ const TimePicker = ({
onChange,
onClear,
renderTrigger,
title,
minuteFilter,
popupClassName,
}: TimePickerProps) => {
const { t } = useTranslation()
const [isOpen, setIsOpen] = useState(false)
@ -49,7 +52,6 @@ const TimePicker = ({
else {
setSelectedTime(prev => prev ? getDateWithTimezone({ date: prev, timezone }) : undefined)
}
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [timezone])
const handleClickTrigger = (e: React.MouseEvent) => {
@ -108,6 +110,15 @@ const TimePicker = ({
const displayValue = value?.format(timeFormat) || ''
const placeholderDate = isOpen && selectedTime ? selectedTime.format(timeFormat) : (placeholder || t('time.defaultPlaceholder'))
const inputElem = (
<input
className='system-xs-regular flex-1 cursor-pointer appearance-none truncate bg-transparent p-1
text-components-input-text-filled outline-none placeholder:text-components-input-text-placeholder'
readOnly
value={isOpen ? '' : displayValue}
placeholder={placeholderDate}
/>
)
return (
<PortalToFollowElem
open={isOpen}
@ -115,18 +126,16 @@ const TimePicker = ({
placement='bottom-end'
>
<PortalToFollowElemTrigger>
{renderTrigger ? (renderTrigger()) : (
{renderTrigger ? (renderTrigger({
inputElem,
onClick: handleClickTrigger,
isOpen,
})) : (
<div
className='group flex w-[252px] cursor-pointer items-center gap-x-0.5 rounded-lg bg-components-input-bg-normal px-2 py-1 hover:bg-state-base-hover-alt'
onClick={handleClickTrigger}
>
<input
className='system-xs-regular flex-1 cursor-pointer appearance-none truncate bg-transparent p-1
text-components-input-text-filled outline-none placeholder:text-components-input-text-placeholder'
readOnly
value={isOpen ? '' : displayValue}
placeholder={placeholderDate}
/>
{inputElem}
<RiTimeLine className={cn(
'h-4 w-4 shrink-0 text-text-quaternary',
isOpen ? 'text-text-secondary' : 'group-hover:text-text-secondary',
@ -142,14 +151,15 @@ const TimePicker = ({
</div>
)}
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-50'>
<PortalToFollowElemContent className={cn('z-50', popupClassName)}>
<div className='mt-1 w-[252px] rounded-xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg shadow-shadow-shadow-5'>
{/* Header */}
<Header />
<Header title={title} />
{/* Time Options */}
<Options
selectedTime={selectedTime}
minuteFilter={minuteFilter}
handleSelectHour={handleSelectHour}
handleSelectMinute={handleSelectMinute}
handleSelectPeriod={handleSelectPeriod}

@ -5,6 +5,7 @@ import OptionListItem from '../common/option-list-item'
const Options: FC<TimeOptionsProps> = ({
selectedTime,
minuteFilter,
handleSelectHour,
handleSelectMinute,
handleSelectPeriod,
@ -33,7 +34,7 @@ const Options: FC<TimeOptionsProps> = ({
{/* Minute */}
<ul className='no-scrollbar flex h-[208px] flex-col gap-y-0.5 overflow-y-auto pb-[184px]'>
{
minuteOptions.map((minute) => {
(minuteFilter ? minuteFilter(minuteOptions) : minuteOptions).map((minute) => {
const isSelected = selectedTime?.format('mm') === minute
return (
<OptionListItem
@ -57,6 +58,7 @@ const Options: FC<TimeOptionsProps> = ({
key={period}
isSelected={isSelected}
onClick={handleSelectPeriod.bind(null, period)}
noAutoScroll // if choose PM which would hide(scrolled) AM that may make user confused that there's no am.
>
{period}
</OptionListItem>

@ -28,6 +28,7 @@ export type DatePickerProps = {
onClear: () => void
triggerWrapClassName?: string
renderTrigger?: (props: TriggerProps) => React.ReactNode
minuteFilter?: (minutes: string[]) => string[]
popupZIndexClassname?: string
}
@ -47,13 +48,21 @@ export type DatePickerFooterProps = {
handleConfirmDate: () => void
}
export type TriggerParams = {
isOpen: boolean
inputElem: React.ReactNode
onClick: (e: React.MouseEvent) => void
}
export type TimePickerProps = {
value: Dayjs | undefined
timezone?: string
placeholder?: string
onChange: (date: Dayjs | undefined) => void
onClear: () => void
renderTrigger?: () => React.ReactNode
renderTrigger?: (props: TriggerParams) => React.ReactNode
title?: string
minuteFilter?: (minutes: string[]) => string[]
popupClassName?: string
}
export type TimePickerFooterProps = {
@ -81,6 +90,7 @@ export type CalendarItemProps = {
export type TimeOptionsProps = {
selectedTime: Dayjs | undefined
minuteFilter?: (minutes: string[]) => string[]
handleSelectHour: (hour: string) => void
handleSelectMinute: (minute: string) => void
handleSelectPeriod: (period: Period) => void

@ -2,6 +2,7 @@ import dayjs, { type Dayjs } from 'dayjs'
import type { Day } from '../types'
import utc from 'dayjs/plugin/utc'
import timezone from 'dayjs/plugin/timezone'
import tz from '@/utils/timezone.json'
dayjs.extend(utc)
dayjs.extend(timezone)
@ -78,3 +79,14 @@ export const getHourIn12Hour = (date: Dayjs) => {
export const getDateWithTimezone = (props: { date?: Dayjs, timezone?: string }) => {
return props.date ? dayjs.tz(props.date, props.timezone) : dayjs().tz(props.timezone)
}
// Asia/Shanghai -> UTC+8
const DEFAULT_OFFSET_STR = 'UTC+0'
export const convertTimezoneToOffsetStr = (timezone?: string) => {
if (!timezone)
return DEFAULT_OFFSET_STR
const tzItem = tz.find(item => item.value === timezone)
if(!tzItem)
return DEFAULT_OFFSET_STR
return `UTC${tzItem.name.charAt(0)}${tzItem.name.charAt(2)}`
}

@ -0,0 +1,7 @@
<svg width="32" height="32" viewBox="0 0 32 32" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M28.0049 16C28.0049 20.4183 24.4231 24 20.0049 24C15.5866 24 12.0049 20.4183 12.0049 16C12.0049 11.5817 15.5866 8 20.0049 8C24.4231 8 28.0049 11.5817 28.0049 16Z" stroke="#676F83" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M4.00488 16H6.67155" stroke="#676F83" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M4.00488 9.33334H8.00488" stroke="#676F83" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M4.00488 22.6667H8.00488" stroke="#676F83" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M26 22L29.3333 25.3333" stroke="#676F83" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 823 B

@ -0,0 +1,4 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M5.46257 4.43262C7.21556 2.91688 9.5007 2 12 2C17.5228 2 22 6.47715 22 12C22 14.1361 21.3302 16.1158 20.1892 17.7406L17 12H20C20 7.58172 16.4183 4 12 4C9.84982 4 7.89777 4.84827 6.46023 6.22842L5.46257 4.43262ZM18.5374 19.5674C16.7844 21.0831 14.4993 22 12 22C6.47715 22 2 17.5228 2 12C2 9.86386 2.66979 7.88416 3.8108 6.25944L7 12H4C4 16.4183 7.58172 20 12 20C14.1502 20 16.1022 19.1517 17.5398 17.7716L18.5374 19.5674Z" fill="black"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M16.3308 16H14.2915L13.6249 13.9476H10.3761L9.70846 16H7.66918L10.7759 7H13.2281L16.3308 16ZM10.8595 12.4622H13.1435L12.0378 9.05639H11.9673L10.8595 12.4622Z" fill="black"/>
</svg>

After

Width:  |  Height:  |  Size: 772 B

@ -75,7 +75,7 @@ Icon.displayName = '<%= svgName %>'
export default Icon
`.trim())
await writeFile(path.resolve(currentPath, `${fileName}.json`), JSON.stringify(svgData, '', '\t'))
await writeFile(path.resolve(currentPath, `${fileName}.json`), `${JSON.stringify(svgData, '', '\t')}\n`)
await writeFile(path.resolve(currentPath, `${fileName}.tsx`), `${componentRender({ svgName: fileName })}\n`)
const indexingRender = template(`

File diff suppressed because one or more lines are too long

@ -4,12 +4,16 @@
import * as React from 'react'
import data from './AliyunIcon.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
import type { IconData } from '@/app/components/base/icons/IconBase'
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
const Icon = (
{
ref,
...props
}: React.SVGProps<SVGSVGElement> & {
ref?: React.RefObject<React.MutableRefObject<HTMLOrSVGElement>>;
},
) => <IconBase {...props} ref={ref} data={data as IconData} />
Icon.displayName = 'AliyunIcon'

File diff suppressed because one or more lines are too long

@ -4,12 +4,16 @@
import * as React from 'react'
import data from './AliyunIconBig.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
import type { IconData } from '@/app/components/base/icons/IconBase'
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
const Icon = (
{
ref,
...props
}: React.SVGProps<SVGSVGElement> & {
ref?: React.RefObject<React.MutableRefObject<HTMLOrSVGElement>>;
},
) => <IconBase {...props} ref={ref} data={data as IconData} />
Icon.displayName = 'AliyunIconBig'

@ -4,12 +4,16 @@
import * as React from 'react'
import data from './WeaveIcon.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
import type { IconData } from '@/app/components/base/icons/IconBase'
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
const Icon = (
{
ref,
...props
}: React.SVGProps<SVGSVGElement> & {
ref?: React.RefObject<React.MutableRefObject<HTMLOrSVGElement>>;
},
) => <IconBase {...props} ref={ref} data={data as IconData} />
Icon.displayName = 'WeaveIcon'

@ -4,12 +4,16 @@
import * as React from 'react'
import data from './WeaveIconBig.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
import type { IconData } from '@/app/components/base/icons/IconBase'
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
const Icon = (
{
ref,
...props
}: React.SVGProps<SVGSVGElement> & {
ref?: React.RefObject<React.MutableRefObject<HTMLOrSVGElement>>;
},
) => <IconBase {...props} ref={ref} data={data as IconData} />
Icon.displayName = 'WeaveIconBig'

@ -1,3 +1,5 @@
export { default as AliyunIconBig } from './AliyunIconBig'
export { default as AliyunIcon } from './AliyunIcon'
export { default as ArizeIconBig } from './ArizeIconBig'
export { default as ArizeIcon } from './ArizeIcon'
export { default as LangfuseIconBig } from './LangfuseIconBig'
@ -11,5 +13,3 @@ export { default as PhoenixIcon } from './PhoenixIcon'
export { default as TracingIcon } from './TracingIcon'
export { default as WeaveIconBig } from './WeaveIconBig'
export { default as WeaveIcon } from './WeaveIcon'
export { default as AliyunIconBig } from './AliyunIconBig'
export { default as AliyunIcon } from './AliyunIcon'

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save