feat: add api-based extension & external data tool & moderation backend (#1403)
Co-authored-by: takatost <takatost@gmail.com>pull/1462/head
parent
7699621983
commit
db43ed6f41
@ -0,0 +1,114 @@
|
|||||||
|
from flask_restful import Resource, reqparse, marshal_with
|
||||||
|
from flask_login import current_user
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.setup import setup_required
|
||||||
|
from controllers.console.wraps import account_initialization_required
|
||||||
|
from libs.login import login_required
|
||||||
|
from models.api_based_extension import APIBasedExtension
|
||||||
|
from fields.api_based_extension_fields import api_based_extension_fields
|
||||||
|
from services.code_based_extension_service import CodeBasedExtensionService
|
||||||
|
from services.api_based_extension_service import APIBasedExtensionService
|
||||||
|
|
||||||
|
|
||||||
|
class CodeBasedExtensionAPI(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('module', type=str, required=True, location='args')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
return {
|
||||||
|
'module': args['module'],
|
||||||
|
'data': CodeBasedExtensionService.get_code_based_extension(args['module'])
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class APIBasedExtensionAPI(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@marshal_with(api_based_extension_fields)
|
||||||
|
def get(self):
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@marshal_with(api_based_extension_fields)
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('name', type=str, required=True, location='json')
|
||||||
|
parser.add_argument('api_endpoint', type=str, required=True, location='json')
|
||||||
|
parser.add_argument('api_key', type=str, required=True, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
extension_data = APIBasedExtension(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
name=args['name'],
|
||||||
|
api_endpoint=args['api_endpoint'],
|
||||||
|
api_key=args['api_key']
|
||||||
|
)
|
||||||
|
|
||||||
|
return APIBasedExtensionService.save(extension_data)
|
||||||
|
|
||||||
|
|
||||||
|
class APIBasedExtensionDetailAPI(Resource):
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@marshal_with(api_based_extension_fields)
|
||||||
|
def get(self, id):
|
||||||
|
api_based_extension_id = str(id)
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@marshal_with(api_based_extension_fields)
|
||||||
|
def post(self, id):
|
||||||
|
api_based_extension_id = str(id)
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument('name', type=str, required=True, location='json')
|
||||||
|
parser.add_argument('api_endpoint', type=str, required=True, location='json')
|
||||||
|
parser.add_argument('api_key', type=str, required=True, location='json')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
extension_data_from_db.name = args['name']
|
||||||
|
extension_data_from_db.api_endpoint = args['api_endpoint']
|
||||||
|
|
||||||
|
if args['api_key'] != '[__HIDDEN__]':
|
||||||
|
extension_data_from_db.api_key = args['api_key']
|
||||||
|
|
||||||
|
return APIBasedExtensionService.save(extension_data_from_db)
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def delete(self, id):
|
||||||
|
api_based_extension_id = str(id)
|
||||||
|
tenant_id = current_user.current_tenant_id
|
||||||
|
|
||||||
|
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
|
||||||
|
|
||||||
|
APIBasedExtensionService.delete(extension_data_from_db)
|
||||||
|
|
||||||
|
return {'result': 'success'}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(CodeBasedExtensionAPI, '/code-based-extension')
|
||||||
|
|
||||||
|
api.add_resource(APIBasedExtensionAPI, '/api-based-extension')
|
||||||
|
api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/<uuid:id>')
|
||||||
@ -0,0 +1 @@
|
|||||||
|
import core.moderation.base
|
||||||
@ -1,92 +0,0 @@
|
|||||||
import enum
|
|
||||||
import logging
|
|
||||||
from typing import List, Dict, Optional, Any
|
|
||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
|
||||||
from langchain.chains.base import Chain
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from core.model_providers.error import LLMBadRequestError
|
|
||||||
from core.model_providers.model_factory import ModelFactory
|
|
||||||
from core.model_providers.models.llm.base import BaseLLM
|
|
||||||
from core.model_providers.models.moderation import openai_moderation
|
|
||||||
|
|
||||||
|
|
||||||
class SensitiveWordAvoidanceRule(BaseModel):
|
|
||||||
class Type(enum.Enum):
|
|
||||||
MODERATION = "moderation"
|
|
||||||
KEYWORDS = "keywords"
|
|
||||||
|
|
||||||
type: Type
|
|
||||||
canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
|
|
||||||
extra_params: dict = {}
|
|
||||||
|
|
||||||
|
|
||||||
class SensitiveWordAvoidanceChain(Chain):
|
|
||||||
input_key: str = "input" #: :meta private:
|
|
||||||
output_key: str = "output" #: :meta private:
|
|
||||||
|
|
||||||
model_instance: BaseLLM
|
|
||||||
sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _chain_type(self) -> str:
|
|
||||||
return "sensitive_word_avoidance_chain"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_keys(self) -> List[str]:
|
|
||||||
"""Expect input key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.input_key]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_keys(self) -> List[str]:
|
|
||||||
"""Return output key.
|
|
||||||
|
|
||||||
:meta private:
|
|
||||||
"""
|
|
||||||
return [self.output_key]
|
|
||||||
|
|
||||||
def _check_sensitive_word(self, text: str) -> bool:
|
|
||||||
for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
|
|
||||||
if word in text:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _check_moderation(self, text: str) -> bool:
|
|
||||||
moderation_model_instance = ModelFactory.get_moderation_model(
|
|
||||||
tenant_id=self.model_instance.model_provider.provider.tenant_id,
|
|
||||||
model_provider_name='openai',
|
|
||||||
model_name=openai_moderation.DEFAULT_MODEL
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return moderation_model_instance.run(text=text)
|
|
||||||
except Exception as ex:
|
|
||||||
logging.exception(ex)
|
|
||||||
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
|
|
||||||
|
|
||||||
def _call(
|
|
||||||
self,
|
|
||||||
inputs: Dict[str, Any],
|
|
||||||
run_manager: Optional[CallbackManagerForChainRun] = None,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
text = inputs[self.input_key]
|
|
||||||
|
|
||||||
if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
|
|
||||||
result = self._check_sensitive_word(text)
|
|
||||||
else:
|
|
||||||
result = self._check_moderation(text)
|
|
||||||
|
|
||||||
if not result:
|
|
||||||
raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response)
|
|
||||||
|
|
||||||
return {self.output_key: text}
|
|
||||||
|
|
||||||
|
|
||||||
class SensitiveWordAvoidanceError(Exception):
|
|
||||||
def __init__(self, message):
|
|
||||||
super().__init__(message)
|
|
||||||
self.message = message
|
|
||||||
@ -0,0 +1,62 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from models.api_based_extension import APIBasedExtensionPoint
|
||||||
|
|
||||||
|
|
||||||
|
class APIBasedExtensionRequestor:
|
||||||
|
timeout: (int, int) = (5, 60)
|
||||||
|
"""timeout for request connect and read"""
|
||||||
|
|
||||||
|
def __init__(self, api_endpoint: str, api_key: str) -> None:
|
||||||
|
self.api_endpoint = api_endpoint
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
def request(self, point: APIBasedExtensionPoint, params: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Request the api.
|
||||||
|
|
||||||
|
:param point: the api point
|
||||||
|
:param params: the request params
|
||||||
|
:return: the response json
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": "Bearer {}".format(self.api_key)
|
||||||
|
}
|
||||||
|
|
||||||
|
url = self.api_endpoint
|
||||||
|
|
||||||
|
try:
|
||||||
|
# proxy support for security
|
||||||
|
proxies = None
|
||||||
|
if os.environ.get("API_BASED_EXTENSION_HTTP_PROXY") and os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"):
|
||||||
|
proxies = {
|
||||||
|
'http': os.environ.get("API_BASED_EXTENSION_HTTP_PROXY"),
|
||||||
|
'https': os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"),
|
||||||
|
}
|
||||||
|
|
||||||
|
response = requests.request(
|
||||||
|
method='POST',
|
||||||
|
url=url,
|
||||||
|
json={
|
||||||
|
'point': point.value,
|
||||||
|
'params': params
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
timeout=self.timeout,
|
||||||
|
proxies=proxies
|
||||||
|
)
|
||||||
|
except requests.exceptions.Timeout:
|
||||||
|
raise ValueError("request timeout")
|
||||||
|
except requests.exceptions.ConnectionError:
|
||||||
|
raise ValueError("request connection error")
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise ValueError("request error, status_code: {}, content: {}".format(
|
||||||
|
response.status_code,
|
||||||
|
response.text[:100]
|
||||||
|
))
|
||||||
|
|
||||||
|
return response.json()
|
||||||
@ -0,0 +1,111 @@
|
|||||||
|
import enum
|
||||||
|
import importlib.util
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class ExtensionModule(enum.Enum):
|
||||||
|
MODERATION = 'moderation'
|
||||||
|
EXTERNAL_DATA_TOOL = 'external_data_tool'
|
||||||
|
|
||||||
|
|
||||||
|
class ModuleExtension(BaseModel):
|
||||||
|
extension_class: Any
|
||||||
|
name: str
|
||||||
|
label: Optional[dict] = None
|
||||||
|
form_schema: Optional[list] = None
|
||||||
|
builtin: bool = True
|
||||||
|
position: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class Extensible:
|
||||||
|
module: ExtensionModule
|
||||||
|
|
||||||
|
name: str
|
||||||
|
tenant_id: str
|
||||||
|
config: Optional[dict] = None
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def scan_extensions(cls):
|
||||||
|
extensions = {}
|
||||||
|
|
||||||
|
# get the path of the current class
|
||||||
|
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
|
||||||
|
current_dir_path = os.path.dirname(current_path)
|
||||||
|
|
||||||
|
# traverse subdirectories
|
||||||
|
for subdir_name in os.listdir(current_dir_path):
|
||||||
|
if subdir_name.startswith('__'):
|
||||||
|
continue
|
||||||
|
|
||||||
|
subdir_path = os.path.join(current_dir_path, subdir_name)
|
||||||
|
extension_name = subdir_name
|
||||||
|
if os.path.isdir(subdir_path):
|
||||||
|
file_names = os.listdir(subdir_path)
|
||||||
|
|
||||||
|
# is builtin extension, builtin extension
|
||||||
|
# in the front-end page and business logic, there are special treatments.
|
||||||
|
builtin = False
|
||||||
|
position = None
|
||||||
|
if '__builtin__' in file_names:
|
||||||
|
builtin = True
|
||||||
|
|
||||||
|
builtin_file_path = os.path.join(subdir_path, '__builtin__')
|
||||||
|
if os.path.exists(builtin_file_path):
|
||||||
|
with open(builtin_file_path, 'r') as f:
|
||||||
|
position = int(f.read().strip())
|
||||||
|
|
||||||
|
if (extension_name + '.py') not in file_names:
|
||||||
|
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
|
||||||
|
py_path = os.path.join(subdir_path, extension_name + '.py')
|
||||||
|
spec = importlib.util.spec_from_file_location(extension_name, py_path)
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(mod)
|
||||||
|
|
||||||
|
extension_class = None
|
||||||
|
for name, obj in vars(mod).items():
|
||||||
|
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
|
||||||
|
extension_class = obj
|
||||||
|
break
|
||||||
|
|
||||||
|
if not extension_class:
|
||||||
|
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
json_data = {}
|
||||||
|
if not builtin:
|
||||||
|
if 'schema.json' not in file_names:
|
||||||
|
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
json_path = os.path.join(subdir_path, 'schema.json')
|
||||||
|
json_data = {}
|
||||||
|
if os.path.exists(json_path):
|
||||||
|
with open(json_path, 'r') as f:
|
||||||
|
json_data = json.load(f)
|
||||||
|
|
||||||
|
extensions[extension_name] = ModuleExtension(
|
||||||
|
extension_class=extension_class,
|
||||||
|
name=extension_name,
|
||||||
|
label=json_data.get('label'),
|
||||||
|
form_schema=json_data.get('form_schema'),
|
||||||
|
builtin=builtin,
|
||||||
|
position=position
|
||||||
|
)
|
||||||
|
|
||||||
|
sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position))
|
||||||
|
sorted_extensions = OrderedDict(sorted_items)
|
||||||
|
|
||||||
|
return sorted_extensions
|
||||||
@ -0,0 +1,47 @@
|
|||||||
|
from core.extension.extensible import ModuleExtension, ExtensionModule
|
||||||
|
from core.external_data_tool.base import ExternalDataTool
|
||||||
|
from core.moderation.base import Moderation
|
||||||
|
|
||||||
|
|
||||||
|
class Extension:
|
||||||
|
__module_extensions: dict[str, dict[str, ModuleExtension]] = {}
|
||||||
|
|
||||||
|
module_classes = {
|
||||||
|
ExtensionModule.MODERATION: Moderation,
|
||||||
|
ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool
|
||||||
|
}
|
||||||
|
|
||||||
|
def init(self):
|
||||||
|
for module, module_class in self.module_classes.items():
|
||||||
|
self.__module_extensions[module.value] = module_class.scan_extensions()
|
||||||
|
|
||||||
|
def module_extensions(self, module: str) -> list[ModuleExtension]:
|
||||||
|
module_extensions = self.__module_extensions.get(module)
|
||||||
|
|
||||||
|
if not module_extensions:
|
||||||
|
raise ValueError(f"Extension Module {module} not found")
|
||||||
|
|
||||||
|
return list(module_extensions.values())
|
||||||
|
|
||||||
|
def module_extension(self, module: ExtensionModule, extension_name: str) -> ModuleExtension:
|
||||||
|
module_extensions = self.__module_extensions.get(module.value)
|
||||||
|
|
||||||
|
if not module_extensions:
|
||||||
|
raise ValueError(f"Extension Module {module} not found")
|
||||||
|
|
||||||
|
module_extension = module_extensions.get(extension_name)
|
||||||
|
|
||||||
|
if not module_extension:
|
||||||
|
raise ValueError(f"Extension {extension_name} not found")
|
||||||
|
|
||||||
|
return module_extension
|
||||||
|
|
||||||
|
def extension_class(self, module: ExtensionModule, extension_name: str) -> type:
|
||||||
|
module_extension = self.module_extension(module, extension_name)
|
||||||
|
return module_extension.extension_class
|
||||||
|
|
||||||
|
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
|
||||||
|
module_extension = self.module_extension(module, extension_name)
|
||||||
|
form_schema = module_extension.form_schema
|
||||||
|
|
||||||
|
# TODO validate form_schema
|
||||||
@ -0,0 +1 @@
|
|||||||
|
1
|
||||||
@ -0,0 +1,92 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
|
||||||
|
from core.external_data_tool.base import ExternalDataTool
|
||||||
|
from core.helper import encrypter
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||||
|
|
||||||
|
|
||||||
|
class ApiExternalDataTool(ExternalDataTool):
|
||||||
|
"""
|
||||||
|
The api external data tool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = "api"
|
||||||
|
"""the unique name of external data tool"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate the incoming form config data.
|
||||||
|
|
||||||
|
:param tenant_id: the id of workspace
|
||||||
|
:param config: the form config data
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# own validation logic
|
||||||
|
api_based_extension_id = config.get("api_based_extension_id")
|
||||||
|
if not api_based_extension_id:
|
||||||
|
raise ValueError("api_based_extension_id is required")
|
||||||
|
|
||||||
|
# get api_based_extension
|
||||||
|
api_based_extension = db.session.query(APIBasedExtension).filter(
|
||||||
|
APIBasedExtension.tenant_id == tenant_id,
|
||||||
|
APIBasedExtension.id == api_based_extension_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not api_based_extension:
|
||||||
|
raise ValueError("api_based_extension_id is invalid")
|
||||||
|
|
||||||
|
def query(self, inputs: dict, query: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Query the external data tool.
|
||||||
|
|
||||||
|
:param inputs: user inputs
|
||||||
|
:param query: the query of chat app
|
||||||
|
:return: the tool query result
|
||||||
|
"""
|
||||||
|
# get params from config
|
||||||
|
api_based_extension_id = self.config.get("api_based_extension_id")
|
||||||
|
|
||||||
|
# get api_based_extension
|
||||||
|
api_based_extension = db.session.query(APIBasedExtension).filter(
|
||||||
|
APIBasedExtension.tenant_id == self.tenant_id,
|
||||||
|
APIBasedExtension.id == api_based_extension_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not api_based_extension:
|
||||||
|
raise ValueError("[External data tool] API query failed, variable: {}, "
|
||||||
|
"error: api_based_extension_id is invalid"
|
||||||
|
.format(self.config.get('variable')))
|
||||||
|
|
||||||
|
# decrypt api_key
|
||||||
|
api_key = encrypter.decrypt_token(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
token=api_based_extension.api_key
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# request api
|
||||||
|
requestor = APIBasedExtensionRequestor(
|
||||||
|
api_endpoint=api_based_extension.api_endpoint,
|
||||||
|
api_key=api_key
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(
|
||||||
|
self.config.get('variable'),
|
||||||
|
e
|
||||||
|
))
|
||||||
|
|
||||||
|
response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={
|
||||||
|
'app_id': self.app_id,
|
||||||
|
'tool_variable': self.variable,
|
||||||
|
'inputs': inputs,
|
||||||
|
'query': query
|
||||||
|
})
|
||||||
|
|
||||||
|
if 'result' not in response_json:
|
||||||
|
raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response"
|
||||||
|
.format(self.config.get('variable')))
|
||||||
|
|
||||||
|
return response_json['result']
|
||||||
@ -0,0 +1,45 @@
|
|||||||
|
from abc import abstractmethod, ABC
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from core.extension.extensible import Extensible, ExtensionModule
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalDataTool(Extensible, ABC):
|
||||||
|
"""
|
||||||
|
The base class of external data tool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
module: ExtensionModule = ExtensionModule.EXTERNAL_DATA_TOOL
|
||||||
|
|
||||||
|
app_id: str
|
||||||
|
"""the id of app"""
|
||||||
|
variable: str
|
||||||
|
"""the tool variable name of app tool"""
|
||||||
|
|
||||||
|
def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None:
|
||||||
|
super().__init__(tenant_id, config)
|
||||||
|
self.app_id = app_id
|
||||||
|
self.variable = variable
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate the incoming form config data.
|
||||||
|
|
||||||
|
:param tenant_id: the id of workspace
|
||||||
|
:param config: the form config data
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def query(self, inputs: dict, query: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Query the external data tool.
|
||||||
|
|
||||||
|
:param inputs: user inputs
|
||||||
|
:param query: the query of chat app
|
||||||
|
:return: the tool query result
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
@ -0,0 +1,40 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from core.extension.extensible import ExtensionModule
|
||||||
|
from extensions.ext_code_based_extension import code_based_extension
|
||||||
|
|
||||||
|
|
||||||
|
class ExternalDataToolFactory:
|
||||||
|
|
||||||
|
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None:
|
||||||
|
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||||
|
self.__extension_instance = extension_class(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
app_id=app_id,
|
||||||
|
variable=variable,
|
||||||
|
config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate the incoming form config data.
|
||||||
|
|
||||||
|
:param name: the name of external data tool
|
||||||
|
:param tenant_id: the id of workspace
|
||||||
|
:param config: the form config data
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
|
||||||
|
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
|
||||||
|
extension_class.validate_config(tenant_id, config)
|
||||||
|
|
||||||
|
def query(self, inputs: dict, query: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Query the external data tool.
|
||||||
|
|
||||||
|
:param inputs: user inputs
|
||||||
|
:param query: the query of chat app
|
||||||
|
:return: the tool query result
|
||||||
|
"""
|
||||||
|
return self.__extension_instance.query(inputs, query)
|
||||||
@ -0,0 +1 @@
|
|||||||
|
3
|
||||||
@ -0,0 +1,88 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
|
||||||
|
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor, APIBasedExtensionPoint
|
||||||
|
from core.helper.encrypter import decrypt_token
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.api_based_extension import APIBasedExtension
|
||||||
|
|
||||||
|
|
||||||
|
class ModerationInputParams(BaseModel):
|
||||||
|
app_id: str = ""
|
||||||
|
inputs: dict = {}
|
||||||
|
query: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class ModerationOutputParams(BaseModel):
|
||||||
|
app_id: str = ""
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class ApiModeration(Moderation):
|
||||||
|
name: str = "api"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate the incoming form config data.
|
||||||
|
|
||||||
|
:param tenant_id: the id of workspace
|
||||||
|
:param config: the form config data
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
cls._validate_inputs_and_outputs_config(config, False)
|
||||||
|
|
||||||
|
api_based_extension_id = config.get("api_based_extension_id")
|
||||||
|
if not api_based_extension_id:
|
||||||
|
raise ValueError("api_based_extension_id is required")
|
||||||
|
|
||||||
|
extension = cls._get_api_based_extension(tenant_id, api_based_extension_id)
|
||||||
|
if not extension:
|
||||||
|
raise ValueError("API-based Extension not found. Please check it again.")
|
||||||
|
|
||||||
|
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||||
|
flagged = False
|
||||||
|
preset_response = ""
|
||||||
|
|
||||||
|
if self.config['inputs_config']['enabled']:
|
||||||
|
params = ModerationInputParams(
|
||||||
|
app_id=self.app_id,
|
||||||
|
inputs=inputs,
|
||||||
|
query=query
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.dict())
|
||||||
|
return ModerationInputsResult(**result)
|
||||||
|
|
||||||
|
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||||
|
|
||||||
|
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||||
|
flagged = False
|
||||||
|
preset_response = ""
|
||||||
|
|
||||||
|
if self.config['outputs_config']['enabled']:
|
||||||
|
params = ModerationOutputParams(
|
||||||
|
app_id=self.app_id,
|
||||||
|
text=text
|
||||||
|
)
|
||||||
|
|
||||||
|
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.dict())
|
||||||
|
return ModerationOutputsResult(**result)
|
||||||
|
|
||||||
|
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||||
|
|
||||||
|
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict:
|
||||||
|
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id"))
|
||||||
|
requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key))
|
||||||
|
|
||||||
|
result = requestor.request(extension_point, params)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
|
||||||
|
extension = db.session.query(APIBasedExtension).filter(
|
||||||
|
APIBasedExtension.tenant_id == tenant_id,
|
||||||
|
APIBasedExtension.id == api_based_extension_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
return extension
|
||||||
@ -0,0 +1,113 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from enum import Enum
|
||||||
|
|
||||||
|
from core.extension.extensible import Extensible, ExtensionModule
|
||||||
|
|
||||||
|
|
||||||
|
class ModerationAction(Enum):
|
||||||
|
DIRECT_OUTPUT = 'direct_output'
|
||||||
|
OVERRIDED = 'overrided'
|
||||||
|
|
||||||
|
|
||||||
|
class ModerationInputsResult(BaseModel):
|
||||||
|
flagged: bool = False
|
||||||
|
action: ModerationAction
|
||||||
|
preset_response: str = ""
|
||||||
|
inputs: dict = {}
|
||||||
|
query: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class ModerationOutputsResult(BaseModel):
|
||||||
|
flagged: bool = False
|
||||||
|
action: ModerationAction
|
||||||
|
preset_response: str = ""
|
||||||
|
text: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
class Moderation(Extensible, ABC):
|
||||||
|
"""
|
||||||
|
The base class of moderation.
|
||||||
|
"""
|
||||||
|
module: ExtensionModule = ExtensionModule.MODERATION
|
||||||
|
|
||||||
|
def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None:
|
||||||
|
super().__init__(tenant_id, config)
|
||||||
|
self.app_id = app_id
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate the incoming form config data.
|
||||||
|
|
||||||
|
:param tenant_id: the id of workspace
|
||||||
|
:param config: the form config data
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||||
|
"""
|
||||||
|
Moderation for inputs.
|
||||||
|
After the user inputs, this method will be called to perform sensitive content review
|
||||||
|
on the user inputs and return the processed results.
|
||||||
|
|
||||||
|
:param inputs: user inputs
|
||||||
|
:param query: query string (required in chat app)
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||||
|
"""
|
||||||
|
Moderation for outputs.
|
||||||
|
When LLM outputs content, the front end will pass the output content (may be segmented)
|
||||||
|
to this method for sensitive content review, and the output content will be shielded if the review fails.
|
||||||
|
|
||||||
|
:param text: LLM output content
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None:
|
||||||
|
# inputs_config
|
||||||
|
inputs_config = config.get("inputs_config")
|
||||||
|
if not isinstance(inputs_config, dict):
|
||||||
|
raise ValueError("inputs_config must be a dict")
|
||||||
|
|
||||||
|
# outputs_config
|
||||||
|
outputs_config = config.get("outputs_config")
|
||||||
|
if not isinstance(outputs_config, dict):
|
||||||
|
raise ValueError("outputs_config must be a dict")
|
||||||
|
|
||||||
|
inputs_config_enabled = inputs_config.get("enabled")
|
||||||
|
outputs_config_enabled = outputs_config.get("enabled")
|
||||||
|
if not inputs_config_enabled and not outputs_config_enabled:
|
||||||
|
raise ValueError("At least one of inputs_config or outputs_config must be enabled")
|
||||||
|
|
||||||
|
# preset_response
|
||||||
|
if not is_preset_response_required:
|
||||||
|
return
|
||||||
|
|
||||||
|
if inputs_config_enabled:
|
||||||
|
if not inputs_config.get("preset_response"):
|
||||||
|
raise ValueError("inputs_config.preset_response is required")
|
||||||
|
|
||||||
|
if len(inputs_config.get("preset_response")) > 100:
|
||||||
|
raise ValueError("inputs_config.preset_response must be less than 100 characters")
|
||||||
|
|
||||||
|
if outputs_config_enabled:
|
||||||
|
if not outputs_config.get("preset_response"):
|
||||||
|
raise ValueError("outputs_config.preset_response is required")
|
||||||
|
|
||||||
|
if len(outputs_config.get("preset_response")) > 100:
|
||||||
|
raise ValueError("outputs_config.preset_response must be less than 100 characters")
|
||||||
|
|
||||||
|
|
||||||
|
class ModerationException(Exception):
|
||||||
|
pass
|
||||||
@ -0,0 +1,48 @@
|
|||||||
|
from core.extension.extensible import ExtensionModule
|
||||||
|
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
|
||||||
|
from extensions.ext_code_based_extension import code_based_extension
|
||||||
|
|
||||||
|
|
||||||
|
class ModerationFactory:
|
||||||
|
__extension_instance: Moderation
|
||||||
|
|
||||||
|
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict) -> None:
|
||||||
|
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
|
||||||
|
self.__extension_instance = extension_class(app_id, tenant_id, config)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate the incoming form config data.
|
||||||
|
|
||||||
|
:param name: the name of extension
|
||||||
|
:param tenant_id: the id of workspace
|
||||||
|
:param config: the form config data
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
|
||||||
|
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
|
||||||
|
extension_class.validate_config(tenant_id, config)
|
||||||
|
|
||||||
|
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||||
|
"""
|
||||||
|
Moderation for inputs.
|
||||||
|
After the user inputs, this method will be called to perform sensitive content review
|
||||||
|
on the user inputs and return the processed results.
|
||||||
|
|
||||||
|
:param inputs: user inputs
|
||||||
|
:param query: query string (required in chat app)
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return self.__extension_instance.moderation_for_inputs(inputs, query)
|
||||||
|
|
||||||
|
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||||
|
"""
|
||||||
|
Moderation for outputs.
|
||||||
|
When LLM outputs content, the front end will pass the output content (may be segmented)
|
||||||
|
to this method for sensitive content review, and the output content will be shielded if the review fails.
|
||||||
|
|
||||||
|
:param text: LLM output content
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return self.__extension_instance.moderation_for_outputs(text)
|
||||||
@ -0,0 +1 @@
|
|||||||
|
2
|
||||||
@ -0,0 +1,60 @@
|
|||||||
|
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordsModeration(Moderation):
|
||||||
|
name: str = "keywords"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate the incoming form config data.
|
||||||
|
|
||||||
|
:param tenant_id: the id of workspace
|
||||||
|
:param config: the form config data
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
cls._validate_inputs_and_outputs_config(config, True)
|
||||||
|
|
||||||
|
if not config.get("keywords"):
|
||||||
|
raise ValueError("keywords is required")
|
||||||
|
|
||||||
|
if len(config.get("keywords")) > 1000:
|
||||||
|
raise ValueError("keywords length must be less than 1000")
|
||||||
|
|
||||||
|
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||||
|
flagged = False
|
||||||
|
preset_response = ""
|
||||||
|
|
||||||
|
if self.config['inputs_config']['enabled']:
|
||||||
|
preset_response = self.config['inputs_config']['preset_response']
|
||||||
|
|
||||||
|
if query:
|
||||||
|
inputs['query__'] = query
|
||||||
|
keywords_list = self.config['keywords'].split('\n')
|
||||||
|
flagged = self._is_violated(inputs, keywords_list)
|
||||||
|
|
||||||
|
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||||
|
|
||||||
|
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||||
|
flagged = False
|
||||||
|
preset_response = ""
|
||||||
|
|
||||||
|
if self.config['outputs_config']['enabled']:
|
||||||
|
keywords_list = self.config['keywords'].split('\n')
|
||||||
|
flagged = self._is_violated({'text': text}, keywords_list)
|
||||||
|
preset_response = self.config['outputs_config']['preset_response']
|
||||||
|
|
||||||
|
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||||
|
|
||||||
|
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
|
||||||
|
for value in inputs.values():
|
||||||
|
if self._check_keywords_in_value(keywords_list, value):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _check_keywords_in_value(self, keywords_list, value):
|
||||||
|
for keyword in keywords_list:
|
||||||
|
if keyword.lower() in value.lower():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
@ -0,0 +1 @@
|
|||||||
|
1
|
||||||
@ -0,0 +1,46 @@
|
|||||||
|
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
|
||||||
|
from core.model_providers.model_factory import ModelFactory
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIModeration(Moderation):
|
||||||
|
name: str = "openai_moderation"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate_config(cls, tenant_id: str, config: dict) -> None:
|
||||||
|
"""
|
||||||
|
Validate the incoming form config data.
|
||||||
|
|
||||||
|
:param tenant_id: the id of workspace
|
||||||
|
:param config: the form config data
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
cls._validate_inputs_and_outputs_config(config, True)
|
||||||
|
|
||||||
|
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
|
||||||
|
flagged = False
|
||||||
|
preset_response = ""
|
||||||
|
|
||||||
|
if self.config['inputs_config']['enabled']:
|
||||||
|
preset_response = self.config['inputs_config']['preset_response']
|
||||||
|
|
||||||
|
if query:
|
||||||
|
inputs['query__'] = query
|
||||||
|
flagged = self._is_violated(inputs)
|
||||||
|
|
||||||
|
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||||
|
|
||||||
|
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
|
||||||
|
flagged = False
|
||||||
|
preset_response = ""
|
||||||
|
|
||||||
|
if self.config['outputs_config']['enabled']:
|
||||||
|
flagged = self._is_violated({'text': text})
|
||||||
|
preset_response = self.config['outputs_config']['preset_response']
|
||||||
|
|
||||||
|
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
|
||||||
|
|
||||||
|
def _is_violated(self, inputs: dict):
|
||||||
|
text = '\n'.join(inputs.values())
|
||||||
|
openai_moderation = ModelFactory.get_moderation_model(self.tenant_id, "openai", "moderation")
|
||||||
|
is_not_invalid = openai_moderation.run(text)
|
||||||
|
return not is_not_invalid
|
||||||
@ -0,0 +1,8 @@
|
|||||||
|
from core.extension.extension import Extension
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
code_based_extension.init()
|
||||||
|
|
||||||
|
|
||||||
|
code_based_extension = Extension()
|
||||||
@ -0,0 +1,17 @@
|
|||||||
|
from flask_restful import fields
|
||||||
|
|
||||||
|
from libs.helper import TimestampField
|
||||||
|
|
||||||
|
|
||||||
|
class HiddenAPIKey(fields.Raw):
|
||||||
|
def output(self, key, obj):
|
||||||
|
return obj.api_key[:3] + '***' + obj.api_key[-3:]
|
||||||
|
|
||||||
|
|
||||||
|
api_based_extension_fields = {
|
||||||
|
'id': fields.String,
|
||||||
|
'name': fields.String,
|
||||||
|
'api_endpoint': fields.String,
|
||||||
|
'api_key': HiddenAPIKey,
|
||||||
|
'created_at': TimestampField
|
||||||
|
}
|
||||||
@ -0,0 +1,45 @@
|
|||||||
|
"""add_api_based_extension
|
||||||
|
|
||||||
|
Revision ID: 968fff4c0ab9
|
||||||
|
Revises: b3a09c049e8e
|
||||||
|
Create Date: 2023-10-27 13:05:58.901858
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = '968fff4c0ab9'
|
||||||
|
down_revision = 'b3a09c049e8e'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
|
||||||
|
op.create_table('api_based_extensions',
|
||||||
|
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', postgresql.UUID(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('api_endpoint', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('api_key', sa.Text(), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='api_based_extension_pkey')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('api_based_extensions', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('api_based_extension_tenant_idx', ['tenant_id'], unique=False)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
|
||||||
|
with op.batch_alter_table('api_based_extensions', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('api_based_extension_tenant_idx')
|
||||||
|
|
||||||
|
op.drop_table('api_based_extensions')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -0,0 +1,32 @@
|
|||||||
|
"""add external_data_tools in app model config
|
||||||
|
|
||||||
|
Revision ID: a9836e3baeee
|
||||||
|
Revises: 968fff4c0ab9
|
||||||
|
Create Date: 2023-11-02 04:04:57.609485
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'a9836e3baeee'
|
||||||
|
down_revision = '968fff4c0ab9'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('external_data_tools', sa.Text(), nullable=True))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('external_data_tools')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -0,0 +1,27 @@
|
|||||||
|
import enum
|
||||||
|
|
||||||
|
from sqlalchemy.dialects.postgresql import UUID
|
||||||
|
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
|
||||||
|
class APIBasedExtensionPoint(enum.Enum):
|
||||||
|
APP_EXTERNAL_DATA_TOOL_QUERY = 'app.external_data_tool.query'
|
||||||
|
PING = 'ping'
|
||||||
|
APP_MODERATION_INPUT = 'app.moderation.input'
|
||||||
|
APP_MODERATION_OUTPUT = 'app.moderation.output'
|
||||||
|
|
||||||
|
|
||||||
|
class APIBasedExtension(db.Model):
|
||||||
|
__tablename__ = 'api_based_extensions'
|
||||||
|
__table_args__ = (
|
||||||
|
db.PrimaryKeyConstraint('id', name='api_based_extension_pkey'),
|
||||||
|
db.Index('api_based_extension_tenant_idx', 'tenant_id'),
|
||||||
|
)
|
||||||
|
|
||||||
|
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||||
|
tenant_id = db.Column(UUID, nullable=False)
|
||||||
|
name = db.Column(db.String(255), nullable=False)
|
||||||
|
api_endpoint = db.Column(db.String(255), nullable=False)
|
||||||
|
api_key = db.Column(db.Text, nullable=False)
|
||||||
|
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||||
@ -0,0 +1,98 @@
|
|||||||
|
from extensions.ext_database import db
|
||||||
|
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
|
||||||
|
from core.helper.encrypter import encrypt_token, decrypt_token
|
||||||
|
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
|
||||||
|
|
||||||
|
|
||||||
|
class APIBasedExtensionService:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_all_by_tenant_id(tenant_id: str) -> list[APIBasedExtension]:
|
||||||
|
extension_list = db.session.query(APIBasedExtension) \
|
||||||
|
.filter_by(tenant_id=tenant_id) \
|
||||||
|
.order_by(APIBasedExtension.created_at.desc()) \
|
||||||
|
.all()
|
||||||
|
|
||||||
|
for extension in extension_list:
|
||||||
|
extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
|
||||||
|
|
||||||
|
return extension_list
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save(cls, extension_data: APIBasedExtension) -> APIBasedExtension:
|
||||||
|
cls._validation(extension_data)
|
||||||
|
|
||||||
|
extension_data.api_key = encrypt_token(extension_data.tenant_id, extension_data.api_key)
|
||||||
|
|
||||||
|
db.session.add(extension_data)
|
||||||
|
db.session.commit()
|
||||||
|
return extension_data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def delete(extension_data: APIBasedExtension) -> None:
|
||||||
|
db.session.delete(extension_data)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_with_tenant_id(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
|
||||||
|
extension = db.session.query(APIBasedExtension) \
|
||||||
|
.filter_by(tenant_id=tenant_id) \
|
||||||
|
.filter_by(id=api_based_extension_id) \
|
||||||
|
.first()
|
||||||
|
|
||||||
|
if not extension:
|
||||||
|
raise ValueError("API based extension is not found")
|
||||||
|
|
||||||
|
extension.api_key = decrypt_token(extension.tenant_id, extension.api_key)
|
||||||
|
|
||||||
|
return extension
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validation(cls, extension_data: APIBasedExtension) -> None:
|
||||||
|
# name
|
||||||
|
if not extension_data.name:
|
||||||
|
raise ValueError("name must not be empty")
|
||||||
|
|
||||||
|
if not extension_data.id:
|
||||||
|
# case one: check new data, name must be unique
|
||||||
|
is_name_existed = db.session.query(APIBasedExtension) \
|
||||||
|
.filter_by(tenant_id=extension_data.tenant_id) \
|
||||||
|
.filter_by(name=extension_data.name) \
|
||||||
|
.first()
|
||||||
|
|
||||||
|
if is_name_existed:
|
||||||
|
raise ValueError("name must be unique, it is already existed")
|
||||||
|
else:
|
||||||
|
# case two: check existing data, name must be unique
|
||||||
|
is_name_existed = db.session.query(APIBasedExtension) \
|
||||||
|
.filter_by(tenant_id=extension_data.tenant_id) \
|
||||||
|
.filter_by(name=extension_data.name) \
|
||||||
|
.filter(APIBasedExtension.id != extension_data.id) \
|
||||||
|
.first()
|
||||||
|
|
||||||
|
if is_name_existed:
|
||||||
|
raise ValueError("name must be unique, it is already existed")
|
||||||
|
|
||||||
|
# api_endpoint
|
||||||
|
if not extension_data.api_endpoint:
|
||||||
|
raise ValueError("api_endpoint must not be empty")
|
||||||
|
|
||||||
|
# api_key
|
||||||
|
if not extension_data.api_key:
|
||||||
|
raise ValueError("api_key must not be empty")
|
||||||
|
|
||||||
|
if len(extension_data.api_key) < 5:
|
||||||
|
raise ValueError("api_key must be at least 5 characters")
|
||||||
|
|
||||||
|
# check endpoint
|
||||||
|
cls._ping_connection(extension_data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ping_connection(extension_data: APIBasedExtension) -> None:
|
||||||
|
try:
|
||||||
|
client = APIBasedExtensionRequestor(extension_data.api_endpoint, extension_data.api_key)
|
||||||
|
resp = client.request(point=APIBasedExtensionPoint.PING, params={})
|
||||||
|
if resp.get('result') != 'pong':
|
||||||
|
raise ValueError(resp)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError("connection error: {}".format(e))
|
||||||
@ -0,0 +1,13 @@
|
|||||||
|
from extensions.ext_code_based_extension import code_based_extension
|
||||||
|
|
||||||
|
|
||||||
|
class CodeBasedExtensionService:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_code_based_extension(module: str) -> list[dict]:
|
||||||
|
module_extensions = code_based_extension.module_extensions(module)
|
||||||
|
return [{
|
||||||
|
'name': module_extension.name,
|
||||||
|
'label': module_extension.label,
|
||||||
|
'form_schema': module_extension.form_schema
|
||||||
|
} for module_extension in module_extensions if not module_extension.builtin]
|
||||||
@ -0,0 +1,20 @@
|
|||||||
|
from models.model import AppModelConfig, App
|
||||||
|
from core.moderation.factory import ModerationFactory, ModerationOutputsResult
|
||||||
|
from extensions.ext_database import db
|
||||||
|
|
||||||
|
|
||||||
|
class ModerationService:
|
||||||
|
|
||||||
|
def moderation_for_outputs(self, app_id: str, app_model: App, text: str) -> ModerationOutputsResult:
|
||||||
|
app_model_config: AppModelConfig = None
|
||||||
|
|
||||||
|
app_model_config = db.session.query(AppModelConfig).filter(AppModelConfig.id == app_model.app_model_config_id).first()
|
||||||
|
|
||||||
|
if not app_model_config:
|
||||||
|
raise ValueError("app model config not found")
|
||||||
|
|
||||||
|
name = app_model_config.sensitive_word_avoidance_dict['type']
|
||||||
|
config = app_model_config.sensitive_word_avoidance_dict['config']
|
||||||
|
|
||||||
|
moderation = ModerationFactory(name, app_id, app_model.tenant_id, config)
|
||||||
|
return moderation.moderation_for_outputs(text)
|
||||||
Loading…
Reference in New Issue