初始化合并后端代码
parent
0720bc7408
commit
ae25db7ad1
@ -0,0 +1,42 @@
|
||||
#data_source:
|
||||
# type: upload_file
|
||||
# info_list:
|
||||
# data_source_type: upload_file
|
||||
# file_info_list:
|
||||
# file_ids:
|
||||
# - none
|
||||
indexing_technique: high_quality
|
||||
process_rule:
|
||||
rules:
|
||||
pre_processing_rules:
|
||||
- id: remove_extra_spaces
|
||||
enabled: true
|
||||
- id: remove_urls_emails
|
||||
enabled: true
|
||||
segmentation:
|
||||
separator: '&&&&&'
|
||||
max_tokens: 500
|
||||
chunk_overlap: 50
|
||||
mode: custom
|
||||
doc_form: text_model
|
||||
doc_language: Chinese
|
||||
retrieval_model:
|
||||
search_method: hybrid_search
|
||||
reranking_enable: true
|
||||
reranking_mode: weighted_score
|
||||
reranking_model:
|
||||
reranking_provider_name: langgenius/huggingface_tei/huggingface_tei
|
||||
reranking_model_name: bge-reranker-large
|
||||
weights:
|
||||
weight_type: customized
|
||||
vector_setting:
|
||||
vector_weight: 0.7
|
||||
embedding_provider_name: ''
|
||||
embedding_model_name: ''
|
||||
keyword_setting:
|
||||
keyword_weight: 0.3
|
||||
top_k: 10
|
||||
score_threshold_enabled: false
|
||||
score_threshold: 0
|
||||
embedding_model: 'bge-m3:latest'
|
||||
embedding_model_provider: langgenius/ollama/ollama
|
||||
@ -0,0 +1,11 @@
|
||||
model: ${INIT_MODEL_TEXT_EMBEDDING_NAME}
|
||||
model_type: text-embedding
|
||||
credentials:
|
||||
mode: chat
|
||||
context_size: ${INIT_MODEL_TEXT_EMBEDDING_CONTEXT_SIZE}
|
||||
max_tokens: ${INIT_MODEL_TEXT_EMBEDDING_MAX_TOKENS}
|
||||
vision_support: false
|
||||
function_call_support: false
|
||||
base_url: ${INIT_MODEL_TEXT_EMBEDDING_BASE_URL}
|
||||
load_balancing:
|
||||
enabled: false
|
||||
@ -0,0 +1,6 @@
|
||||
model: ${INIT_MODEL_TEXT_EMBEDDING_RERANK_NAME}
|
||||
model_type: rerank
|
||||
credentials:
|
||||
server_url: ${INIT_MODEL_TEXT_EMBEDDING_RERANK_BASE_URL}
|
||||
load_balancing:
|
||||
enabled: false
|
||||
@ -0,0 +1,11 @@
|
||||
model: ${INIT_MODEL_LLM_NAME}
|
||||
model_type: llm
|
||||
credentials:
|
||||
mode: chat
|
||||
context_size: ${INIT_MODEL_LLM_CONTEXT_SIZE}
|
||||
max_tokens: ${INIT_MODEL_LLM_MAX_TOKENS}
|
||||
vision_support: false
|
||||
function_call_support: false
|
||||
base_url: ${INIT_MODEL_LLM_BASE_URL}
|
||||
load_balancing:
|
||||
enabled: false
|
||||
@ -0,0 +1,38 @@
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
def get_init_knowledge_config(config:dict) -> dict :
|
||||
return get_ext_config(file_name="dataset_config.yml", config=config)
|
||||
|
||||
def get_ext_config(file_name:str, config:dict = None,params : dict = None) -> dict :
|
||||
# 获取当前脚本所在的目录
|
||||
current_dir = Path(__file__).resolve().parent
|
||||
# 构造绝对路径
|
||||
config_path = current_dir / "ext" / file_name
|
||||
# 读取 YAML 文件
|
||||
with open(config_path, "r") as f:
|
||||
config_data = yaml.safe_load(f) # 使用 safe_load 避免执行任意代码
|
||||
config_data = replace_placeholders(data = config_data, params = params)
|
||||
if config is not None:
|
||||
config_data={**config_data,**config}
|
||||
|
||||
return config_data
|
||||
|
||||
# 定义一个函数,用于替换 YAML 中的占位符
|
||||
def replace_placeholders(data, params:dict = None) -> dict:
|
||||
if params is not None:
|
||||
if isinstance(data, dict):
|
||||
# 如果是字典,递归处理每个键值对
|
||||
return {k: replace_placeholders(v, params) for k, v in data.items()}
|
||||
elif isinstance(data, list):
|
||||
# 如果是列表,递归处理每个元素
|
||||
return [replace_placeholders(item, params) for item in data]
|
||||
elif isinstance(data, str):
|
||||
# 如果是字符串,尝试替换占位符
|
||||
for key, value in params.items():
|
||||
placeholder = f"${{{key}}}" # 构造占位符格式,例如 ${DB_HOST}
|
||||
data = data.replace(placeholder, value)
|
||||
return data
|
||||
|
||||
# 其他类型直接返回
|
||||
return data
|
||||
@ -0,0 +1,87 @@
|
||||
|
||||
from flask_restful import Resource, reqparse # type: ignore
|
||||
import flask_login
|
||||
from unstructured.utils import first
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.wraps import setup_required
|
||||
from services.ext.account_ext_service import AccountExtService, TenantExtService
|
||||
from models.account import (
|
||||
Account,
|
||||
Tenant,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
|
||||
class AccountsApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("accounts",
|
||||
type=lambda x: x if isinstance(x, list) else [] ,
|
||||
required=True,
|
||||
location="json")
|
||||
parser.add_argument("target_tenant_id", type=str,
|
||||
required=True,
|
||||
location="json")
|
||||
args = parser.parse_args()
|
||||
target_tenant_id = args["target_tenant_id"]
|
||||
accounts = args["accounts"]
|
||||
AccountExtService.update_account_list(accounts=accounts,
|
||||
target_tenant_id=target_tenant_id)
|
||||
|
||||
return {}
|
||||
|
||||
class LoginAccountInfo:
|
||||
|
||||
def __init__(self, id, name, tenant_id):
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"id": self.id,
|
||||
"name": self.name,
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
class LoginAccountsApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def get(self):
|
||||
current_user = flask_login.current_user
|
||||
# current_user_info = db.session.query(Account).filter(Account.id==current_user.id).first()
|
||||
tenant = current_user.current_tenant
|
||||
login_account = LoginAccountInfo(id=current_user.id, name=current_user.name, tenant_id=tenant.id)
|
||||
return login_account.to_dict()
|
||||
|
||||
class TenantEnableApi(Resource):
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("target_tenant_id", type=str, required=True, location="json")
|
||||
parser.add_argument("target_tenant_name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
target_tenant_id = args["target_tenant_id"]
|
||||
target_tenant_name = args["target_tenant_name"]
|
||||
tenant_account_info = TenantExtService.enable_tenant(target_tenant_id=target_tenant_id,target_tenant_name=target_tenant_name)
|
||||
return tenant_account_info.to_dict(),200
|
||||
|
||||
class TenantInitApi(Resource):
|
||||
|
||||
@setup_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("target_tenant_id", type=str, required=True, location="json")
|
||||
parser.add_argument("target_tenant_name", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
target_tenant_id = args["target_tenant_id"]
|
||||
target_tenant_name = args["target_tenant_name"]
|
||||
tenant_data = TenantExtService.init_tenant(target_tenant_id=target_tenant_id,target_tenant_name=target_tenant_name)
|
||||
return tenant_data.to_dict(),200
|
||||
|
||||
api.add_resource(AccountsApi, "/accounts/update")
|
||||
api.add_resource(TenantEnableApi, "/tenant/enable")
|
||||
api.add_resource(TenantInitApi, "/tenant/init")
|
||||
api.add_resource(LoginAccountsApi, "/login/account/info")
|
||||
@ -0,0 +1,156 @@
|
||||
import base64
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import queue
|
||||
import re
|
||||
import threading
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
from core.app.entities.queue_entities import (
|
||||
MessageQueueMessage,
|
||||
QueueAgentMessageEvent,
|
||||
QueueLLMChunkEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueTextChunkEvent,
|
||||
WorkflowQueueMessage,
|
||||
)
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
|
||||
class AudioTrunk:
|
||||
def __init__(self, status: str, audio):
|
||||
self.audio = audio
|
||||
self.status = status
|
||||
|
||||
|
||||
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
|
||||
if not text_content or text_content.isspace():
|
||||
return
|
||||
return model_instance.invoke_tts(
|
||||
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
|
||||
)
|
||||
|
||||
|
||||
def _process_future(
|
||||
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None],
|
||||
audio_queue: queue.Queue[AudioTrunk],
|
||||
):
|
||||
while True:
|
||||
try:
|
||||
future = future_queue.get()
|
||||
if future is None:
|
||||
break
|
||||
invoke_result = future.result()
|
||||
if not invoke_result:
|
||||
continue
|
||||
for audio in invoke_result:
|
||||
audio_base64 = base64.b64encode(bytes(audio))
|
||||
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
|
||||
except Exception as e:
|
||||
logging.getLogger(__name__).warning(e)
|
||||
break
|
||||
audio_queue.put(AudioTrunk("finish", b""))
|
||||
|
||||
|
||||
class AppGeneratorTTSPublisher:
|
||||
def __init__(self, tenant_id: str, voice: str, language: Optional[str] = None):
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.tenant_id = tenant_id
|
||||
self.msg_text = ""
|
||||
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
|
||||
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
|
||||
self.match = re.compile(r"[。.!?]")
|
||||
self.model_manager = ModelManager()
|
||||
self.model_instance = self.model_manager.get_default_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.TTS
|
||||
)
|
||||
self.voices = self.model_instance.get_tts_voices(language=language)
|
||||
values = [voice.get("value") for voice in self.voices]
|
||||
self.voice = voice
|
||||
if not voice or voice not in values:
|
||||
self.voice = self.voices[0].get("value")
|
||||
self.MAX_SENTENCE = 2
|
||||
self._last_audio_event: Optional[AudioTrunk] = None
|
||||
# FIXME better way to handle this threading.start
|
||||
threading.Thread(target=self._runtime).start()
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
|
||||
|
||||
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
|
||||
self._msg_queue.put(message)
|
||||
|
||||
def _runtime(self):
|
||||
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue()
|
||||
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
|
||||
while True:
|
||||
try:
|
||||
message = self._msg_queue.get()
|
||||
if message is None:
|
||||
if self.msg_text and len(self.msg_text.strip()) > 0:
|
||||
futures_result = self.executor.submit(
|
||||
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
break
|
||||
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
|
||||
message_content = message.event.chunk.delta.message.content
|
||||
if not message_content:
|
||||
continue
|
||||
if isinstance(message_content, str):
|
||||
self.msg_text += message_content
|
||||
elif isinstance(message_content, list):
|
||||
for content in message_content:
|
||||
if not isinstance(content, TextPromptMessageContent):
|
||||
continue
|
||||
self.msg_text += content.data
|
||||
elif isinstance(message.event, QueueTextChunkEvent):
|
||||
self.msg_text += message.event.text
|
||||
elif isinstance(message.event, QueueNodeSucceededEvent):
|
||||
if message.event.outputs is None:
|
||||
continue
|
||||
self.msg_text += message.event.outputs.get("output", "")
|
||||
self.last_message = message
|
||||
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
|
||||
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
|
||||
self.MAX_SENTENCE += 1
|
||||
text_content = "".join(sentence_arr)
|
||||
futures_result = self.executor.submit(
|
||||
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
|
||||
)
|
||||
future_queue.put(futures_result)
|
||||
if text_tmp:
|
||||
self.msg_text = text_tmp
|
||||
else:
|
||||
self.msg_text = ""
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(e)
|
||||
break
|
||||
future_queue.put(None)
|
||||
|
||||
def check_and_get_audio(self):
|
||||
try:
|
||||
if self._last_audio_event and self._last_audio_event.status == "finish":
|
||||
if self.executor:
|
||||
self.executor.shutdown(wait=False)
|
||||
return self._last_audio_event
|
||||
audio = self._audio_queue.get_nowait()
|
||||
if audio and audio.status == "finish":
|
||||
self.executor.shutdown(wait=False)
|
||||
if audio:
|
||||
self._last_audio_event = audio
|
||||
return audio
|
||||
except queue.Empty:
|
||||
return None
|
||||
|
||||
def _extract_sentence(self, org_text):
|
||||
tx = self.match.finditer(org_text)
|
||||
start = 0
|
||||
result = []
|
||||
for i in tx:
|
||||
end = i.regs[0][1]
|
||||
result.append(org_text[start:end])
|
||||
start = end
|
||||
return result, org_text[start:]
|
||||
@ -0,0 +1,191 @@
|
||||
import logging
|
||||
from threading import Thread
|
||||
from typing import Optional, Union
|
||||
|
||||
from flask import Flask, current_app
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueMessageFileEvent,
|
||||
QueueRetrieverResourcesEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
EasyUITaskState,
|
||||
MessageFileStreamResponse,
|
||||
MessageReplaceStreamResponse,
|
||||
MessageStreamResponse,
|
||||
WorkflowTaskState,
|
||||
)
|
||||
from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
|
||||
class MessageCycleManage:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AdvancedChatAppGenerateEntity,
|
||||
],
|
||||
task_state: Union[EasyUITaskState, WorkflowTaskState],
|
||||
) -> None:
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._task_state = task_state
|
||||
|
||||
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
|
||||
"""
|
||||
Generate conversation name.
|
||||
:param conversation: conversation
|
||||
:param query: query
|
||||
:return: thread
|
||||
"""
|
||||
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
|
||||
return None
|
||||
|
||||
is_first_message = self._application_generate_entity.conversation_id is None
|
||||
extras = self._application_generate_entity.extras
|
||||
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
|
||||
|
||||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
thread = Thread(
|
||||
target=self._generate_conversation_name_worker,
|
||||
kwargs={
|
||||
"flask_app": current_app._get_current_object(), # type: ignore
|
||||
"conversation_id": conversation_id,
|
||||
"query": query,
|
||||
},
|
||||
)
|
||||
|
||||
thread.start()
|
||||
|
||||
return thread
|
||||
|
||||
return None
|
||||
|
||||
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
|
||||
with flask_app.app_context():
|
||||
# get conversation and message
|
||||
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
|
||||
|
||||
if not conversation:
|
||||
return
|
||||
|
||||
if conversation.mode != AppMode.COMPLETION.value:
|
||||
app_model = conversation.app
|
||||
if not app_model:
|
||||
return
|
||||
|
||||
# generate conversation name
|
||||
try:
|
||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
|
||||
conversation.name = name
|
||||
except Exception as e:
|
||||
if dify_config.DEBUG:
|
||||
logging.exception(f"generate conversation name failed, conversation_id: {conversation_id}")
|
||||
pass
|
||||
|
||||
db.session.merge(conversation)
|
||||
db.session.commit()
|
||||
db.session.close()
|
||||
|
||||
def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
|
||||
"""
|
||||
Handle annotation reply.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
|
||||
if annotation:
|
||||
account = annotation.account
|
||||
self._task_state.metadata["annotation_reply"] = {
|
||||
"id": annotation.id,
|
||||
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
|
||||
}
|
||||
|
||||
return annotation
|
||||
|
||||
return None
|
||||
|
||||
def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
|
||||
"""
|
||||
Handle retriever resources.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
|
||||
self._task_state.metadata["retriever_resources"] = event.retriever_resources
|
||||
|
||||
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
|
||||
"""
|
||||
Message file to stream response.
|
||||
:param event: event
|
||||
:return:
|
||||
"""
|
||||
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
|
||||
|
||||
if message_file and message_file.url is not None:
|
||||
# get tool file id
|
||||
tool_file_id = message_file.url.split("/")[-1]
|
||||
# trim extension
|
||||
tool_file_id = tool_file_id.split(".")[0]
|
||||
|
||||
# get extension
|
||||
if "." in message_file.url:
|
||||
extension = f".{message_file.url.split('.')[-1]}"
|
||||
if len(extension) > 10:
|
||||
extension = ".bin"
|
||||
else:
|
||||
extension = ".bin"
|
||||
# add sign url to local file
|
||||
if message_file.url.startswith("http"):
|
||||
url = message_file.url
|
||||
else:
|
||||
url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension)
|
||||
|
||||
return MessageFileStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_file.id,
|
||||
type=message_file.type,
|
||||
belongs_to=message_file.belongs_to or "user",
|
||||
url=url,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _message_to_stream_response(
|
||||
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
|
||||
) -> MessageStreamResponse:
|
||||
"""
|
||||
Message to stream response.
|
||||
:param answer: answer
|
||||
:param message_id: message id
|
||||
:return:
|
||||
"""
|
||||
return MessageStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_id,
|
||||
answer=answer,
|
||||
from_variable_selector=from_variable_selector,
|
||||
)
|
||||
|
||||
def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
|
||||
"""
|
||||
Message replace to stream response.
|
||||
:param answer: answer
|
||||
:return:
|
||||
"""
|
||||
return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer)
|
||||
@ -0,0 +1,964 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
QueueLoopCompletedEvent,
|
||||
QueueLoopNextEvent,
|
||||
QueueLoopStartEvent,
|
||||
QueueNodeExceptionEvent,
|
||||
QueueNodeFailedEvent,
|
||||
QueueNodeInIterationFailedEvent,
|
||||
QueueNodeInLoopFailedEvent,
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueParallelBranchRunFailedEvent,
|
||||
QueueParallelBranchRunStartedEvent,
|
||||
QueueParallelBranchRunSucceededEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AgentLogStreamResponse,
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
LoopNodeCompletedStreamResponse,
|
||||
LoopNodeNextStreamResponse,
|
||||
LoopNodeStartStreamResponse,
|
||||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
ParallelBranchFinishedStreamResponse,
|
||||
ParallelBranchStartStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.workflow.entities.node_entities import NodeRunMetadataKey
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from models.account import Account
|
||||
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
|
||||
from models.model import EndUser
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
WorkflowNodeExecutionTriggeredFrom,
|
||||
WorkflowRun,
|
||||
WorkflowRunStatus,
|
||||
)
|
||||
|
||||
from .exc import WorkflowRunNotFoundError
|
||||
|
||||
|
||||
class WorkflowCycleManage:
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
|
||||
workflow_system_variables: dict[SystemVariableKey, Any],
|
||||
) -> None:
|
||||
self._workflow_run: WorkflowRun | None = None
|
||||
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_system_variables = workflow_system_variables
|
||||
|
||||
def _handle_workflow_run_start(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
workflow_id: str,
|
||||
user_id: str,
|
||||
created_by_role: CreatedByRole,
|
||||
) -> WorkflowRun:
|
||||
workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
|
||||
workflow = session.scalar(workflow_stmt)
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow not found: {workflow_id}")
|
||||
|
||||
max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where(
|
||||
WorkflowRun.tenant_id == workflow.tenant_id,
|
||||
WorkflowRun.app_id == workflow.app_id,
|
||||
)
|
||||
max_sequence = session.scalar(max_sequence_stmt) or 0
|
||||
new_sequence_number = max_sequence + 1
|
||||
|
||||
inputs = {**self._application_generate_entity.inputs}
|
||||
for key, value in (self._workflow_system_variables or {}).items():
|
||||
if key.value == "conversation":
|
||||
continue
|
||||
inputs[f"sys.{key.value}"] = value
|
||||
|
||||
triggered_from = (
|
||||
WorkflowRunTriggeredFrom.DEBUGGING
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
|
||||
else WorkflowRunTriggeredFrom.APP_RUN
|
||||
)
|
||||
|
||||
# handle special values
|
||||
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
|
||||
|
||||
# init workflow run
|
||||
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
|
||||
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
|
||||
|
||||
workflow_run = WorkflowRun()
|
||||
workflow_run.id = workflow_run_id
|
||||
workflow_run.tenant_id = workflow.tenant_id
|
||||
workflow_run.app_id = workflow.app_id
|
||||
workflow_run.sequence_number = new_sequence_number
|
||||
workflow_run.workflow_id = workflow.id
|
||||
workflow_run.type = workflow.type
|
||||
workflow_run.triggered_from = triggered_from.value
|
||||
workflow_run.version = workflow.version
|
||||
workflow_run.graph = workflow.graph
|
||||
workflow_run.inputs = json.dumps(inputs)
|
||||
workflow_run.status = WorkflowRunStatus.RUNNING
|
||||
workflow_run.created_by_role = created_by_role
|
||||
workflow_run.created_by = user_id
|
||||
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_run)
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_success(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
workflow_run_id: str,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run success
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param outputs: outputs
|
||||
:param conversation_id: conversation id
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
|
||||
|
||||
outputs = WorkflowEntry.handle_special_values(outputs)
|
||||
|
||||
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
|
||||
workflow_run.outputs = json.dumps(outputs or {})
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_run=workflow_run,
|
||||
conversation_id=conversation_id,
|
||||
user_id=trace_manager.user_id,
|
||||
)
|
||||
)
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_partial_success(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
workflow_run_id: str,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
outputs: Mapping[str, Any] | None = None,
|
||||
exceptions_count: int = 0,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
) -> WorkflowRun:
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
|
||||
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
|
||||
|
||||
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
|
||||
workflow_run.outputs = json.dumps(outputs or {})
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_run=workflow_run,
|
||||
conversation_id=conversation_id,
|
||||
user_id=trace_manager.user_id,
|
||||
)
|
||||
)
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_workflow_run_failed(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
workflow_run_id: str,
|
||||
start_at: float,
|
||||
total_tokens: int,
|
||||
total_steps: int,
|
||||
status: WorkflowRunStatus,
|
||||
error: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
trace_manager: Optional[TraceQueueManager] = None,
|
||||
exceptions_count: int = 0,
|
||||
) -> WorkflowRun:
|
||||
"""
|
||||
Workflow run failed
|
||||
:param workflow_run: workflow run
|
||||
:param start_at: start time
|
||||
:param total_tokens: total tokens
|
||||
:param total_steps: total steps
|
||||
:param status: status
|
||||
:param error: error message
|
||||
:return:
|
||||
"""
|
||||
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
|
||||
|
||||
workflow_run.status = status.value
|
||||
workflow_run.error = error
|
||||
workflow_run.elapsed_time = time.perf_counter() - start_at
|
||||
workflow_run.total_tokens = total_tokens
|
||||
workflow_run.total_steps = total_steps
|
||||
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_run.exceptions_count = exceptions_count
|
||||
|
||||
stmt = select(WorkflowNodeExecution.node_execution_id).where(
|
||||
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
|
||||
WorkflowNodeExecution.app_id == workflow_run.app_id,
|
||||
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
|
||||
)
|
||||
ids = session.scalars(stmt).all()
|
||||
# Use self._get_workflow_node_execution here to make sure the cache is updated
|
||||
running_workflow_node_executions = [
|
||||
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
|
||||
]
|
||||
|
||||
for workflow_node_execution in running_workflow_node_executions:
|
||||
now = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
|
||||
workflow_node_execution.error = error
|
||||
workflow_node_execution.finished_at = now
|
||||
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_run=workflow_run,
|
||||
conversation_id=conversation_id,
|
||||
user_id=trace_manager.user_id,
|
||||
)
|
||||
)
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _handle_node_execution_start(
|
||||
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.id = str(uuid4())
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.execution_metadata = json.dumps(
|
||||
{
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
)
|
||||
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
|
||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_success(
|
||||
self, *, session: Session, event: QueueNodeSucceededEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
workflow_node_execution = self._get_workflow_node_execution(
|
||||
session=session, node_execution_id=event.node_execution_id
|
||||
)
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
execution_metadata_dict = dict(event.execution_metadata or {})
|
||||
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
|
||||
workflow_node_execution = session.merge(workflow_node_execution)
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_failed(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
event: QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
workflow_node_execution = self._get_workflow_node_execution(
|
||||
session=session, node_execution_id=event.node_execution_id
|
||||
)
|
||||
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - event.start_at).total_seconds()
|
||||
execution_metadata = (
|
||||
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
|
||||
)
|
||||
process_data = WorkflowEntry.handle_special_values(event.process_data)
|
||||
workflow_node_execution.status = (
|
||||
WorkflowNodeExecutionStatus.FAILED.value
|
||||
if not isinstance(event, QueueNodeExceptionEvent)
|
||||
else WorkflowNodeExecutionStatus.EXCEPTION.value
|
||||
)
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
|
||||
workflow_node_execution = session.merge(workflow_node_execution)
|
||||
return workflow_node_execution
|
||||
|
||||
def _handle_workflow_node_execution_retried(
|
||||
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
|
||||
) -> WorkflowNodeExecution:
|
||||
"""
|
||||
Workflow node execution failed
|
||||
:param event: queue node failed event
|
||||
:return:
|
||||
"""
|
||||
created_at = event.start_at
|
||||
finished_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
elapsed_time = (finished_at - created_at).total_seconds()
|
||||
inputs = WorkflowEntry.handle_special_values(event.inputs)
|
||||
outputs = WorkflowEntry.handle_special_values(event.outputs)
|
||||
origin_metadata = {
|
||||
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
|
||||
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
}
|
||||
merged_metadata = (
|
||||
{**jsonable_encoder(event.execution_metadata), **origin_metadata}
|
||||
if event.execution_metadata is not None
|
||||
else origin_metadata
|
||||
)
|
||||
execution_metadata = json.dumps(merged_metadata)
|
||||
|
||||
workflow_node_execution = WorkflowNodeExecution()
|
||||
workflow_node_execution.id = str(uuid4())
|
||||
workflow_node_execution.tenant_id = workflow_run.tenant_id
|
||||
workflow_node_execution.app_id = workflow_run.app_id
|
||||
workflow_node_execution.workflow_id = workflow_run.workflow_id
|
||||
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
|
||||
workflow_node_execution.workflow_run_id = workflow_run.id
|
||||
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
|
||||
workflow_node_execution.node_execution_id = event.node_execution_id
|
||||
workflow_node_execution.node_id = event.node_id
|
||||
workflow_node_execution.node_type = event.node_type.value
|
||||
workflow_node_execution.title = event.node_data.title
|
||||
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
|
||||
workflow_node_execution.created_by_role = workflow_run.created_by_role
|
||||
workflow_node_execution.created_by = workflow_run.created_by
|
||||
workflow_node_execution.created_at = created_at
|
||||
workflow_node_execution.finished_at = finished_at
|
||||
workflow_node_execution.elapsed_time = elapsed_time
|
||||
workflow_node_execution.error = event.error
|
||||
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
|
||||
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
|
||||
workflow_node_execution.execution_metadata = execution_metadata
|
||||
workflow_node_execution.index = event.node_run_index
|
||||
|
||||
session.add(workflow_node_execution)
|
||||
|
||||
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
|
||||
return workflow_node_execution
|
||||
|
||||
#################################################
|
||||
# to stream responses #
|
||||
#################################################
|
||||
|
||||
def _workflow_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
) -> WorkflowStartStreamResponse:
|
||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
||||
_ = session
|
||||
return WorkflowStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=WorkflowStartStreamResponse.Data(
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
inputs=dict(workflow_run.inputs_dict or {}),
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_finish_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
created_by = None
|
||||
if workflow_run.created_by_role == CreatedByRole.ACCOUNT:
|
||||
stmt = select(Account).where(Account.id == workflow_run.created_by)
|
||||
account = session.scalar(stmt)
|
||||
if account:
|
||||
created_by = {
|
||||
"id": account.id,
|
||||
"name": account.name,
|
||||
"email": account.email,
|
||||
}
|
||||
elif workflow_run.created_by_role == CreatedByRole.END_USER:
|
||||
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
|
||||
end_user = session.scalar(stmt)
|
||||
if end_user:
|
||||
created_by = {
|
||||
"id": end_user.id,
|
||||
"user": end_user.session_id,
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=workflow_run.id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
sequence_number=workflow_run.sequence_number,
|
||||
status=workflow_run.status,
|
||||
outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None,
|
||||
error=workflow_run.error,
|
||||
elapsed_time=workflow_run.elapsed_time,
|
||||
total_tokens=workflow_run.total_tokens,
|
||||
total_steps=workflow_run.total_steps,
|
||||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(workflow_run.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)),
|
||||
exceptions_count=workflow_run.exceptions_count,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_node_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
event: QueueNodeStartedEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeStartStreamResponse]:
|
||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
||||
_ = session
|
||||
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
return None
|
||||
|
||||
response = NodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeStartStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
title=workflow_node_execution.title,
|
||||
index=workflow_node_execution.index,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
parallel_run_id=event.parallel_mode_run_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
),
|
||||
)
|
||||
|
||||
# extras logic
|
||||
if event.node_type == NodeType.TOOL:
|
||||
node_data = cast(ToolNodeData, event.node_data)
|
||||
response.data.extras["icon"] = ToolManager.get_tool_icon(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
provider_type=node_data.provider_type,
|
||||
provider_id=node_data.provider_id,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _workflow_node_finish_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
event: QueueNodeSucceededEvent
|
||||
| QueueNodeFailedEvent
|
||||
| QueueNodeInIterationFailedEvent
|
||||
| QueueNodeInLoopFailedEvent
|
||||
| QueueNodeExceptionEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[NodeFinishStreamResponse]:
|
||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
||||
_ = session
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
return None
|
||||
if not workflow_node_execution.finished_at:
|
||||
return None
|
||||
|
||||
return NodeFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeFinishStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
process_data=workflow_node_execution.process_data_dict,
|
||||
outputs=workflow_node_execution.outputs_dict,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_node_retry_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
event: QueueNodeRetryEvent,
|
||||
task_id: str,
|
||||
workflow_node_execution: WorkflowNodeExecution,
|
||||
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
|
||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
||||
_ = session
|
||||
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
|
||||
return None
|
||||
if not workflow_node_execution.workflow_run_id:
|
||||
return None
|
||||
if not workflow_node_execution.finished_at:
|
||||
return None
|
||||
|
||||
return NodeRetryStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_node_execution.workflow_run_id,
|
||||
data=NodeRetryStreamResponse.Data(
|
||||
id=workflow_node_execution.id,
|
||||
node_id=workflow_node_execution.node_id,
|
||||
node_type=workflow_node_execution.node_type,
|
||||
index=workflow_node_execution.index,
|
||||
title=workflow_node_execution.title,
|
||||
predecessor_node_id=workflow_node_execution.predecessor_node_id,
|
||||
inputs=workflow_node_execution.inputs_dict,
|
||||
process_data=workflow_node_execution.process_data_dict,
|
||||
outputs=workflow_node_execution.outputs_dict,
|
||||
status=workflow_node_execution.status,
|
||||
error=workflow_node_execution.error,
|
||||
elapsed_time=workflow_node_execution.elapsed_time,
|
||||
execution_metadata=workflow_node_execution.execution_metadata_dict,
|
||||
created_at=int(workflow_node_execution.created_at.timestamp()),
|
||||
finished_at=int(workflow_node_execution.finished_at.timestamp()),
|
||||
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
retry_index=event.retry_index,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_start_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
|
||||
) -> ParallelBranchStartStreamResponse:
|
||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
||||
_ = session
|
||||
return ParallelBranchStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=ParallelBranchStartStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_parallel_branch_finished_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
|
||||
) -> ParallelBranchFinishedStreamResponse:
|
||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
||||
_ = session
|
||||
return ParallelBranchFinishedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=ParallelBranchFinishedStreamResponse.Data(
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_branch_id=event.parallel_start_node_id,
|
||||
parent_parallel_id=event.parent_parallel_id,
|
||||
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
|
||||
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
|
||||
created_at=int(time.time()),
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_iteration_start_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
|
||||
) -> IterationNodeStartStreamResponse:
|
||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
||||
_ = session
|
||||
return IterationNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_iteration_next_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
|
||||
) -> IterationNodeNextStreamResponse:
|
||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
||||
_ = session
|
||||
return IterationNodeNextStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
index=event.index,
|
||||
pre_iteration_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_iteration_completed_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
|
||||
) -> IterationNodeCompletedStreamResponse:
|
||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
||||
_ = session
|
||||
return IterationNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=IterationNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
outputs=event.outputs,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if event.error is None
|
||||
else WorkflowNodeExecutionStatus.FAILED,
|
||||
error=None,
|
||||
elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_loop_start_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent
|
||||
) -> LoopNodeStartStreamResponse:
|
||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
||||
_ = session
|
||||
return LoopNodeStartStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=LoopNodeStartStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
metadata=event.metadata or {},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_loop_next_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent
|
||||
) -> LoopNodeNextStreamResponse:
|
||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
||||
_ = session
|
||||
return LoopNodeNextStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=LoopNodeNextStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
index=event.index,
|
||||
pre_loop_output=event.output,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
parallel_mode_run_id=event.parallel_mode_run_id,
|
||||
duration=event.duration,
|
||||
),
|
||||
)
|
||||
|
||||
def _workflow_loop_completed_to_stream_response(
|
||||
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent
|
||||
) -> LoopNodeCompletedStreamResponse:
|
||||
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
|
||||
_ = session
|
||||
return LoopNodeCompletedStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
data=LoopNodeCompletedStreamResponse.Data(
|
||||
id=event.node_id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type.value,
|
||||
title=event.node_data.title,
|
||||
outputs=event.outputs,
|
||||
created_at=int(time.time()),
|
||||
extras={},
|
||||
inputs=event.inputs or {},
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
if event.error is None
|
||||
else WorkflowNodeExecutionStatus.FAILED,
|
||||
error=None,
|
||||
elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
|
||||
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
|
||||
execution_metadata=event.metadata,
|
||||
finished_at=int(time.time()),
|
||||
steps=event.steps,
|
||||
parallel_id=event.parallel_id,
|
||||
parallel_start_node_id=event.parallel_start_node_id,
|
||||
),
|
||||
)
|
||||
|
||||
def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from node outputs
|
||||
:param outputs_dict: node outputs dict
|
||||
:return:
|
||||
"""
|
||||
if not outputs_dict:
|
||||
return []
|
||||
|
||||
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
|
||||
# Remove None
|
||||
files = [file for file in files if file]
|
||||
# Flatten list
|
||||
# Flatten the list of sequences into a single list of mappings
|
||||
flattened_files = [file for sublist in files if sublist for file in sublist]
|
||||
|
||||
# Convert to tuple to match Sequence type
|
||||
return tuple(flattened_files)
|
||||
|
||||
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from variable value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return []
|
||||
|
||||
files = []
|
||||
if isinstance(value, list):
|
||||
for item in value:
|
||||
file = self._get_file_var_from_value(item)
|
||||
if file:
|
||||
files.append(file)
|
||||
elif isinstance(value, dict):
|
||||
file = self._get_file_var_from_value(value)
|
||||
if file:
|
||||
files.append(file)
|
||||
|
||||
return files
|
||||
|
||||
def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Get file var from value
|
||||
:param value: variable value
|
||||
:return:
|
||||
"""
|
||||
if not value:
|
||||
return None
|
||||
|
||||
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
|
||||
return value
|
||||
elif isinstance(value, File):
|
||||
return value.to_dict()
|
||||
|
||||
return None
|
||||
|
||||
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
|
||||
if self._workflow_run and self._workflow_run.id == workflow_run_id:
|
||||
cached_workflow_run = self._workflow_run
|
||||
cached_workflow_run = session.merge(cached_workflow_run)
|
||||
return cached_workflow_run
|
||||
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
|
||||
workflow_run = session.scalar(stmt)
|
||||
if not workflow_run:
|
||||
raise WorkflowRunNotFoundError(workflow_run_id)
|
||||
self._workflow_run = workflow_run
|
||||
|
||||
return workflow_run
|
||||
|
||||
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
|
||||
if node_execution_id not in self._workflow_node_executions:
|
||||
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
|
||||
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
|
||||
return session.merge(cached_workflow_node_execution)
|
||||
|
||||
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
|
||||
"""
|
||||
Handle agent log
|
||||
:param task_id: task id
|
||||
:param event: agent log event
|
||||
:return:
|
||||
"""
|
||||
return AgentLogStreamResponse(
|
||||
task_id=task_id,
|
||||
data=AgentLogStreamResponse.Data(
|
||||
node_execution_id=event.node_execution_id,
|
||||
id=event.id,
|
||||
parent_id=event.parent_id,
|
||||
label=event.label,
|
||||
error=event.error,
|
||||
status=event.status,
|
||||
data=event.data,
|
||||
metadata=event.metadata,
|
||||
node_id=event.node_id,
|
||||
),
|
||||
)
|
||||
@ -0,0 +1,170 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional
|
||||
|
||||
import openai
|
||||
from httpx import Timeout
|
||||
from openai import OpenAI
|
||||
from openai.types import ModerationCreateResponse
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
|
||||
|
||||
|
||||
class OpenAIModerationModel(ModerationModel):
|
||||
"""
|
||||
Model class for OpenAI text moderation model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param text: text to moderate
|
||||
:param user: unique user id
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
|
||||
# init model client
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
|
||||
# chars per chunk
|
||||
length = self._get_max_characters_per_chunk(model, credentials)
|
||||
text_chunks = [text[i : i + length] for i in range(0, len(text), length)]
|
||||
|
||||
max_text_chunks = self._get_max_chunks(model, credentials)
|
||||
chunks = [text_chunks[i : i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
|
||||
|
||||
for text_chunk in chunks:
|
||||
moderation_result = self._moderation_invoke(model=model, client=client, texts=text_chunk)
|
||||
|
||||
for result in moderation_result.results:
|
||||
if result.flagged is True:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# transform credentials to kwargs for model instance
|
||||
credentials_kwargs = self._to_credential_kwargs(credentials)
|
||||
client = OpenAI(**credentials_kwargs)
|
||||
|
||||
# call moderation model
|
||||
self._moderation_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=["ping"],
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _moderation_invoke(self, model: str, client: OpenAI, texts: list[str]) -> ModerationCreateResponse:
|
||||
"""
|
||||
Invoke moderation model
|
||||
|
||||
:param model: model name
|
||||
:param client: model client
|
||||
:param texts: texts to moderate
|
||||
:return: false if text is safe, true otherwise
|
||||
"""
|
||||
# call moderation model
|
||||
moderation_result = client.moderations.create(model=model, input=texts)
|
||||
|
||||
return moderation_result
|
||||
|
||||
def _get_max_characters_per_chunk(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
Get max characters per chunk
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: max characters per chunk
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK in model_schema.model_properties:
|
||||
max_characters_per_chunk: int = model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK]
|
||||
return max_characters_per_chunk
|
||||
|
||||
return 2000
|
||||
|
||||
def _get_max_chunks(self, model: str, credentials: dict) -> int:
|
||||
"""
|
||||
Get max chunks for given embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return: max chunks
|
||||
"""
|
||||
model_schema = self.get_model_schema(model, credentials)
|
||||
|
||||
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
|
||||
max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
return max_chunks
|
||||
|
||||
return 1
|
||||
|
||||
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
|
||||
"""
|
||||
Transform credentials to kwargs for model instance
|
||||
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
credentials_kwargs = {
|
||||
"api_key": credentials["openai_api_key"],
|
||||
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
|
||||
"max_retries": 1,
|
||||
}
|
||||
|
||||
if credentials.get("openai_api_base"):
|
||||
openai_api_base = credentials["openai_api_base"].rstrip("/")
|
||||
credentials_kwargs["base_url"] = openai_api_base + "/v1"
|
||||
|
||||
if "openai_organization" in credentials:
|
||||
credentials_kwargs["organization"] = credentials["openai_organization"]
|
||||
|
||||
return credentials_kwargs
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
|
||||
InvokeServerUnavailableError: [openai.InternalServerError],
|
||||
InvokeRateLimitError: [openai.RateLimitError],
|
||||
InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
|
||||
InvokeBadRequestError: [
|
||||
openai.BadRequestError,
|
||||
openai.NotFoundError,
|
||||
openai.UnprocessableEntityError,
|
||||
openai.APIError,
|
||||
],
|
||||
}
|
||||
@ -0,0 +1,22 @@
|
||||
- claude-3-haiku@20240307
|
||||
- claude-3-opus@20240229
|
||||
- claude-3-sonnet@20240229
|
||||
- claude-3-5-sonnet-v2@20241022
|
||||
- claude-3-5-sonnet@20240620
|
||||
- gemini-1.0-pro-vision-001
|
||||
- gemini-1.0-pro-002
|
||||
- gemini-1.5-flash-001
|
||||
- gemini-1.5-flash-002
|
||||
- gemini-1.5-pro-001
|
||||
- gemini-1.5-pro-002
|
||||
- gemini-2.0-flash-001
|
||||
- gemini-2.0-flash-exp
|
||||
- gemini-2.0-flash-lite-preview-02-05
|
||||
- gemini-2.0-flash-thinking-exp-01-21
|
||||
- gemini-2.0-flash-thinking-exp-1219
|
||||
- gemini-2.0-pro-exp-02-05
|
||||
- gemini-exp-1114
|
||||
- gemini-exp-1121
|
||||
- gemini-exp-1206
|
||||
- gemini-flash-experimental
|
||||
- gemini-pro-experimental
|
||||
@ -0,0 +1,41 @@
|
||||
model: gemini-2.0-flash-001
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash 001
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,41 @@
|
||||
model: gemini-2.0-flash-lite-preview-02-05
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Lite Preview 0205
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 1048576
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,39 @@
|
||||
model: gemini-2.0-flash-thinking-exp-01-21
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Thinking Exp 0121
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,39 @@
|
||||
model: gemini-2.0-flash-thinking-exp-1219
|
||||
label:
|
||||
en_US: Gemini 2.0 Flash Thinking Exp 1219
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,37 @@
|
||||
model: gemini-2.0-pro-exp-02-05
|
||||
label:
|
||||
en_US: Gemini 2.0 Pro Exp 0205
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- document
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2000000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: presence_penalty
|
||||
use_template: presence_penalty
|
||||
- name: frequency_penalty
|
||||
use_template: frequency_penalty
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
required: true
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,41 @@
|
||||
model: gemini-exp-1114
|
||||
label:
|
||||
en_US: Gemini exp 1114
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,41 @@
|
||||
model: gemini-exp-1121
|
||||
label:
|
||||
en_US: Gemini exp 1121
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 32767
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,41 @@
|
||||
model: gemini-exp-1206
|
||||
label:
|
||||
en_US: Gemini exp 1206
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
- vision
|
||||
- tool-call
|
||||
- stream-tool-call
|
||||
- document
|
||||
- video
|
||||
- audio
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 2097152
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
- name: top_p
|
||||
use_template: top_p
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
- name: max_output_tokens
|
||||
use_template: max_tokens
|
||||
default: 8192
|
||||
min: 1
|
||||
max: 8192
|
||||
- name: json_schema
|
||||
use_template: json_schema
|
||||
pricing:
|
||||
input: '0.00'
|
||||
output: '0.00'
|
||||
unit: '0.000001'
|
||||
currency: USD
|
||||
@ -0,0 +1,113 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.agent.entities import AgentInvokeMessage
|
||||
from core.plugin.entities.plugin import GenericProviderID
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginAgentProviderEntity,
|
||||
)
|
||||
from core.plugin.manager.base import BasePluginManager
|
||||
|
||||
|
||||
class PluginAgentManager(BasePluginManager):
|
||||
def fetch_agent_strategy_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]:
|
||||
"""
|
||||
Fetch agent providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
for strategy in declaration.get("strategies", []):
|
||||
strategy["identity"]["provider"] = provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/agent_strategies",
|
||||
list[PluginAgentProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for strategy in provider.declaration.strategies:
|
||||
strategy.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_agent_strategy_provider(self, tenant_id: str, provider: str) -> PluginAgentProviderEntity:
|
||||
"""
|
||||
Fetch tool provider for the given tenant and plugin.
|
||||
"""
|
||||
agent_provider_id = GenericProviderID(provider)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
# skip if error occurs
|
||||
if json_response.get("data") is None or json_response.get("data", {}).get("declaration") is None:
|
||||
return json_response
|
||||
|
||||
for strategy in json_response.get("data", {}).get("declaration", {}).get("strategies", []):
|
||||
strategy["identity"]["provider"] = agent_provider_id.provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/agent_strategy",
|
||||
PluginAgentProviderEntity,
|
||||
params={"provider": agent_provider_id.provider_name, "plugin_id": agent_provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for strategy in response.declaration.strategies:
|
||||
strategy.identity.provider = response.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
agent_provider: str,
|
||||
agent_strategy: str,
|
||||
agent_params: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[AgentInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the agent with the given tenant, user, plugin, provider, name and parameters.
|
||||
"""
|
||||
|
||||
agent_provider_id = GenericProviderID(agent_provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/agent_strategy/invoke",
|
||||
AgentInvokeMessage,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"app_id": app_id,
|
||||
"message_id": message_id,
|
||||
"data": {
|
||||
"agent_strategy_provider": agent_provider_id.provider_name,
|
||||
"agent_strategy": agent_strategy,
|
||||
"agent_strategy_params": agent_params,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": agent_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
return response
|
||||
@ -0,0 +1,12 @@
|
||||
from core.plugin.manager.base import BasePluginManager
|
||||
|
||||
|
||||
class PluginAssetManager(BasePluginManager):
|
||||
def fetch_asset(self, tenant_id: str, id: str) -> bytes:
|
||||
"""
|
||||
Fetch an asset by id.
|
||||
"""
|
||||
response = self._request(method="GET", path=f"plugin/{tenant_id}/asset/{id}")
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"can not found asset {id}")
|
||||
return response.content
|
||||
@ -0,0 +1,237 @@
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TypeVar
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError
|
||||
from core.plugin.manager.exc import (
|
||||
PluginDaemonBadRequestError,
|
||||
PluginDaemonInternalServerError,
|
||||
PluginDaemonNotFoundError,
|
||||
PluginDaemonUnauthorizedError,
|
||||
PluginInvokeError,
|
||||
PluginNotFoundError,
|
||||
PluginPermissionDeniedError,
|
||||
PluginUniqueIdentifierError,
|
||||
)
|
||||
|
||||
plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_DAEMON_URL
|
||||
plugin_daemon_inner_api_key = dify_config.PLUGIN_DAEMON_KEY
|
||||
|
||||
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BasePluginManager:
|
||||
def _request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | str | None = None,
|
||||
params: dict | None = None,
|
||||
files: dict | None = None,
|
||||
stream: bool = False,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
Make a request to the plugin daemon inner API.
|
||||
"""
|
||||
url = URL(str(plugin_daemon_inner_api_baseurl)) / path
|
||||
headers = headers or {}
|
||||
headers["X-Api-Key"] = plugin_daemon_inner_api_key
|
||||
headers["Accept-Encoding"] = "gzip, deflate, br"
|
||||
|
||||
if headers.get("Content-Type") == "application/json" and isinstance(data, dict):
|
||||
data = json.dumps(data)
|
||||
|
||||
try:
|
||||
response = requests.request(
|
||||
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
|
||||
)
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.exception("Request to Plugin Daemon Service failed")
|
||||
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
|
||||
|
||||
return response
|
||||
|
||||
def _stream_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
params: dict | None = None,
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
files: dict | None = None,
|
||||
) -> Generator[bytes, None, None]:
|
||||
"""
|
||||
Make a stream request to the plugin daemon inner API
|
||||
"""
|
||||
response = self._request(method, path, headers, data, params, files, stream=True)
|
||||
for line in response.iter_lines():
|
||||
line = line.decode("utf-8").strip()
|
||||
if line.startswith("data:"):
|
||||
line = line[5:].strip()
|
||||
if line:
|
||||
yield line
|
||||
|
||||
def _stream_request_with_model(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
files: dict | None = None,
|
||||
) -> Generator[T, None, None]:
|
||||
"""
|
||||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||
"""
|
||||
for line in self._stream_request(method, path, params, headers, data, files):
|
||||
yield type(**json.loads(line)) # type: ignore
|
||||
|
||||
def _request_with_model(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | None = None,
|
||||
params: dict | None = None,
|
||||
files: dict | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
"""
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
return type(**response.json()) # type: ignore
|
||||
|
||||
def _request_with_plugin_daemon_response(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
files: dict | None = None,
|
||||
transformer: Callable[[dict], dict] | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Make a request to the plugin daemon inner API and return the response as a model.
|
||||
"""
|
||||
response = self._request(method, path, headers, data, params, files)
|
||||
json_response = response.json()
|
||||
if transformer:
|
||||
json_response = transformer(json_response)
|
||||
|
||||
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
|
||||
if rep.code != 0:
|
||||
try:
|
||||
error = PluginDaemonError(**json.loads(rep.message))
|
||||
except Exception:
|
||||
raise ValueError(f"{rep.message}, code: {rep.code}")
|
||||
|
||||
self._handle_plugin_daemon_error(error.error_type, error.message)
|
||||
if rep.data is None:
|
||||
frame = inspect.currentframe()
|
||||
raise ValueError(f"got empty data from plugin daemon: {frame.f_lineno if frame else 'unknown'}")
|
||||
|
||||
return rep.data
|
||||
|
||||
def _request_with_plugin_daemon_response_stream(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
type: type[T],
|
||||
headers: dict | None = None,
|
||||
data: bytes | dict | None = None,
|
||||
params: dict | None = None,
|
||||
files: dict | None = None,
|
||||
) -> Generator[T, None, None]:
|
||||
"""
|
||||
Make a stream request to the plugin daemon inner API and yield the response as a model.
|
||||
"""
|
||||
for line in self._stream_request(method, path, params, headers, data, files):
|
||||
line_data = None
|
||||
try:
|
||||
line_data = json.loads(line)
|
||||
rep = PluginDaemonBasicResponse[type](**line_data) # type: ignore
|
||||
except Exception:
|
||||
# TODO modify this when line_data has code and message
|
||||
if line_data and "error" in line_data:
|
||||
raise ValueError(line_data["error"])
|
||||
else:
|
||||
raise ValueError(line)
|
||||
|
||||
if rep.code != 0:
|
||||
if rep.code == -500:
|
||||
try:
|
||||
error = PluginDaemonError(**json.loads(rep.message))
|
||||
except Exception:
|
||||
raise PluginDaemonInnerError(code=rep.code, message=rep.message)
|
||||
|
||||
self._handle_plugin_daemon_error(error.error_type, error.message)
|
||||
raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}")
|
||||
if rep.data is None:
|
||||
frame = inspect.currentframe()
|
||||
raise ValueError(f"got empty data from plugin daemon: {frame.f_lineno if frame else 'unknown'}")
|
||||
yield rep.data
|
||||
|
||||
def _handle_plugin_daemon_error(self, error_type: str, message: str):
|
||||
"""
|
||||
handle the error from plugin daemon
|
||||
"""
|
||||
match error_type:
|
||||
case PluginDaemonInnerError.__name__:
|
||||
raise PluginDaemonInnerError(code=-500, message=message)
|
||||
case PluginInvokeError.__name__:
|
||||
error_object = json.loads(message)
|
||||
invoke_error_type = error_object.get("error_type")
|
||||
args = error_object.get("args")
|
||||
match invoke_error_type:
|
||||
case InvokeRateLimitError.__name__:
|
||||
raise InvokeRateLimitError(description=args.get("description"))
|
||||
case InvokeAuthorizationError.__name__:
|
||||
raise InvokeAuthorizationError(description=args.get("description"))
|
||||
case InvokeBadRequestError.__name__:
|
||||
raise InvokeBadRequestError(description=args.get("description"))
|
||||
case InvokeConnectionError.__name__:
|
||||
raise InvokeConnectionError(description=args.get("description"))
|
||||
case InvokeServerUnavailableError.__name__:
|
||||
raise InvokeServerUnavailableError(description=args.get("description"))
|
||||
case CredentialsValidateFailedError.__name__:
|
||||
raise CredentialsValidateFailedError(error_object.get("message"))
|
||||
case _:
|
||||
raise PluginInvokeError(description=message)
|
||||
case PluginDaemonInternalServerError.__name__:
|
||||
raise PluginDaemonInternalServerError(description=message)
|
||||
case PluginDaemonBadRequestError.__name__:
|
||||
raise PluginDaemonBadRequestError(description=message)
|
||||
case PluginDaemonNotFoundError.__name__:
|
||||
raise PluginDaemonNotFoundError(description=message)
|
||||
case PluginUniqueIdentifierError.__name__:
|
||||
raise PluginUniqueIdentifierError(description=message)
|
||||
case PluginNotFoundError.__name__:
|
||||
raise PluginNotFoundError(description=message)
|
||||
case PluginDaemonUnauthorizedError.__name__:
|
||||
raise PluginDaemonUnauthorizedError(description=message)
|
||||
case PluginPermissionDeniedError.__name__:
|
||||
raise PluginPermissionDeniedError(description=message)
|
||||
case _:
|
||||
raise Exception(f"got unknown error from plugin daemon: {error_type}, message: {message}")
|
||||
@ -0,0 +1,17 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.manager.base import BasePluginManager
|
||||
|
||||
|
||||
class PluginDebuggingManager(BasePluginManager):
|
||||
def get_debugging_key(self, tenant_id: str) -> str:
|
||||
"""
|
||||
Get the debugging key for the given tenant.
|
||||
"""
|
||||
|
||||
class Response(BaseModel):
|
||||
key: str
|
||||
|
||||
response = self._request_with_plugin_daemon_response("POST", f"plugin/{tenant_id}/debugging/key", Response)
|
||||
|
||||
return response.key
|
||||
@ -0,0 +1,116 @@
|
||||
from core.plugin.entities.endpoint import EndpointEntityWithInstance
|
||||
from core.plugin.manager.base import BasePluginManager
|
||||
|
||||
|
||||
class PluginEndpointManager(BasePluginManager):
|
||||
def create_endpoint(
|
||||
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict
|
||||
) -> bool:
|
||||
"""
|
||||
Create an endpoint for the given plugin.
|
||||
|
||||
Errors will be raised if any error occurs.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/endpoint/setup",
|
||||
bool,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"plugin_unique_identifier": plugin_unique_identifier,
|
||||
"settings": settings,
|
||||
"name": name,
|
||||
},
|
||||
)
|
||||
|
||||
def list_endpoints(self, tenant_id: str, user_id: str, page: int, page_size: int):
|
||||
"""
|
||||
List all endpoints for the given tenant and user.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/endpoint/list",
|
||||
list[EndpointEntityWithInstance],
|
||||
params={"page": page, "page_size": page_size},
|
||||
)
|
||||
|
||||
def list_endpoints_for_single_plugin(self, tenant_id: str, user_id: str, plugin_id: str, page: int, page_size: int):
|
||||
"""
|
||||
List all endpoints for the given tenant, user and plugin.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/endpoint/list/plugin",
|
||||
list[EndpointEntityWithInstance],
|
||||
params={"plugin_id": plugin_id, "page": page, "page_size": page_size},
|
||||
)
|
||||
|
||||
def update_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
|
||||
"""
|
||||
Update the settings of the given endpoint.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/endpoint/update",
|
||||
bool,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"endpoint_id": endpoint_id,
|
||||
"name": name,
|
||||
"settings": settings,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def delete_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
"""
|
||||
Delete the given endpoint.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/endpoint/remove",
|
||||
bool,
|
||||
data={
|
||||
"endpoint_id": endpoint_id,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def enable_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
"""
|
||||
Enable the given endpoint.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/endpoint/enable",
|
||||
bool,
|
||||
data={
|
||||
"endpoint_id": endpoint_id,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
def disable_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
|
||||
"""
|
||||
Disable the given endpoint.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/endpoint/disable",
|
||||
bool,
|
||||
data={
|
||||
"endpoint_id": endpoint_id,
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
@ -0,0 +1,49 @@
|
||||
class PluginDaemonError(Exception):
|
||||
"""Base class for all plugin daemon errors."""
|
||||
|
||||
def __init__(self, description: str) -> None:
|
||||
self.description = description
|
||||
|
||||
def __str__(self) -> str:
|
||||
# returns the class name and description
|
||||
return f"{self.__class__.__name__}: {self.description}"
|
||||
|
||||
|
||||
class PluginDaemonInternalError(PluginDaemonError):
|
||||
pass
|
||||
|
||||
|
||||
class PluginDaemonClientSideError(PluginDaemonError):
|
||||
pass
|
||||
|
||||
|
||||
class PluginDaemonInternalServerError(PluginDaemonInternalError):
|
||||
description: str = "Internal Server Error"
|
||||
|
||||
|
||||
class PluginDaemonUnauthorizedError(PluginDaemonInternalError):
|
||||
description: str = "Unauthorized"
|
||||
|
||||
|
||||
class PluginDaemonNotFoundError(PluginDaemonInternalError):
|
||||
description: str = "Not Found"
|
||||
|
||||
|
||||
class PluginDaemonBadRequestError(PluginDaemonClientSideError):
|
||||
description: str = "Bad Request"
|
||||
|
||||
|
||||
class PluginInvokeError(PluginDaemonClientSideError):
|
||||
description: str = "Invoke Error"
|
||||
|
||||
|
||||
class PluginUniqueIdentifierError(PluginDaemonClientSideError):
|
||||
description: str = "Unique Identifier Error"
|
||||
|
||||
|
||||
class PluginNotFoundError(PluginDaemonClientSideError):
|
||||
description: str = "Plugin Not Found"
|
||||
|
||||
|
||||
class PluginPermissionDeniedError(PluginDaemonClientSideError):
|
||||
description: str = "Permission Denied"
|
||||
@ -0,0 +1,531 @@
|
||||
import binascii
|
||||
from collections.abc import Generator, Sequence
|
||||
from typing import IO, Optional
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResultChunk
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity
|
||||
from core.model_runtime.entities.rerank_entities import RerankResult
|
||||
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin_daemon import (
|
||||
PluginBasicBooleanResponse,
|
||||
PluginDaemonInnerError,
|
||||
PluginLLMNumTokensResponse,
|
||||
PluginModelProviderEntity,
|
||||
PluginModelSchemaEntity,
|
||||
PluginStringResultResponse,
|
||||
PluginTextEmbeddingNumTokensResponse,
|
||||
PluginVoicesResponse,
|
||||
)
|
||||
from core.plugin.manager.base import BasePluginManager
|
||||
|
||||
|
||||
class PluginModelManager(BasePluginManager):
|
||||
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
|
||||
"""
|
||||
Fetch model providers for the given tenant.
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/models",
|
||||
list[PluginModelProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
)
|
||||
return response
|
||||
|
||||
def get_model_schema(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
) -> AIModelEntity | None:
|
||||
"""
|
||||
Get model schema
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/schema",
|
||||
PluginModelSchemaEntity,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.model_schema
|
||||
|
||||
return None
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
if resp.credentials and isinstance(resp.credentials, dict):
|
||||
credentials.update(resp.credentials)
|
||||
|
||||
return resp.result
|
||||
|
||||
return False
|
||||
|
||||
def validate_model_credentials(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/model/validate_model_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
if resp.credentials and isinstance(resp.credentials, dict):
|
||||
credentials.update(resp.credentials)
|
||||
|
||||
return resp.result
|
||||
|
||||
return False
|
||||
|
||||
def invoke_llm(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: Optional[dict] = None,
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
stop: Optional[list[str]] = None,
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Invoke llm
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
|
||||
type=LLMResultChunk,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "llm",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"prompt_messages": prompt_messages,
|
||||
"model_parameters": model_parameters,
|
||||
"tools": tools,
|
||||
"stop": stop,
|
||||
"stream": stream,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
yield from response
|
||||
except PluginDaemonInnerError as e:
|
||||
raise ValueError(e.message + str(e.code))
|
||||
|
||||
def get_llm_num_tokens(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model_type: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Get number of tokens for llm
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
|
||||
type=PluginLLMNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": model_type,
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"prompt_messages": prompt_messages,
|
||||
"tools": tools,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.num_tokens
|
||||
|
||||
return 0
|
||||
|
||||
def invoke_text_embedding(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
input_type: str,
|
||||
) -> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
|
||||
type=TextEmbeddingResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"texts": texts,
|
||||
"input_type": input_type,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("Failed to invoke text embedding")
|
||||
|
||||
def get_text_embedding_num_tokens(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
texts: list[str],
|
||||
) -> list[int]:
|
||||
"""
|
||||
Get number of tokens for text embedding
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
|
||||
type=PluginTextEmbeddingNumTokensResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "text-embedding",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"texts": texts,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.num_tokens
|
||||
|
||||
return []
|
||||
|
||||
def invoke_rerank(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
query: str,
|
||||
docs: list[str],
|
||||
score_threshold: Optional[float] = None,
|
||||
top_n: Optional[int] = None,
|
||||
) -> RerankResult:
|
||||
"""
|
||||
Invoke rerank
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
|
||||
type=RerankResult,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "rerank",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"query": query,
|
||||
"docs": docs,
|
||||
"score_threshold": score_threshold,
|
||||
"top_n": top_n,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp
|
||||
|
||||
raise ValueError("Failed to invoke rerank")
|
||||
|
||||
def invoke_tts(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
content_text: str,
|
||||
voice: str,
|
||||
) -> Generator[bytes, None, None]:
|
||||
"""
|
||||
Invoke tts
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
|
||||
type=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "tts",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"tenant_id": tenant_id,
|
||||
"content_text": content_text,
|
||||
"voice": voice,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
for result in response:
|
||||
hex_str = result.result
|
||||
yield binascii.unhexlify(hex_str)
|
||||
except PluginDaemonInnerError as e:
|
||||
raise ValueError(e.message + str(e.code))
|
||||
|
||||
def get_tts_model_voices(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
language: Optional[str] = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Get tts model voices
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
|
||||
type=PluginVoicesResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "tts",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"language": language,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
voices = []
|
||||
for voice in resp.voices:
|
||||
voices.append({"name": voice.name, "value": voice.value})
|
||||
|
||||
return voices
|
||||
|
||||
return []
|
||||
|
||||
def invoke_speech_to_text(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
file: IO[bytes],
|
||||
) -> str:
|
||||
"""
|
||||
Invoke speech to text
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
|
||||
type=PluginStringResultResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "speech2text",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"file": binascii.hexlify(file.read()).decode(),
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
raise ValueError("Failed to invoke speech to text")
|
||||
|
||||
def invoke_moderation(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
plugin_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
credentials: dict,
|
||||
text: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Invoke moderation
|
||||
"""
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
method="POST",
|
||||
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
|
||||
type=PluginBasicBooleanResponse,
|
||||
data=jsonable_encoder(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": provider,
|
||||
"model_type": "moderation",
|
||||
"model": model,
|
||||
"credentials": credentials,
|
||||
"text": text,
|
||||
},
|
||||
}
|
||||
),
|
||||
headers={
|
||||
"X-Plugin-ID": plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
raise ValueError("Failed to invoke moderation")
|
||||
@ -0,0 +1,249 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.plugin.entities.bundle import PluginBundleDependency
|
||||
from core.plugin.entities.plugin import (
|
||||
GenericProviderID,
|
||||
MissingPluginDependency,
|
||||
PluginDeclaration,
|
||||
PluginEntity,
|
||||
PluginInstallation,
|
||||
PluginInstallationSource,
|
||||
)
|
||||
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginInstallTaskStartResponse, PluginUploadResponse
|
||||
from core.plugin.manager.base import BasePluginManager
|
||||
|
||||
|
||||
class PluginInstallationManager(BasePluginManager):
|
||||
def fetch_plugin_by_identifier(
|
||||
self,
|
||||
tenant_id: str,
|
||||
identifier: str,
|
||||
) -> bool:
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/fetch/identifier",
|
||||
bool,
|
||||
params={"plugin_unique_identifier": identifier},
|
||||
)
|
||||
|
||||
def list_plugins(self, tenant_id: str) -> list[PluginEntity]:
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/list",
|
||||
list[PluginEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
)
|
||||
|
||||
def upload_pkg(
|
||||
self,
|
||||
tenant_id: str,
|
||||
pkg: bytes,
|
||||
verify_signature: bool = False,
|
||||
) -> PluginUploadResponse:
|
||||
"""
|
||||
Upload a plugin package and return the plugin unique identifier.
|
||||
"""
|
||||
body = {
|
||||
"dify_pkg": ("dify_pkg", pkg, "application/octet-stream"),
|
||||
}
|
||||
|
||||
data = {
|
||||
"verify_signature": "true" if verify_signature else "false",
|
||||
}
|
||||
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/install/upload/package",
|
||||
PluginUploadResponse,
|
||||
files=body,
|
||||
data=data,
|
||||
)
|
||||
|
||||
def upload_bundle(
|
||||
self,
|
||||
tenant_id: str,
|
||||
bundle: bytes,
|
||||
verify_signature: bool = False,
|
||||
) -> Sequence[PluginBundleDependency]:
|
||||
"""
|
||||
Upload a plugin bundle and return the dependencies.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/install/upload/bundle",
|
||||
list[PluginBundleDependency],
|
||||
files={"dify_bundle": ("dify_bundle", bundle, "application/octet-stream")},
|
||||
data={"verify_signature": "true" if verify_signature else "false"},
|
||||
)
|
||||
|
||||
def install_from_identifiers(
|
||||
self,
|
||||
tenant_id: str,
|
||||
identifiers: Sequence[str],
|
||||
source: PluginInstallationSource,
|
||||
metas: list[dict],
|
||||
) -> PluginInstallTaskStartResponse:
|
||||
"""
|
||||
Install a plugin from an identifier.
|
||||
"""
|
||||
# exception will be raised if the request failed
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/install/identifiers",
|
||||
PluginInstallTaskStartResponse,
|
||||
data={
|
||||
"plugin_unique_identifiers": identifiers,
|
||||
"source": source,
|
||||
"metas": metas,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def fetch_plugin_installation_tasks(self, tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]:
|
||||
"""
|
||||
Fetch plugin installation tasks.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/install/tasks",
|
||||
list[PluginInstallTask],
|
||||
params={"page": page, "page_size": page_size},
|
||||
)
|
||||
|
||||
def fetch_plugin_installation_task(self, tenant_id: str, task_id: str) -> PluginInstallTask:
|
||||
"""
|
||||
Fetch a plugin installation task.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/install/tasks/{task_id}",
|
||||
PluginInstallTask,
|
||||
)
|
||||
|
||||
def delete_plugin_installation_task(self, tenant_id: str, task_id: str) -> bool:
|
||||
"""
|
||||
Delete a plugin installation task.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/install/tasks/{task_id}/delete",
|
||||
bool,
|
||||
)
|
||||
|
||||
def delete_all_plugin_installation_task_items(self, tenant_id: str) -> bool:
|
||||
"""
|
||||
Delete all plugin installation task items.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/install/tasks/delete_all",
|
||||
bool,
|
||||
)
|
||||
|
||||
def delete_plugin_installation_task_item(self, tenant_id: str, task_id: str, identifier: str) -> bool:
|
||||
"""
|
||||
Delete a plugin installation task item.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/install/tasks/{task_id}/delete/{identifier}",
|
||||
bool,
|
||||
)
|
||||
|
||||
def fetch_plugin_manifest(self, tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
|
||||
"""
|
||||
Fetch a plugin manifest.
|
||||
"""
|
||||
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/fetch/manifest",
|
||||
PluginDeclaration,
|
||||
params={"plugin_unique_identifier": plugin_unique_identifier},
|
||||
)
|
||||
|
||||
def fetch_plugin_installation_by_ids(
|
||||
self, tenant_id: str, plugin_ids: Sequence[str]
|
||||
) -> Sequence[PluginInstallation]:
|
||||
"""
|
||||
Fetch plugin installations by ids.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/installation/fetch/batch",
|
||||
list[PluginInstallation],
|
||||
data={"plugin_ids": plugin_ids},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def fetch_missing_dependencies(
|
||||
self, tenant_id: str, plugin_unique_identifiers: list[str]
|
||||
) -> list[MissingPluginDependency]:
|
||||
"""
|
||||
Fetch missing dependencies
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/installation/missing",
|
||||
list[MissingPluginDependency],
|
||||
data={"plugin_unique_identifiers": plugin_unique_identifiers},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def uninstall(self, tenant_id: str, plugin_installation_id: str) -> bool:
|
||||
"""
|
||||
Uninstall a plugin.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/uninstall",
|
||||
bool,
|
||||
data={
|
||||
"plugin_installation_id": plugin_installation_id,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def upgrade_plugin(
|
||||
self,
|
||||
tenant_id: str,
|
||||
original_plugin_unique_identifier: str,
|
||||
new_plugin_unique_identifier: str,
|
||||
source: PluginInstallationSource,
|
||||
meta: dict,
|
||||
) -> PluginInstallTaskStartResponse:
|
||||
"""
|
||||
Upgrade a plugin.
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/install/upgrade",
|
||||
PluginInstallTaskStartResponse,
|
||||
data={
|
||||
"original_plugin_unique_identifier": original_plugin_unique_identifier,
|
||||
"new_plugin_unique_identifier": new_plugin_unique_identifier,
|
||||
"source": source,
|
||||
"meta": meta,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
def check_tools_existence(self, tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
|
||||
"""
|
||||
Check if the tools exist
|
||||
"""
|
||||
return self._request_with_plugin_daemon_response(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/management/tools/check_existence",
|
||||
list[bool],
|
||||
data={
|
||||
"provider_ids": [
|
||||
{
|
||||
"plugin_id": provider_id.plugin_id,
|
||||
"provider_name": provider_id.provider_name,
|
||||
}
|
||||
for provider_id in provider_ids
|
||||
]
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
@ -0,0 +1,188 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
|
||||
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
|
||||
from core.plugin.manager.base import BasePluginManager
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
|
||||
|
||||
|
||||
class PluginToolManager(BasePluginManager):
|
||||
def fetch_tool_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]:
|
||||
"""
|
||||
Fetch tool providers for the given tenant.
|
||||
"""
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
for provider in json_response.get("data", []):
|
||||
declaration = provider.get("declaration", {}) or {}
|
||||
provider_name = declaration.get("identity", {}).get("name")
|
||||
for tool in declaration.get("tools", []):
|
||||
tool["identity"]["provider"] = provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/tools",
|
||||
list[PluginToolProviderEntity],
|
||||
params={"page": 1, "page_size": 256},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
for provider in response:
|
||||
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in provider.declaration.tools:
|
||||
tool.identity.provider = provider.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
|
||||
"""
|
||||
Fetch tool provider for the given tenant and plugin.
|
||||
"""
|
||||
tool_provider_id = ToolProviderID(provider)
|
||||
|
||||
def transformer(json_response: dict[str, Any]) -> dict:
|
||||
data = json_response.get("data")
|
||||
if data:
|
||||
for tool in data.get("declaration", {}).get("tools", []):
|
||||
tool["identity"]["provider"] = tool_provider_id.provider_name
|
||||
|
||||
return json_response
|
||||
|
||||
response = self._request_with_plugin_daemon_response(
|
||||
"GET",
|
||||
f"plugin/{tenant_id}/management/tool",
|
||||
PluginToolProviderEntity,
|
||||
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id},
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||
|
||||
# override the provider name for each tool to plugin_id/provider_name
|
||||
for tool in response.declaration.tools:
|
||||
tool.identity.provider = response.declaration.identity.name
|
||||
|
||||
return response
|
||||
|
||||
def invoke(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
tool_provider: str,
|
||||
tool_name: str,
|
||||
credentials: dict[str, Any],
|
||||
tool_parameters: dict[str, Any],
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
"""
|
||||
Invoke the tool with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||
"""
|
||||
|
||||
tool_provider_id = GenericProviderID(tool_provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/tool/invoke",
|
||||
ToolInvokeMessage,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"app_id": app_id,
|
||||
"message_id": message_id,
|
||||
"data": {
|
||||
"provider": tool_provider_id.provider_name,
|
||||
"tool": tool_name,
|
||||
"credentials": credentials,
|
||||
"tool_parameters": tool_parameters,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": tool_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
return response
|
||||
|
||||
def validate_provider_credentials(
|
||||
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
validate the credentials of the provider
|
||||
"""
|
||||
tool_provider_id = GenericProviderID(provider)
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/tool/validate_credentials",
|
||||
PluginBasicBooleanResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"data": {
|
||||
"provider": tool_provider_id.provider_name,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": tool_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.result
|
||||
|
||||
return False
|
||||
|
||||
def get_runtime_parameters(
|
||||
self,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
credentials: dict[str, Any],
|
||||
tool: str,
|
||||
conversation_id: Optional[str] = None,
|
||||
app_id: Optional[str] = None,
|
||||
message_id: Optional[str] = None,
|
||||
) -> list[ToolParameter]:
|
||||
"""
|
||||
get the runtime parameters of the tool
|
||||
"""
|
||||
tool_provider_id = GenericProviderID(provider)
|
||||
|
||||
class RuntimeParametersResponse(BaseModel):
|
||||
parameters: list[ToolParameter]
|
||||
|
||||
response = self._request_with_plugin_daemon_response_stream(
|
||||
"POST",
|
||||
f"plugin/{tenant_id}/dispatch/tool/get_runtime_parameters",
|
||||
RuntimeParametersResponse,
|
||||
data={
|
||||
"user_id": user_id,
|
||||
"conversation_id": conversation_id,
|
||||
"app_id": app_id,
|
||||
"message_id": message_id,
|
||||
"data": {
|
||||
"provider": tool_provider_id.provider_name,
|
||||
"tool": tool,
|
||||
"credentials": credentials,
|
||||
},
|
||||
},
|
||||
headers={
|
||||
"X-Plugin-ID": tool_provider_id.plugin_id,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
for resp in response:
|
||||
return resp.parameters
|
||||
|
||||
return []
|
||||
@ -0,0 +1,258 @@
|
||||
"""Functionality for splitting text."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
from core.rag.splitter.text_splitter import (
|
||||
TS,
|
||||
Collection,
|
||||
Literal,
|
||||
RecursiveCharacterTextSplitter,
|
||||
Set,
|
||||
TokenTextSplitter,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
||||
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
"""
|
||||
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
|
||||
""" # 文档字符串,说明该类的作用是实现基于 GPT-2 的编码器,避免使用 tiktoken。
|
||||
|
||||
@classmethod
|
||||
def from_encoder(
|
||||
cls: type[TS],
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set(), # 允许的特殊字符集合,默认为空集。
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # 禁止的特殊字符集合,默认为 "all"。
|
||||
**kwargs: Any, # 其他关键字参数。
|
||||
):
|
||||
def _token_encoder(texts: list[str]) -> list[int]: # 定义一个内部函数,用于计算文本的 token 数量。
|
||||
if not texts: # 如果输入的文本列表为空,则返回空列表。
|
||||
return []
|
||||
# 否则,使用默认的 GPT-2 tokenizer 计算 token 数量。
|
||||
return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
|
||||
|
||||
if issubclass(cls, TokenTextSplitter): # 如果当前类是 TokenTextSplitter 的子类。
|
||||
extra_kwargs = { # 构造额外的关键字参数。
|
||||
"model_name": "gpt2", # 模型名称。
|
||||
"allowed_special": allowed_special, # 允许的特殊字符。
|
||||
"disallowed_special": disallowed_special, # 禁止的特殊字符。
|
||||
}
|
||||
kwargs = {**kwargs, **extra_kwargs} # 将额外参数合并到 kwargs 中。
|
||||
|
||||
return cls(length_function=_token_encoder, **kwargs) # 返回当前类的实例,并传入长度计算函数和其他参数。
|
||||
|
||||
|
||||
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
|
||||
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any):
|
||||
"""Create a new TextSplitter.""" # 文档字符串,说明构造函数的作用是创建一个新的文本分割器。
|
||||
super().__init__(**kwargs) # 调用父类的构造函数,初始化基类的属性。
|
||||
self._fixed_separator = fixed_separator # 固定分隔符,默认为 "\n\n"。
|
||||
self._separators = separators or ["\n\n", "\n", " ", ""] # 备用分隔符列表,默认为 ["\n\n", "\n", " ", ""]。
|
||||
|
||||
def split_text(self, text: str) -> list[str]: # 定义主方法,用于分割文本。
|
||||
"""Split incoming text and return chunks.""" # 文档字符串,说明该方法的作用是分割输入文本并返回块。
|
||||
if self._fixed_separator: # 如果设置了固定分隔符。
|
||||
chunks = text.split(self._fixed_separator) # 使用固定分隔符将文本分割成初步的块。
|
||||
else: # 如果未设置固定分隔符。
|
||||
chunks = [text] # 将整个文本作为一个块。
|
||||
|
||||
final_chunks = [] # 初始化最终的块列表。
|
||||
chunks_lengths = self._length_function(chunks) # 计算每个块的长度。
|
||||
for chunk, chunk_length in zip(chunks, chunks_lengths): # 遍历每个块及其长度。
|
||||
if chunk_length > self._chunk_size: # 如果块的长度超过限制。
|
||||
if self._keep_separator :
|
||||
final_chunks.extend(self.recursive_split_text_keep_separator_(chunk)) # 调用递归分割方法进一步拆分。
|
||||
continue
|
||||
final_chunks.extend(self.recursive_split_text_(chunk)) # 调用递归分割方法进一步拆分。
|
||||
else: # 如果块的长度未超过限制。
|
||||
final_chunks.append(chunk) # 直接保留该块。
|
||||
|
||||
return final_chunks # 返回最终的块列表。
|
||||
|
||||
def recursive_split_text_(self, text: str) -> list[str]: # 定义递归分割方法。
|
||||
"""Split incoming text and return chunks.""" # 文档字符串,说明该方法的作用是递归地分割文本并返回块。
|
||||
|
||||
final_chunks = [] # 初始化最终的块列表。
|
||||
separator = self._separators[-1] # 默认使用备用分隔符列表中的最后一个分隔符。
|
||||
new_separators = [] # 初始化新的分隔符列表。
|
||||
|
||||
for i, _s in enumerate(self._separators): # 遍历备用分隔符列表。
|
||||
if _s == "": # 如果遇到空字符串分隔符。
|
||||
separator = _s # 设置分隔符为空字符串。
|
||||
break # 结束循环。
|
||||
if _s in text: # 如果当前分隔符存在于文本中。
|
||||
separator = _s # 设置分隔符为当前分隔符。
|
||||
new_separators = self._separators[i + 1 :] # 更新新的分隔符列表。
|
||||
break # 结束循环。
|
||||
|
||||
# Now that we have the separator, split the text # 已经确定了分隔符,开始分割文本。
|
||||
if separator: # 如果分隔符不为空。
|
||||
if separator == " ": # 如果分隔符是空格。
|
||||
splits = text.split() # 按空格分割文本。
|
||||
else: # 如果分隔符不是空格。
|
||||
splits = text.split(separator) # 按指定分隔符分割文本。
|
||||
else: # 如果分隔符为空字符串。
|
||||
splits = list(text) # 将文本按字符分割成列表。
|
||||
splits = [s for s in splits if (s not in {""})] # 过滤掉空字符串和换行符。
|
||||
|
||||
_good_splits = [] # 初始化符合长度要求的块列表。
|
||||
_good_splits_lengths = [] # 缓存这些块的长度。
|
||||
self._keep_separator = False
|
||||
_separator = "" if self._keep_separator else separator # 根据是否保留分隔符决定连接符。
|
||||
s_lens = self._length_function(splits) # 计算每个分割部分的长度。
|
||||
if _separator != "": # 如果连接符不为空。
|
||||
for s, s_len in zip(splits, s_lens): # 遍历每个分割部分及其长度。
|
||||
print("-----",s,s_len,self._chunk_size)
|
||||
if s_len < self._chunk_size: # 如果长度小于限制。
|
||||
_good_splits.append(s) # 将其加入符合要求的块列表。
|
||||
_good_splits_lengths.append(s_len) # 缓存其长度。
|
||||
else: # 如果长度超出限制。
|
||||
if _good_splits: # 如果有符合要求的块。
|
||||
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) # 合并这些块。
|
||||
final_chunks.extend(merged_text) # 将合并后的块加入最终块列表。
|
||||
_good_splits = [] # 清空符合要求的块列表。
|
||||
_good_splits_lengths = [] # 清空长度缓存。
|
||||
if not new_separators: # 如果没有新的分隔符。
|
||||
final_chunks.append(s) # 直接保留当前部分。
|
||||
else: # 如果有新的分隔符。
|
||||
other_info = self._split_text(s, new_separators) # 递归调用分割方法。
|
||||
final_chunks.extend(other_info) # 将结果加入最终块列表。
|
||||
|
||||
if _good_splits: # 如果还有剩余的符合要求的块。
|
||||
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) # 合并这些块。
|
||||
final_chunks.extend(merged_text) # 将合并后的块加入最终块列表。
|
||||
else: # 如果连接符为空。
|
||||
current_part = "" # 初始化当前块。
|
||||
current_length = 0 # 初始化当前块的长度。
|
||||
overlap_part = "" # 初始化重叠部分。
|
||||
overlap_part_length = 0 # 初始化重叠部分的长度。
|
||||
for s, s_len in zip(splits, s_lens): # 遍历每个分割部分及其长度。
|
||||
if current_length + s_len <= self._chunk_size - self._chunk_overlap: # 如果当前块可以容纳更多内容。
|
||||
current_part += s # 将当前部分加入当前块。
|
||||
current_length += s_len # 更新当前块的长度。
|
||||
elif current_length + s_len <= self._chunk_size: # 如果当前块接近长度限制。
|
||||
current_part += s # 将当前部分加入当前块。
|
||||
current_length += s_len # 更新当前块的长度。
|
||||
overlap_part += s # 将当前部分加入重叠部分。
|
||||
overlap_part_length += s_len # 更新重叠部分的长度。
|
||||
else: # 如果当前块已满。
|
||||
final_chunks.append(current_part) # 将当前块加入最终块列表。
|
||||
current_part = overlap_part + s # 构造新的当前块。
|
||||
current_length = s_len + overlap_part_length # 更新当前块的长度。
|
||||
overlap_part = "" # 清空重叠部分。
|
||||
overlap_part_length = 0 # 清空重叠部分的长度。
|
||||
if current_part: # 如果还有剩余的当前块。
|
||||
final_chunks.append(current_part) # 将其加入最终块列表。
|
||||
|
||||
return final_chunks # 返回最终的块列表。
|
||||
|
||||
def recursive_split_text_keep_separator_(self, text: str) -> list[str]: # 定义递归分割方法。
|
||||
"""Split incoming text and return chunks.""" # 文档字符串,说明该方法的作用是递归地分割文本并返回块。
|
||||
|
||||
final_chunks = [] # 初始化最终的块列表。
|
||||
current_part_list = []
|
||||
self.append_next_split_text(current_part_list=current_part_list,
|
||||
current_length_list=[],
|
||||
text=text,
|
||||
final_chunks = final_chunks,
|
||||
separators = self._separators)
|
||||
|
||||
if len(current_part_list): # 如果还有剩余的当前块。
|
||||
final_chunks.append("".join(current_part_list)) # 将其加入最终块列表。
|
||||
|
||||
return final_chunks # 返回最终的块列表。
|
||||
|
||||
@classmethod
|
||||
def get_splits_(self,text:str, separators:list[str]) -> (list[str],list[str]): # 定义递归分割方法。
|
||||
"""Split incoming text and return chunks.""" # 文档字符串,说明该方法的作用是递归地分割文本并返回块。
|
||||
if len(separators) > 0:
|
||||
separator = separators[-1] # 默认使用备用分隔符列表中的最后一个分隔符。
|
||||
new_separators = [] # 初始化新的分隔符列表。
|
||||
for i, _s in enumerate(separators): # 遍历备用分隔符列表。
|
||||
if _s in text: # 如果当前分隔符存在于文本中。
|
||||
separator = _s # 设置分隔符为当前分隔符。
|
||||
new_separators = separators[i + 1 :] # 更新新的分隔符列表。
|
||||
break # 结束循环。
|
||||
# Now that we have the separator, split the text # 已经确定了分隔符,开始分割文本。
|
||||
if separator: # 如果分隔符不为空。
|
||||
splits = text.split(separator) # 按指定分隔符分割文本。
|
||||
else: # 如果分隔符为空字符串。
|
||||
splits = list(text) # 将文本按字符分割成列表。
|
||||
# splits = [s for s in splits if (s not in {""})] # 过滤掉空字符串和换行符。
|
||||
return splits,new_separators
|
||||
else:
|
||||
return [text],[]
|
||||
|
||||
def append_next_split_text(self,
|
||||
current_part_list:list[str],
|
||||
current_length_list:list[int],
|
||||
text: str,
|
||||
final_chunks: list[str],
|
||||
separators : list[str]): # 定义递归分割方法。
|
||||
if text:
|
||||
# 需要判断是否可以再拼接
|
||||
splits, new_separators_ = self.get_splits_(text, separators)
|
||||
s_lens = self._length_function(splits) # 计算每个分割部分的长度。
|
||||
for s, s_len in zip(splits, s_lens): # 遍历每个分割部分及其长度。
|
||||
|
||||
current_length = sum(current_length_list)
|
||||
if "制定综合主进度" in s:
|
||||
print(s)
|
||||
|
||||
if current_length + s_len <= self._chunk_size: # 如果当前块可以容纳更多内容。
|
||||
current_part_list.append(s) # 将当前部分加入当前块。
|
||||
current_length_list.append(s_len)
|
||||
else:
|
||||
if len(new_separators_) == 0:
|
||||
# 将片段加入到列表中
|
||||
final_chunks.append("".join(current_part_list))
|
||||
# 计算出重叠部分的内容
|
||||
overlap_part_length_,overlap_part_ = self.get_overlap_part(current_part_list,current_length_list)
|
||||
# 将重叠部分作为下一个片段的开头
|
||||
current_part_list.clear()
|
||||
current_part_list.append(overlap_part_)
|
||||
current_length_list.clear()
|
||||
current_length_list.append(overlap_part_length_)
|
||||
|
||||
if overlap_part_length_ + s_len <= self._chunk_size: # 如果当前块可以容纳更多内容。
|
||||
current_part_list.append(s) # 将当前部分加入当前块。
|
||||
current_length_list.append(s_len)
|
||||
continue
|
||||
# 递归计算
|
||||
self.append_next_split_text(current_part_list=current_part_list,
|
||||
current_length_list=current_length_list,
|
||||
text=s,
|
||||
final_chunks=final_chunks,
|
||||
separators=new_separators_)
|
||||
|
||||
def get_overlap_part(self,
|
||||
current_part_list:list[str],
|
||||
current_length_list:list[int]) -> (int,str): # 定义递归分割方法。
|
||||
# 一下计算出
|
||||
overlap_part_length_ = 0
|
||||
overlap_part_list = []
|
||||
current_length_list_reversed = list(reversed(current_length_list))
|
||||
current_part_list_reversed = list(reversed(current_part_list))
|
||||
for index, s_len_ in enumerate(current_length_list_reversed):
|
||||
if overlap_part_length_ + s_len_ > self._chunk_overlap:
|
||||
if overlap_part_length_ < self._chunk_overlap:
|
||||
text = current_part_list_reversed[index]
|
||||
texts = list(text)
|
||||
text_lens = self._length_function(texts)
|
||||
texts_reversed = list(reversed(texts))
|
||||
text_lens_reversed = list(reversed(text_lens))
|
||||
for s_, len_ in zip(texts_reversed,text_lens_reversed):
|
||||
if overlap_part_length_ + len_ > self._chunk_overlap:
|
||||
break
|
||||
overlap_part_length_ += len_
|
||||
overlap_part_list[0:0] = s_
|
||||
# overlap_part_list.append(s_)
|
||||
break
|
||||
overlap_part_length_ += s_len_
|
||||
overlap_part_list[0:0] = current_part_list_reversed[index]
|
||||
# overlap_part_list.append(current_part_list_reversed[index])
|
||||
return overlap_part_length_, "".join(overlap_part_list)
|
||||
@ -0,0 +1,13 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.workflow.graph_engine.entities.graph import GraphParallel
|
||||
|
||||
|
||||
class NextGraphNode(BaseModel):
|
||||
node_id: str
|
||||
"""next node id"""
|
||||
|
||||
parallel: Optional[GraphParallel] = None
|
||||
"""parallel"""
|
||||
@ -0,0 +1,3 @@
|
||||
from .vanna_node import VannaNode
|
||||
|
||||
__all__ = ["VannaNode"]
|
||||
@ -0,0 +1,26 @@
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.llm import ModelConfig, VisionConfig
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
|
||||
class VannaConfig(BaseModel):
|
||||
"""
|
||||
Vanna Config.
|
||||
"""
|
||||
|
||||
provider: str
|
||||
name: str
|
||||
mode: LLMMode
|
||||
|
||||
|
||||
class VannaNodeData(BaseNodeData):
|
||||
|
||||
model: ModelConfig
|
||||
query: list[str]
|
||||
instruction: Optional[str] = None
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
|
||||
@ -0,0 +1,50 @@
|
||||
class VannaNodeError(ValueError):
|
||||
"""Base error for VannaNode."""
|
||||
|
||||
|
||||
class InvalidModelTypeError(VannaNodeError):
|
||||
"""Raised when the model is not a Large Language Model."""
|
||||
|
||||
|
||||
class ModelSchemaNotFoundError(VannaNodeError):
|
||||
"""Raised when the model schema is not found."""
|
||||
|
||||
|
||||
class InvalidInvokeResultError(VannaNodeError):
|
||||
"""Raised when the invoke result is invalid."""
|
||||
|
||||
|
||||
class InvalidTextContentTypeError(VannaNodeError):
|
||||
"""Raised when the text content type is invalid."""
|
||||
|
||||
|
||||
class InvalidNumberOfParametersError(VannaNodeError):
|
||||
"""Raised when the number of parameters is invalid."""
|
||||
|
||||
|
||||
class RequiredParameterMissingError(VannaNodeError):
|
||||
"""Raised when a required parameter is missing."""
|
||||
|
||||
|
||||
class InvalidSelectValueError(VannaNodeError):
|
||||
"""Raised when a select value is invalid."""
|
||||
|
||||
|
||||
class InvalidNumberValueError(VannaNodeError):
|
||||
"""Raised when a number value is invalid."""
|
||||
|
||||
|
||||
class InvalidBoolValueError(VannaNodeError):
|
||||
"""Raised when a bool value is invalid."""
|
||||
|
||||
|
||||
class InvalidStringValueError(VannaNodeError):
|
||||
"""Raised when a string value is invalid."""
|
||||
|
||||
|
||||
class InvalidArrayValueError(VannaNodeError):
|
||||
"""Raised when an array value is invalid."""
|
||||
|
||||
|
||||
class InvalidModelModeError(VannaNodeError):
|
||||
"""Raised when the model mode is invalid."""
|
||||
@ -0,0 +1,92 @@
|
||||
from typing import Any, Optional, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.model_manager import ModelInstance
|
||||
|
||||
from core.workflow.entities.node_entities import NodeRunResult
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
from extensions.utils.vanna_text2sql import VannaServer
|
||||
from models.workflow import WorkflowNodeExecutionStatus
|
||||
|
||||
from .entities import VannaNodeData
|
||||
|
||||
|
||||
class Config:
|
||||
def __init__(self, supplier):
|
||||
self.embedding_supplier = "SiliconFlow"
|
||||
self.milvus_uri = dify_config.MILVUS_URI
|
||||
self.milvus_database = 'vanna_demo'
|
||||
self.supplier = supplier
|
||||
self.sql_type = 'postgres'
|
||||
self.sql_config = {
|
||||
"host": dify_config.DB_HOST,
|
||||
"dbname": 'vanna_demo',
|
||||
"user": dify_config.DB_USERNAME,
|
||||
"password": dify_config.DB_PASSWORD,
|
||||
"port": dify_config.DB_PORT
|
||||
}
|
||||
|
||||
vn_instances = {}
|
||||
|
||||
def get_vanna_server(key, combined_config):
|
||||
if key not in vn_instances:
|
||||
vn_instances[key] = VannaServer(combined_config)
|
||||
return vn_instances[key]
|
||||
class VannaNode(LLMNode):
|
||||
# FIXME: figure out why here is different from super class
|
||||
_node_data_cls = VannaNodeData # type: ignore
|
||||
_node_type = NodeType.VANNA
|
||||
|
||||
_model_instance: Optional[ModelInstance] = None
|
||||
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
|
||||
return {
|
||||
"model": {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
def _run(self):
|
||||
node_data = cast(VannaNodeData, self.node_data)
|
||||
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
|
||||
query = variable.text if variable else ""
|
||||
|
||||
model_instance, model_config = self._fetch_model_config(self.node_data.model)
|
||||
# 'tongyi' 通义 'openai' openai 'ollama' ollama 'deepseek' deepseek
|
||||
llm_type = model_instance.provider.rsplit('/')[-1]
|
||||
api_key = ''
|
||||
base_url = ''
|
||||
if llm_type == 'tongyi':
|
||||
api_key = model_instance.credentials.get('dashscope_api_key')
|
||||
elif llm_type == 'deepseek':
|
||||
api_key = model_instance.credentials.get('api_key')
|
||||
elif llm_type == 'ollama':
|
||||
base_url = model_instance.credentials.get('base_url')
|
||||
|
||||
cache_kay = llm_type + api_key if api_key else base_url
|
||||
model = model_instance.model
|
||||
|
||||
vanna_config = {
|
||||
"llm_type": llm_type,
|
||||
"model": model,
|
||||
"api_key": api_key,
|
||||
"ollama_host": base_url
|
||||
}
|
||||
config = Config("")
|
||||
# 合并配置
|
||||
combined_config = {**config.__dict__, **config.sql_config, **vanna_config}
|
||||
|
||||
cache_data = get_vanna_server(cache_kay, combined_config)
|
||||
|
||||
# 提问获取sql和结果
|
||||
sql = cache_data.generate_sql(query)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={"output": sql}
|
||||
)
|
||||
|
||||
@ -0,0 +1,228 @@
|
||||
from vanna.ollama import Ollama
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
from functools import wraps
|
||||
from flask import Flask, jsonify, Response, request
|
||||
import flask
|
||||
from extensions.storage.cache import MemoryCache
|
||||
from dify_app import DifyApp
|
||||
from vanna.milvus import Milvus_VectorStore
|
||||
from pymilvus import MilvusClient
|
||||
from configs import dify_config
|
||||
# SETUP
|
||||
cache = MemoryCache()
|
||||
milvus_uri = dify_config.MILVUS_URI
|
||||
|
||||
milvus_client = MilvusClient(uri=milvus_uri)
|
||||
milvus_client.use_database("test")
|
||||
class MyVanna(Milvus_VectorStore, Ollama):
|
||||
def __init__(self, config=None):
|
||||
Milvus_VectorStore.__init__(self, config=config)
|
||||
Ollama.__init__(self, config=config)
|
||||
|
||||
# vn = MyVanna(config={
|
||||
# 'model': 'qwen2:7b', # 本地ollama大模型名称
|
||||
# 'ollama_host':'http://wsd.wisdomidata.com:19042', # 本地ollama大模型服务地址
|
||||
# 'milvus_client': milvus_client, # 本地milvus向量数据库服务地址
|
||||
# "n_results": 12,
|
||||
# })
|
||||
# vn.connect_to_postgres(
|
||||
# host=dify_config.DB_HOST,
|
||||
# dbname='vanna_demo',
|
||||
# user=dify_config.DB_USERNAME,
|
||||
# password=dify_config.DB_PASSWORD,
|
||||
# port=dify_config.DB_PORT
|
||||
# )
|
||||
# vn.connect_to_mysql(
|
||||
# host='122.51.104.137',
|
||||
# port=33306,
|
||||
# dbname='demo',
|
||||
# user='sws',
|
||||
# password='123456'
|
||||
# )
|
||||
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
|
||||
def requires_cache(fields):
|
||||
def decorator(f):
|
||||
@wraps(f)
|
||||
def decorated(*args, **kwargs):
|
||||
id = request.args.get('id')
|
||||
|
||||
if id is None:
|
||||
return jsonify({"type": "error", "error": "No id provided"})
|
||||
|
||||
for field in fields:
|
||||
if cache.get(id=id, field=field) is None:
|
||||
return jsonify({"type": "error", "error": f"No {field} found"})
|
||||
|
||||
field_values = {field: cache.get(id=id, field=field) for field in fields}
|
||||
|
||||
# Add the id to the field_values
|
||||
field_values['id'] = id
|
||||
|
||||
return f(*args, **field_values, **kwargs)
|
||||
|
||||
return decorated
|
||||
|
||||
return decorator
|
||||
|
||||
@app.route('/api/v0/generate_questions', methods=['GET'])
|
||||
def generate_questions():
|
||||
return jsonify({
|
||||
"type": "question_list",
|
||||
"questions": vn.generate_questions(),
|
||||
"header": "Here are some questions you can ask:"
|
||||
})
|
||||
|
||||
@app.route('/api/v0/generate_sql', methods=['GET'])
|
||||
def generate_sql():
|
||||
question = flask.request.args.get('question')
|
||||
|
||||
if question is None:
|
||||
return jsonify({"type": "error", "error": "No question provided"})
|
||||
|
||||
id = cache.generate_id(question=question)
|
||||
sql = vn.generate_sql(question=question)
|
||||
|
||||
cache.set(id=id, field='question', value=question)
|
||||
cache.set(id=id, field='sql', value=sql)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"type": "sql",
|
||||
"id": id,
|
||||
"text": sql,
|
||||
})
|
||||
|
||||
@app.route('/api/v0/run_sql', methods=['GET'])
|
||||
@requires_cache(['sql'])
|
||||
def run_sql(id: str, sql: str):
|
||||
try:
|
||||
df = vn.run_sql(sql=sql)
|
||||
|
||||
cache.set(id=id, field='df', value=df)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"type": "df",
|
||||
"id": id,
|
||||
"df": df.head(10).to_json(orient='records'),
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({"type": "error", "error": str(e)})
|
||||
|
||||
@app.route('/api/v0/download_csv', methods=['GET'])
|
||||
@requires_cache(['df'])
|
||||
def download_csv(id: str, df):
|
||||
csv = df.to_csv()
|
||||
|
||||
return Response(
|
||||
csv,
|
||||
mimetype="text/csv",
|
||||
headers={"Content-disposition":
|
||||
f"attachment; filename={id}.csv"})
|
||||
|
||||
@app.route('/api/v0/generate_plotly_figure', methods=['GET'])
|
||||
@requires_cache(['df', 'question', 'sql'])
|
||||
def generate_plotly_figure(id: str, df, question, sql):
|
||||
try:
|
||||
code = vn.generate_plotly_code(question=question, sql=sql,
|
||||
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}")
|
||||
fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
|
||||
fig_json = fig.to_json()
|
||||
|
||||
cache.set(id=id, field='fig_json', value=fig_json)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"type": "plotly_figure",
|
||||
"id": id,
|
||||
"fig": fig_json,
|
||||
})
|
||||
except Exception as e:
|
||||
# Print the stack trace
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
return jsonify({"type": "error", "error": str(e)})
|
||||
|
||||
@app.route('/api/v0/get_training_data', methods=['GET'])
|
||||
def get_training_data():
|
||||
df = vn.get_training_data()
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"type": "df",
|
||||
"id": "training_data",
|
||||
"df": df.head(25).to_json(orient='records'),
|
||||
})
|
||||
|
||||
@app.route('/api/v0/remove_training_data', methods=['POST'])
|
||||
def remove_training_data():
|
||||
# Get id from the JSON body
|
||||
id = flask.request.json.get('id')
|
||||
|
||||
if id is None:
|
||||
return jsonify({"type": "error", "error": "No id provided"})
|
||||
|
||||
if vn.remove_training_data(id=id):
|
||||
return jsonify({"success": True})
|
||||
else:
|
||||
return jsonify({"type": "error", "error": "Couldn't remove training data"})
|
||||
|
||||
@app.route('/api/v0/train', methods=['POST'])
|
||||
def add_training_data():
|
||||
question = flask.request.json.get('question')
|
||||
sql = flask.request.json.get('sql')
|
||||
ddl = flask.request.json.get('ddl')
|
||||
documentation = flask.request.json.get('documentation')
|
||||
|
||||
try:
|
||||
id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation)
|
||||
|
||||
return jsonify({"id": id})
|
||||
except Exception as e:
|
||||
print("TRAINING ERROR", e)
|
||||
return jsonify({"type": "error", "error": str(e)})
|
||||
|
||||
@app.route('/api/v0/generate_followup_questions', methods=['GET'])
|
||||
@requires_cache(['df', 'question', 'sql'])
|
||||
def generate_followup_questions(id: str, df, question, sql):
|
||||
followup_questions = vn.generate_followup_questions(question=question, sql=sql, df=df)
|
||||
|
||||
cache.set(id=id, field='followup_questions', value=followup_questions)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"type": "question_list",
|
||||
"id": id,
|
||||
"questions": followup_questions,
|
||||
"header": "Here are some followup questions you can ask:"
|
||||
})
|
||||
|
||||
@app.route('/api/v0/load_question', methods=['GET'])
|
||||
@requires_cache(['question', 'sql', 'df', 'fig_json', 'followup_questions'])
|
||||
def load_question(id: str, question, sql, df, fig_json, followup_questions):
|
||||
try:
|
||||
return jsonify(
|
||||
{
|
||||
"type": "question_cache",
|
||||
"id": id,
|
||||
"question": question,
|
||||
"sql": sql,
|
||||
"df": df.head(10).to_json(orient='records'),
|
||||
"fig": fig_json,
|
||||
"followup_questions": followup_questions,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({"type": "error", "error": str(e)})
|
||||
|
||||
@app.route('/api/v0/get_question_history', methods=['GET'])
|
||||
def get_question_history():
|
||||
return jsonify({"type": "question_history", "questions": cache.get_all(field_list=['question'])})
|
||||
@ -0,0 +1,238 @@
|
||||
import json
|
||||
import ast
|
||||
from configs import dify_config
|
||||
from extensions.utils.vanna_text2sql import VannaServer
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
from dify_app import DifyApp
|
||||
from flask import Flask, jsonify, Response, request
|
||||
import flask
|
||||
from werkzeug.exceptions import BadRequest
|
||||
import logging
|
||||
import plotly.io as pio
|
||||
from functools import lru_cache
|
||||
from datetime import datetime
|
||||
class Config:
|
||||
def __init__(self, supplier):
|
||||
self.embedding_supplier = "SiliconFlow"
|
||||
self.milvus_uri = dify_config.MILVUS_URI
|
||||
self.milvus_database = 'vanna_demo'
|
||||
self.supplier = supplier
|
||||
# self.llm_type = 'tongyi'
|
||||
# self.model = 'qwen-max'
|
||||
# self.api_key = 'sk-ba5d240e2dc0483e9e24404d957a15d5'
|
||||
# 本地模型
|
||||
# self.ollama_host = 'http://wsd.wisdomidata.com:19042'
|
||||
# self.model = 'qwen2:7b'
|
||||
self.llm_type = 'deepseek'
|
||||
self.model = 'deepseek-coder'
|
||||
self.api_key = 'sk-0382990b7a90496c889774b1d3843f90'
|
||||
self.sql_type = 'postgres'
|
||||
self.sql_config = {
|
||||
"host": dify_config.DB_HOST,
|
||||
"dbname": 'vanna_demo',
|
||||
"user": dify_config.DB_USERNAME,
|
||||
"password": dify_config.DB_PASSWORD,
|
||||
"port": dify_config.DB_PORT
|
||||
}
|
||||
|
||||
# 存储不同的 VannaServer 实例
|
||||
vn_instances = {}
|
||||
# 获取vanna实例
|
||||
def get_vn_instance(supplier=""):
|
||||
"""获取或创建VannaServer实例"""
|
||||
if supplier == "":
|
||||
supplier = "default"
|
||||
if supplier not in vn_instances:
|
||||
config = Config(supplier)
|
||||
# 合并配置
|
||||
combined_config = {**config.__dict__, **config.sql_config}
|
||||
vn_instances[supplier] = VannaServer(combined_config)
|
||||
return vn_instances[supplier]
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
@app.route('/api/ask', methods=['POST'])
|
||||
def ask_route():
|
||||
"""提问接口"""
|
||||
data = request.json
|
||||
question = data.get('question', '')
|
||||
visualize = data.get('visualize', True)
|
||||
auto_train = data.get('auto_train', False)
|
||||
supplier = data.get('supplier', "") # GITEE, ZHIPU, SiliconFlow
|
||||
|
||||
if not question:
|
||||
raise BadRequest("Question is required")
|
||||
|
||||
server = get_vn_instance(supplier)
|
||||
try:
|
||||
sql, df, fig = server.ask(question=question, visualize=visualize, auto_train=auto_train)
|
||||
|
||||
df_json = df.to_json(orient='records', force_ascii=False)
|
||||
|
||||
"""
|
||||
<img id="plotly-image" src="data:image/png;base64,{{ img_base64 }}" alt="Plotly Image">
|
||||
"""
|
||||
|
||||
# fig_js_path = '../output/html/vanna_fig.js'
|
||||
# fig_html_path = 'http://localhost:8000/html/vanna_fig.html'
|
||||
# figure_json = pio.to_json(fig)
|
||||
# with open(fig_js_path, 'w', encoding='utf-8') as f:
|
||||
# f.write(figure_json)
|
||||
"""
|
||||
<div id="plotly-div"></div>
|
||||
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
|
||||
<script>
|
||||
var fig_json = {{ fig_json }};
|
||||
Plotly.newPlot('plotly-div', fig_json.data, fig_json.layout);
|
||||
</script>
|
||||
"""
|
||||
|
||||
logging.info("Query processed successfully")
|
||||
return jsonify({
|
||||
'sql': sql,
|
||||
'data': df_json,
|
||||
# 'img_base64': img_base64,
|
||||
# 'plotly_figure': fig_html_path
|
||||
}), 200
|
||||
except Exception as e:
|
||||
logging.error(f"Error processing request: {e}")
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
@app.route('/api/vn_train', methods=['POST'])
|
||||
def vn_train_route():
|
||||
"""训练接口"""
|
||||
data = request.json
|
||||
# required_fields = ['question', 'sql']
|
||||
# validate_input(data, required_fields)
|
||||
|
||||
supplier = data.get('supplier', "")
|
||||
question = data.get('question', '')
|
||||
sql = data.get('sql', '')
|
||||
documentation = data.get('documentation', '')
|
||||
ddl = data.get('ddl', '')
|
||||
schema = data.get('schema', False)
|
||||
|
||||
# 验证至少有一个参数不为空
|
||||
if not any([question, sql, documentation, ddl, schema]):
|
||||
return jsonify(
|
||||
{
|
||||
'error': 'At least one of the parameters (question, sql, documentation, ddl, schema) must be provided'}), 400
|
||||
|
||||
server = get_vn_instance(supplier)
|
||||
server.vn_train(question=question, sql=sql, documentation=documentation, ddl=ddl)
|
||||
if schema:
|
||||
try:
|
||||
# server.schema_train()
|
||||
# 更新建表DDL语句
|
||||
server.refresh_create_table_ddl_train()
|
||||
server.refresh_schema_train()
|
||||
except Exception as e:
|
||||
logging.info(f"Error initializing vector store: {e}")
|
||||
|
||||
logging.info("Training completed successfully")
|
||||
return jsonify({'status': 'success'}), 200
|
||||
|
||||
|
||||
@app.route('/api/docs/update', methods=['POST'])
|
||||
def update_schema_train_list_route():
|
||||
"""训练接口"""
|
||||
data = request.json
|
||||
docs = data.get('docs', [])
|
||||
server = get_vn_instance("")
|
||||
server.update_schema_train_list(docs=docs)
|
||||
return jsonify({'status': 'success'}), 200
|
||||
|
||||
@app.route('/api/get_training_data', methods=['GET'])
|
||||
def get_training_data_route():
|
||||
"""获取训练数据接口"""
|
||||
supplier = request.args.get('supplier', "")
|
||||
server = get_vn_instance(supplier)
|
||||
|
||||
@lru_cache(maxsize=128) # 添加缓存机制
|
||||
def cached_get_training_data():
|
||||
return server.get_training_data()
|
||||
|
||||
training_data = cached_get_training_data()
|
||||
logging.info("Fetched training data successfully")
|
||||
|
||||
return jsonify({
|
||||
'data': json.loads(training_data.to_json(orient='records'))
|
||||
}), 200
|
||||
|
||||
@app.route('/api/generate_sql', methods=['GET'])
|
||||
def generate_sql():
|
||||
question = request.args.get('question')
|
||||
supplier = request.args.get('supplier','')
|
||||
|
||||
if question is None:
|
||||
return jsonify({"type": "error", "error": "No question provided"})
|
||||
server = get_vn_instance(supplier)
|
||||
sql = server.generate_sql(question=question)
|
||||
return jsonify(
|
||||
{
|
||||
"sql": sql
|
||||
}) , 200
|
||||
|
||||
@app.route('/api/run_sql', methods=['POST'])
|
||||
def run_sql():
|
||||
data = request.json
|
||||
supplier = data.get('supplier', "")
|
||||
sql = data.get('sql', '')
|
||||
try:
|
||||
server = get_vn_instance(supplier)
|
||||
df = server.run_sql(sql=sql)
|
||||
df_json = df.to_json(orient='records', force_ascii=False)
|
||||
return df_json, 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({"type": "error", "error": str(e)})
|
||||
|
||||
@app.route('/api/training/data/export', methods=['GET'])
|
||||
def training_data_export():
|
||||
supplier = request.args.get('supplier', "")
|
||||
|
||||
server = get_vn_instance(supplier)
|
||||
|
||||
# @lru_cache(maxsize=128) # 添加缓存机制
|
||||
# def cached_get_training_data():
|
||||
# return server.training_data_export()
|
||||
# training_data = cached_get_training_data()
|
||||
|
||||
data = server.training_data_export()
|
||||
content = ",\n".join(str(line) for line in data)
|
||||
|
||||
file_name = datetime.now().strftime('%Y-%m-%d-%H-%M')
|
||||
|
||||
# 创建一个可下载的文本响应
|
||||
return Response(
|
||||
f"[\n{content}\n]",
|
||||
mimetype='text/plain',
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename={file_name}.txt"
|
||||
}
|
||||
)
|
||||
|
||||
@app.route('/api/training/data/import', methods=['POST'])
|
||||
def training_data_import():
|
||||
|
||||
if 'file' not in request.files:
|
||||
return jsonify({"type": "error", "error": "未上传文件"}), 400
|
||||
|
||||
file = request.files['file']
|
||||
if file.filename == '':
|
||||
return jsonify({"type": "error", "error": "文件名为空"}), 400
|
||||
|
||||
try:
|
||||
# 读取文件并解析每一行的 JSON 对象
|
||||
content = file.read().decode('utf-8').strip()
|
||||
data_list = ast.literal_eval(content)
|
||||
|
||||
server = get_vn_instance("")
|
||||
result = server.training_data_import(data_list)
|
||||
if result:
|
||||
return jsonify({"type": "error", "error": "存在数据集question 或 sql为空"})
|
||||
|
||||
return jsonify({'status': 'success'}), 200
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({"type": "error", "error": f"文件解析失败: {str(e)}"}), 500
|
||||
@ -0,0 +1,62 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import uuid
|
||||
|
||||
class Cache(ABC):
|
||||
@abstractmethod
|
||||
def generate_id(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self, id, field):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_all(self, field_list) -> list:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def set(self, id, field, value):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def delete(self, id):
|
||||
pass
|
||||
|
||||
|
||||
class MemoryCache(Cache):
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def generate_id(self, *args, **kwargs):
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def set(self, id, field, value):
|
||||
if id not in self.cache:
|
||||
self.cache[id] = {}
|
||||
|
||||
self.cache[id][field] = value
|
||||
|
||||
def get(self, id, field):
|
||||
if id not in self.cache:
|
||||
return None
|
||||
|
||||
if field not in self.cache[id]:
|
||||
return None
|
||||
|
||||
return self.cache[id][field]
|
||||
|
||||
def get_all(self, field_list) -> list:
|
||||
return [
|
||||
{
|
||||
"id": id,
|
||||
**{
|
||||
field: self.get(id=id, field=field)
|
||||
for field in field_list
|
||||
}
|
||||
}
|
||||
for id in self.cache
|
||||
]
|
||||
|
||||
def delete(self, id):
|
||||
if id in self.cache:
|
||||
del self.cache[id]
|
||||
@ -0,0 +1,150 @@
|
||||
import traceback
|
||||
from typing import Union, Tuple
|
||||
import pandas as pd
|
||||
import plotly
|
||||
from PIL import Image as PILImage
|
||||
import io
|
||||
|
||||
|
||||
def ask(
|
||||
vanna_instance,
|
||||
question: Union[str, None] = None,
|
||||
print_results: bool = True,
|
||||
auto_train: bool = True,
|
||||
visualize: bool = True, # if False, will not generate plotly code
|
||||
allow_llm_to_see_data: bool = False,
|
||||
) -> Union[
|
||||
Tuple[
|
||||
Union[str, None],
|
||||
Union[pd.DataFrame, None],
|
||||
Union[plotly.graph_objs.Figure, None],
|
||||
],
|
||||
None,
|
||||
]:
|
||||
"""
|
||||
**Example:**
|
||||
python
|
||||
vn.ask("What are the top 10 customers by sales?")
|
||||
Ask Vanna.AI a question and get the SQL query that answers it.
|
||||
|
||||
Args:
|
||||
question (str): The question to ask.
|
||||
print_results (bool): Whether to print the results of the SQL query.
|
||||
auto_train (bool): Whether to automatically train Vanna.AI on the question and SQL query.
|
||||
visualize (bool): Whether to generate plotly code and display the plotly figure.
|
||||
|
||||
Returns:
|
||||
Tuple[str, pd.DataFrame, plotly.graph_objs.Figure]: The SQL query, the results of the SQL query, and the plotly figure.
|
||||
"""
|
||||
|
||||
if question is None:
|
||||
question = input("Enter a question: ")
|
||||
|
||||
try:
|
||||
sql = vanna_instance.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None, None, None
|
||||
|
||||
if print_results:
|
||||
try:
|
||||
Code = __import__("IPython.display", fromlist=["Code"]).Code
|
||||
display = __import__("IPython.display", fromlist=["display"]).display
|
||||
display(Code(sql))
|
||||
except Exception as e:
|
||||
print(sql)
|
||||
|
||||
if vanna_instance.run_sql_is_set is False:
|
||||
print(
|
||||
"If you want to run the SQL query, connect to a database first."
|
||||
)
|
||||
|
||||
if print_results:
|
||||
return None
|
||||
else:
|
||||
return sql, None, None
|
||||
|
||||
try:
|
||||
df = vanna_instance.run_sql(sql)
|
||||
|
||||
if print_results:
|
||||
try:
|
||||
display = __import__("IPython.display", fromlist=["display"]).display
|
||||
display(df)
|
||||
except Exception as e:
|
||||
print(df)
|
||||
|
||||
if len(df) > 0 and auto_train:
|
||||
vanna_instance.add_question_sql(question=question, sql=sql)
|
||||
|
||||
# Only generate plotly code if visualize is True
|
||||
if visualize:
|
||||
try:
|
||||
plotly_code = vanna_instance.generate_plotly_code(
|
||||
question=question,
|
||||
sql=sql,
|
||||
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
|
||||
)
|
||||
fig = vanna_instance.get_plotly_figure(plotly_code=plotly_code, df=df)
|
||||
if print_results:
|
||||
try:
|
||||
display = __import__("IPython.display", fromlist=["display"]).display
|
||||
display(plotly_code)
|
||||
except Exception as e:
|
||||
print(plotly_code)
|
||||
|
||||
except Exception as e:
|
||||
# Print stack trace
|
||||
traceback.print_exc()
|
||||
print("Couldn't run plotly code: ", e)
|
||||
if print_results:
|
||||
return None
|
||||
else:
|
||||
return sql, df, None
|
||||
else:
|
||||
return sql, df, None
|
||||
|
||||
except Exception as e:
|
||||
print("Couldn't run sql: ", e)
|
||||
if print_results:
|
||||
return None
|
||||
else:
|
||||
return sql, None, None
|
||||
return sql, df, fig
|
||||
|
||||
|
||||
def display_image_in_pycharm(fig):
|
||||
"""Display image in PyCharm using matplotlib or PIL."""
|
||||
try:
|
||||
# Try to use IPython.display if available
|
||||
try:
|
||||
display = __import__("IPython.display", fromlist=["display"]).display
|
||||
Image = __import__("IPython.display", fromlist=["Image"]).Image
|
||||
img_bytes = fig.to_image(format="png", scale=2)
|
||||
display(Image(img_bytes))
|
||||
except AttributeError:
|
||||
print("fig does not have to_image method, using fig.savefig instead")
|
||||
fig.savefig("output.png")
|
||||
display(Image("output.png"))
|
||||
except ImportError:
|
||||
print("IPython.display not available, using matplotlib to show image")
|
||||
fig.show()
|
||||
except Exception as e:
|
||||
print(f"Failed to display image using IPython.display: {e}")
|
||||
traceback.print_exc()
|
||||
try:
|
||||
# Use matplotlib to show image
|
||||
fig.show()
|
||||
except Exception as e:
|
||||
print(f"Failed to display image using fig.show: {e}")
|
||||
traceback.print_exc()
|
||||
try:
|
||||
# Use PIL to show image
|
||||
img_bytes = io.BytesIO()
|
||||
fig.savefig(img_bytes, format='png')
|
||||
img_bytes.seek(0)
|
||||
pil_img = PILImage.open(img_bytes)
|
||||
pil_img.show()
|
||||
except Exception as e:
|
||||
print(f"Failed to display image using PIL: {e}")
|
||||
traceback.print_exc()
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
@ -0,0 +1,81 @@
|
||||
"""add target_tenant_id
|
||||
|
||||
Revision ID: 0c79d303c76d
|
||||
Revises: d20049ed0af6
|
||||
Create Date: 2025-05-28 10:19:55.445389
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '0c79d303c76d'
|
||||
down_revision = 'd20049ed0af6'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('accounts', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('target_tenant_id', sa.String(length=255), nullable=True))
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('target_tenant_id', sa.String(length=255), nullable=True))
|
||||
|
||||
with op.batch_alter_table('tenants', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('target_tenant_id', sa.String(length=255), nullable=True))
|
||||
|
||||
with op.batch_alter_table('tool_published_apps', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
|
||||
batch_op.add_column(sa.Column('conversation_id', models.types.StringUUID(), nullable=True))
|
||||
batch_op.add_column(sa.Column('file_key', sa.String(length=255), nullable=False))
|
||||
batch_op.add_column(sa.Column('mimetype', sa.String(length=255), nullable=False))
|
||||
batch_op.add_column(sa.Column('original_url', sa.String(length=2048), nullable=True))
|
||||
batch_op.add_column(sa.Column('name', sa.String(), nullable=False))
|
||||
batch_op.add_column(sa.Column('size', sa.Integer(), nullable=False))
|
||||
|
||||
with op.batch_alter_table('upload_files', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('file_id', sa.String(length=255), nullable=True))
|
||||
|
||||
with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op:
|
||||
batch_op.drop_index('workflow_conversation_variables_app_id_idx')
|
||||
batch_op.drop_index('workflow_conversation_variables_created_at_idx')
|
||||
batch_op.create_index('workflow__conversation_variables_app_id_idx', ['app_id'], unique=False)
|
||||
batch_op.create_index('workflow__conversation_variables_created_at_idx', ['created_at'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op:
|
||||
batch_op.drop_index('workflow__conversation_variables_created_at_idx')
|
||||
batch_op.drop_index('workflow__conversation_variables_app_id_idx')
|
||||
batch_op.create_index('workflow_conversation_variables_created_at_idx', ['created_at'], unique=False)
|
||||
batch_op.create_index('workflow_conversation_variables_app_id_idx', ['app_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('upload_files', schema=None) as batch_op:
|
||||
batch_op.drop_column('file_id')
|
||||
|
||||
with op.batch_alter_table('tool_published_apps', schema=None) as batch_op:
|
||||
batch_op.drop_column('size')
|
||||
batch_op.drop_column('name')
|
||||
batch_op.drop_column('original_url')
|
||||
batch_op.drop_column('mimetype')
|
||||
batch_op.drop_column('file_key')
|
||||
batch_op.drop_column('conversation_id')
|
||||
batch_op.drop_column('tenant_id')
|
||||
|
||||
with op.batch_alter_table('tenants', schema=None) as batch_op:
|
||||
batch_op.drop_column('target_tenant_id')
|
||||
|
||||
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||
batch_op.drop_column('target_tenant_id')
|
||||
|
||||
with op.batch_alter_table('accounts', schema=None) as batch_op:
|
||||
batch_op.drop_column('target_tenant_id')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -0,0 +1,33 @@
|
||||
"""providers table add private_key
|
||||
|
||||
Revision ID: 582c477e905b
|
||||
Revises: 8e00c75b3907
|
||||
Create Date: 2025-05-28 10:41:31.662951
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '582c477e905b'
|
||||
down_revision = '8e00c75b3907'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tenants', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('encrypt_private_key', sa.Text(), nullable=True))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('tenants', schema=None) as batch_op:
|
||||
batch_op.drop_column('encrypt_private_key')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,4 @@
|
||||
[virtualenvs]
|
||||
in-project = true
|
||||
create = true
|
||||
prefer-active-python = true
|
||||
@ -0,0 +1,49 @@
|
||||
from typing import Any
|
||||
|
||||
import toml # type: ignore
|
||||
|
||||
|
||||
def load_api_poetry_configs() -> dict[str, Any]:
|
||||
pyproject_toml = toml.load("api/pyproject.toml")
|
||||
return pyproject_toml["tool"]["poetry"]
|
||||
|
||||
|
||||
def load_all_dependency_groups() -> dict[str, dict[str, dict[str, Any]]]:
|
||||
configs = load_api_poetry_configs()
|
||||
configs_by_group = {"main": configs}
|
||||
for group_name in configs["group"]:
|
||||
configs_by_group[group_name] = configs["group"][group_name]
|
||||
dependencies_by_group = {group_name: base["dependencies"] for group_name, base in configs_by_group.items()}
|
||||
return dependencies_by_group
|
||||
|
||||
|
||||
def test_group_dependencies_sorted():
|
||||
for group_name, dependencies in load_all_dependency_groups().items():
|
||||
dependency_names = list(dependencies.keys())
|
||||
expected_dependency_names = sorted(set(dependency_names))
|
||||
section = f"tool.poetry.group.{group_name}.dependencies" if group_name else "tool.poetry.dependencies"
|
||||
assert expected_dependency_names == dependency_names, (
|
||||
f"Dependencies in group {group_name} are not sorted. "
|
||||
f"Check and fix [{section}] section in pyproject.toml file"
|
||||
)
|
||||
|
||||
|
||||
def test_group_dependencies_version_operator():
|
||||
for group_name, dependencies in load_all_dependency_groups().items():
|
||||
for dependency_name, specification in dependencies.items():
|
||||
version_spec = specification if isinstance(specification, str) else specification["version"]
|
||||
assert not version_spec.startswith("^"), (
|
||||
f"Please replace '{dependency_name} = {version_spec}' with '{dependency_name} = ~{version_spec[1:]}' "
|
||||
f"'^' operator is too wide and not allowed in the version specification."
|
||||
)
|
||||
|
||||
|
||||
def test_duplicated_dependency_crossing_groups() -> None:
|
||||
all_dependency_names: list[str] = []
|
||||
for dependencies in load_all_dependency_groups().values():
|
||||
dependency_names = list(dependencies.keys())
|
||||
all_dependency_names.extend(dependency_names)
|
||||
expected_all_dependency_names = set(all_dependency_names)
|
||||
assert sorted(expected_all_dependency_names) == sorted(all_dependency_names), (
|
||||
"Duplicated dependencies crossing groups are found"
|
||||
)
|
||||
@ -0,0 +1,7 @@
|
||||
from core.helper.marketplace import download_plugin_pkg
|
||||
|
||||
|
||||
def test_download_plugin_pkg():
|
||||
pkg = download_plugin_pkg("langgenius/bing:0.0.1@e58735424d2104f208c2bd683c5142e0332045b425927067acf432b26f3d970b")
|
||||
assert pkg is not None
|
||||
assert len(pkg) > 0
|
||||
@ -0,0 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
# rely on `poetry` in path
|
||||
if ! command -v poetry &> /dev/null; then
|
||||
echo "Installing Poetry ..."
|
||||
pip install poetry
|
||||
fi
|
||||
|
||||
# check poetry.lock in sync with pyproject.toml
|
||||
poetry check -C api --lock
|
||||
if [ $? -ne 0 ]; then
|
||||
# update poetry.lock
|
||||
# refreshing lockfile only without updating locked versions
|
||||
echo "poetry.lock is outdated, refreshing without updating locked versions ..."
|
||||
poetry lock -C api
|
||||
else
|
||||
echo "poetry.lock is ready."
|
||||
fi
|
||||
@ -0,0 +1,13 @@
|
||||
#!/bin/bash
|
||||
|
||||
# rely on `poetry` in path
|
||||
if ! command -v poetry &> /dev/null; then
|
||||
echo "Installing Poetry ..."
|
||||
pip install poetry
|
||||
fi
|
||||
|
||||
# refreshing lockfile, updating locked versions
|
||||
poetry update -C api
|
||||
|
||||
# check poetry.lock in sync with pyproject.toml
|
||||
poetry check -C api --lock
|
||||
@ -0,0 +1,4 @@
|
||||
|
||||
docker compose down
|
||||
|
||||
docker compose up -d
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 1.3 MiB |
Binary file not shown.
|
After Width: | Height: | Size: 790 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 257 KiB |
Loading…
Reference in New Issue