diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index c4d1ef70d8..1e3b55d3a0 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -200,6 +200,11 @@ class ToolBuiltinProviderGetCredentialsApi(Resource): class ToolBuiltinProviderIconApi(Resource): @setup_required def get(self, provider): + # Validate provider against an allowlist of known providers + valid_providers = ["provider1", "provider2", "provider3"] # Example allowlist + if provider not in valid_providers: + raise Forbidden("Invalid provider specified.") + icon_bytes, mimetype = BuiltinToolManageService.get_builtin_tool_provider_icon(provider) icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index f286466de0..0d23742063 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -476,14 +476,23 @@ class ToolManager: # get provider provider_controller = cls.get_hardcoded_provider(provider) - absolute_path = path.join( + base_path = path.join( path.dirname(path.realpath(__file__)), "builtin_tool", "providers", - provider, - "_assets", - provider_controller.entity.identity.icon, ) + absolute_path = path.normpath( + path.join( + base_path, + provider, + "_assets", + provider_controller.entity.identity.icon, + ) + ) + # Ensure the resolved path is within the base_path + if not absolute_path.startswith(base_path): + raise ToolProviderNotFoundError(f"Access to provider {provider} icon is not allowed") + # check if the icon exists if not path.exists(absolute_path): raise ToolProviderNotFoundError(f"builtin provider {provider} icon not found") diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 65f05d2986..755dc414a0 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -531,12 +531,12 @@ class BuiltinToolManageService: @staticmethod def get_builtin_tool_provider_icon(provider: str): """ - get tool provider icon and it's mimetype + get tool provider icon and its mimetype """ icon_path, mime_type = ToolManager.get_hardcoded_provider_icon(provider) - icon_bytes = Path(icon_path).read_bytes() - - return icon_bytes, mime_type + + # Normalize and validate the icon path + base_dir = Path("/safe/icon/directory") # Define the safe root directory @staticmethod def list_builtin_tools(user_id: str, tenant_id: str) -> list[ToolProviderApiEntity]: