Merge branch 'main' into feat/support-knowledge-metadata
# Conflicts: # api/core/rag/datasource/retrieval_service.py # api/core/workflow/nodes/code/code_node.py # api/services/dataset_service.pydev/plugin-deploy
commit
17f23f4798
@ -1,9 +1,30 @@
|
|||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
|
from threading import Lock
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from contexts.wrapper import RecyclableContextVar
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
|
||||||
|
from core.tools.plugin_tool.provider import PluginToolProviderController
|
||||||
from core.workflow.entities.variable_pool import VariablePool
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
|
||||||
|
|
||||||
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
tenant_id: ContextVar[str] = ContextVar("tenant_id")
|
||||||
|
|
||||||
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
workflow_variable_pool: ContextVar["VariablePool"] = ContextVar("workflow_variable_pool")
|
||||||
|
|
||||||
|
"""
|
||||||
|
To avoid race-conditions caused by gunicorn thread recycling, using RecyclableContextVar to replace with
|
||||||
|
"""
|
||||||
|
plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderController"]] = RecyclableContextVar(
|
||||||
|
ContextVar("plugin_tool_providers")
|
||||||
|
)
|
||||||
|
plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock"))
|
||||||
|
|
||||||
|
plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar(
|
||||||
|
ContextVar("plugin_model_providers")
|
||||||
|
)
|
||||||
|
plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
|
||||||
|
ContextVar("plugin_model_providers_lock")
|
||||||
|
)
|
||||||
|
|||||||
@ -0,0 +1,65 @@
|
|||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class HiddenValue:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
_default = HiddenValue()
|
||||||
|
|
||||||
|
|
||||||
|
class RecyclableContextVar(Generic[T]):
|
||||||
|
"""
|
||||||
|
RecyclableContextVar is a wrapper around ContextVar
|
||||||
|
It's safe to use in gunicorn with thread recycling, but features like `reset` are not available for now
|
||||||
|
|
||||||
|
NOTE: you need to call `increment_thread_recycles` before requests
|
||||||
|
"""
|
||||||
|
|
||||||
|
_thread_recycles: ContextVar[int] = ContextVar("thread_recycles")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def increment_thread_recycles(cls):
|
||||||
|
try:
|
||||||
|
recycles = cls._thread_recycles.get()
|
||||||
|
cls._thread_recycles.set(recycles + 1)
|
||||||
|
except LookupError:
|
||||||
|
cls._thread_recycles.set(0)
|
||||||
|
|
||||||
|
def __init__(self, context_var: ContextVar[T]):
|
||||||
|
self._context_var = context_var
|
||||||
|
self._updates = ContextVar[int](context_var.name + "_updates", default=0)
|
||||||
|
|
||||||
|
def get(self, default: T | HiddenValue = _default) -> T:
|
||||||
|
thread_recycles = self._thread_recycles.get(0)
|
||||||
|
self_updates = self._updates.get()
|
||||||
|
if thread_recycles > self_updates:
|
||||||
|
self._updates.set(thread_recycles)
|
||||||
|
|
||||||
|
# check if thread is recycled and should be updated
|
||||||
|
if thread_recycles < self_updates:
|
||||||
|
return self._context_var.get()
|
||||||
|
else:
|
||||||
|
# thread_recycles >= self_updates, means current context is invalid
|
||||||
|
if isinstance(default, HiddenValue) or default is _default:
|
||||||
|
raise LookupError
|
||||||
|
else:
|
||||||
|
return default
|
||||||
|
|
||||||
|
def set(self, value: T):
|
||||||
|
# it leads to a situation that self.updates is less than cls.thread_recycles if `set` was never called before
|
||||||
|
# increase it manually
|
||||||
|
thread_recycles = self._thread_recycles.get(0)
|
||||||
|
self_updates = self._updates.get()
|
||||||
|
if thread_recycles > self_updates:
|
||||||
|
self._updates.set(thread_recycles)
|
||||||
|
|
||||||
|
if self._updates.get() == self._thread_recycles.get(0):
|
||||||
|
# after increment,
|
||||||
|
self._updates.set(self._updates.get() + 1)
|
||||||
|
|
||||||
|
# set the context
|
||||||
|
self._context_var.set(value)
|
||||||
@ -0,0 +1,56 @@
|
|||||||
|
from functools import wraps
|
||||||
|
|
||||||
|
from flask_login import current_user # type: ignore
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.account import TenantPluginPermission
|
||||||
|
|
||||||
|
|
||||||
|
def plugin_permission_required(
|
||||||
|
install_required: bool = False,
|
||||||
|
debug_required: bool = False,
|
||||||
|
):
|
||||||
|
def interceptor(view):
|
||||||
|
@wraps(view)
|
||||||
|
def decorated(*args, **kwargs):
|
||||||
|
user = current_user
|
||||||
|
tenant_id = user.current_tenant_id
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
permission = (
|
||||||
|
session.query(TenantPluginPermission)
|
||||||
|
.filter(
|
||||||
|
TenantPluginPermission.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not permission:
|
||||||
|
# no permission set, allow access for everyone
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
|
if install_required:
|
||||||
|
if permission.install_permission == TenantPluginPermission.InstallPermission.NOBODY:
|
||||||
|
raise Forbidden()
|
||||||
|
if permission.install_permission == TenantPluginPermission.InstallPermission.ADMINS:
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
if permission.install_permission == TenantPluginPermission.InstallPermission.EVERYONE:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if debug_required:
|
||||||
|
if permission.debug_permission == TenantPluginPermission.DebugPermission.NOBODY:
|
||||||
|
raise Forbidden()
|
||||||
|
if permission.debug_permission == TenantPluginPermission.DebugPermission.ADMINS:
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
if permission.debug_permission == TenantPluginPermission.DebugPermission.EVERYONE:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return view(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated
|
||||||
|
|
||||||
|
return interceptor
|
||||||
@ -0,0 +1,36 @@
|
|||||||
|
from flask_login import current_user # type: ignore
|
||||||
|
from flask_restful import Resource # type: ignore
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from libs.login import login_required
|
||||||
|
from services.agent_service import AgentService
|
||||||
|
|
||||||
|
|
||||||
|
class AgentProviderListApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
user = current_user
|
||||||
|
|
||||||
|
user_id = user.id
|
||||||
|
tenant_id = user.current_tenant_id
|
||||||
|
|
||||||
|
return jsonable_encoder(AgentService.list_agent_providers(user_id, tenant_id))
|
||||||
|
|
||||||
|
|
||||||
|
class AgentProviderApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, provider_name: str):
|
||||||
|
user = current_user
|
||||||
|
user_id = user.id
|
||||||
|
tenant_id = user.current_tenant_id
|
||||||
|
return jsonable_encoder(AgentService.get_agent_provider(user_id, tenant_id, provider_name))
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(AgentProviderListApi, "/workspaces/current/agent-providers")
|
||||||
|
api.add_resource(AgentProviderApi, "/workspaces/current/agent-provider/<path:provider_name>")
|
||||||
@ -0,0 +1,205 @@
|
|||||||
|
from flask_login import current_user # type: ignore
|
||||||
|
from flask_restful import Resource, reqparse # type: ignore
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from libs.login import login_required
|
||||||
|
from services.plugin.endpoint_service import EndpointService
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointCreateApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
user = current_user
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("plugin_unique_identifier", type=str, required=True)
|
||||||
|
parser.add_argument("settings", type=dict, required=True)
|
||||||
|
parser.add_argument("name", type=str, required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
plugin_unique_identifier = args["plugin_unique_identifier"]
|
||||||
|
settings = args["settings"]
|
||||||
|
name = args["name"]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": EndpointService.create_endpoint(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
plugin_unique_identifier=plugin_unique_identifier,
|
||||||
|
name=name,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointListApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
user = current_user
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("page", type=int, required=True, location="args")
|
||||||
|
parser.add_argument("page_size", type=int, required=True, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
page = args["page"]
|
||||||
|
page_size = args["page_size"]
|
||||||
|
|
||||||
|
return jsonable_encoder(
|
||||||
|
{
|
||||||
|
"endpoints": EndpointService.list_endpoints(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointListForSinglePluginApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
user = current_user
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("page", type=int, required=True, location="args")
|
||||||
|
parser.add_argument("page_size", type=int, required=True, location="args")
|
||||||
|
parser.add_argument("plugin_id", type=str, required=True, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
page = args["page"]
|
||||||
|
page_size = args["page_size"]
|
||||||
|
plugin_id = args["plugin_id"]
|
||||||
|
|
||||||
|
return jsonable_encoder(
|
||||||
|
{
|
||||||
|
"endpoints": EndpointService.list_endpoints_for_single_plugin(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointDeleteApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
user = current_user
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("endpoint_id", type=str, required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
endpoint_id = args["endpoint_id"]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": EndpointService.delete_endpoint(
|
||||||
|
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointUpdateApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
user = current_user
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("endpoint_id", type=str, required=True)
|
||||||
|
parser.add_argument("settings", type=dict, required=True)
|
||||||
|
parser.add_argument("name", type=str, required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
endpoint_id = args["endpoint_id"]
|
||||||
|
settings = args["settings"]
|
||||||
|
name = args["name"]
|
||||||
|
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": EndpointService.update_endpoint(
|
||||||
|
tenant_id=user.current_tenant_id,
|
||||||
|
user_id=user.id,
|
||||||
|
endpoint_id=endpoint_id,
|
||||||
|
name=name,
|
||||||
|
settings=settings,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointEnableApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
user = current_user
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("endpoint_id", type=str, required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
endpoint_id = args["endpoint_id"]
|
||||||
|
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": EndpointService.enable_endpoint(
|
||||||
|
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class EndpointDisableApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
user = current_user
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("endpoint_id", type=str, required=True)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
endpoint_id = args["endpoint_id"]
|
||||||
|
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": EndpointService.disable_endpoint(
|
||||||
|
tenant_id=user.current_tenant_id, user_id=user.id, endpoint_id=endpoint_id
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(EndpointCreateApi, "/workspaces/current/endpoints/create")
|
||||||
|
api.add_resource(EndpointListApi, "/workspaces/current/endpoints/list")
|
||||||
|
api.add_resource(EndpointListForSinglePluginApi, "/workspaces/current/endpoints/list/plugin")
|
||||||
|
api.add_resource(EndpointDeleteApi, "/workspaces/current/endpoints/delete")
|
||||||
|
api.add_resource(EndpointUpdateApi, "/workspaces/current/endpoints/update")
|
||||||
|
api.add_resource(EndpointEnableApi, "/workspaces/current/endpoints/enable")
|
||||||
|
api.add_resource(EndpointDisableApi, "/workspaces/current/endpoints/disable")
|
||||||
@ -0,0 +1,475 @@
|
|||||||
|
import io
|
||||||
|
|
||||||
|
from flask import request, send_file
|
||||||
|
from flask_login import current_user # type: ignore
|
||||||
|
from flask_restful import Resource, reqparse # type: ignore
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.workspace import plugin_permission_required
|
||||||
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.plugin.manager.exc import PluginDaemonClientSideError
|
||||||
|
from libs.login import login_required
|
||||||
|
from models.account import TenantPluginPermission
|
||||||
|
from services.plugin.plugin_permission_service import PluginPermissionService
|
||||||
|
from services.plugin.plugin_service import PluginService
|
||||||
|
|
||||||
|
|
||||||
|
class PluginDebuggingKeyApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(debug_required=True)
|
||||||
|
def get(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
return {
|
||||||
|
"key": PluginService.get_debugging_key(tenant_id),
|
||||||
|
"host": dify_config.PLUGIN_REMOTE_INSTALL_HOST,
|
||||||
|
"port": dify_config.PLUGIN_REMOTE_INSTALL_PORT,
|
||||||
|
}
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginListApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
try:
|
||||||
|
plugins = PluginService.list(tenant_id)
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
return jsonable_encoder({"plugins": plugins})
|
||||||
|
|
||||||
|
|
||||||
|
class PluginListInstallationsFromIdsApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("plugin_ids", type=list, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
plugins = PluginService.list_installations_from_ids(tenant_id, args["plugin_ids"])
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
return jsonable_encoder({"plugins": plugins})
|
||||||
|
|
||||||
|
|
||||||
|
class PluginIconApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
def get(self):
|
||||||
|
req = reqparse.RequestParser()
|
||||||
|
req.add_argument("tenant_id", type=str, required=True, location="args")
|
||||||
|
req.add_argument("filename", type=str, required=True, location="args")
|
||||||
|
args = req.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
icon_bytes, mimetype = PluginService.get_asset(args["tenant_id"], args["filename"])
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE
|
||||||
|
return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginUploadFromPkgApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(install_required=True)
|
||||||
|
def post(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
file = request.files["pkg"]
|
||||||
|
|
||||||
|
# check file size
|
||||||
|
if file.content_length > dify_config.PLUGIN_MAX_PACKAGE_SIZE:
|
||||||
|
raise ValueError("File size exceeds the maximum allowed size")
|
||||||
|
|
||||||
|
content = file.read()
|
||||||
|
try:
|
||||||
|
response = PluginService.upload_pkg(tenant_id, content)
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
return jsonable_encoder(response)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginUploadFromGithubApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(install_required=True)
|
||||||
|
def post(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("repo", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("version", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("package", type=str, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = PluginService.upload_pkg_from_github(tenant_id, args["repo"], args["version"], args["package"])
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
return jsonable_encoder(response)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginUploadFromBundleApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(install_required=True)
|
||||||
|
def post(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
file = request.files["bundle"]
|
||||||
|
|
||||||
|
# check file size
|
||||||
|
if file.content_length > dify_config.PLUGIN_MAX_BUNDLE_SIZE:
|
||||||
|
raise ValueError("File size exceeds the maximum allowed size")
|
||||||
|
|
||||||
|
content = file.read()
|
||||||
|
try:
|
||||||
|
response = PluginService.upload_bundle(tenant_id, content)
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
return jsonable_encoder(response)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstallFromPkgApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(install_required=True)
|
||||||
|
def post(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# check if all plugin_unique_identifiers are valid string
|
||||||
|
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
|
||||||
|
if not isinstance(plugin_unique_identifier, str):
|
||||||
|
raise ValueError("Invalid plugin unique identifier")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = PluginService.install_from_local_pkg(tenant_id, args["plugin_unique_identifiers"])
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
return jsonable_encoder(response)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstallFromGithubApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(install_required=True)
|
||||||
|
def post(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("repo", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("version", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("package", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = PluginService.install_from_github(
|
||||||
|
tenant_id,
|
||||||
|
args["plugin_unique_identifier"],
|
||||||
|
args["repo"],
|
||||||
|
args["version"],
|
||||||
|
args["package"],
|
||||||
|
)
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
return jsonable_encoder(response)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstallFromMarketplaceApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(install_required=True)
|
||||||
|
def post(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("plugin_unique_identifiers", type=list, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# check if all plugin_unique_identifiers are valid string
|
||||||
|
for plugin_unique_identifier in args["plugin_unique_identifiers"]:
|
||||||
|
if not isinstance(plugin_unique_identifier, str):
|
||||||
|
raise ValueError("Invalid plugin unique identifier")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = PluginService.install_from_marketplace_pkg(tenant_id, args["plugin_unique_identifiers"])
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
return jsonable_encoder(response)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginFetchManifestApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(debug_required=True)
|
||||||
|
def get(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("plugin_unique_identifier", type=str, required=True, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return jsonable_encoder(
|
||||||
|
{
|
||||||
|
"manifest": PluginService.fetch_plugin_manifest(
|
||||||
|
tenant_id, args["plugin_unique_identifier"]
|
||||||
|
).model_dump()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginFetchInstallTasksApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(debug_required=True)
|
||||||
|
def get(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("page", type=int, required=True, location="args")
|
||||||
|
parser.add_argument("page_size", type=int, required=True, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return jsonable_encoder(
|
||||||
|
{"tasks": PluginService.fetch_install_tasks(tenant_id, args["page"], args["page_size"])}
|
||||||
|
)
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginFetchInstallTaskApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(debug_required=True)
|
||||||
|
def get(self, task_id: str):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)})
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginDeleteInstallTaskApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(debug_required=True)
|
||||||
|
def post(self, task_id: str):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
return {"success": PluginService.delete_install_task(tenant_id, task_id)}
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginDeleteAllInstallTaskItemsApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(debug_required=True)
|
||||||
|
def post(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
return {"success": PluginService.delete_all_install_task_items(tenant_id)}
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginDeleteInstallTaskItemApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(debug_required=True)
|
||||||
|
def post(self, task_id: str, identifier: str):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)}
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginUpgradeFromMarketplaceApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(debug_required=True)
|
||||||
|
def post(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return jsonable_encoder(
|
||||||
|
PluginService.upgrade_plugin_with_marketplace(
|
||||||
|
tenant_id, args["original_plugin_unique_identifier"], args["new_plugin_unique_identifier"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginUpgradeFromGithubApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(debug_required=True)
|
||||||
|
def post(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("original_plugin_unique_identifier", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("new_plugin_unique_identifier", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("repo", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("version", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("package", type=str, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
return jsonable_encoder(
|
||||||
|
PluginService.upgrade_plugin_with_github(
|
||||||
|
tenant_id,
|
||||||
|
args["original_plugin_unique_identifier"],
|
||||||
|
args["new_plugin_unique_identifier"],
|
||||||
|
args["repo"],
|
||||||
|
args["version"],
|
||||||
|
args["package"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginUninstallApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@plugin_permission_required(debug_required=True)
|
||||||
|
def post(self):
|
||||||
|
req = reqparse.RequestParser()
|
||||||
|
req.add_argument("plugin_installation_id", type=str, required=True, location="json")
|
||||||
|
args = req.parse_args()
|
||||||
|
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
try:
|
||||||
|
return {"success": PluginService.uninstall(tenant_id, args["plugin_installation_id"])}
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
raise ValueError(e)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginChangePermissionApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
user = current_user
|
||||||
|
if not user.is_admin_or_owner:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
req = reqparse.RequestParser()
|
||||||
|
req.add_argument("install_permission", type=str, required=True, location="json")
|
||||||
|
req.add_argument("debug_permission", type=str, required=True, location="json")
|
||||||
|
args = req.parse_args()
|
||||||
|
|
||||||
|
install_permission = TenantPluginPermission.InstallPermission(args["install_permission"])
|
||||||
|
debug_permission = TenantPluginPermission.DebugPermission(args["debug_permission"])
|
||||||
|
|
||||||
|
tenant_id = user.current_tenant_id
|
||||||
|
|
||||||
|
return {"success": PluginPermissionService.change_permission(tenant_id, install_permission, debug_permission)}
|
||||||
|
|
||||||
|
|
||||||
|
class PluginFetchPermissionApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
permission = PluginPermissionService.get_permission(tenant_id)
|
||||||
|
if not permission:
|
||||||
|
return jsonable_encoder(
|
||||||
|
{
|
||||||
|
"install_permission": TenantPluginPermission.InstallPermission.EVERYONE,
|
||||||
|
"debug_permission": TenantPluginPermission.DebugPermission.EVERYONE,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return jsonable_encoder(
|
||||||
|
{
|
||||||
|
"install_permission": permission.install_permission,
|
||||||
|
"debug_permission": permission.debug_permission,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(PluginDebuggingKeyApi, "/workspaces/current/plugin/debugging-key")
|
||||||
|
api.add_resource(PluginListApi, "/workspaces/current/plugin/list")
|
||||||
|
api.add_resource(PluginListInstallationsFromIdsApi, "/workspaces/current/plugin/list/installations/ids")
|
||||||
|
api.add_resource(PluginIconApi, "/workspaces/current/plugin/icon")
|
||||||
|
api.add_resource(PluginUploadFromPkgApi, "/workspaces/current/plugin/upload/pkg")
|
||||||
|
api.add_resource(PluginUploadFromGithubApi, "/workspaces/current/plugin/upload/github")
|
||||||
|
api.add_resource(PluginUploadFromBundleApi, "/workspaces/current/plugin/upload/bundle")
|
||||||
|
api.add_resource(PluginInstallFromPkgApi, "/workspaces/current/plugin/install/pkg")
|
||||||
|
api.add_resource(PluginInstallFromGithubApi, "/workspaces/current/plugin/install/github")
|
||||||
|
api.add_resource(PluginUpgradeFromMarketplaceApi, "/workspaces/current/plugin/upgrade/marketplace")
|
||||||
|
api.add_resource(PluginUpgradeFromGithubApi, "/workspaces/current/plugin/upgrade/github")
|
||||||
|
api.add_resource(PluginInstallFromMarketplaceApi, "/workspaces/current/plugin/install/marketplace")
|
||||||
|
api.add_resource(PluginFetchManifestApi, "/workspaces/current/plugin/fetch-manifest")
|
||||||
|
api.add_resource(PluginFetchInstallTasksApi, "/workspaces/current/plugin/tasks")
|
||||||
|
api.add_resource(PluginFetchInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>")
|
||||||
|
api.add_resource(PluginDeleteInstallTaskApi, "/workspaces/current/plugin/tasks/<task_id>/delete")
|
||||||
|
api.add_resource(PluginDeleteAllInstallTaskItemsApi, "/workspaces/current/plugin/tasks/delete_all")
|
||||||
|
api.add_resource(PluginDeleteInstallTaskItemApi, "/workspaces/current/plugin/tasks/<task_id>/delete/<path:identifier>")
|
||||||
|
api.add_resource(PluginUninstallApi, "/workspaces/current/plugin/uninstall")
|
||||||
|
|
||||||
|
api.add_resource(PluginChangePermissionApi, "/workspaces/current/plugin/permission/change")
|
||||||
|
api.add_resource(PluginFetchPermissionApi, "/workspaces/current/plugin/permission/fetch")
|
||||||
@ -0,0 +1,69 @@
|
|||||||
|
from flask import request
|
||||||
|
from flask_restful import Resource, marshal_with # type: ignore
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
import services
|
||||||
|
from controllers.console.wraps import setup_required
|
||||||
|
from controllers.files import api
|
||||||
|
from controllers.files.error import UnsupportedFileTypeError
|
||||||
|
from controllers.inner_api.plugin.wraps import get_user
|
||||||
|
from controllers.service_api.app.error import FileTooLargeError
|
||||||
|
from core.file.helpers import verify_plugin_file_signature
|
||||||
|
from fields.file_fields import file_fields
|
||||||
|
from services.file_service import FileService
|
||||||
|
|
||||||
|
|
||||||
|
class PluginUploadFileApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@marshal_with(file_fields)
|
||||||
|
def post(self):
|
||||||
|
# get file from request
|
||||||
|
file = request.files["file"]
|
||||||
|
|
||||||
|
timestamp = request.args.get("timestamp")
|
||||||
|
nonce = request.args.get("nonce")
|
||||||
|
sign = request.args.get("sign")
|
||||||
|
tenant_id = request.args.get("tenant_id")
|
||||||
|
if not tenant_id:
|
||||||
|
raise Forbidden("Invalid request.")
|
||||||
|
|
||||||
|
user_id = request.args.get("user_id")
|
||||||
|
user = get_user(tenant_id, user_id)
|
||||||
|
|
||||||
|
filename = file.filename
|
||||||
|
mimetype = file.mimetype
|
||||||
|
|
||||||
|
if not filename or not mimetype:
|
||||||
|
raise Forbidden("Invalid request.")
|
||||||
|
|
||||||
|
if not timestamp or not nonce or not sign:
|
||||||
|
raise Forbidden("Invalid request.")
|
||||||
|
|
||||||
|
if not verify_plugin_file_signature(
|
||||||
|
filename=filename,
|
||||||
|
mimetype=mimetype,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
timestamp=timestamp,
|
||||||
|
nonce=nonce,
|
||||||
|
sign=sign,
|
||||||
|
):
|
||||||
|
raise Forbidden("Invalid request.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
upload_file = FileService.upload_file(
|
||||||
|
filename=filename,
|
||||||
|
content=file.read(),
|
||||||
|
mimetype=mimetype,
|
||||||
|
user=user,
|
||||||
|
source=None,
|
||||||
|
)
|
||||||
|
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||||
|
raise FileTooLargeError(file_too_large_error.description)
|
||||||
|
except services.errors.file.UnsupportedFileTypeError:
|
||||||
|
raise UnsupportedFileTypeError()
|
||||||
|
|
||||||
|
return upload_file, 201
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(PluginUploadFileApi, "/files/upload/for-plugin")
|
||||||
@ -0,0 +1,293 @@
|
|||||||
|
from flask_restful import Resource # type: ignore
|
||||||
|
|
||||||
|
from controllers.console.wraps import setup_required
|
||||||
|
from controllers.inner_api import api
|
||||||
|
from controllers.inner_api.plugin.wraps import get_user_tenant, plugin_data
|
||||||
|
from controllers.inner_api.wraps import plugin_inner_api_only
|
||||||
|
from core.file.helpers import get_signed_file_url_for_plugin
|
||||||
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
|
||||||
|
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
|
||||||
|
from core.plugin.backwards_invocation.encrypt import PluginEncrypter
|
||||||
|
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
|
||||||
|
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
|
||||||
|
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
|
||||||
|
from core.plugin.entities.request import (
|
||||||
|
RequestInvokeApp,
|
||||||
|
RequestInvokeEncrypt,
|
||||||
|
RequestInvokeLLM,
|
||||||
|
RequestInvokeModeration,
|
||||||
|
RequestInvokeParameterExtractorNode,
|
||||||
|
RequestInvokeQuestionClassifierNode,
|
||||||
|
RequestInvokeRerank,
|
||||||
|
RequestInvokeSpeech2Text,
|
||||||
|
RequestInvokeSummary,
|
||||||
|
RequestInvokeTextEmbedding,
|
||||||
|
RequestInvokeTool,
|
||||||
|
RequestInvokeTTS,
|
||||||
|
RequestRequestUploadFile,
|
||||||
|
)
|
||||||
|
from core.tools.entities.tool_entities import ToolProviderType
|
||||||
|
from libs.helper import compact_generate_response
|
||||||
|
from models.account import Account, Tenant
|
||||||
|
from models.model import EndUser
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInvokeLLMApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
|
@plugin_data(payload_type=RequestInvokeLLM)
|
||||||
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeLLM):
|
||||||
|
def generator():
|
||||||
|
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
|
||||||
|
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||||
|
|
||||||
|
return compact_generate_response(generator())
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInvokeTextEmbeddingApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
|
@plugin_data(payload_type=RequestInvokeTextEmbedding)
|
||||||
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTextEmbedding):
|
||||||
|
try:
|
||||||
|
return jsonable_encoder(
|
||||||
|
BaseBackwardsInvocationResponse(
|
||||||
|
data=PluginModelBackwardsInvocation.invoke_text_embedding(
|
||||||
|
user_id=user_model.id,
|
||||||
|
tenant=tenant_model,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInvokeRerankApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
|
@plugin_data(payload_type=RequestInvokeRerank)
|
||||||
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeRerank):
|
||||||
|
try:
|
||||||
|
return jsonable_encoder(
|
||||||
|
BaseBackwardsInvocationResponse(
|
||||||
|
data=PluginModelBackwardsInvocation.invoke_rerank(
|
||||||
|
user_id=user_model.id,
|
||||||
|
tenant=tenant_model,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInvokeTTSApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
|
@plugin_data(payload_type=RequestInvokeTTS)
|
||||||
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTTS):
|
||||||
|
def generator():
|
||||||
|
response = PluginModelBackwardsInvocation.invoke_tts(
|
||||||
|
user_id=user_model.id,
|
||||||
|
tenant=tenant_model,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
|
||||||
|
|
||||||
|
return compact_generate_response(generator())
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInvokeSpeech2TextApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
|
@plugin_data(payload_type=RequestInvokeSpeech2Text)
|
||||||
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSpeech2Text):
|
||||||
|
try:
|
||||||
|
return jsonable_encoder(
|
||||||
|
BaseBackwardsInvocationResponse(
|
||||||
|
data=PluginModelBackwardsInvocation.invoke_speech2text(
|
||||||
|
user_id=user_model.id,
|
||||||
|
tenant=tenant_model,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInvokeModerationApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
|
@plugin_data(payload_type=RequestInvokeModeration)
|
||||||
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeModeration):
|
||||||
|
try:
|
||||||
|
return jsonable_encoder(
|
||||||
|
BaseBackwardsInvocationResponse(
|
||||||
|
data=PluginModelBackwardsInvocation.invoke_moderation(
|
||||||
|
user_id=user_model.id,
|
||||||
|
tenant=tenant_model,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInvokeToolApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
|
@plugin_data(payload_type=RequestInvokeTool)
|
||||||
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeTool):
|
||||||
|
def generator():
|
||||||
|
return PluginToolBackwardsInvocation.convert_to_event_stream(
|
||||||
|
PluginToolBackwardsInvocation.invoke_tool(
|
||||||
|
tenant_id=tenant_model.id,
|
||||||
|
user_id=user_model.id,
|
||||||
|
tool_type=ToolProviderType.value_of(payload.tool_type),
|
||||||
|
provider=payload.provider,
|
||||||
|
tool_name=payload.tool,
|
||||||
|
tool_parameters=payload.tool_parameters,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return compact_generate_response(generator())
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInvokeParameterExtractorNodeApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
|
@plugin_data(payload_type=RequestInvokeParameterExtractorNode)
|
||||||
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode):
|
||||||
|
try:
|
||||||
|
return jsonable_encoder(
|
||||||
|
BaseBackwardsInvocationResponse(
|
||||||
|
data=PluginNodeBackwardsInvocation.invoke_parameter_extractor(
|
||||||
|
tenant_id=tenant_model.id,
|
||||||
|
user_id=user_model.id,
|
||||||
|
parameters=payload.parameters,
|
||||||
|
model_config=payload.model,
|
||||||
|
instruction=payload.instruction,
|
||||||
|
query=payload.query,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInvokeQuestionClassifierNodeApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
|
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
|
||||||
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode):
|
||||||
|
try:
|
||||||
|
return jsonable_encoder(
|
||||||
|
BaseBackwardsInvocationResponse(
|
||||||
|
data=PluginNodeBackwardsInvocation.invoke_question_classifier(
|
||||||
|
tenant_id=tenant_model.id,
|
||||||
|
user_id=user_model.id,
|
||||||
|
query=payload.query,
|
||||||
|
model_config=payload.model,
|
||||||
|
classes=payload.classes,
|
||||||
|
instruction=payload.instruction,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return jsonable_encoder(BaseBackwardsInvocationResponse(error=str(e)))
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInvokeAppApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
|
@plugin_data(payload_type=RequestInvokeApp)
|
||||||
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeApp):
|
||||||
|
response = PluginAppBackwardsInvocation.invoke_app(
|
||||||
|
app_id=payload.app_id,
|
||||||
|
user_id=user_model.id,
|
||||||
|
tenant_id=tenant_model.id,
|
||||||
|
conversation_id=payload.conversation_id,
|
||||||
|
query=payload.query,
|
||||||
|
stream=payload.response_mode == "streaming",
|
||||||
|
inputs=payload.inputs,
|
||||||
|
files=payload.files,
|
||||||
|
)
|
||||||
|
|
||||||
|
return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response))
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInvokeEncryptApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
|
@plugin_data(payload_type=RequestInvokeEncrypt)
|
||||||
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeEncrypt):
|
||||||
|
"""
|
||||||
|
encrypt or decrypt data
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return BaseBackwardsInvocationResponse(
|
||||||
|
data=PluginEncrypter.invoke_encrypt(tenant_model, payload)
|
||||||
|
).model_dump()
|
||||||
|
except Exception as e:
|
||||||
|
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInvokeSummaryApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
|
@plugin_data(payload_type=RequestInvokeSummary)
|
||||||
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestInvokeSummary):
|
||||||
|
try:
|
||||||
|
return BaseBackwardsInvocationResponse(
|
||||||
|
data={
|
||||||
|
"summary": PluginModelBackwardsInvocation.invoke_summary(
|
||||||
|
user_id=user_model.id,
|
||||||
|
tenant=tenant_model,
|
||||||
|
payload=payload,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
).model_dump()
|
||||||
|
except Exception as e:
|
||||||
|
return BaseBackwardsInvocationResponse(error=str(e)).model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
class PluginUploadFileRequestApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@plugin_inner_api_only
|
||||||
|
@get_user_tenant
|
||||||
|
@plugin_data(payload_type=RequestRequestUploadFile)
|
||||||
|
def post(self, user_model: Account | EndUser, tenant_model: Tenant, payload: RequestRequestUploadFile):
|
||||||
|
# generate signed url
|
||||||
|
url = get_signed_file_url_for_plugin(payload.filename, payload.mimetype, tenant_model.id, user_model.id)
|
||||||
|
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
|
||||||
|
api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
|
||||||
|
api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
|
||||||
|
api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
|
||||||
|
api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text")
|
||||||
|
api.add_resource(PluginInvokeModerationApi, "/invoke/moderation")
|
||||||
|
api.add_resource(PluginInvokeToolApi, "/invoke/tool")
|
||||||
|
api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor")
|
||||||
|
api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier")
|
||||||
|
api.add_resource(PluginInvokeAppApi, "/invoke/app")
|
||||||
|
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")
|
||||||
|
api.add_resource(PluginInvokeSummaryApi, "/invoke/summary")
|
||||||
|
api.add_resource(PluginUploadFileRequestApi, "/upload/file/request")
|
||||||
@ -0,0 +1,116 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restful import reqparse # type: ignore
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.account import Account, Tenant
|
||||||
|
from models.model import EndUser
|
||||||
|
from services.account_service import AccountService
|
||||||
|
|
||||||
|
|
||||||
|
def get_user(tenant_id: str, user_id: str | None) -> Account | EndUser:
|
||||||
|
try:
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
if not user_id:
|
||||||
|
user_id = "DEFAULT-USER"
|
||||||
|
|
||||||
|
if user_id == "DEFAULT-USER":
|
||||||
|
user_model = session.query(EndUser).filter(EndUser.session_id == "DEFAULT-USER").first()
|
||||||
|
if not user_model:
|
||||||
|
user_model = EndUser(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
type="service_api",
|
||||||
|
is_anonymous=True if user_id == "DEFAULT-USER" else False,
|
||||||
|
session_id=user_id,
|
||||||
|
)
|
||||||
|
session.add(user_model)
|
||||||
|
session.commit()
|
||||||
|
else:
|
||||||
|
user_model = AccountService.load_user(user_id)
|
||||||
|
if not user_model:
|
||||||
|
user_model = session.query(EndUser).filter(EndUser.id == user_id).first()
|
||||||
|
if not user_model:
|
||||||
|
raise ValueError("user not found")
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("user not found")
|
||||||
|
|
||||||
|
return user_model
|
||||||
|
|
||||||
|
|
||||||
|
def get_user_tenant(view: Optional[Callable] = None):
|
||||||
|
def decorator(view_func):
|
||||||
|
@wraps(view_func)
|
||||||
|
def decorated_view(*args, **kwargs):
|
||||||
|
# fetch json body
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("tenant_id", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("user_id", type=str, required=True, location="json")
|
||||||
|
|
||||||
|
kwargs = parser.parse_args()
|
||||||
|
|
||||||
|
user_id = kwargs.get("user_id")
|
||||||
|
tenant_id = kwargs.get("tenant_id")
|
||||||
|
|
||||||
|
if not tenant_id:
|
||||||
|
raise ValueError("tenant_id is required")
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
user_id = "DEFAULT-USER"
|
||||||
|
|
||||||
|
del kwargs["tenant_id"]
|
||||||
|
del kwargs["user_id"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
tenant_model = (
|
||||||
|
db.session.query(Tenant)
|
||||||
|
.filter(
|
||||||
|
Tenant.id == tenant_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("tenant not found")
|
||||||
|
|
||||||
|
if not tenant_model:
|
||||||
|
raise ValueError("tenant not found")
|
||||||
|
|
||||||
|
kwargs["tenant_model"] = tenant_model
|
||||||
|
kwargs["user_model"] = get_user(tenant_id, user_id)
|
||||||
|
|
||||||
|
return view_func(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated_view
|
||||||
|
|
||||||
|
if view is None:
|
||||||
|
return decorator
|
||||||
|
else:
|
||||||
|
return decorator(view)
|
||||||
|
|
||||||
|
|
||||||
|
def plugin_data(view: Optional[Callable] = None, *, payload_type: type[BaseModel]):
|
||||||
|
def decorator(view_func):
|
||||||
|
def decorated_view(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
data = request.get_json()
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("invalid json")
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = payload_type(**data)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"invalid payload: {str(e)}")
|
||||||
|
|
||||||
|
kwargs["payload"] = payload
|
||||||
|
return view_func(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated_view
|
||||||
|
|
||||||
|
if view is None:
|
||||||
|
return decorator
|
||||||
|
else:
|
||||||
|
return decorator(view)
|
||||||
@ -0,0 +1,89 @@
|
|||||||
|
import enum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator
|
||||||
|
|
||||||
|
from core.entities.parameter_entities import CommonParameterType
|
||||||
|
from core.plugin.entities.parameters import (
|
||||||
|
PluginParameter,
|
||||||
|
as_normal_type,
|
||||||
|
cast_parameter_value,
|
||||||
|
init_frontend_parameter,
|
||||||
|
)
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
from core.tools.entities.tool_entities import (
|
||||||
|
ToolIdentity,
|
||||||
|
ToolProviderIdentity,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentStrategyProviderIdentity(ToolProviderIdentity):
|
||||||
|
"""
|
||||||
|
Inherits from ToolProviderIdentity, without any additional fields.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AgentStrategyParameter(PluginParameter):
|
||||||
|
class AgentStrategyParameterType(enum.StrEnum):
|
||||||
|
"""
|
||||||
|
Keep all the types from PluginParameterType
|
||||||
|
"""
|
||||||
|
|
||||||
|
STRING = CommonParameterType.STRING.value
|
||||||
|
NUMBER = CommonParameterType.NUMBER.value
|
||||||
|
BOOLEAN = CommonParameterType.BOOLEAN.value
|
||||||
|
SELECT = CommonParameterType.SELECT.value
|
||||||
|
SECRET_INPUT = CommonParameterType.SECRET_INPUT.value
|
||||||
|
FILE = CommonParameterType.FILE.value
|
||||||
|
FILES = CommonParameterType.FILES.value
|
||||||
|
APP_SELECTOR = CommonParameterType.APP_SELECTOR.value
|
||||||
|
MODEL_SELECTOR = CommonParameterType.MODEL_SELECTOR.value
|
||||||
|
TOOLS_SELECTOR = CommonParameterType.TOOLS_SELECTOR.value
|
||||||
|
|
||||||
|
# deprecated, should not use.
|
||||||
|
SYSTEM_FILES = CommonParameterType.SYSTEM_FILES.value
|
||||||
|
|
||||||
|
def as_normal_type(self):
|
||||||
|
return as_normal_type(self)
|
||||||
|
|
||||||
|
def cast_value(self, value: Any):
|
||||||
|
return cast_parameter_value(self, value)
|
||||||
|
|
||||||
|
type: AgentStrategyParameterType = Field(..., description="The type of the parameter")
|
||||||
|
|
||||||
|
def init_frontend_parameter(self, value: Any):
|
||||||
|
return init_frontend_parameter(self, self.type, value)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentStrategyProviderEntity(BaseModel):
|
||||||
|
identity: AgentStrategyProviderIdentity
|
||||||
|
plugin_id: Optional[str] = Field(None, description="The id of the plugin")
|
||||||
|
|
||||||
|
|
||||||
|
class AgentStrategyIdentity(ToolIdentity):
|
||||||
|
"""
|
||||||
|
Inherits from ToolIdentity, without any additional fields.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AgentStrategyEntity(BaseModel):
|
||||||
|
identity: AgentStrategyIdentity
|
||||||
|
parameters: list[AgentStrategyParameter] = Field(default_factory=list)
|
||||||
|
description: I18nObject = Field(..., description="The description of the agent strategy")
|
||||||
|
output_schema: Optional[dict] = None
|
||||||
|
|
||||||
|
# pydantic configs
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
@field_validator("parameters", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[AgentStrategyParameter]:
|
||||||
|
return v or []
|
||||||
|
|
||||||
|
|
||||||
|
class AgentProviderEntityWithPlugin(AgentStrategyProviderEntity):
|
||||||
|
strategies: list[AgentStrategyEntity] = Field(default_factory=list)
|
||||||
@ -0,0 +1,42 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import Generator, Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from core.agent.entities import AgentInvokeMessage
|
||||||
|
from core.agent.plugin_entities import AgentStrategyParameter
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAgentStrategy(ABC):
|
||||||
|
"""
|
||||||
|
Agent Strategy
|
||||||
|
"""
|
||||||
|
|
||||||
|
def invoke(
|
||||||
|
self,
|
||||||
|
params: dict[str, Any],
|
||||||
|
user_id: str,
|
||||||
|
conversation_id: Optional[str] = None,
|
||||||
|
app_id: Optional[str] = None,
|
||||||
|
message_id: Optional[str] = None,
|
||||||
|
) -> Generator[AgentInvokeMessage, None, None]:
|
||||||
|
"""
|
||||||
|
Invoke the agent strategy.
|
||||||
|
"""
|
||||||
|
yield from self._invoke(params, user_id, conversation_id, app_id, message_id)
|
||||||
|
|
||||||
|
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||||
|
"""
|
||||||
|
Get the parameters for the agent strategy.
|
||||||
|
"""
|
||||||
|
return []
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
params: dict[str, Any],
|
||||||
|
user_id: str,
|
||||||
|
conversation_id: Optional[str] = None,
|
||||||
|
app_id: Optional[str] = None,
|
||||||
|
message_id: Optional[str] = None,
|
||||||
|
) -> Generator[AgentInvokeMessage, None, None]:
|
||||||
|
pass
|
||||||
@ -0,0 +1,59 @@
|
|||||||
|
from collections.abc import Generator, Sequence
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from core.agent.entities import AgentInvokeMessage
|
||||||
|
from core.agent.plugin_entities import AgentStrategyEntity, AgentStrategyParameter
|
||||||
|
from core.agent.strategy.base import BaseAgentStrategy
|
||||||
|
from core.plugin.manager.agent import PluginAgentManager
|
||||||
|
from core.plugin.utils.converter import convert_parameters_to_plugin_format
|
||||||
|
|
||||||
|
|
||||||
|
class PluginAgentStrategy(BaseAgentStrategy):
|
||||||
|
"""
|
||||||
|
Agent Strategy
|
||||||
|
"""
|
||||||
|
|
||||||
|
tenant_id: str
|
||||||
|
declaration: AgentStrategyEntity
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, declaration: AgentStrategyEntity):
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.declaration = declaration
|
||||||
|
|
||||||
|
def get_parameters(self) -> Sequence[AgentStrategyParameter]:
|
||||||
|
return self.declaration.parameters
|
||||||
|
|
||||||
|
def initialize_parameters(self, params: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Initialize the parameters for the agent strategy.
|
||||||
|
"""
|
||||||
|
for parameter in self.declaration.parameters:
|
||||||
|
params[parameter.name] = parameter.init_frontend_parameter(params.get(parameter.name))
|
||||||
|
return params
|
||||||
|
|
||||||
|
def _invoke(
|
||||||
|
self,
|
||||||
|
params: dict[str, Any],
|
||||||
|
user_id: str,
|
||||||
|
conversation_id: Optional[str] = None,
|
||||||
|
app_id: Optional[str] = None,
|
||||||
|
message_id: Optional[str] = None,
|
||||||
|
) -> Generator[AgentInvokeMessage, None, None]:
|
||||||
|
"""
|
||||||
|
Invoke the agent strategy.
|
||||||
|
"""
|
||||||
|
manager = PluginAgentManager()
|
||||||
|
|
||||||
|
initialized_params = self.initialize_parameters(params)
|
||||||
|
params = convert_parameters_to_plugin_format(initialized_params)
|
||||||
|
|
||||||
|
yield from manager.invoke(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
agent_provider=self.declaration.identity.provider,
|
||||||
|
agent_strategy=self.declaration.identity.name,
|
||||||
|
agent_params=params,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
app_id=app_id,
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
@ -1,5 +1,26 @@
|
|||||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
from collections.abc import Generator, Iterable, Mapping
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler, print_text
|
||||||
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||||
|
|
||||||
|
|
||||||
class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
|
class DifyWorkflowCallbackHandler(DifyAgentCallbackHandler):
|
||||||
"""Callback Handler that prints to std out."""
|
"""Callback Handler that prints to std out."""
|
||||||
|
|
||||||
|
def on_tool_execution(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
tool_inputs: Mapping[str, Any],
|
||||||
|
tool_outputs: Iterable[ToolInvokeMessage],
|
||||||
|
message_id: Optional[str] = None,
|
||||||
|
timer: Optional[Any] = None,
|
||||||
|
trace_manager: Optional[TraceQueueManager] = None,
|
||||||
|
) -> Generator[ToolInvokeMessage, None, None]:
|
||||||
|
for tool_output in tool_outputs:
|
||||||
|
print_text("\n[on_tool_execution]\n", color=self.color)
|
||||||
|
print_text("Tool: " + tool_name + "\n", color=self.color)
|
||||||
|
print_text("Outputs: " + tool_output.model_dump_json()[:1000] + "\n", color=self.color)
|
||||||
|
print_text("\n")
|
||||||
|
yield tool_output
|
||||||
|
|||||||
@ -0,0 +1 @@
|
|||||||
|
DEFAULT_PLUGIN_ID = "langgenius"
|
||||||
@ -0,0 +1,42 @@
|
|||||||
|
from enum import StrEnum
|
||||||
|
|
||||||
|
|
||||||
|
class CommonParameterType(StrEnum):
|
||||||
|
SECRET_INPUT = "secret-input"
|
||||||
|
TEXT_INPUT = "text-input"
|
||||||
|
SELECT = "select"
|
||||||
|
STRING = "string"
|
||||||
|
NUMBER = "number"
|
||||||
|
FILE = "file"
|
||||||
|
FILES = "files"
|
||||||
|
SYSTEM_FILES = "system-files"
|
||||||
|
BOOLEAN = "boolean"
|
||||||
|
APP_SELECTOR = "app-selector"
|
||||||
|
MODEL_SELECTOR = "model-selector"
|
||||||
|
TOOLS_SELECTOR = "array[tools]"
|
||||||
|
|
||||||
|
# TOOL_SELECTOR = "tool-selector"
|
||||||
|
|
||||||
|
|
||||||
|
class AppSelectorScope(StrEnum):
|
||||||
|
ALL = "all"
|
||||||
|
CHAT = "chat"
|
||||||
|
WORKFLOW = "workflow"
|
||||||
|
COMPLETION = "completion"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelSelectorScope(StrEnum):
|
||||||
|
LLM = "llm"
|
||||||
|
TEXT_EMBEDDING = "text-embedding"
|
||||||
|
RERANK = "rerank"
|
||||||
|
TTS = "tts"
|
||||||
|
SPEECH2TEXT = "speech2text"
|
||||||
|
MODERATION = "moderation"
|
||||||
|
VISION = "vision"
|
||||||
|
|
||||||
|
|
||||||
|
class ToolSelectorScope(StrEnum):
|
||||||
|
ALL = "all"
|
||||||
|
CUSTOM = "custom"
|
||||||
|
BUILTIN = "builtin"
|
||||||
|
WORKFLOW = "workflow"
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue