diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index e0061f933b..b2fcf3ce7b 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -5,7 +5,7 @@ from typing import cast from flask import request from flask_login import current_user -from flask_restful import Resource, fields, marshal, marshal_with, reqparse +from flask_restful import Resource, marshal, marshal_with, reqparse from sqlalchemy import asc, desc, select from werkzeug.exceptions import Forbidden, NotFound @@ -239,12 +239,10 @@ class DatasetDocumentListApi(Resource): return response - documents_and_batch_fields = {"documents": fields.List(fields.Nested(document_fields)), "batch": fields.String} - @setup_required @login_required @account_initialization_required - @marshal_with(documents_and_batch_fields) + @marshal_with(dataset_and_document_fields) @cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_rate_limit_check("knowledge") def post(self, dataset_id): @@ -290,6 +288,8 @@ class DatasetDocumentListApi(Resource): try: documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user) + dataset = DatasetService.get_dataset(dataset_id) + except ProviderTokenNotInitError as ex: raise ProviderNotInitializeError(ex.description) except QuotaExceededError: diff --git a/api/core/workflow/graph_engine/graph_engine.py b/api/core/workflow/graph_engine/graph_engine.py index 2809afad02..61a7a26652 100644 --- a/api/core/workflow/graph_engine/graph_engine.py +++ b/api/core/workflow/graph_engine/graph_engine.py @@ -53,6 +53,7 @@ from core.workflow.nodes.end.end_stream_processor import EndStreamProcessor from core.workflow.nodes.enums import ErrorStrategy, FailBranchSourceHandle from core.workflow.nodes.event import RunCompletedEvent, RunRetrieverResourceEvent, RunStreamChunkEvent from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.utils import variable_utils from libs.flask_utils import preserve_flask_contexts from models.enums import UserFrom from models.workflow import WorkflowType @@ -856,16 +857,12 @@ class GraphEngine: :param variable_value: variable value :return: """ - self.graph_runtime_state.variable_pool.add([node_id] + variable_key_list, variable_value) - - # if variable_value is a dict, then recursively append variables - if isinstance(variable_value, dict): - for key, value in variable_value.items(): - # construct new key list - new_key_list = variable_key_list + [key] - self._append_variables_recursively( - node_id=node_id, variable_key_list=new_key_list, variable_value=value - ) + variable_utils.append_variables_recursively( + self.graph_runtime_state.variable_pool, + node_id, + variable_key_list, + variable_value, + ) def _is_timed_out(self, start_at: float, max_execution_time: int) -> bool: """ diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 22c564c1fc..2f28363955 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -39,6 +39,10 @@ class AgentNode(ToolNode): _node_data_cls = AgentNodeData # type: ignore _node_type = NodeType.AGENT + @classmethod + def version(cls) -> str: + return "1" + def _run(self) -> Generator: """ Run the agent node diff --git a/api/core/workflow/nodes/document_extractor/node.py b/api/core/workflow/nodes/document_extractor/node.py index 9f48b48865..8e6150f9cc 100644 --- a/api/core/workflow/nodes/document_extractor/node.py +++ b/api/core/workflow/nodes/document_extractor/node.py @@ -451,7 +451,7 @@ def _extract_text_from_excel(file_content: bytes) -> str: df = df.applymap(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) # type: ignore # Combine multi-line text in column names into a single line - df.columns = pd.Index([" ".join(col.splitlines()) for col in df.columns]) + df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns]) # Manually construct the Markdown table markdown_table += _construct_markdown_table(df) + "\n\n" diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 2995f0682f..0b9e98f28a 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -71,6 +71,10 @@ class KnowledgeRetrievalNode(LLMNode): _node_data_cls = KnowledgeRetrievalNodeData # type: ignore _node_type = NodeType.KNOWLEDGE_RETRIEVAL + @classmethod + def version(cls): + return "1" + def _run(self) -> NodeRunResult: # type: ignore node_data = cast(KnowledgeRetrievalNodeData, self.node_data) # extract variables diff --git a/api/core/workflow/nodes/question_classifier/question_classifier_node.py b/api/core/workflow/nodes/question_classifier/question_classifier_node.py index 1f50700c7e..a518167cc6 100644 --- a/api/core/workflow/nodes/question_classifier/question_classifier_node.py +++ b/api/core/workflow/nodes/question_classifier/question_classifier_node.py @@ -40,6 +40,10 @@ class QuestionClassifierNode(LLMNode): _node_data_cls = QuestionClassifierNodeData # type: ignore _node_type = NodeType.QUESTION_CLASSIFIER + @classmethod + def version(cls): + return "1" + def _run(self): node_data = cast(QuestionClassifierNodeData, self.node_data) variable_pool = self.graph_runtime_state.variable_pool diff --git a/api/core/workflow/utils/variable_utils.py b/api/core/workflow/utils/variable_utils.py new file mode 100644 index 0000000000..e68d990b60 --- /dev/null +++ b/api/core/workflow/utils/variable_utils.py @@ -0,0 +1,28 @@ +from core.variables.segments import ObjectSegment, Segment +from core.workflow.entities.variable_pool import VariablePool, VariableValue + + +def append_variables_recursively( + pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue | Segment +): + """ + Append variables recursively + :param node_id: node id + :param variable_key_list: variable key list + :param variable_value: variable value + :return: + """ + pool.add([node_id] + variable_key_list, variable_value) + + # if variable_value is a dict, then recursively append variables + if isinstance(variable_value, ObjectSegment): + variable_dict = variable_value.value + elif isinstance(variable_value, dict): + variable_dict = variable_value + else: + return + + for key, value in variable_dict.items(): + # construct new key list + new_key_list = variable_key_list + [key] + append_variables_recursively(pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value) diff --git a/api/core/workflow/variable_loader.py b/api/core/workflow/variable_loader.py index 4842ee00a5..1e13871d0a 100644 --- a/api/core/workflow/variable_loader.py +++ b/api/core/workflow/variable_loader.py @@ -3,7 +3,9 @@ from collections.abc import Mapping, Sequence from typing import Any, Protocol from core.variables import Variable +from core.variables.consts import MIN_SELECTORS_LENGTH from core.workflow.entities.variable_pool import VariablePool +from core.workflow.utils import variable_utils class VariableLoader(Protocol): @@ -76,4 +78,7 @@ def load_into_variable_pool( variables_to_load.append(list(selector)) loaded = variable_loader.load_variables(variables_to_load) for var in loaded: - variable_pool.add(var.selector, var) + assert len(var.selector) >= MIN_SELECTORS_LENGTH, f"Invalid variable {var}" + variable_utils.append_variables_recursively( + variable_pool, node_id=var.selector[0], variable_key_list=list(var.selector[1:]), variable_value=var + ) diff --git a/api/libs/smtp.py b/api/libs/smtp.py index 35561f071c..b94386660e 100644 --- a/api/libs/smtp.py +++ b/api/libs/smtp.py @@ -22,7 +22,11 @@ class SMTPClient: if self.use_tls: if self.opportunistic_tls: smtp = smtplib.SMTP(self.server, self.port, timeout=10) + # Send EHLO command with the HELO domain name as the server address + smtp.ehlo(self.server) smtp.starttls() + # Resend EHLO command to identify the TLS session + smtp.ehlo(self.server) else: smtp = smtplib.SMTP_SSL(self.server, self.port, timeout=10) else: diff --git a/api/models/model.py b/api/models/model.py index 75202d9367..c980e61baf 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -832,7 +832,12 @@ class Conversation(Base): @property def first_message(self): - return db.session.query(Message).filter(Message.conversation_id == self.id).first() + return ( + db.session.query(Message) + .filter(Message.conversation_id == self.id) + .order_by(Message.created_at.asc()) + .first() + ) @property def app(self): diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 49ca98624a..af1210c19d 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -280,174 +280,328 @@ class DatasetService: @staticmethod def update_dataset(dataset_id, data, user): + """ + Update dataset configuration and settings. + + Args: + dataset_id: The unique identifier of the dataset to update + data: Dictionary containing the update data + user: The user performing the update operation + + Returns: + Dataset: The updated dataset object + + Raises: + ValueError: If dataset not found or validation fails + NoPermissionError: If user lacks permission to update the dataset + """ + # Retrieve and validate dataset existence dataset = DatasetService.get_dataset(dataset_id) if not dataset: raise ValueError("Dataset not found") + # Verify user has permission to update this dataset DatasetService.check_dataset_permission(dataset, user) + + # Handle external dataset updates if dataset.provider == "external": - external_retrieval_model = data.get("external_retrieval_model", None) - if external_retrieval_model: - dataset.retrieval_model = external_retrieval_model - dataset.name = data.get("name", dataset.name) - dataset.description = data.get("description", "") - permission = data.get("permission") - if permission: - dataset.permission = permission - external_knowledge_id = data.get("external_knowledge_id", None) - db.session.add(dataset) - if not external_knowledge_id: - raise ValueError("External knowledge id is required.") - external_knowledge_api_id = data.get("external_knowledge_api_id", None) - if not external_knowledge_api_id: - raise ValueError("External knowledge api id is required.") - - with Session(db.engine) as session: - external_knowledge_binding = ( - session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() - ) + return DatasetService._update_external_dataset(dataset, data, user) + else: + return DatasetService._update_internal_dataset(dataset, data, user) - if not external_knowledge_binding: - raise ValueError("External knowledge binding not found.") + @staticmethod + def _update_external_dataset(dataset, data, user): + """ + Update external dataset configuration. - if ( - external_knowledge_binding.external_knowledge_id != external_knowledge_id - or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id - ): - external_knowledge_binding.external_knowledge_id = external_knowledge_id - external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id - db.session.add(external_knowledge_binding) - db.session.commit() - else: - data.pop("partial_member_list", None) - data.pop("external_knowledge_api_id", None) - data.pop("external_knowledge_id", None) - data.pop("external_retrieval_model", None) - filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} - action = None - if dataset.indexing_technique != data["indexing_technique"]: - # if update indexing_technique - if data["indexing_technique"] == "economy": - action = "remove" - filtered_data["embedding_model"] = None - filtered_data["embedding_model_provider"] = None - filtered_data["collection_binding_id"] = None - elif data["indexing_technique"] == "high_quality": - action = "add" - # get embedding model setting - try: - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=data["embedding_model_provider"], - model_type=ModelType.TEXT_EMBEDDING, - model=data["embedding_model"], - ) - filtered_data["embedding_model"] = embedding_model.model - filtered_data["embedding_model_provider"] = embedding_model.provider - dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model - ) - filtered_data["collection_binding_id"] = dataset_collection_binding.id - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) - else: - # add default plugin id to both setting sets, to make sure the plugin model provider is consistent - # Skip embedding model checks if not provided in the update request - if ( - "embedding_model_provider" not in data - or "embedding_model" not in data - or not data.get("embedding_model_provider") - or not data.get("embedding_model") - ): - # If the dataset already has embedding model settings, use those - if dataset.embedding_model_provider and dataset.embedding_model: - # Keep existing values - filtered_data["embedding_model_provider"] = dataset.embedding_model_provider - filtered_data["embedding_model"] = dataset.embedding_model - # If collection_binding_id exists, keep it too - if dataset.collection_binding_id: - filtered_data["collection_binding_id"] = dataset.collection_binding_id - # Otherwise, don't try to update embedding model settings at all - # Remove these fields from filtered_data if they exist but are None/empty - if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]: - del filtered_data["embedding_model_provider"] - if "embedding_model" in filtered_data and not filtered_data["embedding_model"]: - del filtered_data["embedding_model"] - else: - skip_embedding_update = False - try: - # Handle existing model provider - plugin_model_provider = dataset.embedding_model_provider - plugin_model_provider_str = None - if plugin_model_provider: - plugin_model_provider_str = str(ModelProviderID(plugin_model_provider)) - - # Handle new model provider from request - new_plugin_model_provider = data["embedding_model_provider"] - new_plugin_model_provider_str = None - if new_plugin_model_provider: - new_plugin_model_provider_str = str(ModelProviderID(new_plugin_model_provider)) - - # Only update embedding model if both values are provided and different from current - if ( - plugin_model_provider_str != new_plugin_model_provider_str - or data["embedding_model"] != dataset.embedding_model - ): - action = "update" - model_manager = ModelManager() - try: - embedding_model = model_manager.get_model_instance( - tenant_id=current_user.current_tenant_id, - provider=data["embedding_model_provider"], - model_type=ModelType.TEXT_EMBEDDING, - model=data["embedding_model"], - ) - except ProviderTokenNotInitError: - # If we can't get the embedding model, skip updating it - # and keep the existing settings if available - if dataset.embedding_model_provider and dataset.embedding_model: - filtered_data["embedding_model_provider"] = dataset.embedding_model_provider - filtered_data["embedding_model"] = dataset.embedding_model - if dataset.collection_binding_id: - filtered_data["collection_binding_id"] = dataset.collection_binding_id - # Skip the rest of the embedding model update - skip_embedding_update = True - if not skip_embedding_update: - filtered_data["embedding_model"] = embedding_model.model - filtered_data["embedding_model_provider"] = embedding_model.provider - dataset_collection_binding = ( - DatasetCollectionBindingService.get_dataset_collection_binding( - embedding_model.provider, embedding_model.model - ) - ) - filtered_data["collection_binding_id"] = dataset_collection_binding.id - except LLMBadRequestError: - raise ValueError( - "No Embedding Model available. Please configure a valid provider " - "in the Settings -> Model Provider." - ) - except ProviderTokenNotInitError as ex: - raise ValueError(ex.description) + Args: + dataset: The dataset object to update + data: Update data dictionary + user: User performing the update - filtered_data["updated_by"] = user.id - filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + Returns: + Dataset: Updated dataset object + """ + # Update retrieval model if provided + external_retrieval_model = data.get("external_retrieval_model", None) + if external_retrieval_model: + dataset.retrieval_model = external_retrieval_model + + # Update basic dataset properties + dataset.name = data.get("name", dataset.name) + dataset.description = data.get("description", dataset.description) + + # Update permission if provided + permission = data.get("permission") + if permission: + dataset.permission = permission + + # Validate and update external knowledge configuration + external_knowledge_id = data.get("external_knowledge_id", None) + external_knowledge_api_id = data.get("external_knowledge_api_id", None) + + if not external_knowledge_id: + raise ValueError("External knowledge id is required.") + if not external_knowledge_api_id: + raise ValueError("External knowledge api id is required.") + # Update metadata fields + dataset.updated_by = user.id if user else None + dataset.updated_at = datetime.datetime.utcnow() + db.session.add(dataset) - # update Retrieval model - filtered_data["retrieval_model"] = data["retrieval_model"] + # Update external knowledge binding + DatasetService._update_external_knowledge_binding(dataset.id, external_knowledge_id, external_knowledge_api_id) - db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data) + # Commit changes to database + db.session.commit() + + return dataset + + @staticmethod + def _update_external_knowledge_binding(dataset_id, external_knowledge_id, external_knowledge_api_id): + """ + Update external knowledge binding configuration. + + Args: + dataset_id: Dataset identifier + external_knowledge_id: External knowledge identifier + external_knowledge_api_id: External knowledge API identifier + """ + with Session(db.engine) as session: + external_knowledge_binding = ( + session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id).first() + ) + + if not external_knowledge_binding: + raise ValueError("External knowledge binding not found.") + + # Update binding if values have changed + if ( + external_knowledge_binding.external_knowledge_id != external_knowledge_id + or external_knowledge_binding.external_knowledge_api_id != external_knowledge_api_id + ): + external_knowledge_binding.external_knowledge_id = external_knowledge_id + external_knowledge_binding.external_knowledge_api_id = external_knowledge_api_id + db.session.add(external_knowledge_binding) + + @staticmethod + def _update_internal_dataset(dataset, data, user): + """ + Update internal dataset configuration. + + Args: + dataset: The dataset object to update + data: Update data dictionary + user: User performing the update + + Returns: + Dataset: Updated dataset object + """ + # Remove external-specific fields from update data + data.pop("partial_member_list", None) + data.pop("external_knowledge_api_id", None) + data.pop("external_knowledge_id", None) + data.pop("external_retrieval_model", None) + + # Filter out None values except for description field + filtered_data = {k: v for k, v in data.items() if v is not None or k == "description"} + + # Handle indexing technique changes and embedding model updates + action = DatasetService._handle_indexing_technique_change(dataset, data, filtered_data) + + # Add metadata fields + filtered_data["updated_by"] = user.id + filtered_data["updated_at"] = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) + # update Retrieval model + filtered_data["retrieval_model"] = data["retrieval_model"] + + # Update dataset in database + db.session.query(Dataset).filter_by(id=dataset.id).update(filtered_data) + db.session.commit() + + # Trigger vector index task if indexing technique changed + if action: + deal_dataset_vector_index_task.delay(dataset.id, action) - db.session.commit() - if action: - deal_dataset_vector_index_task.delay(dataset_id, action) return dataset + @staticmethod + def _handle_indexing_technique_change(dataset, data, filtered_data): + """ + Handle changes in indexing technique and configure embedding models accordingly. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data + + Returns: + str: Action to perform ('add', 'remove', 'update', or None) + """ + if dataset.indexing_technique != data["indexing_technique"]: + if data["indexing_technique"] == "economy": + # Remove embedding model configuration for economy mode + filtered_data["embedding_model"] = None + filtered_data["embedding_model_provider"] = None + filtered_data["collection_binding_id"] = None + return "remove" + elif data["indexing_technique"] == "high_quality": + # Configure embedding model for high quality mode + DatasetService._configure_embedding_model_for_high_quality(data, filtered_data) + return "add" + else: + # Handle embedding model updates when indexing technique remains the same + return DatasetService._handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data) + return None + + @staticmethod + def _configure_embedding_model_for_high_quality(data, filtered_data): + """ + Configure embedding model settings for high quality indexing. + + Args: + data: Update data dictionary + filtered_data: Filtered update data to modify + """ + try: + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=data["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=data["embedding_model"], + ) + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + filtered_data["collection_binding_id"] = dataset_collection_binding.id + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + + @staticmethod + def _handle_embedding_model_update_when_technique_unchanged(dataset, data, filtered_data): + """ + Handle embedding model updates when indexing technique remains the same. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data to modify + + Returns: + str: Action to perform ('update' or None) + """ + # Skip embedding model checks if not provided in the update request + if ( + "embedding_model_provider" not in data + or "embedding_model" not in data + or not data.get("embedding_model_provider") + or not data.get("embedding_model") + ): + DatasetService._preserve_existing_embedding_settings(dataset, filtered_data) + return None + else: + return DatasetService._update_embedding_model_settings(dataset, data, filtered_data) + + @staticmethod + def _preserve_existing_embedding_settings(dataset, filtered_data): + """ + Preserve existing embedding model settings when not provided in update. + + Args: + dataset: Current dataset object + filtered_data: Filtered update data to modify + """ + # If the dataset already has embedding model settings, use those + if dataset.embedding_model_provider and dataset.embedding_model: + filtered_data["embedding_model_provider"] = dataset.embedding_model_provider + filtered_data["embedding_model"] = dataset.embedding_model + # If collection_binding_id exists, keep it too + if dataset.collection_binding_id: + filtered_data["collection_binding_id"] = dataset.collection_binding_id + # Otherwise, don't try to update embedding model settings at all + # Remove these fields from filtered_data if they exist but are None/empty + if "embedding_model_provider" in filtered_data and not filtered_data["embedding_model_provider"]: + del filtered_data["embedding_model_provider"] + if "embedding_model" in filtered_data and not filtered_data["embedding_model"]: + del filtered_data["embedding_model"] + + @staticmethod + def _update_embedding_model_settings(dataset, data, filtered_data): + """ + Update embedding model settings with new values. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data to modify + + Returns: + str: Action to perform ('update' or None) + """ + try: + # Compare current and new model provider settings + current_provider_str = ( + str(ModelProviderID(dataset.embedding_model_provider)) if dataset.embedding_model_provider else None + ) + new_provider_str = ( + str(ModelProviderID(data["embedding_model_provider"])) if data["embedding_model_provider"] else None + ) + + # Only update if values are different + if current_provider_str != new_provider_str or data["embedding_model"] != dataset.embedding_model: + DatasetService._apply_new_embedding_settings(dataset, data, filtered_data) + return "update" + except LLMBadRequestError: + raise ValueError( + "No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider." + ) + except ProviderTokenNotInitError as ex: + raise ValueError(ex.description) + return None + + @staticmethod + def _apply_new_embedding_settings(dataset, data, filtered_data): + """ + Apply new embedding model settings to the dataset. + + Args: + dataset: Current dataset object + data: Update data dictionary + filtered_data: Filtered update data to modify + """ + model_manager = ModelManager() + try: + embedding_model = model_manager.get_model_instance( + tenant_id=current_user.current_tenant_id, + provider=data["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=data["embedding_model"], + ) + except ProviderTokenNotInitError: + # If we can't get the embedding model, preserve existing settings + if dataset.embedding_model_provider and dataset.embedding_model: + filtered_data["embedding_model_provider"] = dataset.embedding_model_provider + filtered_data["embedding_model"] = dataset.embedding_model + if dataset.collection_binding_id: + filtered_data["collection_binding_id"] = dataset.collection_binding_id + # Skip the rest of the embedding model update + return + + # Apply new embedding model settings + filtered_data["embedding_model"] = embedding_model.model + filtered_data["embedding_model_provider"] = embedding_model.provider + dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding( + embedding_model.provider, embedding_model.model + ) + filtered_data["collection_binding_id"] = dataset_collection_binding.id + @staticmethod def delete_dataset(dataset_id, user): dataset = DatasetService.get_dataset(dataset_id) diff --git a/api/services/metadata_service.py b/api/services/metadata_service.py index 26d6d4ce18..cfcb121153 100644 --- a/api/services/metadata_service.py +++ b/api/services/metadata_service.py @@ -19,6 +19,10 @@ from services.entities.knowledge_entities.knowledge_entities import ( class MetadataService: @staticmethod def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata: + # check if metadata name is too long + if len(metadata_args.name) > 255: + raise ValueError("Metadata name cannot exceed 255 characters.") + # check if metadata name already exists if ( db.session.query(DatasetMetadata) @@ -42,6 +46,10 @@ class MetadataService: @staticmethod def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore + # check if metadata name is too long + if len(name) > 255: + raise ValueError("Metadata name cannot exceed 255 characters.") + lock_key = f"dataset_metadata_lock_{dataset_id}" # check if metadata name already exists if ( diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index cd30440b4f..164693c2e1 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -129,7 +129,8 @@ class WorkflowDraftVariableService: ) -> list[WorkflowDraftVariable]: ors = [] for selector in selectors: - node_id, name = selector + assert len(selector) >= MIN_SELECTORS_LENGTH, f"Invalid selector to get: {selector}" + node_id, name = selector[:2] ors.append(and_(WorkflowDraftVariable.node_id == node_id, WorkflowDraftVariable.name == name)) # NOTE(QuantumGhost): Although the number of `or` expressions may be large, as long as diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py new file mode 100644 index 0000000000..8712b61a23 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -0,0 +1,36 @@ +from core.workflow.nodes.base.node import BaseNode +from core.workflow.nodes.enums import NodeType + +# Ensures that all node classes are imported. +from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING + +_ = NODE_TYPE_CLASSES_MAPPING + + +def _get_all_subclasses(root: type[BaseNode]) -> list[type[BaseNode]]: + subclasses = [] + queue = [root] + while queue: + cls = queue.pop() + + subclasses.extend(cls.__subclasses__()) + queue.extend(cls.__subclasses__()) + + return subclasses + + +def test_ensure_subclasses_of_base_node_has_node_type_and_version_method_defined(): + classes = _get_all_subclasses(BaseNode) # type: ignore + type_version_set: set[tuple[NodeType, str]] = set() + + for cls in classes: + # Validate that 'version' is directly defined in the class (not inherited) by checking the class's __dict__ + assert "version" in cls.__dict__, f"class {cls} should have version method defined (NOT INHERITED.)" + node_type = cls._node_type + node_version = cls.version() + + assert isinstance(cls._node_type, NodeType) + assert isinstance(node_version, str) + node_type_and_version = (node_type, node_version) + assert node_type_and_version not in type_version_set + type_version_set.add(node_type_and_version) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 76bb640d1c..66c7818adf 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -342,3 +342,26 @@ def test_extract_text_from_excel_all_sheets_fail(mock_excel_file): assert result == "" assert mock_excel_instance.parse.call_count == 2 + + +@patch("pandas.ExcelFile") +def test_extract_text_from_excel_numeric_type_column(mock_excel_file): + """Test extracting text from Excel file with numeric column names.""" + + # Test numeric type column + data = {1: ["Test"], 1.1: ["Test"]} + + df = pd.DataFrame(data) + + # Mock ExcelFile + mock_excel_instance = Mock() + mock_excel_instance.sheet_names = ["Sheet1"] + mock_excel_instance.parse.return_value = df + mock_excel_file.return_value = mock_excel_instance + + file_content = b"fake_excel_content" + result = _extract_text_from_excel(file_content) + + expected_manual = "| 1.0 | 1.1 |\n| --- | --- |\n| Test | Test |\n\n" + + assert expected_manual == result diff --git a/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py new file mode 100644 index 0000000000..f1cb937bb3 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/utils/test_variable_utils.py @@ -0,0 +1,148 @@ +from typing import Any + +from core.variables.segments import ObjectSegment, StringSegment +from core.workflow.entities.variable_pool import VariablePool +from core.workflow.utils.variable_utils import append_variables_recursively + + +class TestAppendVariablesRecursively: + """Test cases for append_variables_recursively function""" + + def test_append_simple_dict_value(self): + """Test appending a simple dictionary value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["output"] + variable_value = {"name": "John", "age": 30} + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert main_var.value == variable_value + + # Check that nested variables are added recursively + name_var = pool.get([node_id] + variable_key_list + ["name"]) + assert name_var is not None + assert name_var.value == "John" + + age_var = pool.get([node_id] + variable_key_list + ["age"]) + assert age_var is not None + assert age_var.value == 30 + + def test_append_object_segment_value(self): + """Test appending an ObjectSegment value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["result"] + + # Create an ObjectSegment + obj_data = {"status": "success", "code": 200} + variable_value = ObjectSegment(value=obj_data) + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert isinstance(main_var, ObjectSegment) + assert main_var.value == obj_data + + # Check that nested variables are added recursively + status_var = pool.get([node_id] + variable_key_list + ["status"]) + assert status_var is not None + assert status_var.value == "success" + + code_var = pool.get([node_id] + variable_key_list + ["code"]) + assert code_var is not None + assert code_var.value == 200 + + def test_append_nested_dict_value(self): + """Test appending a nested dictionary value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["data"] + + variable_value = { + "user": { + "profile": {"name": "Alice", "email": "alice@example.com"}, + "settings": {"theme": "dark", "notifications": True}, + }, + "metadata": {"version": "1.0", "timestamp": 1234567890}, + } + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check deeply nested variables + name_var = pool.get([node_id] + variable_key_list + ["user", "profile", "name"]) + assert name_var is not None + assert name_var.value == "Alice" + + email_var = pool.get([node_id] + variable_key_list + ["user", "profile", "email"]) + assert email_var is not None + assert email_var.value == "alice@example.com" + + theme_var = pool.get([node_id] + variable_key_list + ["user", "settings", "theme"]) + assert theme_var is not None + assert theme_var.value == "dark" + + notifications_var = pool.get([node_id] + variable_key_list + ["user", "settings", "notifications"]) + assert notifications_var is not None + assert notifications_var.value == 1 # Boolean True is converted to integer 1 + + version_var = pool.get([node_id] + variable_key_list + ["metadata", "version"]) + assert version_var is not None + assert version_var.value == "1.0" + + def test_append_non_dict_value(self): + """Test appending a non-dictionary value (should not recurse)""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["simple"] + variable_value = "simple_string" + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that only the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert main_var.value == variable_value + + # Ensure no additional variables are created + assert len(pool.variable_dictionary[node_id]) == 1 + + def test_append_segment_non_object_value(self): + """Test appending a Segment that is not ObjectSegment (should not recurse)""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["text"] + variable_value = StringSegment(value="Hello World") + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that only the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert isinstance(main_var, StringSegment) + assert main_var.value == "Hello World" + + # Ensure no additional variables are created + assert len(pool.variable_dictionary[node_id]) == 1 + + def test_append_empty_dict_value(self): + """Test appending an empty dictionary value""" + pool = VariablePool() + node_id = "test_node" + variable_key_list = ["empty"] + variable_value: dict[str, Any] = {} + + append_variables_recursively(pool, node_id, variable_key_list, variable_value) + + # Check that the main variable is added + main_var = pool.get([node_id] + variable_key_list) + assert main_var is not None + assert main_var.value == {} + + # Ensure only the main variable is created (no recursion for empty dict) + assert len(pool.variable_dictionary[node_id]) == 1 diff --git a/api/tests/unit_tests/services/test_dataset_service.py b/api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py similarity index 100% rename from api/tests/unit_tests/services/test_dataset_service.py rename to api/tests/unit_tests/services/test_dataset_service_batch_update_document_status.py diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py new file mode 100644 index 0000000000..15e1b7569f --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py @@ -0,0 +1,826 @@ +import datetime + +# Mock redis_client before importing dataset_service +from unittest.mock import Mock, patch + +import pytest + +from core.model_runtime.entities.model_entities import ModelType +from models.dataset import Dataset, ExternalKnowledgeBindings +from services.dataset_service import DatasetService +from services.errors.account import NoPermissionError +from tests.unit_tests.conftest import redis_mock + + +class TestDatasetServiceUpdateDataset: + """ + Comprehensive unit tests for DatasetService.update_dataset method. + + This test suite covers all supported scenarios including: + - External dataset updates + - Internal dataset updates with different indexing techniques + - Embedding model updates + - Permission checks + - Error conditions and edge cases + """ + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_external_dataset_success(self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db): + """ + Test successful update of external dataset. + + Verifies that: + 1. External dataset attributes are updated correctly + 2. External knowledge binding is updated when values change + 3. Database changes are committed + 4. Permission check is performed + """ + from unittest.mock import Mock, patch + + from extensions.ext_database import db + + with patch.object(db.__class__, "engine", new_callable=Mock): + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "external" + mock_dataset.name = "old_name" + mock_dataset.description = "old_description" + mock_dataset.retrieval_model = "old_model" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Create mock external knowledge binding + mock_binding = Mock(spec=ExternalKnowledgeBindings) + mock_binding.external_knowledge_id = "old_knowledge_id" + mock_binding.external_knowledge_api_id = "old_api_id" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock external knowledge binding query + with patch("services.dataset_service.Session") as mock_session: + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.query.return_value.filter_by.return_value.first.return_value = mock_binding + + # Test data + update_data = { + "name": "new_name", + "description": "new_description", + "external_retrieval_model": "new_model", + "permission": "only_me", + "external_knowledge_id": "new_knowledge_id", + "external_knowledge_api_id": "new_api_id", + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify permission check was called + mock_check_permission.assert_called_once_with(mock_dataset, mock_user) + + # Verify dataset attributes were updated + assert mock_dataset.name == "new_name" + assert mock_dataset.description == "new_description" + assert mock_dataset.retrieval_model == "new_model" + + # Verify external knowledge binding was updated + assert mock_binding.external_knowledge_id == "new_knowledge_id" + assert mock_binding.external_knowledge_api_id == "new_api_id" + + # Verify database operations + mock_db.add.assert_any_call(mock_dataset) + mock_db.add.assert_any_call(mock_binding) + mock_db.commit.assert_called_once() + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_external_dataset_missing_knowledge_id_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when external knowledge id is missing. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.provider = "external" + + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data without external_knowledge_id + update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"} + + # Call the method and expect ValueError + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "External knowledge id is required" in str(context.value) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_external_dataset_missing_api_id_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when external knowledge api id is missing. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.provider = "external" + + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data without external_knowledge_api_id + update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"} + + # Call the method and expect ValueError + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "External knowledge api id is required" in str(context.value) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.Session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_external_dataset_binding_not_found_error( + self, mock_check_permission, mock_get_dataset, mock_session, mock_db + ): + from unittest.mock import Mock, patch + + from extensions.ext_database import db + + with patch.object(db.__class__, "engine", new_callable=Mock): + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.provider = "external" + + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock external knowledge binding query returning None + mock_session_instance = Mock() + mock_session.return_value.__enter__.return_value = mock_session_instance + mock_session_instance.query.return_value.filter_by.return_value.first.return_value = None + + # Test data + update_data = { + "name": "new_name", + "external_knowledge_id": "knowledge_id", + "external_knowledge_api_id": "api_id", + } + + # Call the method and expect ValueError + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "External knowledge binding not found" in str(context.value) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_basic_success( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db + ): + """ + Test successful update of internal dataset with basic fields. + + Verifies that: + 1. Basic dataset attributes are updated correctly + 2. Filtered data excludes None values except description + 3. Timestamp fields are updated + 4. Database changes are committed + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.name = "old_name" + mock_dataset.description = "old_description" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.retrieval_model = "old_model" + mock_dataset.embedding_model_provider = "openai" + mock_dataset.embedding_model = "text-embedding-ada-002" + mock_dataset.collection_binding_id = "binding-123" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data + update_data = { + "name": "new_name", + "description": "new_description", + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify permission check was called + mock_check_permission.assert_called_once_with(mock_dataset, mock_user) + + # Verify database update was called with correct filtered data + expected_filtered_data = { + "name": "new_name", + "description": "new_description", + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.deal_dataset_vector_index_task") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_indexing_technique_to_economy( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_task, mock_db + ): + """ + Test updating internal dataset indexing technique to economy. + + Verifies that: + 1. Embedding model fields are cleared when switching to economy + 2. Vector index task is triggered with 'remove' action + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data + update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"} + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify database update was called with embedding model fields cleared + expected_filtered_data = { + "indexing_technique": "economy", + "embedding_model": None, + "embedding_model_provider": None, + "collection_binding_id": None, + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify vector index task was triggered + mock_task.delay.assert_called_once_with("dataset-123", "remove") + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_dataset_not_found_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when dataset is not found. + """ + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval returning None + mock_get_dataset.return_value = None + + # Test data + update_data = {"name": "new_name"} + + # Call the method and expect ValueError + with pytest.raises(ValueError) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "Dataset not found" in str(context.value) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_dataset_permission_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when user doesn't have permission. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + + # Create mock user + mock_user = Mock() + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock permission check to raise error + mock_check_permission.side_effect = NoPermissionError("No permission") + + # Test data + update_data = {"name": "new_name"} + + # Call the method and expect NoPermissionError + with pytest.raises(NoPermissionError): + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_keep_existing_embedding_model( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db + ): + """ + Test updating internal dataset without changing embedding model. + + Verifies that: + 1. Existing embedding model settings are preserved when not provided in update + 2. No vector index task is triggered + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.embedding_model_provider = "openai" + mock_dataset.embedding_model = "text-embedding-ada-002" + mock_dataset.collection_binding_id = "binding-123" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data without embedding model fields + update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"} + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify database update was called with existing embedding model preserved + expected_filtered_data = { + "name": "new_name", + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "collection_binding_id": "binding-123", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding") + @patch("services.dataset_service.ModelManager") + @patch("services.dataset_service.deal_dataset_vector_index_task") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_indexing_technique_to_high_quality( + self, + mock_datetime, + mock_check_permission, + mock_get_dataset, + mock_task, + mock_model_manager, + mock_collection_binding, + mock_db, + ): + """ + Test updating internal dataset indexing technique to high_quality. + + Verifies that: + 1. Embedding model is validated and set + 2. Collection binding is retrieved + 3. Vector index task is triggered with 'add' action + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "economy" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock embedding model + mock_embedding_model = Mock() + mock_embedding_model.model = "text-embedding-ada-002" + mock_embedding_model.provider = "openai" + + # Mock collection binding + mock_collection_binding_instance = Mock() + mock_collection_binding_instance.id = "binding-456" + + # Mock model manager + mock_model_manager_instance = Mock() + mock_model_manager_instance.get_model_instance.return_value = mock_embedding_model + mock_model_manager.return_value = mock_model_manager_instance + + # Mock collection binding service + mock_collection_binding.return_value = mock_collection_binding_instance + + # Mock current_user + mock_current_user = Mock() + mock_current_user.current_tenant_id = "tenant-123" + + # Test data + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "retrieval_model": "new_model", + } + + # Call the method with current_user mock + with patch("services.dataset_service.current_user", mock_current_user): + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify embedding model was validated + mock_model_manager_instance.get_model_instance.assert_called_once_with( + tenant_id=mock_current_user.current_tenant_id, + provider="openai", + model_type=ModelType.TEXT_EMBEDDING, + model="text-embedding-ada-002", + ) + + # Verify collection binding was retrieved + mock_collection_binding.assert_called_once_with("openai", "text-embedding-ada-002") + + # Verify database update was called with correct data + expected_filtered_data = { + "indexing_technique": "high_quality", + "embedding_model": "text-embedding-ada-002", + "embedding_model_provider": "openai", + "collection_binding_id": "binding-456", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify vector index task was triggered + mock_task.delay.assert_called_once_with("dataset-123", "add") + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + def test_update_internal_dataset_embedding_model_error(self, mock_check_permission, mock_get_dataset, mock_db): + """ + Test error when embedding model is not available. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "economy" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock current_user + mock_current_user = Mock() + mock_current_user.current_tenant_id = "tenant-123" + + # Mock model manager to raise error + with ( + patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.current_user", mock_current_user), + ): + mock_model_manager_instance = Mock() + mock_model_manager_instance.get_model_instance.side_effect = Exception("No Embedding Model available") + mock_model_manager.return_value = mock_model_manager_instance + + # Test data + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "invalid_provider", + "embedding_model": "invalid_model", + "retrieval_model": "new_model", + } + + # Call the method and expect ValueError + with pytest.raises(Exception) as context: + DatasetService.update_dataset("dataset-123", update_data, mock_user) + + assert "No Embedding Model available".lower() in str(context.value).lower() + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_filter_none_values( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db + ): + """ + Test that None values are filtered out except for description field. + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data with None values + update_data = { + "name": "new_name", + "description": None, # Should be included + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": None, # Should be filtered out + "embedding_model": None, # Should be filtered out + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify database update was called with filtered data + expected_filtered_data = { + "name": "new_name", + "description": None, # Description should be included even if None + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": mock_db.query.return_value.filter_by.return_value.update.call_args[0][0]["updated_at"], + } + + actual_call_args = mock_db.query.return_value.filter_by.return_value.update.call_args[0][0] + # Remove timestamp for comparison as it's dynamic + del actual_call_args["updated_at"] + del expected_filtered_data["updated_at"] + + del actual_call_args["collection_binding_id"] + del actual_call_args["embedding_model"] + del actual_call_args["embedding_model_provider"] + + assert actual_call_args == expected_filtered_data + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.deal_dataset_vector_index_task") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_embedding_model_update( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_task, mock_db + ): + """ + Test updating internal dataset with new embedding model. + + Verifies that: + 1. Embedding model is updated when different from current + 2. Vector index task is triggered with 'update' action + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.embedding_model_provider = "openai" + mock_dataset.embedding_model = "text-embedding-ada-002" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Mock embedding model + mock_embedding_model = Mock() + mock_embedding_model.model = "text-embedding-3-small" + mock_embedding_model.provider = "openai" + + # Mock collection binding + mock_collection_binding_instance = Mock() + mock_collection_binding_instance.id = "binding-789" + + # Mock current_user + mock_current_user = Mock() + mock_current_user.current_tenant_id = "tenant-123" + + # Mock model manager + with patch("services.dataset_service.ModelManager") as mock_model_manager: + mock_model_manager_instance = Mock() + mock_model_manager_instance.get_model_instance.return_value = mock_embedding_model + mock_model_manager.return_value = mock_model_manager_instance + + # Mock collection binding service + with ( + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" + ) as mock_collection_binding, + patch("services.dataset_service.current_user", mock_current_user), + ): + mock_collection_binding.return_value = mock_collection_binding_instance + + # Test data + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-3-small", + "retrieval_model": "new_model", + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify embedding model was validated + mock_model_manager_instance.get_model_instance.assert_called_once_with( + tenant_id=mock_current_user.current_tenant_id, + provider="openai", + model_type=ModelType.TEXT_EMBEDDING, + model="text-embedding-3-small", + ) + + # Verify collection binding was retrieved + mock_collection_binding.assert_called_once_with("openai", "text-embedding-3-small") + + # Verify database update was called with correct data + expected_filtered_data = { + "indexing_technique": "high_quality", + "embedding_model": "text-embedding-3-small", + "embedding_model_provider": "openai", + "collection_binding_id": "binding-789", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify vector index task was triggered + mock_task.delay.assert_called_once_with("dataset-123", "update") + + # Verify return value + assert result == mock_dataset + + @patch("extensions.ext_database.db.session") + @patch("services.dataset_service.DatasetService.get_dataset") + @patch("services.dataset_service.DatasetService.check_dataset_permission") + @patch("services.dataset_service.datetime") + def test_update_internal_dataset_no_indexing_technique_change( + self, mock_datetime, mock_check_permission, mock_get_dataset, mock_db + ): + """ + Test updating internal dataset without changing indexing technique. + + Verifies that: + 1. No vector index task is triggered when indexing technique doesn't change + 2. Database update is performed normally + """ + # Create mock dataset + mock_dataset = Mock(spec=Dataset) + mock_dataset.id = "dataset-123" + mock_dataset.provider = "vendor" + mock_dataset.indexing_technique = "high_quality" + mock_dataset.embedding_model_provider = "openai" + mock_dataset.embedding_model = "text-embedding-ada-002" + mock_dataset.collection_binding_id = "binding-123" + + # Create mock user + mock_user = Mock() + mock_user.id = "user-789" + + # Set up mock return values + current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) + mock_datetime.datetime.now.return_value = current_time + mock_datetime.UTC = datetime.UTC + + # Mock dataset retrieval + mock_get_dataset.return_value = mock_dataset + + # Test data with same indexing technique + update_data = { + "name": "new_name", + "indexing_technique": "high_quality", # Same as current + "retrieval_model": "new_model", + } + + # Call the method + result = DatasetService.update_dataset("dataset-123", update_data, mock_user) + + # Verify database update was called with correct data + expected_filtered_data = { + "name": "new_name", + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "collection_binding_id": "binding-123", + "retrieval_model": "new_model", + "updated_by": mock_user.id, + "updated_at": current_time.replace(tzinfo=None), + } + + mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_filtered_data) + mock_db.commit.assert_called_once() + + # Verify no vector index task was triggered + mock_db.query.return_value.filter_by.return_value.update.assert_called_once() + + # Verify return value + assert result == mock_dataset diff --git a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx index 1eec51992d..437e25fde4 100644 --- a/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/advanced-prompt-input.tsx @@ -227,7 +227,7 @@ const AdvancedPromptInput: FC = ({ }} variableBlock={{ show: true, - variables: modelConfig.configs.prompt_variables.filter(item => item.type !== 'api').map(item => ({ + variables: modelConfig.configs.prompt_variables.filter(item => item.type !== 'api' && item.key && item.key.trim() && item.name && item.name.trim()).map(item => ({ name: item.name, value: item.key, })), diff --git a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx index 2a9a15296e..eb0f524386 100644 --- a/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx +++ b/web/app/components/app/configuration/config-prompt/simple-prompt-input.tsx @@ -97,20 +97,31 @@ const Prompt: FC = ({ }, }) } - const promptVariablesObj = (() => { - const obj: Record = {} - promptVariables.forEach((item) => { - obj[item.key] = true - }) - return obj - })() const [newPromptVariables, setNewPromptVariables] = React.useState(promptVariables) const [newTemplates, setNewTemplates] = React.useState('') const [isShowConfirmAddVar, { setTrue: showConfirmAddVar, setFalse: hideConfirmAddVar }] = useBoolean(false) const handleChange = (newTemplates: string, keys: string[]) => { - const newPromptVariables = keys.filter(key => !(key in promptVariablesObj) && !externalDataToolsConfig.find(item => item.variable === key)).map(key => getNewVar(key, '')) + // Filter out keys that are not properly defined (either not exist or exist but without valid name) + const newPromptVariables = keys.filter((key) => { + // Check if key exists in external data tools + if (externalDataToolsConfig.find((item: ExternalDataTool) => item.variable === key)) + return false + + // Check if key exists in prompt variables + const existingVar = promptVariables.find((item: PromptVariable) => item.key === key) + if (!existingVar) { + // Variable doesn't exist at all + return true + } + + // Variable exists but check if it has valid name and key + return !existingVar.name || !existingVar.name.trim() || !existingVar.key || !existingVar.key.trim() + + return false + }).map(key => getNewVar(key, '')) + if (newPromptVariables.length > 0) { setNewPromptVariables(newPromptVariables) setNewTemplates(newTemplates) @@ -210,14 +221,14 @@ const Prompt: FC = ({ }} variableBlock={{ show: true, - variables: modelConfig.configs.prompt_variables.filter(item => item.type !== 'api').map(item => ({ + variables: modelConfig.configs.prompt_variables.filter((item: PromptVariable) => item.type !== 'api' && item.key && item.key.trim() && item.name && item.name.trim()).map((item: PromptVariable) => ({ name: item.name, value: item.key, })), }} externalToolBlock={{ show: true, - externalTools: modelConfig.configs.prompt_variables.filter(item => item.type === 'api').map(item => ({ + externalTools: modelConfig.configs.prompt_variables.filter((item: PromptVariable) => item.type === 'api').map((item: PromptVariable) => ({ name: item.name, variableName: item.key, icon: item.icon, diff --git a/web/app/components/app/configuration/config/agent/prompt-editor.tsx b/web/app/components/app/configuration/config/agent/prompt-editor.tsx index 7f7f140eda..579b7c4d64 100644 --- a/web/app/components/app/configuration/config/agent/prompt-editor.tsx +++ b/web/app/components/app/configuration/config/agent/prompt-editor.tsx @@ -107,7 +107,7 @@ const Editor: FC = ({ }} variableBlock={{ show: true, - variables: modelConfig.configs.prompt_variables.map(item => ({ + variables: modelConfig.configs.prompt_variables.filter(item => item.key && item.key.trim() && item.name && item.name.trim()).map(item => ({ name: item.name, value: item.key, })), diff --git a/web/app/components/app/log/list.tsx b/web/app/components/app/log/list.tsx index 208fddecd1..dd72c6c810 100644 --- a/web/app/components/app/log/list.tsx +++ b/web/app/components/app/log/list.tsx @@ -41,6 +41,7 @@ import { buildChatItemTree, getThreadMessages } from '@/app/components/base/chat import { getProcessedFilesFromResponse } from '@/app/components/base/file-uploader/utils' import cn from '@/utils/classnames' import { noop } from 'lodash-es' +import PromptLogModal from '../../base/prompt-log-modal' dayjs.extend(utc) dayjs.extend(timezone) @@ -190,11 +191,13 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { const { userProfile: { timezone } } = useAppContext() const { formatTime } = useTimestamp() const { onClose, appDetail } = useContext(DrawerContext) - const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, currentLogModalActiveTab } = useAppStore(useShallow(state => ({ + const { currentLogItem, setCurrentLogItem, showMessageLogModal, setShowMessageLogModal, showPromptLogModal, setShowPromptLogModal, currentLogModalActiveTab } = useAppStore(useShallow(state => ({ currentLogItem: state.currentLogItem, setCurrentLogItem: state.setCurrentLogItem, showMessageLogModal: state.showMessageLogModal, setShowMessageLogModal: state.setShowMessageLogModal, + showPromptLogModal: state.showPromptLogModal, + setShowPromptLogModal: state.setShowPromptLogModal, currentLogModalActiveTab: state.currentLogModalActiveTab, }))) const { t } = useTranslation() @@ -516,6 +519,16 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) { defaultTab={currentLogModalActiveTab} /> )} + {!isChatMode && showPromptLogModal && ( + { + setCurrentLogItem() + setShowPromptLogModal(false) + }} + /> + )} ) } diff --git a/web/app/components/app/text-generate/item/index.tsx b/web/app/components/app/text-generate/item/index.tsx index aa3ffa33c4..92d86351e0 100644 --- a/web/app/components/app/text-generate/item/index.tsx +++ b/web/app/components/app/text-generate/item/index.tsx @@ -171,7 +171,7 @@ const GenerationItem: FC = ({ appId: params.appId as string, messageId: messageId!, }) - const logItem = { + const logItem = Array.isArray(data.message) ? { ...data, log: [ ...data.message, @@ -185,6 +185,11 @@ const GenerationItem: FC = ({ ] : []), ], + } : { + ...data, + log: [typeof data.message === 'string' ? { + text: data.message, + } : data.message], } setCurrentLogItem(logItem) setShowPromptLogModal(true) diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/variable-option.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/variable-option.tsx index 20c03768d0..57de34a701 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/variable-option.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/variable-option.tsx @@ -44,6 +44,10 @@ export const VariableMenuItem = memo(({ tabIndex={-1} ref={setRefElement} onMouseEnter={onMouseEnter} + onMouseDown={(e) => { + e.preventDefault() + e.stopPropagation() + }} onClick={onClick}>
{icon} diff --git a/web/app/components/datasets/common/document-picker/document-list.tsx b/web/app/components/datasets/common/document-picker/document-list.tsx index 5ac98b97fe..6b51973ccc 100644 --- a/web/app/components/datasets/common/document-picker/document-list.tsx +++ b/web/app/components/datasets/common/document-picker/document-list.tsx @@ -21,7 +21,7 @@ const DocumentList: FC = ({ }, [onChange]) return ( -
+
{list.map((item) => { const { id, name, extension } = item return ( diff --git a/web/app/components/datasets/metadata/hooks/use-check-metadata-name.ts b/web/app/components/datasets/metadata/hooks/use-check-metadata-name.ts index 45fe7675f1..aea7efd5c8 100644 --- a/web/app/components/datasets/metadata/hooks/use-check-metadata-name.ts +++ b/web/app/components/datasets/metadata/hooks/use-check-metadata-name.ts @@ -18,6 +18,12 @@ const useCheckMetadataName = () => { } } + if (name.length > 255) { + return { + errorMsg: t(`${i18nPrefix}.tooLong`, { max: 255 }), + } + } + return { errorMsg: '', } diff --git a/web/app/components/plugins/plugin-detail-panel/app-selector/app-picker.tsx b/web/app/components/plugins/plugin-detail-panel/app-selector/app-picker.tsx index 2f57276559..3c79acb653 100644 --- a/web/app/components/plugins/plugin-detail-panel/app-selector/app-picker.tsx +++ b/web/app/components/plugins/plugin-detail-panel/app-selector/app-picker.tsx @@ -1,7 +1,6 @@ 'use client' import type { FC } from 'react' -import React, { useMemo } from 'react' -import { useState } from 'react' +import React, { useCallback, useEffect, useRef } from 'react' import { PortalToFollowElem, PortalToFollowElemContent, @@ -14,9 +13,9 @@ import type { import Input from '@/app/components/base/input' import AppIcon from '@/app/components/base/app-icon' import type { App } from '@/types/app' +import { useTranslation } from 'react-i18next' type Props = { - appList: App[] scope: string disabled: boolean trigger: React.ReactNode @@ -25,11 +24,16 @@ type Props = { isShow: boolean onShowChange: (isShow: boolean) => void onSelect: (app: App) => void + apps: App[] + isLoading: boolean + hasMore: boolean + onLoadMore: () => void + searchText: string + onSearchChange: (text: string) => void } const AppPicker: FC = ({ scope, - appList, disabled, trigger, placement = 'right-start', @@ -37,19 +41,81 @@ const AppPicker: FC = ({ isShow, onShowChange, onSelect, + apps, + isLoading, + hasMore, + onLoadMore, + searchText, + onSearchChange, }) => { - const [searchText, setSearchText] = useState('') - const filteredAppList = useMemo(() => { - return (appList || []) - .filter(app => app.name.toLowerCase().includes(searchText.toLowerCase())) - .filter(app => (app.mode !== 'advanced-chat' && app.mode !== 'workflow') || !!app.workflow) - .filter(app => scope === 'all' - || (scope === 'completion' && app.mode === 'completion') - || (scope === 'workflow' && app.mode === 'workflow') - || (scope === 'chat' && app.mode === 'advanced-chat') - || (scope === 'chat' && app.mode === 'agent-chat') - || (scope === 'chat' && app.mode === 'chat')) - }, [appList, scope, searchText]) + const { t } = useTranslation() + const observerTarget = useRef(null) + const observerRef = useRef(null) + const loadingRef = useRef(false) + + const handleIntersection = useCallback((entries: IntersectionObserverEntry[]) => { + const target = entries[0] + if (!target.isIntersecting || loadingRef.current || !hasMore || isLoading) return + + loadingRef.current = true + onLoadMore() + // Reset loading state + setTimeout(() => { + loadingRef.current = false + }, 500) + }, [hasMore, isLoading, onLoadMore]) + + useEffect(() => { + if (!isShow) { + if (observerRef.current) { + observerRef.current.disconnect() + observerRef.current = null + } + return + } + + let mutationObserver: MutationObserver | null = null + + const setupIntersectionObserver = () => { + if (!observerTarget.current) return + + // Create new observer + observerRef.current = new IntersectionObserver(handleIntersection, { + root: null, + rootMargin: '100px', + threshold: 0.1, + }) + + observerRef.current.observe(observerTarget.current) + } + + // Set up MutationObserver to watch DOM changes + mutationObserver = new MutationObserver((mutations) => { + if (observerTarget.current) { + setupIntersectionObserver() + mutationObserver?.disconnect() + } + }) + + // Watch body changes since Portal adds content to body + mutationObserver.observe(document.body, { + childList: true, + subtree: true, + }) + + // If element exists, set up IntersectionObserver directly + if (observerTarget.current) + setupIntersectionObserver() + + return () => { + if (observerRef.current) { + observerRef.current.disconnect() + observerRef.current = null + } + mutationObserver?.disconnect() + } + }, [isShow, handleIntersection]) + const getAppType = (app: App) => { switch (app.mode) { case 'advanced-chat': @@ -84,18 +150,18 @@ const AppPicker: FC = ({ -
+
setSearchText(e.target.value)} - onClear={() => setSearchText('')} + onChange={e => onSearchChange(e.target.value)} + onClear={() => onSearchChange('')} />
-
- {filteredAppList.map(app => ( +
+ {apps.map(app => (
= ({
{getAppType(app)}
))} +
+ {isLoading && ( +
+
{t('common.loading')}
+
+ )} +
diff --git a/web/app/components/plugins/plugin-detail-panel/app-selector/index.tsx b/web/app/components/plugins/plugin-detail-panel/app-selector/index.tsx index 39ca957953..9c16c66a70 100644 --- a/web/app/components/plugins/plugin-detail-panel/app-selector/index.tsx +++ b/web/app/components/plugins/plugin-detail-panel/app-selector/index.tsx @@ -1,6 +1,6 @@ 'use client' import type { FC } from 'react' -import React, { useMemo, useState } from 'react' +import React, { useCallback, useMemo, useState } from 'react' import { useTranslation } from 'react-i18next' import { PortalToFollowElem, @@ -10,12 +10,36 @@ import { import AppTrigger from '@/app/components/plugins/plugin-detail-panel/app-selector/app-trigger' import AppPicker from '@/app/components/plugins/plugin-detail-panel/app-selector/app-picker' import AppInputsPanel from '@/app/components/plugins/plugin-detail-panel/app-selector/app-inputs-panel' -import { useAppFullList } from '@/service/use-apps' import type { App } from '@/types/app' import type { OffsetOptions, Placement, } from '@floating-ui/react' +import useSWRInfinite from 'swr/infinite' +import { fetchAppList } from '@/service/apps' +import type { AppListResponse } from '@/models/app' + +const PAGE_SIZE = 20 + +const getKey = ( + pageIndex: number, + previousPageData: AppListResponse, + searchText: string, +) => { + if (pageIndex === 0 || (previousPageData && previousPageData.has_more)) { + const params: any = { + url: 'apps', + params: { + page: pageIndex + 1, + limit: PAGE_SIZE, + name: searchText, + }, + } + + return params + } + return null +} type Props = { value?: { @@ -34,6 +58,7 @@ type Props = { }) => void supportAddCustomTool?: boolean } + const AppSelector: FC = ({ value, scope, @@ -44,18 +69,47 @@ const AppSelector: FC = ({ }) => { const { t } = useTranslation() const [isShow, onShowChange] = useState(false) + const [searchText, setSearchText] = useState('') + const [isLoadingMore, setIsLoadingMore] = useState(false) + + const { data, isLoading, setSize } = useSWRInfinite( + (pageIndex: number, previousPageData: AppListResponse) => getKey(pageIndex, previousPageData, searchText), + fetchAppList, + { + revalidateFirstPage: true, + shouldRetryOnError: false, + dedupingInterval: 500, + errorRetryCount: 3, + }, + ) + + const displayedApps = useMemo(() => { + if (!data) return [] + return data.flatMap(({ data: apps }) => apps) + }, [data]) + + const hasMore = data?.at(-1)?.has_more ?? true + + const handleLoadMore = useCallback(async () => { + if (isLoadingMore || !hasMore) return + + setIsLoadingMore(true) + try { + await setSize((size: number) => size + 1) + } + finally { + // Add a small delay to ensure state updates are complete + setTimeout(() => { + setIsLoadingMore(false) + }, 300) + } + }, [isLoadingMore, hasMore, setSize]) + const handleTriggerClick = () => { if (disabled) return onShowChange(true) } - const { data: appList } = useAppFullList() - const currentAppInfo = useMemo(() => { - if (!appList?.data || !value) - return undefined - return appList.data.find(app => app.id === value.app_id) - }, [appList?.data, value]) - const [isShowChooseApp, setIsShowChooseApp] = useState(false) const handleSelectApp = (app: App) => { const clearValue = app.id !== value?.app_id @@ -67,6 +121,7 @@ const AppSelector: FC = ({ onSelect(appValue) setIsShowChooseApp(false) } + const handleFormChange = (inputs: Record) => { const newFiles = inputs['#image#'] delete inputs['#image#'] @@ -88,6 +143,12 @@ const AppSelector: FC = ({ } }, [value]) + const currentAppInfo = useMemo(() => { + if (!displayedApps || !value) + return undefined + return displayedApps.find(app => app.id === value.app_id) + }, [displayedApps, value]) + return ( <> = ({ isShow={isShowChooseApp} onShowChange={setIsShowChooseApp} disabled={false} - appList={appList?.data || []} onSelect={handleSelectApp} scope={scope || 'all'} + apps={displayedApps} + isLoading={isLoading || isLoadingMore} + hasMore={hasMore} + onLoadMore={handleLoadMore} + searchText={searchText} + onSearchChange={setSearchText} />
{/* app inputs config panel */} @@ -140,4 +206,5 @@ const AppSelector: FC = ({ ) } + export default React.memo(AppSelector) diff --git a/web/app/components/workflow-app/hooks/use-workflow-template.ts b/web/app/components/workflow-app/hooks/use-workflow-template.ts index 9f47b981dc..2bab08f205 100644 --- a/web/app/components/workflow-app/hooks/use-workflow-template.ts +++ b/web/app/components/workflow-app/hooks/use-workflow-template.ts @@ -22,7 +22,7 @@ export const useWorkflowTemplate = () => { ...nodesInitialData.llm, memory: { window: { enabled: false, size: 10 }, - query_prompt_template: '{{#sys.query#}}', + query_prompt_template: '{{#sys.query#}}\n\n{{#sys.files#}}', }, selected: true, }, diff --git a/web/app/components/workflow/nodes/_base/components/memory-config.tsx b/web/app/components/workflow/nodes/_base/components/memory-config.tsx index 446fcfa8ae..168ad656b6 100644 --- a/web/app/components/workflow/nodes/_base/components/memory-config.tsx +++ b/web/app/components/workflow/nodes/_base/components/memory-config.tsx @@ -53,7 +53,7 @@ type Props = { const MEMORY_DEFAULT: Memory = { window: { enabled: false, size: WINDOW_SIZE_DEFAULT }, - query_prompt_template: '{{#sys.query#}}', + query_prompt_template: '{{#sys.query#}}\n\n{{#sys.files#}}', } const MemoryConfig: FC = ({ diff --git a/web/app/components/workflow/nodes/assigner/use-single-run-form-params.ts b/web/app/components/workflow/nodes/assigner/use-single-run-form-params.ts index c6d483a3ee..7ff31d91c7 100644 --- a/web/app/components/workflow/nodes/assigner/use-single-run-form-params.ts +++ b/web/app/components/workflow/nodes/assigner/use-single-run-form-params.ts @@ -24,7 +24,7 @@ const useSingleRunFormParams = ({ }: Params) => { const { inputs } = useNodeCrud(id, payload) - const vars = inputs.items.filter((item) => { + const vars = (inputs.items ?? []).filter((item) => { return item.operation !== WriteMode.clear && item.operation !== WriteMode.set && item.operation !== WriteMode.removeFirst && item.operation !== WriteMode.removeLast && !writeModeTypesNum.includes(item.operation) diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx index 9c82ee748b..3240496b62 100644 --- a/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx +++ b/web/app/components/workflow/panel/chat-variable-panel/components/variable-modal.tsx @@ -324,7 +324,7 @@ const ChatVariableModal = ({ {type === ChatVarType.String && ( // Input will remove \n\r, so use Textarea just like description area