diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index e41375e52b..c70bf84d2a 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -29,7 +29,7 @@ from libs.login import login_required from services.plugin.oauth_service import OAuthProxyService from services.tools.api_tools_manage_service import ApiToolManageService from services.tools.builtin_tools_manage_service import BuiltinToolManageService -from services.tools.mcp_tools_mange_service import MCPToolManageService +from services.tools.mcp_tools_manage_service import MCPToolManageService from services.tools.tool_labels_service import ToolLabelsService from services.tools.tools_manage_service import ToolCommonService from services.tools.tools_transform_service import ToolTransformService diff --git a/api/core/mcp/auth/auth_provider.py b/api/core/mcp/auth/auth_provider.py index cd55dbf64f..00d5a25956 100644 --- a/api/core/mcp/auth/auth_provider.py +++ b/api/core/mcp/auth/auth_provider.py @@ -8,7 +8,7 @@ from core.mcp.types import ( OAuthTokens, ) from models.tools import MCPToolProvider -from services.tools.mcp_tools_mange_service import MCPToolManageService +from services.tools.mcp_tools_manage_service import MCPToolManageService LATEST_PROTOCOL_VERSION = "1.0" diff --git a/api/core/mcp/mcp_client.py b/api/core/mcp/mcp_client.py index e9036de8c6..f7aa7bbd7b 100644 --- a/api/core/mcp/mcp_client.py +++ b/api/core/mcp/mcp_client.py @@ -68,15 +68,17 @@ class MCPClient: } parsed_url = urlparse(self.server_url) - path = parsed_url.path - method_name = path.rstrip("/").split("/")[-1] if path else "" - try: + path = parsed_url.path or "" + method_name = path.removesuffix("/").lower() + if method_name in connection_methods: client_factory = connection_methods[method_name] self.connect_server(client_factory, method_name) - except KeyError: + else: try: + logger.debug(f"Not supported method {method_name} found in URL path, trying default 'mcp' method.") self.connect_server(sse_client, "sse") except MCPConnectionError: + logger.debug("MCP connection failed with 'sse', falling back to 'mcp' method.") self.connect_server(streamablehttp_client, "mcp") def connect_server( @@ -91,7 +93,7 @@ class MCPClient: else {} ) self._streams_context = client_factory(url=self.server_url, headers=headers) - if self._streams_context is None: + if not self._streams_context: raise MCPConnectionError("Failed to create connection context") # Use exit_stack to manage context managers properly @@ -141,10 +143,11 @@ class MCPClient: try: # ExitStack will handle proper cleanup of all managed context managers self.exit_stack.close() + except Exception as e: + logging.exception("Error during cleanup") + raise ValueError(f"Error during cleanup: {e}") + finally: self._session = None self._session_context = None self._streams_context = None self._initialized = False - except Exception as e: - logging.exception("Error during cleanup") - raise ValueError(f"Error during cleanup: {e}") diff --git a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py index 095752ea8e..6f3e15d166 100644 --- a/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py +++ b/api/core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py @@ -233,6 +233,12 @@ class AnalyticdbVectorOpenAPI: def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause += f"metadata_->>'document_id' IN ({document_ids})" + score_threshold = kwargs.get("score_threshold") or 0.0 request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, @@ -245,7 +251,7 @@ class AnalyticdbVectorOpenAPI: vector=query_vector, content=None, top_k=kwargs.get("top_k", 4), - filter=None, + filter=where_clause, ) response = self._client.query_collection_data(request) documents = [] @@ -265,6 +271,11 @@ class AnalyticdbVectorOpenAPI: def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: from alibabacloud_gpdb20160503 import models as gpdb_20160503_models + document_ids_filter = kwargs.get("document_ids_filter") + where_clause = "" + if document_ids_filter: + document_ids = ", ".join(f"'{id}'" for id in document_ids_filter) + where_clause += f"metadata_->>'document_id' IN ({document_ids})" score_threshold = float(kwargs.get("score_threshold") or 0.0) request = gpdb_20160503_models.QueryCollectionDataRequest( dbinstance_id=self.config.instance_id, @@ -277,7 +288,7 @@ class AnalyticdbVectorOpenAPI: vector=None, content=query, top_k=kwargs.get("top_k", 4), - filter=None, + filter=where_clause, ) response = self._client.query_collection_data(request) documents = [] diff --git a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py index 44cc5d3e98..ad39717183 100644 --- a/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py +++ b/api/core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py @@ -147,10 +147,17 @@ class ElasticSearchVector(BaseVector): return docs def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]: - query_str = {"match": {Field.CONTENT_KEY.value: query}} + query_str: dict[str, Any] = {"match": {Field.CONTENT_KEY.value: query}} document_ids_filter = kwargs.get("document_ids_filter") + if document_ids_filter: - query_str["filter"] = {"terms": {"metadata.document_id": document_ids_filter}} # type: ignore + query_str = { + "bool": { + "must": {"match": {Field.CONTENT_KEY.value: query}}, + "filter": {"terms": {"metadata.document_id": document_ids_filter}}, + } + } + results = self._client.search(index=self._collection_name, query=query_str, size=kwargs.get("top_k", 4)) docs = [] for hit in results["hits"]["hits"]: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index d61856a8f5..7822bc389c 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -21,7 +21,7 @@ from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.workflow.entities.variable_pool import VariablePool -from services.tools.mcp_tools_mange_service import MCPToolManageService +from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: from core.workflow.nodes.tool.entities import ToolEntity diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index a4616eda69..704eb6a3ac 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -270,7 +270,14 @@ class AgentNode(BaseNode): ) extra = tool.get("extra", {}) - runtime_variable_pool = variable_pool if self._node_data.version != "1" else None + + # This is an issue that caused problems before. + # Logically, we shouldn't use the node_data.version field for judgment + # But for backward compatibility with historical data + # this version field judgment is still preserved here. + runtime_variable_pool: VariablePool | None = None + if node_data.version != "1" or node_data.tool_node_version != "1": + runtime_variable_pool = variable_pool tool_runtime = ToolManager.get_agent_tool_runtime( self.tenant_id, self.app_id, entity, self.invoke_from, runtime_variable_pool ) diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 075a41fb2f..11b11068e7 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -13,6 +13,10 @@ class AgentNodeData(BaseNodeData): agent_strategy_name: str agent_strategy_label: str # redundancy memory: MemoryConfig | None = None + # The version of the tool parameter. + # If this value is None, it indicates this is a previous version + # and requires using the legacy parameter parsing rules. + tool_node_version: str | None = None class AgentInput(BaseModel): value: Union[list[str], list[ToolSelector], Any] diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index e9122b1eec..f1767bdf9e 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -118,7 +118,7 @@ class KnowledgeRetrievalNodeData(BaseNodeData): multiple_retrieval_config: Optional[MultipleRetrievalConfig] = None single_retrieval_config: Optional[SingleRetrievalConfig] = None metadata_filtering_mode: Optional[Literal["disabled", "automatic", "manual"]] = "disabled" - metadata_model_config: ModelConfig + metadata_model_config: Optional[ModelConfig] = None metadata_filtering_conditions: Optional[MetadataFilteringCondition] = None vision: VisionConfig = Field(default_factory=VisionConfig) diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 4e9a38f552..5f092dc2f1 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -509,6 +509,8 @@ class KnowledgeRetrievalNode(BaseNode): # get all metadata field metadata_fields = db.session.query(DatasetMetadata).filter(DatasetMetadata.dataset_id.in_(dataset_ids)).all() all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] + if node_data.metadata_model_config is None: + raise ValueError("metadata_model_config is required") # get metadata model instance and fetch model config model_instance, model_config = self.get_model_config(node_data.metadata_model_config) # fetch prompt messages @@ -701,7 +703,7 @@ class KnowledgeRetrievalNode(BaseNode): ) def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str): - model_mode = ModelMode(node_data.metadata_model_config.mode) + model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore input_text = query prompt_messages: list[LLMNodeChatModelMessage] = [] diff --git a/api/core/workflow/nodes/node_mapping.py b/api/core/workflow/nodes/node_mapping.py index ccfaec4a8c..294b47670b 100644 --- a/api/core/workflow/nodes/node_mapping.py +++ b/api/core/workflow/nodes/node_mapping.py @@ -73,6 +73,9 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { }, NodeType.TOOL: { LATEST_VERSION: ToolNode, + # This is an issue that caused problems before. + # Logically, we shouldn't use two different versions to point to the same class here, + # but in order to maintain compatibility with historical data, this approach has been retained. "2": ToolNode, "1": ToolNode, }, @@ -123,6 +126,9 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = { }, NodeType.AGENT: { LATEST_VERSION: AgentNode, + # This is an issue that caused problems before. + # Logically, we shouldn't use two different versions to point to the same class here, + # but in order to maintain compatibility with historical data, this approach has been retained. "2": AgentNode, "1": AgentNode, }, diff --git a/api/core/workflow/nodes/tool/entities.py b/api/core/workflow/nodes/tool/entities.py index 88c5160d14..f0a44d919b 100644 --- a/api/core/workflow/nodes/tool/entities.py +++ b/api/core/workflow/nodes/tool/entities.py @@ -59,6 +59,10 @@ class ToolNodeData(BaseNodeData, ToolEntity): return typ tool_parameters: dict[str, ToolInput] + # The version of the tool parameter. + # If this value is None, it indicates this is a previous version + # and requires using the legacy parameter parsing rules. + tool_node_version: str | None = None @field_validator("tool_parameters", mode="before") @classmethod diff --git a/api/core/workflow/nodes/tool/tool_node.py b/api/core/workflow/nodes/tool/tool_node.py index c565ad15c1..140fe71f60 100644 --- a/api/core/workflow/nodes/tool/tool_node.py +++ b/api/core/workflow/nodes/tool/tool_node.py @@ -70,7 +70,13 @@ class ToolNode(BaseNode): try: from core.tools.tool_manager import ToolManager - variable_pool = self.graph_runtime_state.variable_pool if self._node_data.version != "1" else None + # This is an issue that caused problems before. + # Logically, we shouldn't use the node_data.version field for judgment + # But for backward compatibility with historical data + # this version field judgment is still preserved here. + variable_pool: VariablePool | None = None + if node_data.version != "1" or node_data.tool_node_version != "1": + variable_pool = self.graph_runtime_state.variable_pool tool_runtime = ToolManager.get_workflow_tool_runtime( self.tenant_id, self.app_id, self.node_id, self._node_data, self.invoke_from, variable_pool ) diff --git a/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py b/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py new file mode 100644 index 0000000000..3bdbafda7c --- /dev/null +++ b/api/migrations/versions/2025_07_21_0935-1a83934ad6d1_update_models.py @@ -0,0 +1,51 @@ +"""update models + +Revision ID: 1a83934ad6d1 +Revises: 71f5020c6470 +Create Date: 2025-07-21 09:35:48.774794 + +""" +from alembic import op +import models as models +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '1a83934ad6d1' +down_revision = '71f5020c6470' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op: + batch_op.alter_column('server_identifier', + existing_type=sa.VARCHAR(length=24), + type_=sa.String(length=64), + existing_nullable=False) + + with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: + batch_op.alter_column('tool_name', + existing_type=sa.VARCHAR(length=40), + type_=sa.String(length=128), + existing_nullable=False) + + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('tool_model_invokes', schema=None) as batch_op: + batch_op.alter_column('tool_name', + existing_type=sa.String(length=128), + type_=sa.VARCHAR(length=40), + existing_nullable=False) + + with op.batch_alter_table('tool_mcp_providers', schema=None) as batch_op: + batch_op.alter_column('server_identifier', + existing_type=sa.String(length=64), + type_=sa.VARCHAR(length=24), + existing_nullable=False) + + # ### end Alembic commands ### diff --git a/api/models/tools.py b/api/models/tools.py index f6baa77166..a0b7e54175 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -254,7 +254,7 @@ class MCPToolProvider(Base): # name of the mcp provider name: Mapped[str] = mapped_column(db.String(40), nullable=False) # server identifier of the mcp provider - server_identifier: Mapped[str] = mapped_column(db.String(24), nullable=False) + server_identifier: Mapped[str] = mapped_column(db.String(64), nullable=False) # encrypted url of the mcp provider server_url: Mapped[str] = mapped_column(db.Text, nullable=False) # hash of server_url for uniqueness check @@ -358,7 +358,7 @@ class ToolModelInvoke(Base): # type tool_type = mapped_column(db.String(40), nullable=False) # tool name - tool_name = mapped_column(db.String(40), nullable=False) + tool_name = mapped_column(db.String(128), nullable=False) # invoke parameters model_parameters = mapped_column(db.Text, nullable=False) # prompt messages diff --git a/api/services/account_service.py b/api/services/account_service.py index a664c312e6..c88e70e380 100644 --- a/api/services/account_service.py +++ b/api/services/account_service.py @@ -1067,15 +1067,6 @@ class TenantService: target_member_join.role = new_role db.session.commit() - @staticmethod - def dissolve_tenant(tenant: Tenant, operator: Account) -> None: - """Dissolve tenant""" - if not TenantService.check_member_permission(tenant, operator, operator, "remove"): - raise NoPermissionError("No permission to dissolve tenant.") - db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id).delete() - db.session.delete(tenant) - db.session.commit() - @staticmethod def get_custom_config(tenant_id: str) -> dict: tenant = db.get_or_404(Tenant, tenant_id) diff --git a/api/services/tools/mcp_tools_mange_service.py b/api/services/tools/mcp_tools_manage_service.py similarity index 95% rename from api/services/tools/mcp_tools_mange_service.py rename to api/services/tools/mcp_tools_manage_service.py index fda6da5983..e0e256912e 100644 --- a/api/services/tools/mcp_tools_mange_service.py +++ b/api/services/tools/mcp_tools_manage_service.py @@ -70,16 +70,15 @@ class MCPToolManageService: MCPToolProvider.server_url_hash == server_url_hash, MCPToolProvider.server_identifier == server_identifier, ), - MCPToolProvider.tenant_id == tenant_id, ) .first() ) if existing_provider: if existing_provider.name == name: raise ValueError(f"MCP tool {name} already exists") - elif existing_provider.server_url_hash == server_url_hash: + if existing_provider.server_url_hash == server_url_hash: raise ValueError(f"MCP tool {server_url} already exists") - elif existing_provider.server_identifier == server_identifier: + if existing_provider.server_identifier == server_identifier: raise ValueError(f"MCP tool {server_identifier} already exists") encrypted_server_url = encrypter.encrypt_token(tenant_id, server_url) mcp_tool = MCPToolProvider( @@ -111,15 +110,14 @@ class MCPToolManageService: ] @classmethod - def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str): + def list_mcp_tool_from_remote_server(cls, tenant_id: str, provider_id: str) -> ToolProviderApiEntity: mcp_provider = cls.get_mcp_provider_by_provider_id(provider_id, tenant_id) - try: with MCPClient( mcp_provider.decrypted_server_url, provider_id, tenant_id, authed=mcp_provider.authed, for_list=True ) as mcp_client: tools = mcp_client.list_tools() - except MCPAuthError as e: + except MCPAuthError: raise ValueError("Please auth the tool first") except MCPError as e: raise ValueError(f"Failed to connect to MCP server: {e}") @@ -184,12 +182,11 @@ class MCPToolManageService: error_msg = str(e.orig) if "unique_mcp_provider_name" in error_msg: raise ValueError(f"MCP tool {name} already exists") - elif "unique_mcp_provider_server_url" in error_msg: + if "unique_mcp_provider_server_url" in error_msg: raise ValueError(f"MCP tool {server_url} already exists") - elif "unique_mcp_provider_server_identifier" in error_msg: + if "unique_mcp_provider_server_identifier" in error_msg: raise ValueError(f"MCP tool {server_identifier} already exists") - else: - raise + raise @classmethod def update_mcp_provider_credentials( diff --git a/api/tests/unit_tests/services/auth/test_firecrawl_auth.py b/api/tests/unit_tests/services/auth/test_firecrawl_auth.py new file mode 100644 index 0000000000..ffdf5897ed --- /dev/null +++ b/api/tests/unit_tests/services/auth/test_firecrawl_auth.py @@ -0,0 +1,191 @@ +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from services.auth.firecrawl.firecrawl import FirecrawlAuth + + +class TestFirecrawlAuth: + @pytest.fixture + def valid_credentials(self): + """Fixture for valid bearer credentials""" + return {"auth_type": "bearer", "config": {"api_key": "test_api_key_123"}} + + @pytest.fixture + def auth_instance(self, valid_credentials): + """Fixture for FirecrawlAuth instance with valid credentials""" + return FirecrawlAuth(valid_credentials) + + def test_should_initialize_with_valid_bearer_credentials(self, valid_credentials): + """Test successful initialization with valid bearer credentials""" + auth = FirecrawlAuth(valid_credentials) + assert auth.api_key == "test_api_key_123" + assert auth.base_url == "https://api.firecrawl.dev" + assert auth.credentials == valid_credentials + + def test_should_initialize_with_custom_base_url(self): + """Test initialization with custom base URL""" + credentials = { + "auth_type": "bearer", + "config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"}, + } + auth = FirecrawlAuth(credentials) + assert auth.api_key == "test_api_key_123" + assert auth.base_url == "https://custom.firecrawl.dev" + + @pytest.mark.parametrize( + ("auth_type", "expected_error"), + [ + ("basic", "Invalid auth type, Firecrawl auth type must be Bearer"), + ("x-api-key", "Invalid auth type, Firecrawl auth type must be Bearer"), + ("", "Invalid auth type, Firecrawl auth type must be Bearer"), + ], + ) + def test_should_raise_error_for_invalid_auth_type(self, auth_type, expected_error): + """Test that non-bearer auth types raise ValueError""" + credentials = {"auth_type": auth_type, "config": {"api_key": "test_api_key_123"}} + with pytest.raises(ValueError) as exc_info: + FirecrawlAuth(credentials) + assert str(exc_info.value) == expected_error + + @pytest.mark.parametrize( + ("credentials", "expected_error"), + [ + ({"auth_type": "bearer", "config": {}}, "No API key provided"), + ({"auth_type": "bearer"}, "No API key provided"), + ({"auth_type": "bearer", "config": {"api_key": ""}}, "No API key provided"), + ({"auth_type": "bearer", "config": {"api_key": None}}, "No API key provided"), + ], + ) + def test_should_raise_error_for_missing_api_key(self, credentials, expected_error): + """Test that missing or empty API key raises ValueError""" + with pytest.raises(ValueError) as exc_info: + FirecrawlAuth(credentials) + assert str(exc_info.value) == expected_error + + @patch("services.auth.firecrawl.firecrawl.requests.post") + def test_should_validate_valid_credentials_successfully(self, mock_post, auth_instance): + """Test successful credential validation""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_post.return_value = mock_response + + result = auth_instance.validate_credentials() + + assert result is True + expected_data = { + "url": "https://example.com", + "includePaths": [], + "excludePaths": [], + "limit": 1, + "scrapeOptions": {"onlyMainContent": True}, + } + mock_post.assert_called_once_with( + "https://api.firecrawl.dev/v1/crawl", + headers={"Content-Type": "application/json", "Authorization": "Bearer test_api_key_123"}, + json=expected_data, + ) + + @pytest.mark.parametrize( + ("status_code", "error_message"), + [ + (402, "Payment required"), + (409, "Conflict error"), + (500, "Internal server error"), + ], + ) + @patch("services.auth.firecrawl.firecrawl.requests.post") + def test_should_handle_http_errors(self, mock_post, status_code, error_message, auth_instance): + """Test handling of various HTTP error codes""" + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.json.return_value = {"error": error_message} + mock_post.return_value = mock_response + + with pytest.raises(Exception) as exc_info: + auth_instance.validate_credentials() + assert str(exc_info.value) == f"Failed to authorize. Status code: {status_code}. Error: {error_message}" + + @pytest.mark.parametrize( + ("status_code", "response_text", "has_json_error", "expected_error_contains"), + [ + (403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"), + (404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"), + (401, "Not JSON", True, "Expecting value"), # JSON decode error + ], + ) + @patch("services.auth.firecrawl.firecrawl.requests.post") + def test_should_handle_unexpected_errors( + self, mock_post, status_code, response_text, has_json_error, expected_error_contains, auth_instance + ): + """Test handling of unexpected errors with various response formats""" + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.text = response_text + if has_json_error: + mock_response.json.side_effect = Exception("Not JSON") + mock_post.return_value = mock_response + + with pytest.raises(Exception) as exc_info: + auth_instance.validate_credentials() + assert expected_error_contains in str(exc_info.value) + + @pytest.mark.parametrize( + ("exception_type", "exception_message"), + [ + (requests.ConnectionError, "Network error"), + (requests.Timeout, "Request timeout"), + (requests.ReadTimeout, "Read timeout"), + (requests.ConnectTimeout, "Connection timeout"), + ], + ) + @patch("services.auth.firecrawl.firecrawl.requests.post") + def test_should_handle_network_errors(self, mock_post, exception_type, exception_message, auth_instance): + """Test handling of various network-related errors including timeouts""" + mock_post.side_effect = exception_type(exception_message) + + with pytest.raises(exception_type) as exc_info: + auth_instance.validate_credentials() + assert exception_message in str(exc_info.value) + + def test_should_not_expose_api_key_in_error_messages(self): + """Test that API key is not exposed in error messages""" + credentials = {"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}} + auth = FirecrawlAuth(credentials) + + # Verify API key is stored but not in any error message + assert auth.api_key == "super_secret_key_12345" + + # Test various error scenarios don't expose the key + with pytest.raises(ValueError) as exc_info: + FirecrawlAuth({"auth_type": "basic", "config": {"api_key": "super_secret_key_12345"}}) + assert "super_secret_key_12345" not in str(exc_info.value) + + @patch("services.auth.firecrawl.firecrawl.requests.post") + def test_should_use_custom_base_url_in_validation(self, mock_post): + """Test that custom base URL is used in validation""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_post.return_value = mock_response + + credentials = { + "auth_type": "bearer", + "config": {"api_key": "test_api_key_123", "base_url": "https://custom.firecrawl.dev"}, + } + auth = FirecrawlAuth(credentials) + result = auth.validate_credentials() + + assert result is True + assert mock_post.call_args[0][0] == "https://custom.firecrawl.dev/v1/crawl" + + @patch("services.auth.firecrawl.firecrawl.requests.post") + def test_should_handle_timeout_with_retry_suggestion(self, mock_post, auth_instance): + """Test that timeout errors are handled gracefully with appropriate error message""" + mock_post.side_effect = requests.Timeout("The request timed out after 30 seconds") + + with pytest.raises(requests.Timeout) as exc_info: + auth_instance.validate_credentials() + + # Verify the timeout exception is raised with original message + assert "timed out" in str(exc_info.value) diff --git a/api/tests/unit_tests/services/auth/test_watercrawl_auth.py b/api/tests/unit_tests/services/auth/test_watercrawl_auth.py new file mode 100644 index 0000000000..bacf0b24ea --- /dev/null +++ b/api/tests/unit_tests/services/auth/test_watercrawl_auth.py @@ -0,0 +1,205 @@ +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from services.auth.watercrawl.watercrawl import WatercrawlAuth + + +class TestWatercrawlAuth: + @pytest.fixture + def valid_credentials(self): + """Fixture for valid x-api-key credentials""" + return {"auth_type": "x-api-key", "config": {"api_key": "test_api_key_123"}} + + @pytest.fixture + def auth_instance(self, valid_credentials): + """Fixture for WatercrawlAuth instance with valid credentials""" + return WatercrawlAuth(valid_credentials) + + def test_should_initialize_with_valid_x_api_key_credentials(self, valid_credentials): + """Test successful initialization with valid x-api-key credentials""" + auth = WatercrawlAuth(valid_credentials) + assert auth.api_key == "test_api_key_123" + assert auth.base_url == "https://app.watercrawl.dev" + assert auth.credentials == valid_credentials + + def test_should_initialize_with_custom_base_url(self): + """Test initialization with custom base URL""" + credentials = { + "auth_type": "x-api-key", + "config": {"api_key": "test_api_key_123", "base_url": "https://custom.watercrawl.dev"}, + } + auth = WatercrawlAuth(credentials) + assert auth.api_key == "test_api_key_123" + assert auth.base_url == "https://custom.watercrawl.dev" + + @pytest.mark.parametrize( + ("auth_type", "expected_error"), + [ + ("bearer", "Invalid auth type, WaterCrawl auth type must be x-api-key"), + ("basic", "Invalid auth type, WaterCrawl auth type must be x-api-key"), + ("", "Invalid auth type, WaterCrawl auth type must be x-api-key"), + ], + ) + def test_should_raise_error_for_invalid_auth_type(self, auth_type, expected_error): + """Test that non-x-api-key auth types raise ValueError""" + credentials = {"auth_type": auth_type, "config": {"api_key": "test_api_key_123"}} + with pytest.raises(ValueError) as exc_info: + WatercrawlAuth(credentials) + assert str(exc_info.value) == expected_error + + @pytest.mark.parametrize( + ("credentials", "expected_error"), + [ + ({"auth_type": "x-api-key", "config": {}}, "No API key provided"), + ({"auth_type": "x-api-key"}, "No API key provided"), + ({"auth_type": "x-api-key", "config": {"api_key": ""}}, "No API key provided"), + ({"auth_type": "x-api-key", "config": {"api_key": None}}, "No API key provided"), + ], + ) + def test_should_raise_error_for_missing_api_key(self, credentials, expected_error): + """Test that missing or empty API key raises ValueError""" + with pytest.raises(ValueError) as exc_info: + WatercrawlAuth(credentials) + assert str(exc_info.value) == expected_error + + @patch("services.auth.watercrawl.watercrawl.requests.get") + def test_should_validate_valid_credentials_successfully(self, mock_get, auth_instance): + """Test successful credential validation""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + result = auth_instance.validate_credentials() + + assert result is True + mock_get.assert_called_once_with( + "https://app.watercrawl.dev/api/v1/core/crawl-requests/", + headers={"Content-Type": "application/json", "X-API-KEY": "test_api_key_123"}, + ) + + @pytest.mark.parametrize( + ("status_code", "error_message"), + [ + (402, "Payment required"), + (409, "Conflict error"), + (500, "Internal server error"), + ], + ) + @patch("services.auth.watercrawl.watercrawl.requests.get") + def test_should_handle_http_errors(self, mock_get, status_code, error_message, auth_instance): + """Test handling of various HTTP error codes""" + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.json.return_value = {"error": error_message} + mock_get.return_value = mock_response + + with pytest.raises(Exception) as exc_info: + auth_instance.validate_credentials() + assert str(exc_info.value) == f"Failed to authorize. Status code: {status_code}. Error: {error_message}" + + @pytest.mark.parametrize( + ("status_code", "response_text", "has_json_error", "expected_error_contains"), + [ + (403, '{"error": "Forbidden"}', True, "Failed to authorize. Status code: 403. Error: Forbidden"), + (404, "", True, "Unexpected error occurred while trying to authorize. Status code: 404"), + (401, "Not JSON", True, "Expecting value"), # JSON decode error + ], + ) + @patch("services.auth.watercrawl.watercrawl.requests.get") + def test_should_handle_unexpected_errors( + self, mock_get, status_code, response_text, has_json_error, expected_error_contains, auth_instance + ): + """Test handling of unexpected errors with various response formats""" + mock_response = MagicMock() + mock_response.status_code = status_code + mock_response.text = response_text + if has_json_error: + mock_response.json.side_effect = Exception("Not JSON") + mock_get.return_value = mock_response + + with pytest.raises(Exception) as exc_info: + auth_instance.validate_credentials() + assert expected_error_contains in str(exc_info.value) + + @pytest.mark.parametrize( + ("exception_type", "exception_message"), + [ + (requests.ConnectionError, "Network error"), + (requests.Timeout, "Request timeout"), + (requests.ReadTimeout, "Read timeout"), + (requests.ConnectTimeout, "Connection timeout"), + ], + ) + @patch("services.auth.watercrawl.watercrawl.requests.get") + def test_should_handle_network_errors(self, mock_get, exception_type, exception_message, auth_instance): + """Test handling of various network-related errors including timeouts""" + mock_get.side_effect = exception_type(exception_message) + + with pytest.raises(exception_type) as exc_info: + auth_instance.validate_credentials() + assert exception_message in str(exc_info.value) + + def test_should_not_expose_api_key_in_error_messages(self): + """Test that API key is not exposed in error messages""" + credentials = {"auth_type": "x-api-key", "config": {"api_key": "super_secret_key_12345"}} + auth = WatercrawlAuth(credentials) + + # Verify API key is stored but not in any error message + assert auth.api_key == "super_secret_key_12345" + + # Test various error scenarios don't expose the key + with pytest.raises(ValueError) as exc_info: + WatercrawlAuth({"auth_type": "bearer", "config": {"api_key": "super_secret_key_12345"}}) + assert "super_secret_key_12345" not in str(exc_info.value) + + @patch("services.auth.watercrawl.watercrawl.requests.get") + def test_should_use_custom_base_url_in_validation(self, mock_get): + """Test that custom base URL is used in validation""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + credentials = { + "auth_type": "x-api-key", + "config": {"api_key": "test_api_key_123", "base_url": "https://custom.watercrawl.dev"}, + } + auth = WatercrawlAuth(credentials) + result = auth.validate_credentials() + + assert result is True + assert mock_get.call_args[0][0] == "https://custom.watercrawl.dev/api/v1/core/crawl-requests/" + + @pytest.mark.parametrize( + ("base_url", "expected_url"), + [ + ("https://app.watercrawl.dev", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"), + ("https://app.watercrawl.dev/", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"), + ("https://app.watercrawl.dev//", "https://app.watercrawl.dev/api/v1/core/crawl-requests/"), + ], + ) + @patch("services.auth.watercrawl.watercrawl.requests.get") + def test_should_use_urljoin_for_url_construction(self, mock_get, base_url, expected_url): + """Test that urljoin is used correctly for URL construction with various base URLs""" + mock_response = MagicMock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + credentials = {"auth_type": "x-api-key", "config": {"api_key": "test_api_key_123", "base_url": base_url}} + auth = WatercrawlAuth(credentials) + auth.validate_credentials() + + # Verify the correct URL was called + assert mock_get.call_args[0][0] == expected_url + + @patch("services.auth.watercrawl.watercrawl.requests.get") + def test_should_handle_timeout_with_retry_suggestion(self, mock_get, auth_instance): + """Test that timeout errors are handled gracefully with appropriate error message""" + mock_get.side_effect = requests.Timeout("The request timed out after 30 seconds") + + with pytest.raises(requests.Timeout) as exc_info: + auth_instance.validate_credentials() + + # Verify the timeout exception is raised with original message + assert "timed out" in str(exc_info.value) diff --git a/docker/.env.example b/docker/.env.example index ab98a40fef..6149f63165 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -283,11 +283,12 @@ REDIS_CLUSTERS_PASSWORD= # Celery Configuration # ------------------------------ -# Use redis as the broker, and redis db 1 for celery broker. -# Format as follows: `redis://:@:/` +# Use standalone redis as the broker, and redis db 1 for celery broker. (redis_username is usually set by defualt as empty) +# Format as follows: `redis://:@:/`. # Example: redis://:difyai123456@redis:6379/1 -# If use Redis Sentinel, format as follows: `sentinel://:@:/` -# Example: sentinel://localhost:26379/1;sentinel://localhost:26380/1;sentinel://localhost:26381/1 +# If use Redis Sentinel, format as follows: `sentinel://:@:/` +# For high availability, you can configure multiple Sentinel nodes (if provided) separated by semicolons like below example: +# Example: sentinel://:difyai123456@localhost:26379/1;sentinel://:difyai12345@localhost:26379/1;sentinel://:difyai12345@localhost:26379/1 CELERY_BROKER_URL=redis://:difyai123456@redis:6379/1 CELERY_BACKEND=redis BROKER_USE_SSL=false diff --git a/web/app/(commonLayout)/datasets/template/template.ja.mdx b/web/app/(commonLayout)/datasets/template/template.ja.mdx index 23f78b5d7d..6c0e20e1bb 100644 --- a/web/app/(commonLayout)/datasets/template/template.ja.mdx +++ b/web/app/(commonLayout)/datasets/template/template.ja.mdx @@ -83,7 +83,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi - subchunk_segmentation (object) 子チャンクルール - separator セグメンテーション識別子。現在は 1 つの区切り文字のみ許可。デフォルトは *** - max_tokens 最大長 (トークン) は親チャンクの長さより短いことを検証する必要があります - - chunk_overlap 隣接するチャンク間の重複を定義 (オプション) + - chunk_overlap 隣接するチャンク間の重なりを定義 (オプション) ナレッジベースにパラメータが設定されていない場合、最初のアップロードには以下のパラメータを提供する必要があります。提供されない場合、デフォルトパラメータが使用されます。 @@ -218,7 +218,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi - subchunk_segmentation (object) 子チャンクルール - separator セグメンテーション識別子。現在は 1 つの区切り文字のみ許可。デフォルトは *** - max_tokens 最大長 (トークン) は親チャンクの長さより短いことを検証する必要があります - - chunk_overlap 隣接するチャンク間の重複を定義 (オプション) + - chunk_overlap 隣接するチャンク間の重なりを定義 (オプション) アップロードする必要があるファイル。 @@ -555,7 +555,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi - subchunk_segmentation (object) 子チャンクルール - separator セグメンテーション識別子。現在は 1 つの区切り文字のみ許可。デフォルトは *** - max_tokens 最大長 (トークン) は親チャンクの長さより短いことを検証する必要があります - - chunk_overlap 隣接するチャンク間の重複を定義 (オプション) + - chunk_overlap 隣接するチャンク間の重なりを定義 (オプション) @@ -657,7 +657,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi - subchunk_segmentation (object) 子チャンクルール - separator セグメンテーション識別子。現在は 1 つの区切り文字のみ許可。デフォルトは *** - max_tokens 最大長 (トークン) は親チャンクの長さより短いことを検証する必要があります - - chunk_overlap 隣接するチャンク間の重複を定義 (オプション) + - chunk_overlap 隣接するチャンク間の重なりを定義 (オプション) diff --git a/web/app/(commonLayout)/layout.tsx b/web/app/(commonLayout)/layout.tsx index d07e2a99d9..64186a1b10 100644 --- a/web/app/(commonLayout)/layout.tsx +++ b/web/app/(commonLayout)/layout.tsx @@ -1,6 +1,6 @@ import React from 'react' import type { ReactNode } from 'react' -import SwrInitor from '@/app/components/swr-initor' +import SwrInitializer from '@/app/components/swr-initializer' import { AppContextProvider } from '@/context/app-context' import GA, { GaType } from '@/app/components/base/ga' import HeaderWrapper from '@/app/components/header/header-wrapper' @@ -13,7 +13,7 @@ const Layout = ({ children }: { children: ReactNode }) => { return ( <> - + @@ -26,7 +26,7 @@ const Layout = ({ children }: { children: ReactNode }) => { - + ) } diff --git a/web/app/account/account-page/index.tsx b/web/app/account/account-page/index.tsx index 55fa2983dd..47b8f045d2 100644 --- a/web/app/account/account-page/index.tsx +++ b/web/app/account/account-page/index.tsx @@ -1,5 +1,6 @@ 'use client' import { useState } from 'react' +import useSWR from 'swr' import { useTranslation } from 'react-i18next' import { RiGraduationCapFill, @@ -22,6 +23,8 @@ import PremiumBadge from '@/app/components/base/premium-badge' import { useGlobalPublicStore } from '@/context/global-public-context' import EmailChangeModal from './email-change-modal' import { validPassword } from '@/config' +import { fetchAppList } from '@/service/apps' +import type { App } from '@/types/app' const titleClassName = ` system-sm-semibold text-text-secondary @@ -33,7 +36,9 @@ const descriptionClassName = ` export default function AccountPage() { const { t } = useTranslation() const { systemFeatures } = useGlobalPublicStore() - const { mutateUserProfile, userProfile, apps } = useAppContext() + const { data: appList } = useSWR({ url: '/apps', params: { page: 1, limit: 100, name: '' } }, fetchAppList) + const apps = appList?.data || [] + const { mutateUserProfile, userProfile } = useAppContext() const { isEducationAccount } = useProviderContext() const { notify } = useContext(ToastContext) const [editNameModalVisible, setEditNameModalVisible] = useState(false) @@ -202,7 +207,7 @@ export default function AccountPage() { {!!apps.length && ( ({ ...app, key: app.id, name: app.name }))} + items={apps.map((app: App) => ({ ...app, key: app.id, name: app.name }))} renderItem={renderAppItem} wrapperClassName='mt-2' /> diff --git a/web/app/account/layout.tsx b/web/app/account/layout.tsx index e74716fb3b..b3225b5341 100644 --- a/web/app/account/layout.tsx +++ b/web/app/account/layout.tsx @@ -1,7 +1,7 @@ import React from 'react' import type { ReactNode } from 'react' import Header from './header' -import SwrInitor from '@/app/components/swr-initor' +import SwrInitor from '@/app/components/swr-initializer' import { AppContextProvider } from '@/context/app-context' import GA, { GaType } from '@/app/components/base/ga' import HeaderWrapper from '@/app/components/header/header-wrapper' diff --git a/web/app/components/app-sidebar/app-info.tsx b/web/app/components/app-sidebar/app-info.tsx index e85eaa2f53..c35047bbc5 100644 --- a/web/app/components/app-sidebar/app-info.tsx +++ b/web/app/components/app-sidebar/app-info.tsx @@ -1,6 +1,6 @@ import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' -import { useContext, useContextSelector } from 'use-context-selector' +import { useContext } from 'use-context-selector' import React, { useCallback, useState } from 'react' import { RiDeleteBinLine, @@ -15,7 +15,7 @@ import AppIcon from '../base/app-icon' import cn from '@/utils/classnames' import { useStore as useAppStore } from '@/app/components/app/store' import { ToastContext } from '@/app/components/base/toast' -import AppsContext, { useAppContext } from '@/context/app-context' +import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' import { copyApp, deleteApp, exportAppConfig, updateAppInfo } from '@/service/apps' import type { DuplicateAppModalProps } from '@/app/components/app/duplicate-modal' @@ -73,11 +73,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx const [showImportDSLModal, setShowImportDSLModal] = useState(false) const [secretEnvList, setSecretEnvList] = useState([]) - const mutateApps = useContextSelector( - AppsContext, - state => state.mutateApps, - ) - const onEdit: CreateAppModalProps['onConfirm'] = useCallback(async ({ name, icon_type, @@ -106,12 +101,11 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx message: t('app.editDone'), }) setAppDetail(app) - mutateApps() } catch { notify({ type: 'error', message: t('app.editFailed') }) } - }, [appDetail, mutateApps, notify, setAppDetail, t]) + }, [appDetail, notify, setAppDetail, t]) const onCopy: DuplicateAppModalProps['onConfirm'] = async ({ name, icon_type, icon, icon_background }) => { if (!appDetail) @@ -131,7 +125,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx message: t('app.newApp.appCreated'), }) localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') - mutateApps() onPlanInfoChanged() getRedirection(true, newApp, replace) } @@ -186,7 +179,6 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx try { await deleteApp(appDetail.id) notify({ type: 'success', message: t('app.appDeleted') }) - mutateApps() onPlanInfoChanged() setAppDetail() replace('/apps') @@ -198,7 +190,7 @@ const AppInfo = ({ expand, onlyShowDetail = false, openState = false, onDetailEx }) } setShowConfirmDelete(false) - }, [appDetail, mutateApps, notify, onPlanInfoChanged, replace, setAppDetail, t]) + }, [appDetail, notify, onPlanInfoChanged, replace, setAppDetail, t]) const { isCurrentWorkspaceEditor } = useAppContext() diff --git a/web/app/components/app/create-app-modal/index.tsx b/web/app/components/app/create-app-modal/index.tsx index f0a0da41a5..c37f7b051a 100644 --- a/web/app/components/app/create-app-modal/index.tsx +++ b/web/app/components/app/create-app-modal/index.tsx @@ -4,7 +4,7 @@ import { useCallback, useRef, useState } from 'react' import { useTranslation } from 'react-i18next' import { useRouter } from 'next/navigation' -import { useContext, useContextSelector } from 'use-context-selector' +import { useContext } from 'use-context-selector' import { RiArrowRightLine, RiArrowRightSLine, RiCommandLine, RiCornerDownLeftLine, RiExchange2Fill } from '@remixicon/react' import Link from 'next/link' import { useDebounceFn, useKeyPress } from 'ahooks' @@ -15,7 +15,7 @@ import Button from '@/app/components/base/button' import Divider from '@/app/components/base/divider' import cn from '@/utils/classnames' import { basePath } from '@/utils/var' -import AppsContext, { useAppContext } from '@/context/app-context' +import { useAppContext } from '@/context/app-context' import { useProviderContext } from '@/context/provider-context' import { ToastContext } from '@/app/components/base/toast' import type { AppMode } from '@/types/app' @@ -41,7 +41,6 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) const { t } = useTranslation() const { push } = useRouter() const { notify } = useContext(ToastContext) - const mutateApps = useContextSelector(AppsContext, state => state.mutateApps) const [appMode, setAppMode] = useState('advanced-chat') const [appIcon, setAppIcon] = useState({ type: 'emoji', icon: '🤖', background: '#FFEAD5' }) @@ -80,7 +79,6 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) notify({ type: 'success', message: t('app.newApp.appCreated') }) onSuccess() onClose() - mutateApps() localStorage.setItem(NEED_REFRESH_APP_LIST_KEY, '1') getRedirection(isCurrentWorkspaceEditor, app, push) } @@ -88,7 +86,7 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate }: CreateAppProps) notify({ type: 'error', message: t('app.newApp.appCreateFailed') }) } isCreatingRef.current = false - }, [name, notify, t, appMode, appIcon, description, onSuccess, onClose, mutateApps, push, isCurrentWorkspaceEditor]) + }, [name, notify, t, appMode, appIcon, description, onSuccess, onClose, push, isCurrentWorkspaceEditor]) const { run: handleCreateApp } = useDebounceFn(onCreate, { wait: 300 }) useKeyPress(['meta.enter', 'ctrl.enter'], () => { @@ -298,7 +296,7 @@ function AppTypeCard({ icon, title, description, active, onClick }: AppTypeCardP > {icon}
{title}
-
{description}
+
{description}
} diff --git a/web/app/components/app/overview/embedded/index.tsx b/web/app/components/app/overview/embedded/index.tsx index b48eac5458..9d97eae38d 100644 --- a/web/app/components/app/overview/embedded/index.tsx +++ b/web/app/components/app/overview/embedded/index.tsx @@ -90,10 +90,10 @@ const Embedded = ({ siteInfo, isShow, onClose, appBaseUrl, accessToken, classNam const [option, setOption] = useState