diff --git a/api/commands.py b/api/commands.py index c5394c6f87..3881439ddf 100644 --- a/api/commands.py +++ b/api/commands.py @@ -17,6 +17,7 @@ from core.rag.models.document import Document from events.app_event import app_was_created from extensions.ext_database import db from extensions.ext_redis import redis_client +from extensions.ext_storage import storage from libs.helper import email as email_validate from libs.password import hash_password, password_pattern, valid_password from libs.rsa import generate_key_pair @@ -815,3 +816,274 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[ ClearFreePlanTenantExpiredLogs.process(days, batch, tenant_ids) click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green")) + + +@click.command("clear-orphaned-file-records", help="Clear orphaned file records.") +def clear_orphaned_file_records(): + """ + Clear orphaned file records in the database. + """ + + # define tables and columns to process + files_tables = [ + {"table": "upload_files", "id_column": "id", "key_column": "key"}, + {"table": "tool_files", "id_column": "id", "key_column": "file_key"}, + ] + ids_tables = [ + {"type": "uuid", "table": "message_files", "column": "upload_file_id"}, + {"type": "text", "table": "documents", "column": "data_source_info"}, + {"type": "text", "table": "document_segments", "column": "content"}, + {"type": "text", "table": "messages", "column": "answer"}, + {"type": "text", "table": "workflow_node_executions", "column": "inputs"}, + {"type": "text", "table": "workflow_node_executions", "column": "process_data"}, + {"type": "text", "table": "workflow_node_executions", "column": "outputs"}, + {"type": "text", "table": "conversations", "column": "introduction"}, + {"type": "text", "table": "conversations", "column": "system_instruction"}, + {"type": "json", "table": "messages", "column": "inputs"}, + {"type": "json", "table": "messages", "column": "message"}, + ] + + # notify user and ask for confirmation + click.echo( + click.style("This command will find and delete orphaned file records in the following tables:", fg="yellow") + ) + for files_table in files_tables: + click.echo(click.style(f"- {files_table['table']}", fg="yellow")) + click.echo( + click.style("The following tables and columns will be scanned to find orphaned file records:", fg="yellow") + ) + for ids_table in ids_tables: + click.echo(click.style(f"- {ids_table['table']} ({ids_table['column']})", fg="yellow")) + click.echo("") + + click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red")) + click.echo( + click.style( + ( + "Since not all patterns have been fully tested, " + "please note that this command may delete unintended file records." + ), + fg="yellow", + ) + ) + click.echo( + click.style("This cannot be undone. Please make sure to back up your database before proceeding.", fg="yellow") + ) + click.echo( + click.style( + ( + "It is also recommended to run this during the maintenance window, " + "as this may cause high load on your instance." + ), + fg="yellow", + ) + ) + click.confirm("Do you want to proceed?", abort=True) + + # start the cleanup process + click.echo(click.style("Starting orphaned file records cleanup.", fg="white")) + + try: + # fetch file id and keys from each table + all_files_in_tables = [] + for files_table in files_tables: + click.echo(click.style(f"- Listing file records in table {files_table['table']}", fg="white")) + query = f"SELECT {files_table['id_column']}, {files_table['key_column']} FROM {files_table['table']}" + with db.engine.begin() as conn: + rs = conn.execute(db.text(query)) + for i in rs: + all_files_in_tables.append({"table": files_table["table"], "id": str(i[0]), "key": i[1]}) + click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) + + # fetch referred table and columns + guid_regexp = "[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" + all_ids_in_tables = [] + for ids_table in ids_tables: + query = "" + if ids_table["type"] == "uuid": + click.echo( + click.style( + f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", fg="white" + ) + ) + query = ( + f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL" + ) + with db.engine.begin() as conn: + rs = conn.execute(db.text(query)) + for i in rs: + all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])}) + elif ids_table["type"] == "text": + click.echo( + click.style( + f"- Listing file-id-like strings in column {ids_table['column']} in table {ids_table['table']}", + fg="white", + ) + ) + query = ( + f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id " + f"FROM {ids_table['table']}" + ) + with db.engine.begin() as conn: + rs = conn.execute(db.text(query)) + for i in rs: + for j in i[0]: + all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + elif ids_table["type"] == "json": + click.echo( + click.style( + ( + f"- Listing file-id-like JSON string in column {ids_table['column']} " + f"in table {ids_table['table']}" + ), + fg="white", + ) + ) + query = ( + f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id " + f"FROM {ids_table['table']}" + ) + with db.engine.begin() as conn: + rs = conn.execute(db.text(query)) + for i in rs: + for j in i[0]: + all_ids_in_tables.append({"table": ids_table["table"], "id": j}) + click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white")) + + except Exception as e: + click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) + return + + # find orphaned files + all_files = [file["id"] for file in all_files_in_tables] + all_ids = [file["id"] for file in all_ids_in_tables] + orphaned_files = list(set(all_files) - set(all_ids)) + if not orphaned_files: + click.echo(click.style("No orphaned file records found. There is nothing to delete.", fg="green")) + return + click.echo(click.style(f"Found {len(orphaned_files)} orphaned file records.", fg="white")) + for file in orphaned_files: + click.echo(click.style(f"- orphaned file id: {file}", fg="black")) + click.confirm(f"Do you want to proceed to delete all {len(orphaned_files)} orphaned file records?", abort=True) + + # delete orphaned records for each file + try: + for files_table in files_tables: + click.echo(click.style(f"- Deleting orphaned file records in table {files_table['table']}", fg="white")) + query = f"DELETE FROM {files_table['table']} WHERE {files_table['id_column']} IN :ids" + with db.engine.begin() as conn: + conn.execute(db.text(query), {"ids": tuple(orphaned_files)}) + except Exception as e: + click.echo(click.style(f"Error deleting orphaned file records: {str(e)}", fg="red")) + return + click.echo(click.style(f"Removed {len(orphaned_files)} orphaned file records.", fg="green")) + + +@click.command("remove-orphaned-files-on-storage", help="Remove orphaned files on the storage.") +def remove_orphaned_files_on_storage(): + """ + Remove orphaned files on the storage. + """ + + # define tables and columns to process + files_tables = [ + {"table": "upload_files", "key_column": "key"}, + {"table": "tool_files", "key_column": "file_key"}, + ] + storage_paths = ["image_files", "tools", "upload_files"] + + # notify user and ask for confirmation + click.echo(click.style("This command will find and remove orphaned files on the storage,", fg="yellow")) + click.echo( + click.style("by comparing the files on the storage with the records in the following tables:", fg="yellow") + ) + for files_table in files_tables: + click.echo(click.style(f"- {files_table['table']}", fg="yellow")) + click.echo(click.style("The following paths on the storage will be scanned to find orphaned files:", fg="yellow")) + for storage_path in storage_paths: + click.echo(click.style(f"- {storage_path}", fg="yellow")) + click.echo("") + + click.echo(click.style("!!! USE WITH CAUTION !!!", fg="red")) + click.echo( + click.style( + "Currently, this command will work only for opendal based storage (STORAGE_TYPE=opendal).", fg="yellow" + ) + ) + click.echo( + click.style( + "Since not all patterns have been fully tested, please note that this command may delete unintended files.", + fg="yellow", + ) + ) + click.echo( + click.style("This cannot be undone. Please make sure to back up your storage before proceeding.", fg="yellow") + ) + click.echo( + click.style( + ( + "It is also recommended to run this during the maintenance window, " + "as this may cause high load on your instance." + ), + fg="yellow", + ) + ) + click.confirm("Do you want to proceed?", abort=True) + + # start the cleanup process + click.echo(click.style("Starting orphaned files cleanup.", fg="white")) + + # fetch file id and keys from each table + all_files_in_tables = [] + try: + for files_table in files_tables: + click.echo(click.style(f"- Listing files from table {files_table['table']}", fg="white")) + query = f"SELECT {files_table['key_column']} FROM {files_table['table']}" + with db.engine.begin() as conn: + rs = conn.execute(db.text(query)) + for i in rs: + all_files_in_tables.append(str(i[0])) + click.echo(click.style(f"Found {len(all_files_in_tables)} files in tables.", fg="white")) + except Exception as e: + click.echo(click.style(f"Error fetching keys: {str(e)}", fg="red")) + + all_files_on_storage = [] + for storage_path in storage_paths: + try: + click.echo(click.style(f"- Scanning files on storage path {storage_path}", fg="white")) + files = storage.scan(path=storage_path, files=True, directories=False) + all_files_on_storage.extend(files) + except FileNotFoundError as e: + click.echo(click.style(f" -> Skipping path {storage_path} as it does not exist.", fg="yellow")) + continue + except Exception as e: + click.echo(click.style(f" -> Error scanning files on storage path {storage_path}: {str(e)}", fg="red")) + continue + click.echo(click.style(f"Found {len(all_files_on_storage)} files on storage.", fg="white")) + + # find orphaned files + orphaned_files = list(set(all_files_on_storage) - set(all_files_in_tables)) + if not orphaned_files: + click.echo(click.style("No orphaned files found. There is nothing to remove.", fg="green")) + return + click.echo(click.style(f"Found {len(orphaned_files)} orphaned files.", fg="white")) + for file in orphaned_files: + click.echo(click.style(f"- orphaned file: {file}", fg="black")) + click.confirm(f"Do you want to proceed to remove all {len(orphaned_files)} orphaned files?", abort=True) + + # delete orphaned files + removed_files = 0 + error_files = 0 + for file in orphaned_files: + try: + storage.delete(file) + removed_files += 1 + click.echo(click.style(f"- Removing orphaned file: {file}", fg="white")) + except Exception as e: + error_files += 1 + click.echo(click.style(f"- Error deleting orphaned file {file}: {str(e)}", fg="red")) + continue + if error_files == 0: + click.echo(click.style(f"Removed {removed_files} orphaned files without errors.", fg="green")) + else: + click.echo(click.style(f"Removed {removed_files} orphaned files, with {error_files} errors.", fg="yellow")) diff --git a/api/controllers/console/app/annotation.py b/api/controllers/console/app/annotation.py index fcd8ed1882..48353a63af 100644 --- a/api/controllers/console/app/annotation.py +++ b/api/controllers/console/app/annotation.py @@ -186,7 +186,7 @@ class AnnotationUpdateDeleteApi(Resource): app_id = str(app_id) annotation_id = str(annotation_id) AppAnnotationService.delete_app_annotation(app_id, annotation_id) - return {"result": "success"}, 200 + return {"result": "success"}, 204 class AnnotationBatchImportApi(Resource): diff --git a/api/controllers/console/app/ops_trace.py b/api/controllers/console/app/ops_trace.py index dd25af8ebf..7176440e16 100644 --- a/api/controllers/console/app/ops_trace.py +++ b/api/controllers/console/app/ops_trace.py @@ -84,7 +84,7 @@ class TraceAppConfigApi(Resource): result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args["tracing_provider"]) if not result: raise TracingConfigNotExist() - return {"result": "success"} + return {"result": "success"}, 204 except Exception as e: raise BadRequest(str(e)) diff --git a/api/controllers/console/auth/data_source_bearer_auth.py b/api/controllers/console/auth/data_source_bearer_auth.py index ea00c2b8c2..5f0762e4a5 100644 --- a/api/controllers/console/auth/data_source_bearer_auth.py +++ b/api/controllers/console/auth/data_source_bearer_auth.py @@ -65,7 +65,7 @@ class ApiKeyAuthDataSourceBindingDelete(Resource): ApiKeyAuthService.delete_provider_auth(current_user.current_tenant_id, binding_id) - return {"result": "success"}, 200 + return {"result": "success"}, 204 api.add_resource(ApiKeyAuthDataSource, "/api-key-auth/data-source") diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 0b40312368..3588abeff5 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -40,7 +40,7 @@ from core.indexing_runner import IndexingRunner from core.model_manager import ModelManager from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.errors.invoke import InvokeAuthorizationError -from core.plugin.manager.exc import PluginDaemonClientSideError +from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.extract_setting import ExtractSetting from extensions.ext_database import db from extensions.ext_redis import redis_client diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 696aaa94db..5c54ecbe81 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -131,7 +131,7 @@ class DatasetDocumentSegmentListApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) SegmentService.delete_segments(segment_ids, document, dataset) - return {"result": "success"}, 200 + return {"result": "success"}, 204 class DatasetDocumentSegmentApi(Resource): @@ -333,7 +333,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): except services.errors.account.NoPermissionError as e: raise Forbidden(str(e)) SegmentService.delete_segment(segment, document, dataset) - return {"result": "success"}, 200 + return {"result": "success"}, 204 class DatasetDocumentSegmentBatchImportApi(Resource): @@ -590,7 +590,7 @@ class ChildChunkUpdateApi(Resource): SegmentService.delete_child_chunk(child_chunk, dataset) except ChildChunkDeleteIndexServiceError as e: raise ChildChunkDeleteIndexError(str(e)) - return {"result": "success"}, 200 + return {"result": "success"}, 204 @setup_required @login_required diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 2c031172bf..aee8323f23 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -135,7 +135,7 @@ class ExternalApiTemplateApi(Resource): raise Forbidden() ExternalDatasetService.delete_external_knowledge_api(current_user.current_tenant_id, external_knowledge_api_id) - return {"result": "success"}, 200 + return {"result": "success"}, 204 class ExternalApiUseCheckApi(Resource): diff --git a/api/controllers/console/datasets/metadata.py b/api/controllers/console/datasets/metadata.py index fc9711169f..e4cac40ca1 100644 --- a/api/controllers/console/datasets/metadata.py +++ b/api/controllers/console/datasets/metadata.py @@ -82,7 +82,7 @@ class DatasetMetadataApi(Resource): DatasetService.check_dataset_permission(dataset, current_user) MetadataService.delete_metadata(dataset_id_str, metadata_id_str) - return 200 + return {"result": "success"}, 204 class DatasetMetadataBuiltInFieldApi(Resource): diff --git a/api/controllers/console/explore/installed_app.py b/api/controllers/console/explore/installed_app.py index 86550b2bdf..132da11878 100644 --- a/api/controllers/console/explore/installed_app.py +++ b/api/controllers/console/explore/installed_app.py @@ -113,7 +113,7 @@ class InstalledAppApi(InstalledAppResource): db.session.delete(installed_app) db.session.commit() - return {"result": "success", "message": "App uninstalled successfully"} + return {"result": "success", "message": "App uninstalled successfully"}, 204 def patch(self, installed_app): parser = reqparse.RequestParser() diff --git a/api/controllers/console/explore/saved_message.py b/api/controllers/console/explore/saved_message.py index 9f0c496645..3a1655d0ee 100644 --- a/api/controllers/console/explore/saved_message.py +++ b/api/controllers/console/explore/saved_message.py @@ -72,7 +72,7 @@ class SavedMessageApi(InstalledAppResource): SavedMessageService.delete(app_model, current_user, message_id) - return {"result": "success"} + return {"result": "success"}, 204 api.add_resource( diff --git a/api/controllers/console/extension.py b/api/controllers/console/extension.py index ed6cedb220..833da0d03c 100644 --- a/api/controllers/console/extension.py +++ b/api/controllers/console/extension.py @@ -99,7 +99,7 @@ class APIBasedExtensionDetailAPI(Resource): APIBasedExtensionService.delete(extension_data_from_db) - return {"result": "success"} + return {"result": "success"}, 204 api.add_resource(CodeBasedExtensionAPI, "/code-based-extension") diff --git a/api/controllers/console/tag/tags.py b/api/controllers/console/tag/tags.py index da83f64019..0d0d7ae95f 100644 --- a/api/controllers/console/tag/tags.py +++ b/api/controllers/console/tag/tags.py @@ -86,7 +86,7 @@ class TagUpdateDeleteApi(Resource): TagService.delete_tag(tag_id) - return 200 + return 204 class TagBindingCreateApi(Resource): diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 46dee20f8b..aa1a78935d 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -5,7 +5,7 @@ from werkzeug.exceptions import Forbidden from controllers.console import api from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder -from core.plugin.manager.exc import PluginPermissionDeniedError +from core.plugin.impl.exc import PluginPermissionDeniedError from libs.login import login_required from services.plugin.endpoint_service import EndpointService diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index e9c1884c60..6f9ae18750 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -10,7 +10,7 @@ from controllers.console import api from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, setup_required from core.model_runtime.utils.encoders import jsonable_encoder -from core.plugin.manager.exc import PluginDaemonClientSideError +from core.plugin.impl.exc import PluginDaemonClientSideError from libs.login import login_required from models.account import TenantPluginPermission from services.plugin.plugin_permission_service import PluginPermissionService diff --git a/api/controllers/service_api/app/annotation.py b/api/controllers/service_api/app/annotation.py index cffa3665b1..522a96b791 100644 --- a/api/controllers/service_api/app/annotation.py +++ b/api/controllers/service_api/app/annotation.py @@ -98,7 +98,7 @@ class AnnotationUpdateDeleteApi(Resource): annotation_id = str(annotation_id) AppAnnotationService.delete_app_annotation(app_model.id, annotation_id) - return {"result": "success"}, 200 + return {"result": "success"}, 204 api.add_resource(AnnotationReplyActionApi, "/apps/annotation-reply/") diff --git a/api/controllers/service_api/app/conversation.py b/api/controllers/service_api/app/conversation.py index 55600a3fd0..dfc357e1ab 100644 --- a/api/controllers/service_api/app/conversation.py +++ b/api/controllers/service_api/app/conversation.py @@ -72,7 +72,7 @@ class ConversationDetailApi(Resource): ConversationService.delete(app_model, conversation_id, end_user) except services.errors.conversation.ConversationNotExistsError: raise NotFound("Conversation Not Exists.") - return {"result": "success"}, 200 + return {"result": "success"}, 204 class ConversationRenameApi(Resource): diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index eec6afc9ef..9e943e2b2d 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -323,7 +323,7 @@ class DocumentDeleteApi(DatasetApiResource): except services.errors.document.DocumentIndexingError: raise DocumentIndexingError("Cannot delete document during indexing.") - return {"result": "success"}, 200 + return {"result": "success"}, 204 class DocumentListApi(DatasetApiResource): diff --git a/api/controllers/service_api/dataset/metadata.py b/api/controllers/service_api/dataset/metadata.py index 298c8a8df8..35578eae54 100644 --- a/api/controllers/service_api/dataset/metadata.py +++ b/api/controllers/service_api/dataset/metadata.py @@ -63,7 +63,7 @@ class DatasetMetadataServiceApi(DatasetApiResource): DatasetService.check_dataset_permission(dataset, current_user) MetadataService.delete_metadata(dataset_id_str, metadata_id_str) - return 200 + return 204 class DatasetMetadataBuiltInFieldServiceApi(DatasetApiResource): diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 2a79e15cc5..95753cfd67 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -159,7 +159,7 @@ class DatasetSegmentApi(DatasetApiResource): if not segment: raise NotFound("Segment not found.") SegmentService.delete_segment(segment, document, dataset) - return {"result": "success"}, 200 + return {"result": "success"}, 204 @cloud_edition_billing_resource_check("vector_space", "dataset") def post(self, tenant_id, dataset_id, document_id, segment_id): @@ -344,7 +344,7 @@ class DatasetChildChunkApi(DatasetApiResource): except ChildChunkDeleteIndexServiceError as e: raise ChildChunkDeleteIndexError(str(e)) - return {"result": "success"}, 200 + return {"result": "success"}, 204 @cloud_edition_billing_resource_check("vector_space", "dataset") @cloud_edition_billing_knowledge_limit_check("add_segment", "dataset") diff --git a/api/controllers/web/saved_message.py b/api/controllers/web/saved_message.py index 6a9b818907..ab2d4abcd3 100644 --- a/api/controllers/web/saved_message.py +++ b/api/controllers/web/saved_message.py @@ -67,7 +67,7 @@ class SavedMessageApi(WebApiResource): SavedMessageService.delete(app_model, end_user, message_id) - return {"result": "success"} + return {"result": "success"}, 204 api.add_resource(SavedMessageListApi, "/saved-messages") diff --git a/api/core/agent/strategy/plugin.py b/api/core/agent/strategy/plugin.py index a4b25f46e6..79b074cf95 100644 --- a/api/core/agent/strategy/plugin.py +++ b/api/core/agent/strategy/plugin.py @@ -4,7 +4,7 @@ from typing import Any, Optional from core.agent.entities import AgentInvokeMessage from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter from core.agent.strategy.base import BaseAgentStrategy -from core.plugin.manager.agent import PluginAgentManager +from core.plugin.impl.agent import PluginAgentClient from core.plugin.utils.converter import convert_parameters_to_plugin_format @@ -42,7 +42,7 @@ class PluginAgentStrategy(BaseAgentStrategy): """ Invoke the agent strategy. """ - manager = PluginAgentManager() + manager = PluginAgentClient() initialized_params = self.initialize_parameters(params) params = convert_parameters_to_plugin_format(initialized_params) diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index bd05590018..3c5a2dce4f 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -26,7 +26,7 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity -from core.plugin.manager.model import PluginModelManager +from core.plugin.impl.model import PluginModelClient class AIModel(BaseModel): @@ -141,7 +141,7 @@ class AIModel(BaseModel): :param credentials: model credentials :return: model schema """ - plugin_model_manager = PluginModelManager() + plugin_model_manager = PluginModelClient() cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}" # sort credentials sorted_credentials = sorted(credentials.items()) if credentials else [] diff --git a/api/core/model_runtime/model_providers/__base/large_language_model.py b/api/core/model_runtime/model_providers/__base/large_language_model.py index 1b799131e7..6312587861 100644 --- a/api/core/model_runtime/model_providers/__base/large_language_model.py +++ b/api/core/model_runtime/model_providers/__base/large_language_model.py @@ -2,7 +2,7 @@ import logging import time import uuid from collections.abc import Generator, Sequence -from typing import Optional, Union +from typing import Optional, Union, cast from pydantic import ConfigDict @@ -20,7 +20,8 @@ from core.model_runtime.entities.model_entities import ( PriceType, ) from core.model_runtime.model_providers.__base.ai_model import AIModel -from core.plugin.manager.model import PluginModelManager +from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str +from core.plugin.impl.model import PluginModelClient logger = logging.getLogger(__name__) @@ -140,7 +141,7 @@ class LargeLanguageModel(AIModel): result: Union[LLMResult, Generator[LLMResultChunk, None, None]] try: - plugin_model_manager = PluginModelManager() + plugin_model_manager = PluginModelClient() result = plugin_model_manager.invoke_llm( tenant_id=self.tenant_id, user_id=user or "unknown", @@ -280,7 +281,9 @@ class LargeLanguageModel(AIModel): callbacks=callbacks, ) - assistant_message.content += chunk.delta.message.content + text = convert_llm_result_chunk_to_str(chunk.delta.message.content) + current_content = cast(str, assistant_message.content) + assistant_message.content = current_content + text real_model = chunk.model if chunk.delta.usage: usage = chunk.delta.usage @@ -326,7 +329,7 @@ class LargeLanguageModel(AIModel): :return: """ if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED: - plugin_model_manager = PluginModelManager() + plugin_model_manager = PluginModelClient() return plugin_model_manager.get_llm_num_tokens( tenant_id=self.tenant_id, user_id="unknown", diff --git a/api/core/model_runtime/model_providers/__base/moderation_model.py b/api/core/model_runtime/model_providers/__base/moderation_model.py index f98d7572c7..19dc1d599a 100644 --- a/api/core/model_runtime/model_providers/__base/moderation_model.py +++ b/api/core/model_runtime/model_providers/__base/moderation_model.py @@ -5,7 +5,7 @@ from pydantic import ConfigDict from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.ai_model import AIModel -from core.plugin.manager.model import PluginModelManager +from core.plugin.impl.model import PluginModelClient class ModerationModel(AIModel): @@ -31,7 +31,7 @@ class ModerationModel(AIModel): self.started_at = time.perf_counter() try: - plugin_model_manager = PluginModelManager() + plugin_model_manager = PluginModelClient() return plugin_model_manager.invoke_moderation( tenant_id=self.tenant_id, user_id=user or "unknown", diff --git a/api/core/model_runtime/model_providers/__base/rerank_model.py b/api/core/model_runtime/model_providers/__base/rerank_model.py index e905cb18d4..569e756a3b 100644 --- a/api/core/model_runtime/model_providers/__base/rerank_model.py +++ b/api/core/model_runtime/model_providers/__base/rerank_model.py @@ -3,7 +3,7 @@ from typing import Optional from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.rerank_entities import RerankResult from core.model_runtime.model_providers.__base.ai_model import AIModel -from core.plugin.manager.model import PluginModelManager +from core.plugin.impl.model import PluginModelClient class RerankModel(AIModel): @@ -36,7 +36,7 @@ class RerankModel(AIModel): :return: rerank result """ try: - plugin_model_manager = PluginModelManager() + plugin_model_manager = PluginModelClient() return plugin_model_manager.invoke_rerank( tenant_id=self.tenant_id, user_id=user or "unknown", diff --git a/api/core/model_runtime/model_providers/__base/speech2text_model.py b/api/core/model_runtime/model_providers/__base/speech2text_model.py index 97ff322f09..c69f65b681 100644 --- a/api/core/model_runtime/model_providers/__base/speech2text_model.py +++ b/api/core/model_runtime/model_providers/__base/speech2text_model.py @@ -4,7 +4,7 @@ from pydantic import ConfigDict from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.ai_model import AIModel -from core.plugin.manager.model import PluginModelManager +from core.plugin.impl.model import PluginModelClient class Speech2TextModel(AIModel): @@ -28,7 +28,7 @@ class Speech2TextModel(AIModel): :return: text for given audio file """ try: - plugin_model_manager = PluginModelManager() + plugin_model_manager = PluginModelClient() return plugin_model_manager.invoke_speech_to_text( tenant_id=self.tenant_id, user_id=user or "unknown", diff --git a/api/core/model_runtime/model_providers/__base/text_embedding_model.py b/api/core/model_runtime/model_providers/__base/text_embedding_model.py index c4c1f92177..f7bba0eba1 100644 --- a/api/core/model_runtime/model_providers/__base/text_embedding_model.py +++ b/api/core/model_runtime/model_providers/__base/text_embedding_model.py @@ -6,7 +6,7 @@ from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult from core.model_runtime.model_providers.__base.ai_model import AIModel -from core.plugin.manager.model import PluginModelManager +from core.plugin.impl.model import PluginModelClient class TextEmbeddingModel(AIModel): @@ -38,7 +38,7 @@ class TextEmbeddingModel(AIModel): :return: embeddings result """ try: - plugin_model_manager = PluginModelManager() + plugin_model_manager = PluginModelClient() return plugin_model_manager.invoke_text_embedding( tenant_id=self.tenant_id, user_id=user or "unknown", @@ -61,7 +61,7 @@ class TextEmbeddingModel(AIModel): :param texts: texts to embed :return: """ - plugin_model_manager = PluginModelManager() + plugin_model_manager = PluginModelClient() return plugin_model_manager.get_text_embedding_num_tokens( tenant_id=self.tenant_id, user_id="unknown", diff --git a/api/core/model_runtime/model_providers/__base/tts_model.py b/api/core/model_runtime/model_providers/__base/tts_model.py index 1f248d11ac..d51831900c 100644 --- a/api/core/model_runtime/model_providers/__base/tts_model.py +++ b/api/core/model_runtime/model_providers/__base/tts_model.py @@ -6,7 +6,7 @@ from pydantic import ConfigDict from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.model_providers.__base.ai_model import AIModel -from core.plugin.manager.model import PluginModelManager +from core.plugin.impl.model import PluginModelClient logger = logging.getLogger(__name__) @@ -42,7 +42,7 @@ class TTSModel(AIModel): :return: translated audio file """ try: - plugin_model_manager = PluginModelManager() + plugin_model_manager = PluginModelClient() return plugin_model_manager.invoke_tts( tenant_id=self.tenant_id, user_id=user or "unknown", @@ -65,7 +65,7 @@ class TTSModel(AIModel): :param credentials: The credentials required to access the TTS model. :return: A list of voices supported by the TTS model. """ - plugin_model_manager = PluginModelManager() + plugin_model_manager = PluginModelClient() return plugin_model_manager.get_tts_model_voices( tenant_id=self.tenant_id, user_id="unknown", diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index d2fd4916a4..ad46f64ec3 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -22,8 +22,8 @@ from core.model_runtime.schema_validators.model_credential_schema_validator impo from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator from core.plugin.entities.plugin import ModelProviderID from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from core.plugin.manager.asset import PluginAssetManager -from core.plugin.manager.model import PluginModelManager +from core.plugin.impl.asset import PluginAssetManager +from core.plugin.impl.model import PluginModelClient logger = logging.getLogger(__name__) @@ -40,7 +40,7 @@ class ModelProviderFactory: self.provider_position_map = {} self.tenant_id = tenant_id - self.plugin_model_manager = PluginModelManager() + self.plugin_model_manager = PluginModelClient() if not self.provider_position_map: # get the path of current classes diff --git a/api/core/model_runtime/utils/helper.py b/api/core/model_runtime/utils/helper.py index 5e8a723ec7..53789a8e91 100644 --- a/api/core/model_runtime/utils/helper.py +++ b/api/core/model_runtime/utils/helper.py @@ -1,6 +1,8 @@ import pydantic from pydantic import BaseModel +from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes + def dump_model(model: BaseModel) -> dict: if hasattr(pydantic, "model_dump"): @@ -8,3 +10,18 @@ def dump_model(model: BaseModel) -> dict: return pydantic.model_dump(model) # type: ignore else: return model.model_dump() + + +def convert_llm_result_chunk_to_str(content: None | str | list[PromptMessageContentUnionTypes]) -> str: + if content is None: + message_text = "" + elif isinstance(content, str): + message_text = content + elif isinstance(content, list): + # Assuming the list contains PromptMessageContent objects with a "data" attribute + message_text = "".join( + item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content + ) + else: + message_text = str(content) + return message_text diff --git a/api/core/plugin/manager/agent.py b/api/core/plugin/impl/agent.py similarity index 97% rename from api/core/plugin/manager/agent.py rename to api/core/plugin/impl/agent.py index 50172f12f2..66b77c7489 100644 --- a/api/core/plugin/manager/agent.py +++ b/api/core/plugin/impl/agent.py @@ -6,10 +6,10 @@ from core.plugin.entities.plugin import GenericProviderID from core.plugin.entities.plugin_daemon import ( PluginAgentProviderEntity, ) -from core.plugin.manager.base import BasePluginManager +from core.plugin.impl.base import BasePluginClient -class PluginAgentManager(BasePluginManager): +class PluginAgentClient(BasePluginClient): def fetch_agent_strategy_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]: """ Fetch agent providers for the given tenant. diff --git a/api/core/plugin/manager/asset.py b/api/core/plugin/impl/asset.py similarity index 76% rename from api/core/plugin/manager/asset.py rename to api/core/plugin/impl/asset.py index 17755d3561..b9bfe2d2cf 100644 --- a/api/core/plugin/manager/asset.py +++ b/api/core/plugin/impl/asset.py @@ -1,7 +1,7 @@ -from core.plugin.manager.base import BasePluginManager +from core.plugin.impl.base import BasePluginClient -class PluginAssetManager(BasePluginManager): +class PluginAssetManager(BasePluginClient): def fetch_asset(self, tenant_id: str, id: str) -> bytes: """ Fetch an asset by id. diff --git a/api/core/plugin/manager/base.py b/api/core/plugin/impl/base.py similarity index 99% rename from api/core/plugin/manager/base.py rename to api/core/plugin/impl/base.py index d8d7b3e860..4f1d808a3e 100644 --- a/api/core/plugin/manager/base.py +++ b/api/core/plugin/impl/base.py @@ -18,7 +18,7 @@ from core.model_runtime.errors.invoke import ( ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError -from core.plugin.manager.exc import ( +from core.plugin.impl.exc import ( PluginDaemonBadRequestError, PluginDaemonInternalServerError, PluginDaemonNotFoundError, @@ -37,7 +37,7 @@ T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) logger = logging.getLogger(__name__) -class BasePluginManager: +class BasePluginClient: def _request( self, method: str, diff --git a/api/core/plugin/manager/debugging.py b/api/core/plugin/impl/debugging.py similarity index 78% rename from api/core/plugin/manager/debugging.py rename to api/core/plugin/impl/debugging.py index fb6bad7fa3..523377895c 100644 --- a/api/core/plugin/manager/debugging.py +++ b/api/core/plugin/impl/debugging.py @@ -1,9 +1,9 @@ from pydantic import BaseModel -from core.plugin.manager.base import BasePluginManager +from core.plugin.impl.base import BasePluginClient -class PluginDebuggingManager(BasePluginManager): +class PluginDebuggingClient(BasePluginClient): def get_debugging_key(self, tenant_id: str) -> str: """ Get the debugging key for the given tenant. diff --git a/api/core/plugin/manager/endpoint.py b/api/core/plugin/impl/endpoint.py similarity index 97% rename from api/core/plugin/manager/endpoint.py rename to api/core/plugin/impl/endpoint.py index 415b981ffb..5b88742be5 100644 --- a/api/core/plugin/manager/endpoint.py +++ b/api/core/plugin/impl/endpoint.py @@ -1,8 +1,8 @@ from core.plugin.entities.endpoint import EndpointEntityWithInstance -from core.plugin.manager.base import BasePluginManager +from core.plugin.impl.base import BasePluginClient -class PluginEndpointManager(BasePluginManager): +class PluginEndpointClient(BasePluginClient): def create_endpoint( self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict ) -> bool: diff --git a/api/core/plugin/manager/exc.py b/api/core/plugin/impl/exc.py similarity index 100% rename from api/core/plugin/manager/exc.py rename to api/core/plugin/impl/exc.py diff --git a/api/core/plugin/manager/model.py b/api/core/plugin/impl/model.py similarity index 99% rename from api/core/plugin/manager/model.py rename to api/core/plugin/impl/model.py index 5ebc0c2320..f7607eef8d 100644 --- a/api/core/plugin/manager/model.py +++ b/api/core/plugin/impl/model.py @@ -18,10 +18,10 @@ from core.plugin.entities.plugin_daemon import ( PluginTextEmbeddingNumTokensResponse, PluginVoicesResponse, ) -from core.plugin.manager.base import BasePluginManager +from core.plugin.impl.base import BasePluginClient -class PluginModelManager(BasePluginManager): +class PluginModelClient(BasePluginClient): def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: """ Fetch model providers for the given tenant. diff --git a/api/core/plugin/impl/oauth.py b/api/core/plugin/impl/oauth.py new file mode 100644 index 0000000000..1d40edb086 --- /dev/null +++ b/api/core/plugin/impl/oauth.py @@ -0,0 +1,6 @@ +from core.plugin.impl.base import BasePluginClient + + +class OAuthHandler(BasePluginClient): + def get_authorization_url(self, tenant_id: str, user_id: str, provider_name: str) -> str: + return "1234567890" diff --git a/api/core/plugin/manager/plugin.py b/api/core/plugin/impl/plugin.py similarity index 98% rename from api/core/plugin/manager/plugin.py rename to api/core/plugin/impl/plugin.py index 15dcd6cb34..3349463ce5 100644 --- a/api/core/plugin/manager/plugin.py +++ b/api/core/plugin/impl/plugin.py @@ -10,10 +10,10 @@ from core.plugin.entities.plugin import ( PluginInstallationSource, ) from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginInstallTaskStartResponse, PluginUploadResponse -from core.plugin.manager.base import BasePluginManager +from core.plugin.impl.base import BasePluginClient -class PluginInstallationManager(BasePluginManager): +class PluginInstaller(BasePluginClient): def fetch_plugin_by_identifier( self, tenant_id: str, diff --git a/api/core/plugin/manager/tool.py b/api/core/plugin/impl/tool.py similarity index 98% rename from api/core/plugin/manager/tool.py rename to api/core/plugin/impl/tool.py index 7592f867e1..19b26c8fe3 100644 --- a/api/core/plugin/manager/tool.py +++ b/api/core/plugin/impl/tool.py @@ -5,11 +5,11 @@ from pydantic import BaseModel from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity -from core.plugin.manager.base import BasePluginManager +from core.plugin.impl.base import BasePluginClient from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -class PluginToolManager(BasePluginManager): +class PluginToolManager(BasePluginClient): def fetch_tool_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]: """ Fetch tool providers for the given tenant. diff --git a/api/core/rag/extractor/watercrawl/provider.py b/api/core/rag/extractor/watercrawl/provider.py index b8003b386b..21fbb2100f 100644 --- a/api/core/rag/extractor/watercrawl/provider.py +++ b/api/core/rag/extractor/watercrawl/provider.py @@ -20,7 +20,7 @@ class WaterCrawlProvider: } if options.get("crawl_sub_pages", True): spider_options["page_limit"] = options.get("limit", 1) - spider_options["max_depth"] = options.get("depth", 1) + spider_options["max_depth"] = options.get("max_depth", 1) spider_options["include_paths"] = options.get("includes", "").split(",") if options.get("includes") else [] spider_options["exclude_paths"] = options.get("excludes", "").split(",") if options.get("excludes") else [] diff --git a/api/core/tools/plugin_tool/provider.py b/api/core/tools/plugin_tool/provider.py index 3616e426b9..494b8e209c 100644 --- a/api/core/tools/plugin_tool/provider.py +++ b/api/core/tools/plugin_tool/provider.py @@ -1,6 +1,6 @@ from typing import Any -from core.plugin.manager.tool import PluginToolManager +from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin, ToolProviderType diff --git a/api/core/tools/plugin_tool/tool.py b/api/core/tools/plugin_tool/tool.py index f31a9a0d3e..d21e3d7d1c 100644 --- a/api/core/tools/plugin_tool/tool.py +++ b/api/core/tools/plugin_tool/tool.py @@ -1,7 +1,7 @@ from collections.abc import Generator from typing import Any, Optional -from core.plugin.manager.tool import PluginToolManager +from core.plugin.impl.tool import PluginToolManager from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index f2d0b74f7c..aa2661fe63 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -10,7 +10,7 @@ from yarl import URL import contexts from core.plugin.entities.plugin import ToolProviderID -from core.plugin.manager.tool import PluginToolManager +from core.plugin.impl.tool import PluginToolManager from core.tools.__base.tool_provider import ToolProviderController from core.tools.__base.tool_runtime import ToolRuntime from core.tools.plugin_tool.provider import PluginToolProviderController diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index da40cbcdea..771e0ca7a5 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -7,8 +7,8 @@ from core.agent.plugin_entities import AgentStrategyParameter from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.model_entities import AIModelEntity, ModelType -from core.plugin.manager.exc import PluginDaemonClientSideError -from core.plugin.manager.plugin import PluginInstallationManager +from core.plugin.impl.exc import PluginDaemonClientSideError +from core.plugin.impl.plugin import PluginInstaller from core.provider_manager import ProviderManager from core.tools.entities.tool_entities import ToolParameter, ToolProviderType from core.tools.tool_manager import ToolManager @@ -297,7 +297,7 @@ class AgentNode(ToolNode): Get agent strategy icon :return: """ - manager = PluginInstallationManager() + manager = PluginInstaller() plugins = manager.list_plugins(self.tenant_id) try: current_plugin = next( diff --git a/api/core/workflow/nodes/llm/node.py b/api/core/workflow/nodes/llm/node.py index 1089e7168e..35b146e5d9 100644 --- a/api/core/workflow/nodes/llm/node.py +++ b/api/core/workflow/nodes/llm/node.py @@ -38,6 +38,7 @@ from core.model_runtime.entities.model_entities import ( ) from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder +from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str from core.plugin.entities.plugin import ModelProviderID from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig from core.prompt.utils.prompt_message_util import PromptMessageUtil @@ -269,18 +270,7 @@ class LLMNode(BaseNode[LLMNodeData]): def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]: if isinstance(invoke_result, LLMResult): - content = invoke_result.message.content - if content is None: - message_text = "" - elif isinstance(content, str): - message_text = content - elif isinstance(content, list): - # Assuming the list contains PromptMessageContent objects with a "data" attribute - message_text = "".join( - item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content - ) - else: - message_text = str(content) + message_text = convert_llm_result_chunk_to_str(invoke_result.message.content) yield ModelInvokeCompletedEvent( text=message_text, @@ -295,7 +285,7 @@ class LLMNode(BaseNode[LLMNodeData]): usage = None finish_reason = None for result in invoke_result: - text = result.delta.message.content + text = convert_llm_result_chunk_to_str(result.delta.message.content) full_text += text yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"]) diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index 6f0cc3f6d2..c72ae5b69b 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -6,8 +6,8 @@ from sqlalchemy.orm import Session from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.file import File, FileTransferMethod -from core.plugin.manager.exc import PluginDaemonClientSideError -from core.plugin.manager.plugin import PluginInstallationManager +from core.plugin.impl.exc import PluginDaemonClientSideError +from core.plugin.impl.plugin import PluginInstaller from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.errors import ToolInvokeError from core.tools.tool_engine import ToolEngine @@ -307,7 +307,7 @@ class ToolNode(BaseNode[ToolNodeData]): icon = tool_info.get("icon", "") dict_metadata = dict(message.message.metadata) if dict_metadata.get("provider"): - manager = PluginInstallationManager() + manager = PluginInstaller() plugins = manager.list_plugins(self.tenant_id) try: current_plugin = next( diff --git a/api/extensions/ext_commands.py b/api/extensions/ext_commands.py index be43f55ea7..ddc2158a02 100644 --- a/api/extensions/ext_commands.py +++ b/api/extensions/ext_commands.py @@ -5,6 +5,7 @@ def init_app(app: DifyApp): from commands import ( add_qdrant_index, clear_free_plan_tenant_expired_logs, + clear_orphaned_file_records, convert_to_agent_apps, create_tenant, extract_plugins, @@ -13,6 +14,7 @@ def init_app(app: DifyApp): install_plugins, migrate_data_for_plugin, old_metadata_migration, + remove_orphaned_files_on_storage, reset_email, reset_encrypt_key_pair, reset_password, @@ -36,6 +38,8 @@ def init_app(app: DifyApp): install_plugins, old_metadata_migration, clear_free_plan_tenant_expired_logs, + clear_orphaned_file_records, + remove_orphaned_files_on_storage, ] for cmd in cmds_to_register: app.cli.add_command(cmd) diff --git a/api/extensions/ext_storage.py b/api/extensions/ext_storage.py index 4c811c66ba..bd35278544 100644 --- a/api/extensions/ext_storage.py +++ b/api/extensions/ext_storage.py @@ -102,6 +102,9 @@ class Storage: def delete(self, filename): return self.storage_runner.delete(filename) + def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]: + return self.storage_runner.scan(path, files=files, directories=directories) + storage = Storage() diff --git a/api/extensions/storage/base_storage.py b/api/extensions/storage/base_storage.py index 0dedd7ff8c..0393206e54 100644 --- a/api/extensions/storage/base_storage.py +++ b/api/extensions/storage/base_storage.py @@ -30,3 +30,11 @@ class BaseStorage(ABC): @abstractmethod def delete(self, filename): raise NotImplementedError + + def scan(self, path, files=True, directories=False) -> list[str]: + """ + Scan files and directories in the given path. + This method is implemented only in some storage backends. + If a storage backend doesn't support scanning, it will raise NotImplementedError. + """ + raise NotImplementedError("This storage backend doesn't support scanning") diff --git a/api/extensions/storage/opendal_storage.py b/api/extensions/storage/opendal_storage.py index ee8cfa9179..12e2738e9d 100644 --- a/api/extensions/storage/opendal_storage.py +++ b/api/extensions/storage/opendal_storage.py @@ -80,3 +80,20 @@ class OpenDALStorage(BaseStorage): logger.debug(f"file {filename} deleted") return logger.debug(f"file {filename} not found, skip delete") + + def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]: + if not self.exists(path): + raise FileNotFoundError("Path not found") + + all_files = self.op.scan(path=path) + if files and directories: + logger.debug(f"files and directories on {path} scanned") + return [f.path for f in all_files] + if files: + logger.debug(f"files on {path} scanned") + return [f.path for f in all_files if not f.path.endswith("/")] + elif directories: + logger.debug(f"directories on {path} scanned") + return [f.path for f in all_files if f.path.endswith("/")] + else: + raise ValueError("At least one of files or directories must be True") diff --git a/api/factories/agent_factory.py b/api/factories/agent_factory.py index 4b2d2cc769..4b12afb528 100644 --- a/api/factories/agent_factory.py +++ b/api/factories/agent_factory.py @@ -1,12 +1,12 @@ from core.agent.strategy.plugin import PluginAgentStrategy -from core.plugin.manager.agent import PluginAgentManager +from core.plugin.impl.agent import PluginAgentClient def get_plugin_agent_strategy( tenant_id: str, agent_strategy_provider_name: str, agent_strategy_name: str ) -> PluginAgentStrategy: # TODO: use contexts to cache the agent provider - manager = PluginAgentManager() + manager = PluginAgentClient() agent_provider = manager.fetch_agent_strategy_provider(tenant_id, agent_strategy_provider_name) for agent_strategy in agent_provider.declaration.strategies: if agent_strategy.identity.name == agent_strategy_name: diff --git a/api/services/agent_service.py b/api/services/agent_service.py index 0ff144052f..4c63611bb3 100644 --- a/api/services/agent_service.py +++ b/api/services/agent_service.py @@ -6,8 +6,8 @@ from flask_login import current_user # type: ignore import contexts from core.app.app_config.easy_ui_based_app.agent.manager import AgentConfigManager -from core.plugin.manager.agent import PluginAgentManager -from core.plugin.manager.exc import PluginDaemonClientSideError +from core.plugin.impl.agent import PluginAgentClient +from core.plugin.impl.exc import PluginDaemonClientSideError from core.tools.tool_manager import ToolManager from extensions.ext_database import db from models.account import Account @@ -161,7 +161,7 @@ class AgentService: """ List agent providers """ - manager = PluginAgentManager() + manager = PluginAgentClient() return manager.fetch_agent_strategy_providers(tenant_id) @classmethod @@ -169,7 +169,7 @@ class AgentService: """ Get agent provider """ - manager = PluginAgentManager() + manager = PluginAgentClient() try: return manager.fetch_agent_strategy_provider(tenant_id, provider_name) except PluginDaemonClientSideError as e: diff --git a/api/services/plugin/dependencies_analysis.py b/api/services/plugin/dependencies_analysis.py index 07e624b4e8..830d3a4769 100644 --- a/api/services/plugin/dependencies_analysis.py +++ b/api/services/plugin/dependencies_analysis.py @@ -1,7 +1,7 @@ from configs import dify_config from core.helper import marketplace from core.plugin.entities.plugin import ModelProviderID, PluginDependency, PluginInstallationSource, ToolProviderID -from core.plugin.manager.plugin import PluginInstallationManager +from core.plugin.impl.plugin import PluginInstaller class DependenciesAnalysisService: @@ -38,7 +38,7 @@ class DependenciesAnalysisService: for dependency in dependencies: required_plugin_unique_identifiers.append(dependency.value.plugin_unique_identifier) - manager = PluginInstallationManager() + manager = PluginInstaller() # get leaked dependencies missing_plugins = manager.fetch_missing_dependencies(tenant_id, required_plugin_unique_identifiers) @@ -64,7 +64,7 @@ class DependenciesAnalysisService: Generate dependencies through the list of plugin ids """ dependencies = list(set(dependencies)) - manager = PluginInstallationManager() + manager = PluginInstaller() plugins = manager.fetch_plugin_installation_by_ids(tenant_id, dependencies) result = [] for plugin in plugins: diff --git a/api/services/plugin/endpoint_service.py b/api/services/plugin/endpoint_service.py index 35961345a8..11b8e0a3d9 100644 --- a/api/services/plugin/endpoint_service.py +++ b/api/services/plugin/endpoint_service.py @@ -1,10 +1,10 @@ -from core.plugin.manager.endpoint import PluginEndpointManager +from core.plugin.impl.endpoint import PluginEndpointClient class EndpointService: @classmethod def create_endpoint(cls, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict): - return PluginEndpointManager().create_endpoint( + return PluginEndpointClient().create_endpoint( tenant_id=tenant_id, user_id=user_id, plugin_unique_identifier=plugin_unique_identifier, @@ -14,7 +14,7 @@ class EndpointService: @classmethod def list_endpoints(cls, tenant_id: str, user_id: str, page: int, page_size: int): - return PluginEndpointManager().list_endpoints( + return PluginEndpointClient().list_endpoints( tenant_id=tenant_id, user_id=user_id, page=page, @@ -23,7 +23,7 @@ class EndpointService: @classmethod def list_endpoints_for_single_plugin(cls, tenant_id: str, user_id: str, plugin_id: str, page: int, page_size: int): - return PluginEndpointManager().list_endpoints_for_single_plugin( + return PluginEndpointClient().list_endpoints_for_single_plugin( tenant_id=tenant_id, user_id=user_id, plugin_id=plugin_id, @@ -33,7 +33,7 @@ class EndpointService: @classmethod def update_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict): - return PluginEndpointManager().update_endpoint( + return PluginEndpointClient().update_endpoint( tenant_id=tenant_id, user_id=user_id, endpoint_id=endpoint_id, @@ -43,7 +43,7 @@ class EndpointService: @classmethod def delete_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str): - return PluginEndpointManager().delete_endpoint( + return PluginEndpointClient().delete_endpoint( tenant_id=tenant_id, user_id=user_id, endpoint_id=endpoint_id, @@ -51,7 +51,7 @@ class EndpointService: @classmethod def enable_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str): - return PluginEndpointManager().enable_endpoint( + return PluginEndpointClient().enable_endpoint( tenant_id=tenant_id, user_id=user_id, endpoint_id=endpoint_id, @@ -59,7 +59,7 @@ class EndpointService: @classmethod def disable_endpoint(cls, tenant_id: str, user_id: str, endpoint_id: str): - return PluginEndpointManager().disable_endpoint( + return PluginEndpointClient().disable_endpoint( tenant_id=tenant_id, user_id=user_id, endpoint_id=endpoint_id, diff --git a/api/services/plugin/oauth_service.py b/api/services/plugin/oauth_service.py new file mode 100644 index 0000000000..461247419b --- /dev/null +++ b/api/services/plugin/oauth_service.py @@ -0,0 +1,7 @@ +from core.plugin.impl.base import BasePluginClient + + +class OAuthService(BasePluginClient): + @classmethod + def get_authorization_url(cls, tenant_id: str, user_id: str, provider_name: str) -> str: + return "1234567890" diff --git a/api/services/plugin/plugin_migration.py b/api/services/plugin/plugin_migration.py index ec9e0aa8dc..dbaaa7160e 100644 --- a/api/services/plugin/plugin_migration.py +++ b/api/services/plugin/plugin_migration.py @@ -17,7 +17,7 @@ from core.agent.entities import AgentToolEntity from core.helper import marketplace from core.plugin.entities.plugin import ModelProviderID, PluginInstallationSource, ToolProviderID from core.plugin.entities.plugin_daemon import PluginInstallTaskStatus -from core.plugin.manager.plugin import PluginInstallationManager +from core.plugin.impl.plugin import PluginInstaller from core.tools.entities.tool_entities import ToolProviderType from models.account import Tenant from models.engine import db @@ -331,7 +331,7 @@ class PluginMigration: """ Install plugins. """ - manager = PluginInstallationManager() + manager = PluginInstaller() plugins = cls.extract_unique_plugins(extracted_plugins) not_installed = [] @@ -426,7 +426,7 @@ class PluginMigration: """ Install plugins for a tenant. """ - manager = PluginInstallationManager() + manager = PluginInstaller() # download all the plugins and upload thread_pool = ThreadPoolExecutor(max_workers=10) diff --git a/api/services/plugin/plugin_service.py b/api/services/plugin/plugin_service.py index 4d213dd761..be722a59ad 100644 --- a/api/services/plugin/plugin_service.py +++ b/api/services/plugin/plugin_service.py @@ -18,9 +18,9 @@ from core.plugin.entities.plugin import ( PluginInstallationSource, ) from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginUploadResponse -from core.plugin.manager.asset import PluginAssetManager -from core.plugin.manager.debugging import PluginDebuggingManager -from core.plugin.manager.plugin import PluginInstallationManager +from core.plugin.impl.asset import PluginAssetManager +from core.plugin.impl.debugging import PluginDebuggingClient +from core.plugin.impl.plugin import PluginInstaller from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) @@ -91,7 +91,7 @@ class PluginService: """ get the debugging key of the tenant """ - manager = PluginDebuggingManager() + manager = PluginDebuggingClient() return manager.get_debugging_key(tenant_id) @staticmethod @@ -106,7 +106,7 @@ class PluginService: """ list all plugins of the tenant """ - manager = PluginInstallationManager() + manager = PluginInstaller() plugins = manager.list_plugins(tenant_id) return plugins @@ -115,7 +115,7 @@ class PluginService: """ List plugin installations from ids """ - manager = PluginInstallationManager() + manager = PluginInstaller() return manager.fetch_plugin_installation_by_ids(tenant_id, ids) @staticmethod @@ -133,7 +133,7 @@ class PluginService: """ check if the plugin unique identifier is already installed by other tenant """ - manager = PluginInstallationManager() + manager = PluginInstaller() return manager.fetch_plugin_by_identifier(tenant_id, plugin_unique_identifier) @staticmethod @@ -141,7 +141,7 @@ class PluginService: """ Fetch plugin manifest """ - manager = PluginInstallationManager() + manager = PluginInstaller() return manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier) @staticmethod @@ -149,12 +149,12 @@ class PluginService: """ Fetch plugin installation tasks """ - manager = PluginInstallationManager() + manager = PluginInstaller() return manager.fetch_plugin_installation_tasks(tenant_id, page, page_size) @staticmethod def fetch_install_task(tenant_id: str, task_id: str) -> PluginInstallTask: - manager = PluginInstallationManager() + manager = PluginInstaller() return manager.fetch_plugin_installation_task(tenant_id, task_id) @staticmethod @@ -162,7 +162,7 @@ class PluginService: """ Delete a plugin installation task """ - manager = PluginInstallationManager() + manager = PluginInstaller() return manager.delete_plugin_installation_task(tenant_id, task_id) @staticmethod @@ -172,7 +172,7 @@ class PluginService: """ Delete all plugin installation task items """ - manager = PluginInstallationManager() + manager = PluginInstaller() return manager.delete_all_plugin_installation_task_items(tenant_id) @staticmethod @@ -180,7 +180,7 @@ class PluginService: """ Delete a plugin installation task item """ - manager = PluginInstallationManager() + manager = PluginInstaller() return manager.delete_plugin_installation_task_item(tenant_id, task_id, identifier) @staticmethod @@ -197,7 +197,7 @@ class PluginService: raise ValueError("you should not upgrade plugin with the same plugin") # check if plugin pkg is already downloaded - manager = PluginInstallationManager() + manager = PluginInstaller() try: manager.fetch_plugin_manifest(tenant_id, new_plugin_unique_identifier) @@ -230,7 +230,7 @@ class PluginService: """ Upgrade plugin with github """ - manager = PluginInstallationManager() + manager = PluginInstaller() return manager.upgrade_plugin( tenant_id, original_plugin_unique_identifier, @@ -250,7 +250,7 @@ class PluginService: returns: plugin_unique_identifier """ - manager = PluginInstallationManager() + manager = PluginInstaller() return manager.upload_pkg(tenant_id, pkg, verify_signature) @staticmethod @@ -265,7 +265,7 @@ class PluginService: f"https://github.com/{repo}/releases/download/{version}/{package}", dify_config.PLUGIN_MAX_PACKAGE_SIZE ) - manager = PluginInstallationManager() + manager = PluginInstaller() return manager.upload_pkg( tenant_id, pkg, @@ -279,12 +279,12 @@ class PluginService: """ Upload a plugin bundle and return the dependencies. """ - manager = PluginInstallationManager() + manager = PluginInstaller() return manager.upload_bundle(tenant_id, bundle, verify_signature) @staticmethod def install_from_local_pkg(tenant_id: str, plugin_unique_identifiers: Sequence[str]): - manager = PluginInstallationManager() + manager = PluginInstaller() return manager.install_from_identifiers( tenant_id, plugin_unique_identifiers, @@ -298,7 +298,7 @@ class PluginService: Install plugin from github release package files, returns plugin_unique_identifier """ - manager = PluginInstallationManager() + manager = PluginInstaller() return manager.install_from_identifiers( tenant_id, [plugin_unique_identifier], @@ -322,7 +322,7 @@ class PluginService: if not dify_config.MARKETPLACE_ENABLED: raise ValueError("marketplace is not enabled") - manager = PluginInstallationManager() + manager = PluginInstaller() try: declaration = manager.fetch_plugin_manifest(tenant_id, plugin_unique_identifier) except Exception: @@ -342,7 +342,7 @@ class PluginService: if not dify_config.MARKETPLACE_ENABLED: raise ValueError("marketplace is not enabled") - manager = PluginInstallationManager() + manager = PluginInstaller() # check if already downloaded for plugin_unique_identifier in plugin_unique_identifiers: @@ -368,7 +368,7 @@ class PluginService: @staticmethod def uninstall(tenant_id: str, plugin_installation_id: str) -> bool: - manager = PluginInstallationManager() + manager = PluginInstaller() return manager.uninstall(tenant_id, plugin_installation_id) @staticmethod @@ -376,5 +376,5 @@ class PluginService: """ Check if the tools exist """ - manager = PluginInstallationManager() + manager = PluginInstaller() return manager.check_tools_existence(tenant_id, provider_ids) diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 075c60842b..3ccd14415d 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -8,7 +8,7 @@ from configs import dify_config from core.helper.position_helper import is_filtered from core.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin import GenericProviderID, ToolProviderID -from core.plugin.manager.exc import PluginDaemonClientSideError +from core.plugin.impl.exc import PluginDaemonClientSideError from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity from core.tools.errors import ToolNotFoundError, ToolProviderCredentialValidationError, ToolProviderNotFoundError diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_daemon.py b/api/tests/integration_tests/model_runtime/__mock/plugin_daemon.py index 6dfc01ab4c..e3c592b583 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_daemon.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_daemon.py @@ -6,7 +6,7 @@ import pytest # import monkeypatch from _pytest.monkeypatch import MonkeyPatch -from core.plugin.manager.model import PluginModelManager +from core.plugin.impl.model import PluginModelClient from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass @@ -23,9 +23,9 @@ def mock_plugin_daemon( def unpatch() -> None: monkeypatch.undo() - monkeypatch.setattr(PluginModelManager, "invoke_llm", MockModelClass.invoke_llm) - monkeypatch.setattr(PluginModelManager, "fetch_model_providers", MockModelClass.fetch_model_providers) - monkeypatch.setattr(PluginModelManager, "get_model_schema", MockModelClass.get_model_schema) + monkeypatch.setattr(PluginModelClient, "invoke_llm", MockModelClass.invoke_llm) + monkeypatch.setattr(PluginModelClient, "fetch_model_providers", MockModelClass.fetch_model_providers) + monkeypatch.setattr(PluginModelClient, "get_model_schema", MockModelClass.get_model_schema) return unpatch diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index 50913662e2..d699866fb4 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -19,10 +19,10 @@ from core.model_runtime.entities.model_entities import ( ) from core.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from core.plugin.manager.model import PluginModelManager +from core.plugin.impl.model import PluginModelClient -class MockModelClass(PluginModelManager): +class MockModelClass(PluginModelClient): def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: """ Fetch model providers for the given tenant. @@ -232,7 +232,7 @@ class MockModelClass(PluginModelManager): ) def invoke_llm( - self: PluginModelManager, + self: PluginModelClient, *, tenant_id: str, user_id: str, diff --git a/api/tests/integration_tests/plugin/tools/test_fetch_all_tools.py b/api/tests/integration_tests/plugin/tools/test_fetch_all_tools.py index c6d836ed6d..b6d583e338 100644 --- a/api/tests/integration_tests/plugin/tools/test_fetch_all_tools.py +++ b/api/tests/integration_tests/plugin/tools/test_fetch_all_tools.py @@ -1,4 +1,4 @@ -from core.plugin.manager.tool import PluginToolManager +from core.plugin.impl.tool import PluginToolManager from tests.integration_tests.plugin.__mock.http import setup_http_mock diff --git a/web/app/components/app/app-publisher/index.tsx b/web/app/components/app/app-publisher/index.tsx index 360741ab2e..d4357a0955 100644 --- a/web/app/components/app/app-publisher/index.tsx +++ b/web/app/components/app/app-publisher/index.tsx @@ -231,7 +231,7 @@ const AppPublisher = ({ > {t('workflow.common.runApp')} - {appDetail?.mode === 'workflow' + {appDetail?.mode === 'workflow' || appDetail?.mode === 'completion' ? ( void + onConfirm: () => void + confirmDisabled?: boolean +} +const DSLConfirmModal = ({ + versions = { importedVersion: '', systemVersion: '' }, + onCancel, + onConfirm, + confirmDisabled = false, +}: DSLConfirmModalProps) => { + const { t } = useTranslation() + + return ( + onCancel()} + className='w-[480px]' + > +
+
{t('app.newApp.appCreateDSLErrorTitle')}
+
+
{t('app.newApp.appCreateDSLErrorPart1')}
+
{t('app.newApp.appCreateDSLErrorPart2')}
+
+
{t('app.newApp.appCreateDSLErrorPart3')}{versions.importedVersion}
+
{t('app.newApp.appCreateDSLErrorPart4')}{versions.systemVersion}
+
+
+
+ + +
+
+ ) +} + +export default DSLConfirmModal diff --git a/web/app/components/explore/app-list/index.tsx b/web/app/components/explore/app-list/index.tsx index e217dda2b2..7e2d990bc8 100644 --- a/web/app/components/explore/app-list/index.tsx +++ b/web/app/components/explore/app-list/index.tsx @@ -1,12 +1,10 @@ 'use client' -import React, { useMemo, useState } from 'react' -import { useRouter } from 'next/navigation' +import React, { useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { useContext } from 'use-context-selector' import useSWR from 'swr' import { useDebounceFn } from 'ahooks' -import Toast from '../../base/toast' import s from './style.module.css' import cn from '@/utils/classnames' import ExploreContext from '@/context/explore-context' @@ -14,17 +12,16 @@ import type { App } from '@/models/explore' import Category from '@/app/components/explore/category' import AppCard from '@/app/components/explore/app-card' import { fetchAppDetail, fetchAppList } from '@/service/explore' -import { importDSL } from '@/service/apps' import { useTabSearchParams } from '@/hooks/use-tab-searchparams' import CreateAppModal from '@/app/components/explore/create-app-modal' import type { CreateAppModalProps } from '@/app/components/explore/create-app-modal' import Loading from '@/app/components/base/loading' -import { NEED_REFRESH_APP_LIST_KEY } from '@/config' -import { useAppContext } from '@/context/app-context' -import { getRedirection } from '@/utils/app-redirection' import Input from '@/app/components/base/input' -import { DSLImportMode } from '@/models/app' -import { usePluginDependencies } from '@/app/components/workflow/plugin-dependency/hooks' +import { + DSLImportMode, +} from '@/models/app' +import { useImportDSL } from '@/hooks/use-import-dsl' +import DSLConfirmModal from '@/app/components/app/create-from-dsl-modal/dsl-confirm-modal' type AppsProps = { onSuccess?: () => void @@ -39,8 +36,6 @@ const Apps = ({ onSuccess, }: AppsProps) => { const { t } = useTranslation() - const { isCurrentWorkspaceEditor } = useAppContext() - const { push } = useRouter() const { hasEditPermission } = useContext(ExploreContext) const allCategoriesEn = t('explore.apps.allCategories', { lng: 'en' }) @@ -115,7 +110,14 @@ const Apps = ({ const [currApp, setCurrApp] = React.useState(null) const [isShowCreateModal, setIsShowCreateModal] = React.useState(false) - const { handleCheckPluginDependencies } = usePluginDependencies() + + const { + handleImportDSL, + handleImportDSLConfirm, + versions, + isFetching, + } = useImportDSL() + const [showDSLConfirmModal, setShowDSLConfirmModal] = useState(false) const onCreate: CreateAppModalProps['onConfirm'] = async ({ name, icon_type, @@ -123,36 +125,34 @@ const Apps = ({ icon_background, description, }) => { - const { export_data, mode } = await fetchAppDetail( + const { export_data } = await fetchAppDetail( currApp?.app.id as string, ) - try { - const app = await importDSL({ - mode: DSLImportMode.YAML_CONTENT, - yaml_content: export_data, - name, - icon_type, - icon, - icon_background, - description, - }) - setIsShowCreateModal(false) - Toast.notify({ - type: 'success', - message: t('app.newApp.appCreated'), - }) - if (onSuccess) - onSuccess() - if (app.app_id) - await handleCheckPluginDependencies(app.app_id) - localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') - getRedirection(isCurrentWorkspaceEditor, { id: app.app_id!, mode }, push) - } - catch { - Toast.notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) + const payload = { + mode: DSLImportMode.YAML_CONTENT, + yaml_content: export_data, + name, + icon_type, + icon, + icon_background, + description, } + await handleImportDSL(payload, { + onSuccess: () => { + setIsShowCreateModal(false) + }, + onPending: () => { + setShowDSLConfirmModal(true) + }, + }) } + const onConfirmDSL = useCallback(async () => { + await handleImportDSLConfirm({ + onSuccess, + }) + }, [handleImportDSLConfirm, onSuccess]) + if (!categories || categories.length === 0) { return (
@@ -225,9 +225,20 @@ const Apps = ({ appDescription={currApp?.app.description || ''} show={isShowCreateModal} onConfirm={onCreate} + confirmDisabled={isFetching} onHide={() => setIsShowCreateModal(false)} /> )} + { + showDSLConfirmModal && ( + setShowDSLConfirmModal(false)} + onConfirm={onConfirmDSL} + confirmDisabled={isFetching} + /> + ) + }
) } diff --git a/web/app/components/explore/create-app-modal/index.tsx b/web/app/components/explore/create-app-modal/index.tsx index d6d521833a..f30b286786 100644 --- a/web/app/components/explore/create-app-modal/index.tsx +++ b/web/app/components/explore/create-app-modal/index.tsx @@ -35,6 +35,7 @@ export type CreateAppModalProps = { description: string use_icon_as_answer_icon?: boolean }) => Promise + confirmDisabled?: boolean onHide: () => void } @@ -50,6 +51,7 @@ const CreateAppModal = ({ appMode, appUseIconAsAnswerIcon, onConfirm, + confirmDisabled, onHide, }: CreateAppModalProps) => { const { t } = useTranslation() @@ -160,7 +162,7 @@ const CreateAppModal = ({