feat: validate credentials

pull/9184/head
Yeuoly 2 years ago
parent 7a3e756020
commit 947bfdc807
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -35,4 +35,11 @@ class PluginToolProviderEntity(BaseModel):
provider: str provider: str
plugin_unique_identifier: str plugin_unique_identifier: str
plugin_id: str plugin_id: str
declaration: ToolProviderEntityWithPlugin declaration: ToolProviderEntityWithPlugin
class PluginBasicBooleanResponse(BaseModel):
"""
Basic boolean response from plugin daemon.
"""
result: bool

@ -1,5 +1,5 @@
import json import json
from collections.abc import Generator from collections.abc import Callable, Generator
from typing import TypeVar from typing import TypeVar
import requests import requests
@ -21,7 +21,7 @@ class BasePluginManager:
method: str, method: str,
path: str, path: str,
headers: dict | None = None, headers: dict | None = None,
data: bytes | dict | None = None, data: bytes | dict | str | None = None,
params: dict | None = None, params: dict | None = None,
stream: bool = False, stream: bool = False,
) -> requests.Response: ) -> requests.Response:
@ -31,6 +31,10 @@ class BasePluginManager:
url = URL(str(plugin_daemon_inner_api_baseurl)) / path url = URL(str(plugin_daemon_inner_api_baseurl)) / path
headers = headers or {} headers = headers or {}
headers["X-Api-Key"] = plugin_daemon_inner_api_key headers["X-Api-Key"] = plugin_daemon_inner_api_key
if headers.get("Content-Type") == "application/json" and isinstance(data, dict):
data = json.dumps(data)
response = requests.request( response = requests.request(
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream method=method, url=str(url), headers=headers, data=data, params=params, stream=stream
) )
@ -48,7 +52,11 @@ class BasePluginManager:
Make a stream request to the plugin daemon inner API Make a stream request to the plugin daemon inner API
""" """
response = self._request(method, path, headers, data, params, stream=True) response = self._request(method, path, headers, data, params, stream=True)
yield from response.iter_lines() for line in response.iter_lines():
line = line.decode("utf-8").strip()
if line.startswith("data:"):
line = line[5:].strip()
yield line
def _stream_request_with_model( def _stream_request_with_model(
self, self,
@ -88,17 +96,15 @@ class BasePluginManager:
headers: dict | None = None, headers: dict | None = None,
data: bytes | dict | None = None, data: bytes | dict | None = None,
params: dict | None = None, params: dict | None = None,
transformer: Callable[[dict], dict] | None = None,
) -> T: ) -> T:
""" """
Make a request to the plugin daemon inner API and return the response as a model. Make a request to the plugin daemon inner API and return the response as a model.
""" """
response = self._request(method, path, headers, data, params) response = self._request(method, path, headers, data, params)
json_response = response.json() json_response = response.json()
for provider in json_response.get("data", []): if transformer:
declaration = provider.get("declaration", {}) or {} json_response = transformer(json_response)
provider_name = declaration.get("identity", {}).get("name")
for tool in declaration.get("tools", []):
tool["identity"]["provider"] = provider_name
rep = PluginDaemonBasicResponse[type](**json_response) rep = PluginDaemonBasicResponse[type](**json_response)
if rep.code != 0: if rep.code != 0:
@ -128,3 +134,4 @@ class BasePluginManager:
if rep.data is None: if rep.data is None:
raise ValueError("got empty data from plugin daemon") raise ValueError("got empty data from plugin daemon")
yield rep.data yield rep.data

@ -1,7 +1,7 @@
from collections.abc import Generator from collections.abc import Generator
from typing import Any from typing import Any
from core.plugin.entities.plugin_daemon import PluginToolProviderEntity from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.manager.base import BasePluginManager from core.plugin.manager.base import BasePluginManager
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
@ -11,8 +11,22 @@ class PluginToolManager(BasePluginManager):
""" """
Fetch tool providers for the given asset. Fetch tool providers for the given asset.
""" """
def transformer(json_response: dict[str, Any]) -> dict:
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name")
for tool in declaration.get("tools", []):
tool["identity"]["provider"] = provider_name
return json_response
response = self._request_with_plugin_daemon_response( response = self._request_with_plugin_daemon_response(
"GET", f"plugin/{tenant_id}/tools", list[PluginToolProviderEntity], params={"page": 1, "page_size": 256} "GET",
f"plugin/{tenant_id}/management/tools",
list[PluginToolProviderEntity],
params={"page": 1, "page_size": 256},
transformer=transformer,
) )
return response return response
@ -28,7 +42,7 @@ class PluginToolManager(BasePluginManager):
) -> Generator[ToolInvokeMessage, None, None]: ) -> Generator[ToolInvokeMessage, None, None]:
response = self._request_with_plugin_daemon_response_stream( response = self._request_with_plugin_daemon_response_stream(
"POST", "POST",
f"plugin/{tenant_id}/tool/invoke", f"plugin/{tenant_id}/dispatch/tool/invoke",
ToolInvokeMessage, ToolInvokeMessage,
data={ data={
"plugin_unique_identifier": plugin_unique_identifier, "plugin_unique_identifier": plugin_unique_identifier,
@ -40,6 +54,10 @@ class PluginToolManager(BasePluginManager):
"tool_parameters": tool_parameters, "tool_parameters": tool_parameters,
}, },
}, },
headers={
"X-Plugin-Identifier": plugin_unique_identifier,
"Content-Type": "application/json",
}
) )
return response return response
@ -49,10 +67,10 @@ class PluginToolManager(BasePluginManager):
""" """
validate the credentials of the provider validate the credentials of the provider
""" """
response = self._request_with_plugin_daemon_response( response = self._request_with_plugin_daemon_response_stream(
"POST", "POST",
f"plugin/{tenant_id}/tool/validate_credentials", f"plugin/{tenant_id}/dispatch/tool/validate_credentials",
bool, PluginBasicBooleanResponse,
data={ data={
"plugin_unique_identifier": plugin_unique_identifier, "plugin_unique_identifier": plugin_unique_identifier,
"user_id": user_id, "user_id": user_id,
@ -61,5 +79,13 @@ class PluginToolManager(BasePluginManager):
"credentials": credentials, "credentials": credentials,
}, },
}, },
headers={
"X-Plugin-Identifier": plugin_unique_identifier,
"Content-Type": "application/json",
}
) )
return response
for resp in response:
return resp.result
return False

Loading…
Cancel
Save