初始化合并后端代码

pull/22121/head
liuchangsheng@wisdomidata.com 1 year ago
parent 0720bc7408
commit ae25db7ad1

@ -162,7 +162,7 @@ COUCHBASE_BUCKET_NAME=Embeddings
COUCHBASE_SCOPE_NAME=_default COUCHBASE_SCOPE_NAME=_default
# Milvus configuration # Milvus configuration
MILVUS_URI=http://127.0.0.1:19530 MILVUS_URI=http://wsd.wisdomidata.com:19044
MILVUS_TOKEN= MILVUS_TOKEN=
MILVUS_USER=root MILVUS_USER=root
MILVUS_PASSWORD=Milvus MILVUS_PASSWORD=Milvus
@ -498,3 +498,21 @@ QUEUE_MONITOR_THRESHOLD=200
QUEUE_MONITOR_ALERT_EMAILS= QUEUE_MONITOR_ALERT_EMAILS=
# Monitor interval in minutes, default is 30 minutes # Monitor interval in minutes, default is 30 minutes
QUEUE_MONITOR_INTERVAL=30 QUEUE_MONITOR_INTERVAL=30
################################# 自定义配置
PLUGIN_UNIQUE_IDENTIFIERS=langgenius/ollama:0.0.3@9ded90ac00e8510119a24be7396ba77191c9610d5e1e29f59d68fa1229822fc7,langgenius/huggingface_tei:0.0.3@7ae4cd259ec7d6f95931de898b77f8b0d374cc95a72ecece82168f90107350bd
INIT_MODEL_LLM_BASE_URL=http://wsd.wisdomidata.com:19042
INIT_MODEL_LLM_NAME=qwq:32b
INIT_MODEL_LLM_PROVIDER=langgenius/ollama/ollama
INIT_MODEL_LLM_CONTEXT_SIZE=4096
INIT_MODEL_LLM_MAX_TOKENS=4096
INIT_MODEL_TEXT_EMBEDDING_BASE_URL=http://wsd.wisdomidata.com:19042
INIT_MODEL_TEXT_EMBEDDING_NAME=bge-m3:latest
INIT_MODEL_TEXT_EMBEDDING_PROVIDER=langgenius/ollama/ollama
INIT_MODEL_TEXT_EMBEDDING_CONTEXT_SIZE=4096
INIT_MODEL_TEXT_EMBEDDING_MAX_TOKENS=4096
INIT_MODEL_TEXT_EMBEDDING_RERANK_BASE_URL=http://wsd.wisdomidata.com:19086
INIT_MODEL_TEXT_EMBEDDING_RERANK_NAME=bge-reranker-large
INIT_MODEL_TEXT_EMBEDDING_RERANK_PROVIDER=langgenius/huggingface_tei/huggingface_tei

@ -60,6 +60,7 @@ def initialize_extensions(app: DifyApp):
ext_storage, ext_storage,
ext_timezone, ext_timezone,
ext_warnings, ext_warnings,
ext_vanna_server,
) )
extensions = [ extensions = [
@ -85,6 +86,7 @@ def initialize_extensions(app: DifyApp):
ext_commands, ext_commands,
ext_otel, ext_otel,
ext_request_logging, ext_request_logging,
ext_vanna_server,
] ]
for ext in extensions: for ext in extensions:
short_name = ext.__name__.split(".")[-1] short_name = ext.__name__.split(".")[-1]

@ -0,0 +1,42 @@
#data_source:
# type: upload_file
# info_list:
# data_source_type: upload_file
# file_info_list:
# file_ids:
# - none
indexing_technique: high_quality
process_rule:
rules:
pre_processing_rules:
- id: remove_extra_spaces
enabled: true
- id: remove_urls_emails
enabled: true
segmentation:
separator: '&&&&&'
max_tokens: 500
chunk_overlap: 50
mode: custom
doc_form: text_model
doc_language: Chinese
retrieval_model:
search_method: hybrid_search
reranking_enable: true
reranking_mode: weighted_score
reranking_model:
reranking_provider_name: langgenius/huggingface_tei/huggingface_tei
reranking_model_name: bge-reranker-large
weights:
weight_type: customized
vector_setting:
vector_weight: 0.7
embedding_provider_name: ''
embedding_model_name: ''
keyword_setting:
keyword_weight: 0.3
top_k: 10
score_threshold_enabled: false
score_threshold: 0
embedding_model: 'bge-m3:latest'
embedding_model_provider: langgenius/ollama/ollama

@ -0,0 +1,11 @@
model: ${INIT_MODEL_TEXT_EMBEDDING_NAME}
model_type: text-embedding
credentials:
mode: chat
context_size: ${INIT_MODEL_TEXT_EMBEDDING_CONTEXT_SIZE}
max_tokens: ${INIT_MODEL_TEXT_EMBEDDING_MAX_TOKENS}
vision_support: false
function_call_support: false
base_url: ${INIT_MODEL_TEXT_EMBEDDING_BASE_URL}
load_balancing:
enabled: false

@ -0,0 +1,6 @@
model: ${INIT_MODEL_TEXT_EMBEDDING_RERANK_NAME}
model_type: rerank
credentials:
server_url: ${INIT_MODEL_TEXT_EMBEDDING_RERANK_BASE_URL}
load_balancing:
enabled: false

@ -0,0 +1,11 @@
model: ${INIT_MODEL_LLM_NAME}
model_type: llm
credentials:
mode: chat
context_size: ${INIT_MODEL_LLM_CONTEXT_SIZE}
max_tokens: ${INIT_MODEL_LLM_MAX_TOKENS}
vision_support: false
function_call_support: false
base_url: ${INIT_MODEL_LLM_BASE_URL}
load_balancing:
enabled: false

@ -0,0 +1,38 @@
import yaml
from pathlib import Path
def get_init_knowledge_config(config:dict) -> dict :
return get_ext_config(file_name="dataset_config.yml", config=config)
def get_ext_config(file_name:str, config:dict = None,params : dict = None) -> dict :
# 获取当前脚本所在的目录
current_dir = Path(__file__).resolve().parent
# 构造绝对路径
config_path = current_dir / "ext" / file_name
# 读取 YAML 文件
with open(config_path, "r") as f:
config_data = yaml.safe_load(f) # 使用 safe_load 避免执行任意代码
config_data = replace_placeholders(data = config_data, params = params)
if config is not None:
config_data={**config_data,**config}
return config_data
# 定义一个函数,用于替换 YAML 中的占位符
def replace_placeholders(data, params:dict = None) -> dict:
if params is not None:
if isinstance(data, dict):
# 如果是字典,递归处理每个键值对
return {k: replace_placeholders(v, params) for k, v in data.items()}
elif isinstance(data, list):
# 如果是列表,递归处理每个元素
return [replace_placeholders(item, params) for item in data]
elif isinstance(data, str):
# 如果是字符串,尝试替换占位符
for key, value in params.items():
placeholder = f"${{{key}}}" # 构造占位符格式,例如 ${DB_HOST}
data = data.replace(placeholder, value)
return data
# 其他类型直接返回
return data

@ -868,6 +868,63 @@ class AccountConfig(BaseSettings):
default=False, default=False,
) )
class ExtConfig(BaseSettings):
PLUGIN_UNIQUE_IDENTIFIERS: str = Field(
description="Duration in minutes for which a account deletion token remains valid",
default="langgenius/ollama:0.0.3@9ded90ac00e8510119a24be7396ba77191c9610d5e1e29f59d68fa1229822fc7",
)
INIT_MODEL_LLM_BASE_URL: str = Field(
description="Duration in minutes for which a account deletion token remains valid",
default="http://120.46.154.21:19042",
)
INIT_MODEL_LLM_NAME: str = Field(
description="Duration in minutes for which a account deletion token remains valid",
default="qwq:32b",
)
INIT_MODEL_LLM_PROVIDER: str = Field(
description="Duration in minutes for which a account deletion token remains valid",
default="langgenius/ollama/ollama",
)
INIT_MODEL_LLM_CONTEXT_SIZE: str = Field(
description="Duration in minutes for which a account deletion token remains valid",
default="4096",
)
INIT_MODEL_LLM_MAX_TOKENS: str = Field(
description="Duration in minutes for which a account deletion token remains valid",
default="4096",
)
INIT_MODEL_TEXT_EMBEDDING_BASE_URL: str = Field(
description="Duration in minutes for which a account deletion token remains valid",
default="http://120.46.154.21:19042",
)
INIT_MODEL_TEXT_EMBEDDING_NAME: str = Field(
description="Duration in minutes for which a account deletion token remains valid",
default="bge-m3:latest",
)
INIT_MODEL_TEXT_EMBEDDING_PROVIDER: str = Field(
description="Duration in minutes for which a account deletion token remains valid",
default="langgenius/ollama/ollama",
)
INIT_MODEL_TEXT_EMBEDDING_CONTEXT_SIZE: str = Field(
description="Duration in minutes for which a account deletion token remains valid",
default="4096",
)
INIT_MODEL_TEXT_EMBEDDING_MAX_TOKENS: str = Field(
description="Duration in minutes for which a account deletion token remains valid",
default="4096",
)
INIT_MODEL_TEXT_EMBEDDING_RERANK_BASE_URL: str = Field(
description="Duration in minutes for which a account deletion token remains valid",
default="http://120.46.154.21:19086",
)
INIT_MODEL_TEXT_EMBEDDING_RERANK_NAME: str = Field(
description="Duration in minutes for which a account deletion token remains valid",
default="bge-reranker-large",
)
INIT_MODEL_TEXT_EMBEDDING_RERANK_PROVIDER: str = Field(
description="Duration in minutes for which a account deletion token remains valid",
default="langgenius/huggingface_tei/huggingface_tei",
)
class FeatureConfig( class FeatureConfig(
# place the configs in alphabet order # place the configs in alphabet order
@ -902,5 +959,6 @@ class FeatureConfig(
# hosted services config # hosted services config
HostedServiceConfig, HostedServiceConfig,
CeleryBeatConfig, CeleryBeatConfig,
ExtConfig
): ):
pass pass

@ -7,6 +7,7 @@ from .explore.audio import ChatAudioApi, ChatTextApi
from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi from .explore.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi
from .explore.conversation import ( from .explore.conversation import (
ConversationApi, ConversationApi,
ConversationBatchApi,
ConversationListApi, ConversationListApi,
ConversationPinApi, ConversationPinApi,
ConversationRenameApi, ConversationRenameApi,
@ -129,6 +130,13 @@ api.add_resource(
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>", "/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>",
endpoint="installed_app_conversation", endpoint="installed_app_conversation",
) )
api.add_resource(
ConversationBatchApi,
"/installed-apps/<uuid:installed_app_id>/conversations/batch/remove",
endpoint="installed_app_conversation_batch_remove",
)
api.add_resource( api.add_resource(
ConversationPinApi, ConversationPinApi,
"/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin", "/installed-apps/<uuid:installed_app_id>/conversations/<uuid:c_id>/pin",
@ -170,6 +178,7 @@ from .tag import tags
# Import workspace controllers # Import workspace controllers
from .workspace import ( from .workspace import (
account, account,
account_ext,
agent_providers, agent_providers,
endpoint, endpoint,
load_balancing_config, load_balancing_config,

@ -244,7 +244,31 @@ class RefreshTokenApi(Resource):
return {"result": "fail", "data": str(e)}, 401 return {"result": "fail", "data": str(e)}, 401
class SingleSignApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("email", type=str, required=True, location="json")
parser.add_argument("tenantId", type=str, required=True, location="json")
args = parser.parse_args()
try:
account = AccountService.authenticate_email(args["email"], args["tenantId"])
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountNotFoundError:
if FeatureService.get_system_features().is_allow_register:
return {"result": "fail", "code": "account_not_found"}
else:
raise AccountNotFound()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args["email"])
return {"result": "success", "data": token_pair.model_dump()}
api.add_resource(LoginApi, "/login") api.add_resource(LoginApi, "/login")
api.add_resource(SingleSignApi, "/single/login")
api.add_resource(LogoutApi, "/logout") api.add_resource(LogoutApi, "/logout")
api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login") api.add_resource(EmailCodeLoginSendEmailApi, "/email-code-login")
api.add_resource(EmailCodeLoginApi, "/email-code-login/validity") api.add_resource(EmailCodeLoginApi, "/email-code-login/validity")

@ -15,6 +15,7 @@ from services.conversation_service import ConversationService
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
from services.web_conversation_service import WebConversationService from services.web_conversation_service import WebConversationService
import ast
class ConversationListApi(InstalledAppResource): class ConversationListApi(InstalledAppResource):
@marshal_with(conversation_infinite_scroll_pagination_fields) @marshal_with(conversation_infinite_scroll_pagination_fields)
@ -65,6 +66,27 @@ class ConversationApi(InstalledAppResource):
return {"result": "success"}, 204 return {"result": "success"}, 204
class ConversationBatchApi(InstalledAppResource):
def post(self, installed_app):
app_model = installed_app.app
app_mode = AppMode.value_of(app_model.mode)
parser = reqparse.RequestParser()
parser.add_argument("conv_ids", location="json",)
args = parser.parse_args()
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
conv_ids = conv_ids = ast.literal_eval(args['conv_ids'])
try:
ConversationService.batch_delete(app_model, conv_ids, current_user)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")
WebConversationService.batch_unpin(app_model, conv_ids, current_user)
return {"result": "success"}, 204
class ConversationRenameApi(InstalledAppResource): class ConversationRenameApi(InstalledAppResource):
@marshal_with(simple_conversation_fields) @marshal_with(simple_conversation_fields)

@ -0,0 +1,87 @@
from flask_restful import Resource, reqparse # type: ignore
import flask_login
from unstructured.utils import first
from controllers.console import api
from controllers.console.wraps import setup_required
from services.ext.account_ext_service import AccountExtService, TenantExtService
from models.account import (
Account,
Tenant,
)
from extensions.ext_database import db
class AccountsApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("accounts",
type=lambda x: x if isinstance(x, list) else [] ,
required=True,
location="json")
parser.add_argument("target_tenant_id", type=str,
required=True,
location="json")
args = parser.parse_args()
target_tenant_id = args["target_tenant_id"]
accounts = args["accounts"]
AccountExtService.update_account_list(accounts=accounts,
target_tenant_id=target_tenant_id)
return {}
class LoginAccountInfo:
def __init__(self, id, name, tenant_id):
self.id = id
self.name = name
self.tenant_id = tenant_id
def to_dict(self):
return {
"id": self.id,
"name": self.name,
"tenant_id": self.tenant_id,
}
class LoginAccountsApi(Resource):
@setup_required
def get(self):
current_user = flask_login.current_user
# current_user_info = db.session.query(Account).filter(Account.id==current_user.id).first()
tenant = current_user.current_tenant
login_account = LoginAccountInfo(id=current_user.id, name=current_user.name, tenant_id=tenant.id)
return login_account.to_dict()
class TenantEnableApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("target_tenant_id", type=str, required=True, location="json")
parser.add_argument("target_tenant_name", type=str, required=True, location="json")
args = parser.parse_args()
target_tenant_id = args["target_tenant_id"]
target_tenant_name = args["target_tenant_name"]
tenant_account_info = TenantExtService.enable_tenant(target_tenant_id=target_tenant_id,target_tenant_name=target_tenant_name)
return tenant_account_info.to_dict(),200
class TenantInitApi(Resource):
@setup_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("target_tenant_id", type=str, required=True, location="json")
parser.add_argument("target_tenant_name", type=str, required=True, location="json")
args = parser.parse_args()
target_tenant_id = args["target_tenant_id"]
target_tenant_name = args["target_tenant_name"]
tenant_data = TenantExtService.init_tenant(target_tenant_id=target_tenant_id,target_tenant_name=target_tenant_name)
return tenant_data.to_dict(),200
api.add_resource(AccountsApi, "/accounts/update")
api.add_resource(TenantEnableApi, "/tenant/enable")
api.add_resource(TenantInitApi, "/tenant/init")
api.add_resource(LoginAccountsApi, "/login/account/info")

@ -28,7 +28,7 @@ from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DocumentService from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService from services.file_service import FileService
from configs.ext_config import get_init_knowledge_config
class DocumentAddByTextApi(DatasetApiResource): class DocumentAddByTextApi(DatasetApiResource):
"""Resource for documents.""" """Resource for documents."""
@ -161,12 +161,15 @@ class DocumentAddByFileApi(DatasetApiResource):
def post(self, tenant_id, dataset_id): def post(self, tenant_id, dataset_id):
"""Create document by upload file.""" """Create document by upload file."""
args = {} args = {}
file_id = None
if "data" in request.form: if "data" in request.form:
args = json.loads(request.form["data"]) args = json.loads(request.form["data"])
if "doc_form" not in args: if "doc_form" not in args:
args["doc_form"] = "text_model" args["doc_form"] = "text_model"
if "doc_language" not in args: if "doc_language" not in args:
args["doc_language"] = "English" args["doc_language"] = "English"
if "file_id" in request.form:
file_id = int(request.form["file_id"])
# get dataset info # get dataset info
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
@ -199,12 +202,16 @@ class DocumentAddByFileApi(DatasetApiResource):
mimetype=file.mimetype, mimetype=file.mimetype,
user=current_user, user=current_user,
source="datasets", source="datasets",
file_id=file_id
) )
data_source = { data_source = {
"type": "upload_file", "type": "upload_file",
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
} }
args["data_source"] = data_source args["data_source"] = data_source
# 取默认的值
args = get_init_knowledge_config(args)
# validate args # validate args
knowledge_config = KnowledgeConfig(**args) knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config) DocumentService.document_create_args_validate(knowledge_config)

@ -157,6 +157,9 @@ SupportedComparisonOperator = Literal[
# for time # for time
"before", "before",
"after", "after",
# 扩展
"in",
"not in",
] ]

@ -0,0 +1,156 @@
import base64
import concurrent.futures
import logging
import queue
import re
import threading
from collections.abc import Iterable
from typing import Optional
from core.app.entities.queue_entities import (
MessageQueueMessage,
QueueAgentMessageEvent,
QueueLLMChunkEvent,
QueueNodeSucceededEvent,
QueueTextChunkEvent,
WorkflowQueueMessage,
)
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.message_entities import TextPromptMessageContent
from core.model_runtime.entities.model_entities import ModelType
class AudioTrunk:
def __init__(self, status: str, audio):
self.audio = audio
self.status = status
def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str):
if not text_content or text_content.isspace():
return
return model_instance.invoke_tts(
content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice
)
def _process_future(
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None],
audio_queue: queue.Queue[AudioTrunk],
):
while True:
try:
future = future_queue.get()
if future is None:
break
invoke_result = future.result()
if not invoke_result:
continue
for audio in invoke_result:
audio_base64 = base64.b64encode(bytes(audio))
audio_queue.put(AudioTrunk("responding", audio=audio_base64))
except Exception as e:
logging.getLogger(__name__).warning(e)
break
audio_queue.put(AudioTrunk("finish", b""))
class AppGeneratorTTSPublisher:
def __init__(self, tenant_id: str, voice: str, language: Optional[str] = None):
self.logger = logging.getLogger(__name__)
self.tenant_id = tenant_id
self.msg_text = ""
self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue()
self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue()
self.match = re.compile(r"[。.!?]")
self.model_manager = ModelManager()
self.model_instance = self.model_manager.get_default_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.TTS
)
self.voices = self.model_instance.get_tts_voices(language=language)
values = [voice.get("value") for voice in self.voices]
self.voice = voice
if not voice or voice not in values:
self.voice = self.voices[0].get("value")
self.MAX_SENTENCE = 2
self._last_audio_event: Optional[AudioTrunk] = None
# FIXME better way to handle this threading.start
threading.Thread(target=self._runtime).start()
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=3)
def publish(self, message: WorkflowQueueMessage | MessageQueueMessage | None, /):
self._msg_queue.put(message)
def _runtime(self):
future_queue: queue.Queue[concurrent.futures.Future[Iterable[bytes] | None] | None] = queue.Queue()
threading.Thread(target=_process_future, args=(future_queue, self._audio_queue)).start()
while True:
try:
message = self._msg_queue.get()
if message is None:
if self.msg_text and len(self.msg_text.strip()) > 0:
futures_result = self.executor.submit(
_invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice
)
future_queue.put(futures_result)
break
elif isinstance(message.event, QueueAgentMessageEvent | QueueLLMChunkEvent):
message_content = message.event.chunk.delta.message.content
if not message_content:
continue
if isinstance(message_content, str):
self.msg_text += message_content
elif isinstance(message_content, list):
for content in message_content:
if not isinstance(content, TextPromptMessageContent):
continue
self.msg_text += content.data
elif isinstance(message.event, QueueTextChunkEvent):
self.msg_text += message.event.text
elif isinstance(message.event, QueueNodeSucceededEvent):
if message.event.outputs is None:
continue
self.msg_text += message.event.outputs.get("output", "")
self.last_message = message
sentence_arr, text_tmp = self._extract_sentence(self.msg_text)
if len(sentence_arr) >= min(self.MAX_SENTENCE, 7):
self.MAX_SENTENCE += 1
text_content = "".join(sentence_arr)
futures_result = self.executor.submit(
_invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice
)
future_queue.put(futures_result)
if text_tmp:
self.msg_text = text_tmp
else:
self.msg_text = ""
except Exception as e:
self.logger.warning(e)
break
future_queue.put(None)
def check_and_get_audio(self):
try:
if self._last_audio_event and self._last_audio_event.status == "finish":
if self.executor:
self.executor.shutdown(wait=False)
return self._last_audio_event
audio = self._audio_queue.get_nowait()
if audio and audio.status == "finish":
self.executor.shutdown(wait=False)
if audio:
self._last_audio_event = audio
return audio
except queue.Empty:
return None
def _extract_sentence(self, org_text):
tx = self.match.finditer(org_text)
start = 0
result = []
for i in tx:
end = i.regs[0][1]
result.append(org_text[start:end])
start = end
return result, org_text[start:]

@ -0,0 +1,191 @@
import logging
from threading import Thread
from typing import Optional, Union
from flask import Flask, current_app
from configs import dify_config
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
)
from core.app.entities.queue_entities import (
QueueAnnotationReplyEvent,
QueueMessageFileEvent,
QueueRetrieverResourcesEvent,
)
from core.app.entities.task_entities import (
EasyUITaskState,
MessageFileStreamResponse,
MessageReplaceStreamResponse,
MessageStreamResponse,
WorkflowTaskState,
)
from core.llm_generator.llm_generator import LLMGenerator
from core.tools.tool_file_manager import ToolFileManager
from extensions.ext_database import db
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
from services.annotation_service import AppAnnotationService
class MessageCycleManage:
def __init__(
self,
*,
application_generate_entity: Union[
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
AgentChatAppGenerateEntity,
AdvancedChatAppGenerateEntity,
],
task_state: Union[EasyUITaskState, WorkflowTaskState],
) -> None:
self._application_generate_entity = application_generate_entity
self._task_state = task_state
def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
"""
Generate conversation name.
:param conversation: conversation
:param query: query
:return: thread
"""
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
return None
is_first_message = self._application_generate_entity.conversation_id is None
extras = self._application_generate_entity.extras
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
if auto_generate_conversation_name and is_first_message:
# start generate thread
thread = Thread(
target=self._generate_conversation_name_worker,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"conversation_id": conversation_id,
"query": query,
},
)
thread.start()
return thread
return None
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():
# get conversation and message
conversation = db.session.query(Conversation).filter(Conversation.id == conversation_id).first()
if not conversation:
return
if conversation.mode != AppMode.COMPLETION.value:
app_model = conversation.app
if not app_model:
return
# generate conversation name
try:
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, query)
conversation.name = name
except Exception as e:
if dify_config.DEBUG:
logging.exception(f"generate conversation name failed, conversation_id: {conversation_id}")
pass
db.session.merge(conversation)
db.session.commit()
db.session.close()
def _handle_annotation_reply(self, event: QueueAnnotationReplyEvent) -> Optional[MessageAnnotation]:
"""
Handle annotation reply.
:param event: event
:return:
"""
annotation = AppAnnotationService.get_annotation_by_id(event.message_annotation_id)
if annotation:
account = annotation.account
self._task_state.metadata["annotation_reply"] = {
"id": annotation.id,
"account": {"id": annotation.account_id, "name": account.name if account else "Dify user"},
}
return annotation
return None
def _handle_retriever_resources(self, event: QueueRetrieverResourcesEvent) -> None:
"""
Handle retriever resources.
:param event: event
:return:
"""
if self._application_generate_entity.app_config.additional_features.show_retrieve_source:
self._task_state.metadata["retriever_resources"] = event.retriever_resources
def _message_file_to_stream_response(self, event: QueueMessageFileEvent) -> Optional[MessageFileStreamResponse]:
"""
Message file to stream response.
:param event: event
:return:
"""
message_file = db.session.query(MessageFile).filter(MessageFile.id == event.message_file_id).first()
if message_file and message_file.url is not None:
# get tool file id
tool_file_id = message_file.url.split("/")[-1]
# trim extension
tool_file_id = tool_file_id.split(".")[0]
# get extension
if "." in message_file.url:
extension = f".{message_file.url.split('.')[-1]}"
if len(extension) > 10:
extension = ".bin"
else:
extension = ".bin"
# add sign url to local file
if message_file.url.startswith("http"):
url = message_file.url
else:
url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension)
return MessageFileStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_file.id,
type=message_file.type,
belongs_to=message_file.belongs_to or "user",
url=url,
)
return None
def _message_to_stream_response(
self, answer: str, message_id: str, from_variable_selector: Optional[list[str]] = None
) -> MessageStreamResponse:
"""
Message to stream response.
:param answer: answer
:param message_id: message id
:return:
"""
return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer,
from_variable_selector=from_variable_selector,
)
def _message_replace_to_stream_response(self, answer: str) -> MessageReplaceStreamResponse:
"""
Message replace to stream response.
:param answer: answer
:return:
"""
return MessageReplaceStreamResponse(task_id=self._application_generate_entity.task_id, answer=answer)

@ -0,0 +1,964 @@
import json
import time
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
from uuid import uuid4
from sqlalchemy import func, select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
QueueLoopCompletedEvent,
QueueLoopNextEvent,
QueueLoopStartEvent,
QueueNodeExceptionEvent,
QueueNodeFailedEvent,
QueueNodeInIterationFailedEvent,
QueueNodeInLoopFailedEvent,
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueParallelBranchRunFailedEvent,
QueueParallelBranchRunStartedEvent,
QueueParallelBranchRunSucceededEvent,
)
from core.app.entities.task_entities import (
AgentLogStreamResponse,
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
LoopNodeCompletedStreamResponse,
LoopNodeNextStreamResponse,
LoopNodeStartStreamResponse,
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
ParallelBranchFinishedStreamResponse,
ParallelBranchStartStreamResponse,
WorkflowFinishStreamResponse,
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.tools.tool_manager import ToolManager
from core.workflow.entities.node_entities import NodeRunMetadataKey
from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_entry import WorkflowEntry
from models.account import Account
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.model import EndUser
from models.workflow import (
Workflow,
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
WorkflowNodeExecutionTriggeredFrom,
WorkflowRun,
WorkflowRunStatus,
)
from .exc import WorkflowRunNotFoundError
class WorkflowCycleManage:
def __init__(
self,
*,
application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity],
workflow_system_variables: dict[SystemVariableKey, Any],
) -> None:
self._workflow_run: WorkflowRun | None = None
self._workflow_node_executions: dict[str, WorkflowNodeExecution] = {}
self._application_generate_entity = application_generate_entity
self._workflow_system_variables = workflow_system_variables
def _handle_workflow_run_start(
self,
*,
session: Session,
workflow_id: str,
user_id: str,
created_by_role: CreatedByRole,
) -> WorkflowRun:
workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
workflow = session.scalar(workflow_stmt)
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where(
WorkflowRun.tenant_id == workflow.tenant_id,
WorkflowRun.app_id == workflow.app_id,
)
max_sequence = session.scalar(max_sequence_stmt) or 0
new_sequence_number = max_sequence + 1
inputs = {**self._application_generate_entity.inputs}
for key, value in (self._workflow_system_variables or {}).items():
if key.value == "conversation":
continue
inputs[f"sys.{key.value}"] = value
triggered_from = (
WorkflowRunTriggeredFrom.DEBUGGING
if self._application_generate_entity.invoke_from == InvokeFrom.DEBUGGER
else WorkflowRunTriggeredFrom.APP_RUN
)
# handle special values
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run
# TODO: This workflow_run_id should always not be None, maybe we can use a more elegant way to handle this
workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID) or uuid4())
workflow_run = WorkflowRun()
workflow_run.id = workflow_run_id
workflow_run.tenant_id = workflow.tenant_id
workflow_run.app_id = workflow.app_id
workflow_run.sequence_number = new_sequence_number
workflow_run.workflow_id = workflow.id
workflow_run.type = workflow.type
workflow_run.triggered_from = triggered_from.value
workflow_run.version = workflow.version
workflow_run.graph = workflow.graph
workflow_run.inputs = json.dumps(inputs)
workflow_run.status = WorkflowRunStatus.RUNNING
workflow_run.created_by_role = created_by_role
workflow_run.created_by = user_id
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_run)
return workflow_run
def _handle_workflow_run_success(
self,
*,
session: Session,
workflow_run_id: str,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
"""
Workflow run success
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param outputs: outputs
:param conversation_id: conversation id
:return:
"""
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
outputs = WorkflowEntry.handle_special_values(outputs)
workflow_run.status = WorkflowRunStatus.SUCCEEDED.value
workflow_run.outputs = json.dumps(outputs or {})
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
return workflow_run
def _handle_workflow_run_partial_success(
self,
*,
session: Session,
workflow_run_id: str,
start_at: float,
total_tokens: int,
total_steps: int,
outputs: Mapping[str, Any] | None = None,
exceptions_count: int = 0,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
workflow_run.outputs = json.dumps(outputs or {})
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
return workflow_run
def _handle_workflow_run_failed(
self,
*,
session: Session,
workflow_run_id: str,
start_at: float,
total_tokens: int,
total_steps: int,
status: WorkflowRunStatus,
error: str,
conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None,
exceptions_count: int = 0,
) -> WorkflowRun:
"""
Workflow run failed
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param status: status
:param error: error message
:return:
"""
workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
workflow_run.status = status.value
workflow_run.error = error
workflow_run.elapsed_time = time.perf_counter() - start_at
workflow_run.total_tokens = total_tokens
workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count
stmt = select(WorkflowNodeExecution.node_execution_id).where(
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
WorkflowNodeExecution.app_id == workflow_run.app_id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
)
ids = session.scalars(stmt).all()
# Use self._get_workflow_node_execution here to make sure the cache is updated
running_workflow_node_executions = [
self._get_workflow_node_execution(session=session, node_execution_id=id) for id in ids if id
]
for workflow_node_execution in running_workflow_node_executions:
now = datetime.now(UTC).replace(tzinfo=None)
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error
workflow_node_execution.finished_at = now
workflow_node_execution.elapsed_time = (now - workflow_node_execution.created_at).total_seconds()
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_run=workflow_run,
conversation_id=conversation_id,
user_id=trace_manager.user_id,
)
)
return workflow_run
def _handle_node_execution_start(
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution:
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = str(uuid4())
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.index = event.node_run_index
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.execution_metadata = json.dumps(
{
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
}
)
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
def _handle_workflow_node_execution_success(
self, *, session: Session, event: QueueNodeSucceededEvent
) -> WorkflowNodeExecution:
workflow_node_execution = self._get_workflow_node_execution(
session=session, node_execution_id=event.node_execution_id
)
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
execution_metadata_dict = dict(event.execution_metadata or {})
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_failed(
self,
*,
session: Session,
event: QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:return:
"""
workflow_node_execution = self._get_workflow_node_execution(
session=session, node_execution_id=event.node_execution_id
)
inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds()
execution_metadata = (
json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
)
process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = (
WorkflowNodeExecutionStatus.FAILED.value
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION.value
)
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.process_data = json.dumps(process_data) if process_data else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution = session.merge(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_retried(
self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution:
"""
Workflow node execution failed
:param event: queue node failed event
:return:
"""
created_at = event.start_at
finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - created_at).total_seconds()
inputs = WorkflowEntry.handle_special_values(event.inputs)
outputs = WorkflowEntry.handle_special_values(event.outputs)
origin_metadata = {
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
NodeRunMetadataKey.LOOP_ID: event.in_loop_id,
}
merged_metadata = (
{**jsonable_encoder(event.execution_metadata), **origin_metadata}
if event.execution_metadata is not None
else origin_metadata
)
execution_metadata = json.dumps(merged_metadata)
workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = str(uuid4())
workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.title = event.node_data.title
workflow_node_execution.status = WorkflowNodeExecutionStatus.RETRY.value
workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_at = created_at
workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.error = event.error
workflow_node_execution.inputs = json.dumps(inputs) if inputs else None
workflow_node_execution.outputs = json.dumps(outputs) if outputs else None
workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.index = event.node_run_index
session.add(workflow_node_execution)
self._workflow_node_executions[event.node_execution_id] = workflow_node_execution
return workflow_node_execution
#################################################
# to stream responses #
#################################################
def _workflow_start_to_stream_response(
self,
*,
session: Session,
task_id: str,
workflow_run: WorkflowRun,
) -> WorkflowStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return WorkflowStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=WorkflowStartStreamResponse.Data(
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
inputs=dict(workflow_run.inputs_dict or {}),
created_at=int(workflow_run.created_at.timestamp()),
),
)
def _workflow_finish_to_stream_response(
self,
*,
session: Session,
task_id: str,
workflow_run: WorkflowRun,
) -> WorkflowFinishStreamResponse:
created_by = None
if workflow_run.created_by_role == CreatedByRole.ACCOUNT:
stmt = select(Account).where(Account.id == workflow_run.created_by)
account = session.scalar(stmt)
if account:
created_by = {
"id": account.id,
"name": account.name,
"email": account.email,
}
elif workflow_run.created_by_role == CreatedByRole.END_USER:
stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
end_user = session.scalar(stmt)
if end_user:
created_by = {
"id": end_user.id,
"user": end_user.session_id,
}
else:
raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
return WorkflowFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=WorkflowFinishStreamResponse.Data(
id=workflow_run.id,
workflow_id=workflow_run.workflow_id,
sequence_number=workflow_run.sequence_number,
status=workflow_run.status,
outputs=dict(workflow_run.outputs_dict) if workflow_run.outputs_dict else None,
error=workflow_run.error,
elapsed_time=workflow_run.elapsed_time,
total_tokens=workflow_run.total_tokens,
total_steps=workflow_run.total_steps,
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(workflow_run.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(dict(workflow_run.outputs_dict)),
exceptions_count=workflow_run.exceptions_count,
),
)
def _workflow_node_start_to_stream_response(
self,
*,
session: Session,
event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeStartStreamResponse]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
if not workflow_node_execution.workflow_run_id:
return None
response = NodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
data=NodeStartStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
title=workflow_node_execution.title,
index=workflow_node_execution.index,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
created_at=int(workflow_node_execution.created_at.timestamp()),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
parallel_run_id=event.parallel_mode_run_id,
agent_strategy=event.agent_strategy,
),
)
# extras logic
if event.node_type == NodeType.TOOL:
node_data = cast(ToolNodeData, event.node_data)
response.data.extras["icon"] = ToolManager.get_tool_icon(
tenant_id=self._application_generate_entity.app_config.tenant_id,
provider_type=node_data.provider_type,
provider_id=node_data.provider_id,
)
return response
def _workflow_node_finish_to_stream_response(
self,
*,
session: Session,
event: QueueNodeSucceededEvent
| QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent
| QueueNodeInLoopFailedEvent
| QueueNodeExceptionEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
if not workflow_node_execution.workflow_run_id:
return None
if not workflow_node_execution.finished_at:
return None
return NodeFinishStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
data=NodeFinishStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
process_data=workflow_node_execution.process_data_dict,
outputs=workflow_node_execution.outputs_dict,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.execution_metadata_dict,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
),
)
def _workflow_node_retry_to_stream_response(
self,
*,
session: Session,
event: QueueNodeRetryEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None
if not workflow_node_execution.workflow_run_id:
return None
if not workflow_node_execution.finished_at:
return None
return NodeRetryStreamResponse(
task_id=task_id,
workflow_run_id=workflow_node_execution.workflow_run_id,
data=NodeRetryStreamResponse.Data(
id=workflow_node_execution.id,
node_id=workflow_node_execution.node_id,
node_type=workflow_node_execution.node_type,
index=workflow_node_execution.index,
title=workflow_node_execution.title,
predecessor_node_id=workflow_node_execution.predecessor_node_id,
inputs=workflow_node_execution.inputs_dict,
process_data=workflow_node_execution.process_data_dict,
outputs=workflow_node_execution.outputs_dict,
status=workflow_node_execution.status,
error=workflow_node_execution.error,
elapsed_time=workflow_node_execution.elapsed_time,
execution_metadata=workflow_node_execution.execution_metadata_dict,
created_at=int(workflow_node_execution.created_at.timestamp()),
finished_at=int(workflow_node_execution.finished_at.timestamp()),
files=self._fetch_files_from_node_outputs(workflow_node_execution.outputs_dict or {}),
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
retry_index=event.retry_index,
),
)
def _workflow_parallel_branch_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return ParallelBranchStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=ParallelBranchStartStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
created_at=int(time.time()),
),
)
def _workflow_parallel_branch_finished_to_stream_response(
self,
*,
session: Session,
task_id: str,
workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return ParallelBranchFinishedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=ParallelBranchFinishedStreamResponse.Data(
parallel_id=event.parallel_id,
parallel_branch_id=event.parallel_start_node_id,
parent_parallel_id=event.parent_parallel_id,
parent_parallel_start_node_id=event.parent_parallel_start_node_id,
iteration_id=event.in_iteration_id,
loop_id=event.in_loop_id,
status="succeeded" if isinstance(event, QueueParallelBranchRunSucceededEvent) else "failed",
error=event.error if isinstance(event, QueueParallelBranchRunFailedEvent) else None,
created_at=int(time.time()),
),
)
def _workflow_iteration_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return IterationNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
def _workflow_iteration_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
) -> IterationNodeNextStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return IterationNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
index=event.index,
pre_iteration_output=event.output,
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
),
)
def _workflow_iteration_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
) -> IterationNodeCompletedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return IterationNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=IterationNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=event.outputs,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
def _workflow_loop_start_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopStartEvent
) -> LoopNodeStartStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return LoopNodeStartStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=LoopNodeStartStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
metadata=event.metadata or {},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
def _workflow_loop_next_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopNextEvent
) -> LoopNodeNextStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return LoopNodeNextStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=LoopNodeNextStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
index=event.index,
pre_loop_output=event.output,
created_at=int(time.time()),
extras={},
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
parallel_mode_run_id=event.parallel_mode_run_id,
duration=event.duration,
),
)
def _workflow_loop_completed_to_stream_response(
self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueLoopCompletedEvent
) -> LoopNodeCompletedStreamResponse:
# receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
_ = session
return LoopNodeCompletedStreamResponse(
task_id=task_id,
workflow_run_id=workflow_run.id,
data=LoopNodeCompletedStreamResponse.Data(
id=event.node_id,
node_id=event.node_id,
node_type=event.node_type.value,
title=event.node_data.title,
outputs=event.outputs,
created_at=int(time.time()),
extras={},
inputs=event.inputs or {},
status=WorkflowNodeExecutionStatus.SUCCEEDED
if event.error is None
else WorkflowNodeExecutionStatus.FAILED,
error=None,
elapsed_time=(datetime.now(UTC).replace(tzinfo=None) - event.start_at).total_seconds(),
total_tokens=event.metadata.get("total_tokens", 0) if event.metadata else 0,
execution_metadata=event.metadata,
finished_at=int(time.time()),
steps=event.steps,
parallel_id=event.parallel_id,
parallel_start_node_id=event.parallel_start_node_id,
),
)
def _fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any]) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict
:return:
"""
if not outputs_dict:
return []
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
# Remove None
files = [file for file in files if file]
# Flatten list
# Flatten the list of sequences into a single list of mappings
flattened_files = [file for sublist in files if sublist for file in sublist]
# Convert to tuple to match Sequence type
return tuple(flattened_files)
def _fetch_files_from_variable_value(self, value: Union[dict, list]) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from variable value
:param value: variable value
:return:
"""
if not value:
return []
files = []
if isinstance(value, list):
for item in value:
file = self._get_file_var_from_value(item)
if file:
files.append(file)
elif isinstance(value, dict):
file = self._get_file_var_from_value(value)
if file:
files.append(file)
return files
def _get_file_var_from_value(self, value: Union[dict, list]) -> Mapping[str, Any] | None:
"""
Get file var from value
:param value: variable value
:return:
"""
if not value:
return None
if isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY:
return value
elif isinstance(value, File):
return value.to_dict()
return None
def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
if self._workflow_run and self._workflow_run.id == workflow_run_id:
cached_workflow_run = self._workflow_run
cached_workflow_run = session.merge(cached_workflow_run)
return cached_workflow_run
stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalar(stmt)
if not workflow_run:
raise WorkflowRunNotFoundError(workflow_run_id)
self._workflow_run = workflow_run
return workflow_run
def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
if node_execution_id not in self._workflow_node_executions:
raise ValueError(f"Workflow node execution not found: {node_execution_id}")
cached_workflow_node_execution = self._workflow_node_executions[node_execution_id]
return session.merge(cached_workflow_node_execution)
def _handle_agent_log(self, task_id: str, event: QueueAgentLogEvent) -> AgentLogStreamResponse:
"""
Handle agent log
:param task_id: task id
:param event: agent log event
:return:
"""
return AgentLogStreamResponse(
task_id=task_id,
data=AgentLogStreamResponse.Data(
node_execution_id=event.node_execution_id,
id=event.id,
parent_id=event.parent_id,
label=event.label,
error=event.error,
status=event.status,
data=event.data,
metadata=event.metadata,
node_id=event.node_id,
),
)

@ -445,7 +445,7 @@ class IndexingRunner:
chunk_size=max_tokens, chunk_size=max_tokens,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
fixed_separator=separator, fixed_separator=separator,
separators=["\n\n", "", ". ", " ", ""], separators=["\n\n", "\n","", ". ", " ", ""],
embedding_model_instance=embedding_model_instance, embedding_model_instance=embedding_model_instance,
) )
else: else:
@ -454,7 +454,7 @@ class IndexingRunner:
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
chunk_size=automatic_rules["max_tokens"], chunk_size=automatic_rules["max_tokens"],
chunk_overlap=automatic_rules["chunk_overlap"], chunk_overlap=automatic_rules["chunk_overlap"],
separators=["\n\n", "", ". ", " ", ""], separators=["\n\n", "\n", "", ". ", " ", ""],
embedding_model_instance=embedding_model_instance, embedding_model_instance=embedding_model_instance,
) )

@ -0,0 +1,170 @@
from collections.abc import Mapping
from typing import Optional
import openai
from httpx import Timeout
from openai import OpenAI
from openai.types import ModerationCreateResponse
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.moderation_model import ModerationModel
class OpenAIModerationModel(ModerationModel):
"""
Model class for OpenAI text moderation model.
"""
def _invoke(self, model: str, credentials: dict, text: str, user: Optional[str] = None) -> bool:
"""
Invoke moderation model
:param model: model name
:param credentials: model credentials
:param text: text to moderate
:param user: unique user id
:return: false if text is safe, true otherwise
"""
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
# init model client
client = OpenAI(**credentials_kwargs)
# chars per chunk
length = self._get_max_characters_per_chunk(model, credentials)
text_chunks = [text[i : i + length] for i in range(0, len(text), length)]
max_text_chunks = self._get_max_chunks(model, credentials)
chunks = [text_chunks[i : i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
for text_chunk in chunks:
moderation_result = self._moderation_invoke(model=model, client=client, texts=text_chunk)
for result in moderation_result.results:
if result.flagged is True:
return True
return False
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
# transform credentials to kwargs for model instance
credentials_kwargs = self._to_credential_kwargs(credentials)
client = OpenAI(**credentials_kwargs)
# call moderation model
self._moderation_invoke(
model=model,
client=client,
texts=["ping"],
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
def _moderation_invoke(self, model: str, client: OpenAI, texts: list[str]) -> ModerationCreateResponse:
"""
Invoke moderation model
:param model: model name
:param client: model client
:param texts: texts to moderate
:return: false if text is safe, true otherwise
"""
# call moderation model
moderation_result = client.moderations.create(model=model, input=texts)
return moderation_result
def _get_max_characters_per_chunk(self, model: str, credentials: dict) -> int:
"""
Get max characters per chunk
:param model: model name
:param credentials: model credentials
:return: max characters per chunk
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK in model_schema.model_properties:
max_characters_per_chunk: int = model_schema.model_properties[ModelPropertyKey.MAX_CHARACTERS_PER_CHUNK]
return max_characters_per_chunk
return 2000
def _get_max_chunks(self, model: str, credentials: dict) -> int:
"""
Get max chunks for given embedding model
:param model: model name
:param credentials: model credentials
:return: max chunks
"""
model_schema = self.get_model_schema(model, credentials)
if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties:
max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
return max_chunks
return 1
def _to_credential_kwargs(self, credentials: Mapping) -> dict:
"""
Transform credentials to kwargs for model instance
:param credentials:
:return:
"""
credentials_kwargs = {
"api_key": credentials["openai_api_key"],
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"max_retries": 1,
}
if credentials.get("openai_api_base"):
openai_api_base = credentials["openai_api_base"].rstrip("/")
credentials_kwargs["base_url"] = openai_api_base + "/v1"
if "openai_organization" in credentials:
credentials_kwargs["organization"] = credentials["openai_organization"]
return credentials_kwargs
@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [openai.APIConnectionError, openai.APITimeoutError],
InvokeServerUnavailableError: [openai.InternalServerError],
InvokeRateLimitError: [openai.RateLimitError],
InvokeAuthorizationError: [openai.AuthenticationError, openai.PermissionDeniedError],
InvokeBadRequestError: [
openai.BadRequestError,
openai.NotFoundError,
openai.UnprocessableEntityError,
openai.APIError,
],
}

@ -0,0 +1,22 @@
- claude-3-haiku@20240307
- claude-3-opus@20240229
- claude-3-sonnet@20240229
- claude-3-5-sonnet-v2@20241022
- claude-3-5-sonnet@20240620
- gemini-1.0-pro-vision-001
- gemini-1.0-pro-002
- gemini-1.5-flash-001
- gemini-1.5-flash-002
- gemini-1.5-pro-001
- gemini-1.5-pro-002
- gemini-2.0-flash-001
- gemini-2.0-flash-exp
- gemini-2.0-flash-lite-preview-02-05
- gemini-2.0-flash-thinking-exp-01-21
- gemini-2.0-flash-thinking-exp-1219
- gemini-2.0-pro-exp-02-05
- gemini-exp-1114
- gemini-exp-1121
- gemini-exp-1206
- gemini-flash-experimental
- gemini-pro-experimental

@ -0,0 +1,41 @@
model: gemini-2.0-flash-001
label:
en_US: Gemini 2.0 Flash 001
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

@ -0,0 +1,41 @@
model: gemini-2.0-flash-lite-preview-02-05
label:
en_US: Gemini 2.0 Flash Lite Preview 0205
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

@ -0,0 +1,39 @@
model: gemini-2.0-flash-thinking-exp-01-21
label:
en_US: Gemini 2.0 Flash Thinking Exp 0121
model_type: llm
features:
- agent-thought
- vision
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

@ -0,0 +1,39 @@
model: gemini-2.0-flash-thinking-exp-1219
label:
en_US: Gemini 2.0 Flash Thinking Exp 1219
model_type: llm
features:
- agent-thought
- vision
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

@ -0,0 +1,37 @@
model: gemini-2.0-pro-exp-02-05
label:
en_US: Gemini 2.0 Pro Exp 0205
model_type: llm
features:
- agent-thought
- document
model_properties:
mode: chat
context_size: 2000000
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
en_US: Top k
type: int
help:
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: presence_penalty
use_template: presence_penalty
- name: frequency_penalty
use_template: frequency_penalty
- name: max_output_tokens
use_template: max_tokens
required: true
default: 8192
min: 1
max: 8192
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

@ -0,0 +1,41 @@
model: gemini-exp-1114
label:
en_US: Gemini exp 1114
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

@ -0,0 +1,41 @@
model: gemini-exp-1121
label:
en_US: Gemini exp 1121
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

@ -0,0 +1,41 @@
model: gemini-exp-1206
label:
en_US: Gemini exp 1206
model_type: llm
features:
- agent-thought
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152
parameter_rules:
- name: temperature
use_template: temperature
- name: top_p
use_template: top_p
- name: top_k
label:
zh_Hans: 取样数量
en_US: Top k
type: int
help:
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
en_US: Only sample from the top K options for each subsequent token.
required: false
- name: max_output_tokens
use_template: max_tokens
default: 8192
min: 1
max: 8192
- name: json_schema
use_template: json_schema
pricing:
input: '0.00'
output: '0.00'
unit: '0.000001'
currency: USD

@ -0,0 +1,66 @@
model: glm-4-air-0111
label:
en_US: glm-4-air-0111
model_type: llm
features:
- multi-tool-call
- agent-thought
- stream-tool-call
model_properties:
mode: chat
context_size: 131072
parameter_rules:
- name: temperature
use_template: temperature
default: 0.95
min: 0.0
max: 1.0
help:
zh_Hans: 采样温度,控制输出的随机性,必须为正数取值范围是:(0.0,1.0],不能等于 0,默认值为 0.95 值越大,会使输出更随机,更具创造性;值越小,输出会更加稳定或确定建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Sampling temperature, controls the randomness of the output, must be a positive number. The value range is (0.0,1.0], which cannot be equal to 0. The default value is 0.95. The larger the value, the more random and creative the output will be; the smaller the value, The output will be more stable or certain. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: top_p
use_template: top_p
default: 0.7
help:
zh_Hans: 用温度取样的另一种方法,称为核取样取值范围是:(0.0, 1.0) 开区间,不能等于 0 或 1默认值为 0.7 模型考虑具有 top_p 概率质量tokens的结果例如0.1 意味着模型解码器只考虑从前 10% 的概率的候选集中取 tokens 建议您根据应用场景调整 top_p 或 temperature 参数,但不要同时调整两个参数。
en_US: Another method of temperature sampling is called kernel sampling. The value range is (0.0, 1.0) open interval, which cannot be equal to 0 or 1. The default value is 0.7. The model considers the results with top_p probability mass tokens. For example 0.1 means The model decoder only considers tokens from the candidate set with the top 10% probability. It is recommended that you adjust the top_p or temperature parameters according to the application scenario, but do not adjust both parameters at the same time.
- name: do_sample
label:
zh_Hans: 采样策略
en_US: Sampling strategy
type: boolean
help:
zh_Hans: do_sample 为 true 时启用采样策略do_sample 为 false 时采样策略 temperature、top_p 将不生效。默认值为 true。
en_US: When `do_sample` is set to true, the sampling strategy is enabled. When `do_sample` is set to false, the sampling strategies such as `temperature` and `top_p` will not take effect. The default value is true.
default: true
- name: max_tokens
use_template: max_tokens
default: 1024
min: 1
max: 4095
- name: web_search
type: boolean
label:
zh_Hans: 联网搜索
en_US: Web Search
default: false
help:
zh_Hans: 模型内置了互联网搜索服务,该参数控制模型在生成文本时是否参考使用互联网搜索结果。启用互联网搜索,模型会将搜索结果作为文本生成过程中的参考信息,但模型会基于其内部逻辑“自行判断”是否使用互联网搜索结果。
en_US: The model has a built-in Internet search service. This parameter controls whether the model refers to Internet search results when generating text. When Internet search is enabled, the model will use the search results as reference information in the text generation process, but the model will "judge" whether to use Internet search results based on its internal logic.
- name: response_format
label:
zh_Hans: 回复格式
en_US: Response Format
type: string
help:
zh_Hans: 指定模型必须输出的格式
en_US: specifying the format that the model must output
required: false
options:
- text
- json_object
pricing:
input: '0.0005'
output: '0.0005'
unit: '0.001'
currency: RMB

@ -0,0 +1,113 @@
from collections.abc import Generator
from typing import Any, Optional
from core.agent.entities import AgentInvokeMessage
from core.plugin.entities.plugin import GenericProviderID
from core.plugin.entities.plugin_daemon import (
PluginAgentProviderEntity,
)
from core.plugin.manager.base import BasePluginManager
class PluginAgentManager(BasePluginManager):
def fetch_agent_strategy_providers(self, tenant_id: str) -> list[PluginAgentProviderEntity]:
"""
Fetch agent providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict:
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name")
for strategy in declaration.get("strategies", []):
strategy["identity"]["provider"] = provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/agent_strategies",
list[PluginAgentProviderEntity],
params={"page": 1, "page_size": 256},
transformer=transformer,
)
for provider in response:
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for strategy in provider.declaration.strategies:
strategy.identity.provider = provider.declaration.identity.name
return response
def fetch_agent_strategy_provider(self, tenant_id: str, provider: str) -> PluginAgentProviderEntity:
"""
Fetch tool provider for the given tenant and plugin.
"""
agent_provider_id = GenericProviderID(provider)
def transformer(json_response: dict[str, Any]) -> dict:
# skip if error occurs
if json_response.get("data") is None or json_response.get("data", {}).get("declaration") is None:
return json_response
for strategy in json_response.get("data", {}).get("declaration", {}).get("strategies", []):
strategy["identity"]["provider"] = agent_provider_id.provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/agent_strategy",
PluginAgentProviderEntity,
params={"provider": agent_provider_id.provider_name, "plugin_id": agent_provider_id.plugin_id},
transformer=transformer,
)
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for strategy in response.declaration.strategies:
strategy.identity.provider = response.declaration.identity.name
return response
def invoke(
self,
tenant_id: str,
user_id: str,
agent_provider: str,
agent_strategy: str,
agent_params: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[AgentInvokeMessage, None, None]:
"""
Invoke the agent with the given tenant, user, plugin, provider, name and parameters.
"""
agent_provider_id = GenericProviderID(agent_provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/agent_strategy/invoke",
AgentInvokeMessage,
data={
"user_id": user_id,
"conversation_id": conversation_id,
"app_id": app_id,
"message_id": message_id,
"data": {
"agent_strategy_provider": agent_provider_id.provider_name,
"agent_strategy": agent_strategy,
"agent_strategy_params": agent_params,
},
},
headers={
"X-Plugin-ID": agent_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
return response

@ -0,0 +1,12 @@
from core.plugin.manager.base import BasePluginManager
class PluginAssetManager(BasePluginManager):
def fetch_asset(self, tenant_id: str, id: str) -> bytes:
"""
Fetch an asset by id.
"""
response = self._request(method="GET", path=f"plugin/{tenant_id}/asset/{id}")
if response.status_code != 200:
raise ValueError(f"can not found asset {id}")
return response.content

@ -0,0 +1,237 @@
import inspect
import json
import logging
from collections.abc import Callable, Generator
from typing import TypeVar
import requests
from pydantic import BaseModel
from yarl import URL
from configs import dify_config
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError
from core.plugin.manager.exc import (
PluginDaemonBadRequestError,
PluginDaemonInternalServerError,
PluginDaemonNotFoundError,
PluginDaemonUnauthorizedError,
PluginInvokeError,
PluginNotFoundError,
PluginPermissionDeniedError,
PluginUniqueIdentifierError,
)
plugin_daemon_inner_api_baseurl = dify_config.PLUGIN_DAEMON_URL
plugin_daemon_inner_api_key = dify_config.PLUGIN_DAEMON_KEY
T = TypeVar("T", bound=(BaseModel | dict | list | bool | str))
logger = logging.getLogger(__name__)
class BasePluginManager:
def _request(
self,
method: str,
path: str,
headers: dict | None = None,
data: bytes | dict | str | None = None,
params: dict | None = None,
files: dict | None = None,
stream: bool = False,
) -> requests.Response:
"""
Make a request to the plugin daemon inner API.
"""
url = URL(str(plugin_daemon_inner_api_baseurl)) / path
headers = headers or {}
headers["X-Api-Key"] = plugin_daemon_inner_api_key
headers["Accept-Encoding"] = "gzip, deflate, br"
if headers.get("Content-Type") == "application/json" and isinstance(data, dict):
data = json.dumps(data)
try:
response = requests.request(
method=method, url=str(url), headers=headers, data=data, params=params, stream=stream, files=files
)
except requests.exceptions.ConnectionError:
logger.exception("Request to Plugin Daemon Service failed")
raise PluginDaemonInnerError(code=-500, message="Request to Plugin Daemon Service failed")
return response
def _stream_request(
self,
method: str,
path: str,
params: dict | None = None,
headers: dict | None = None,
data: bytes | dict | None = None,
files: dict | None = None,
) -> Generator[bytes, None, None]:
"""
Make a stream request to the plugin daemon inner API
"""
response = self._request(method, path, headers, data, params, files, stream=True)
for line in response.iter_lines():
line = line.decode("utf-8").strip()
if line.startswith("data:"):
line = line[5:].strip()
if line:
yield line
def _stream_request_with_model(
self,
method: str,
path: str,
type: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
files: dict | None = None,
) -> Generator[T, None, None]:
"""
Make a stream request to the plugin daemon inner API and yield the response as a model.
"""
for line in self._stream_request(method, path, params, headers, data, files):
yield type(**json.loads(line)) # type: ignore
def _request_with_model(
self,
method: str,
path: str,
type: type[T],
headers: dict | None = None,
data: bytes | None = None,
params: dict | None = None,
files: dict | None = None,
) -> T:
"""
Make a request to the plugin daemon inner API and return the response as a model.
"""
response = self._request(method, path, headers, data, params, files)
return type(**response.json()) # type: ignore
def _request_with_plugin_daemon_response(
self,
method: str,
path: str,
type: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
files: dict | None = None,
transformer: Callable[[dict], dict] | None = None,
) -> T:
"""
Make a request to the plugin daemon inner API and return the response as a model.
"""
response = self._request(method, path, headers, data, params, files)
json_response = response.json()
if transformer:
json_response = transformer(json_response)
rep = PluginDaemonBasicResponse[type](**json_response) # type: ignore
if rep.code != 0:
try:
error = PluginDaemonError(**json.loads(rep.message))
except Exception:
raise ValueError(f"{rep.message}, code: {rep.code}")
self._handle_plugin_daemon_error(error.error_type, error.message)
if rep.data is None:
frame = inspect.currentframe()
raise ValueError(f"got empty data from plugin daemon: {frame.f_lineno if frame else 'unknown'}")
return rep.data
def _request_with_plugin_daemon_response_stream(
self,
method: str,
path: str,
type: type[T],
headers: dict | None = None,
data: bytes | dict | None = None,
params: dict | None = None,
files: dict | None = None,
) -> Generator[T, None, None]:
"""
Make a stream request to the plugin daemon inner API and yield the response as a model.
"""
for line in self._stream_request(method, path, params, headers, data, files):
line_data = None
try:
line_data = json.loads(line)
rep = PluginDaemonBasicResponse[type](**line_data) # type: ignore
except Exception:
# TODO modify this when line_data has code and message
if line_data and "error" in line_data:
raise ValueError(line_data["error"])
else:
raise ValueError(line)
if rep.code != 0:
if rep.code == -500:
try:
error = PluginDaemonError(**json.loads(rep.message))
except Exception:
raise PluginDaemonInnerError(code=rep.code, message=rep.message)
self._handle_plugin_daemon_error(error.error_type, error.message)
raise ValueError(f"plugin daemon: {rep.message}, code: {rep.code}")
if rep.data is None:
frame = inspect.currentframe()
raise ValueError(f"got empty data from plugin daemon: {frame.f_lineno if frame else 'unknown'}")
yield rep.data
def _handle_plugin_daemon_error(self, error_type: str, message: str):
"""
handle the error from plugin daemon
"""
match error_type:
case PluginDaemonInnerError.__name__:
raise PluginDaemonInnerError(code=-500, message=message)
case PluginInvokeError.__name__:
error_object = json.loads(message)
invoke_error_type = error_object.get("error_type")
args = error_object.get("args")
match invoke_error_type:
case InvokeRateLimitError.__name__:
raise InvokeRateLimitError(description=args.get("description"))
case InvokeAuthorizationError.__name__:
raise InvokeAuthorizationError(description=args.get("description"))
case InvokeBadRequestError.__name__:
raise InvokeBadRequestError(description=args.get("description"))
case InvokeConnectionError.__name__:
raise InvokeConnectionError(description=args.get("description"))
case InvokeServerUnavailableError.__name__:
raise InvokeServerUnavailableError(description=args.get("description"))
case CredentialsValidateFailedError.__name__:
raise CredentialsValidateFailedError(error_object.get("message"))
case _:
raise PluginInvokeError(description=message)
case PluginDaemonInternalServerError.__name__:
raise PluginDaemonInternalServerError(description=message)
case PluginDaemonBadRequestError.__name__:
raise PluginDaemonBadRequestError(description=message)
case PluginDaemonNotFoundError.__name__:
raise PluginDaemonNotFoundError(description=message)
case PluginUniqueIdentifierError.__name__:
raise PluginUniqueIdentifierError(description=message)
case PluginNotFoundError.__name__:
raise PluginNotFoundError(description=message)
case PluginDaemonUnauthorizedError.__name__:
raise PluginDaemonUnauthorizedError(description=message)
case PluginPermissionDeniedError.__name__:
raise PluginPermissionDeniedError(description=message)
case _:
raise Exception(f"got unknown error from plugin daemon: {error_type}, message: {message}")

@ -0,0 +1,17 @@
from pydantic import BaseModel
from core.plugin.manager.base import BasePluginManager
class PluginDebuggingManager(BasePluginManager):
def get_debugging_key(self, tenant_id: str) -> str:
"""
Get the debugging key for the given tenant.
"""
class Response(BaseModel):
key: str
response = self._request_with_plugin_daemon_response("POST", f"plugin/{tenant_id}/debugging/key", Response)
return response.key

@ -0,0 +1,116 @@
from core.plugin.entities.endpoint import EndpointEntityWithInstance
from core.plugin.manager.base import BasePluginManager
class PluginEndpointManager(BasePluginManager):
def create_endpoint(
self, tenant_id: str, user_id: str, plugin_unique_identifier: str, name: str, settings: dict
) -> bool:
"""
Create an endpoint for the given plugin.
Errors will be raised if any error occurs.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/endpoint/setup",
bool,
headers={
"Content-Type": "application/json",
},
data={
"user_id": user_id,
"plugin_unique_identifier": plugin_unique_identifier,
"settings": settings,
"name": name,
},
)
def list_endpoints(self, tenant_id: str, user_id: str, page: int, page_size: int):
"""
List all endpoints for the given tenant and user.
"""
return self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/endpoint/list",
list[EndpointEntityWithInstance],
params={"page": page, "page_size": page_size},
)
def list_endpoints_for_single_plugin(self, tenant_id: str, user_id: str, plugin_id: str, page: int, page_size: int):
"""
List all endpoints for the given tenant, user and plugin.
"""
return self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/endpoint/list/plugin",
list[EndpointEntityWithInstance],
params={"plugin_id": plugin_id, "page": page, "page_size": page_size},
)
def update_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str, name: str, settings: dict):
"""
Update the settings of the given endpoint.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/endpoint/update",
bool,
data={
"user_id": user_id,
"endpoint_id": endpoint_id,
"name": name,
"settings": settings,
},
headers={
"Content-Type": "application/json",
},
)
def delete_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""
Delete the given endpoint.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/endpoint/remove",
bool,
data={
"endpoint_id": endpoint_id,
},
headers={
"Content-Type": "application/json",
},
)
def enable_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""
Enable the given endpoint.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/endpoint/enable",
bool,
data={
"endpoint_id": endpoint_id,
},
headers={
"Content-Type": "application/json",
},
)
def disable_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""
Disable the given endpoint.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/endpoint/disable",
bool,
data={
"endpoint_id": endpoint_id,
},
headers={
"Content-Type": "application/json",
},
)

@ -0,0 +1,49 @@
class PluginDaemonError(Exception):
"""Base class for all plugin daemon errors."""
def __init__(self, description: str) -> None:
self.description = description
def __str__(self) -> str:
# returns the class name and description
return f"{self.__class__.__name__}: {self.description}"
class PluginDaemonInternalError(PluginDaemonError):
pass
class PluginDaemonClientSideError(PluginDaemonError):
pass
class PluginDaemonInternalServerError(PluginDaemonInternalError):
description: str = "Internal Server Error"
class PluginDaemonUnauthorizedError(PluginDaemonInternalError):
description: str = "Unauthorized"
class PluginDaemonNotFoundError(PluginDaemonInternalError):
description: str = "Not Found"
class PluginDaemonBadRequestError(PluginDaemonClientSideError):
description: str = "Bad Request"
class PluginInvokeError(PluginDaemonClientSideError):
description: str = "Invoke Error"
class PluginUniqueIdentifierError(PluginDaemonClientSideError):
description: str = "Unique Identifier Error"
class PluginNotFoundError(PluginDaemonClientSideError):
description: str = "Plugin Not Found"
class PluginPermissionDeniedError(PluginDaemonClientSideError):
description: str = "Permission Denied"

@ -0,0 +1,531 @@
import binascii
from collections.abc import Generator, Sequence
from typing import IO, Optional
from core.model_runtime.entities.llm_entities import LLMResultChunk
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import (
PluginBasicBooleanResponse,
PluginDaemonInnerError,
PluginLLMNumTokensResponse,
PluginModelProviderEntity,
PluginModelSchemaEntity,
PluginStringResultResponse,
PluginTextEmbeddingNumTokensResponse,
PluginVoicesResponse,
)
from core.plugin.manager.base import BasePluginManager
class PluginModelManager(BasePluginManager):
def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]:
"""
Fetch model providers for the given tenant.
"""
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/models",
list[PluginModelProviderEntity],
params={"page": 1, "page_size": 256},
)
return response
def get_model_schema(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model_type: str,
model: str,
credentials: dict,
) -> AIModelEntity | None:
"""
Get model schema
"""
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/model/schema",
PluginModelSchemaEntity,
data={
"user_id": user_id,
"data": {
"provider": provider,
"model_type": model_type,
"model": model,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp.model_schema
return None
def validate_provider_credentials(
self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict
) -> bool:
"""
validate the credentials of the provider
"""
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials",
PluginBasicBooleanResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
if resp.credentials and isinstance(resp.credentials, dict):
credentials.update(resp.credentials)
return resp.result
return False
def validate_model_credentials(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model_type: str,
model: str,
credentials: dict,
) -> bool:
"""
validate the credentials of the provider
"""
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/model/validate_model_credentials",
PluginBasicBooleanResponse,
data={
"user_id": user_id,
"data": {
"provider": provider,
"model_type": model_type,
"model": model,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
if resp.credentials and isinstance(resp.credentials, dict):
credentials.update(resp.credentials)
return resp.result
return False
def invoke_llm(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Optional[list[PromptMessageTool]] = None,
stop: Optional[list[str]] = None,
stream: bool = True,
) -> Generator[LLMResultChunk, None, None]:
"""
Invoke llm
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/llm/invoke",
type=LLMResultChunk,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
"provider": provider,
"model_type": "llm",
"model": model,
"credentials": credentials,
"prompt_messages": prompt_messages,
"model_parameters": model_parameters,
"tools": tools,
"stop": stop,
"stream": stream,
},
}
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
try:
yield from response
except PluginDaemonInnerError as e:
raise ValueError(e.message + str(e.code))
def get_llm_num_tokens(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model_type: str,
model: str,
credentials: dict,
prompt_messages: list[PromptMessage],
tools: Optional[list[PromptMessageTool]] = None,
) -> int:
"""
Get number of tokens for llm
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/llm/num_tokens",
type=PluginLLMNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
"provider": provider,
"model_type": model_type,
"model": model,
"credentials": credentials,
"prompt_messages": prompt_messages,
"tools": tools,
},
}
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp.num_tokens
return 0
def invoke_text_embedding(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
texts: list[str],
input_type: str,
) -> TextEmbeddingResult:
"""
Invoke text embedding
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke",
type=TextEmbeddingResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
"provider": provider,
"model_type": "text-embedding",
"model": model,
"credentials": credentials,
"texts": texts,
"input_type": input_type,
},
}
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("Failed to invoke text embedding")
def get_text_embedding_num_tokens(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
texts: list[str],
) -> list[int]:
"""
Get number of tokens for text embedding
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens",
type=PluginTextEmbeddingNumTokensResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
"provider": provider,
"model_type": "text-embedding",
"model": model,
"credentials": credentials,
"texts": texts,
},
}
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp.num_tokens
return []
def invoke_rerank(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
) -> RerankResult:
"""
Invoke rerank
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/rerank/invoke",
type=RerankResult,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
"provider": provider,
"model_type": "rerank",
"model": model,
"credentials": credentials,
"query": query,
"docs": docs,
"score_threshold": score_threshold,
"top_n": top_n,
},
}
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp
raise ValueError("Failed to invoke rerank")
def invoke_tts(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
content_text: str,
voice: str,
) -> Generator[bytes, None, None]:
"""
Invoke tts
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/tts/invoke",
type=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
"provider": provider,
"model_type": "tts",
"model": model,
"credentials": credentials,
"tenant_id": tenant_id,
"content_text": content_text,
"voice": voice,
},
}
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
try:
for result in response:
hex_str = result.result
yield binascii.unhexlify(hex_str)
except PluginDaemonInnerError as e:
raise ValueError(e.message + str(e.code))
def get_tts_model_voices(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
language: Optional[str] = None,
) -> list[dict]:
"""
Get tts model voices
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/tts/model/voices",
type=PluginVoicesResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
"provider": provider,
"model_type": "tts",
"model": model,
"credentials": credentials,
"language": language,
},
}
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
voices = []
for voice in resp.voices:
voices.append({"name": voice.name, "value": voice.value})
return voices
return []
def invoke_speech_to_text(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
file: IO[bytes],
) -> str:
"""
Invoke speech to text
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/speech2text/invoke",
type=PluginStringResultResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
"provider": provider,
"model_type": "speech2text",
"model": model,
"credentials": credentials,
"file": binascii.hexlify(file.read()).decode(),
},
}
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp.result
raise ValueError("Failed to invoke speech to text")
def invoke_moderation(
self,
tenant_id: str,
user_id: str,
plugin_id: str,
provider: str,
model: str,
credentials: dict,
text: str,
) -> bool:
"""
Invoke moderation
"""
response = self._request_with_plugin_daemon_response_stream(
method="POST",
path=f"plugin/{tenant_id}/dispatch/moderation/invoke",
type=PluginBasicBooleanResponse,
data=jsonable_encoder(
{
"user_id": user_id,
"data": {
"provider": provider,
"model_type": "moderation",
"model": model,
"credentials": credentials,
"text": text,
},
}
),
headers={
"X-Plugin-ID": plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp.result
raise ValueError("Failed to invoke moderation")

@ -0,0 +1,249 @@
from collections.abc import Sequence
from core.plugin.entities.bundle import PluginBundleDependency
from core.plugin.entities.plugin import (
GenericProviderID,
MissingPluginDependency,
PluginDeclaration,
PluginEntity,
PluginInstallation,
PluginInstallationSource,
)
from core.plugin.entities.plugin_daemon import PluginInstallTask, PluginInstallTaskStartResponse, PluginUploadResponse
from core.plugin.manager.base import BasePluginManager
class PluginInstallationManager(BasePluginManager):
def fetch_plugin_by_identifier(
self,
tenant_id: str,
identifier: str,
) -> bool:
return self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/fetch/identifier",
bool,
params={"plugin_unique_identifier": identifier},
)
def list_plugins(self, tenant_id: str) -> list[PluginEntity]:
return self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/list",
list[PluginEntity],
params={"page": 1, "page_size": 256},
)
def upload_pkg(
self,
tenant_id: str,
pkg: bytes,
verify_signature: bool = False,
) -> PluginUploadResponse:
"""
Upload a plugin package and return the plugin unique identifier.
"""
body = {
"dify_pkg": ("dify_pkg", pkg, "application/octet-stream"),
}
data = {
"verify_signature": "true" if verify_signature else "false",
}
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/management/install/upload/package",
PluginUploadResponse,
files=body,
data=data,
)
def upload_bundle(
self,
tenant_id: str,
bundle: bytes,
verify_signature: bool = False,
) -> Sequence[PluginBundleDependency]:
"""
Upload a plugin bundle and return the dependencies.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/management/install/upload/bundle",
list[PluginBundleDependency],
files={"dify_bundle": ("dify_bundle", bundle, "application/octet-stream")},
data={"verify_signature": "true" if verify_signature else "false"},
)
def install_from_identifiers(
self,
tenant_id: str,
identifiers: Sequence[str],
source: PluginInstallationSource,
metas: list[dict],
) -> PluginInstallTaskStartResponse:
"""
Install a plugin from an identifier.
"""
# exception will be raised if the request failed
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/management/install/identifiers",
PluginInstallTaskStartResponse,
data={
"plugin_unique_identifiers": identifiers,
"source": source,
"metas": metas,
},
headers={"Content-Type": "application/json"},
)
def fetch_plugin_installation_tasks(self, tenant_id: str, page: int, page_size: int) -> Sequence[PluginInstallTask]:
"""
Fetch plugin installation tasks.
"""
return self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/install/tasks",
list[PluginInstallTask],
params={"page": page, "page_size": page_size},
)
def fetch_plugin_installation_task(self, tenant_id: str, task_id: str) -> PluginInstallTask:
"""
Fetch a plugin installation task.
"""
return self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/install/tasks/{task_id}",
PluginInstallTask,
)
def delete_plugin_installation_task(self, tenant_id: str, task_id: str) -> bool:
"""
Delete a plugin installation task.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/management/install/tasks/{task_id}/delete",
bool,
)
def delete_all_plugin_installation_task_items(self, tenant_id: str) -> bool:
"""
Delete all plugin installation task items.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/management/install/tasks/delete_all",
bool,
)
def delete_plugin_installation_task_item(self, tenant_id: str, task_id: str, identifier: str) -> bool:
"""
Delete a plugin installation task item.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/management/install/tasks/{task_id}/delete/{identifier}",
bool,
)
def fetch_plugin_manifest(self, tenant_id: str, plugin_unique_identifier: str) -> PluginDeclaration:
"""
Fetch a plugin manifest.
"""
return self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/fetch/manifest",
PluginDeclaration,
params={"plugin_unique_identifier": plugin_unique_identifier},
)
def fetch_plugin_installation_by_ids(
self, tenant_id: str, plugin_ids: Sequence[str]
) -> Sequence[PluginInstallation]:
"""
Fetch plugin installations by ids.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/management/installation/fetch/batch",
list[PluginInstallation],
data={"plugin_ids": plugin_ids},
headers={"Content-Type": "application/json"},
)
def fetch_missing_dependencies(
self, tenant_id: str, plugin_unique_identifiers: list[str]
) -> list[MissingPluginDependency]:
"""
Fetch missing dependencies
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/management/installation/missing",
list[MissingPluginDependency],
data={"plugin_unique_identifiers": plugin_unique_identifiers},
headers={"Content-Type": "application/json"},
)
def uninstall(self, tenant_id: str, plugin_installation_id: str) -> bool:
"""
Uninstall a plugin.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/management/uninstall",
bool,
data={
"plugin_installation_id": plugin_installation_id,
},
headers={"Content-Type": "application/json"},
)
def upgrade_plugin(
self,
tenant_id: str,
original_plugin_unique_identifier: str,
new_plugin_unique_identifier: str,
source: PluginInstallationSource,
meta: dict,
) -> PluginInstallTaskStartResponse:
"""
Upgrade a plugin.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/management/install/upgrade",
PluginInstallTaskStartResponse,
data={
"original_plugin_unique_identifier": original_plugin_unique_identifier,
"new_plugin_unique_identifier": new_plugin_unique_identifier,
"source": source,
"meta": meta,
},
headers={"Content-Type": "application/json"},
)
def check_tools_existence(self, tenant_id: str, provider_ids: Sequence[GenericProviderID]) -> Sequence[bool]:
"""
Check if the tools exist
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/management/tools/check_existence",
list[bool],
data={
"provider_ids": [
{
"plugin_id": provider_id.plugin_id,
"provider_name": provider_id.provider_name,
}
for provider_id in provider_ids
]
},
headers={"Content-Type": "application/json"},
)

@ -0,0 +1,188 @@
from collections.abc import Generator
from typing import Any, Optional
from pydantic import BaseModel
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.manager.base import BasePluginManager
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
class PluginToolManager(BasePluginManager):
def fetch_tool_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]:
"""
Fetch tool providers for the given tenant.
"""
def transformer(json_response: dict[str, Any]) -> dict:
for provider in json_response.get("data", []):
declaration = provider.get("declaration", {}) or {}
provider_name = declaration.get("identity", {}).get("name")
for tool in declaration.get("tools", []):
tool["identity"]["provider"] = provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/tools",
list[PluginToolProviderEntity],
params={"page": 1, "page_size": 256},
transformer=transformer,
)
for provider in response:
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for tool in provider.declaration.tools:
tool.identity.provider = provider.declaration.identity.name
return response
def fetch_tool_provider(self, tenant_id: str, provider: str) -> PluginToolProviderEntity:
"""
Fetch tool provider for the given tenant and plugin.
"""
tool_provider_id = ToolProviderID(provider)
def transformer(json_response: dict[str, Any]) -> dict:
data = json_response.get("data")
if data:
for tool in data.get("declaration", {}).get("tools", []):
tool["identity"]["provider"] = tool_provider_id.provider_name
return json_response
response = self._request_with_plugin_daemon_response(
"GET",
f"plugin/{tenant_id}/management/tool",
PluginToolProviderEntity,
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id},
transformer=transformer,
)
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
# override the provider name for each tool to plugin_id/provider_name
for tool in response.declaration.tools:
tool.identity.provider = response.declaration.identity.name
return response
def invoke(
self,
tenant_id: str,
user_id: str,
tool_provider: str,
tool_name: str,
credentials: dict[str, Any],
tool_parameters: dict[str, Any],
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> Generator[ToolInvokeMessage, None, None]:
"""
Invoke the tool with the given tenant, user, plugin, provider, name, credentials and parameters.
"""
tool_provider_id = GenericProviderID(tool_provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/tool/invoke",
ToolInvokeMessage,
data={
"user_id": user_id,
"conversation_id": conversation_id,
"app_id": app_id,
"message_id": message_id,
"data": {
"provider": tool_provider_id.provider_name,
"tool": tool_name,
"credentials": credentials,
"tool_parameters": tool_parameters,
},
},
headers={
"X-Plugin-ID": tool_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
return response
def validate_provider_credentials(
self, tenant_id: str, user_id: str, provider: str, credentials: dict[str, Any]
) -> bool:
"""
validate the credentials of the provider
"""
tool_provider_id = GenericProviderID(provider)
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/tool/validate_credentials",
PluginBasicBooleanResponse,
data={
"user_id": user_id,
"data": {
"provider": tool_provider_id.provider_name,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": tool_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp.result
return False
def get_runtime_parameters(
self,
tenant_id: str,
user_id: str,
provider: str,
credentials: dict[str, Any],
tool: str,
conversation_id: Optional[str] = None,
app_id: Optional[str] = None,
message_id: Optional[str] = None,
) -> list[ToolParameter]:
"""
get the runtime parameters of the tool
"""
tool_provider_id = GenericProviderID(provider)
class RuntimeParametersResponse(BaseModel):
parameters: list[ToolParameter]
response = self._request_with_plugin_daemon_response_stream(
"POST",
f"plugin/{tenant_id}/dispatch/tool/get_runtime_parameters",
RuntimeParametersResponse,
data={
"user_id": user_id,
"conversation_id": conversation_id,
"app_id": app_id,
"message_id": message_id,
"data": {
"provider": tool_provider_id.provider_name,
"tool": tool,
"credentials": credentials,
},
},
headers={
"X-Plugin-ID": tool_provider_id.plugin_id,
"Content-Type": "application/json",
},
)
for resp in response:
return resp.parameters
return []

@ -422,6 +422,7 @@ class RetrievalService:
if score_value is not None and isinstance(score_value, int | float | str) if score_value is not None and isinstance(score_value, int | float | str)
else None else None
) )
cls.append_next_segments(records=records,dataset_documents=dataset_documents)
# Create RetrievalSegments object # Create RetrievalSegments object
retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score) retrieval_segment = RetrievalSegments(segment=segment, child_chunks=child_chunks, score=score)
@ -431,3 +432,93 @@ class RetrievalService:
except Exception as e: except Exception as e:
db.session.rollback() db.session.rollback()
raise e raise e
@classmethod
def append_next_segments(cls, records: list[dict], dataset_documents : dict):
# import pdb; pdb.set_trace()
def filter_record(record):
document_id = record["segment"].document_id
if document_id in dataset_documents:
dataset_document = dataset_documents[document_id]
if dataset_document and dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
return True
return False
filtered_data = list(filter(filter_record, records))
cls.set_next_segments(records=filtered_data)
# 为文档
@classmethod
def set_next_segments(cls,records: list[dict]) :
# 判断文档是否为空
document_ids = []
doc_segment_ids = []
for record in records:
document_id = record["segment"].document_id
doc_segment_id = record["segment"].id
doc_segment_ids.append(doc_segment_id)
document_ids.append(document_id)
# 找到文档的所有的
if len(document_ids) > 0:
document_segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all()
document_segment_data = {}
for document_segment in document_segments:
key = document_segment.document_id
if key not in document_segment_data:
document_segment_data[key] = []
document_segment_data[key].append(document_segment)
cls.merged_next_segment_content(records=records, document_segment_data=document_segment_data,doc_segment_ids=doc_segment_ids)
@classmethod
def merged_next_segment_content(cls,records: list[dict],document_segment_data: dict,doc_segment_ids: list) :
# 按照分数倒叙排序
sorted_records = sorted(records, key=lambda r: r["score"], reverse=True)
# 只处理最大分数的前三个如果已存在顺延处理下一片直到满3个
index = 3
for record in sorted_records:
if index == 0:
break
document_id = record["segment"].document_id
doc_segment_id = record["segment"].id
content = record["segment"].content
document_segments = document_segment_data[document_id]
# 获取下一个分片
next_segment = cls.get_next_segment(doc_segment_id=doc_segment_id,document_segments=document_segments)
if next_segment and next_segment.id not in doc_segment_ids:
merged_string, merged = cls.merged_text(content, next_segment.content)
doc_segment_ids.append(next_segment.id)
if merged:
record["segment"].content = merged_string
index -= 1
@classmethod
def merged_text(cls, text, target_text) -> (str,bool):
# 初始化最大重叠长度为0
max_overlap_length = 0 # 初始化变量max_overlap_length用于存储最大重叠长度
# 检查A的结尾与B的开头是否有大于10个字符的重叠
for overlap_length in range(1, min(len(text), len(target_text)) + 1): # 遍历可能的重叠长度从1到最小字符串长度
if text[-overlap_length:] == target_text[:overlap_length]: # 检查A的后缀和B的前缀是否相同
max_overlap_length = overlap_length # 更新最大重叠长度
merged_string = text
merged = False
# 如果有大于10个字符的重叠则合并字符串
if max_overlap_length > 10: # 判断最大重叠长度是否大于10
merged_string = text + target_text[max_overlap_length:] # 合并字符串,去掉重复部分
merged = True
return merged_string,merged
@classmethod
def get_next_segment(cls,doc_segment_id, document_segments: list[DocumentSegment]) -> DocumentSegment:
# import pdb; pdb.set_trace()
next_segment = None
if document_segments is not None and len(document_segments) > 0:
this_positions = -1
for index, document_segment in enumerate(document_segments):
if document_segment.id == doc_segment_id:
this_positions = document_segment.position
for document_segment in document_segments:
if document_segment.position == this_positions + 1:
next_segment = document_segment
break
return next_segment

@ -23,6 +23,10 @@ SupportedComparisonOperator = Literal[
# for time # for time
"before", "before",
"after", "after",
# 扩展
"in",
"not in",
] ]

@ -7,7 +7,7 @@ from core.rag.extractor.blob.blob import Blob
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_storage import storage from extensions.ext_storage import storage
from services.ext.read_file_service import ReadPdfService
class PdfExtractor(BaseExtractor): class PdfExtractor(BaseExtractor):
"""Load pdf files. """Load pdf files.
@ -48,7 +48,8 @@ class PdfExtractor(BaseExtractor):
) -> Iterator[Document]: ) -> Iterator[Document]:
"""Lazy load given path as pages.""" """Lazy load given path as pages."""
blob = Blob.from_path(self._file_path) blob = Blob.from_path(self._file_path)
yield from self.parse(blob) yield from self.parse_ext(blob)
# yield from self.parse(blob)
def parse(self, blob: Blob) -> Iterator[Document]: def parse(self, blob: Blob) -> Iterator[Document]:
"""Lazily parse the blob.""" """Lazily parse the blob."""
@ -57,12 +58,23 @@ class PdfExtractor(BaseExtractor):
with blob.as_bytes_io() as file_path: with blob.as_bytes_io() as file_path:
pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True) pdf_reader = pypdfium2.PdfDocument(file_path, autoclose=True)
try: try:
content_arr = []
for page_number, page in enumerate(pdf_reader): for page_number, page in enumerate(pdf_reader):
text_page = page.get_textpage() text_page = page.get_textpage()
content = text_page.get_text_range() content = text_page.get_text_range()
# print(content)
content_arr.append(content)
text_page.close() text_page.close()
page.close() page.close()
metadata = {"source": blob.source, "page": page_number}
yield Document(page_content=content, metadata=metadata) contents = " ".join([p for p in content_arr])
metadata = {"source": blob.source, "page": 1}
yield Document(page_content=contents, metadata=metadata)
finally: finally:
pdf_reader.close() pdf_reader.close()
def parse_ext(self, blob: Blob) -> Iterator[Document]:
read_pdf_service = ReadPdfService()
contents = read_pdf_service.load_content(self._file_path)
metadata = {"source": blob.source, "page": 1}
yield Document(page_content=contents, metadata=metadata)

@ -13,7 +13,15 @@ from core.rag.splitter.fixed_text_splitter import (
) )
from core.rag.splitter.text_splitter import TextSplitter from core.rag.splitter.text_splitter import TextSplitter
from models.dataset import Dataset, DatasetProcessRule from models.dataset import Dataset, DatasetProcessRule
from core.rag.splitter.text_splitter import (
TS,
Collection,
Literal,
RecursiveCharacterTextSplitter,
Set,
TokenTextSplitter,
Union,
)
class BaseIndexProcessor(ABC): class BaseIndexProcessor(ABC):
"""Interface for extract files.""" """Interface for extract files."""
@ -69,7 +77,8 @@ class BaseIndexProcessor(ABC):
chunk_size=max_tokens, chunk_size=max_tokens,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
fixed_separator=separator, fixed_separator=separator,
separators=["\n\n", "", ". ", " ", ""], # separators=["\n\n","\n", "。", ". ", " ", "#"],
separators=["\n\n", "\n", "", "", "", "", "", ". ", "?", "", "!", ")", ":", ",", "#", "", " "],
embedding_model_instance=embedding_model_instance, embedding_model_instance=embedding_model_instance,
) )
else: else:
@ -77,7 +86,8 @@ class BaseIndexProcessor(ABC):
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder( character_splitter = EnhanceRecursiveCharacterTextSplitter.from_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"], chunk_size=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["max_tokens"],
chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"], chunk_overlap=DatasetProcessRule.AUTOMATIC_RULES["segmentation"]["chunk_overlap"],
separators=["\n\n", "", ". ", " ", ""], # separators=["\n\n", "\n", "。", ". ", " ", ""],
separators=["\n\n", "\n", "", "", "", "", "", ". ", "?", "", "!", ")", ":", ",", "#", "", " "],
embedding_model_instance=embedding_model_instance, embedding_model_instance=embedding_model_instance,
) )

@ -74,6 +74,9 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
chunks_lengths = self._length_function(chunks) chunks_lengths = self._length_function(chunks)
for chunk, chunk_length in zip(chunks, chunks_lengths): for chunk, chunk_length in zip(chunks, chunks_lengths):
if chunk_length > self._chunk_size: if chunk_length > self._chunk_size:
if self._keep_separator :
final_chunks.extend(self.recursive_split_text_keep_separator_(chunk)) # 调用递归分割方法进一步拆分。
continue
final_chunks.extend(self.recursive_split_text(chunk)) final_chunks.extend(self.recursive_split_text(chunk))
else: else:
final_chunks.append(chunk) final_chunks.append(chunk)
@ -153,3 +156,112 @@ class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter)
final_chunks.append(current_part) final_chunks.append(current_part)
return final_chunks return final_chunks
def recursive_split_text_keep_separator_(self, text: str) -> list[str]: # 定义递归分割方法。
"""Split incoming text and return chunks.""" # 文档字符串,说明该方法的作用是递归地分割文本并返回块。
final_chunks = [] # 初始化最终的块列表。
current_part_list = []
self.append_next_split_text(current_part_list=current_part_list,
current_length_list=[],
text=text,
final_chunks = final_chunks,
separators = self._separators)
if len(current_part_list): # 如果还有剩余的当前块。
final_chunks.append("".join(current_part_list)) # 将其加入最终块列表。
return final_chunks # 返回最终的块列表。
@classmethod
def get_splits_(self,text:str, separators:list[str]) -> (list[str],list[str]): # 定义递归分割方法。
"""Split incoming text and return chunks.""" # 文档字符串,说明该方法的作用是递归地分割文本并返回块。
if len(separators) > 0:
separator = separators[-1] # 默认使用备用分隔符列表中的最后一个分隔符。
new_separators = [] # 初始化新的分隔符列表。
for i, _s in enumerate(separators): # 遍历备用分隔符列表。
if _s in text: # 如果当前分隔符存在于文本中。
separator = _s # 设置分隔符为当前分隔符。
new_separators = separators[i + 1 :] # 更新新的分隔符列表。
break # 结束循环。
# Now that we have the separator, split the text # 已经确定了分隔符,开始分割文本。
if separator: # 如果分隔符不为空。
splits = text.split(separator) # 按指定分隔符分割文本。
else: # 如果分隔符为空字符串。
splits = list(text) # 将文本按字符分割成列表。
# splits = [s for s in splits if (s not in {""})] # 过滤掉空字符串和换行符。
return splits,new_separators
else:
return [text],[]
def append_next_split_text(self,
current_part_list:list[str],
current_length_list:list[int],
text: str,
final_chunks: list[str],
separators : list[str]): # 定义递归分割方法。
if text:
# 需要判断是否可以再拼接
splits, new_separators_ = self.get_splits_(text, separators)
s_lens = self._length_function(splits) # 计算每个分割部分的长度。
for s, s_len in zip(splits, s_lens): # 遍历每个分割部分及其长度。
current_length = sum(current_length_list)
if "制定综合主进度" in s:
# import pdb; pdb.post_mortem()
print(s)
if current_length + s_len <= self._chunk_size: # 如果当前块可以容纳更多内容。
current_part_list.append(s) # 将当前部分加入当前块。
current_length_list.append(s_len)
else:
if len(new_separators_) == 0:
# 将片段加入到列表中
final_chunks.append("".join(current_part_list))
# 计算出重叠部分的内容
overlap_part_length_,overlap_part_ = self.get_overlap_part(current_part_list,current_length_list)
# 将重叠部分作为下一个片段的开头
current_part_list.clear()
current_part_list.append(overlap_part_)
current_length_list.clear()
current_length_list.append(overlap_part_length_)
if overlap_part_length_ + s_len <= self._chunk_size: # 如果当前块可以容纳更多内容。
current_part_list.append(s) # 将当前部分加入当前块。
current_length_list.append(s_len)
continue
# 递归计算
self.append_next_split_text(current_part_list=current_part_list,
current_length_list=current_length_list,
text=s,
final_chunks=final_chunks,
separators=new_separators_)
def get_overlap_part(self,current_part_list:list[str],
current_length_list:list[int]) -> (int,str): # 定义递归分割方法。
# 一下计算出
overlap_part_length_ = 0
overlap_part_list = []
current_length_list_reversed = list(reversed(current_length_list))
current_part_list_reversed = list(reversed(current_part_list))
for index, s_len_ in enumerate(current_length_list_reversed):
if overlap_part_length_ + s_len_ > self._chunk_overlap:
if overlap_part_length_ < self._chunk_overlap:
text = current_part_list_reversed[index]
texts = list(text)
text_lens = self._length_function(texts)
texts_reversed = list(reversed(texts))
text_lens_reversed = list(reversed(text_lens))
for s_, len_ in zip(texts_reversed,text_lens_reversed):
if overlap_part_length_ + len_ > self._chunk_overlap:
break
overlap_part_length_ += len_
overlap_part_list[0:0] = s_
# overlap_part_list.append(s_)
break
overlap_part_length_ += s_len_
overlap_part_list[0:0] = current_part_list_reversed[index]
# overlap_part_list.append(current_part_list_reversed[index])
return overlap_part_length_, "".join(overlap_part_list)

@ -0,0 +1,258 @@
"""Functionality for splitting text."""
from __future__ import annotations
from typing import Any, Optional
from core.model_manager import ModelInstance
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.rag.splitter.text_splitter import (
TS,
Collection,
Literal,
RecursiveCharacterTextSplitter,
Set,
TokenTextSplitter,
Union,
)
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
"""
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
""" # 文档字符串,说明该类的作用是实现基于 GPT-2 的编码器,避免使用 tiktoken。
@classmethod
def from_encoder(
cls: type[TS],
allowed_special: Union[Literal["all"], Set[str]] = set(), # 允许的特殊字符集合,默认为空集。
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # 禁止的特殊字符集合,默认为 "all"。
**kwargs: Any, # 其他关键字参数。
):
def _token_encoder(texts: list[str]) -> list[int]: # 定义一个内部函数,用于计算文本的 token 数量。
if not texts: # 如果输入的文本列表为空,则返回空列表。
return []
# 否则,使用默认的 GPT-2 tokenizer 计算 token 数量。
return [GPT2Tokenizer.get_num_tokens(text) for text in texts]
if issubclass(cls, TokenTextSplitter): # 如果当前类是 TokenTextSplitter 的子类。
extra_kwargs = { # 构造额外的关键字参数。
"model_name": "gpt2", # 模型名称。
"allowed_special": allowed_special, # 允许的特殊字符。
"disallowed_special": disallowed_special, # 禁止的特殊字符。
}
kwargs = {**kwargs, **extra_kwargs} # 将额外参数合并到 kwargs 中。
return cls(length_function=_token_encoder, **kwargs) # 返回当前类的实例,并传入长度计算函数和其他参数。
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[list[str]] = None, **kwargs: Any):
"""Create a new TextSplitter.""" # 文档字符串,说明构造函数的作用是创建一个新的文本分割器。
super().__init__(**kwargs) # 调用父类的构造函数,初始化基类的属性。
self._fixed_separator = fixed_separator # 固定分隔符,默认为 "\n\n"。
self._separators = separators or ["\n\n", "\n", " ", ""] # 备用分隔符列表,默认为 ["\n\n", "\n", " ", ""]。
def split_text(self, text: str) -> list[str]: # 定义主方法,用于分割文本。
"""Split incoming text and return chunks.""" # 文档字符串,说明该方法的作用是分割输入文本并返回块。
if self._fixed_separator: # 如果设置了固定分隔符。
chunks = text.split(self._fixed_separator) # 使用固定分隔符将文本分割成初步的块。
else: # 如果未设置固定分隔符。
chunks = [text] # 将整个文本作为一个块。
final_chunks = [] # 初始化最终的块列表。
chunks_lengths = self._length_function(chunks) # 计算每个块的长度。
for chunk, chunk_length in zip(chunks, chunks_lengths): # 遍历每个块及其长度。
if chunk_length > self._chunk_size: # 如果块的长度超过限制。
if self._keep_separator :
final_chunks.extend(self.recursive_split_text_keep_separator_(chunk)) # 调用递归分割方法进一步拆分。
continue
final_chunks.extend(self.recursive_split_text_(chunk)) # 调用递归分割方法进一步拆分。
else: # 如果块的长度未超过限制。
final_chunks.append(chunk) # 直接保留该块。
return final_chunks # 返回最终的块列表。
def recursive_split_text_(self, text: str) -> list[str]: # 定义递归分割方法。
"""Split incoming text and return chunks.""" # 文档字符串,说明该方法的作用是递归地分割文本并返回块。
final_chunks = [] # 初始化最终的块列表。
separator = self._separators[-1] # 默认使用备用分隔符列表中的最后一个分隔符。
new_separators = [] # 初始化新的分隔符列表。
for i, _s in enumerate(self._separators): # 遍历备用分隔符列表。
if _s == "": # 如果遇到空字符串分隔符。
separator = _s # 设置分隔符为空字符串。
break # 结束循环。
if _s in text: # 如果当前分隔符存在于文本中。
separator = _s # 设置分隔符为当前分隔符。
new_separators = self._separators[i + 1 :] # 更新新的分隔符列表。
break # 结束循环。
# Now that we have the separator, split the text # 已经确定了分隔符,开始分割文本。
if separator: # 如果分隔符不为空。
if separator == " ": # 如果分隔符是空格。
splits = text.split() # 按空格分割文本。
else: # 如果分隔符不是空格。
splits = text.split(separator) # 按指定分隔符分割文本。
else: # 如果分隔符为空字符串。
splits = list(text) # 将文本按字符分割成列表。
splits = [s for s in splits if (s not in {""})] # 过滤掉空字符串和换行符。
_good_splits = [] # 初始化符合长度要求的块列表。
_good_splits_lengths = [] # 缓存这些块的长度。
self._keep_separator = False
_separator = "" if self._keep_separator else separator # 根据是否保留分隔符决定连接符。
s_lens = self._length_function(splits) # 计算每个分割部分的长度。
if _separator != "": # 如果连接符不为空。
for s, s_len in zip(splits, s_lens): # 遍历每个分割部分及其长度。
print("-----",s,s_len,self._chunk_size)
if s_len < self._chunk_size: # 如果长度小于限制。
_good_splits.append(s) # 将其加入符合要求的块列表。
_good_splits_lengths.append(s_len) # 缓存其长度。
else: # 如果长度超出限制。
if _good_splits: # 如果有符合要求的块。
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) # 合并这些块。
final_chunks.extend(merged_text) # 将合并后的块加入最终块列表。
_good_splits = [] # 清空符合要求的块列表。
_good_splits_lengths = [] # 清空长度缓存。
if not new_separators: # 如果没有新的分隔符。
final_chunks.append(s) # 直接保留当前部分。
else: # 如果有新的分隔符。
other_info = self._split_text(s, new_separators) # 递归调用分割方法。
final_chunks.extend(other_info) # 将结果加入最终块列表。
if _good_splits: # 如果还有剩余的符合要求的块。
merged_text = self._merge_splits(_good_splits, _separator, _good_splits_lengths) # 合并这些块。
final_chunks.extend(merged_text) # 将合并后的块加入最终块列表。
else: # 如果连接符为空。
current_part = "" # 初始化当前块。
current_length = 0 # 初始化当前块的长度。
overlap_part = "" # 初始化重叠部分。
overlap_part_length = 0 # 初始化重叠部分的长度。
for s, s_len in zip(splits, s_lens): # 遍历每个分割部分及其长度。
if current_length + s_len <= self._chunk_size - self._chunk_overlap: # 如果当前块可以容纳更多内容。
current_part += s # 将当前部分加入当前块。
current_length += s_len # 更新当前块的长度。
elif current_length + s_len <= self._chunk_size: # 如果当前块接近长度限制。
current_part += s # 将当前部分加入当前块。
current_length += s_len # 更新当前块的长度。
overlap_part += s # 将当前部分加入重叠部分。
overlap_part_length += s_len # 更新重叠部分的长度。
else: # 如果当前块已满。
final_chunks.append(current_part) # 将当前块加入最终块列表。
current_part = overlap_part + s # 构造新的当前块。
current_length = s_len + overlap_part_length # 更新当前块的长度。
overlap_part = "" # 清空重叠部分。
overlap_part_length = 0 # 清空重叠部分的长度。
if current_part: # 如果还有剩余的当前块。
final_chunks.append(current_part) # 将其加入最终块列表。
return final_chunks # 返回最终的块列表。
def recursive_split_text_keep_separator_(self, text: str) -> list[str]: # 定义递归分割方法。
"""Split incoming text and return chunks.""" # 文档字符串,说明该方法的作用是递归地分割文本并返回块。
final_chunks = [] # 初始化最终的块列表。
current_part_list = []
self.append_next_split_text(current_part_list=current_part_list,
current_length_list=[],
text=text,
final_chunks = final_chunks,
separators = self._separators)
if len(current_part_list): # 如果还有剩余的当前块。
final_chunks.append("".join(current_part_list)) # 将其加入最终块列表。
return final_chunks # 返回最终的块列表。
@classmethod
def get_splits_(self,text:str, separators:list[str]) -> (list[str],list[str]): # 定义递归分割方法。
"""Split incoming text and return chunks.""" # 文档字符串,说明该方法的作用是递归地分割文本并返回块。
if len(separators) > 0:
separator = separators[-1] # 默认使用备用分隔符列表中的最后一个分隔符。
new_separators = [] # 初始化新的分隔符列表。
for i, _s in enumerate(separators): # 遍历备用分隔符列表。
if _s in text: # 如果当前分隔符存在于文本中。
separator = _s # 设置分隔符为当前分隔符。
new_separators = separators[i + 1 :] # 更新新的分隔符列表。
break # 结束循环。
# Now that we have the separator, split the text # 已经确定了分隔符,开始分割文本。
if separator: # 如果分隔符不为空。
splits = text.split(separator) # 按指定分隔符分割文本。
else: # 如果分隔符为空字符串。
splits = list(text) # 将文本按字符分割成列表。
# splits = [s for s in splits if (s not in {""})] # 过滤掉空字符串和换行符。
return splits,new_separators
else:
return [text],[]
def append_next_split_text(self,
current_part_list:list[str],
current_length_list:list[int],
text: str,
final_chunks: list[str],
separators : list[str]): # 定义递归分割方法。
if text:
# 需要判断是否可以再拼接
splits, new_separators_ = self.get_splits_(text, separators)
s_lens = self._length_function(splits) # 计算每个分割部分的长度。
for s, s_len in zip(splits, s_lens): # 遍历每个分割部分及其长度。
current_length = sum(current_length_list)
if "制定综合主进度" in s:
print(s)
if current_length + s_len <= self._chunk_size: # 如果当前块可以容纳更多内容。
current_part_list.append(s) # 将当前部分加入当前块。
current_length_list.append(s_len)
else:
if len(new_separators_) == 0:
# 将片段加入到列表中
final_chunks.append("".join(current_part_list))
# 计算出重叠部分的内容
overlap_part_length_,overlap_part_ = self.get_overlap_part(current_part_list,current_length_list)
# 将重叠部分作为下一个片段的开头
current_part_list.clear()
current_part_list.append(overlap_part_)
current_length_list.clear()
current_length_list.append(overlap_part_length_)
if overlap_part_length_ + s_len <= self._chunk_size: # 如果当前块可以容纳更多内容。
current_part_list.append(s) # 将当前部分加入当前块。
current_length_list.append(s_len)
continue
# 递归计算
self.append_next_split_text(current_part_list=current_part_list,
current_length_list=current_length_list,
text=s,
final_chunks=final_chunks,
separators=new_separators_)
def get_overlap_part(self,
current_part_list:list[str],
current_length_list:list[int]) -> (int,str): # 定义递归分割方法。
# 一下计算出
overlap_part_length_ = 0
overlap_part_list = []
current_length_list_reversed = list(reversed(current_length_list))
current_part_list_reversed = list(reversed(current_part_list))
for index, s_len_ in enumerate(current_length_list_reversed):
if overlap_part_length_ + s_len_ > self._chunk_overlap:
if overlap_part_length_ < self._chunk_overlap:
text = current_part_list_reversed[index]
texts = list(text)
text_lens = self._length_function(texts)
texts_reversed = list(reversed(texts))
text_lens_reversed = list(reversed(text_lens))
for s_, len_ in zip(texts_reversed,text_lens_reversed):
if overlap_part_length_ + len_ > self._chunk_overlap:
break
overlap_part_length_ += len_
overlap_part_list[0:0] = s_
# overlap_part_list.append(s_)
break
overlap_part_length_ += s_len_
overlap_part_list[0:0] = current_part_list_reversed[index]
# overlap_part_list.append(current_part_list_reversed[index])
return overlap_part_length_, "".join(overlap_part_list)

@ -0,0 +1,307 @@
"""Abstract interface for document loader implementations."""
from core.rag.splitter.fixed_text_splitter_ext import (
EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter,
)
from services.ext.read_file_service import ReadPdfService
class BaseIndexExtProcessor:
def _get_splitter(
self,
):
text = """作战系统
5. 综合主进度
综合主计划与综合主进度编制与使用指南
30
由于可以将实际进度与计划进度进行比较因此综合主进度是提供指标衡量以及评估剩
余工作范围和时间的关键
5.3 制定综合主进度
5.3.1 审查综合主计划
制定综合主进度的第一步是审查综合主计划的事件成就和标准以及综合主计划中的
政府和承包商业务流程综合主进度将以综合主计划的结构为基础进一步列出任务和
详细工作包
工作包指的是一组相关的任务或活动将它们作为一个整体进行管理与规划包不同的
工作包中包括挣值管理的任务即描述如何计算计划工作预算费用也称为挣值
综合主进度中使用的事件成就和标准的描述性标签应与综合主计划中使用的标签相同
每个事件成就和标准都应标注简短的描述性标题并进行编号或编码以便与综合主
计划联系起来如此一来综合主进度的任务便可直接追溯到综合主计划
综合主进度列出了实现每项综合主计划标准成就和事件的计划日期以及实现这些标
成就和事件所需执行的详细任务的预估时长因此只有在制定了综合主进度后
项目组才能确定完成综合主计划时间的预期日期
综合主进度本身通常不具有合同约束力因为随着项目的进行日期可能会发生变化
而且实际进度可能与计划估算不符此外详细任务可能会因各种原因发生变化但这
并不会影响标准的有效性或达成综合主进度是一份动态文件需要随着项目的进展而
不断变更完善因此利益相关者和进度计划团队应定期会面更新进度计划并确定其
状态综合主进度中确定资源可用性的日期并非正式日期而是针对任务完成的实际预
为确保综合主进度由事件驱动而不是由进度驱动项目应坚持在执行事件前完成
事件的所有进入标准
顺序可能因情况而异但基于综合主计划制定 综合主进度的常用步骤如下图5-3
1. 确定项目目标项目目标直接来源于工作说明书或合同工作说明书工作说明书
明确定义了项目范围包括目标和交付成果此外确定项目目标还有助于确定
项目范围要求和时间线
2. 创建或采用工作分解结构工作分解结构按层次将项目活动分解为更小更易于
管理的组成部分
3. 制定综合主计划综合主计划是对整个项目的高层次概述
包括目的目标和主要里程碑虽然综合主计划以工作分
解结构为依据但它通常与工作分解结构同时制定
4. 定义组织分解结构组织分解结构是一个层次模型或图表
用于表示组织结构确定不同部门团队和个人之间的关
它直观地展示了组织结构及其内部的汇报关系组织
分解结构通常用于为特定的个人团队或部门分配职能
在特定的控制账户内执行特定的项目任务或活动它还用
于确定详细工作主要是确定任务可将控制账户分解为
工作将规划包分解为任务
5. 制定综合主进度
o 定义任务时长和逻辑
o 确定资源可用性
o 确定里程碑通过确定里程碑可以将综合主进度与
综合主计划和工作分解结构联系起来
6. 构建综合主进度网络图逻辑纽带
o 验证综合主进度
o 调整网络图
o 设置综合主进度基
线
7. 确定关键路径
图5-3综合主进度制定流程
8. 制定状态监测和报告计划
确定关键路径
提供报告和
分析
设置
IMS
基线
调整网
络图&
获得同
监测IMS状态
/控制基线变
确定里
程碑
构建IMS
网络
验证
IMS
确定资
源可用
定义任
长和逻
记录工作责任方
OBS
制定
IMP
创建
WBS
确定项目目
(SOW)
5. 综合主进度
综合主计划与综合主进度编制与使用指南
32
o 综合主进度状态/控制基线变化综合主进度状态是指项目目前相对于计划进度
的进展情况综合主进度状态可以有多种例如
正按计划进行项目按照计划进度推进任务按时完成且未超出预算
进度落后项目进展与计划不符任务完成时间晚于预期这可能是由于
延误资源限制或意外问题等因素造成的
进度提前项目进度快于预期任务提前完成这可能是由于项目计划的
效率或意想不到的机会等因素造成的
面临风险项目面临进度落后或超出预算的风险
o 提供报告和分析报告和分析通常侧重于跟踪项目进度并确定可能影响项目
时间线或预算的潜在延误或问题
5.3.2 编制综合主进度文件
通常情况下除综合主进度外每个投标方还会创建并提交一份综合主进度文件其中
包括对其制定进度的方法进行解释对如何使用电子文档进行说明以及确定所定义的
字段该文件是为了方便评估并允许投标方提供有关综合主进度的补充信息下文为
综合主进度文件的推荐编排格式这一格式可根据需要进行调整以满足各个项目的需
第1节引言部分内容包括
简要概述综合主进度
介绍编制综合主进度的假定和基本准则如采用的日历节假日限制等
介绍综合主进度的独特功能
o 编号体系说明
o 包含其他数据字段标识相关文本或其他字段
o 说明应如何管理综合主进度及其更改
第2节为长时任务尽快完成(ASAP)的任务约束条件或提前或滞后时间过长等进
度情况提供依据应避免提前但如果在进度中采用了提前则需要加以解释
第3节对投标方在甘特图或表格形式进度表中所用方法的关键要素以及项目关键路径进
行说明关键路径在报告格式中应易于区分本节也适合用于论述投标方开展的任何进
度风险评估SRA
5. 综合主进度
综合主计划与综合主进度编制与使用指南
33
第4节招标书要求或投标方确定的综合主进度术语和缩略语词汇表
第5节简明进度表甘特图格式通常为一页但复杂项目可能不止一页表格格式
5.3.3 定义项目管理工具设置和属性
项目管理工具即进度计划软件工具是否十分有用取决于所跟踪的信息以及输入的数
据是否准确
在将事件成就标准任务输入进度计划软件工具之前用户应首先熟悉其所
在单位使用的软件如果所在单位没有标准的软件排程模板排程人员就必须建立一个
模板建立模板包括设置自动编号体系确定哪些字段是必填字段日期时长成本
创建带有节假日和其他与项目管理团队相关的独特日期的自定义日历例如定期
举行的集成产品团队会议为当前更新周期确定状态日期以及确认任务已设置为自动
计划
自动编号自动编号可防止重复编号并提供从特定里程碑到完成该里程碑所需的工作
包和规划包的逻辑流程综合主进度编号体系如图5-4所示各单位可根据自身组织需求
和项目管理工具开发自己的编号体系
图5-4综合主进度编号体系
1.1.3.3 (依据WBS编号体系)
任务
标准
成就
事件
5. 综合主进度
综合主计划与综合主进度编制与使用指南
34
各项任务通过属性进一步细化这些属性为每个任务提供了更多细节一些常见的属性
包括
开始任务和项目计划开始日期
时长完成一项任务的总时间
完成任务和项目计划完成日期
开始任务和项目计划开始日期
实际完成任务和项目实际完成日期
资源完成任务所需的人员设备和材料
前置任务必须在另一项任务开始之前完成的任务
后继任务必须在另一项任务结束之后完成的任务
关键关键路径任务关键路径任务必须按时完成以确保项目按进度进行
键路径可根据综合主进度的更新而变更
其他属性包括工作完成百分比费用实际开始时间实际完成时间分配任
预算相关属性成本成本差异等
项目开始日期团队应通过输入项目开始日期来启动综合主进度在团队输入任
任务开始日期时长以及前置任务和后继任务关系时输入此项内容应使项
目管理软件工具能够自动生成其余的开始和完成日期
初始栏目设置综合主进度的制定分为几个阶段第一阶段包括输入综合主计划
信息和与 综合主计划标准相关的综合主进度任务启动综合主进度最初需要的栏
目包括工作分解结构任务名称时长开始时间完成时间前置任务和后继
任务任务名称时长开始时间完成时间前置任务和后继任务图5-5
图5-5初始栏目设置
5.3.4 开发任务
各集成产品团队应通过确定哪些任务是支持综合主计划所必需的从而制定自己的综合
主进度集成产品团队应提供每项任务的任务名称包括使动动词时长以及与其他
任务的关系前置任务和后继任务这样就可以确定项目的关键路径进度风险评估
可能需要提供最短和最长用时的信息集成产品团队还应利用工作分解结构字典其中
列出并定义了工作分解结构要素与综合主进度划与综合主进度联络人确认每项任务
的相关工作分解结构要素
集成
产品
团队
资源
前置任务 后继任务 名称
任务名称 时长
5. 综合主进度
综合主计划与综合主进度编制与使用指南
35
建立综合主计划与综合主进度是一个反复的过程如果集成产品团队在构建综合主进度
时发现所需的任务在逻辑上不属于现有的综合主计划标准则应建议将这些任务归入其
他标准或成就从事件到成就再到标准最后到任务应该始终逻辑清晰如果某项
任务没有合理的工作分解结构归属则应调整工作分解结构工作分解结构可以帮助政
府和承包商评估项目的进度和成熟度并确保项目由事件驱动
在定义综合主进度的任务时项目可能需要增加更多层级的契约或子任务以获得集成
产品团队所需的细节并进一步定义工作包对于综合主进度中更高层次的任务用于
说明主要分包商执行的任务而言尤其如此当要求主承包商提供综合主进度时
承包商综合主进度中的某项任务可能在分包商的内部综合主进度中被进一步细分为子
任务根据关键程度主承包商的综合主进度可能会将任务细分为子任务细分为子任
务是一种常见现象而且与综合主计划与综合主进度的结构和理念相一致在细分子任
务时编号体系只需进一步定义或扩展如扩展为D01a02a或D01a02.1
5.3.5 预估任务时长
一旦计划中确定了综合主进度的任务就需要预估这些任务的时长并确定与任务相关
的活动这一步骤需要各种类型的信息包括但不限于
历史数据关于完成某些任务和活动所需时间的历史数据
资源可用性所需资源的可用性会对任务和活动的时长产生重大影响
任务依赖性相互依赖的任务和活动及其执行顺序会影响时长
工作时间与非工作时间关于工作时间的规定如每天工作8小时还是10小时
周工作5天还是7天等会影响进度
资源质量人员经验是否丰富设施和设备是否合格先进这些与资源质量相关
的因素对完成一项活动所需的时间有重要影响
5. 综合主进度
综合主计划与综合主进度编制与使用指南
36
风险因素这指的是可能影响任务时长的潜在事件或条件识别和评估潜在风险
因素有助于制定应急计划更准确地估算任务和活动所需的时长
任务时长估算技术用于预测完成特定进度任务和活动所需的时间各种估算技术包括
5.3.5.1 类比估算这一技术也称为自上而下估算包括将当前项目与过去的类似项目进
行比较分析人员将过去和当前项目的类似任务进行比较包括制造技术材料流程
设备设施法规和数量等方面的差异类比估算主要用于项目信息有限的情况
5.3.5.2 三点估算这一技术也称为计划评估与审查技术(PERT)顾名思义它是为每项
任务提供三种估算
乐观估算这是一种最佳情况下的估算假设一切均按计划进行
悲观估算这是一种最糟糕情况下的估算假设一切都会出错估算引入的问题
包括供应商延误人力短缺和其他可能导致进度问题的问题集
最有可能的估算这是一种根据正常情况和潜在挑战做出的最现实的估算介于
乐观和悲观估算之间
确定这些估算值后计算加权平均值可以得出最终估算值三点进度估算技术考虑了
每项任务或活动的不确定性和可变性
5.3.5.3 参数估算这种方法使用统计数据根据类似项目的历史数据估算任务的时长
参数化进度计划通常需要大量的数据和统计分析来开发精确的模型这些模型可能受到
项目规模复杂性和团队生产率等因素的影响这种方法还可能需要考虑市场趋势
术进步和监管要求等外部因素参数估算技术是一种强大的工具但必须确保所使用历
史数据的相关性和准确性
5.3.5.4 单点估算单点估算是专业领域专家对任务时长可能是多少的看法它的优点是
简单易于沟通并有助于快速决策缺点是可能不准确没有误差余地提供的洞察
力有限
5. 综合主进度
综合主计划与综合主进度编制与使用指南
37
关于如何估算进度时长请参考组织政策和程序
5.3.6 审查任务时长问题
长时任务根据国防合同管理局的规定任何超过44天的任务或活动均为长时任务
果综合主进度中存在长时任务与状态间隔时长不相等的活动团队应审查这些任务
以确定是否应进一步细分如果不应进一步细分政府或承包商需在综合主进度文件中
说明理由图5-6
在项目执行过程中项目活动可能需要进一步界定并分解为时长较短的活动单元
时任务的缺点包括可能会曲解关键路径使进度难以衡量在下述示例中通用战车的
作战评估还可以进一步细分为子任务例如进行每个阶段的测试以及为每个阶段的
测试运进或运出测试设备
图5-6长时任务示例
提前期应避免提前期也称负滞后指的是比前置任务开始或结束提前了一段时
由于复杂性风险增加和缺乏精确性应避免这些依赖关系提前期的示例如图5-7
所示在此示例中分析数据并生成作战评估I测试报告的任务被安排在各自的前置任务
之前启动由于可以对早期测试活动的数据进行分析并在测试事件完成时编写测试报
因此这样的进度安排是可以实现的更多信息请参阅有关提前期的组织政策
逻辑依据
用于整理分析和报告数据的典型时长
150
120
时长
200
OA有四个连续的阶段射击水陆两栖
漠和寒冷气候
执行测试后的数据整理分析和报告编写
任务名称
当前供应系统中不包括通常用于申购特殊维修
用品的时间
进行通用战车OA
采购OA维修区块的维修部件
"""
"""
Get the NodeParser object according to the processing rule.
"""
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=500,
chunk_overlap=50,
fixed_separator="@@@@@",
separators=["\n\n", "\n", "", "", "", "", "", ". ", "?", "", "!", ")", ":", ",", "#", "", " "],
)
pdf_text = ReadPdfService().load_content(pdf_file_path=r"D:\a.pdf")
# print(pdf_text)
document_nodes = character_splitter.split_text(pdf_text)
for document in document_nodes:
print("---------")
print(document)
if __name__ == "__main__":
indexExtProcessor = BaseIndexExtProcessor()
indexExtProcessor._get_splitter()

@ -159,6 +159,50 @@ class TextSplitter(BaseDocumentTransformer, ABC):
) )
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs) return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
@classmethod
def from_tiktoken_encoder(
cls: type[TS],
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], Set[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> TS:
"""Text splitter that uses tiktoken encoder to count length."""
try:
import tiktoken
except ImportError:
raise ImportError(
"Could not import tiktoken python package. "
"This is needed in order to calculate max_tokens_for_prompt. "
"Please install it with `pip install tiktoken`."
)
if model_name is not None:
enc = tiktoken.encoding_for_model(model_name)
else:
enc = tiktoken.get_encoding(encoding_name)
def _tiktoken_encoder(text: str) -> int:
return len(
enc.encode(
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
)
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"encoding_name": encoding_name,
"model_name": model_name,
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs)
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]: def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Transform sequence of documents by splitting them.""" """Transform sequence of documents by splitting them."""
return self.split_documents(list(documents)) return self.split_documents(list(documents))

@ -0,0 +1,13 @@
from typing import Optional
from pydantic import BaseModel
from core.workflow.graph_engine.entities.graph import GraphParallel
class NextGraphNode(BaseModel):
node_id: str
"""next node id"""
parallel: Optional[GraphParallel] = None
"""parallel"""

@ -25,6 +25,7 @@ class NodeType(StrEnum):
DOCUMENT_EXTRACTOR = "document-extractor" DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator" LIST_OPERATOR = "list-operator"
AGENT = "agent" AGENT = "agent"
VANNA = "vanna"
class ErrorStrategy(StrEnum): class ErrorStrategy(StrEnum):

@ -95,6 +95,9 @@ SupportedComparisonOperator = Literal[
# for time # for time
"before", "before",
"after", "after",
# 扩展
"in",
"not in"
] ]

@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast from typing import Any, Optional, cast
from sqlalchemy import Float, and_, func, or_, text from sqlalchemy import Float, Integer, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast from sqlalchemy import cast as sqlalchemy_cast
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -370,7 +370,8 @@ class KnowledgeRetrievalNode(LLMNode):
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name metadata_name = condition.name
expected_value = condition.value expected_value = condition.value
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): # if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty", "in"):
if isinstance(expected_value, str): if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template( expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value expected_value
@ -531,6 +532,14 @@ class KnowledgeRetrievalNode(LLMNode):
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) <= value) filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) <= value)
case "" | ">=": case "" | ">=":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) >= value) filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) >= value)
case "in":
if value is None or value == "":
filters.append(1 == 2)
else:
values = value.split(',')
filters.append(
(text("documents.doc_metadata ->> :key in :value")).params(key=metadata_name, value=tuple(values))
)
case _: case _:
pass pass
return filters return filters

@ -1,5 +1,6 @@
from collections.abc import Mapping from collections.abc import Mapping
from core.workflow.nodes.vanna.vanna_node import VannaNode
from core.workflow.nodes.agent.agent_node import AgentNode from core.workflow.nodes.agent.agent_node import AgentNode
from core.workflow.nodes.answer import AnswerNode from core.workflow.nodes.answer import AnswerNode
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
@ -119,4 +120,8 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[BaseNode]]] = {
LATEST_VERSION: AgentNode, LATEST_VERSION: AgentNode,
"1": AgentNode, "1": AgentNode,
}, },
NodeType.VANNA: {
LATEST_VERSION: VannaNode,
"1": VannaNode,
},
} }

@ -0,0 +1,3 @@
from .vanna_node import VannaNode
__all__ = ["VannaNode"]

@ -0,0 +1,26 @@
from typing import Any, Literal, Optional
from pydantic import BaseModel, Field, field_validator
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.llm import ModelConfig, VisionConfig
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
class VannaConfig(BaseModel):
"""
Vanna Config.
"""
provider: str
name: str
mode: LLMMode
class VannaNodeData(BaseNodeData):
model: ModelConfig
query: list[str]
instruction: Optional[str] = None
vision: VisionConfig = Field(default_factory=VisionConfig)

@ -0,0 +1,50 @@
class VannaNodeError(ValueError):
"""Base error for VannaNode."""
class InvalidModelTypeError(VannaNodeError):
"""Raised when the model is not a Large Language Model."""
class ModelSchemaNotFoundError(VannaNodeError):
"""Raised when the model schema is not found."""
class InvalidInvokeResultError(VannaNodeError):
"""Raised when the invoke result is invalid."""
class InvalidTextContentTypeError(VannaNodeError):
"""Raised when the text content type is invalid."""
class InvalidNumberOfParametersError(VannaNodeError):
"""Raised when the number of parameters is invalid."""
class RequiredParameterMissingError(VannaNodeError):
"""Raised when a required parameter is missing."""
class InvalidSelectValueError(VannaNodeError):
"""Raised when a select value is invalid."""
class InvalidNumberValueError(VannaNodeError):
"""Raised when a number value is invalid."""
class InvalidBoolValueError(VannaNodeError):
"""Raised when a bool value is invalid."""
class InvalidStringValueError(VannaNodeError):
"""Raised when a string value is invalid."""
class InvalidArrayValueError(VannaNodeError):
"""Raised when an array value is invalid."""
class InvalidModelModeError(VannaNodeError):
"""Raised when the model mode is invalid."""

@ -0,0 +1,92 @@
from typing import Any, Optional, cast
from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.model_manager import ModelInstance
from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.enums import NodeType
from core.workflow.nodes.llm import LLMNode
from extensions.utils.vanna_text2sql import VannaServer
from models.workflow import WorkflowNodeExecutionStatus
from .entities import VannaNodeData
class Config:
def __init__(self, supplier):
self.embedding_supplier = "SiliconFlow"
self.milvus_uri = dify_config.MILVUS_URI
self.milvus_database = 'vanna_demo'
self.supplier = supplier
self.sql_type = 'postgres'
self.sql_config = {
"host": dify_config.DB_HOST,
"dbname": 'vanna_demo',
"user": dify_config.DB_USERNAME,
"password": dify_config.DB_PASSWORD,
"port": dify_config.DB_PORT
}
vn_instances = {}
def get_vanna_server(key, combined_config):
if key not in vn_instances:
vn_instances[key] = VannaServer(combined_config)
return vn_instances[key]
class VannaNode(LLMNode):
# FIXME: figure out why here is different from super class
_node_data_cls = VannaNodeData # type: ignore
_node_type = NodeType.VANNA
_model_instance: Optional[ModelInstance] = None
_model_config: Optional[ModelConfigWithCredentialsEntity] = None
@classmethod
def get_default_config(cls, filters: Optional[dict] = None) -> dict:
return {
"model": {
}
}
def _run(self):
node_data = cast(VannaNodeData, self.node_data)
variable = self.graph_runtime_state.variable_pool.get(node_data.query)
query = variable.text if variable else ""
model_instance, model_config = self._fetch_model_config(self.node_data.model)
# 'tongyi' 通义 'openai' openai 'ollama' ollama 'deepseek' deepseek
llm_type = model_instance.provider.rsplit('/')[-1]
api_key = ''
base_url = ''
if llm_type == 'tongyi':
api_key = model_instance.credentials.get('dashscope_api_key')
elif llm_type == 'deepseek':
api_key = model_instance.credentials.get('api_key')
elif llm_type == 'ollama':
base_url = model_instance.credentials.get('base_url')
cache_kay = llm_type + api_key if api_key else base_url
model = model_instance.model
vanna_config = {
"llm_type": llm_type,
"model": model,
"api_key": api_key,
"ollama_host": base_url
}
config = Config("")
# 合并配置
combined_config = {**config.__dict__, **config.sql_config, **vanna_config}
cache_data = get_vanna_server(cache_kay, combined_config)
# 提问获取sql和结果
sql = cache_data.generate_sql(query)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={"output": sql}
)

@ -0,0 +1,228 @@
from vanna.ollama import Ollama
from dotenv import load_dotenv
load_dotenv()
from functools import wraps
from flask import Flask, jsonify, Response, request
import flask
from extensions.storage.cache import MemoryCache
from dify_app import DifyApp
from vanna.milvus import Milvus_VectorStore
from pymilvus import MilvusClient
from configs import dify_config
# SETUP
cache = MemoryCache()
milvus_uri = dify_config.MILVUS_URI
milvus_client = MilvusClient(uri=milvus_uri)
milvus_client.use_database("test")
class MyVanna(Milvus_VectorStore, Ollama):
def __init__(self, config=None):
Milvus_VectorStore.__init__(self, config=config)
Ollama.__init__(self, config=config)
# vn = MyVanna(config={
# 'model': 'qwen2:7b', # 本地ollama大模型名称
# 'ollama_host':'http://wsd.wisdomidata.com:19042', # 本地ollama大模型服务地址
# 'milvus_client': milvus_client, # 本地milvus向量数据库服务地址
# "n_results": 12,
# })
# vn.connect_to_postgres(
# host=dify_config.DB_HOST,
# dbname='vanna_demo',
# user=dify_config.DB_USERNAME,
# password=dify_config.DB_PASSWORD,
# port=dify_config.DB_PORT
# )
# vn.connect_to_mysql(
# host='122.51.104.137',
# port=33306,
# dbname='demo',
# user='sws',
# password='123456'
# )
def init_app(app: DifyApp):
def requires_cache(fields):
def decorator(f):
@wraps(f)
def decorated(*args, **kwargs):
id = request.args.get('id')
if id is None:
return jsonify({"type": "error", "error": "No id provided"})
for field in fields:
if cache.get(id=id, field=field) is None:
return jsonify({"type": "error", "error": f"No {field} found"})
field_values = {field: cache.get(id=id, field=field) for field in fields}
# Add the id to the field_values
field_values['id'] = id
return f(*args, **field_values, **kwargs)
return decorated
return decorator
@app.route('/api/v0/generate_questions', methods=['GET'])
def generate_questions():
return jsonify({
"type": "question_list",
"questions": vn.generate_questions(),
"header": "Here are some questions you can ask:"
})
@app.route('/api/v0/generate_sql', methods=['GET'])
def generate_sql():
question = flask.request.args.get('question')
if question is None:
return jsonify({"type": "error", "error": "No question provided"})
id = cache.generate_id(question=question)
sql = vn.generate_sql(question=question)
cache.set(id=id, field='question', value=question)
cache.set(id=id, field='sql', value=sql)
return jsonify(
{
"type": "sql",
"id": id,
"text": sql,
})
@app.route('/api/v0/run_sql', methods=['GET'])
@requires_cache(['sql'])
def run_sql(id: str, sql: str):
try:
df = vn.run_sql(sql=sql)
cache.set(id=id, field='df', value=df)
return jsonify(
{
"type": "df",
"id": id,
"df": df.head(10).to_json(orient='records'),
})
except Exception as e:
return jsonify({"type": "error", "error": str(e)})
@app.route('/api/v0/download_csv', methods=['GET'])
@requires_cache(['df'])
def download_csv(id: str, df):
csv = df.to_csv()
return Response(
csv,
mimetype="text/csv",
headers={"Content-disposition":
f"attachment; filename={id}.csv"})
@app.route('/api/v0/generate_plotly_figure', methods=['GET'])
@requires_cache(['df', 'question', 'sql'])
def generate_plotly_figure(id: str, df, question, sql):
try:
code = vn.generate_plotly_code(question=question, sql=sql,
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}")
fig = vn.get_plotly_figure(plotly_code=code, df=df, dark_mode=False)
fig_json = fig.to_json()
cache.set(id=id, field='fig_json', value=fig_json)
return jsonify(
{
"type": "plotly_figure",
"id": id,
"fig": fig_json,
})
except Exception as e:
# Print the stack trace
import traceback
traceback.print_exc()
return jsonify({"type": "error", "error": str(e)})
@app.route('/api/v0/get_training_data', methods=['GET'])
def get_training_data():
df = vn.get_training_data()
return jsonify(
{
"type": "df",
"id": "training_data",
"df": df.head(25).to_json(orient='records'),
})
@app.route('/api/v0/remove_training_data', methods=['POST'])
def remove_training_data():
# Get id from the JSON body
id = flask.request.json.get('id')
if id is None:
return jsonify({"type": "error", "error": "No id provided"})
if vn.remove_training_data(id=id):
return jsonify({"success": True})
else:
return jsonify({"type": "error", "error": "Couldn't remove training data"})
@app.route('/api/v0/train', methods=['POST'])
def add_training_data():
question = flask.request.json.get('question')
sql = flask.request.json.get('sql')
ddl = flask.request.json.get('ddl')
documentation = flask.request.json.get('documentation')
try:
id = vn.train(question=question, sql=sql, ddl=ddl, documentation=documentation)
return jsonify({"id": id})
except Exception as e:
print("TRAINING ERROR", e)
return jsonify({"type": "error", "error": str(e)})
@app.route('/api/v0/generate_followup_questions', methods=['GET'])
@requires_cache(['df', 'question', 'sql'])
def generate_followup_questions(id: str, df, question, sql):
followup_questions = vn.generate_followup_questions(question=question, sql=sql, df=df)
cache.set(id=id, field='followup_questions', value=followup_questions)
return jsonify(
{
"type": "question_list",
"id": id,
"questions": followup_questions,
"header": "Here are some followup questions you can ask:"
})
@app.route('/api/v0/load_question', methods=['GET'])
@requires_cache(['question', 'sql', 'df', 'fig_json', 'followup_questions'])
def load_question(id: str, question, sql, df, fig_json, followup_questions):
try:
return jsonify(
{
"type": "question_cache",
"id": id,
"question": question,
"sql": sql,
"df": df.head(10).to_json(orient='records'),
"fig": fig_json,
"followup_questions": followup_questions,
})
except Exception as e:
return jsonify({"type": "error", "error": str(e)})
@app.route('/api/v0/get_question_history', methods=['GET'])
def get_question_history():
return jsonify({"type": "question_history", "questions": cache.get_all(field_list=['question'])})

@ -0,0 +1,238 @@
import json
import ast
from configs import dify_config
from extensions.utils.vanna_text2sql import VannaServer
from dotenv import load_dotenv
load_dotenv()
from dify_app import DifyApp
from flask import Flask, jsonify, Response, request
import flask
from werkzeug.exceptions import BadRequest
import logging
import plotly.io as pio
from functools import lru_cache
from datetime import datetime
class Config:
def __init__(self, supplier):
self.embedding_supplier = "SiliconFlow"
self.milvus_uri = dify_config.MILVUS_URI
self.milvus_database = 'vanna_demo'
self.supplier = supplier
# self.llm_type = 'tongyi'
# self.model = 'qwen-max'
# self.api_key = 'sk-ba5d240e2dc0483e9e24404d957a15d5'
# 本地模型
# self.ollama_host = 'http://wsd.wisdomidata.com:19042'
# self.model = 'qwen2:7b'
self.llm_type = 'deepseek'
self.model = 'deepseek-coder'
self.api_key = 'sk-0382990b7a90496c889774b1d3843f90'
self.sql_type = 'postgres'
self.sql_config = {
"host": dify_config.DB_HOST,
"dbname": 'vanna_demo',
"user": dify_config.DB_USERNAME,
"password": dify_config.DB_PASSWORD,
"port": dify_config.DB_PORT
}
# 存储不同的 VannaServer 实例
vn_instances = {}
# 获取vanna实例
def get_vn_instance(supplier=""):
"""获取或创建VannaServer实例"""
if supplier == "":
supplier = "default"
if supplier not in vn_instances:
config = Config(supplier)
# 合并配置
combined_config = {**config.__dict__, **config.sql_config}
vn_instances[supplier] = VannaServer(combined_config)
return vn_instances[supplier]
def init_app(app: DifyApp):
@app.route('/api/ask', methods=['POST'])
def ask_route():
"""提问接口"""
data = request.json
question = data.get('question', '')
visualize = data.get('visualize', True)
auto_train = data.get('auto_train', False)
supplier = data.get('supplier', "") # GITEE, ZHIPU, SiliconFlow
if not question:
raise BadRequest("Question is required")
server = get_vn_instance(supplier)
try:
sql, df, fig = server.ask(question=question, visualize=visualize, auto_train=auto_train)
df_json = df.to_json(orient='records', force_ascii=False)
"""
<img id="plotly-image" src="data:image/png;base64,{{ img_base64 }}" alt="Plotly Image">
"""
# fig_js_path = '../output/html/vanna_fig.js'
# fig_html_path = 'http://localhost:8000/html/vanna_fig.html'
# figure_json = pio.to_json(fig)
# with open(fig_js_path, 'w', encoding='utf-8') as f:
# f.write(figure_json)
"""
<div id="plotly-div"></div>
<script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
<script>
var fig_json = {{ fig_json }};
Plotly.newPlot('plotly-div', fig_json.data, fig_json.layout);
</script>
"""
logging.info("Query processed successfully")
return jsonify({
'sql': sql,
'data': df_json,
# 'img_base64': img_base64,
# 'plotly_figure': fig_html_path
}), 200
except Exception as e:
logging.error(f"Error processing request: {e}")
return jsonify({'error': str(e)}), 500
@app.route('/api/vn_train', methods=['POST'])
def vn_train_route():
"""训练接口"""
data = request.json
# required_fields = ['question', 'sql']
# validate_input(data, required_fields)
supplier = data.get('supplier', "")
question = data.get('question', '')
sql = data.get('sql', '')
documentation = data.get('documentation', '')
ddl = data.get('ddl', '')
schema = data.get('schema', False)
# 验证至少有一个参数不为空
if not any([question, sql, documentation, ddl, schema]):
return jsonify(
{
'error': 'At least one of the parameters (question, sql, documentation, ddl, schema) must be provided'}), 400
server = get_vn_instance(supplier)
server.vn_train(question=question, sql=sql, documentation=documentation, ddl=ddl)
if schema:
try:
# server.schema_train()
# 更新建表DDL语句
server.refresh_create_table_ddl_train()
server.refresh_schema_train()
except Exception as e:
logging.info(f"Error initializing vector store: {e}")
logging.info("Training completed successfully")
return jsonify({'status': 'success'}), 200
@app.route('/api/docs/update', methods=['POST'])
def update_schema_train_list_route():
"""训练接口"""
data = request.json
docs = data.get('docs', [])
server = get_vn_instance("")
server.update_schema_train_list(docs=docs)
return jsonify({'status': 'success'}), 200
@app.route('/api/get_training_data', methods=['GET'])
def get_training_data_route():
"""获取训练数据接口"""
supplier = request.args.get('supplier', "")
server = get_vn_instance(supplier)
@lru_cache(maxsize=128) # 添加缓存机制
def cached_get_training_data():
return server.get_training_data()
training_data = cached_get_training_data()
logging.info("Fetched training data successfully")
return jsonify({
'data': json.loads(training_data.to_json(orient='records'))
}), 200
@app.route('/api/generate_sql', methods=['GET'])
def generate_sql():
question = request.args.get('question')
supplier = request.args.get('supplier','')
if question is None:
return jsonify({"type": "error", "error": "No question provided"})
server = get_vn_instance(supplier)
sql = server.generate_sql(question=question)
return jsonify(
{
"sql": sql
}) , 200
@app.route('/api/run_sql', methods=['POST'])
def run_sql():
data = request.json
supplier = data.get('supplier', "")
sql = data.get('sql', '')
try:
server = get_vn_instance(supplier)
df = server.run_sql(sql=sql)
df_json = df.to_json(orient='records', force_ascii=False)
return df_json, 200
except Exception as e:
return jsonify({"type": "error", "error": str(e)})
@app.route('/api/training/data/export', methods=['GET'])
def training_data_export():
supplier = request.args.get('supplier', "")
server = get_vn_instance(supplier)
# @lru_cache(maxsize=128) # 添加缓存机制
# def cached_get_training_data():
# return server.training_data_export()
# training_data = cached_get_training_data()
data = server.training_data_export()
content = ",\n".join(str(line) for line in data)
file_name = datetime.now().strftime('%Y-%m-%d-%H-%M')
# 创建一个可下载的文本响应
return Response(
f"[\n{content}\n]",
mimetype='text/plain',
headers={
"Content-Disposition": f"attachment; filename={file_name}.txt"
}
)
@app.route('/api/training/data/import', methods=['POST'])
def training_data_import():
if 'file' not in request.files:
return jsonify({"type": "error", "error": "未上传文件"}), 400
file = request.files['file']
if file.filename == '':
return jsonify({"type": "error", "error": "文件名为空"}), 400
try:
# 读取文件并解析每一行的 JSON 对象
content = file.read().decode('utf-8').strip()
data_list = ast.literal_eval(content)
server = get_vn_instance("")
result = server.training_data_import(data_list)
if result:
return jsonify({"type": "error", "error": "存在数据集question 或 sql为空"})
return jsonify({'status': 'success'}), 200
except Exception as e:
return jsonify({"type": "error", "error": f"文件解析失败: {str(e)}"}), 500

@ -0,0 +1,62 @@
from abc import ABC, abstractmethod
import uuid
class Cache(ABC):
@abstractmethod
def generate_id(self, *args, **kwargs):
pass
@abstractmethod
def get(self, id, field):
pass
@abstractmethod
def get_all(self, field_list) -> list:
pass
@abstractmethod
def set(self, id, field, value):
pass
@abstractmethod
def delete(self, id):
pass
class MemoryCache(Cache):
def __init__(self):
self.cache = {}
def generate_id(self, *args, **kwargs):
return str(uuid.uuid4())
def set(self, id, field, value):
if id not in self.cache:
self.cache[id] = {}
self.cache[id][field] = value
def get(self, id, field):
if id not in self.cache:
return None
if field not in self.cache[id]:
return None
return self.cache[id][field]
def get_all(self, field_list) -> list:
return [
{
"id": id,
**{
field: self.get(id=id, field=field)
for field in field_list
}
}
for id in self.cache
]
def delete(self, id):
if id in self.cache:
del self.cache[id]

@ -0,0 +1,150 @@
import traceback
from typing import Union, Tuple
import pandas as pd
import plotly
from PIL import Image as PILImage
import io
def ask(
vanna_instance,
question: Union[str, None] = None,
print_results: bool = True,
auto_train: bool = True,
visualize: bool = True, # if False, will not generate plotly code
allow_llm_to_see_data: bool = False,
) -> Union[
Tuple[
Union[str, None],
Union[pd.DataFrame, None],
Union[plotly.graph_objs.Figure, None],
],
None,
]:
"""
**Example:**
python
vn.ask("What are the top 10 customers by sales?")
Ask Vanna.AI a question and get the SQL query that answers it.
Args:
question (str): The question to ask.
print_results (bool): Whether to print the results of the SQL query.
auto_train (bool): Whether to automatically train Vanna.AI on the question and SQL query.
visualize (bool): Whether to generate plotly code and display the plotly figure.
Returns:
Tuple[str, pd.DataFrame, plotly.graph_objs.Figure]: The SQL query, the results of the SQL query, and the plotly figure.
"""
if question is None:
question = input("Enter a question: ")
try:
sql = vanna_instance.generate_sql(question=question, allow_llm_to_see_data=allow_llm_to_see_data)
except Exception as e:
print(e)
return None, None, None
if print_results:
try:
Code = __import__("IPython.display", fromlist=["Code"]).Code
display = __import__("IPython.display", fromlist=["display"]).display
display(Code(sql))
except Exception as e:
print(sql)
if vanna_instance.run_sql_is_set is False:
print(
"If you want to run the SQL query, connect to a database first."
)
if print_results:
return None
else:
return sql, None, None
try:
df = vanna_instance.run_sql(sql)
if print_results:
try:
display = __import__("IPython.display", fromlist=["display"]).display
display(df)
except Exception as e:
print(df)
if len(df) > 0 and auto_train:
vanna_instance.add_question_sql(question=question, sql=sql)
# Only generate plotly code if visualize is True
if visualize:
try:
plotly_code = vanna_instance.generate_plotly_code(
question=question,
sql=sql,
df_metadata=f"Running df.dtypes gives:\n {df.dtypes}",
)
fig = vanna_instance.get_plotly_figure(plotly_code=plotly_code, df=df)
if print_results:
try:
display = __import__("IPython.display", fromlist=["display"]).display
display(plotly_code)
except Exception as e:
print(plotly_code)
except Exception as e:
# Print stack trace
traceback.print_exc()
print("Couldn't run plotly code: ", e)
if print_results:
return None
else:
return sql, df, None
else:
return sql, df, None
except Exception as e:
print("Couldn't run sql: ", e)
if print_results:
return None
else:
return sql, None, None
return sql, df, fig
def display_image_in_pycharm(fig):
"""Display image in PyCharm using matplotlib or PIL."""
try:
# Try to use IPython.display if available
try:
display = __import__("IPython.display", fromlist=["display"]).display
Image = __import__("IPython.display", fromlist=["Image"]).Image
img_bytes = fig.to_image(format="png", scale=2)
display(Image(img_bytes))
except AttributeError:
print("fig does not have to_image method, using fig.savefig instead")
fig.savefig("output.png")
display(Image("output.png"))
except ImportError:
print("IPython.display not available, using matplotlib to show image")
fig.show()
except Exception as e:
print(f"Failed to display image using IPython.display: {e}")
traceback.print_exc()
try:
# Use matplotlib to show image
fig.show()
except Exception as e:
print(f"Failed to display image using fig.show: {e}")
traceback.print_exc()
try:
# Use PIL to show image
img_bytes = io.BytesIO()
fig.savefig(img_bytes, format='png')
img_bytes.seek(0)
pil_img = PILImage.open(img_bytes)
pil_img.show()
except Exception as e:
print(f"Failed to display image using PIL: {e}")
traceback.print_exc()

@ -0,0 +1,425 @@
import os
import json
from vanna.ollama import Ollama
from vanna.qianwen import QianWenAI_Chat
from vanna.deepseek import DeepSeekChat
from extensions.utils.rewrite_ask import ask
from dotenv import load_dotenv
import plotly.io as pio
from vanna.milvus import Milvus_VectorStore
from pymilvus import MilvusClient,model
from collections import defaultdict
load_dotenv()
# 设置显示后端为浏览器
pio.renderers.default = 'browser'
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
from typing import List
import ollama
import numpy as np
from pymilvus.model.base import BaseEmbeddingFunction
# 自定义嵌入式模型适配milvus向量数据库
class CustomEmbeddingFunction(BaseEmbeddingFunction):
def __init__(self, config=None):
model_host = config['host'] if "host" in config else 'http://wsd.wisdomidata.com:19042'
self.embed_model = config['embed_model'] if "embed_model" in config else 'bge-m3'
self.embedding_model = ollama.Client(model_host)
self.keep_alive = config.get('keep_alive', None)
self.ollama_options = config.get('options', {})
self.num_ctx = self.ollama_options.get('num_ctx', 2048)
def __call__(self, texts: List[str]):
self._encode(texts)
def _encode(self,texts: list[str]) -> list[list[float]]:
return [self.embedding_model.embeddings(
model=self.embed_model,
prompt=text,
options=self.ollama_options,
keep_alive=self.keep_alive
)["embedding"] for text in texts]
def encode_documents(self, documents: List[str]) -> List[np.array]:
# 将每个嵌入结果转换为 np.ndarray
embeddings = self._encode(documents)
return [np.array(embedding) for embedding in embeddings]
def encode_queries(self, queries: List[str]) -> List[np.array]:
embeddings = self._encode(queries)
return [np.array(embedding) for embedding in embeddings]
class VannaServer:
def __init__(self, config):
self.config = config
self.vn = self._initialize_vn()
def _initialize_vn(self):
config = self.config
supplier = config["supplier"]
llm_type = config["llm_type"]
model_ = config["model"]
api_key = config["api_key"]
ollama_host = config["ollama_host"] if "ollama_host" in config else None
milvus_uri = config["milvus_uri"]
sql_type = config["sql_type"]
host = config["host"] if "host" in config else os.getenv("DB_HOST", "localhost")
dbname = config["dbname"] if "dbname" in config else os.getenv("DB_NAME", "dify_data")
user = config["user"] if "user" in config else os.getenv("DB_USER", "root")
password = config["password"] if "password" in config else os.getenv("DB_PASSWORD", "mysql")
port = config["port"] if "port" in config else int(os.getenv("DB_PORT", 3306))
milvus_database = config["milvus_database"] if "milvus_database" in config else "test"
milvus_client = MilvusClient(uri=milvus_uri,db_name=milvus_database)
embedding_host = config["embedding_host"] if "embedding_host" in config else 'http://wsd.wisdomidata.com:19042'
embedding_model = config["embedding_model"] if "embedding_model" in config else "bge-m3" # BAAI/bge-m3
embedding_function = CustomEmbeddingFunction({
"host": embedding_host,
"embed_model": embedding_model
})
chat_llm = Ollama
if llm_type == "ollama":
config = {
'model': model_, # 本地ollama大模型名称
'ollama_host': ollama_host, # 本地ollama大模型服务地址
'milvus_client': milvus_client, # 本地milvus向量数据库服务地址
"n_results": 12,
"embedding_function": embedding_function,
}
else:
config = {
'model': model_, # 本地ollama大模型名称
'api_key': api_key, # 本地ollama大模型服务地址
'milvus_client': milvus_client, # 本地milvus向量数据库服务地址
"n_results": 12,
"embedding_function": embedding_function,
}
if llm_type == "tongyi":
chat_llm = QianWenAI_Chat
elif llm_type == "deepseek":
chat_llm = DeepSeekChat
MyVanna = make_vanna_class(ChatClass=chat_llm)
vn = MyVanna(config)
if sql_type == "postgres":
vn.connect_to_postgres(host=host, dbname=dbname, user=user, password=password, port=port)
elif sql_type == "mysql":
vn.connect_to_mysql(host=host, dbname=dbname, user=user, password=password, port=port)
return vn
def schema_train(self):
# The information schema query may need some tweaking depending on your database. This is a good starting point.
df_information_schema = self.vn.run_sql("SELECT * FROM INFORMATION_SCHEMA.COLUMNS where table_schema = 'public'")
# This will break up the information schema into bite-sized chunks that can be referenced by the LLM
plan = self.vn.get_training_plan_generic(df_information_schema)
# print(plan)
# If you like the plan, then uncomment this and run it to train
self.vn.train(plan=plan)
# 更新建表DDL语句
def refresh_create_table_ddl_train(self):
sql = """
SELECT
'CREATE TABLE '
|| C.TABLE_NAME
|| ' ('
|| C.COLUMN_NAMES
|| ');'
|| C.COMMENT_COLUMNS
|| CASE WHEN FK.FOREIGN_KEY_COLUMNS IS NOT NULL THEN FK.FOREIGN_KEY_COLUMNS ELSE '' END
|| CASE WHEN FK.FOREIGN_KEY_DESC IS NOT NULL THEN FK.FOREIGN_KEY_DESC ELSE '' END
|| 'COMMENT ON TABLE '
|| C.TABLE_NAME
|| ' IS '''
|| G.DESCRIPTION
|| ''';'
AS DDL,
C.TABLE_NAME
FROM (
SELECT
COL.TABLE_NAME,
COL.TABLE_SCHEMA,
STRING_AGG(
COL.COLUMN_NAME
|| ' '
|| COL.DATA_TYPE
|| COALESCE('(' || COL.CHARACTER_MAXIMUM_LENGTH || ')', '')
|| COALESCE(' DEFAULT ' || COL.COLUMN_DEFAULT, '')
|| CASE
WHEN COL.IS_NULLABLE = 'NO' THEN ' NOT NULL'
ELSE ''
END,
','
) AS COLUMN_NAMES,
STRING_AGG(
'COMMENT ON COLUMN '
|| COL.TABLE_NAME
|| '.'
|| COL.COLUMN_NAME
|| ' IS '''
|| PGD.DESCRIPTION
|| ''';',
''
) AS COMMENT_COLUMNS
FROM
PG_CATALOG.PG_STATIO_ALL_TABLES AS ST
INNER JOIN
PG_CATALOG.PG_DESCRIPTION AS PGD
ON PGD.OBJOID = ST.RELID
INNER JOIN
INFORMATION_SCHEMA.COLUMNS AS COL
ON (
COL.TABLE_SCHEMA = ST.SCHEMANAME
AND COL.TABLE_NAME = ST.RELNAME
AND COL.ORDINAL_POSITION = PGD.OBJSUBID
)
WHERE
COL.TABLE_SCHEMA = 'public'
GROUP BY
COL.TABLE_SCHEMA,
COL.TABLE_NAME
) C
LEFT JOIN (
SELECT
N.NSPNAME AS SCHEMA_NAME,
C.RELNAME AS TABLE_NAME,
D.DESCRIPTION
FROM
PG_CATALOG.PG_DESCRIPTION D
JOIN
PG_CATALOG.PG_CLASS C
ON C.OID = D.OBJOID
JOIN
PG_CATALOG.PG_NAMESPACE N
ON N.OID = C.RELNAMESPACE
WHERE
C.RELKIND = 'r'
AND D.OBJSUBID = 0
) G
ON G.SCHEMA_NAME = C.TABLE_SCHEMA
AND G.TABLE_NAME = C.TABLE_NAME
LEFT JOIN (
SELECT rel_src.relname AS source_table,
STRING_AGG(
'ALTER TABLE '
|| rel_src.relname
|| ' ADD CONSTRAINT '
|| con.conname
|| ' FOREIGN KEY ('
|| att_src.attname
|| ') REFERENCES '
|| rel_tgt.relname
|| '('
|| att_tgt.attname
|| ');'
,
''
) AS FOREIGN_KEY_COLUMNS,
STRING_AGG(
'COMMENT ON CONSTRAINT '
|| con.conname
|| ' ON '
|| rel_src.relname
|| ' IS '''
|| d.description
|| ''';',
''
) AS FOREIGN_KEY_DESC
FROM
pg_constraint con
JOIN pg_class rel_src ON rel_src.oid = con.conrelid
JOIN pg_class rel_tgt ON rel_tgt.oid = con.confrelid
JOIN pg_attribute att_src ON att_src.attrelid = rel_src.oid AND att_src.attnum = ANY(con.conkey)
JOIN pg_attribute att_tgt ON att_tgt.attrelid = rel_tgt.oid AND att_tgt.attnum = ANY(con.confkey)
LEFT JOIN pg_description d ON d.objoid = con.oid
WHERE
con.contype = 'f'
GROUP BY
rel_src.relname
) FK ON FK.source_table = C.TABLE_NAME
WHERE C.TABLE_NAME NOT IN ('flyway_table_dict','flyway_schema_history')
"""
# The information schema query may need some tweaking depending on your database. This is a good starting point.
c_table_ddl_list = self.vn.run_sql(sql)
# 将 DataFrame 转换为字典列表
c_table_ddl_records = c_table_ddl_list.to_dict(orient='records')
exist_ddl_data = self.vn.milvus_client.query(
collection_name="vannaddl",
output_fields=["*"],
limit=10000,
)
exists_list = filter(lambda m: m["ddl"].startswith("CREATE TABLE "), exist_ddl_data)
remove_ids = [exist["id"] for exist in exists_list]
if len(remove_ids) > 0:
self.vn.milvus_client.delete(collection_name="vannaddl", ids=remove_ids)
for table_ddl in c_table_ddl_records:
self.vn.train(ddl=table_ddl["ddl"])
self.vn.milvus_client.refresh_load(collection_name="vannaddl")
def refresh_schema_train(self):
exist_doc_data = self.vn.milvus_client.query(
collection_name="vannadoc",
output_fields=["*"],
limit=10000,
)
exists_list = filter(lambda m: m["doc"].startswith("The following columns are in the "), exist_doc_data)
remove_ids = [exist["id"] for exist in exists_list]
if len(remove_ids) > 0:
self.vn.milvus_client.delete(collection_name="vannadoc", ids=remove_ids)
self.schema_train()
self.vn.milvus_client.refresh_load(collection_name="vannadoc")
def update_schema_train_list(self,docs : list[str]):
exist_doc_data = self.vn.milvus_client.query(
collection_name="vannadoc",
output_fields=["*"],
limit=10000,
)
exists_list = filter(lambda m: not m["doc"].startswith("The following columns are in the "), exist_doc_data)
remove_ids = [exist["id"] for exist in exists_list]
if len(remove_ids) > 0:
self.vn.milvus_client.delete(collection_name="vannadoc", ids=remove_ids)
dict_docs = self.get_dict_docs()
docs.extend(dict_docs)
for doc in docs:
self.vn.train(documentation=doc)
# self.schema_train()
self.vn.milvus_client.refresh_load(collection_name="vannadoc")
def get_dict_docs(self) -> list[str]:
dict_docs = []
sql = "select id,table_name,column_name,column_remark,table_remark,dict_values from flyway_table_dict"
c_table_dict_list = self.vn.run_sql(sql)
# 将 DataFrame 转换为字典列表
c_table_dict_records = c_table_dict_list.to_dict(orient='records')
table_names = list(set(item['table_name'] for item in c_table_dict_records))
grouped = defaultdict(list)
for table_dict in c_table_dict_records:
table_name = table_dict['table_name'] # 分组依据字段
grouped[table_name].append(table_dict)
grouped_dict = dict(grouped)
for table_name in table_names:
columns_list = grouped_dict[table_name]
dict_values = ';'.join(f"字段:{item['column_remark']}({item['column_name']})的值:{item["dict_values"]}" for item in columns_list)
column = columns_list[0]
doc = f"{column["table_remark"]}表:{column["table_name"]},{dict_values}"
dict_docs.append(doc)
return dict_docs
def vn_train(self, question="", sql="", documentation="", ddl=""):
if question and sql:
# 训练问答对
self.vn.train(
question=question,
sql=sql
)
elif sql:
# You can also add SQL queries to your training data. This is useful if you have some queries already laying around. You can just copy and paste those from your editor to begin generating new SQL.
self.vn.train(sql=sql)
if documentation:
# Sometimes you may want to add documentation about your business terminology or definitions.
self.vn.train(documentation=documentation)
if ddl:
# You can also add DDL queries to your training data. This is useful if you have some queries already laying around. You can just copy and paste those from your editor to begin generating new SQL.
self.vn.train(ddl=ddl)
def get_training_data(self):
training_data = self.vn.get_training_data()
# print(training_data)
return training_data
def ask(self, question, visualize=True, auto_train=True, *args, **kwargs):
sql, df, fig = ask(self.vn, question, visualize=visualize, auto_train=auto_train, *args, **kwargs)
return sql, df, fig
def generate_sql(self, question):
return self.vn.generate_sql(question=question)
def run_sql(self, sql):
return self.vn.run_sql(sql=sql)
def training_data_export(self):
training_data = self.vn.milvus_client.query(
collection_name="vannasql",
output_fields=["*"],
limit=10000,
)
result = []
if training_data is not None:
result = [{"question":t['text'], "sql": t['sql']} for t in training_data]
return result
def training_data_import(self, data_list):
empty_items = list(filter(
lambda item: item['question'] is None or item['question'] == "" or item['sql'] is None or item['sql'] == "",
data_list
))
if bool(empty_items):
return True
exist_doc_data = self.vn.milvus_client.query(
collection_name="vannasql",
output_fields=["*"],
limit=10000,
)
data_texts = {t["question"]: t for t in data_list}
if bool(exist_doc_data):
remove_ids = [item["id"] for item in exist_doc_data if item['text'] in data_texts ]
if bool(remove_ids):
self.vn.milvus_client.delete(collection_name="vannasql", ids=remove_ids)
for item in data_list:
self.vn.train(
question=item["question"],
sql=item["sql"],
)
self.vn.milvus_client.refresh_load(collection_name="vannasql")
return False
def make_vanna_class(ChatClass=Ollama):
class MyVanna(Milvus_VectorStore, ChatClass):
def __init__(self, config=None):
Milvus_VectorStore.__init__(self, config=config)
ChatClass.__init__(self, config=config)
def is_sql_valid(self, sql: str) -> bool:
# Your implementation here
return False
def generate_query_explanation(self, sql: str):
my_prompt = [
self.system_message("You are a helpful assistant that will explain a SQL query"),
self.user_message("Explain this SQL query: " + sql),
]
return self.submit_prompt(prompt=my_prompt)
return MyVanna
# 使用示例
if __name__ == '__main__':
config = {"supplier": "GITEE"}
server = VannaServer(config)
# server.schema_train()
server.ask("汇总每个类别的销售量和销售额, 并按照销售量进行降序排列")

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

@ -7,7 +7,7 @@ from Crypto.Random import get_random_bytes
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.ext_storage import storage from extensions.ext_storage import storage
from libs import gmpy2_pkcs10aep_cipher from libs import gmpy2_pkcs10aep_cipher
from models import Tenant
def generate_key_pair(tenant_id): def generate_key_pair(tenant_id):
private_key = RSA.generate(2048) private_key = RSA.generate(2048)
@ -16,11 +16,11 @@ def generate_key_pair(tenant_id):
pem_private = private_key.export_key() pem_private = private_key.export_key()
pem_public = public_key.export_key() pem_public = public_key.export_key()
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem" # filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
#
storage.save(filepath, pem_private) # storage.save(filepath, pem_private)
return pem_public.decode() return pem_public.decode(), pem_private.decode()
prefix_hybrid = b"HYBRID:" prefix_hybrid = b"HYBRID:"
@ -46,16 +46,21 @@ def encrypt(text, public_key):
def get_decrypt_decoding(tenant_id): def get_decrypt_decoding(tenant_id):
filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest()) from extensions.ext_database import db
# filepath = "privkeys/{tenant_id}".format(tenant_id=tenant_id) + "/private.pem"
# cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(filepath.encode()).hexdigest())
cache_key = "tenant_privkey:{hash}".format(hash=hashlib.sha3_256(tenant_id.encode('utf-8')).hexdigest())
private_key = redis_client.get(cache_key) private_key = redis_client.get(cache_key)
if not private_key: if not private_key:
try: # try:
private_key = storage.load(filepath) # private_key = storage.load(filepath)
except FileNotFoundError: # except FileNotFoundError:
raise PrivkeyNotFoundError("Private key not found, tenant_id: {tenant_id}".format(tenant_id=tenant_id)) # raise PrivkeyNotFoundError("Private key not found, tenant_id: {tenant_id}".format(tenant_id=tenant_id))
tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).one_or_none()
private_key = tenant.encrypt_private_key
redis_client.setex(cache_key, 120, private_key) redis_client.setex(cache_key, 120, private_key)
rsa_key = RSA.import_key(private_key) rsa_key = RSA.import_key(private_key)

@ -0,0 +1,81 @@
"""add target_tenant_id
Revision ID: 0c79d303c76d
Revises: d20049ed0af6
Create Date: 2025-05-28 10:19:55.445389
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '0c79d303c76d'
down_revision = 'd20049ed0af6'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('accounts', schema=None) as batch_op:
batch_op.add_column(sa.Column('target_tenant_id', sa.String(length=255), nullable=True))
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.add_column(sa.Column('target_tenant_id', sa.String(length=255), nullable=True))
with op.batch_alter_table('tenants', schema=None) as batch_op:
batch_op.add_column(sa.Column('target_tenant_id', sa.String(length=255), nullable=True))
with op.batch_alter_table('tool_published_apps', schema=None) as batch_op:
batch_op.add_column(sa.Column('tenant_id', models.types.StringUUID(), nullable=False))
batch_op.add_column(sa.Column('conversation_id', models.types.StringUUID(), nullable=True))
batch_op.add_column(sa.Column('file_key', sa.String(length=255), nullable=False))
batch_op.add_column(sa.Column('mimetype', sa.String(length=255), nullable=False))
batch_op.add_column(sa.Column('original_url', sa.String(length=2048), nullable=True))
batch_op.add_column(sa.Column('name', sa.String(), nullable=False))
batch_op.add_column(sa.Column('size', sa.Integer(), nullable=False))
with op.batch_alter_table('upload_files', schema=None) as batch_op:
batch_op.add_column(sa.Column('file_id', sa.String(length=255), nullable=True))
with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op:
batch_op.drop_index('workflow_conversation_variables_app_id_idx')
batch_op.drop_index('workflow_conversation_variables_created_at_idx')
batch_op.create_index('workflow__conversation_variables_app_id_idx', ['app_id'], unique=False)
batch_op.create_index('workflow__conversation_variables_created_at_idx', ['created_at'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op:
batch_op.drop_index('workflow__conversation_variables_created_at_idx')
batch_op.drop_index('workflow__conversation_variables_app_id_idx')
batch_op.create_index('workflow_conversation_variables_created_at_idx', ['created_at'], unique=False)
batch_op.create_index('workflow_conversation_variables_app_id_idx', ['app_id'], unique=False)
with op.batch_alter_table('upload_files', schema=None) as batch_op:
batch_op.drop_column('file_id')
with op.batch_alter_table('tool_published_apps', schema=None) as batch_op:
batch_op.drop_column('size')
batch_op.drop_column('name')
batch_op.drop_column('original_url')
batch_op.drop_column('mimetype')
batch_op.drop_column('file_key')
batch_op.drop_column('conversation_id')
batch_op.drop_column('tenant_id')
with op.batch_alter_table('tenants', schema=None) as batch_op:
batch_op.drop_column('target_tenant_id')
with op.batch_alter_table('datasets', schema=None) as batch_op:
batch_op.drop_column('target_tenant_id')
with op.batch_alter_table('accounts', schema=None) as batch_op:
batch_op.drop_column('target_tenant_id')
# ### end Alembic commands ###

@ -0,0 +1,33 @@
"""providers table add private_key
Revision ID: 582c477e905b
Revises: 8e00c75b3907
Create Date: 2025-05-28 10:41:31.662951
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '582c477e905b'
down_revision = '8e00c75b3907'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tenants', schema=None) as batch_op:
batch_op.add_column(sa.Column('encrypt_private_key', sa.Text(), nullable=True))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('tenants', schema=None) as batch_op:
batch_op.drop_column('encrypt_private_key')
# ### end Alembic commands ###

@ -100,11 +100,13 @@ class Account(UserMixin, Base):
initialized_at = db.Column(db.DateTime) initialized_at = db.Column(db.DateTime)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
target_tenant_id = db.Column(db.String(255), nullable=True)
@reconstructor @reconstructor
def init_on_load(self): def init_on_load(self):
self.role: Optional[TenantAccountRole] = None self.role: Optional[TenantAccountRole] = None
self._current_tenant: Optional[Tenant] = None self._current_tenant: Optional[Tenant] = None
target_tenant_id = db.Column(db.String(255), nullable=True)
@property @property
def is_password_set(self): def is_password_set(self):
@ -199,11 +201,13 @@ class Tenant(Base):
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
encrypt_public_key = db.Column(db.Text) encrypt_public_key = db.Column(db.Text)
encrypt_private_key = db.Column(db.Text)
plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying")) plan = db.Column(db.String(255), nullable=False, server_default=db.text("'basic'::character varying"))
status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying")) status = db.Column(db.String(255), nullable=False, server_default=db.text("'normal'::character varying"))
custom_config = db.Column(db.Text) custom_config = db.Column(db.Text)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
target_tenant_id = db.Column(db.String(255), nullable=True)
def get_accounts(self) -> list[Account]: def get_accounts(self) -> list[Account]:
return ( return (

@ -63,6 +63,7 @@ class Dataset(Base):
collection_binding_id = db.Column(StringUUID, nullable=True) collection_binding_id = db.Column(StringUUID, nullable=True)
retrieval_model = db.Column(JSONB, nullable=True) retrieval_model = db.Column(JSONB, nullable=True)
built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) built_in_field_enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
target_tenant_id = db.Column(db.String(255), nullable=True)
@property @property
def dataset_keyword_table(self): def dataset_keyword_table(self):
@ -418,6 +419,7 @@ class Document(Base):
"mime_type": file_detail.mime_type, "mime_type": file_detail.mime_type,
"created_by": file_detail.created_by, "created_by": file_detail.created_by,
"created_at": file_detail.created_at.timestamp(), "created_at": file_detail.created_at.timestamp(),
"file_id": file_detail.file_id,
} }
} }
elif self.data_source_type in {"notion_import", "website_crawl"}: elif self.data_source_type in {"notion_import", "website_crawl"}:

@ -1388,6 +1388,26 @@ class AppAnnotationSetting(Base):
updated_user_id = db.Column(StringUUID, nullable=False) updated_user_id = db.Column(StringUUID, nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def created_account(self):
account = (
db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id)
.first()
)
return account
@property
def updated_account(self):
account = (
db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id)
.first()
)
return account
@property @property
def collection_binding_detail(self): def collection_binding_detail(self):
from .dataset import DatasetCollectionBinding from .dataset import DatasetCollectionBinding
@ -1545,6 +1565,7 @@ class UploadFile(Base):
used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True) used_at: Mapped[datetime | None] = db.Column(db.DateTime, nullable=True)
hash: Mapped[str | None] = db.Column(db.String(255), nullable=True) hash: Mapped[str | None] = db.Column(db.String(255), nullable=True)
source_url: Mapped[str] = mapped_column(sa.TEXT, default="") source_url: Mapped[str] = mapped_column(sa.TEXT, default="")
file_id = db.Column(db.String(255), nullable=True, default=0)
def __init__( def __init__(
self, self,
@ -1564,6 +1585,7 @@ class UploadFile(Base):
used_at: datetime | None = None, used_at: datetime | None = None,
hash: str | None = None, hash: str | None = None,
source_url: str = "", source_url: str = "",
file_id:str| None = None
): ):
self.tenant_id = tenant_id self.tenant_id = tenant_id
self.storage_type = storage_type self.storage_type = storage_type
@ -1580,6 +1602,7 @@ class UploadFile(Base):
self.used_at = used_at self.used_at = used_at
self.hash = hash self.hash = hash
self.source_url = source_url self.source_url = source_url
self.file_id = file_id
class ApiRequest(Base): class ApiRequest(Base):

10375
api/poetry.lock generated

File diff suppressed because it is too large Load Diff

@ -0,0 +1,4 @@
[virtualenvs]
in-project = true
create = true
prefer-active-python = true

@ -1,6 +1,7 @@
import base64 import base64
import json import json
import logging import logging
import random
import secrets import secrets
import uuid import uuid
from datetime import UTC, datetime, timedelta from datetime import UTC, datetime, timedelta
@ -175,6 +176,7 @@ class AccountService:
return cast(Account, account) return cast(Account, account)
@staticmethod @staticmethod
def update_account_password(account, password, new_password): def update_account_password(account, password, new_password):
"""update account password""" """update account password"""
@ -587,6 +589,10 @@ class AccountService:
return False return False
def _get_login_cache_key(*, account_id: str, token: str):
return f"account_login:{account_id}:{token}"
class TenantService: class TenantService:
@staticmethod @staticmethod
def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant: def create_tenant(name: str, is_setup: Optional[bool] = False, is_from_dashboard: Optional[bool] = False) -> Tenant:
@ -604,7 +610,10 @@ class TenantService:
db.session.add(tenant) db.session.add(tenant)
db.session.commit() db.session.commit()
tenant.encrypt_public_key = generate_key_pair(tenant.id) # tenant.encrypt_public_key = generate_key_pair(tenant.id)
pem_public, pem_private = generate_key_pair(tenant.id)
tenant.encrypt_public_key = pem_public
tenant.encrypt_private_key = pem_private
db.session.commit() db.session.commit()
return tenant return tenant

@ -0,0 +1,348 @@
import logging
import yaml
from typing import Optional
import flask_login
from pathlib import Path
from constants.languages import languages
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.plugin.manager.exc import PluginDaemonClientSideError
from extensions.ext_database import db
from models.account import (
Account,
Tenant,
)
from services.account_service import AccountService, TenantService
from services.dataset_service import DatasetService
from services.errors.account import (
AccountRegisterError,
TenantNotFoundError,
)
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.ext.dataset_ext_service import DatasetExtService
from services.model_load_balancing_service import ModelLoadBalancingService
from services.model_provider_service import ModelProviderService
from services.plugin.plugin_service import PluginService
from configs import dify_config
from configs.ext_config import get_ext_config
import os
class AccountInfo:
def __init__(self, email, name, user_id, tenant_id):
self.email = email
self.name = name
self.user_id=user_id
self.tenant_id=tenant_id
def to_dict(self):
return {
"tenant_id": self.email,
"tenant_name": self.name,
"api_key": self.tenant_id,
}
class TenantAccountInfo:
def __init__(self, tenant_id:str,
tenant_name:str,
admin_account:str,
admin_account_password:str,
):
self.tenant_id = tenant_id
self.tenant_name = tenant_name
self.admin_account=admin_account
self.admin_account_password=admin_account_password
def to_dict(self):
return {
"tenant_id": self.tenant_id,
"tenant_name": self.tenant_name,
"admin_account": self.admin_account,
"admin_account_password": self.admin_account_password,
}
class TenantData:
def __init__(self,
api_key:str,
dataset_ids:list[str]
):
self.api_key=api_key
self.dataset_ids=dataset_ids
def to_dict(self):
return {
"api_key": self.api_key,
"dataset_ids": self.dataset_ids,
}
class AccountExtService:
@staticmethod
def create_account_and_tenant(
email: str,
name: str,
tenant_name: str,
target_tenant_id: str,
interface_language: Optional[str] = None,
password: Optional[str] = None
) -> Account:
"""create account"""
account = AccountService.create_account(
email=email, name=name, interface_language=interface_language, password=password, is_setup=True
)
account.target_tenant_id = target_tenant_id
TenantService.create_owner_tenant_if_not_exist(account=account,name=tenant_name,is_setup=True)
account.current_tenant.target_tenant_id = target_tenant_id
db.session.commit()
return account
@staticmethod
def get_admin_account() -> Account:
admin = db.session.query(Account).filter(Account.target_tenant_id=="100").first()
return admin
@staticmethod
def update_account_list(
accounts: list[AccountInfo],
target_tenant_id: str,
interface_language: Optional[str] = None,
):
db.session.begin_nested()
"""Register account"""
try:
# 获取对应的企业
tenant = TenantExtService.get_tenant_by_target_tenant_id(target_tenant_id=target_tenant_id)
if tenant is None:
raise TenantNotFoundError("企业未初始,请联系管理员!")
# 获取所有的用户列表
exists = db.session.query(Account).filter(Account.target_tenant_id == target_tenant_id).all()
#
existDict = { account.email: account for account in exists }
for account in accounts:
email = account["email"]
if email in existDict:
existAccount = existDict[email]
existAccount.name = account["name"]
existAccount.email = account["email"]
existAccount.user_id = account["user_id"]
else:
newAccount = AccountService.create_account(email=account["email"],
name=account["name"],
interface_language=interface_language or languages[0],
password="wisdom@123",
is_setup=True)
newAccount.user_id = account["user_id"]
newAccount.target_tenant_id = target_tenant_id
# 创建企业关系
TenantService.create_tenant_member(tenant, newAccount)
db.session.commit()
except WorkSpaceNotAllowedCreateError:
db.session.rollback()
except AccountRegisterError as are:
db.session.rollback()
logging.exception("Register failed")
raise are
except Exception as e:
db.session.rollback()
logging.exception("Register failed")
raise AccountRegisterError(f"Registration failed: {e}") from e
class TenantExtService:
@staticmethod
def get_tenant() -> Tenant:
# 获取第一个企业,为默认企业
tenant = db.session.query(Tenant).first()
return tenant
@staticmethod
def get_tenant_by_target_tenant_id(target_tenant_id:str) -> Tenant:
# 获取第一个企业,为默认企业
tenant = db.session.query(Tenant).filter(Tenant.target_tenant_id == target_tenant_id).first()
return tenant
@staticmethod
def setModeConfig(tenant_id:str, args:dict[str, object], provider:str) -> None:
model_load_balancing_service = ModelLoadBalancingService()
if (
"load_balancing" in args
and args["load_balancing"]
and "enabled" in args["load_balancing"]
and args["load_balancing"]["enabled"]
):
if "configs" not in args["load_balancing"]:
raise ValueError("invalid load balancing configs")
# save load balancing configs
model_load_balancing_service.update_load_balancing_configs(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
configs=args["load_balancing"]["configs"],
)
# enable load balancing
model_load_balancing_service.enable_model_load_balancing(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
)
else:
# disable load balancing
model_load_balancing_service.disable_model_load_balancing(
tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"]
)
if args.get("config_from", "") != "predefined-model":
model_provider_service = ModelProviderService()
try:
model_provider_service.save_model_credentials(
tenant_id=tenant_id,
provider=provider,
model=args["model"],
model_type=args["model_type"],
credentials=args["credentials"],
)
except CredentialsValidateFailedError as ex:
logging.exception(
f"Failed to save model credentials, tenant_id: {tenant_id},"
f" model: {args.get('model')}, model_type: {args.get('model_type')}"
)
raise ValueError(str(ex))
@staticmethod
def install_plugin(tenant_id:str):
TenantExtService.install_langgenius(tenant_id=tenant_id)
TenantExtService.install_model(tenant_id=tenant_id)
@staticmethod
def install_model(tenant_id:str):
params = {
"INIT_MODEL_LLM_NAME" : dify_config.INIT_MODEL_LLM_NAME,
"INIT_MODEL_LLM_CONTEXT_SIZE" : dify_config.INIT_MODEL_LLM_CONTEXT_SIZE,
"INIT_MODEL_LLM_MAX_TOKENS" : dify_config.INIT_MODEL_LLM_MAX_TOKENS,
"INIT_MODEL_LLM_BASE_URL" : dify_config.INIT_MODEL_LLM_BASE_URL
}
llm_config = get_ext_config(file_name="plugin_llm_config.yml",params = params)
TenantExtService.setModeConfig(
tenant_id=tenant_id,args=llm_config,provider=dify_config.INIT_MODEL_LLM_PROVIDER
)
params = {
"INIT_MODEL_TEXT_EMBEDDING_NAME" : dify_config.INIT_MODEL_TEXT_EMBEDDING_NAME,
"INIT_MODEL_TEXT_EMBEDDING_CONTEXT_SIZE" : dify_config.INIT_MODEL_TEXT_EMBEDDING_CONTEXT_SIZE,
"INIT_MODEL_TEXT_EMBEDDING_MAX_TOKENS" : dify_config.INIT_MODEL_TEXT_EMBEDDING_MAX_TOKENS,
"INIT_MODEL_TEXT_EMBEDDING_BASE_URL" : dify_config.INIT_MODEL_TEXT_EMBEDDING_BASE_URL
}
text_embedding_config = get_ext_config(file_name="plugin_embedding_config.yml", params=params)
TenantExtService.setModeConfig(
tenant_id=tenant_id,args=text_embedding_config,provider=dify_config.INIT_MODEL_TEXT_EMBEDDING_PROVIDER
)
params = {
"INIT_MODEL_TEXT_EMBEDDING_RERANK_NAME": dify_config.INIT_MODEL_TEXT_EMBEDDING_RERANK_NAME,
"INIT_MODEL_TEXT_EMBEDDING_RERANK_BASE_URL": dify_config.INIT_MODEL_TEXT_EMBEDDING_RERANK_BASE_URL,
}
text_embedding_rerank_config = get_ext_config(file_name="plugin_embedding_rerank_config.yml", params=params)
TenantExtService.setModeConfig(
tenant_id=tenant_id,args=text_embedding_rerank_config,provider=dify_config.INIT_MODEL_TEXT_EMBEDDING_RERANK_PROVIDER
)
@staticmethod
def install_langgenius(tenant_id: str):
upload_unique_identifiers = TenantExtService.upload_langgenius(tenant_id=tenant_id)
# plugin_unique_identifiers = dify_config.PLUGIN_UNIQUE_IDENTIFIERS.split(",") if dify_config.PLUGIN_UNIQUE_IDENTIFIERS else []
# 查询已经安装的
tasks = PluginService.list(tenant_id)
# 已经安装的插件
exists_plugin_unique_identifiers = [item.plugin_unique_identifier for item in tasks]
# 去除已经安装的插件ID只保留未安装的插件ID
new_unique_identifiers = [uui for uui in upload_unique_identifiers if uui not in exists_plugin_unique_identifiers]
# 安装插件
PluginService.install_from_marketplace_pkg(tenant_id, new_unique_identifiers)
@staticmethod
def upload_langgenius(tenant_id: str) -> list[str]:
directory = Path(__file__).parent.parent.parent / 'plugins' / 'langgenius'
unique_identifiers = []
for filename in os.listdir(directory):
file_path = os.path.join(directory, filename)
if os.path.isfile(file_path):
print(f"读取文件:{file_path}")
with open(file_path, 'rb') as f:
content = f.read()
try:
response = PluginService.upload_pkg(tenant_id=tenant_id, pkg=content)
unique_identifier = response.unique_identifier
unique_identifiers.append(unique_identifier)
except PluginDaemonClientSideError as e:
raise ValueError(e)
return unique_identifiers
@staticmethod
def enable_tenant(
target_tenant_id: str,
target_tenant_name: str,
) -> TenantAccountInfo:
db.session.begin_nested()
password = "wisdom@123"
try:
email = f"admin@{target_tenant_id}.com"
admin_name = f"{target_tenant_name}-管理员"
# 判断企业是否已经创建
tenant = TenantExtService.get_tenant_by_target_tenant_id(target_tenant_id)
if tenant is not None:
account = AccountService.get_user_through_email(email)
if account is None:
account = AccountService.create_account(email=email, name=admin_name, password=password, is_setup=True,interface_language="zh-Hans")
TenantService.create_tenant_member(tenant, account, role="owner")
account.target_tenant_id = target_tenant_id
else:
account = AccountExtService.create_account_and_tenant(email=email,
name=admin_name,
tenant_name=target_tenant_name,
target_tenant_id=target_tenant_id,
interface_language="zh-Hans",
password=password)
# 获取第一个企业,为默认企业
tenant = account.current_tenant
account_info = TenantAccountInfo(tenant_name=tenant.name,
tenant_id=tenant.id,
admin_account=admin_name,
admin_account_password=password)
return account_info
except Exception as e:
db.session.rollback()
logging.exception("Register failed")
raise AccountRegisterError(f"Registration failed: {e}") from e
@staticmethod
def init_tenant(
target_tenant_id: str,
target_tenant_name: str,
) -> TenantData:
db.session.begin_nested()
try:
account = flask_login.current_user
tenant = account.current_tenant
# 初始化大模型插槽
TenantExtService.install_plugin(tenant_id=tenant.id)
# 初始化知识库
datasets = DatasetExtService.init_dataset(
tenant=tenant, target_tenant_id=target_tenant_id,target_tenant_name=target_tenant_name,account=account
)
# 获取Api token
api_token = DatasetExtService().get_or_add_datasets_api_token(tenant_id=tenant.id)
db.session.commit()
dataset_ids = [dataset.id for dataset in datasets]
tenant_data = TenantData(api_key=api_token.token,
dataset_ids=dataset_ids)
return tenant_data
except Exception as e:
db.session.rollback()
logging.exception("Register failed")
raise AccountRegisterError(f"Registration failed: {e}") from e

@ -0,0 +1,177 @@
from models import ApiToken, Account, Tenant
from models.dataset import (
Dataset,DocumentSegment
)
from core.rag.models.document import Document as DocumentModel
from core.errors.error import (
LLMBadRequestError,
ModelCurrentlyNotSupportError,
ProviderTokenNotInitError,
QuotaExceededError,
)
from controllers.console.app.error import (
ProviderModelCurrentlyNotSupportError,
ProviderNotInitializeError,
ProviderQuotaExceededError,
)
from extensions.ext_database import db
from services.dataset_service import DatasetService, DocumentService
from configs.ext_config import get_init_knowledge_config
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
class DatasetExtService:
resource_type = "dataset"
token_prefix = "dataset-"
max_keys = 10
@staticmethod
def get_datasets(tenant_id=None, target_tenant_id=None) -> list[Dataset]:
datasets = (Dataset.query
.filter(Dataset.tenant_id == tenant_id,Dataset.target_tenant_id == target_tenant_id)
.all())
return datasets
@staticmethod
def init_dataset(tenant:Tenant=None, target_tenant_id:str=None,target_tenant_name:str=None, account:Account=None) -> list[Dataset]:
# 判断是否有知识库,如果没有,创建知识库
datasets = (DatasetExtService.get_datasets(tenant_id=tenant.id, target_tenant_id=target_tenant_id))
if not datasets:
public_name = f"PUBLIC_KNOWLEDGE"
public_description = f"{target_tenant_name}的公共知识库"
public_dataset = DatasetService.create_empty_dataset(tenant_id=tenant.id,
name=public_name,
description=public_description,
indexing_technique="",
account=account)
public_dataset.target_tenant_id = target_tenant_id
company_name = f"COMPANY_KNOWLEDGE"
company_description = f"{target_tenant_name}的企业知识库"
company_dataset = DatasetService.create_empty_dataset(tenant_id=tenant.id,
name=company_name,
description=company_description,
indexing_technique="",
account=account)
company_dataset.target_tenant_id = target_tenant_id
db.session.commit()
datasets = [public_dataset,company_dataset]
return datasets
@staticmethod
def set_dataset_config(dataset=None,current_user=None):
# 取默认的值
args = get_init_knowledge_config({})
# validate args
knowledge_config = KnowledgeConfig(**args)
print("knowledge_config")
try:
DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
# @staticmethod
def get_or_add_datasets_api_token(self,tenant_id: str):
api_tokens = (
db.session.query(ApiToken)
.filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == tenant_id)
.all()
)
if not api_tokens:
key = ApiToken.generate_api_key(self.token_prefix, 24)
api_token = ApiToken()
api_token.tenant_id = tenant_id
api_token.token = key
api_token.type = self.resource_type
db.session.add(api_token)
db.session.commit()
return api_token
else:
return api_tokens[-1]
class DocumentExtService:
# 为文档
@staticmethod
def set_next_segments(all_documents: list[DocumentModel]) :
# 判断文档是否为空
if all_documents:
document_ids = []
doc_segment_ids = []
for document in all_documents:
if document.children is None:
doc_segment_id = document.metadata["doc_id"]
document_id = document.metadata["document_id"]
doc_segment_ids.append(doc_segment_id)
document_ids.append(document_id)
# 找到文档的所有的
if len(document_ids) > 0:
document_segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all()
document_segment_data = {}
for document_segment in document_segments:
key = document_segment.document_id
if key not in document_segment_data:
document_segment_data[key] = []
document_segment_data[key].append(document_segment)
DocumentExtService.merged_next_segment_content(all_documents=all_documents,document_segment_data=document_segment_data,doc_segment_ids=doc_segment_ids)
# 为文档
@staticmethod
def merged_next_segment_content(all_documents: list[DocumentModel],document_segment_data: dict,doc_segment_ids: list) :
# 判断文档是否为空
if all_documents:
for document in all_documents:
if document.children is None:
doc_segment_id = document.metadata["doc_id"]
document_id = document.metadata["document_id"]
document_segments = document_segment_data[document_id]
next_segment = DocumentExtService.get_next_segment(doc_segment_id=doc_segment_id,document_segments=document_segments)
if next_segment and next_segment.index_node_id not in doc_segment_ids:
unin_content = DocumentExtService.merged_text(document.page_content, next_segment.content)
doc_segment_ids.append(next_segment.index_node_id)
document.page_content = unin_content
@staticmethod
def merged_text(text, target_text) -> str:
# 初始化最大重叠长度为0
max_overlap_length = 0 # 初始化变量max_overlap_length用于存储最大重叠长度
# 检查A的结尾与B的开头是否有大于10个字符的重叠
for overlap_length in range(1, min(len(text), len(target_text)) + 1): # 遍历可能的重叠长度从1到最小字符串长度
if text[-overlap_length:] == target_text[:overlap_length]: # 检查A的后缀和B的前缀是否相同
max_overlap_length = overlap_length # 更新最大重叠长度
# 如果有大于10个字符的重叠则合并字符串
if max_overlap_length > 10: # 判断最大重叠长度是否大于10
merged_string = text + target_text[max_overlap_length:] # 合并字符串,去掉重复部分
else:
merged_string = text
return merged_string
@staticmethod
def get_next_segment(doc_segment_id, document_segments: list[DocumentSegment]) -> DocumentSegment:
next_segment = None
if document_segments is not None and len(document_segments) > 0:
this_positions = -1
for index, document_segment in enumerate(document_segments):
if document_segment.index_node_id == doc_segment_id:
this_positions = document_segment.position
for document_segment in document_segments:
if document_segment.position == this_positions + 1:
next_segment = document_segment
break
return next_segment

@ -0,0 +1,299 @@
class ReadPdfService:
@classmethod
def load_content(cls, pdf_file_path: str) -> str | None:
doc = None
try:
# PDF标题提取需要pymupdf库
import fitz
doc = fitz.open(pdf_file_path)
contents = []
for page in doc:
page_height = page.rect.height
page_width = page.rect.width
# 旧版本中 Page 对象可能没有 get_text 方法,使用 getText 方法替代
# 从 v1.21.0 版本开始fitz.Page 类的 getText 方法已被弃用,应使用 get_text 方法
blocks = page.get_text("dict")["blocks"] # type: ignore
page_number = page.number
# if page_number == 39:
# print("-----------------------------")
blocks = cls.handle_blocks(blocks)
for block in blocks:
# if page_number == 39:
# print("-----------------------------")
# print(block)
type = block["type"]
if type == 1:
continue
content = cls.get_block_content(block,page_width,page_height)
if content is not None:
# if "图5-5" in content:
# # print(content)
# print("aaaaaaaaa")
contents.append(content)
return "\n".join(contents)
finally:
if doc is not None:
doc.close()
@classmethod
def handle_blocks(cls, blocks: list[dict]) -> list[dict] | None:
if blocks is not None:
handle_block_list = cls.sort_blocks(blocks)
handle_block_list = cls.filter_inner_img_block(handle_block_list)
return handle_block_list
return blocks
@classmethod
def sort_blocks(cls, blocks: list[dict]) -> list[dict] | None:
if blocks is not None:
def custom_sorted(block:dict) -> float:
bbox = block["bbox"]
return bbox[1]
sorted_data_asc = sorted(blocks, key=custom_sorted)
return sorted_data_asc
return blocks
@classmethod
def handle_lines(cls, blocks: list[dict]) -> list[dict] | None:
if blocks is not None:
def top_sorted(block:dict) -> float:
bbox = block["bbox"]
return bbox[1]
def left_sorted(block:dict) -> float:
bbox = block["bbox"]
return bbox[0]
sorted_data_asc = sorted(blocks, key=top_sorted)
sorted_data_asc = sorted(sorted_data_asc, key=left_sorted)
return sorted_data_asc
return blocks
@classmethod
def is_inner_img_block(cls, block: dict, img_bboxs: list[dict]) -> bool:
is_inner_img = False
type = block["type"]
if type != 1:
for img_bbox in img_bboxs:
if not is_inner_img:
bbox = block["bbox"]
is_inner_ = (bbox[0] >= img_bbox[0]
and bbox[1] >= img_bbox[1]
and bbox[2] <= img_bbox[2]
and bbox[3] <= img_bbox[3])
if is_inner_:
is_inner_img = True
return is_inner_img
@classmethod
def get_only_row_img_bboxs(cls, blocks: list[dict],img_bboxs: list[dict]) -> list[dict]:
# 判断图片是否是单独一行,单独一行的图片,内部的文字正常处理,反之,不处理(如果两个图片并列的话,内部的文字也是正常处理)
only_row_img_bboxs = []
for img_bbox in img_bboxs:
# 同层级是否只有图片
is_only_img = True
for block in blocks:
type = block["type"]
if type != 1:
bbox = block["bbox"]
# 判断当前是否是不在图片内的文本
is_inner_img = cls.is_inner_img_block(block, img_bboxs)
# 判断当前是否在同层级
is_save_level = (
(
bbox[0] > img_bbox[2]
or bbox[2] < img_bbox[0]
)
and bbox[1] >= img_bbox[1]
and bbox[3] <= img_bbox[3]
)
if not is_inner_img and is_save_level:
is_only_img = False
if is_only_img:
only_row_img_bboxs.append(img_bbox)
return only_row_img_bboxs
@classmethod
def filter_inner_img_block(cls, blocks: list[dict]) -> list[dict] | None:
if blocks is not None:
# 判断是否有图片
img_bboxs:list[dict] = []
for block in blocks:
bbox = block["bbox"]
type = block["type"]
if type == 1:
img_bboxs.append(bbox)
if len(img_bboxs) > 0:
# 获取所有单独在一行的图片区域集合
only_row_img_bboxs = cls.get_only_row_img_bboxs(blocks, img_bboxs)
filter_blocks: list[dict] = []
for block in blocks:
type = block["type"]
if type != 1:
# 判断是否在单行都是图片的区域
is_only_row_img = cls.is_inner_img_block(block, only_row_img_bboxs)
# 判断是否在图片区域内
is_inner_img = cls.is_inner_img_block(block, img_bboxs)
# 如果在单行都是图片的区域内,返回值。或者不在图片区域内,返回值。
if is_only_row_img or not is_inner_img:
filter_blocks.append(block)
# else:
# content_ = cls.load_block_content(block)
# print(content_)
return filter_blocks
return blocks
@classmethod
def get_block_content(cls, block: dict, page_width: float, page_height : float) -> str | None:
header = cls.is_header_fitz(block=block, page_width=page_width, page_height=page_height)
footer = cls.is_footer_fitz(block=block, page_width=page_width, page_height=page_height)
if not header and not footer:
return cls.load_block_content(block=block)
return None
@classmethod
def load_block_content(cls, block: dict) -> str | None:
if "lines" in block:
line_texts = []
lines = cls.handle_lines(block["lines"])
for line in lines:
texts = []
for span in line["spans"]:
text = span["text"].strip()
if text:
texts.append(text)
if len(texts) > 0:
line_text = "".join(texts)
line_texts.append(line_text)
if len(line_texts) > 0:
# print("************************************",len(line_texts))
# print("\n".join(line_texts))
return "\n".join(line_texts)
return None
@classmethod
def is_heading_fitz(cls, span: dict, page_width: float) -> bool:
"""
判断一个文本片段是否为标题
:param span: PyMuPDF 返回的文本片段信息
:param page_width: 页面宽度用于判断位置
:return: 是否为标题
"""
# 特征 1字体加粗
is_bold = "bold" in span["font"].lower()
# 特征 2字体大小相对较大
is_large_font = span["size"] > 14 # 适当降低阈值
# 特征 3文本位置靠近页面顶部或居中
is_top = span["origin"][1] < 100 # 距离页面顶部小于 100 像素
is_centered = abs(span["origin"][0] - page_width / 2) < 50 # 水平居中
# 特征 4文本格式包含大写字母或编号
text = span["text"].strip()
is_uppercase = text.isupper()
is_numbered = any(text.startswith(f"{i}.") for i in range(1, 10)) # 如 "1.", "2."
# 综合判断
# return is_bold or is_large_font
# flg = (is_bold or is_large_font) and (is_top or is_centered) and (is_uppercase or is_numbered)
flg = is_numbered
if flg :
print("标题:",text)
return flg
@classmethod
def is_top_fitz(cls, bbox_top: float) -> bool:
# print("----------------",bbox_top)
if bbox_top < 58:
return True
return False
@classmethod
def is_bottom_fitz(cls, bbox_bottom: float, page_height) -> bool:
# print("----------------",bbox_top)
if bbox_bottom > page_height - 60:
return True
return False
@classmethod
def is_centered_fitz(cls, origin_x: float, page_width: float) -> bool:
is_centered = abs(origin_x - page_width / 2) < 20 # 水平居中
return is_centered
@classmethod
def get_origin_bybbox(cls,bbox:list[float]) -> list[float]:
x = (bbox[0] + bbox[2]) / 2
y = (bbox[1] + bbox[3]) / 2
origin = list([x,y])
return origin
@classmethod
def get_first_span(cls,block:dict) -> dict | None:
if "lines" in block and len(block["lines"]) > 0:
lines = block["lines"]
line = lines[0]
return line["spans"][0]
return None
@classmethod
def is_bold_byspan(cls,span:dict) -> bool:
is_bold = "bold" in span["font"].lower()
return is_bold
@classmethod
def is_header_fitz(cls,block: dict,page_width: float, page_height: float) -> bool:
# 判断block
if block:
bbox = block["bbox"]
# number = block["number"]
origin = cls.get_origin_bybbox(bbox)
if "lines" in block and len(block["lines"]) == 1:
span = cls.get_first_span(block)
# 是否加粗
is_bold = cls.is_bold_byspan(span)
# 判断字体是否比较小
is_font_size = span["size"] < 10
# 不满行
is_not_full_line = bbox[0] > 120 or (bbox[3] < (page_width -120))
# 靠近页面顶部
is_top = cls.is_top_fitz(bbox_top = bbox[1]) # 距离页面顶部小于 100 像素
# 水平居中
is_centered = cls.is_centered_fitz(origin_x = origin[0], page_width=page_width)
# 在判断字体是否加粗,或者字体大小,一般页眉页脚的字体比较小
return is_top or ( bbox[1] < 70 and not is_bold and is_font_size and is_centered and is_not_full_line )
return False
@classmethod
def is_footer_fitz(cls,block: dict,page_width: float, page_height: float) -> bool:
# 判断block
if block:
bbox = block["bbox"]
# number = block["number"]
origin = cls.get_origin_bybbox(bbox)
if "lines" in block :
span = cls.get_first_span(block)
# 是否加粗
is_bold = cls.is_bold_byspan(span)
# 判断字体是否比较小
is_font_size = span["size"] < 10
# 不满行
is_not_full_line = bbox[0] > 120 or (bbox[3] < (page_width -120))
# 靠近页面顶部
is_bottom = cls.is_bottom_fitz(bbox_bottom = bbox[3],page_height=page_height) # 距离页面顶部小于 100 像素
# 水平居中
is_centered = cls.is_centered_fitz(origin_x = origin[0], page_width=page_width)
# 在判断字体是否加粗,或者字体大小,一般页眉页脚的字体比较小
return is_bottom or ( bbox[3] > page_height - 70 and not is_bold and is_font_size and is_centered and is_not_full_line )
return False
if __name__ == "__main__":
readPdfService = ReadPdfService()
content = readPdfService.load_content(pdf_file_path=r"D:\a.pdf")
# print(content)
print(content)
# PyPdfService.get_headline_page_dictionary(pdf_file_path=r"D:\a.pdf")

@ -37,6 +37,7 @@ class FileService:
user: Union[Account, EndUser, Any], user: Union[Account, EndUser, Any],
source: Literal["datasets"] | None = None, source: Literal["datasets"] | None = None,
source_url: str = "", source_url: str = "",
file_id: int | None = None,
) -> UploadFile: ) -> UploadFile:
# get file extension # get file extension
extension = os.path.splitext(filename)[1].lstrip(".").lower() extension = os.path.splitext(filename)[1].lstrip(".").lower()
@ -87,6 +88,7 @@ class FileService:
used=False, used=False,
hash=hashlib.sha3_256(content).hexdigest(), hash=hashlib.sha3_256(content).hexdigest(),
source_url=source_url, source_url=source_url,
file_id=file_id
) )
db.session.add(upload_file) db.session.add(upload_file)

@ -111,3 +111,24 @@ class WebConversationService:
db.session.delete(pinned_conversation) db.session.delete(pinned_conversation)
db.session.commit() db.session.commit()
@classmethod
def batch_unpin(cls, app_model: App, conversation_ids: list[str], user: Optional[Union[Account, EndUser]]):
if not user:
return
pinned_conversations = (
db.session.query(PinnedConversation)
.filter(
PinnedConversation.app_id == app_model.id,
PinnedConversation.conversation_id.in_(conversation_ids),
PinnedConversation.created_by_role == ("account" if isinstance(user, Account) else "end_user"),
PinnedConversation.created_by == user.id,
)
.first()
)
if pinned_conversations is None:
return
db.session.delete(pinned_conversations)
db.session.commit()

@ -0,0 +1,49 @@
from typing import Any
import toml # type: ignore
def load_api_poetry_configs() -> dict[str, Any]:
pyproject_toml = toml.load("api/pyproject.toml")
return pyproject_toml["tool"]["poetry"]
def load_all_dependency_groups() -> dict[str, dict[str, dict[str, Any]]]:
configs = load_api_poetry_configs()
configs_by_group = {"main": configs}
for group_name in configs["group"]:
configs_by_group[group_name] = configs["group"][group_name]
dependencies_by_group = {group_name: base["dependencies"] for group_name, base in configs_by_group.items()}
return dependencies_by_group
def test_group_dependencies_sorted():
for group_name, dependencies in load_all_dependency_groups().items():
dependency_names = list(dependencies.keys())
expected_dependency_names = sorted(set(dependency_names))
section = f"tool.poetry.group.{group_name}.dependencies" if group_name else "tool.poetry.dependencies"
assert expected_dependency_names == dependency_names, (
f"Dependencies in group {group_name} are not sorted. "
f"Check and fix [{section}] section in pyproject.toml file"
)
def test_group_dependencies_version_operator():
for group_name, dependencies in load_all_dependency_groups().items():
for dependency_name, specification in dependencies.items():
version_spec = specification if isinstance(specification, str) else specification["version"]
assert not version_spec.startswith("^"), (
f"Please replace '{dependency_name} = {version_spec}' with '{dependency_name} = ~{version_spec[1:]}' "
f"'^' operator is too wide and not allowed in the version specification."
)
def test_duplicated_dependency_crossing_groups() -> None:
all_dependency_names: list[str] = []
for dependencies in load_all_dependency_groups().values():
dependency_names = list(dependencies.keys())
all_dependency_names.extend(dependency_names)
expected_all_dependency_names = set(all_dependency_names)
assert sorted(expected_all_dependency_names) == sorted(all_dependency_names), (
"Duplicated dependencies crossing groups are found"
)

@ -0,0 +1,7 @@
from core.helper.marketplace import download_plugin_pkg
def test_download_plugin_pkg():
pkg = download_plugin_pkg("langgenius/bing:0.0.1@e58735424d2104f208c2bd683c5142e0332045b425927067acf432b26f3d970b")
assert pkg is not None
assert len(pkg) > 0

@ -0,0 +1,18 @@
#!/bin/bash
# rely on `poetry` in path
if ! command -v poetry &> /dev/null; then
echo "Installing Poetry ..."
pip install poetry
fi
# check poetry.lock in sync with pyproject.toml
poetry check -C api --lock
if [ $? -ne 0 ]; then
# update poetry.lock
# refreshing lockfile only without updating locked versions
echo "poetry.lock is outdated, refreshing without updating locked versions ..."
poetry lock -C api
else
echo "poetry.lock is ready."
fi

@ -0,0 +1,13 @@
#!/bin/bash
# rely on `poetry` in path
if ! command -v poetry &> /dev/null; then
echo "Installing Poetry ..."
pip install poetry
fi
# refreshing lockfile, updating locked versions
poetry update -C api
# check poetry.lock in sync with pyproject.toml
poetry check -C api --lock

@ -0,0 +1,4 @@
docker compose down
docker compose up -d

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.3 MiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 790 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 257 KiB

Loading…
Cancel
Save