Merge branch 'feat/r2' into deploy/rag-dev
# Conflicts: # web/app/components/workflow-app/components/workflow-main.tsx # web/app/components/workflow/constants.ts # web/app/components/workflow/header/run-and-history.tsx # web/app/components/workflow/hooks-store/store.ts # web/app/components/workflow/hooks/use-nodes-interactions.ts # web/app/components/workflow/hooks/use-workflow-interactions.ts # web/app/components/workflow/hooks/use-workflow.ts # web/app/components/workflow/nodes/_base/components/panel-operator/panel-operator-popup.tsx # web/app/components/workflow/nodes/_base/components/variable/var-reference-picker.tsx # web/app/components/workflow/nodes/code/use-config.ts # web/app/components/workflow/nodes/llm/default.ts # web/app/components/workflow/panel/index.tsx # web/app/components/workflow/panel/version-history-panel/index.tsx # web/app/components/workflow/store/workflow/index.ts # web/app/components/workflow/types.ts # web/config/index.ts # web/types/workflow.tsfeat/rag-2
commit
7c5893db91
@ -0,0 +1,197 @@
|
|||||||
|
from flask import redirect, request
|
||||||
|
from flask_login import current_user # type: ignore
|
||||||
|
from flask_restful import ( # type: ignore
|
||||||
|
Resource, # type: ignore
|
||||||
|
reqparse,
|
||||||
|
)
|
||||||
|
from werkzeug.exceptions import Forbidden, NotFound
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||||
|
from core.plugin.impl.oauth import OAuthHandler
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.login import login_required
|
||||||
|
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
|
||||||
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourcePluginOauthApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||||
|
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
provider = args["provider"]
|
||||||
|
plugin_id = args["plugin_id"]
|
||||||
|
# Check user role first
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
# get all plugin oauth configs
|
||||||
|
plugin_oauth_config = (
|
||||||
|
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
||||||
|
)
|
||||||
|
if not plugin_oauth_config:
|
||||||
|
raise NotFound()
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
redirect_url = (
|
||||||
|
f"{dify_config.CONSOLE_WEB_URL}/oauth/datasource/callback?provider={provider}&plugin_id={plugin_id}"
|
||||||
|
)
|
||||||
|
system_credentials = plugin_oauth_config.system_credentials
|
||||||
|
if system_credentials:
|
||||||
|
system_credentials["redirect_url"] = redirect_url
|
||||||
|
response = oauth_handler.get_authorization_url(
|
||||||
|
current_user.current_tenant.id, current_user.id, plugin_id, provider, system_credentials=system_credentials
|
||||||
|
)
|
||||||
|
return response.model_dump()
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceOauthCallback(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||||
|
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
provider = args["provider"]
|
||||||
|
plugin_id = args["plugin_id"]
|
||||||
|
oauth_handler = OAuthHandler()
|
||||||
|
plugin_oauth_config = (
|
||||||
|
db.session.query(DatasourceOauthParamConfig).filter_by(provider=provider, plugin_id=plugin_id).first()
|
||||||
|
)
|
||||||
|
if not plugin_oauth_config:
|
||||||
|
raise NotFound()
|
||||||
|
credentials = oauth_handler.get_credentials(
|
||||||
|
current_user.current_tenant.id,
|
||||||
|
current_user.id,
|
||||||
|
plugin_id,
|
||||||
|
provider,
|
||||||
|
system_credentials=plugin_oauth_config.system_credentials,
|
||||||
|
request=request,
|
||||||
|
)
|
||||||
|
datasource_provider = DatasourceProvider(
|
||||||
|
plugin_id=plugin_id, provider=provider, auth_type="oauth", encrypted_credentials=credentials
|
||||||
|
)
|
||||||
|
db.session.add(datasource_provider)
|
||||||
|
db.session.commit()
|
||||||
|
return redirect(f"{dify_config.CONSOLE_WEB_URL}")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceAuth(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def post(self):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("name", type=str, required=False, nullable=False, location="json", default="test")
|
||||||
|
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
|
||||||
|
try:
|
||||||
|
datasource_provider_service.datasource_provider_credentials_validate(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
provider=args["provider"],
|
||||||
|
plugin_id=args["plugin_id"],
|
||||||
|
credentials=args["credentials"],
|
||||||
|
name=args["name"],
|
||||||
|
)
|
||||||
|
except CredentialsValidateFailedError as ex:
|
||||||
|
raise ValueError(str(ex))
|
||||||
|
|
||||||
|
return {"result": "success"}, 201
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||||
|
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasources = datasource_provider_service.get_datasource_credentials(
|
||||||
|
tenant_id=current_user.current_tenant_id, provider=args["provider"], plugin_id=args["plugin_id"]
|
||||||
|
)
|
||||||
|
return {"result": datasources}, 200
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceAuthUpdateDeleteApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def delete(self, auth_id: str):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||||
|
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasource_provider_service.remove_datasource_credentials(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
auth_id=auth_id,
|
||||||
|
provider=args["provider"],
|
||||||
|
plugin_id=args["plugin_id"],
|
||||||
|
)
|
||||||
|
return {"result": "success"}, 200
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
def patch(self, auth_id: str):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("provider", type=str, required=True, nullable=False, location="args")
|
||||||
|
parser.add_argument("plugin_id", type=str, required=True, nullable=False, location="args")
|
||||||
|
parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
try:
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
datasource_provider_service.update_datasource_credentials(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
auth_id=auth_id,
|
||||||
|
provider=args["provider"],
|
||||||
|
plugin_id=args["plugin_id"],
|
||||||
|
credentials=args["credentials"],
|
||||||
|
)
|
||||||
|
except CredentialsValidateFailedError as ex:
|
||||||
|
raise ValueError(str(ex))
|
||||||
|
|
||||||
|
return {"result": "success"}, 201
|
||||||
|
|
||||||
|
|
||||||
|
# Import Rag Pipeline
|
||||||
|
api.add_resource(
|
||||||
|
DatasourcePluginOauthApi,
|
||||||
|
"/oauth/plugin/datasource",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceOauthCallback,
|
||||||
|
"/oauth/plugin/datasource/callback",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceAuth,
|
||||||
|
"/auth/plugin/datasource",
|
||||||
|
)
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
DatasourceAuthUpdateDeleteApi,
|
||||||
|
"/auth/plugin/datasource/<string:auth_id>",
|
||||||
|
)
|
||||||
@ -0,0 +1,55 @@
|
|||||||
|
from flask_restful import ( # type: ignore
|
||||||
|
Resource, # type: ignore
|
||||||
|
reqparse,
|
||||||
|
)
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||||
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
from libs.login import current_user, login_required
|
||||||
|
from models import Account
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
|
|
||||||
|
|
||||||
|
class DataSourceContentPreviewApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_rag_pipeline
|
||||||
|
def post(self, pipeline: Pipeline, node_id: str):
|
||||||
|
"""
|
||||||
|
Run datasource content preview
|
||||||
|
"""
|
||||||
|
if not isinstance(current_user, Account):
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||||
|
parser.add_argument("datasource_type", type=str, required=True, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
inputs = args.get("inputs")
|
||||||
|
if inputs is None:
|
||||||
|
raise ValueError("missing inputs")
|
||||||
|
datasource_type = args.get("datasource_type")
|
||||||
|
if datasource_type is None:
|
||||||
|
raise ValueError("missing datasource_type")
|
||||||
|
|
||||||
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
preview_content = rag_pipeline_service.run_datasource_node_preview(
|
||||||
|
pipeline=pipeline,
|
||||||
|
node_id=node_id,
|
||||||
|
user_inputs=inputs,
|
||||||
|
account=current_user,
|
||||||
|
datasource_type=datasource_type,
|
||||||
|
is_published=True,
|
||||||
|
)
|
||||||
|
return preview_content, 200
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
DataSourceContentPreviewApi,
|
||||||
|
"/rag/pipelines/<uuid:pipeline_id>/workflows/published/datasource/nodes/<string:node_id>/preview",
|
||||||
|
)
|
||||||
@ -0,0 +1,162 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from flask_restful import Resource, reqparse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
enterprise_license_required,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.login import login_required
|
||||||
|
from models.dataset import PipelineCustomizedTemplate
|
||||||
|
from services.entities.knowledge_entities.rag_pipeline_entities import PipelineTemplateInfoEntity
|
||||||
|
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_name(name):
|
||||||
|
if not name or len(name) < 1 or len(name) > 40:
|
||||||
|
raise ValueError("Name must be between 1 to 40 characters.")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_description_length(description):
|
||||||
|
if len(description) > 400:
|
||||||
|
raise ValueError("Description cannot exceed 400 characters.")
|
||||||
|
return description
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineTemplateListApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def get(self):
|
||||||
|
type = request.args.get("type", default="built-in", type=str)
|
||||||
|
language = request.args.get("language", default="en-US", type=str)
|
||||||
|
# get pipeline templates
|
||||||
|
pipeline_templates = RagPipelineService.get_pipeline_templates(type, language)
|
||||||
|
return pipeline_templates, 200
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineTemplateDetailApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def get(self, template_id: str):
|
||||||
|
type = request.args.get("type", default="built-in", type=str)
|
||||||
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
pipeline_template = rag_pipeline_service.get_pipeline_template_detail(template_id, type)
|
||||||
|
return pipeline_template, 200
|
||||||
|
|
||||||
|
|
||||||
|
class CustomizedPipelineTemplateApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def patch(self, template_id: str):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"name",
|
||||||
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
help="Name must be between 1 to 40 characters.",
|
||||||
|
type=_validate_name,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"description",
|
||||||
|
type=str,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"icon_info",
|
||||||
|
type=dict,
|
||||||
|
location="json",
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
pipeline_template_info = PipelineTemplateInfoEntity(**args)
|
||||||
|
RagPipelineService.update_customized_pipeline_template(template_id, pipeline_template_info)
|
||||||
|
return 200
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def delete(self, template_id: str):
|
||||||
|
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||||
|
return 200
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def post(self, template_id: str):
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
template = (
|
||||||
|
session.query(PipelineCustomizedTemplate).filter(PipelineCustomizedTemplate.id == template_id).first()
|
||||||
|
)
|
||||||
|
if not template:
|
||||||
|
raise ValueError("Customized pipeline template not found.")
|
||||||
|
|
||||||
|
return {"data": template.yaml_content}, 200
|
||||||
|
|
||||||
|
|
||||||
|
class PublishCustomizedPipelineTemplateApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@enterprise_license_required
|
||||||
|
def post(self, pipeline_id: str):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"name",
|
||||||
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
help="Name must be between 1 to 40 characters.",
|
||||||
|
type=_validate_name,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"description",
|
||||||
|
type=str,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"icon_info",
|
||||||
|
type=dict,
|
||||||
|
location="json",
|
||||||
|
nullable=True,
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
rag_pipeline_service = RagPipelineService()
|
||||||
|
rag_pipeline_service.publish_customized_pipeline_template(pipeline_id, args)
|
||||||
|
return {"result": "success"}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
PipelineTemplateListApi,
|
||||||
|
"/rag/pipeline/templates",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
PipelineTemplateDetailApi,
|
||||||
|
"/rag/pipeline/templates/<string:template_id>",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
CustomizedPipelineTemplateApi,
|
||||||
|
"/rag/pipeline/customized/templates/<string:template_id>",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
PublishCustomizedPipelineTemplateApi,
|
||||||
|
"/rag/pipelines/<string:pipeline_id>/customized/publish",
|
||||||
|
)
|
||||||
@ -0,0 +1,171 @@
|
|||||||
|
from flask_login import current_user # type: ignore # type: ignore
|
||||||
|
from flask_restful import Resource, marshal, reqparse # type: ignore
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
import services
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||||
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
cloud_edition_billing_rate_limit_check,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
from fields.dataset_fields import dataset_detail_fields
|
||||||
|
from libs.login import login_required
|
||||||
|
from models.dataset import DatasetPermissionEnum
|
||||||
|
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||||
|
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
|
||||||
|
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_name(name):
|
||||||
|
if not name or len(name) < 1 or len(name) > 40:
|
||||||
|
raise ValueError("Name must be between 1 to 40 characters.")
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_description_length(description):
|
||||||
|
if len(description) > 400:
|
||||||
|
raise ValueError("Description cannot exceed 400 characters.")
|
||||||
|
return description
|
||||||
|
|
||||||
|
|
||||||
|
class CreateRagPipelineDatasetApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
|
def post(self):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"name",
|
||||||
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
help="type is required. Name must be between 1 to 40 characters.",
|
||||||
|
type=_validate_name,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"description",
|
||||||
|
type=str,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"icon_info",
|
||||||
|
type=dict,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default={},
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"permission",
|
||||||
|
type=str,
|
||||||
|
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default=DatasetPermissionEnum.ONLY_ME,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"partial_member_list",
|
||||||
|
type=list,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"yaml_content",
|
||||||
|
type=str,
|
||||||
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
help="yaml_content is required.",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
|
if not current_user.is_dataset_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args)
|
||||||
|
try:
|
||||||
|
import_info = RagPipelineDslService.create_rag_pipeline_dataset(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||||
|
)
|
||||||
|
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||||
|
DatasetPermissionService.update_partial_member_list(
|
||||||
|
current_user.current_tenant_id,
|
||||||
|
import_info["dataset_id"],
|
||||||
|
rag_pipeline_dataset_create_entity.partial_member_list,
|
||||||
|
)
|
||||||
|
except services.errors.dataset.DatasetNameDuplicateError:
|
||||||
|
raise DatasetNameDuplicateError()
|
||||||
|
|
||||||
|
return import_info, 201
|
||||||
|
|
||||||
|
|
||||||
|
class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||||
|
def post(self):
|
||||||
|
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||||
|
if not current_user.is_dataset_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"name",
|
||||||
|
nullable=False,
|
||||||
|
required=True,
|
||||||
|
help="type is required. Name must be between 1 to 40 characters.",
|
||||||
|
type=_validate_name,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"description",
|
||||||
|
type=str,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default="",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"icon_info",
|
||||||
|
type=dict,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default={},
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"permission",
|
||||||
|
type=str,
|
||||||
|
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default=DatasetPermissionEnum.ONLY_ME,
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"partial_member_list",
|
||||||
|
type=list,
|
||||||
|
nullable=True,
|
||||||
|
required=False,
|
||||||
|
default=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||||
|
tenant_id=current_user.current_tenant_id,
|
||||||
|
rag_pipeline_dataset_create_entity=RagPipelineDatasetCreateEntity(**args),
|
||||||
|
)
|
||||||
|
return marshal(dataset, dataset_detail_fields), 201
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
|
||||||
|
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")
|
||||||
@ -0,0 +1,146 @@
|
|||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from flask_login import current_user # type: ignore
|
||||||
|
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||||
|
from controllers.console.wraps import (
|
||||||
|
account_initialization_required,
|
||||||
|
setup_required,
|
||||||
|
)
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
|
||||||
|
from libs.login import login_required
|
||||||
|
from models import Account
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
from services.app_dsl_service import ImportStatus
|
||||||
|
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineImportApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@marshal_with(pipeline_import_fields)
|
||||||
|
def post(self):
|
||||||
|
# Check user role first
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("mode", type=str, required=True, location="json")
|
||||||
|
parser.add_argument("yaml_content", type=str, location="json")
|
||||||
|
parser.add_argument("yaml_url", type=str, location="json")
|
||||||
|
parser.add_argument("name", type=str, location="json")
|
||||||
|
parser.add_argument("description", type=str, location="json")
|
||||||
|
parser.add_argument("icon_type", type=str, location="json")
|
||||||
|
parser.add_argument("icon", type=str, location="json")
|
||||||
|
parser.add_argument("icon_background", type=str, location="json")
|
||||||
|
parser.add_argument("pipeline_id", type=str, location="json")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# Create service with session
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
import_service = RagPipelineDslService(session)
|
||||||
|
# Import app
|
||||||
|
account = cast(Account, current_user)
|
||||||
|
result = import_service.import_rag_pipeline(
|
||||||
|
account=account,
|
||||||
|
import_mode=args["mode"],
|
||||||
|
yaml_content=args.get("yaml_content"),
|
||||||
|
yaml_url=args.get("yaml_url"),
|
||||||
|
pipeline_id=args.get("pipeline_id"),
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Return appropriate status code based on result
|
||||||
|
status = result.status
|
||||||
|
if status == ImportStatus.FAILED.value:
|
||||||
|
return result.model_dump(mode="json"), 400
|
||||||
|
elif status == ImportStatus.PENDING.value:
|
||||||
|
return result.model_dump(mode="json"), 202
|
||||||
|
return result.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineImportConfirmApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@marshal_with(pipeline_import_fields)
|
||||||
|
def post(self, import_id):
|
||||||
|
# Check user role first
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
# Create service with session
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
import_service = RagPipelineDslService(session)
|
||||||
|
# Confirm import
|
||||||
|
account = cast(Account, current_user)
|
||||||
|
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Return appropriate status code based on result
|
||||||
|
if result.status == ImportStatus.FAILED.value:
|
||||||
|
return result.model_dump(mode="json"), 400
|
||||||
|
return result.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineImportCheckDependenciesApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@get_rag_pipeline
|
||||||
|
@account_initialization_required
|
||||||
|
@marshal_with(pipeline_import_check_dependencies_fields)
|
||||||
|
def get(self, pipeline: Pipeline):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
import_service = RagPipelineDslService(session)
|
||||||
|
result = import_service.check_dependencies(pipeline=pipeline)
|
||||||
|
|
||||||
|
return result.model_dump(mode="json"), 200
|
||||||
|
|
||||||
|
|
||||||
|
class RagPipelineExportApi(Resource):
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@get_rag_pipeline
|
||||||
|
@account_initialization_required
|
||||||
|
def get(self, pipeline: Pipeline):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
|
||||||
|
# Add include_secret params
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument("include_secret", type=bool, default=False, location="args")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
export_service = RagPipelineDslService(session)
|
||||||
|
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
|
||||||
|
|
||||||
|
return {"data": result}, 200
|
||||||
|
|
||||||
|
|
||||||
|
# Import Rag Pipeline
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineImportApi,
|
||||||
|
"/rag/pipelines/imports",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineImportConfirmApi,
|
||||||
|
"/rag/pipelines/imports/<string:import_id>/confirm",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineImportCheckDependenciesApi,
|
||||||
|
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
|
||||||
|
)
|
||||||
|
api.add_resource(
|
||||||
|
RagPipelineExportApi,
|
||||||
|
"/rag/pipelines/<string:pipeline_id>/exports",
|
||||||
|
)
|
||||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,43 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from controllers.console.datasets.error import PipelineNotFoundError
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.login import current_user
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def get_rag_pipeline(
|
||||||
|
view: Optional[Callable] = None,
|
||||||
|
):
|
||||||
|
def decorator(view_func):
|
||||||
|
@wraps(view_func)
|
||||||
|
def decorated_view(*args, **kwargs):
|
||||||
|
if not kwargs.get("pipeline_id"):
|
||||||
|
raise ValueError("missing pipeline_id in path parameters")
|
||||||
|
|
||||||
|
pipeline_id = kwargs.get("pipeline_id")
|
||||||
|
pipeline_id = str(pipeline_id)
|
||||||
|
|
||||||
|
del kwargs["pipeline_id"]
|
||||||
|
|
||||||
|
pipeline = (
|
||||||
|
db.session.query(Pipeline)
|
||||||
|
.filter(Pipeline.id == pipeline_id, Pipeline.tenant_id == current_user.current_tenant_id)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not pipeline:
|
||||||
|
raise PipelineNotFoundError()
|
||||||
|
|
||||||
|
kwargs["pipeline"] = pipeline
|
||||||
|
|
||||||
|
return view_func(*args, **kwargs)
|
||||||
|
|
||||||
|
return decorated_view
|
||||||
|
|
||||||
|
if view is None:
|
||||||
|
return decorator
|
||||||
|
else:
|
||||||
|
return decorator(view)
|
||||||
@ -0,0 +1,95 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||||
|
from core.app.entities.task_entities import (
|
||||||
|
AppStreamResponse,
|
||||||
|
ErrorStreamResponse,
|
||||||
|
NodeFinishStreamResponse,
|
||||||
|
NodeStartStreamResponse,
|
||||||
|
PingStreamResponse,
|
||||||
|
WorkflowAppBlockingResponse,
|
||||||
|
WorkflowAppStreamResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||||
|
_blocking_response_type = WorkflowAppBlockingResponse
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_blocking_full_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||||
|
"""
|
||||||
|
Convert blocking full response.
|
||||||
|
:param blocking_response: blocking response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return dict(blocking_response.to_dict())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_blocking_simple_response(cls, blocking_response: WorkflowAppBlockingResponse) -> dict: # type: ignore[override]
|
||||||
|
"""
|
||||||
|
Convert blocking simple response.
|
||||||
|
:param blocking_response: blocking response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
return cls.convert_blocking_full_response(blocking_response)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_stream_full_response(
|
||||||
|
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||||
|
) -> Generator[dict | str, None, None]:
|
||||||
|
"""
|
||||||
|
Convert stream full response.
|
||||||
|
:param stream_response: stream response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for chunk in stream_response:
|
||||||
|
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||||
|
sub_stream_response = chunk.stream_response
|
||||||
|
|
||||||
|
if isinstance(sub_stream_response, PingStreamResponse):
|
||||||
|
yield "ping"
|
||||||
|
continue
|
||||||
|
|
||||||
|
response_chunk = {
|
||||||
|
"event": sub_stream_response.event.value,
|
||||||
|
"workflow_run_id": chunk.workflow_run_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||||
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
|
response_chunk.update(data)
|
||||||
|
else:
|
||||||
|
response_chunk.update(sub_stream_response.to_dict())
|
||||||
|
yield response_chunk
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def convert_stream_simple_response(
|
||||||
|
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||||
|
) -> Generator[dict | str, None, None]:
|
||||||
|
"""
|
||||||
|
Convert stream simple response.
|
||||||
|
:param stream_response: stream response
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
for chunk in stream_response:
|
||||||
|
chunk = cast(WorkflowAppStreamResponse, chunk)
|
||||||
|
sub_stream_response = chunk.stream_response
|
||||||
|
|
||||||
|
if isinstance(sub_stream_response, PingStreamResponse):
|
||||||
|
yield "ping"
|
||||||
|
continue
|
||||||
|
|
||||||
|
response_chunk = {
|
||||||
|
"event": sub_stream_response.event.value,
|
||||||
|
"workflow_run_id": chunk.workflow_run_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if isinstance(sub_stream_response, ErrorStreamResponse):
|
||||||
|
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||||
|
response_chunk.update(data)
|
||||||
|
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||||
|
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||||
|
else:
|
||||||
|
response_chunk.update(sub_stream_response.to_dict())
|
||||||
|
yield response_chunk
|
||||||
@ -0,0 +1,64 @@
|
|||||||
|
from core.app.app_config.base_app_config_manager import BaseAppConfigManager
|
||||||
|
from core.app.app_config.common.sensitive_word_avoidance.manager import SensitiveWordAvoidanceConfigManager
|
||||||
|
from core.app.app_config.entities import RagPipelineVariableEntity, WorkflowUIBasedAppConfig
|
||||||
|
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||||
|
from core.app.app_config.features.text_to_speech.manager import TextToSpeechConfigManager
|
||||||
|
from core.app.app_config.workflow_ui_based_app.variables.manager import WorkflowVariablesConfigManager
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
from models.model import AppMode
|
||||||
|
from models.workflow import Workflow
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineConfig(WorkflowUIBasedAppConfig):
|
||||||
|
"""
|
||||||
|
Pipeline Config Entity.
|
||||||
|
"""
|
||||||
|
|
||||||
|
rag_pipeline_variables: list[RagPipelineVariableEntity] = []
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineConfigManager(BaseAppConfigManager):
|
||||||
|
@classmethod
|
||||||
|
def get_pipeline_config(cls, pipeline: Pipeline, workflow: Workflow) -> PipelineConfig:
|
||||||
|
pipeline_config = PipelineConfig(
|
||||||
|
tenant_id=pipeline.tenant_id,
|
||||||
|
app_id=pipeline.id,
|
||||||
|
app_mode=AppMode.RAG_PIPELINE,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
rag_pipeline_variables=WorkflowVariablesConfigManager.convert_rag_pipeline_variable(workflow=workflow),
|
||||||
|
)
|
||||||
|
|
||||||
|
return pipeline_config
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def config_validate(cls, tenant_id: str, config: dict, only_structure_validate: bool = False) -> dict:
|
||||||
|
"""
|
||||||
|
Validate for pipeline config
|
||||||
|
|
||||||
|
:param tenant_id: tenant id
|
||||||
|
:param config: app model config args
|
||||||
|
:param only_structure_validate: only validate the structure of the config
|
||||||
|
"""
|
||||||
|
related_config_keys = []
|
||||||
|
|
||||||
|
# file upload validation
|
||||||
|
config, current_related_config_keys = FileUploadConfigManager.validate_and_set_defaults(config=config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# text_to_speech
|
||||||
|
config, current_related_config_keys = TextToSpeechConfigManager.validate_and_set_defaults(config)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
# moderation validation
|
||||||
|
config, current_related_config_keys = SensitiveWordAvoidanceConfigManager.validate_and_set_defaults(
|
||||||
|
tenant_id=tenant_id, config=config, only_structure_validate=only_structure_validate
|
||||||
|
)
|
||||||
|
related_config_keys.extend(current_related_config_keys)
|
||||||
|
|
||||||
|
related_config_keys = list(set(related_config_keys))
|
||||||
|
|
||||||
|
# Filter out extra parameters
|
||||||
|
filtered_config = {key: config.get(key) for key in related_config_keys}
|
||||||
|
|
||||||
|
return filtered_config
|
||||||
@ -0,0 +1,621 @@
|
|||||||
|
import contextvars
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections.abc import Generator, Mapping
|
||||||
|
from typing import Any, Literal, Optional, Union, overload
|
||||||
|
|
||||||
|
from flask import Flask, current_app
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
import contexts
|
||||||
|
from configs import dify_config
|
||||||
|
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||||
|
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager
|
||||||
|
from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager
|
||||||
|
from core.app.apps.pipeline.pipeline_runner import PipelineRunner
|
||||||
|
from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter
|
||||||
|
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||||
|
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||||
|
from core.entities.knowledge_entities import PipelineDataset, PipelineDocument
|
||||||
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
|
from core.rag.index_processor.constant.built_in_field import BuiltInField
|
||||||
|
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||||
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
|
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from libs.flask_utils import preserve_flask_contexts
|
||||||
|
from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||||
|
from models.dataset import Document, DocumentPipelineExecutionLog, Pipeline
|
||||||
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
|
from models.model import AppMode
|
||||||
|
from services.dataset_service import DocumentService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineGenerator(BaseAppGenerator):
|
||||||
|
@overload
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: Literal[True],
|
||||||
|
call_depth: int,
|
||||||
|
workflow_thread_pool_id: Optional[str],
|
||||||
|
) -> Mapping[str, Any] | Generator[Mapping | str, None, None] | None: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: Literal[False],
|
||||||
|
call_depth: int,
|
||||||
|
workflow_thread_pool_id: Optional[str],
|
||||||
|
) -> Mapping[str, Any]: ...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: bool,
|
||||||
|
call_depth: int,
|
||||||
|
workflow_thread_pool_id: Optional[str],
|
||||||
|
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None]]: ...
|
||||||
|
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
streaming: bool = True,
|
||||||
|
call_depth: int = 0,
|
||||||
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
|
) -> Union[Mapping[str, Any], Generator[Mapping | str, None, None], None]:
|
||||||
|
# convert to app config
|
||||||
|
pipeline_config = PipelineConfigManager.get_pipeline_config(
|
||||||
|
pipeline=pipeline,
|
||||||
|
workflow=workflow,
|
||||||
|
)
|
||||||
|
# Add null check for dataset
|
||||||
|
dataset = pipeline.dataset
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("Pipeline dataset is required")
|
||||||
|
inputs: Mapping[str, Any] = args["inputs"]
|
||||||
|
start_node_id: str = args["start_node_id"]
|
||||||
|
datasource_type: str = args["datasource_type"]
|
||||||
|
datasource_info_list: list[Mapping[str, Any]] = args["datasource_info_list"]
|
||||||
|
batch = time.strftime("%Y%m%d%H%M%S") + str(secrets.randbelow(900000) + 100000)
|
||||||
|
documents = []
|
||||||
|
if invoke_from == InvokeFrom.PUBLISHED:
|
||||||
|
for datasource_info in datasource_info_list:
|
||||||
|
position = DocumentService.get_documents_position(dataset.id)
|
||||||
|
document = self._build_document(
|
||||||
|
tenant_id=pipeline.tenant_id,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
built_in_field_enabled=dataset.built_in_field_enabled,
|
||||||
|
datasource_type=datasource_type,
|
||||||
|
datasource_info=datasource_info,
|
||||||
|
created_from="rag-pipeline",
|
||||||
|
position=position,
|
||||||
|
account=user,
|
||||||
|
batch=batch,
|
||||||
|
document_form=dataset.chunk_structure,
|
||||||
|
)
|
||||||
|
db.session.add(document)
|
||||||
|
documents.append(document)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
# run in child thread
|
||||||
|
for i, datasource_info in enumerate(datasource_info_list):
|
||||||
|
workflow_run_id = str(uuid.uuid4())
|
||||||
|
document_id = None
|
||||||
|
if invoke_from == InvokeFrom.PUBLISHED:
|
||||||
|
document_id = documents[i].id
|
||||||
|
document_pipeline_execution_log = DocumentPipelineExecutionLog(
|
||||||
|
document_id=document_id,
|
||||||
|
datasource_type=datasource_type,
|
||||||
|
datasource_info=json.dumps(datasource_info),
|
||||||
|
datasource_node_id=start_node_id,
|
||||||
|
input_data=inputs,
|
||||||
|
pipeline_id=pipeline.id,
|
||||||
|
created_by=user.id,
|
||||||
|
)
|
||||||
|
db.session.add(document_pipeline_execution_log)
|
||||||
|
db.session.commit()
|
||||||
|
application_generate_entity = RagPipelineGenerateEntity(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
app_config=pipeline_config,
|
||||||
|
pipeline_config=pipeline_config,
|
||||||
|
datasource_type=datasource_type,
|
||||||
|
datasource_info=datasource_info,
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
start_node_id=start_node_id,
|
||||||
|
batch=batch,
|
||||||
|
document_id=document_id,
|
||||||
|
inputs=self._prepare_user_inputs(
|
||||||
|
user_inputs=inputs,
|
||||||
|
variables=pipeline_config.rag_pipeline_variables,
|
||||||
|
tenant_id=pipeline.tenant_id,
|
||||||
|
strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False,
|
||||||
|
),
|
||||||
|
files=[],
|
||||||
|
user_id=user.id,
|
||||||
|
stream=streaming,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
call_depth=call_depth,
|
||||||
|
workflow_execution_id=workflow_run_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
contexts.plugin_tool_providers.set({})
|
||||||
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
if invoke_from == InvokeFrom.DEBUGGER:
|
||||||
|
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING
|
||||||
|
else:
|
||||||
|
workflow_triggered_from = WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN
|
||||||
|
# Create workflow node execution repository
|
||||||
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=workflow_triggered_from,
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||||
|
)
|
||||||
|
if invoke_from == InvokeFrom.DEBUGGER:
|
||||||
|
return self._generate(
|
||||||
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
|
context=contextvars.copy_context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
user=user,
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
invoke_from=invoke_from,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
streaming=streaming,
|
||||||
|
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# run in child thread
|
||||||
|
context = contextvars.copy_context()
|
||||||
|
|
||||||
|
worker_thread = threading.Thread(
|
||||||
|
target=self._generate,
|
||||||
|
kwargs={
|
||||||
|
"flask_app": current_app._get_current_object(), # type: ignore
|
||||||
|
"context": context,
|
||||||
|
"pipeline": pipeline,
|
||||||
|
"workflow_id": workflow.id,
|
||||||
|
"user": user,
|
||||||
|
"application_generate_entity": application_generate_entity,
|
||||||
|
"invoke_from": invoke_from,
|
||||||
|
"workflow_execution_repository": workflow_execution_repository,
|
||||||
|
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||||
|
"streaming": streaming,
|
||||||
|
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_thread.start()
|
||||||
|
# return batch, dataset, documents
|
||||||
|
return {
|
||||||
|
"batch": batch,
|
||||||
|
"dataset": PipelineDataset(
|
||||||
|
id=dataset.id,
|
||||||
|
name=dataset.name,
|
||||||
|
description=dataset.description,
|
||||||
|
chunk_structure=dataset.chunk_structure,
|
||||||
|
).model_dump(),
|
||||||
|
"documents": [
|
||||||
|
PipelineDocument(
|
||||||
|
id=document.id,
|
||||||
|
position=document.position,
|
||||||
|
data_source_type=document.data_source_type,
|
||||||
|
data_source_info=json.loads(document.data_source_info) if document.data_source_info else None,
|
||||||
|
name=document.name,
|
||||||
|
indexing_status=document.indexing_status,
|
||||||
|
error=document.error,
|
||||||
|
enabled=document.enabled,
|
||||||
|
).model_dump()
|
||||||
|
for document in documents
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
flask_app: Flask,
|
||||||
|
context: contextvars.Context,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow_id: str,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
|
invoke_from: InvokeFrom,
|
||||||
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
|
streaming: bool = True,
|
||||||
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
|
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||||
|
"""
|
||||||
|
Generate App response.
|
||||||
|
|
||||||
|
:param pipeline: Pipeline
|
||||||
|
:param workflow: Workflow
|
||||||
|
:param user: account or end user
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param invoke_from: invoke from source
|
||||||
|
:param workflow_execution_repository: repository for workflow execution
|
||||||
|
:param workflow_node_execution_repository: repository for workflow node execution
|
||||||
|
:param streaming: is stream
|
||||||
|
:param workflow_thread_pool_id: workflow thread pool id
|
||||||
|
"""
|
||||||
|
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||||
|
# init queue manager
|
||||||
|
workflow = db.session.query(Workflow).filter(Workflow.id == workflow_id).first()
|
||||||
|
if not workflow:
|
||||||
|
raise ValueError(f"Workflow not found: {workflow_id}")
|
||||||
|
queue_manager = PipelineQueueManager(
|
||||||
|
task_id=application_generate_entity.task_id,
|
||||||
|
user_id=application_generate_entity.user_id,
|
||||||
|
invoke_from=application_generate_entity.invoke_from,
|
||||||
|
app_mode=AppMode.RAG_PIPELINE,
|
||||||
|
)
|
||||||
|
context = contextvars.copy_context()
|
||||||
|
|
||||||
|
# new thread
|
||||||
|
worker_thread = threading.Thread(
|
||||||
|
target=self._generate_worker,
|
||||||
|
kwargs={
|
||||||
|
"flask_app": current_app._get_current_object(), # type: ignore
|
||||||
|
"context": context,
|
||||||
|
"queue_manager": queue_manager,
|
||||||
|
"application_generate_entity": application_generate_entity,
|
||||||
|
"workflow_thread_pool_id": workflow_thread_pool_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_thread.start()
|
||||||
|
|
||||||
|
# return response or stream generator
|
||||||
|
response = self._handle_response(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow=workflow,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
user=user,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
stream=streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from)
|
||||||
|
|
||||||
|
def single_iteration_generate(
|
||||||
|
self,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
node_id: str,
|
||||||
|
user: Account | EndUser,
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
streaming: bool = True,
|
||||||
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||||
|
"""
|
||||||
|
Generate App response.
|
||||||
|
|
||||||
|
:param app_model: App
|
||||||
|
:param workflow: Workflow
|
||||||
|
:param node_id: the node id
|
||||||
|
:param user: account or end user
|
||||||
|
:param args: request args
|
||||||
|
:param streaming: is streamed
|
||||||
|
"""
|
||||||
|
if not node_id:
|
||||||
|
raise ValueError("node_id is required")
|
||||||
|
|
||||||
|
if args.get("inputs") is None:
|
||||||
|
raise ValueError("inputs is required")
|
||||||
|
|
||||||
|
# convert to app config
|
||||||
|
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||||
|
|
||||||
|
dataset = pipeline.dataset
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("Pipeline dataset is required")
|
||||||
|
|
||||||
|
# init application generate entity - use RagPipelineGenerateEntity instead
|
||||||
|
application_generate_entity = RagPipelineGenerateEntity(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
app_config=pipeline_config,
|
||||||
|
pipeline_config=pipeline_config,
|
||||||
|
datasource_type=args.get("datasource_type", ""),
|
||||||
|
datasource_info=args.get("datasource_info", {}),
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
batch=args.get("batch", ""),
|
||||||
|
document_id=args.get("document_id"),
|
||||||
|
inputs={},
|
||||||
|
files=[],
|
||||||
|
user_id=user.id,
|
||||||
|
stream=streaming,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
call_depth=0,
|
||||||
|
workflow_execution_id=str(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
contexts.plugin_tool_providers.set({})
|
||||||
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
# Create workflow node execution repository
|
||||||
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._generate(
|
||||||
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
|
pipeline=pipeline,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
user=user,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
streaming=streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
def single_loop_generate(
|
||||||
|
self,
|
||||||
|
pipeline: Pipeline,
|
||||||
|
workflow: Workflow,
|
||||||
|
node_id: str,
|
||||||
|
user: Account | EndUser,
|
||||||
|
args: Mapping[str, Any],
|
||||||
|
streaming: bool = True,
|
||||||
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||||
|
"""
|
||||||
|
Generate App response.
|
||||||
|
|
||||||
|
:param app_model: App
|
||||||
|
:param workflow: Workflow
|
||||||
|
:param node_id: the node id
|
||||||
|
:param user: account or end user
|
||||||
|
:param args: request args
|
||||||
|
:param streaming: is streamed
|
||||||
|
"""
|
||||||
|
if not node_id:
|
||||||
|
raise ValueError("node_id is required")
|
||||||
|
|
||||||
|
if args.get("inputs") is None:
|
||||||
|
raise ValueError("inputs is required")
|
||||||
|
|
||||||
|
dataset = pipeline.dataset
|
||||||
|
if not dataset:
|
||||||
|
raise ValueError("Pipeline dataset is required")
|
||||||
|
|
||||||
|
# convert to app config
|
||||||
|
pipeline_config = PipelineConfigManager.get_pipeline_config(pipeline=pipeline, workflow=workflow)
|
||||||
|
|
||||||
|
# init application generate entity
|
||||||
|
application_generate_entity = RagPipelineGenerateEntity(
|
||||||
|
task_id=str(uuid.uuid4()),
|
||||||
|
app_config=pipeline_config,
|
||||||
|
pipeline_config=pipeline_config,
|
||||||
|
datasource_type=args.get("datasource_type", ""),
|
||||||
|
datasource_info=args.get("datasource_info", {}),
|
||||||
|
batch=args.get("batch", ""),
|
||||||
|
document_id=args.get("document_id"),
|
||||||
|
dataset_id=dataset.id,
|
||||||
|
inputs={},
|
||||||
|
files=[],
|
||||||
|
user_id=user.id,
|
||||||
|
stream=streaming,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
extras={"auto_generate_conversation_name": False},
|
||||||
|
single_loop_run=RagPipelineGenerateEntity.SingleLoopRunEntity(node_id=node_id, inputs=args["inputs"]),
|
||||||
|
workflow_execution_id=str(uuid.uuid4()),
|
||||||
|
)
|
||||||
|
contexts.plugin_tool_providers.set({})
|
||||||
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
|
|
||||||
|
# Create workflow node execution repository
|
||||||
|
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_DEBUGGING,
|
||||||
|
)
|
||||||
|
|
||||||
|
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
|
session_factory=session_factory,
|
||||||
|
user=user,
|
||||||
|
app_id=application_generate_entity.app_config.app_id,
|
||||||
|
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._generate(
|
||||||
|
flask_app=current_app._get_current_object(), # type: ignore
|
||||||
|
pipeline=pipeline,
|
||||||
|
workflow=workflow,
|
||||||
|
user=user,
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
streaming=streaming,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_worker(
|
||||||
|
self,
|
||||||
|
flask_app: Flask,
|
||||||
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
context: contextvars.Context,
|
||||||
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Generate worker in a new thread.
|
||||||
|
:param flask_app: Flask app
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param workflow_thread_pool_id: workflow thread pool id
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
with preserve_flask_contexts(flask_app, context_vars=context):
|
||||||
|
try:
|
||||||
|
# workflow app
|
||||||
|
runner = PipelineRunner(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
runner.run()
|
||||||
|
except GenerateTaskStoppedError:
|
||||||
|
pass
|
||||||
|
except InvokeAuthorizationError:
|
||||||
|
queue_manager.publish_error(
|
||||||
|
InvokeAuthorizationError("Incorrect API key provided"), PublishFrom.APPLICATION_MANAGER
|
||||||
|
)
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.exception("Validation Error when generating")
|
||||||
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
except ValueError as e:
|
||||||
|
if dify_config.DEBUG:
|
||||||
|
logger.exception("Error when generating")
|
||||||
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Unknown Error when generating")
|
||||||
|
queue_manager.publish_error(e, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
finally:
|
||||||
|
db.session.close()
|
||||||
|
|
||||||
|
def _handle_response(
|
||||||
|
self,
|
||||||
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
|
workflow: Workflow,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
user: Union[Account, EndUser],
|
||||||
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
|
stream: bool = False,
|
||||||
|
) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
|
||||||
|
"""
|
||||||
|
Handle response.
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param workflow: workflow
|
||||||
|
:param queue_manager: queue manager
|
||||||
|
:param user: account or end user
|
||||||
|
:param stream: is stream
|
||||||
|
:param workflow_node_execution_repository: optional repository for workflow node execution
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# init generate task pipeline
|
||||||
|
generate_task_pipeline = WorkflowAppGenerateTaskPipeline(
|
||||||
|
application_generate_entity=application_generate_entity,
|
||||||
|
workflow=workflow,
|
||||||
|
queue_manager=queue_manager,
|
||||||
|
user=user,
|
||||||
|
stream=stream,
|
||||||
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return generate_task_pipeline.process()
|
||||||
|
except ValueError as e:
|
||||||
|
if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
|
||||||
|
raise GenerateTaskStoppedError()
|
||||||
|
else:
|
||||||
|
logger.exception(
|
||||||
|
f"Fails to process generate task pipeline, task_id: {application_generate_entity.task_id}"
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def _build_document(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
dataset_id: str,
|
||||||
|
built_in_field_enabled: bool,
|
||||||
|
datasource_type: str,
|
||||||
|
datasource_info: Mapping[str, Any],
|
||||||
|
created_from: str,
|
||||||
|
position: int,
|
||||||
|
account: Union[Account, EndUser],
|
||||||
|
batch: str,
|
||||||
|
document_form: str,
|
||||||
|
):
|
||||||
|
if datasource_type == "local_file":
|
||||||
|
name = datasource_info["name"]
|
||||||
|
elif datasource_type == "online_document":
|
||||||
|
name = datasource_info["page"]["page_name"]
|
||||||
|
elif datasource_type == "website_crawl":
|
||||||
|
name = datasource_info["title"]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||||
|
|
||||||
|
document = Document(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
position=position,
|
||||||
|
data_source_type=datasource_type,
|
||||||
|
data_source_info=json.dumps(datasource_info),
|
||||||
|
batch=batch,
|
||||||
|
name=name,
|
||||||
|
created_from=created_from,
|
||||||
|
created_by=account.id,
|
||||||
|
doc_form=document_form,
|
||||||
|
)
|
||||||
|
doc_metadata = {}
|
||||||
|
if built_in_field_enabled:
|
||||||
|
doc_metadata = {
|
||||||
|
BuiltInField.document_name: name,
|
||||||
|
BuiltInField.uploader: account.name,
|
||||||
|
BuiltInField.upload_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
BuiltInField.last_update_date: datetime.datetime.now(datetime.UTC).strftime("%Y-%m-%d %H:%M:%S"),
|
||||||
|
BuiltInField.source: datasource_type,
|
||||||
|
}
|
||||||
|
if doc_metadata:
|
||||||
|
document.doc_metadata = doc_metadata
|
||||||
|
return document
|
||||||
@ -0,0 +1,44 @@
|
|||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager, GenerateTaskStoppedError, PublishFrom
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.app.entities.queue_entities import (
|
||||||
|
AppQueueEvent,
|
||||||
|
QueueErrorEvent,
|
||||||
|
QueueMessageEndEvent,
|
||||||
|
QueueStopEvent,
|
||||||
|
QueueWorkflowFailedEvent,
|
||||||
|
QueueWorkflowPartialSuccessEvent,
|
||||||
|
QueueWorkflowSucceededEvent,
|
||||||
|
WorkflowQueueMessage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineQueueManager(AppQueueManager):
|
||||||
|
def __init__(self, task_id: str, user_id: str, invoke_from: InvokeFrom, app_mode: str) -> None:
|
||||||
|
super().__init__(task_id, user_id, invoke_from)
|
||||||
|
|
||||||
|
self._app_mode = app_mode
|
||||||
|
|
||||||
|
def _publish(self, event: AppQueueEvent, pub_from: PublishFrom) -> None:
|
||||||
|
"""
|
||||||
|
Publish event to queue
|
||||||
|
:param event:
|
||||||
|
:param pub_from:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
message = WorkflowQueueMessage(task_id=self._task_id, app_mode=self._app_mode, event=event)
|
||||||
|
|
||||||
|
self._q.put(message)
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
event,
|
||||||
|
QueueStopEvent
|
||||||
|
| QueueErrorEvent
|
||||||
|
| QueueMessageEndEvent
|
||||||
|
| QueueWorkflowSucceededEvent
|
||||||
|
| QueueWorkflowFailedEvent
|
||||||
|
| QueueWorkflowPartialSuccessEvent,
|
||||||
|
):
|
||||||
|
self.stop_listen()
|
||||||
|
|
||||||
|
if pub_from == PublishFrom.APPLICATION_MANAGER and self._is_stopped():
|
||||||
|
raise GenerateTaskStoppedError()
|
||||||
@ -0,0 +1,221 @@
|
|||||||
|
import logging
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, Optional, cast
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
|
from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig
|
||||||
|
from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner
|
||||||
|
from core.app.entities.app_invoke_entities import (
|
||||||
|
InvokeFrom,
|
||||||
|
RagPipelineGenerateEntity,
|
||||||
|
)
|
||||||
|
from core.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput
|
||||||
|
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.enums import SystemVariableKey
|
||||||
|
from core.workflow.graph_engine.entities.graph import Graph
|
||||||
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Pipeline
|
||||||
|
from models.enums import UserFrom
|
||||||
|
from models.model import EndUser
|
||||||
|
from models.workflow import Workflow, WorkflowType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PipelineRunner(WorkflowBasedAppRunner):
|
||||||
|
"""
|
||||||
|
Pipeline Application Runner
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
application_generate_entity: RagPipelineGenerateEntity,
|
||||||
|
queue_manager: AppQueueManager,
|
||||||
|
workflow_thread_pool_id: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
:param application_generate_entity: application generate entity
|
||||||
|
:param queue_manager: application queue manager
|
||||||
|
:param workflow_thread_pool_id: workflow thread pool id
|
||||||
|
"""
|
||||||
|
self.application_generate_entity = application_generate_entity
|
||||||
|
self.queue_manager = queue_manager
|
||||||
|
self.workflow_thread_pool_id = workflow_thread_pool_id
|
||||||
|
|
||||||
|
def _get_app_id(self) -> str:
|
||||||
|
return self.application_generate_entity.app_config.app_id
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""
|
||||||
|
Run application
|
||||||
|
"""
|
||||||
|
app_config = self.application_generate_entity.app_config
|
||||||
|
app_config = cast(PipelineConfig, app_config)
|
||||||
|
|
||||||
|
user_id = None
|
||||||
|
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
|
||||||
|
end_user = db.session.query(EndUser).filter(EndUser.id == self.application_generate_entity.user_id).first()
|
||||||
|
if end_user:
|
||||||
|
user_id = end_user.session_id
|
||||||
|
else:
|
||||||
|
user_id = self.application_generate_entity.user_id
|
||||||
|
|
||||||
|
pipeline = db.session.query(Pipeline).filter(Pipeline.id == app_config.app_id).first()
|
||||||
|
if not pipeline:
|
||||||
|
raise ValueError("Pipeline not found")
|
||||||
|
|
||||||
|
workflow = self.get_workflow(pipeline=pipeline, workflow_id=app_config.workflow_id)
|
||||||
|
if not workflow:
|
||||||
|
raise ValueError("Workflow not initialized")
|
||||||
|
|
||||||
|
db.session.close()
|
||||||
|
|
||||||
|
workflow_callbacks: list[WorkflowCallback] = []
|
||||||
|
if dify_config.DEBUG:
|
||||||
|
workflow_callbacks.append(WorkflowLoggingCallback())
|
||||||
|
|
||||||
|
# if only single iteration run is requested
|
||||||
|
if self.application_generate_entity.single_iteration_run:
|
||||||
|
# if only single iteration run is requested
|
||||||
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_iteration(
|
||||||
|
workflow=workflow,
|
||||||
|
node_id=self.application_generate_entity.single_iteration_run.node_id,
|
||||||
|
user_inputs=self.application_generate_entity.single_iteration_run.inputs,
|
||||||
|
)
|
||||||
|
elif self.application_generate_entity.single_loop_run:
|
||||||
|
# if only single loop run is requested
|
||||||
|
graph, variable_pool = self._get_graph_and_variable_pool_of_single_loop(
|
||||||
|
workflow=workflow,
|
||||||
|
node_id=self.application_generate_entity.single_loop_run.node_id,
|
||||||
|
user_inputs=self.application_generate_entity.single_loop_run.inputs,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
inputs = self.application_generate_entity.inputs
|
||||||
|
files = self.application_generate_entity.files
|
||||||
|
|
||||||
|
# Create a variable pool.
|
||||||
|
system_inputs = {
|
||||||
|
SystemVariableKey.FILES: files,
|
||||||
|
SystemVariableKey.USER_ID: user_id,
|
||||||
|
SystemVariableKey.APP_ID: app_config.app_id,
|
||||||
|
SystemVariableKey.WORKFLOW_ID: app_config.workflow_id,
|
||||||
|
SystemVariableKey.WORKFLOW_EXECUTION_ID: self.application_generate_entity.workflow_execution_id,
|
||||||
|
SystemVariableKey.DOCUMENT_ID: self.application_generate_entity.document_id,
|
||||||
|
SystemVariableKey.BATCH: self.application_generate_entity.batch,
|
||||||
|
SystemVariableKey.DATASET_ID: self.application_generate_entity.dataset_id,
|
||||||
|
SystemVariableKey.DATASOURCE_TYPE: self.application_generate_entity.datasource_type,
|
||||||
|
SystemVariableKey.DATASOURCE_INFO: self.application_generate_entity.datasource_info,
|
||||||
|
SystemVariableKey.INVOKE_FROM: self.application_generate_entity.invoke_from.value,
|
||||||
|
}
|
||||||
|
rag_pipeline_variables = []
|
||||||
|
if workflow.rag_pipeline_variables:
|
||||||
|
for v in workflow.rag_pipeline_variables:
|
||||||
|
rag_pipeline_variable = RAGPipelineVariable(**v)
|
||||||
|
if (
|
||||||
|
rag_pipeline_variable.belong_to_node_id
|
||||||
|
in (self.application_generate_entity.start_node_id, "shared")
|
||||||
|
) and rag_pipeline_variable.variable in inputs:
|
||||||
|
rag_pipeline_variables.append(
|
||||||
|
RAGPipelineVariableInput(
|
||||||
|
variable=rag_pipeline_variable,
|
||||||
|
value=inputs[rag_pipeline_variable.variable],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
variable_pool = VariablePool(
|
||||||
|
system_variables=system_inputs,
|
||||||
|
user_inputs=inputs,
|
||||||
|
environment_variables=workflow.environment_variables,
|
||||||
|
conversation_variables=[],
|
||||||
|
rag_pipeline_variables=rag_pipeline_variables,
|
||||||
|
)
|
||||||
|
|
||||||
|
# init graph
|
||||||
|
graph = self._init_rag_pipeline_graph(
|
||||||
|
graph_config=workflow.graph_dict,
|
||||||
|
start_node_id=self.application_generate_entity.start_node_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# RUN WORKFLOW
|
||||||
|
workflow_entry = WorkflowEntry(
|
||||||
|
tenant_id=workflow.tenant_id,
|
||||||
|
app_id=workflow.app_id,
|
||||||
|
workflow_id=workflow.id,
|
||||||
|
workflow_type=WorkflowType.value_of(workflow.type),
|
||||||
|
graph=graph,
|
||||||
|
graph_config=workflow.graph_dict,
|
||||||
|
user_id=self.application_generate_entity.user_id,
|
||||||
|
user_from=(
|
||||||
|
UserFrom.ACCOUNT
|
||||||
|
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||||
|
else UserFrom.END_USER
|
||||||
|
),
|
||||||
|
invoke_from=self.application_generate_entity.invoke_from,
|
||||||
|
call_depth=self.application_generate_entity.call_depth,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
thread_pool_id=self.workflow_thread_pool_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
generator = workflow_entry.run(callbacks=workflow_callbacks)
|
||||||
|
|
||||||
|
for event in generator:
|
||||||
|
self._handle_event(workflow_entry, event)
|
||||||
|
|
||||||
|
def get_workflow(self, pipeline: Pipeline, workflow_id: str) -> Optional[Workflow]:
|
||||||
|
"""
|
||||||
|
Get workflow
|
||||||
|
"""
|
||||||
|
# fetch workflow by workflow_id
|
||||||
|
workflow = (
|
||||||
|
db.session.query(Workflow)
|
||||||
|
.filter(
|
||||||
|
Workflow.tenant_id == pipeline.tenant_id, Workflow.app_id == pipeline.id, Workflow.id == workflow_id
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
# return workflow
|
||||||
|
return workflow
|
||||||
|
|
||||||
|
def _init_rag_pipeline_graph(self, graph_config: Mapping[str, Any], start_node_id: Optional[str] = None) -> Graph:
|
||||||
|
"""
|
||||||
|
Init pipeline graph
|
||||||
|
"""
|
||||||
|
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||||
|
raise ValueError("nodes or edges not found in workflow graph")
|
||||||
|
|
||||||
|
if not isinstance(graph_config.get("nodes"), list):
|
||||||
|
raise ValueError("nodes in workflow graph must be a list")
|
||||||
|
|
||||||
|
if not isinstance(graph_config.get("edges"), list):
|
||||||
|
raise ValueError("edges in workflow graph must be a list")
|
||||||
|
nodes = graph_config.get("nodes", [])
|
||||||
|
edges = graph_config.get("edges", [])
|
||||||
|
real_run_nodes = []
|
||||||
|
real_edges = []
|
||||||
|
exclude_node_ids = []
|
||||||
|
for node in nodes:
|
||||||
|
node_id = node.get("id")
|
||||||
|
node_type = node.get("data", {}).get("type", "")
|
||||||
|
if node_type == "datasource":
|
||||||
|
if start_node_id != node_id:
|
||||||
|
exclude_node_ids.append(node_id)
|
||||||
|
continue
|
||||||
|
real_run_nodes.append(node)
|
||||||
|
for edge in edges:
|
||||||
|
if edge.get("source") in exclude_node_ids:
|
||||||
|
continue
|
||||||
|
real_edges.append(edge)
|
||||||
|
graph_config = dict(graph_config)
|
||||||
|
graph_config["nodes"] = real_run_nodes
|
||||||
|
graph_config["edges"] = real_edges
|
||||||
|
# init graph
|
||||||
|
graph = Graph.init(graph_config=graph_config)
|
||||||
|
|
||||||
|
if not graph:
|
||||||
|
raise ValueError("graph not found in workflow")
|
||||||
|
|
||||||
|
return graph
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceProviderType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourcePlugin(ABC):
|
||||||
|
entity: DatasourceEntity
|
||||||
|
runtime: DatasourceRuntime
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
) -> None:
|
||||||
|
self.entity = entity
|
||||||
|
self.runtime = runtime
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def datasource_provider_type(self) -> str:
|
||||||
|
"""
|
||||||
|
returns the type of the datasource provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
|
|
||||||
|
def fork_datasource_runtime(self, runtime: DatasourceRuntime) -> "DatasourcePlugin":
|
||||||
|
return self.__class__(
|
||||||
|
entity=self.entity.model_copy(),
|
||||||
|
runtime=runtime,
|
||||||
|
)
|
||||||
@ -0,0 +1,118 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.entities.provider_entities import ProviderConfig
|
||||||
|
from core.plugin.impl.tool import PluginToolManager
|
||||||
|
from core.tools.errors import ToolProviderCredentialValidationError
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourcePluginProviderController(ABC):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
tenant_id: str
|
||||||
|
|
||||||
|
def __init__(self, entity: DatasourceProviderEntityWithPlugin, tenant_id: str) -> None:
|
||||||
|
self.entity = entity
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def need_credentials(self) -> bool:
|
||||||
|
"""
|
||||||
|
returns whether the provider needs credentials
|
||||||
|
|
||||||
|
:return: whether the provider needs credentials
|
||||||
|
"""
|
||||||
|
return self.entity.credentials_schema is not None and len(self.entity.credentials_schema) != 0
|
||||||
|
|
||||||
|
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
validate the credentials of the provider
|
||||||
|
"""
|
||||||
|
manager = PluginToolManager()
|
||||||
|
if not manager.validate_datasource_credentials(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
provider=self.entity.identity.name,
|
||||||
|
credentials=credentials,
|
||||||
|
):
|
||||||
|
raise ToolProviderCredentialValidationError("Invalid credentials")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_datasource(self, datasource_name: str) -> DatasourcePlugin:
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
validate the format of the credentials of the provider and set the default value if needed
|
||||||
|
|
||||||
|
:param credentials: the credentials of the tool
|
||||||
|
"""
|
||||||
|
credentials_schema = dict[str, ProviderConfig]()
|
||||||
|
if credentials_schema is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for credential in self.entity.credentials_schema:
|
||||||
|
credentials_schema[credential.name] = credential
|
||||||
|
|
||||||
|
credentials_need_to_validate: dict[str, ProviderConfig] = {}
|
||||||
|
for credential_name in credentials_schema:
|
||||||
|
credentials_need_to_validate[credential_name] = credentials_schema[credential_name]
|
||||||
|
|
||||||
|
for credential_name in credentials:
|
||||||
|
if credential_name not in credentials_need_to_validate:
|
||||||
|
raise ToolProviderCredentialValidationError(
|
||||||
|
f"credential {credential_name} not found in provider {self.entity.identity.name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# check type
|
||||||
|
credential_schema = credentials_need_to_validate[credential_name]
|
||||||
|
if not credential_schema.required and credentials[credential_name] is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if credential_schema.type in {ProviderConfig.Type.SECRET_INPUT, ProviderConfig.Type.TEXT_INPUT}:
|
||||||
|
if not isinstance(credentials[credential_name], str):
|
||||||
|
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||||
|
|
||||||
|
elif credential_schema.type == ProviderConfig.Type.SELECT:
|
||||||
|
if not isinstance(credentials[credential_name], str):
|
||||||
|
raise ToolProviderCredentialValidationError(f"credential {credential_name} should be string")
|
||||||
|
|
||||||
|
options = credential_schema.options
|
||||||
|
if not isinstance(options, list):
|
||||||
|
raise ToolProviderCredentialValidationError(f"credential {credential_name} options should be list")
|
||||||
|
|
||||||
|
if credentials[credential_name] not in [x.value for x in options]:
|
||||||
|
raise ToolProviderCredentialValidationError(
|
||||||
|
f"credential {credential_name} should be one of {options}"
|
||||||
|
)
|
||||||
|
|
||||||
|
credentials_need_to_validate.pop(credential_name)
|
||||||
|
|
||||||
|
for credential_name in credentials_need_to_validate:
|
||||||
|
credential_schema = credentials_need_to_validate[credential_name]
|
||||||
|
if credential_schema.required:
|
||||||
|
raise ToolProviderCredentialValidationError(f"credential {credential_name} is required")
|
||||||
|
|
||||||
|
# the credential is not set currently, set the default value if needed
|
||||||
|
if credential_schema.default is not None:
|
||||||
|
default_value = credential_schema.default
|
||||||
|
# parse default value into the correct type
|
||||||
|
if credential_schema.type in {
|
||||||
|
ProviderConfig.Type.SECRET_INPUT,
|
||||||
|
ProviderConfig.Type.TEXT_INPUT,
|
||||||
|
ProviderConfig.Type.SELECT,
|
||||||
|
}:
|
||||||
|
default_value = str(default_value)
|
||||||
|
|
||||||
|
credentials[credential_name] = default_value
|
||||||
@ -0,0 +1,36 @@
|
|||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from openai import BaseModel
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceRuntime(BaseModel):
|
||||||
|
"""
|
||||||
|
Meta data of a datasource call processing
|
||||||
|
"""
|
||||||
|
|
||||||
|
tenant_id: str
|
||||||
|
datasource_id: Optional[str] = None
|
||||||
|
invoke_from: Optional[InvokeFrom] = None
|
||||||
|
datasource_invoke_from: Optional[DatasourceInvokeFrom] = None
|
||||||
|
credentials: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
runtime_parameters: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class FakeDatasourceRuntime(DatasourceRuntime):
|
||||||
|
"""
|
||||||
|
Fake datasource runtime for testing
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(
|
||||||
|
tenant_id="fake_tenant_id",
|
||||||
|
datasource_id="fake_datasource_id",
|
||||||
|
invoke_from=InvokeFrom.DEBUGGER,
|
||||||
|
datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE,
|
||||||
|
credentials={},
|
||||||
|
runtime_parameters={},
|
||||||
|
)
|
||||||
@ -0,0 +1,244 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from mimetypes import guess_extension, guess_type
|
||||||
|
from typing import Optional, Union
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.helper import ssrf_proxy
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from extensions.ext_storage import storage
|
||||||
|
from models.enums import CreatorUserRole
|
||||||
|
from models.model import MessageFile, UploadFile
|
||||||
|
from models.tools import ToolFile
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceFileManager:
|
||||||
|
@staticmethod
|
||||||
|
def sign_file(datasource_file_id: str, extension: str) -> str:
|
||||||
|
"""
|
||||||
|
sign file to get a temporary url
|
||||||
|
"""
|
||||||
|
base_url = dify_config.FILES_URL
|
||||||
|
file_preview_url = f"{base_url}/files/datasources/{datasource_file_id}{extension}"
|
||||||
|
|
||||||
|
timestamp = str(int(time.time()))
|
||||||
|
nonce = os.urandom(16).hex()
|
||||||
|
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||||
|
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||||
|
|
||||||
|
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def verify_file(datasource_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||||
|
"""
|
||||||
|
verify signature
|
||||||
|
"""
|
||||||
|
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||||
|
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
|
||||||
|
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||||
|
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||||
|
|
||||||
|
# verify signature
|
||||||
|
if sign != recalculated_encoded_sign:
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_time = int(time.time())
|
||||||
|
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_file_by_raw(
|
||||||
|
*,
|
||||||
|
user_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
conversation_id: Optional[str],
|
||||||
|
file_binary: bytes,
|
||||||
|
mimetype: str,
|
||||||
|
filename: Optional[str] = None,
|
||||||
|
) -> UploadFile:
|
||||||
|
extension = guess_extension(mimetype) or ".bin"
|
||||||
|
unique_name = uuid4().hex
|
||||||
|
unique_filename = f"{unique_name}{extension}"
|
||||||
|
# default just as before
|
||||||
|
present_filename = unique_filename
|
||||||
|
if filename is not None:
|
||||||
|
has_extension = len(filename.split(".")) > 1
|
||||||
|
# Add extension flexibly
|
||||||
|
present_filename = filename if has_extension else f"{filename}{extension}"
|
||||||
|
filepath = f"datasources/{tenant_id}/{unique_filename}"
|
||||||
|
storage.save(filepath, file_binary)
|
||||||
|
|
||||||
|
upload_file = UploadFile(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
storage_type=dify_config.STORAGE_TYPE,
|
||||||
|
key=filepath,
|
||||||
|
name=present_filename,
|
||||||
|
size=len(file_binary),
|
||||||
|
extension=extension,
|
||||||
|
mime_type=mimetype,
|
||||||
|
created_by_role=CreatorUserRole.ACCOUNT,
|
||||||
|
created_by=user_id,
|
||||||
|
used=False,
|
||||||
|
hash=hashlib.sha3_256(file_binary).hexdigest(),
|
||||||
|
source_url="",
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(upload_file)
|
||||||
|
db.session.commit()
|
||||||
|
db.session.refresh(upload_file)
|
||||||
|
|
||||||
|
return upload_file
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def create_file_by_url(
|
||||||
|
user_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
file_url: str,
|
||||||
|
conversation_id: Optional[str] = None,
|
||||||
|
) -> UploadFile:
|
||||||
|
# try to download image
|
||||||
|
try:
|
||||||
|
response = ssrf_proxy.get(file_url)
|
||||||
|
response.raise_for_status()
|
||||||
|
blob = response.content
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
raise ValueError(f"timeout when downloading file from {file_url}")
|
||||||
|
|
||||||
|
mimetype = (
|
||||||
|
guess_type(file_url)[0]
|
||||||
|
or response.headers.get("Content-Type", "").split(";")[0].strip()
|
||||||
|
or "application/octet-stream"
|
||||||
|
)
|
||||||
|
extension = guess_extension(mimetype) or ".bin"
|
||||||
|
unique_name = uuid4().hex
|
||||||
|
filename = f"{unique_name}{extension}"
|
||||||
|
filepath = f"tools/{tenant_id}/{filename}"
|
||||||
|
storage.save(filepath, blob)
|
||||||
|
|
||||||
|
upload_file = UploadFile(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
storage_type=dify_config.STORAGE_TYPE,
|
||||||
|
key=filepath,
|
||||||
|
name=filename,
|
||||||
|
size=len(blob),
|
||||||
|
extension=extension,
|
||||||
|
mime_type=mimetype,
|
||||||
|
created_by_role=CreatorUserRole.ACCOUNT,
|
||||||
|
created_by=user_id,
|
||||||
|
used=False,
|
||||||
|
hash=hashlib.sha3_256(blob).hexdigest(),
|
||||||
|
source_url=file_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add(upload_file)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return upload_file
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
|
||||||
|
"""
|
||||||
|
get file binary
|
||||||
|
|
||||||
|
:param id: the id of the file
|
||||||
|
|
||||||
|
:return: the binary of the file, mime type
|
||||||
|
"""
|
||||||
|
upload_file: UploadFile | None = (
|
||||||
|
db.session.query(UploadFile)
|
||||||
|
.filter(
|
||||||
|
UploadFile.id == id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not upload_file:
|
||||||
|
return None
|
||||||
|
|
||||||
|
blob = storage.load_once(upload_file.key)
|
||||||
|
|
||||||
|
return blob, upload_file.mime_type
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
|
||||||
|
"""
|
||||||
|
get file binary
|
||||||
|
|
||||||
|
:param id: the id of the file
|
||||||
|
|
||||||
|
:return: the binary of the file, mime type
|
||||||
|
"""
|
||||||
|
message_file: MessageFile | None = (
|
||||||
|
db.session.query(MessageFile)
|
||||||
|
.filter(
|
||||||
|
MessageFile.id == id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check if message_file is not None
|
||||||
|
if message_file is not None:
|
||||||
|
# get tool file id
|
||||||
|
if message_file.url is not None:
|
||||||
|
tool_file_id = message_file.url.split("/")[-1]
|
||||||
|
# trim extension
|
||||||
|
tool_file_id = tool_file_id.split(".")[0]
|
||||||
|
else:
|
||||||
|
tool_file_id = None
|
||||||
|
else:
|
||||||
|
tool_file_id = None
|
||||||
|
|
||||||
|
tool_file: ToolFile | None = (
|
||||||
|
db.session.query(ToolFile)
|
||||||
|
.filter(
|
||||||
|
ToolFile.id == tool_file_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not tool_file:
|
||||||
|
return None
|
||||||
|
|
||||||
|
blob = storage.load_once(tool_file.file_key)
|
||||||
|
|
||||||
|
return blob, tool_file.mimetype
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_file_generator_by_upload_file_id(upload_file_id: str):
|
||||||
|
"""
|
||||||
|
get file binary
|
||||||
|
|
||||||
|
:param tool_file_id: the id of the tool file
|
||||||
|
|
||||||
|
:return: the binary of the file, mime type
|
||||||
|
"""
|
||||||
|
upload_file: UploadFile | None = (
|
||||||
|
db.session.query(UploadFile)
|
||||||
|
.filter(
|
||||||
|
UploadFile.id == upload_file_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
|
||||||
|
if not upload_file:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
stream = storage.load_stream(upload_file.key)
|
||||||
|
|
||||||
|
return stream, upload_file.mime_type
|
||||||
|
|
||||||
|
|
||||||
|
# init tool_file_parser
|
||||||
|
# from core.file.datasource_file_parser import datasource_file_manager
|
||||||
|
#
|
||||||
|
# datasource_file_manager["manager"] = DatasourceFileManager
|
||||||
@ -0,0 +1,100 @@
|
|||||||
|
import logging
|
||||||
|
from threading import Lock
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
import contexts
|
||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.entities.common_entities import I18nObject
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderType
|
||||||
|
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||||
|
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
|
||||||
|
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
|
||||||
|
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||||
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceManager:
|
||||||
|
_builtin_provider_lock = Lock()
|
||||||
|
_hardcoded_providers: dict[str, DatasourcePluginProviderController] = {}
|
||||||
|
_builtin_providers_loaded = False
|
||||||
|
_builtin_tools_labels: dict[str, Union[I18nObject, None]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_datasource_plugin_provider(
|
||||||
|
cls, provider_id: str, tenant_id: str, datasource_type: DatasourceProviderType
|
||||||
|
) -> DatasourcePluginProviderController:
|
||||||
|
"""
|
||||||
|
get the datasource plugin provider
|
||||||
|
"""
|
||||||
|
# check if context is set
|
||||||
|
try:
|
||||||
|
contexts.datasource_plugin_providers.get()
|
||||||
|
except LookupError:
|
||||||
|
contexts.datasource_plugin_providers.set({})
|
||||||
|
contexts.datasource_plugin_providers_lock.set(Lock())
|
||||||
|
|
||||||
|
with contexts.datasource_plugin_providers_lock.get():
|
||||||
|
datasource_plugin_providers = contexts.datasource_plugin_providers.get()
|
||||||
|
if provider_id in datasource_plugin_providers:
|
||||||
|
return datasource_plugin_providers[provider_id]
|
||||||
|
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
provider_entity = manager.fetch_datasource_provider(tenant_id, provider_id)
|
||||||
|
if not provider_entity:
|
||||||
|
raise DatasourceProviderNotFoundError(f"plugin provider {provider_id} not found")
|
||||||
|
|
||||||
|
match datasource_type:
|
||||||
|
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||||
|
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||||
|
entity=provider_entity.declaration,
|
||||||
|
plugin_id=provider_entity.plugin_id,
|
||||||
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||||
|
controller = WebsiteCrawlDatasourcePluginProviderController(
|
||||||
|
entity=provider_entity.declaration,
|
||||||
|
plugin_id=provider_entity.plugin_id,
|
||||||
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
case DatasourceProviderType.LOCAL_FILE:
|
||||||
|
controller = LocalFileDatasourcePluginProviderController(
|
||||||
|
entity=provider_entity.declaration,
|
||||||
|
plugin_id=provider_entity.plugin_id,
|
||||||
|
plugin_unique_identifier=provider_entity.plugin_unique_identifier,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
raise ValueError(f"Unsupported datasource type: {datasource_type}")
|
||||||
|
|
||||||
|
datasource_plugin_providers[provider_id] = controller
|
||||||
|
|
||||||
|
return controller
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_datasource_runtime(
|
||||||
|
cls,
|
||||||
|
provider_id: str,
|
||||||
|
datasource_name: str,
|
||||||
|
tenant_id: str,
|
||||||
|
datasource_type: DatasourceProviderType,
|
||||||
|
) -> DatasourcePlugin:
|
||||||
|
"""
|
||||||
|
get the datasource runtime
|
||||||
|
|
||||||
|
:param provider_type: the type of the provider
|
||||||
|
:param provider_id: the id of the provider
|
||||||
|
:param datasource_name: the name of the datasource
|
||||||
|
:param tenant_id: the tenant id
|
||||||
|
|
||||||
|
:return: the datasource plugin
|
||||||
|
"""
|
||||||
|
return cls.get_datasource_plugin_provider(
|
||||||
|
provider_id,
|
||||||
|
tenant_id,
|
||||||
|
datasource_type,
|
||||||
|
).get_datasource(datasource_name)
|
||||||
@ -0,0 +1,71 @@
|
|||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, field_validator
|
||||||
|
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceParameter
|
||||||
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceApiEntity(BaseModel):
|
||||||
|
author: str
|
||||||
|
name: str # identifier
|
||||||
|
label: I18nObject # label
|
||||||
|
description: I18nObject
|
||||||
|
parameters: Optional[list[DatasourceParameter]] = None
|
||||||
|
labels: list[str] = Field(default_factory=list)
|
||||||
|
output_schema: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderApiEntity(BaseModel):
|
||||||
|
id: str
|
||||||
|
author: str
|
||||||
|
name: str # identifier
|
||||||
|
description: I18nObject
|
||||||
|
icon: str | dict
|
||||||
|
label: I18nObject # label
|
||||||
|
type: str
|
||||||
|
masked_credentials: Optional[dict] = None
|
||||||
|
original_credentials: Optional[dict] = None
|
||||||
|
is_team_authorization: bool = False
|
||||||
|
allow_delete: bool = True
|
||||||
|
plugin_id: Optional[str] = Field(default="", description="The plugin id of the datasource")
|
||||||
|
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the datasource")
|
||||||
|
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
|
||||||
|
labels: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
@field_validator("datasources", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def convert_none_to_empty_list(cls, v):
|
||||||
|
return v if v is not None else []
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
# -------------
|
||||||
|
# overwrite datasource parameter types for temp fix
|
||||||
|
datasources = jsonable_encoder(self.datasources)
|
||||||
|
for datasource in datasources:
|
||||||
|
if datasource.get("parameters"):
|
||||||
|
for parameter in datasource.get("parameters"):
|
||||||
|
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
|
||||||
|
parameter["type"] = "files"
|
||||||
|
# -------------
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"author": self.author,
|
||||||
|
"name": self.name,
|
||||||
|
"plugin_id": self.plugin_id,
|
||||||
|
"plugin_unique_identifier": self.plugin_unique_identifier,
|
||||||
|
"description": self.description.to_dict(),
|
||||||
|
"icon": self.icon,
|
||||||
|
"label": self.label.to_dict(),
|
||||||
|
"type": self.type.value,
|
||||||
|
"team_credentials": self.masked_credentials,
|
||||||
|
"is_team_authorization": self.is_team_authorization,
|
||||||
|
"allow_delete": self.allow_delete,
|
||||||
|
"datasources": datasources,
|
||||||
|
"labels": self.labels,
|
||||||
|
}
|
||||||
@ -0,0 +1,23 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class I18nObject(BaseModel):
|
||||||
|
"""
|
||||||
|
Model class for i18n object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
en_US: str
|
||||||
|
zh_Hans: Optional[str] = Field(default=None)
|
||||||
|
pt_BR: Optional[str] = Field(default=None)
|
||||||
|
ja_JP: Optional[str] = Field(default=None)
|
||||||
|
|
||||||
|
def __init__(self, **data):
|
||||||
|
super().__init__(**data)
|
||||||
|
self.zh_Hans = self.zh_Hans or self.en_US
|
||||||
|
self.pt_BR = self.pt_BR or self.en_US
|
||||||
|
self.ja_JP = self.ja_JP or self.en_US
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {"zh_Hans": self.zh_Hans, "en_US": self.en_US, "pt_BR": self.pt_BR, "ja_JP": self.ja_JP}
|
||||||
@ -0,0 +1,361 @@
|
|||||||
|
import enum
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, ValidationInfo, field_validator
|
||||||
|
|
||||||
|
from core.entities.provider_entities import ProviderConfig
|
||||||
|
from core.plugin.entities.oauth import OAuthSchema
|
||||||
|
from core.plugin.entities.parameters import (
|
||||||
|
PluginParameter,
|
||||||
|
PluginParameterOption,
|
||||||
|
PluginParameterType,
|
||||||
|
as_normal_type,
|
||||||
|
cast_parameter_value,
|
||||||
|
init_frontend_parameter,
|
||||||
|
)
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolLabelEnum
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderType(enum.StrEnum):
|
||||||
|
"""
|
||||||
|
Enum class for datasource provider
|
||||||
|
"""
|
||||||
|
|
||||||
|
ONLINE_DOCUMENT = "online_document"
|
||||||
|
LOCAL_FILE = "local_file"
|
||||||
|
WEBSITE_CRAWL = "website_crawl"
|
||||||
|
ONLINE_DRIVE = "online_drive"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def value_of(cls, value: str) -> "DatasourceProviderType":
|
||||||
|
"""
|
||||||
|
Get value of given mode.
|
||||||
|
|
||||||
|
:param value: mode value
|
||||||
|
:return: mode
|
||||||
|
"""
|
||||||
|
for mode in cls:
|
||||||
|
if mode.value == value:
|
||||||
|
return mode
|
||||||
|
raise ValueError(f"invalid mode value {value}")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceParameter(PluginParameter):
|
||||||
|
"""
|
||||||
|
Overrides type
|
||||||
|
"""
|
||||||
|
|
||||||
|
class DatasourceParameterType(enum.StrEnum):
|
||||||
|
"""
|
||||||
|
removes TOOLS_SELECTOR from PluginParameterType
|
||||||
|
"""
|
||||||
|
|
||||||
|
STRING = PluginParameterType.STRING.value
|
||||||
|
NUMBER = PluginParameterType.NUMBER.value
|
||||||
|
BOOLEAN = PluginParameterType.BOOLEAN.value
|
||||||
|
SELECT = PluginParameterType.SELECT.value
|
||||||
|
SECRET_INPUT = PluginParameterType.SECRET_INPUT.value
|
||||||
|
FILE = PluginParameterType.FILE.value
|
||||||
|
FILES = PluginParameterType.FILES.value
|
||||||
|
|
||||||
|
# deprecated, should not use.
|
||||||
|
SYSTEM_FILES = PluginParameterType.SYSTEM_FILES.value
|
||||||
|
|
||||||
|
def as_normal_type(self):
|
||||||
|
return as_normal_type(self)
|
||||||
|
|
||||||
|
def cast_value(self, value: Any):
|
||||||
|
return cast_parameter_value(self, value)
|
||||||
|
|
||||||
|
type: DatasourceParameterType = Field(..., description="The type of the parameter")
|
||||||
|
description: I18nObject = Field(..., description="The description of the parameter")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_simple_instance(
|
||||||
|
cls,
|
||||||
|
name: str,
|
||||||
|
typ: DatasourceParameterType,
|
||||||
|
required: bool,
|
||||||
|
options: Optional[list[str]] = None,
|
||||||
|
) -> "DatasourceParameter":
|
||||||
|
"""
|
||||||
|
get a simple datasource parameter
|
||||||
|
|
||||||
|
:param name: the name of the parameter
|
||||||
|
:param llm_description: the description presented to the LLM
|
||||||
|
:param typ: the type of the parameter
|
||||||
|
:param required: if the parameter is required
|
||||||
|
:param options: the options of the parameter
|
||||||
|
"""
|
||||||
|
# convert options to ToolParameterOption
|
||||||
|
# FIXME fix the type error
|
||||||
|
if options:
|
||||||
|
option_objs = [
|
||||||
|
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))
|
||||||
|
for option in options
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
option_objs = []
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=name,
|
||||||
|
label=I18nObject(en_US="", zh_Hans=""),
|
||||||
|
placeholder=None,
|
||||||
|
type=typ,
|
||||||
|
required=required,
|
||||||
|
options=option_objs,
|
||||||
|
description=I18nObject(en_US="", zh_Hans=""),
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_frontend_parameter(self, value: Any):
|
||||||
|
return init_frontend_parameter(self, self.type, value)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceIdentity(BaseModel):
|
||||||
|
author: str = Field(..., description="The author of the datasource")
|
||||||
|
name: str = Field(..., description="The name of the datasource")
|
||||||
|
label: I18nObject = Field(..., description="The label of the datasource")
|
||||||
|
provider: str = Field(..., description="The provider of the datasource")
|
||||||
|
icon: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceEntity(BaseModel):
|
||||||
|
identity: DatasourceIdentity
|
||||||
|
parameters: list[DatasourceParameter] = Field(default_factory=list)
|
||||||
|
description: I18nObject = Field(..., description="The label of the datasource")
|
||||||
|
|
||||||
|
@field_validator("parameters", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def set_parameters(cls, v, validation_info: ValidationInfo) -> list[DatasourceParameter]:
|
||||||
|
return v or []
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderIdentity(BaseModel):
|
||||||
|
author: str = Field(..., description="The author of the tool")
|
||||||
|
name: str = Field(..., description="The name of the tool")
|
||||||
|
description: I18nObject = Field(..., description="The description of the tool")
|
||||||
|
icon: str = Field(..., description="The icon of the tool")
|
||||||
|
label: I18nObject = Field(..., description="The label of the tool")
|
||||||
|
tags: Optional[list[ToolLabelEnum]] = Field(
|
||||||
|
default=[],
|
||||||
|
description="The tags of the tool",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderEntity(BaseModel):
|
||||||
|
"""
|
||||||
|
Datasource provider entity
|
||||||
|
"""
|
||||||
|
|
||||||
|
identity: DatasourceProviderIdentity
|
||||||
|
credentials_schema: list[ProviderConfig] = Field(default_factory=list)
|
||||||
|
oauth_schema: Optional[OAuthSchema] = None
|
||||||
|
provider_type: DatasourceProviderType
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderEntityWithPlugin(DatasourceProviderEntity):
|
||||||
|
datasources: list[DatasourceEntity] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceInvokeMeta(BaseModel):
|
||||||
|
"""
|
||||||
|
Datasource invoke meta
|
||||||
|
"""
|
||||||
|
|
||||||
|
time_cost: float = Field(..., description="The time cost of the tool invoke")
|
||||||
|
error: Optional[str] = None
|
||||||
|
tool_config: Optional[dict] = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def empty(cls) -> "DatasourceInvokeMeta":
|
||||||
|
"""
|
||||||
|
Get an empty instance of DatasourceInvokeMeta
|
||||||
|
"""
|
||||||
|
return cls(time_cost=0.0, error=None, tool_config={})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def error_instance(cls, error: str) -> "DatasourceInvokeMeta":
|
||||||
|
"""
|
||||||
|
Get an instance of DatasourceInvokeMeta with error
|
||||||
|
"""
|
||||||
|
return cls(time_cost=0.0, error=error, tool_config={})
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"time_cost": self.time_cost,
|
||||||
|
"error": self.error,
|
||||||
|
"tool_config": self.tool_config,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceLabel(BaseModel):
|
||||||
|
"""
|
||||||
|
Datasource label
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = Field(..., description="The name of the tool")
|
||||||
|
label: I18nObject = Field(..., description="The label of the tool")
|
||||||
|
icon: str = Field(..., description="The icon of the tool")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceInvokeFrom(Enum):
|
||||||
|
"""
|
||||||
|
Enum class for datasource invoke
|
||||||
|
"""
|
||||||
|
|
||||||
|
RAG_PIPELINE = "rag_pipeline"
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentPage(BaseModel):
|
||||||
|
"""
|
||||||
|
Online document page
|
||||||
|
"""
|
||||||
|
|
||||||
|
page_id: str = Field(..., description="The page id")
|
||||||
|
page_name: str = Field(..., description="The page title")
|
||||||
|
page_icon: Optional[dict] = Field(None, description="The page icon")
|
||||||
|
type: str = Field(..., description="The type of the page")
|
||||||
|
last_edited_time: str = Field(..., description="The last edited time")
|
||||||
|
parent_id: Optional[str] = Field(None, description="The parent page id")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
Online document info
|
||||||
|
"""
|
||||||
|
|
||||||
|
workspace_id: str = Field(..., description="The workspace id")
|
||||||
|
workspace_name: str = Field(..., description="The workspace name")
|
||||||
|
workspace_icon: str = Field(..., description="The workspace icon")
|
||||||
|
total: int = Field(..., description="The total number of documents")
|
||||||
|
pages: list[OnlineDocumentPage] = Field(..., description="The pages of the online document")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentPagesMessage(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online document pages response
|
||||||
|
"""
|
||||||
|
|
||||||
|
result: list[OnlineDocumentInfo]
|
||||||
|
|
||||||
|
|
||||||
|
class GetOnlineDocumentPageContentRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online document page content request
|
||||||
|
"""
|
||||||
|
|
||||||
|
workspace_id: str = Field(..., description="The workspace id")
|
||||||
|
page_id: str = Field(..., description="The page id")
|
||||||
|
type: str = Field(..., description="The type of the page")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentPageContent(BaseModel):
|
||||||
|
"""
|
||||||
|
Online document page content
|
||||||
|
"""
|
||||||
|
|
||||||
|
workspace_id: str = Field(..., description="The workspace id")
|
||||||
|
page_id: str = Field(..., description="The page id")
|
||||||
|
content: str = Field(..., description="The content of the page")
|
||||||
|
|
||||||
|
|
||||||
|
class GetOnlineDocumentPageContentResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online document page content response
|
||||||
|
"""
|
||||||
|
|
||||||
|
result: OnlineDocumentPageContent
|
||||||
|
|
||||||
|
|
||||||
|
class GetWebsiteCrawlRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
Get website crawl request
|
||||||
|
"""
|
||||||
|
|
||||||
|
crawl_parameters: dict = Field(..., description="The crawl parameters")
|
||||||
|
|
||||||
|
|
||||||
|
class WebSiteInfoDetail(BaseModel):
|
||||||
|
source_url: str = Field(..., description="The url of the website")
|
||||||
|
content: str = Field(..., description="The content of the website")
|
||||||
|
title: str = Field(..., description="The title of the website")
|
||||||
|
description: str = Field(..., description="The description of the website")
|
||||||
|
|
||||||
|
|
||||||
|
class WebSiteInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
Website info
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: Optional[str] = Field(..., description="crawl job status")
|
||||||
|
web_info_list: Optional[list[WebSiteInfoDetail]] = []
|
||||||
|
total: Optional[int] = Field(default=0, description="The total number of websites")
|
||||||
|
completed: Optional[int] = Field(default=0, description="The number of completed websites")
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteCrawlMessage(BaseModel):
|
||||||
|
"""
|
||||||
|
Get website crawl response
|
||||||
|
"""
|
||||||
|
|
||||||
|
result: WebSiteInfo = WebSiteInfo(status="", web_info_list=[], total=0, completed=0)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceMessage(ToolInvokeMessage):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
#########################
|
||||||
|
# Online driver file
|
||||||
|
#########################
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveFile(BaseModel):
|
||||||
|
"""
|
||||||
|
Online driver file
|
||||||
|
"""
|
||||||
|
|
||||||
|
key: str = Field(..., description="The key of the file")
|
||||||
|
size: int = Field(..., description="The size of the file")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveFileBucket(BaseModel):
|
||||||
|
"""
|
||||||
|
Online driver file bucket
|
||||||
|
"""
|
||||||
|
|
||||||
|
bucket: Optional[str] = Field(None, description="The bucket of the file")
|
||||||
|
files: list[OnlineDriveFile] = Field(..., description="The files of the bucket")
|
||||||
|
is_truncated: bool = Field(False, description="Whether the bucket has more files")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveBrowseFilesRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online driver file list request
|
||||||
|
"""
|
||||||
|
|
||||||
|
prefix: Optional[str] = Field(None, description="File path prefix for filtering eg: 'docs/dify/'")
|
||||||
|
bucket: Optional[str] = Field(None, description="Storage bucket name")
|
||||||
|
max_keys: int = Field(20, description="Maximum number of files to return")
|
||||||
|
start_after: Optional[str] = Field(
|
||||||
|
None, description="Pagination token for continuing from a specific file eg: 'docs/dify/1.txt'"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveBrowseFilesResponse(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online driver file list response
|
||||||
|
"""
|
||||||
|
|
||||||
|
result: list[OnlineDriveFileBucket] = Field(..., description="The bucket of the files")
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveDownloadFileRequest(BaseModel):
|
||||||
|
"""
|
||||||
|
Get online driver file
|
||||||
|
"""
|
||||||
|
|
||||||
|
key: str = Field(..., description="The name of the file")
|
||||||
|
bucket: Optional[str] = Field(None, description="The name of the bucket")
|
||||||
@ -0,0 +1,37 @@
|
|||||||
|
from core.datasource.entities.datasource_entities import DatasourceInvokeMeta
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderNotFoundError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceNotFoundError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceParameterValidationError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProviderCredentialValidationError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceNotSupportedError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceInvokeError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceApiSchemaError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceEngineInvokeError(Exception):
|
||||||
|
meta: DatasourceInvokeMeta
|
||||||
|
|
||||||
|
def __init__(self, meta, **kwargs):
|
||||||
|
self.meta = meta
|
||||||
|
super().__init__(**kwargs)
|
||||||
@ -0,0 +1,28 @@
|
|||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceProviderType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LocalFileDatasourcePlugin(DatasourcePlugin):
|
||||||
|
tenant_id: str
|
||||||
|
icon: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
tenant_id: str,
|
||||||
|
icon: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, runtime)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.icon = icon
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
def datasource_provider_type(self) -> str:
|
||||||
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
@ -0,0 +1,56 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class LocalFileDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
plugin_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, tenant_id)
|
||||||
|
self.plugin_id = plugin_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.LOCAL_FILE
|
||||||
|
|
||||||
|
def _validate_credentials(self, user_id: str, credentials: dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
validate the credentials of the provider
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_datasource(self, datasource_name: str) -> LocalFileDatasourcePlugin: # type: ignore
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
datasource_entity = next(
|
||||||
|
(
|
||||||
|
datasource_entity
|
||||||
|
for datasource_entity in self.entity.datasources
|
||||||
|
if datasource_entity.identity.name == datasource_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not datasource_entity:
|
||||||
|
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||||
|
|
||||||
|
return LocalFileDatasourcePlugin(
|
||||||
|
entity=datasource_entity,
|
||||||
|
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.entity.identity.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
@ -0,0 +1,73 @@
|
|||||||
|
from collections.abc import Generator, Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceMessage,
|
||||||
|
DatasourceProviderType,
|
||||||
|
GetOnlineDocumentPageContentRequest,
|
||||||
|
OnlineDocumentPagesMessage,
|
||||||
|
)
|
||||||
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentDatasourcePlugin(DatasourcePlugin):
|
||||||
|
tenant_id: str
|
||||||
|
icon: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
entity: DatasourceEntity
|
||||||
|
runtime: DatasourceRuntime
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
tenant_id: str,
|
||||||
|
icon: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, runtime)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.icon = icon
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
def get_online_document_pages(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
datasource_parameters: Mapping[str, Any],
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[OnlineDocumentPagesMessage, None, None]:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
return manager.get_online_document_pages(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
datasource_parameters=datasource_parameters,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_online_document_page_content(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
datasource_parameters: GetOnlineDocumentPageContentRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[DatasourceMessage, None, None]:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
return manager.get_online_document_page_content(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
datasource_parameters=datasource_parameters,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def datasource_provider_type(self) -> str:
|
||||||
|
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||||
@ -0,0 +1,48 @@
|
|||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
plugin_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, tenant_id)
|
||||||
|
self.plugin_id = plugin_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.ONLINE_DOCUMENT
|
||||||
|
|
||||||
|
def get_datasource(self, datasource_name: str) -> OnlineDocumentDatasourcePlugin: # type: ignore
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
datasource_entity = next(
|
||||||
|
(
|
||||||
|
datasource_entity
|
||||||
|
for datasource_entity in self.entity.datasources
|
||||||
|
if datasource_entity.identity.name == datasource_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not datasource_entity:
|
||||||
|
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||||
|
|
||||||
|
return OnlineDocumentDatasourcePlugin(
|
||||||
|
entity=datasource_entity,
|
||||||
|
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.entity.identity.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
@ -0,0 +1,73 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceMessage,
|
||||||
|
DatasourceProviderType,
|
||||||
|
OnlineDriveBrowseFilesRequest,
|
||||||
|
OnlineDriveBrowseFilesResponse,
|
||||||
|
OnlineDriveDownloadFileRequest,
|
||||||
|
)
|
||||||
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveDatasourcePlugin(DatasourcePlugin):
|
||||||
|
tenant_id: str
|
||||||
|
icon: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
entity: DatasourceEntity
|
||||||
|
runtime: DatasourceRuntime
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
tenant_id: str,
|
||||||
|
icon: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, runtime)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.icon = icon
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
def online_drive_browse_files(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
request: OnlineDriveBrowseFilesRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
return manager.online_drive_browse_files(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
request=request,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def online_drive_download_file(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
request: OnlineDriveDownloadFileRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[DatasourceMessage, None, None]:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
return manager.online_drive_download_file(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
request=request,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def datasource_provider_type(self) -> str:
|
||||||
|
return DatasourceProviderType.ONLINE_DRIVE
|
||||||
@ -0,0 +1,48 @@
|
|||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDriveDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
plugin_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, entity: DatasourceProviderEntityWithPlugin, plugin_id: str, plugin_unique_identifier: str, tenant_id: str
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, tenant_id)
|
||||||
|
self.plugin_id = plugin_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.ONLINE_DRIVE
|
||||||
|
|
||||||
|
def get_datasource(self, datasource_name: str) -> OnlineDriveDatasourcePlugin: # type: ignore
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
datasource_entity = next(
|
||||||
|
(
|
||||||
|
datasource_entity
|
||||||
|
for datasource_entity in self.entity.datasources
|
||||||
|
if datasource_entity.identity.name == datasource_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not datasource_entity:
|
||||||
|
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||||
|
|
||||||
|
return OnlineDriveDatasourcePlugin(
|
||||||
|
entity=datasource_entity,
|
||||||
|
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.entity.identity.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
@ -0,0 +1,265 @@
|
|||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.entities.provider_entities import BasicProviderConfig
|
||||||
|
from core.helper import encrypter
|
||||||
|
from core.helper.tool_parameter_cache import ToolParameterCache, ToolParameterCacheType
|
||||||
|
from core.helper.tool_provider_cache import ToolProviderCredentialsCache, ToolProviderCredentialsCacheType
|
||||||
|
from core.tools.__base.tool import Tool
|
||||||
|
from core.tools.entities.tool_entities import (
|
||||||
|
ToolParameter,
|
||||||
|
ToolProviderType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ProviderConfigEncrypter(BaseModel):
|
||||||
|
tenant_id: str
|
||||||
|
config: list[BasicProviderConfig]
|
||||||
|
provider_type: str
|
||||||
|
provider_identity: str
|
||||||
|
|
||||||
|
def _deep_copy(self, data: dict[str, str]) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
deep copy data
|
||||||
|
"""
|
||||||
|
return deepcopy(data)
|
||||||
|
|
||||||
|
def encrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
encrypt tool credentials with tenant id
|
||||||
|
|
||||||
|
return a deep copy of credentials with encrypted values
|
||||||
|
"""
|
||||||
|
data = self._deep_copy(data)
|
||||||
|
|
||||||
|
# get fields need to be decrypted
|
||||||
|
fields = dict[str, BasicProviderConfig]()
|
||||||
|
for credential in self.config:
|
||||||
|
fields[credential.name] = credential
|
||||||
|
|
||||||
|
for field_name, field in fields.items():
|
||||||
|
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
|
if field_name in data:
|
||||||
|
encrypted = encrypter.encrypt_token(self.tenant_id, data[field_name] or "")
|
||||||
|
data[field_name] = encrypted
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def mask_tool_credentials(self, data: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
mask tool credentials
|
||||||
|
|
||||||
|
return a deep copy of credentials with masked values
|
||||||
|
"""
|
||||||
|
data = self._deep_copy(data)
|
||||||
|
|
||||||
|
# get fields need to be decrypted
|
||||||
|
fields = dict[str, BasicProviderConfig]()
|
||||||
|
for credential in self.config:
|
||||||
|
fields[credential.name] = credential
|
||||||
|
|
||||||
|
for field_name, field in fields.items():
|
||||||
|
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
|
if field_name in data:
|
||||||
|
if len(data[field_name]) > 6:
|
||||||
|
data[field_name] = (
|
||||||
|
data[field_name][:2] + "*" * (len(data[field_name]) - 4) + data[field_name][-2:]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
data[field_name] = "*" * len(data[field_name])
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def decrypt(self, data: dict[str, str]) -> dict[str, str]:
|
||||||
|
"""
|
||||||
|
decrypt tool credentials with tenant id
|
||||||
|
|
||||||
|
return a deep copy of credentials with decrypted values
|
||||||
|
"""
|
||||||
|
cache = ToolProviderCredentialsCache(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||||
|
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||||
|
)
|
||||||
|
cached_credentials = cache.get()
|
||||||
|
if cached_credentials:
|
||||||
|
return cached_credentials
|
||||||
|
data = self._deep_copy(data)
|
||||||
|
# get fields need to be decrypted
|
||||||
|
fields = dict[str, BasicProviderConfig]()
|
||||||
|
for credential in self.config:
|
||||||
|
fields[credential.name] = credential
|
||||||
|
|
||||||
|
for field_name, field in fields.items():
|
||||||
|
if field.type == BasicProviderConfig.Type.SECRET_INPUT:
|
||||||
|
if field_name in data:
|
||||||
|
try:
|
||||||
|
# if the value is None or empty string, skip decrypt
|
||||||
|
if not data[field_name]:
|
||||||
|
continue
|
||||||
|
|
||||||
|
data[field_name] = encrypter.decrypt_token(self.tenant_id, data[field_name])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
cache.set(data)
|
||||||
|
return data
|
||||||
|
|
||||||
|
def delete_tool_credentials_cache(self):
|
||||||
|
cache = ToolProviderCredentialsCache(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
identity_id=f"{self.provider_type}.{self.provider_identity}",
|
||||||
|
cache_type=ToolProviderCredentialsCacheType.PROVIDER,
|
||||||
|
)
|
||||||
|
cache.delete()
|
||||||
|
|
||||||
|
|
||||||
|
class ToolParameterConfigurationManager:
|
||||||
|
"""
|
||||||
|
Tool parameter configuration manager
|
||||||
|
"""
|
||||||
|
|
||||||
|
tenant_id: str
|
||||||
|
tool_runtime: Tool
|
||||||
|
provider_name: str
|
||||||
|
provider_type: ToolProviderType
|
||||||
|
identity_id: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, tenant_id: str, tool_runtime: Tool, provider_name: str, provider_type: ToolProviderType, identity_id: str
|
||||||
|
) -> None:
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.tool_runtime = tool_runtime
|
||||||
|
self.provider_name = provider_name
|
||||||
|
self.provider_type = provider_type
|
||||||
|
self.identity_id = identity_id
|
||||||
|
|
||||||
|
def _deep_copy(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
deep copy parameters
|
||||||
|
"""
|
||||||
|
return deepcopy(parameters)
|
||||||
|
|
||||||
|
def _merge_parameters(self) -> list[ToolParameter]:
|
||||||
|
"""
|
||||||
|
merge parameters
|
||||||
|
"""
|
||||||
|
# get tool parameters
|
||||||
|
tool_parameters = self.tool_runtime.entity.parameters or []
|
||||||
|
# get tool runtime parameters
|
||||||
|
runtime_parameters = self.tool_runtime.get_runtime_parameters()
|
||||||
|
# override parameters
|
||||||
|
current_parameters = tool_parameters.copy()
|
||||||
|
for runtime_parameter in runtime_parameters:
|
||||||
|
found = False
|
||||||
|
for index, parameter in enumerate(current_parameters):
|
||||||
|
if parameter.name == runtime_parameter.name and parameter.form == runtime_parameter.form:
|
||||||
|
current_parameters[index] = runtime_parameter
|
||||||
|
found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not found and runtime_parameter.form == ToolParameter.ToolParameterForm.FORM:
|
||||||
|
current_parameters.append(runtime_parameter)
|
||||||
|
|
||||||
|
return current_parameters
|
||||||
|
|
||||||
|
def mask_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
mask tool parameters
|
||||||
|
|
||||||
|
return a deep copy of parameters with masked values
|
||||||
|
"""
|
||||||
|
parameters = self._deep_copy(parameters)
|
||||||
|
|
||||||
|
# override parameters
|
||||||
|
current_parameters = self._merge_parameters()
|
||||||
|
|
||||||
|
for parameter in current_parameters:
|
||||||
|
if (
|
||||||
|
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||||
|
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||||
|
):
|
||||||
|
if parameter.name in parameters:
|
||||||
|
if len(parameters[parameter.name]) > 6:
|
||||||
|
parameters[parameter.name] = (
|
||||||
|
parameters[parameter.name][:2]
|
||||||
|
+ "*" * (len(parameters[parameter.name]) - 4)
|
||||||
|
+ parameters[parameter.name][-2:]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
parameters[parameter.name] = "*" * len(parameters[parameter.name])
|
||||||
|
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
def encrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
encrypt tool parameters with tenant id
|
||||||
|
|
||||||
|
return a deep copy of parameters with encrypted values
|
||||||
|
"""
|
||||||
|
# override parameters
|
||||||
|
current_parameters = self._merge_parameters()
|
||||||
|
|
||||||
|
parameters = self._deep_copy(parameters)
|
||||||
|
|
||||||
|
for parameter in current_parameters:
|
||||||
|
if (
|
||||||
|
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||||
|
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||||
|
):
|
||||||
|
if parameter.name in parameters:
|
||||||
|
encrypted = encrypter.encrypt_token(self.tenant_id, parameters[parameter.name])
|
||||||
|
parameters[parameter.name] = encrypted
|
||||||
|
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
def decrypt_tool_parameters(self, parameters: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
decrypt tool parameters with tenant id
|
||||||
|
|
||||||
|
return a deep copy of parameters with decrypted values
|
||||||
|
"""
|
||||||
|
|
||||||
|
cache = ToolParameterCache(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||||
|
tool_name=self.tool_runtime.entity.identity.name,
|
||||||
|
cache_type=ToolParameterCacheType.PARAMETER,
|
||||||
|
identity_id=self.identity_id,
|
||||||
|
)
|
||||||
|
cached_parameters = cache.get()
|
||||||
|
if cached_parameters:
|
||||||
|
return cached_parameters
|
||||||
|
|
||||||
|
# override parameters
|
||||||
|
current_parameters = self._merge_parameters()
|
||||||
|
has_secret_input = False
|
||||||
|
|
||||||
|
for parameter in current_parameters:
|
||||||
|
if (
|
||||||
|
parameter.form == ToolParameter.ToolParameterForm.FORM
|
||||||
|
and parameter.type == ToolParameter.ToolParameterType.SECRET_INPUT
|
||||||
|
):
|
||||||
|
if parameter.name in parameters:
|
||||||
|
try:
|
||||||
|
has_secret_input = True
|
||||||
|
parameters[parameter.name] = encrypter.decrypt_token(self.tenant_id, parameters[parameter.name])
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if has_secret_input:
|
||||||
|
cache.set(parameters)
|
||||||
|
|
||||||
|
return parameters
|
||||||
|
|
||||||
|
def delete_tool_parameters_cache(self):
|
||||||
|
cache = ToolParameterCache(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
provider=f"{self.provider_type.value}.{self.provider_name}",
|
||||||
|
tool_name=self.tool_runtime.entity.identity.name,
|
||||||
|
cache_type=ToolParameterCacheType.PARAMETER,
|
||||||
|
identity_id=self.identity_id,
|
||||||
|
)
|
||||||
|
cache.delete()
|
||||||
@ -0,0 +1,121 @@
|
|||||||
|
import logging
|
||||||
|
from collections.abc import Generator
|
||||||
|
from mimetypes import guess_extension
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||||
|
from core.file import File, FileTransferMethod, FileType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceFileMessageTransformer:
|
||||||
|
@classmethod
|
||||||
|
def transform_datasource_invoke_messages(
|
||||||
|
cls,
|
||||||
|
messages: Generator[DatasourceMessage, None, None],
|
||||||
|
user_id: str,
|
||||||
|
tenant_id: str,
|
||||||
|
conversation_id: Optional[str] = None,
|
||||||
|
) -> Generator[DatasourceMessage, None, None]:
|
||||||
|
"""
|
||||||
|
Transform datasource message and handle file download
|
||||||
|
"""
|
||||||
|
for message in messages:
|
||||||
|
if message.type in {DatasourceMessage.MessageType.TEXT, DatasourceMessage.MessageType.LINK}:
|
||||||
|
yield message
|
||||||
|
elif message.type == DatasourceMessage.MessageType.IMAGE and isinstance(
|
||||||
|
message.message, DatasourceMessage.TextMessage
|
||||||
|
):
|
||||||
|
# try to download image
|
||||||
|
try:
|
||||||
|
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||||
|
|
||||||
|
file = DatasourceFileManager.create_file_by_url(
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
file_url=message.message.text,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
url = f"/files/datasources/{file.id}{guess_extension(file.mime_type) or '.png'}"
|
||||||
|
|
||||||
|
yield DatasourceMessage(
|
||||||
|
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||||
|
message=DatasourceMessage.TextMessage(text=url),
|
||||||
|
meta=message.meta.copy() if message.meta is not None else {},
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
yield DatasourceMessage(
|
||||||
|
type=DatasourceMessage.MessageType.TEXT,
|
||||||
|
message=DatasourceMessage.TextMessage(
|
||||||
|
text=f"Failed to download image: {message.message.text}: {e}"
|
||||||
|
),
|
||||||
|
meta=message.meta.copy() if message.meta is not None else {},
|
||||||
|
)
|
||||||
|
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||||
|
# get mime type and save blob to storage
|
||||||
|
meta = message.meta or {}
|
||||||
|
|
||||||
|
mimetype = meta.get("mime_type", "application/octet-stream")
|
||||||
|
# get filename from meta
|
||||||
|
filename = meta.get("file_name", None)
|
||||||
|
# if message is str, encode it to bytes
|
||||||
|
|
||||||
|
if not isinstance(message.message, DatasourceMessage.BlobMessage):
|
||||||
|
raise ValueError("unexpected message type")
|
||||||
|
|
||||||
|
# FIXME: should do a type check here.
|
||||||
|
assert isinstance(message.message.blob, bytes)
|
||||||
|
file = DatasourceFileManager.create_file_by_raw(
|
||||||
|
user_id=user_id,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
file_binary=message.message.blob,
|
||||||
|
mimetype=mimetype,
|
||||||
|
filename=filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
url = cls.get_datasource_file_url(datasource_file_id=file.id, extension=guess_extension(file.mime_type))
|
||||||
|
|
||||||
|
# check if file is image
|
||||||
|
if "image" in mimetype:
|
||||||
|
yield DatasourceMessage(
|
||||||
|
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||||
|
message=DatasourceMessage.TextMessage(text=url),
|
||||||
|
meta=meta.copy() if meta is not None else {},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield DatasourceMessage(
|
||||||
|
type=DatasourceMessage.MessageType.BINARY_LINK,
|
||||||
|
message=DatasourceMessage.TextMessage(text=url),
|
||||||
|
meta=meta.copy() if meta is not None else {},
|
||||||
|
)
|
||||||
|
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||||
|
meta = message.meta or {}
|
||||||
|
file = meta.get("file", None)
|
||||||
|
if isinstance(file, File):
|
||||||
|
if file.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||||
|
assert file.related_id is not None
|
||||||
|
url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension)
|
||||||
|
if file.type == FileType.IMAGE:
|
||||||
|
yield DatasourceMessage(
|
||||||
|
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||||
|
message=DatasourceMessage.TextMessage(text=url),
|
||||||
|
meta=meta.copy() if meta is not None else {},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield DatasourceMessage(
|
||||||
|
type=DatasourceMessage.MessageType.LINK,
|
||||||
|
message=DatasourceMessage.TextMessage(text=url),
|
||||||
|
meta=meta.copy() if meta is not None else {},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
yield message
|
||||||
|
else:
|
||||||
|
yield message
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_datasource_file_url(cls, datasource_file_id: str, extension: Optional[str]) -> str:
|
||||||
|
return f"/files/datasources/{datasource_file_id}{extension or '.bin'}"
|
||||||
@ -0,0 +1,389 @@
|
|||||||
|
import re
|
||||||
|
import uuid
|
||||||
|
from json import dumps as json_dumps
|
||||||
|
from json import loads as json_loads
|
||||||
|
from json.decoder import JSONDecodeError
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from flask import request
|
||||||
|
from requests import get
|
||||||
|
from yaml import YAMLError, safe_load # type: ignore
|
||||||
|
|
||||||
|
from core.tools.entities.common_entities import I18nObject
|
||||||
|
from core.tools.entities.tool_bundle import ApiToolBundle
|
||||||
|
from core.tools.entities.tool_entities import ApiProviderSchemaType, ToolParameter
|
||||||
|
from core.tools.errors import ToolApiSchemaError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||||
|
|
||||||
|
|
||||||
|
class ApiBasedToolSchemaParser:
|
||||||
|
@staticmethod
|
||||||
|
def parse_openapi_to_tool_bundle(
|
||||||
|
openapi: dict, extra_info: dict | None = None, warning: dict | None = None
|
||||||
|
) -> list[ApiToolBundle]:
|
||||||
|
warning = warning if warning is not None else {}
|
||||||
|
extra_info = extra_info if extra_info is not None else {}
|
||||||
|
|
||||||
|
# set description to extra_info
|
||||||
|
extra_info["description"] = openapi["info"].get("description", "")
|
||||||
|
|
||||||
|
if len(openapi["servers"]) == 0:
|
||||||
|
raise ToolProviderNotFoundError("No server found in the openapi yaml.")
|
||||||
|
|
||||||
|
server_url = openapi["servers"][0]["url"]
|
||||||
|
request_env = request.headers.get("X-Request-Env")
|
||||||
|
if request_env:
|
||||||
|
matched_servers = [server["url"] for server in openapi["servers"] if server["env"] == request_env]
|
||||||
|
server_url = matched_servers[0] if matched_servers else server_url
|
||||||
|
|
||||||
|
# list all interfaces
|
||||||
|
interfaces = []
|
||||||
|
for path, path_item in openapi["paths"].items():
|
||||||
|
methods = ["get", "post", "put", "delete", "patch", "head", "options", "trace"]
|
||||||
|
for method in methods:
|
||||||
|
if method in path_item:
|
||||||
|
interfaces.append(
|
||||||
|
{
|
||||||
|
"path": path,
|
||||||
|
"method": method,
|
||||||
|
"operation": path_item[method],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# get all parameters
|
||||||
|
bundles = []
|
||||||
|
for interface in interfaces:
|
||||||
|
# convert parameters
|
||||||
|
parameters = []
|
||||||
|
if "parameters" in interface["operation"]:
|
||||||
|
for parameter in interface["operation"]["parameters"]:
|
||||||
|
tool_parameter = ToolParameter(
|
||||||
|
name=parameter["name"],
|
||||||
|
label=I18nObject(en_US=parameter["name"], zh_Hans=parameter["name"]),
|
||||||
|
human_description=I18nObject(
|
||||||
|
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||||
|
),
|
||||||
|
type=ToolParameter.ToolParameterType.STRING,
|
||||||
|
required=parameter.get("required", False),
|
||||||
|
form=ToolParameter.ToolParameterForm.LLM,
|
||||||
|
llm_description=parameter.get("description"),
|
||||||
|
default=parameter["schema"]["default"]
|
||||||
|
if "schema" in parameter and "default" in parameter["schema"]
|
||||||
|
else None,
|
||||||
|
placeholder=I18nObject(
|
||||||
|
en_US=parameter.get("description", ""), zh_Hans=parameter.get("description", "")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if there is a type
|
||||||
|
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(parameter)
|
||||||
|
if typ:
|
||||||
|
tool_parameter.type = typ
|
||||||
|
|
||||||
|
parameters.append(tool_parameter)
|
||||||
|
# create tool bundle
|
||||||
|
# check if there is a request body
|
||||||
|
if "requestBody" in interface["operation"]:
|
||||||
|
request_body = interface["operation"]["requestBody"]
|
||||||
|
if "content" in request_body:
|
||||||
|
for content_type, content in request_body["content"].items():
|
||||||
|
# if there is a reference, get the reference and overwrite the content
|
||||||
|
if "schema" not in content:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "$ref" in content["schema"]:
|
||||||
|
# get the reference
|
||||||
|
root = openapi
|
||||||
|
reference = content["schema"]["$ref"].split("/")[1:]
|
||||||
|
for ref in reference:
|
||||||
|
root = root[ref]
|
||||||
|
# overwrite the content
|
||||||
|
interface["operation"]["requestBody"]["content"][content_type]["schema"] = root
|
||||||
|
|
||||||
|
# parse body parameters
|
||||||
|
if "schema" in interface["operation"]["requestBody"]["content"][content_type]:
|
||||||
|
body_schema = interface["operation"]["requestBody"]["content"][content_type]["schema"]
|
||||||
|
required = body_schema.get("required", [])
|
||||||
|
properties = body_schema.get("properties", {})
|
||||||
|
for name, property in properties.items():
|
||||||
|
tool = ToolParameter(
|
||||||
|
name=name,
|
||||||
|
label=I18nObject(en_US=name, zh_Hans=name),
|
||||||
|
human_description=I18nObject(
|
||||||
|
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||||
|
),
|
||||||
|
type=ToolParameter.ToolParameterType.STRING,
|
||||||
|
required=name in required,
|
||||||
|
form=ToolParameter.ToolParameterForm.LLM,
|
||||||
|
llm_description=property.get("description", ""),
|
||||||
|
default=property.get("default", None),
|
||||||
|
placeholder=I18nObject(
|
||||||
|
en_US=property.get("description", ""), zh_Hans=property.get("description", "")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# check if there is a type
|
||||||
|
typ = ApiBasedToolSchemaParser._get_tool_parameter_type(property)
|
||||||
|
if typ:
|
||||||
|
tool.type = typ
|
||||||
|
|
||||||
|
parameters.append(tool)
|
||||||
|
|
||||||
|
# check if parameters is duplicated
|
||||||
|
parameters_count = {}
|
||||||
|
for parameter in parameters:
|
||||||
|
if parameter.name not in parameters_count:
|
||||||
|
parameters_count[parameter.name] = 0
|
||||||
|
parameters_count[parameter.name] += 1
|
||||||
|
for name, count in parameters_count.items():
|
||||||
|
if count > 1:
|
||||||
|
warning["duplicated_parameter"] = f"Parameter {name} is duplicated."
|
||||||
|
|
||||||
|
# check if there is a operation id, use $path_$method as operation id if not
|
||||||
|
if "operationId" not in interface["operation"]:
|
||||||
|
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||||
|
path = interface["path"]
|
||||||
|
if interface["path"].startswith("/"):
|
||||||
|
path = interface["path"][1:]
|
||||||
|
# remove special characters like / to ensure the operation id is valid ^[a-zA-Z0-9_-]{1,64}$
|
||||||
|
path = re.sub(r"[^a-zA-Z0-9_-]", "", path)
|
||||||
|
if not path:
|
||||||
|
path = str(uuid.uuid4())
|
||||||
|
|
||||||
|
interface["operation"]["operationId"] = f"{path}_{interface['method']}"
|
||||||
|
|
||||||
|
bundles.append(
|
||||||
|
ApiToolBundle(
|
||||||
|
server_url=server_url + interface["path"],
|
||||||
|
method=interface["method"],
|
||||||
|
summary=interface["operation"]["description"]
|
||||||
|
if "description" in interface["operation"]
|
||||||
|
else interface["operation"].get("summary", None),
|
||||||
|
operation_id=interface["operation"]["operationId"],
|
||||||
|
parameters=parameters,
|
||||||
|
author="",
|
||||||
|
icon=None,
|
||||||
|
openapi=interface["operation"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return bundles
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_tool_parameter_type(parameter: dict) -> Optional[ToolParameter.ToolParameterType]:
|
||||||
|
parameter = parameter or {}
|
||||||
|
typ: Optional[str] = None
|
||||||
|
if parameter.get("format") == "binary":
|
||||||
|
return ToolParameter.ToolParameterType.FILE
|
||||||
|
|
||||||
|
if "type" in parameter:
|
||||||
|
typ = parameter["type"]
|
||||||
|
elif "schema" in parameter and "type" in parameter["schema"]:
|
||||||
|
typ = parameter["schema"]["type"]
|
||||||
|
|
||||||
|
if typ in {"integer", "number"}:
|
||||||
|
return ToolParameter.ToolParameterType.NUMBER
|
||||||
|
elif typ == "boolean":
|
||||||
|
return ToolParameter.ToolParameterType.BOOLEAN
|
||||||
|
elif typ == "string":
|
||||||
|
return ToolParameter.ToolParameterType.STRING
|
||||||
|
elif typ == "array":
|
||||||
|
items = parameter.get("items") or parameter.get("schema", {}).get("items")
|
||||||
|
return ToolParameter.ToolParameterType.FILES if items and items.get("format") == "binary" else None
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_openapi_yaml_to_tool_bundle(
|
||||||
|
yaml: str, extra_info: dict | None = None, warning: dict | None = None
|
||||||
|
) -> list[ApiToolBundle]:
|
||||||
|
"""
|
||||||
|
parse openapi yaml to tool bundle
|
||||||
|
|
||||||
|
:param yaml: the yaml string
|
||||||
|
:param extra_info: the extra info
|
||||||
|
:param warning: the warning message
|
||||||
|
:return: the tool bundle
|
||||||
|
"""
|
||||||
|
warning = warning if warning is not None else {}
|
||||||
|
extra_info = extra_info if extra_info is not None else {}
|
||||||
|
|
||||||
|
openapi: dict = safe_load(yaml)
|
||||||
|
if openapi is None:
|
||||||
|
raise ToolApiSchemaError("Invalid openapi yaml.")
|
||||||
|
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(openapi, extra_info=extra_info, warning=warning)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_swagger_to_openapi(swagger: dict, extra_info: dict | None = None, warning: dict | None = None) -> dict:
|
||||||
|
warning = warning or {}
|
||||||
|
"""
|
||||||
|
parse swagger to openapi
|
||||||
|
|
||||||
|
:param swagger: the swagger dict
|
||||||
|
:return: the openapi dict
|
||||||
|
"""
|
||||||
|
# convert swagger to openapi
|
||||||
|
info = swagger.get("info", {"title": "Swagger", "description": "Swagger", "version": "1.0.0"})
|
||||||
|
|
||||||
|
servers = swagger.get("servers", [])
|
||||||
|
|
||||||
|
if len(servers) == 0:
|
||||||
|
raise ToolApiSchemaError("No server found in the swagger yaml.")
|
||||||
|
|
||||||
|
openapi = {
|
||||||
|
"openapi": "3.0.0",
|
||||||
|
"info": {
|
||||||
|
"title": info.get("title", "Swagger"),
|
||||||
|
"description": info.get("description", "Swagger"),
|
||||||
|
"version": info.get("version", "1.0.0"),
|
||||||
|
},
|
||||||
|
"servers": swagger["servers"],
|
||||||
|
"paths": {},
|
||||||
|
"components": {"schemas": {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
# check paths
|
||||||
|
if "paths" not in swagger or len(swagger["paths"]) == 0:
|
||||||
|
raise ToolApiSchemaError("No paths found in the swagger yaml.")
|
||||||
|
|
||||||
|
# convert paths
|
||||||
|
for path, path_item in swagger["paths"].items():
|
||||||
|
openapi["paths"][path] = {}
|
||||||
|
for method, operation in path_item.items():
|
||||||
|
if "operationId" not in operation:
|
||||||
|
raise ToolApiSchemaError(f"No operationId found in operation {method} {path}.")
|
||||||
|
|
||||||
|
if ("summary" not in operation or len(operation["summary"]) == 0) and (
|
||||||
|
"description" not in operation or len(operation["description"]) == 0
|
||||||
|
):
|
||||||
|
if warning is not None:
|
||||||
|
warning["missing_summary"] = f"No summary or description found in operation {method} {path}."
|
||||||
|
|
||||||
|
openapi["paths"][path][method] = {
|
||||||
|
"operationId": operation["operationId"],
|
||||||
|
"summary": operation.get("summary", ""),
|
||||||
|
"description": operation.get("description", ""),
|
||||||
|
"parameters": operation.get("parameters", []),
|
||||||
|
"responses": operation.get("responses", {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
if "requestBody" in operation:
|
||||||
|
openapi["paths"][path][method]["requestBody"] = operation["requestBody"]
|
||||||
|
|
||||||
|
# convert definitions
|
||||||
|
for name, definition in swagger["definitions"].items():
|
||||||
|
openapi["components"]["schemas"][name] = definition
|
||||||
|
|
||||||
|
return openapi
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def parse_openai_plugin_json_to_tool_bundle(
|
||||||
|
json: str, extra_info: dict | None = None, warning: dict | None = None
|
||||||
|
) -> list[ApiToolBundle]:
|
||||||
|
"""
|
||||||
|
parse openapi plugin yaml to tool bundle
|
||||||
|
|
||||||
|
:param json: the json string
|
||||||
|
:param extra_info: the extra info
|
||||||
|
:param warning: the warning message
|
||||||
|
:return: the tool bundle
|
||||||
|
"""
|
||||||
|
warning = warning if warning is not None else {}
|
||||||
|
extra_info = extra_info if extra_info is not None else {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
openai_plugin = json_loads(json)
|
||||||
|
api = openai_plugin["api"]
|
||||||
|
api_url = api["url"]
|
||||||
|
api_type = api["type"]
|
||||||
|
except JSONDecodeError:
|
||||||
|
raise ToolProviderNotFoundError("Invalid openai plugin json.")
|
||||||
|
|
||||||
|
if api_type != "openapi":
|
||||||
|
raise ToolNotSupportedError("Only openapi is supported now.")
|
||||||
|
|
||||||
|
# get openapi yaml
|
||||||
|
response = get(api_url, headers={"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) "}, timeout=5)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise ToolProviderNotFoundError("cannot get openapi yaml from url.")
|
||||||
|
|
||||||
|
return ApiBasedToolSchemaParser.parse_openapi_yaml_to_tool_bundle(
|
||||||
|
response.text, extra_info=extra_info, warning=warning
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def auto_parse_to_tool_bundle(
|
||||||
|
content: str, extra_info: dict | None = None, warning: dict | None = None
|
||||||
|
) -> tuple[list[ApiToolBundle], str]:
|
||||||
|
"""
|
||||||
|
auto parse to tool bundle
|
||||||
|
|
||||||
|
:param content: the content
|
||||||
|
:param extra_info: the extra info
|
||||||
|
:param warning: the warning message
|
||||||
|
:return: tools bundle, schema_type
|
||||||
|
"""
|
||||||
|
warning = warning if warning is not None else {}
|
||||||
|
extra_info = extra_info if extra_info is not None else {}
|
||||||
|
|
||||||
|
content = content.strip()
|
||||||
|
loaded_content = None
|
||||||
|
json_error = None
|
||||||
|
yaml_error = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
loaded_content = json_loads(content)
|
||||||
|
except JSONDecodeError as e:
|
||||||
|
json_error = e
|
||||||
|
|
||||||
|
if loaded_content is None:
|
||||||
|
try:
|
||||||
|
loaded_content = safe_load(content)
|
||||||
|
except YAMLError as e:
|
||||||
|
yaml_error = e
|
||||||
|
if loaded_content is None:
|
||||||
|
raise ToolApiSchemaError(
|
||||||
|
f"Invalid api schema, schema is neither json nor yaml. json error: {str(json_error)},"
|
||||||
|
f" yaml error: {str(yaml_error)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
swagger_error = None
|
||||||
|
openapi_error = None
|
||||||
|
openapi_plugin_error = None
|
||||||
|
schema_type = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
openapi = ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||||
|
loaded_content, extra_info=extra_info, warning=warning
|
||||||
|
)
|
||||||
|
schema_type = ApiProviderSchemaType.OPENAPI.value
|
||||||
|
return openapi, schema_type
|
||||||
|
except ToolApiSchemaError as e:
|
||||||
|
openapi_error = e
|
||||||
|
|
||||||
|
# openai parse error, fallback to swagger
|
||||||
|
try:
|
||||||
|
converted_swagger = ApiBasedToolSchemaParser.parse_swagger_to_openapi(
|
||||||
|
loaded_content, extra_info=extra_info, warning=warning
|
||||||
|
)
|
||||||
|
schema_type = ApiProviderSchemaType.SWAGGER.value
|
||||||
|
return ApiBasedToolSchemaParser.parse_openapi_to_tool_bundle(
|
||||||
|
converted_swagger, extra_info=extra_info, warning=warning
|
||||||
|
), schema_type
|
||||||
|
except ToolApiSchemaError as e:
|
||||||
|
swagger_error = e
|
||||||
|
|
||||||
|
# swagger parse error, fallback to openai plugin
|
||||||
|
try:
|
||||||
|
openapi_plugin = ApiBasedToolSchemaParser.parse_openai_plugin_json_to_tool_bundle(
|
||||||
|
json_dumps(loaded_content), extra_info=extra_info, warning=warning
|
||||||
|
)
|
||||||
|
return openapi_plugin, ApiProviderSchemaType.OPENAI_PLUGIN.value
|
||||||
|
except ToolNotSupportedError as e:
|
||||||
|
# maybe it's not plugin at all
|
||||||
|
openapi_plugin_error = e
|
||||||
|
|
||||||
|
raise ToolApiSchemaError(
|
||||||
|
f"Invalid api schema, openapi error: {str(openapi_error)}, swagger error: {str(swagger_error)},"
|
||||||
|
f" openapi plugin error: {str(openapi_plugin_error)}"
|
||||||
|
)
|
||||||
@ -0,0 +1,17 @@
|
|||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def remove_leading_symbols(text: str) -> str:
|
||||||
|
"""
|
||||||
|
Remove leading punctuation or symbols from the given text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text (str): The input text to process.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The text with leading punctuation or symbols removed.
|
||||||
|
"""
|
||||||
|
# Match Unicode ranges for punctuation and symbols
|
||||||
|
# FIXME this pattern is confused quick fix for #11868 maybe refactor it later
|
||||||
|
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
|
||||||
|
return re.sub(pattern, "", text)
|
||||||
@ -0,0 +1,9 @@
|
|||||||
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_uuid(uuid_str: str) -> bool:
|
||||||
|
try:
|
||||||
|
uuid.UUID(uuid_str)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
@ -0,0 +1,43 @@
|
|||||||
|
from collections.abc import Mapping, Sequence
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.app.app_config.entities import VariableEntity
|
||||||
|
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowToolConfigurationUtils:
|
||||||
|
@classmethod
|
||||||
|
def check_parameter_configurations(cls, configurations: list[Mapping[str, Any]]):
|
||||||
|
for configuration in configurations:
|
||||||
|
WorkflowToolParameterConfiguration.model_validate(configuration)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_workflow_graph_variables(cls, graph: Mapping[str, Any]) -> Sequence[VariableEntity]:
|
||||||
|
"""
|
||||||
|
get workflow graph variables
|
||||||
|
"""
|
||||||
|
nodes = graph.get("nodes", [])
|
||||||
|
start_node = next(filter(lambda x: x.get("data", {}).get("type") == "start", nodes), None)
|
||||||
|
|
||||||
|
if not start_node:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [VariableEntity.model_validate(variable) for variable in start_node.get("data", {}).get("variables", [])]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_is_synced(
|
||||||
|
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
check is synced
|
||||||
|
|
||||||
|
raise ValueError if not synced
|
||||||
|
"""
|
||||||
|
variable_names = [variable.variable for variable in variables]
|
||||||
|
|
||||||
|
if len(tool_configurations) != len(variables):
|
||||||
|
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||||
|
|
||||||
|
for parameter in tool_configurations:
|
||||||
|
if parameter.name not in variable_names:
|
||||||
|
raise ValueError("parameter configuration mismatch, please republish the tool to update")
|
||||||
@ -0,0 +1,35 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import yaml # type: ignore
|
||||||
|
from yaml import YAMLError
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def load_yaml_file(file_path: str, ignore_error: bool = True, default_value: Any = {}) -> Any:
|
||||||
|
"""
|
||||||
|
Safe loading a YAML file
|
||||||
|
:param file_path: the path of the YAML file
|
||||||
|
:param ignore_error:
|
||||||
|
if True, return default_value if error occurs and the error will be logged in debug level
|
||||||
|
if False, raise error if error occurs
|
||||||
|
:param default_value: the value returned when errors ignored
|
||||||
|
:return: an object of the YAML content
|
||||||
|
"""
|
||||||
|
if not file_path or not Path(file_path).exists():
|
||||||
|
if ignore_error:
|
||||||
|
return default_value
|
||||||
|
else:
|
||||||
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
|
with open(file_path, encoding="utf-8") as yaml_file:
|
||||||
|
try:
|
||||||
|
yaml_content = yaml.safe_load(yaml_file)
|
||||||
|
return yaml_content or default_value
|
||||||
|
except Exception as e:
|
||||||
|
if ignore_error:
|
||||||
|
return default_value
|
||||||
|
else:
|
||||||
|
raise YAMLError(f"Failed to load YAML file {file_path}: {e}") from e
|
||||||
@ -0,0 +1,53 @@
|
|||||||
|
from collections.abc import Generator, Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceEntity,
|
||||||
|
DatasourceProviderType,
|
||||||
|
WebsiteCrawlMessage,
|
||||||
|
)
|
||||||
|
from core.plugin.impl.datasource import PluginDatasourceManager
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteCrawlDatasourcePlugin(DatasourcePlugin):
|
||||||
|
tenant_id: str
|
||||||
|
icon: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
entity: DatasourceEntity
|
||||||
|
runtime: DatasourceRuntime
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceEntity,
|
||||||
|
runtime: DatasourceRuntime,
|
||||||
|
tenant_id: str,
|
||||||
|
icon: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, runtime)
|
||||||
|
self.tenant_id = tenant_id
|
||||||
|
self.icon = icon
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
def get_website_crawl(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
datasource_parameters: Mapping[str, Any],
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[WebsiteCrawlMessage, None, None]:
|
||||||
|
manager = PluginDatasourceManager()
|
||||||
|
|
||||||
|
return manager.get_website_crawl(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
datasource_provider=self.entity.identity.provider,
|
||||||
|
datasource_name=self.entity.identity.name,
|
||||||
|
credentials=self.runtime.credentials,
|
||||||
|
datasource_parameters=datasource_parameters,
|
||||||
|
provider_type=provider_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
def datasource_provider_type(self) -> str:
|
||||||
|
return DatasourceProviderType.WEBSITE_CRAWL
|
||||||
@ -0,0 +1,52 @@
|
|||||||
|
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||||
|
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||||
|
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||||
|
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteCrawlDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||||
|
entity: DatasourceProviderEntityWithPlugin
|
||||||
|
plugin_id: str
|
||||||
|
plugin_unique_identifier: str
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
entity: DatasourceProviderEntityWithPlugin,
|
||||||
|
plugin_id: str,
|
||||||
|
plugin_unique_identifier: str,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(entity, tenant_id)
|
||||||
|
self.plugin_id = plugin_id
|
||||||
|
self.plugin_unique_identifier = plugin_unique_identifier
|
||||||
|
|
||||||
|
@property
|
||||||
|
def provider_type(self) -> DatasourceProviderType:
|
||||||
|
"""
|
||||||
|
returns the type of the provider
|
||||||
|
"""
|
||||||
|
return DatasourceProviderType.WEBSITE_CRAWL
|
||||||
|
|
||||||
|
def get_datasource(self, datasource_name: str) -> WebsiteCrawlDatasourcePlugin: # type: ignore
|
||||||
|
"""
|
||||||
|
return datasource with given name
|
||||||
|
"""
|
||||||
|
datasource_entity = next(
|
||||||
|
(
|
||||||
|
datasource_entity
|
||||||
|
for datasource_entity in self.entity.datasources
|
||||||
|
if datasource_entity.identity.name == datasource_name
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not datasource_entity:
|
||||||
|
raise ValueError(f"Datasource with name {datasource_name} not found")
|
||||||
|
|
||||||
|
return WebsiteCrawlDatasourcePlugin(
|
||||||
|
entity=datasource_entity,
|
||||||
|
runtime=DatasourceRuntime(tenant_id=self.tenant_id),
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
icon=self.entity.identity.icon,
|
||||||
|
plugin_unique_identifier=self.plugin_unique_identifier,
|
||||||
|
)
|
||||||
@ -0,0 +1,15 @@
|
|||||||
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
|
from core.datasource import datasource_file_manager
|
||||||
|
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||||
|
|
||||||
|
tool_file_manager: dict[str, Any] = {"manager": None}
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceFileParser:
|
||||||
|
@staticmethod
|
||||||
|
def get_datasource_file_manager() -> "DatasourceFileManager":
|
||||||
|
return cast("DatasourceFileManager", datasource_file_manager["manager"])
|
||||||
@ -0,0 +1,21 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from core.entities.provider_entities import ProviderConfig
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthSchema(BaseModel):
|
||||||
|
"""
|
||||||
|
OAuth schema
|
||||||
|
"""
|
||||||
|
|
||||||
|
client_schema: Sequence[ProviderConfig] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="client schema like client_id, client_secret, etc.",
|
||||||
|
)
|
||||||
|
|
||||||
|
credentials_schema: Sequence[ProviderConfig] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="credentials schema like access_token, refresh_token, etc.",
|
||||||
|
)
|
||||||
@ -0,0 +1,329 @@
|
|||||||
|
from collections.abc import Generator, Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceMessage,
|
||||||
|
GetOnlineDocumentPageContentRequest,
|
||||||
|
OnlineDocumentPagesMessage,
|
||||||
|
OnlineDriveBrowseFilesRequest,
|
||||||
|
OnlineDriveBrowseFilesResponse,
|
||||||
|
OnlineDriveDownloadFileRequest,
|
||||||
|
WebsiteCrawlMessage,
|
||||||
|
)
|
||||||
|
from core.plugin.entities.plugin import DatasourceProviderID, GenericProviderID
|
||||||
|
from core.plugin.entities.plugin_daemon import (
|
||||||
|
PluginBasicBooleanResponse,
|
||||||
|
PluginDatasourceProviderEntity,
|
||||||
|
)
|
||||||
|
from core.plugin.impl.base import BasePluginClient
|
||||||
|
from services.tools.tools_transform_service import ToolTransformService
|
||||||
|
|
||||||
|
|
||||||
|
class PluginDatasourceManager(BasePluginClient):
|
||||||
|
def fetch_datasource_providers(self, tenant_id: str) -> list[PluginDatasourceProviderEntity]:
|
||||||
|
"""
|
||||||
|
Fetch datasource providers for the given tenant.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def transformer(json_response: dict[str, Any]) -> dict:
|
||||||
|
if json_response.get("data"):
|
||||||
|
for provider in json_response.get("data", []):
|
||||||
|
declaration = provider.get("declaration", {}) or {}
|
||||||
|
provider_name = declaration.get("identity", {}).get("name")
|
||||||
|
for datasource in declaration.get("datasources", []):
|
||||||
|
datasource["identity"]["provider"] = provider_name
|
||||||
|
|
||||||
|
return json_response
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response(
|
||||||
|
"GET",
|
||||||
|
f"plugin/{tenant_id}/management/datasources",
|
||||||
|
list[PluginDatasourceProviderEntity],
|
||||||
|
params={"page": 1, "page_size": 256},
|
||||||
|
transformer=transformer,
|
||||||
|
)
|
||||||
|
local_file_datasource_provider = PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||||
|
|
||||||
|
for provider in response:
|
||||||
|
ToolTransformService.repack_provider(tenant_id=tenant_id, provider=provider)
|
||||||
|
all_response = [local_file_datasource_provider] + response
|
||||||
|
|
||||||
|
for provider in all_response:
|
||||||
|
provider.declaration.identity.name = f"{provider.plugin_id}/{provider.declaration.identity.name}"
|
||||||
|
|
||||||
|
# override the provider name for each tool to plugin_id/provider_name
|
||||||
|
for tool in provider.declaration.datasources:
|
||||||
|
tool.identity.provider = provider.declaration.identity.name
|
||||||
|
|
||||||
|
return all_response
|
||||||
|
|
||||||
|
def fetch_datasource_provider(self, tenant_id: str, provider_id: str) -> PluginDatasourceProviderEntity:
|
||||||
|
"""
|
||||||
|
Fetch datasource provider for the given tenant and plugin.
|
||||||
|
"""
|
||||||
|
if provider_id == "langgenius/file/file":
|
||||||
|
return PluginDatasourceProviderEntity(**self._get_local_file_datasource_provider())
|
||||||
|
|
||||||
|
tool_provider_id = DatasourceProviderID(provider_id)
|
||||||
|
|
||||||
|
def transformer(json_response: dict[str, Any]) -> dict:
|
||||||
|
data = json_response.get("data")
|
||||||
|
if data:
|
||||||
|
for datasource in data.get("declaration", {}).get("datasources", []):
|
||||||
|
datasource["identity"]["provider"] = tool_provider_id.provider_name
|
||||||
|
|
||||||
|
return json_response
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response(
|
||||||
|
"GET",
|
||||||
|
f"plugin/{tenant_id}/management/datasource",
|
||||||
|
PluginDatasourceProviderEntity,
|
||||||
|
params={"provider": tool_provider_id.provider_name, "plugin_id": tool_provider_id.plugin_id},
|
||||||
|
transformer=transformer,
|
||||||
|
)
|
||||||
|
|
||||||
|
response.declaration.identity.name = f"{response.plugin_id}/{response.declaration.identity.name}"
|
||||||
|
|
||||||
|
# override the provider name for each tool to plugin_id/provider_name
|
||||||
|
for datasource in response.declaration.datasources:
|
||||||
|
datasource.identity.provider = response.declaration.identity.name
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
def get_website_crawl(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
datasource_provider: str,
|
||||||
|
datasource_name: str,
|
||||||
|
credentials: dict[str, Any],
|
||||||
|
datasource_parameters: Mapping[str, Any],
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[WebsiteCrawlMessage, None, None]:
|
||||||
|
"""
|
||||||
|
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||||
|
|
||||||
|
return self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/datasource/get_website_crawl",
|
||||||
|
WebsiteCrawlMessage,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": datasource_provider_id.provider_name,
|
||||||
|
"datasource": datasource_name,
|
||||||
|
"credentials": credentials,
|
||||||
|
"datasource_parameters": datasource_parameters,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_online_document_pages(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
datasource_provider: str,
|
||||||
|
datasource_name: str,
|
||||||
|
credentials: dict[str, Any],
|
||||||
|
datasource_parameters: Mapping[str, Any],
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[OnlineDocumentPagesMessage, None, None]:
|
||||||
|
"""
|
||||||
|
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||||
|
|
||||||
|
return self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_pages",
|
||||||
|
OnlineDocumentPagesMessage,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": datasource_provider_id.provider_name,
|
||||||
|
"datasource": datasource_name,
|
||||||
|
"credentials": credentials,
|
||||||
|
"datasource_parameters": datasource_parameters,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_online_document_page_content(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
datasource_provider: str,
|
||||||
|
datasource_name: str,
|
||||||
|
credentials: dict[str, Any],
|
||||||
|
datasource_parameters: GetOnlineDocumentPageContentRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[DatasourceMessage, None, None]:
|
||||||
|
"""
|
||||||
|
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||||
|
|
||||||
|
return self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/datasource/get_online_document_page_content",
|
||||||
|
DatasourceMessage,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": datasource_provider_id.provider_name,
|
||||||
|
"datasource": datasource_name,
|
||||||
|
"credentials": credentials,
|
||||||
|
"page": datasource_parameters.model_dump(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def online_drive_browse_files(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
datasource_provider: str,
|
||||||
|
datasource_name: str,
|
||||||
|
credentials: dict[str, Any],
|
||||||
|
request: OnlineDriveBrowseFilesRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[OnlineDriveBrowseFilesResponse, None, None]:
|
||||||
|
"""
|
||||||
|
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/datasource/online_drive_browse_files",
|
||||||
|
OnlineDriveBrowseFilesResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": datasource_provider_id.provider_name,
|
||||||
|
"datasource": datasource_name,
|
||||||
|
"credentials": credentials,
|
||||||
|
"request": request.model_dump(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield from response
|
||||||
|
|
||||||
|
def online_drive_download_file(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
datasource_provider: str,
|
||||||
|
datasource_name: str,
|
||||||
|
credentials: dict[str, Any],
|
||||||
|
request: OnlineDriveDownloadFileRequest,
|
||||||
|
provider_type: str,
|
||||||
|
) -> Generator[DatasourceMessage, None, None]:
|
||||||
|
"""
|
||||||
|
Invoke the datasource with the given tenant, user, plugin, provider, name, credentials and parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
datasource_provider_id = GenericProviderID(datasource_provider)
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/datasource/online_drive_download_file",
|
||||||
|
DatasourceMessage,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": datasource_provider_id.provider_name,
|
||||||
|
"datasource": datasource_name,
|
||||||
|
"credentials": credentials,
|
||||||
|
"request": request.model_dump(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": datasource_provider_id.plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield from response
|
||||||
|
|
||||||
|
def validate_provider_credentials(
|
||||||
|
self, tenant_id: str, user_id: str, provider: str, plugin_id: str, credentials: dict[str, Any]
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
validate the credentials of the provider
|
||||||
|
"""
|
||||||
|
# datasource_provider_id = GenericProviderID(provider_id)
|
||||||
|
|
||||||
|
response = self._request_with_plugin_daemon_response_stream(
|
||||||
|
"POST",
|
||||||
|
f"plugin/{tenant_id}/dispatch/datasource/validate_credentials",
|
||||||
|
PluginBasicBooleanResponse,
|
||||||
|
data={
|
||||||
|
"user_id": user_id,
|
||||||
|
"data": {
|
||||||
|
"provider": provider,
|
||||||
|
"credentials": credentials,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers={
|
||||||
|
"X-Plugin-ID": plugin_id,
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for resp in response:
|
||||||
|
return resp.result
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _get_local_file_datasource_provider(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"id": "langgenius/file/file",
|
||||||
|
"plugin_id": "langgenius/file",
|
||||||
|
"provider": "file",
|
||||||
|
"plugin_unique_identifier": "langgenius/file:0.0.1@dify",
|
||||||
|
"declaration": {
|
||||||
|
"identity": {
|
||||||
|
"author": "langgenius",
|
||||||
|
"name": "file",
|
||||||
|
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||||
|
"icon": "https://assets.dify.ai/images/File%20Upload.svg",
|
||||||
|
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||||
|
},
|
||||||
|
"credentials_schema": [],
|
||||||
|
"provider_type": "local_file",
|
||||||
|
"datasources": [
|
||||||
|
{
|
||||||
|
"identity": {
|
||||||
|
"author": "langgenius",
|
||||||
|
"name": "upload-file",
|
||||||
|
"provider": "file",
|
||||||
|
"label": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||||
|
},
|
||||||
|
"parameters": [],
|
||||||
|
"description": {"zh_Hans": "File", "en_US": "File", "pt_BR": "File", "ja_JP": "File"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
}
|
||||||
@ -0,0 +1,38 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceStreamEvent(Enum):
|
||||||
|
"""
|
||||||
|
Datasource Stream event
|
||||||
|
"""
|
||||||
|
|
||||||
|
PROCESSING = "datasource_processing"
|
||||||
|
COMPLETED = "datasource_completed"
|
||||||
|
ERROR = "datasource_error"
|
||||||
|
|
||||||
|
|
||||||
|
class BaseDatasourceEvent(BaseModel):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceErrorEvent(BaseDatasourceEvent):
|
||||||
|
event: str = DatasourceStreamEvent.ERROR.value
|
||||||
|
error: str = Field(..., description="error message")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceCompletedEvent(BaseDatasourceEvent):
|
||||||
|
event: str = DatasourceStreamEvent.COMPLETED.value
|
||||||
|
data: Mapping[str, Any] | list = Field(..., description="result")
|
||||||
|
total: Optional[int] = Field(default=0, description="total")
|
||||||
|
completed: Optional[int] = Field(default=0, description="completed")
|
||||||
|
time_consuming: Optional[float] = Field(default=0.0, description="time consuming")
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceProcessingEvent(BaseDatasourceEvent):
|
||||||
|
event: str = DatasourceStreamEvent.PROCESSING.value
|
||||||
|
total: Optional[int] = Field(..., description="total")
|
||||||
|
completed: Optional[int] = Field(..., description="completed")
|
||||||
@ -1,3 +1,4 @@
|
|||||||
SYSTEM_VARIABLE_NODE_ID = "sys"
|
SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||||
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||||
|
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"
|
||||||
|
|||||||
@ -0,0 +1,3 @@
|
|||||||
|
from .datasource_node import DatasourceNode
|
||||||
|
|
||||||
|
__all__ = ["DatasourceNode"]
|
||||||
@ -0,0 +1,468 @@
|
|||||||
|
from collections.abc import Generator, Mapping, Sequence
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from core.datasource.entities.datasource_entities import (
|
||||||
|
DatasourceMessage,
|
||||||
|
DatasourceParameter,
|
||||||
|
DatasourceProviderType,
|
||||||
|
GetOnlineDocumentPageContentRequest,
|
||||||
|
OnlineDriveDownloadFileRequest,
|
||||||
|
)
|
||||||
|
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||||
|
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||||
|
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||||
|
from core.file import File
|
||||||
|
from core.file.enums import FileTransferMethod, FileType
|
||||||
|
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||||
|
from core.variables.segments import ArrayAnySegment
|
||||||
|
from core.variables.variables import ArrayAnyVariable
|
||||||
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool, VariableValue
|
||||||
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
|
from core.workflow.enums import SystemVariableKey
|
||||||
|
from core.workflow.nodes.base import BaseNode
|
||||||
|
from core.workflow.nodes.enums import NodeType
|
||||||
|
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
|
||||||
|
from core.workflow.nodes.tool.exc import ToolFileError
|
||||||
|
from core.workflow.utils.variable_template_parser import VariableTemplateParser
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from factories import file_factory
|
||||||
|
from models.model import UploadFile
|
||||||
|
from services.datasource_provider_service import DatasourceProviderService
|
||||||
|
|
||||||
|
from ...entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||||
|
from .entities import DatasourceNodeData
|
||||||
|
from .exc import DatasourceNodeError, DatasourceParameterError
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceNode(BaseNode[DatasourceNodeData]):
|
||||||
|
"""
|
||||||
|
Datasource Node
|
||||||
|
"""
|
||||||
|
|
||||||
|
_node_data_cls = DatasourceNodeData
|
||||||
|
_node_type = NodeType.DATASOURCE
|
||||||
|
|
||||||
|
def _run(self) -> Generator:
|
||||||
|
"""
|
||||||
|
Run the datasource node
|
||||||
|
"""
|
||||||
|
|
||||||
|
node_data = cast(DatasourceNodeData, self.node_data)
|
||||||
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
|
datasource_type = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE.value])
|
||||||
|
if not datasource_type:
|
||||||
|
raise DatasourceNodeError("Datasource type is not set")
|
||||||
|
datasource_type = datasource_type.value
|
||||||
|
datasource_info = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO.value])
|
||||||
|
if not datasource_info:
|
||||||
|
raise DatasourceNodeError("Datasource info is not set")
|
||||||
|
datasource_info = datasource_info.value
|
||||||
|
# get datasource runtime
|
||||||
|
try:
|
||||||
|
from core.datasource.datasource_manager import DatasourceManager
|
||||||
|
|
||||||
|
if datasource_type is None:
|
||||||
|
raise DatasourceNodeError("Datasource type is not set")
|
||||||
|
|
||||||
|
datasource_runtime = DatasourceManager.get_datasource_runtime(
|
||||||
|
provider_id=f"{node_data.plugin_id}/{node_data.provider_name}",
|
||||||
|
datasource_name=node_data.datasource_name or "",
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
datasource_type=DatasourceProviderType.value_of(datasource_type),
|
||||||
|
)
|
||||||
|
except DatasourceNodeError as e:
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs={},
|
||||||
|
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||||
|
error=f"Failed to get datasource runtime: {str(e)}",
|
||||||
|
error_type=type(e).__name__,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# get parameters
|
||||||
|
datasource_parameters = datasource_runtime.entity.parameters
|
||||||
|
parameters = self._generate_parameters(
|
||||||
|
datasource_parameters=datasource_parameters,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
node_data=self.node_data,
|
||||||
|
)
|
||||||
|
parameters_for_log = self._generate_parameters(
|
||||||
|
datasource_parameters=datasource_parameters,
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
node_data=self.node_data,
|
||||||
|
for_log=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
match datasource_type:
|
||||||
|
case DatasourceProviderType.ONLINE_DOCUMENT:
|
||||||
|
datasource_runtime = cast(OnlineDocumentDatasourcePlugin, datasource_runtime)
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
credentials = datasource_provider_service.get_real_datasource_credentials(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
provider=node_data.provider_name,
|
||||||
|
plugin_id=node_data.plugin_id,
|
||||||
|
)
|
||||||
|
if credentials:
|
||||||
|
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
|
||||||
|
online_document_result: Generator[DatasourceMessage, None, None] = (
|
||||||
|
datasource_runtime.get_online_document_page_content(
|
||||||
|
user_id=self.user_id,
|
||||||
|
datasource_parameters=GetOnlineDocumentPageContentRequest(
|
||||||
|
workspace_id=datasource_info.get("workspace_id"),
|
||||||
|
page_id=datasource_info.get("page").get("page_id"),
|
||||||
|
type=datasource_info.get("page").get("type"),
|
||||||
|
),
|
||||||
|
provider_type=datasource_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield from self._transform_message(
|
||||||
|
messages=online_document_result,
|
||||||
|
parameters_for_log=parameters_for_log,
|
||||||
|
datasource_info=datasource_info,
|
||||||
|
)
|
||||||
|
case DatasourceProviderType.ONLINE_DRIVE:
|
||||||
|
datasource_runtime = cast(OnlineDriveDatasourcePlugin, datasource_runtime)
|
||||||
|
datasource_provider_service = DatasourceProviderService()
|
||||||
|
credentials = datasource_provider_service.get_real_datasource_credentials(
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
provider=node_data.provider_name,
|
||||||
|
plugin_id=node_data.plugin_id,
|
||||||
|
)
|
||||||
|
if credentials:
|
||||||
|
datasource_runtime.runtime.credentials = credentials[0].get("credentials")
|
||||||
|
online_drive_result: Generator[DatasourceMessage, None, None] = (
|
||||||
|
datasource_runtime.online_drive_download_file(
|
||||||
|
user_id=self.user_id,
|
||||||
|
request=OnlineDriveDownloadFileRequest(
|
||||||
|
key=datasource_info.get("key"),
|
||||||
|
bucket=datasource_info.get("bucket"),
|
||||||
|
),
|
||||||
|
provider_type=datasource_type,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield from self._transform_message(
|
||||||
|
messages=online_drive_result,
|
||||||
|
parameters_for_log=parameters_for_log,
|
||||||
|
datasource_info=datasource_info,
|
||||||
|
)
|
||||||
|
case DatasourceProviderType.WEBSITE_CRAWL:
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
inputs=parameters_for_log,
|
||||||
|
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||||
|
outputs={
|
||||||
|
**datasource_info,
|
||||||
|
"datasource_type": datasource_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
case DatasourceProviderType.LOCAL_FILE:
|
||||||
|
related_id = datasource_info.get("related_id")
|
||||||
|
if not related_id:
|
||||||
|
raise DatasourceNodeError("File is not exist")
|
||||||
|
upload_file = db.session.query(UploadFile).filter(UploadFile.id == related_id).first()
|
||||||
|
if not upload_file:
|
||||||
|
raise ValueError("Invalid upload file Info")
|
||||||
|
|
||||||
|
file_info = File(
|
||||||
|
id=upload_file.id,
|
||||||
|
filename=upload_file.name,
|
||||||
|
extension="." + upload_file.extension,
|
||||||
|
mime_type=upload_file.mime_type,
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
type=FileType.CUSTOM,
|
||||||
|
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||||
|
remote_url=upload_file.source_url,
|
||||||
|
related_id=upload_file.id,
|
||||||
|
size=upload_file.size,
|
||||||
|
storage_key=upload_file.key,
|
||||||
|
)
|
||||||
|
variable_pool.add([self.node_id, "file"], [file_info])
|
||||||
|
for key, value in datasource_info.items():
|
||||||
|
# construct new key list
|
||||||
|
new_key_list = ["file", key]
|
||||||
|
self._append_variables_recursively(
|
||||||
|
variable_pool=variable_pool,
|
||||||
|
node_id=self.node_id,
|
||||||
|
variable_key_list=new_key_list,
|
||||||
|
variable_value=value,
|
||||||
|
)
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
inputs=parameters_for_log,
|
||||||
|
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||||
|
outputs={
|
||||||
|
"file_info": datasource_info,
|
||||||
|
"datasource_type": datasource_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
raise DatasourceNodeError(f"Unsupported datasource provider: {datasource_type}")
|
||||||
|
except PluginDaemonClientSideError as e:
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=parameters_for_log,
|
||||||
|
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||||
|
error=f"Failed to transform datasource message: {str(e)}",
|
||||||
|
error_type=type(e).__name__,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except DatasourceNodeError as e:
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=parameters_for_log,
|
||||||
|
metadata={WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info},
|
||||||
|
error=f"Failed to invoke datasource: {str(e)}",
|
||||||
|
error_type=type(e).__name__,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_parameters(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
datasource_parameters: Sequence[DatasourceParameter],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
node_data: DatasourceNodeData,
|
||||||
|
for_log: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||||
|
variable_pool (VariablePool): The variable pool containing the variables.
|
||||||
|
node_data (ToolNodeData): The data associated with the tool node.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||||
|
|
||||||
|
"""
|
||||||
|
datasource_parameters_dictionary = {parameter.name: parameter for parameter in datasource_parameters}
|
||||||
|
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
if node_data.datasource_parameters:
|
||||||
|
for parameter_name in node_data.datasource_parameters:
|
||||||
|
parameter = datasource_parameters_dictionary.get(parameter_name)
|
||||||
|
if not parameter:
|
||||||
|
result[parameter_name] = None
|
||||||
|
continue
|
||||||
|
datasource_input = node_data.datasource_parameters[parameter_name]
|
||||||
|
if datasource_input.type == "variable":
|
||||||
|
variable = variable_pool.get(datasource_input.value)
|
||||||
|
if variable is None:
|
||||||
|
raise DatasourceParameterError(f"Variable {datasource_input.value} does not exist")
|
||||||
|
parameter_value = variable.value
|
||||||
|
elif datasource_input.type in {"mixed", "constant"}:
|
||||||
|
segment_group = variable_pool.convert_template(str(datasource_input.value))
|
||||||
|
parameter_value = segment_group.log if for_log else segment_group.text
|
||||||
|
else:
|
||||||
|
raise DatasourceParameterError(f"Unknown datasource input type '{datasource_input.type}'")
|
||||||
|
result[parameter_name] = parameter_value
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _fetch_files(self, variable_pool: VariablePool) -> list[File]:
|
||||||
|
variable = variable_pool.get(["sys", SystemVariableKey.FILES.value])
|
||||||
|
assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment)
|
||||||
|
return list(variable.value) if variable else []
|
||||||
|
|
||||||
|
def _append_variables_recursively(
|
||||||
|
self, variable_pool: VariablePool, node_id: str, variable_key_list: list[str], variable_value: VariableValue
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Append variables recursively
|
||||||
|
:param node_id: node id
|
||||||
|
:param variable_key_list: variable key list
|
||||||
|
:param variable_value: variable value
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
variable_pool.add([node_id] + variable_key_list, variable_value)
|
||||||
|
|
||||||
|
# if variable_value is a dict, then recursively append variables
|
||||||
|
if isinstance(variable_value, dict):
|
||||||
|
for key, value in variable_value.items():
|
||||||
|
# construct new key list
|
||||||
|
new_key_list = variable_key_list + [key]
|
||||||
|
self._append_variables_recursively(
|
||||||
|
variable_pool=variable_pool, node_id=node_id, variable_key_list=new_key_list, variable_value=value
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _extract_variable_selector_to_variable_mapping(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
graph_config: Mapping[str, Any],
|
||||||
|
node_id: str,
|
||||||
|
node_data: DatasourceNodeData,
|
||||||
|
) -> Mapping[str, Sequence[str]]:
|
||||||
|
"""
|
||||||
|
Extract variable selector to variable mapping
|
||||||
|
:param graph_config: graph config
|
||||||
|
:param node_id: node id
|
||||||
|
:param node_data: node data
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
result = {}
|
||||||
|
if node_data.datasource_parameters:
|
||||||
|
for parameter_name in node_data.datasource_parameters:
|
||||||
|
input = node_data.datasource_parameters[parameter_name]
|
||||||
|
if input.type == "mixed":
|
||||||
|
assert isinstance(input.value, str)
|
||||||
|
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||||
|
for selector in selectors:
|
||||||
|
result[selector.variable] = selector.value_selector
|
||||||
|
elif input.type == "variable":
|
||||||
|
result[parameter_name] = input.value
|
||||||
|
elif input.type == "constant":
|
||||||
|
pass
|
||||||
|
|
||||||
|
result = {node_id + "." + key: value for key, value in result.items()}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _transform_message(
|
||||||
|
self,
|
||||||
|
messages: Generator[DatasourceMessage, None, None],
|
||||||
|
parameters_for_log: dict[str, Any],
|
||||||
|
datasource_info: dict[str, Any],
|
||||||
|
) -> Generator:
|
||||||
|
"""
|
||||||
|
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||||
|
"""
|
||||||
|
# transform message and handle file storage
|
||||||
|
message_stream = DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||||
|
messages=messages,
|
||||||
|
user_id=self.user_id,
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
conversation_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
text = ""
|
||||||
|
files: list[File] = []
|
||||||
|
json: list[dict] = []
|
||||||
|
|
||||||
|
variables: dict[str, Any] = {}
|
||||||
|
|
||||||
|
for message in message_stream:
|
||||||
|
if message.type in {
|
||||||
|
DatasourceMessage.MessageType.IMAGE_LINK,
|
||||||
|
DatasourceMessage.MessageType.BINARY_LINK,
|
||||||
|
DatasourceMessage.MessageType.IMAGE,
|
||||||
|
}:
|
||||||
|
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||||
|
|
||||||
|
url = message.message.text
|
||||||
|
if message.meta:
|
||||||
|
transfer_method = message.meta.get("transfer_method", FileTransferMethod.DATASOURCE_FILE)
|
||||||
|
else:
|
||||||
|
transfer_method = FileTransferMethod.DATASOURCE_FILE
|
||||||
|
|
||||||
|
datasource_file_id = str(url).split("/")[-1].split(".")[0]
|
||||||
|
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
stmt = select(UploadFile).where(UploadFile.id == datasource_file_id)
|
||||||
|
datasource_file = session.scalar(stmt)
|
||||||
|
if datasource_file is None:
|
||||||
|
raise ToolFileError(f"Tool file {datasource_file_id} does not exist")
|
||||||
|
|
||||||
|
mapping = {
|
||||||
|
"datasource_file_id": datasource_file_id,
|
||||||
|
"type": file_factory.get_file_type_by_mime_type(datasource_file.mime_type),
|
||||||
|
"transfer_method": transfer_method,
|
||||||
|
"url": url,
|
||||||
|
}
|
||||||
|
file = file_factory.build_from_mapping(
|
||||||
|
mapping=mapping,
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
)
|
||||||
|
files.append(file)
|
||||||
|
elif message.type == DatasourceMessage.MessageType.BLOB:
|
||||||
|
# get tool file id
|
||||||
|
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||||
|
assert message.meta
|
||||||
|
|
||||||
|
datasource_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||||
|
with Session(db.engine) as session:
|
||||||
|
stmt = select(UploadFile).where(UploadFile.id == datasource_file_id)
|
||||||
|
datasource_file = session.scalar(stmt)
|
||||||
|
if datasource_file is None:
|
||||||
|
raise ToolFileError(f"datasource file {datasource_file_id} not exists")
|
||||||
|
|
||||||
|
mapping = {
|
||||||
|
"datasource_file_id": datasource_file_id,
|
||||||
|
"transfer_method": FileTransferMethod.DATASOURCE_FILE,
|
||||||
|
}
|
||||||
|
|
||||||
|
files.append(
|
||||||
|
file_factory.build_from_mapping(
|
||||||
|
mapping=mapping,
|
||||||
|
tenant_id=self.tenant_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
elif message.type == DatasourceMessage.MessageType.TEXT:
|
||||||
|
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||||
|
text += message.message.text
|
||||||
|
yield RunStreamChunkEvent(
|
||||||
|
chunk_content=message.message.text, from_variable_selector=[self.node_id, "text"]
|
||||||
|
)
|
||||||
|
elif message.type == DatasourceMessage.MessageType.JSON:
|
||||||
|
assert isinstance(message.message, DatasourceMessage.JsonMessage)
|
||||||
|
if self.node_type == NodeType.AGENT:
|
||||||
|
msg_metadata = message.message.json_object.pop("execution_metadata", {})
|
||||||
|
agent_execution_metadata = {
|
||||||
|
key: value
|
||||||
|
for key, value in msg_metadata.items()
|
||||||
|
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||||
|
}
|
||||||
|
json.append(message.message.json_object)
|
||||||
|
elif message.type == DatasourceMessage.MessageType.LINK:
|
||||||
|
assert isinstance(message.message, DatasourceMessage.TextMessage)
|
||||||
|
stream_text = f"Link: {message.message.text}\n"
|
||||||
|
text += stream_text
|
||||||
|
yield RunStreamChunkEvent(chunk_content=stream_text, from_variable_selector=[self.node_id, "text"])
|
||||||
|
elif message.type == DatasourceMessage.MessageType.VARIABLE:
|
||||||
|
assert isinstance(message.message, DatasourceMessage.VariableMessage)
|
||||||
|
variable_name = message.message.variable_name
|
||||||
|
variable_value = message.message.variable_value
|
||||||
|
if message.message.stream:
|
||||||
|
if not isinstance(variable_value, str):
|
||||||
|
raise ValueError("When 'stream' is True, 'variable_value' must be a string.")
|
||||||
|
if variable_name not in variables:
|
||||||
|
variables[variable_name] = ""
|
||||||
|
variables[variable_name] += variable_value
|
||||||
|
|
||||||
|
yield RunStreamChunkEvent(
|
||||||
|
chunk_content=variable_value, from_variable_selector=[self.node_id, variable_name]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
variables[variable_name] = variable_value
|
||||||
|
elif message.type == DatasourceMessage.MessageType.FILE:
|
||||||
|
assert message.meta is not None
|
||||||
|
files.append(message.meta["file"])
|
||||||
|
|
||||||
|
yield RunCompletedEvent(
|
||||||
|
run_result=NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
outputs={"json": json, "files": files, **variables, "text": text},
|
||||||
|
metadata={
|
||||||
|
WorkflowNodeExecutionMetadataKey.DATASOURCE_INFO: datasource_info,
|
||||||
|
},
|
||||||
|
inputs=parameters_for_log,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
@ -0,0 +1,41 @@
|
|||||||
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, field_validator
|
||||||
|
from pydantic_core.core_schema import ValidationInfo
|
||||||
|
|
||||||
|
from core.workflow.nodes.base.entities import BaseNodeData
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceEntity(BaseModel):
|
||||||
|
plugin_id: str
|
||||||
|
provider_name: str # redundancy
|
||||||
|
provider_type: str
|
||||||
|
datasource_name: Optional[str] = "local_file"
|
||||||
|
datasource_configurations: dict[str, Any] | None = None
|
||||||
|
plugin_unique_identifier: str | None = None # redundancy
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceNodeData(BaseNodeData, DatasourceEntity):
|
||||||
|
class DatasourceInput(BaseModel):
|
||||||
|
# TODO: check this type
|
||||||
|
value: Union[Any, list[str]]
|
||||||
|
type: Optional[Literal["mixed", "variable", "constant"]] = None
|
||||||
|
|
||||||
|
@field_validator("type", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_type(cls, value, validation_info: ValidationInfo):
|
||||||
|
typ = value
|
||||||
|
value = validation_info.data.get("value")
|
||||||
|
if typ == "mixed" and not isinstance(value, str):
|
||||||
|
raise ValueError("value must be a string")
|
||||||
|
elif typ == "variable":
|
||||||
|
if not isinstance(value, list):
|
||||||
|
raise ValueError("value must be a list")
|
||||||
|
for val in value:
|
||||||
|
if not isinstance(val, str):
|
||||||
|
raise ValueError("value must be a list of strings")
|
||||||
|
elif typ == "constant" and not isinstance(value, str | int | float | bool):
|
||||||
|
raise ValueError("value must be a string, int, float, or bool")
|
||||||
|
return typ
|
||||||
|
|
||||||
|
datasource_parameters: dict[str, DatasourceInput] | None = None
|
||||||
@ -0,0 +1,16 @@
|
|||||||
|
class DatasourceNodeError(ValueError):
|
||||||
|
"""Base exception for datasource node errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceParameterError(DatasourceNodeError):
|
||||||
|
"""Exception raised for errors in datasource parameters."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class DatasourceFileError(DatasourceNodeError):
|
||||||
|
"""Exception raised for errors related to datasource files."""
|
||||||
|
|
||||||
|
pass
|
||||||
@ -0,0 +1,3 @@
|
|||||||
|
from .knowledge_index_node import KnowledgeIndexNode
|
||||||
|
|
||||||
|
__all__ = ["KnowledgeIndexNode"]
|
||||||
@ -0,0 +1,159 @@
|
|||||||
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from core.workflow.nodes.base import BaseNodeData
|
||||||
|
|
||||||
|
|
||||||
|
class RerankingModelConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Reranking Model Config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
reranking_provider_name: str
|
||||||
|
reranking_model_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Vector Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_weight: float
|
||||||
|
embedding_provider_name: str
|
||||||
|
embedding_model_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class KeywordSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Keyword Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
keyword_weight: float
|
||||||
|
|
||||||
|
|
||||||
|
class WeightedScoreConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Weighted score Config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
vector_setting: VectorSetting
|
||||||
|
keyword_setting: KeywordSetting
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Embedding Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
embedding_provider_name: str
|
||||||
|
embedding_model_name: str
|
||||||
|
|
||||||
|
|
||||||
|
class EconomySetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Economy Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
keyword_number: int
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalSetting(BaseModel):
|
||||||
|
"""
|
||||||
|
Retrieval Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
search_method: Literal["semantic_search", "keyword_search", "fulltext_search", "hybrid_search"]
|
||||||
|
top_k: int
|
||||||
|
score_threshold: Optional[float] = 0.5
|
||||||
|
score_threshold_enabled: bool = False
|
||||||
|
reranking_mode: str = "reranking_model"
|
||||||
|
reranking_enable: bool = True
|
||||||
|
reranking_model: Optional[RerankingModelConfig] = None
|
||||||
|
weights: Optional[WeightedScoreConfig] = None
|
||||||
|
|
||||||
|
|
||||||
|
class IndexMethod(BaseModel):
|
||||||
|
"""
|
||||||
|
Knowledge Index Setting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
indexing_technique: Literal["high_quality", "economy"]
|
||||||
|
embedding_setting: EmbeddingSetting
|
||||||
|
economy_setting: EconomySetting
|
||||||
|
|
||||||
|
|
||||||
|
class FileInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
File Info.
|
||||||
|
"""
|
||||||
|
|
||||||
|
file_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentIcon(BaseModel):
|
||||||
|
"""
|
||||||
|
Document Icon.
|
||||||
|
"""
|
||||||
|
|
||||||
|
icon_url: str
|
||||||
|
icon_type: str
|
||||||
|
icon_emoji: str
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineDocumentInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
Online document info.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: str
|
||||||
|
workspace_id: str
|
||||||
|
page_id: str
|
||||||
|
page_type: str
|
||||||
|
icon: OnlineDocumentIcon
|
||||||
|
|
||||||
|
|
||||||
|
class WebsiteInfo(BaseModel):
|
||||||
|
"""
|
||||||
|
website import info.
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: str
|
||||||
|
url: str
|
||||||
|
|
||||||
|
|
||||||
|
class GeneralStructureChunk(BaseModel):
|
||||||
|
"""
|
||||||
|
General Structure Chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
general_chunks: list[str]
|
||||||
|
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
|
||||||
|
|
||||||
|
|
||||||
|
class ParentChildChunk(BaseModel):
|
||||||
|
"""
|
||||||
|
Parent Child Chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parent_content: str
|
||||||
|
child_contents: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ParentChildStructureChunk(BaseModel):
|
||||||
|
"""
|
||||||
|
Parent Child Structure Chunk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
parent_child_chunks: list[ParentChildChunk]
|
||||||
|
data_source_info: Union[FileInfo, OnlineDocumentInfo, WebsiteInfo]
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeIndexNodeData(BaseNodeData):
|
||||||
|
"""
|
||||||
|
Knowledge index Node Data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
type: str = "knowledge-index"
|
||||||
|
chunk_structure: str
|
||||||
|
index_chunk_variable_selector: list[str]
|
||||||
@ -0,0 +1,22 @@
|
|||||||
|
class KnowledgeIndexNodeError(ValueError):
|
||||||
|
"""Base class for KnowledgeIndexNode errors."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotExistError(KnowledgeIndexNodeError):
|
||||||
|
"""Raised when the model does not exist."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelCredentialsNotInitializedError(KnowledgeIndexNodeError):
|
||||||
|
"""Raised when the model credentials are not initialized."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelNotSupportedError(KnowledgeIndexNodeError):
|
||||||
|
"""Raised when the model is not supported."""
|
||||||
|
|
||||||
|
|
||||||
|
class ModelQuotaExceededError(KnowledgeIndexNodeError):
|
||||||
|
"""Raised when the model provider quota is exceeded."""
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidModelTypeError(KnowledgeIndexNodeError):
|
||||||
|
"""Raised when the model is not a Large Language Model."""
|
||||||
@ -0,0 +1,165 @@
|
|||||||
|
import datetime
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
|
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||||
|
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||||
|
from core.workflow.entities.node_entities import NodeRunResult
|
||||||
|
from core.workflow.entities.variable_pool import VariablePool
|
||||||
|
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||||
|
from core.workflow.enums import SystemVariableKey
|
||||||
|
from core.workflow.nodes.enums import NodeType
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from models.dataset import Dataset, Document, DocumentSegment
|
||||||
|
|
||||||
|
from ..base import BaseNode
|
||||||
|
from .entities import KnowledgeIndexNodeData
|
||||||
|
from .exc import (
|
||||||
|
KnowledgeIndexNodeError,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
default_retrieval_model = {
|
||||||
|
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
|
||||||
|
"reranking_enable": False,
|
||||||
|
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||||
|
"top_k": 2,
|
||||||
|
"score_threshold_enabled": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class KnowledgeIndexNode(BaseNode[KnowledgeIndexNodeData]):
|
||||||
|
_node_data_cls = KnowledgeIndexNodeData # type: ignore
|
||||||
|
_node_type = NodeType.KNOWLEDGE_INDEX
|
||||||
|
|
||||||
|
def _run(self) -> NodeRunResult: # type: ignore
|
||||||
|
node_data = cast(KnowledgeIndexNodeData, self.node_data)
|
||||||
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
|
dataset_id = variable_pool.get(["sys", SystemVariableKey.DATASET_ID])
|
||||||
|
if not dataset_id:
|
||||||
|
raise KnowledgeIndexNodeError("Dataset ID is required.")
|
||||||
|
dataset = db.session.query(Dataset).filter_by(id=dataset_id.value).first()
|
||||||
|
if not dataset:
|
||||||
|
raise KnowledgeIndexNodeError(f"Dataset {dataset_id.value} not found.")
|
||||||
|
|
||||||
|
# extract variables
|
||||||
|
variable = variable_pool.get(node_data.index_chunk_variable_selector)
|
||||||
|
if not variable:
|
||||||
|
raise KnowledgeIndexNodeError("Index chunk variable is required.")
|
||||||
|
invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM])
|
||||||
|
if invoke_from:
|
||||||
|
is_preview = invoke_from.value == InvokeFrom.DEBUGGER.value
|
||||||
|
else:
|
||||||
|
is_preview = False
|
||||||
|
chunks = variable.value
|
||||||
|
variables = {"chunks": chunks}
|
||||||
|
if not chunks:
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error="Chunks is required."
|
||||||
|
)
|
||||||
|
|
||||||
|
# index knowledge
|
||||||
|
try:
|
||||||
|
if is_preview:
|
||||||
|
outputs = self._get_preview_output(node_data.chunk_structure, chunks)
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
inputs=variables,
|
||||||
|
process_data=None,
|
||||||
|
outputs=outputs,
|
||||||
|
)
|
||||||
|
results = self._invoke_knowledge_index(
|
||||||
|
dataset=dataset, node_data=node_data, chunks=chunks, variable_pool=variable_pool
|
||||||
|
)
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, process_data=None, outputs=results
|
||||||
|
)
|
||||||
|
|
||||||
|
except KnowledgeIndexNodeError as e:
|
||||||
|
logger.warning("Error when running knowledge index node")
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=variables,
|
||||||
|
error=str(e),
|
||||||
|
error_type=type(e).__name__,
|
||||||
|
)
|
||||||
|
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||||
|
except Exception as e:
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.FAILED,
|
||||||
|
inputs=variables,
|
||||||
|
error=str(e),
|
||||||
|
error_type=type(e).__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _invoke_knowledge_index(
|
||||||
|
self,
|
||||||
|
dataset: Dataset,
|
||||||
|
node_data: KnowledgeIndexNodeData,
|
||||||
|
chunks: Mapping[str, Any],
|
||||||
|
variable_pool: VariablePool,
|
||||||
|
) -> Any:
|
||||||
|
document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID])
|
||||||
|
if not document_id:
|
||||||
|
raise KnowledgeIndexNodeError("Document ID is required.")
|
||||||
|
batch = variable_pool.get(["sys", SystemVariableKey.BATCH])
|
||||||
|
if not batch:
|
||||||
|
raise KnowledgeIndexNodeError("Batch is required.")
|
||||||
|
document = db.session.query(Document).filter_by(id=document_id.value).first()
|
||||||
|
if not document:
|
||||||
|
raise KnowledgeIndexNodeError(f"Document {document_id.value} not found.")
|
||||||
|
# chunk nodes by chunk size
|
||||||
|
indexing_start_at = time.perf_counter()
|
||||||
|
index_processor = IndexProcessorFactory(dataset.chunk_structure).init_index_processor()
|
||||||
|
index_processor.index(dataset, document, chunks)
|
||||||
|
indexing_end_at = time.perf_counter()
|
||||||
|
document.indexing_latency = indexing_end_at - indexing_start_at
|
||||||
|
# update document status
|
||||||
|
document.indexing_status = "completed"
|
||||||
|
document.completed_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||||
|
document.word_count = (
|
||||||
|
db.session.query(func.sum(DocumentSegment.word_count))
|
||||||
|
.filter(
|
||||||
|
DocumentSegment.document_id == document.id,
|
||||||
|
DocumentSegment.dataset_id == dataset.id,
|
||||||
|
)
|
||||||
|
.scalar()
|
||||||
|
)
|
||||||
|
db.session.add(document)
|
||||||
|
# update document segment status
|
||||||
|
db.session.query(DocumentSegment).filter(
|
||||||
|
DocumentSegment.document_id == document.id,
|
||||||
|
DocumentSegment.dataset_id == dataset.id,
|
||||||
|
).update(
|
||||||
|
{
|
||||||
|
DocumentSegment.status: "completed",
|
||||||
|
DocumentSegment.enabled: True,
|
||||||
|
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"dataset_id": dataset.id,
|
||||||
|
"dataset_name": dataset.name,
|
||||||
|
"batch": batch.value,
|
||||||
|
"document_id": document.id,
|
||||||
|
"document_name": document.name,
|
||||||
|
"created_at": document.created_at.timestamp(),
|
||||||
|
"display_status": document.indexing_status,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _get_preview_output(self, chunk_structure: str, chunks: Mapping[str, Any]) -> Mapping[str, Any]:
|
||||||
|
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
|
||||||
|
return index_processor.format_preview(chunks)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def version(cls) -> str:
|
||||||
|
return "1"
|
||||||
@ -0,0 +1,164 @@
|
|||||||
|
from flask_restful import fields # type: ignore
|
||||||
|
|
||||||
|
from fields.workflow_fields import workflow_partial_fields
|
||||||
|
from libs.helper import AppIconUrlField, TimestampField
|
||||||
|
|
||||||
|
pipeline_detail_kernel_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"name": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"icon_type": fields.String,
|
||||||
|
"icon": fields.String,
|
||||||
|
"icon_background": fields.String,
|
||||||
|
"icon_url": AppIconUrlField,
|
||||||
|
}
|
||||||
|
|
||||||
|
related_app_list = {
|
||||||
|
"data": fields.List(fields.Nested(pipeline_detail_kernel_fields)),
|
||||||
|
"total": fields.Integer,
|
||||||
|
}
|
||||||
|
|
||||||
|
app_detail_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"name": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||||
|
"icon": fields.String,
|
||||||
|
"icon_background": fields.String,
|
||||||
|
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
|
||||||
|
"tracing": fields.Raw,
|
||||||
|
"created_by": fields.String,
|
||||||
|
"created_at": TimestampField,
|
||||||
|
"updated_by": fields.String,
|
||||||
|
"updated_at": TimestampField,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
|
||||||
|
|
||||||
|
app_partial_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"name": fields.String,
|
||||||
|
"description": fields.String(attribute="desc_or_prompt"),
|
||||||
|
"icon_type": fields.String,
|
||||||
|
"icon": fields.String,
|
||||||
|
"icon_background": fields.String,
|
||||||
|
"icon_url": AppIconUrlField,
|
||||||
|
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
|
||||||
|
"created_by": fields.String,
|
||||||
|
"created_at": TimestampField,
|
||||||
|
"updated_by": fields.String,
|
||||||
|
"updated_at": TimestampField,
|
||||||
|
"tags": fields.List(fields.Nested(tag_fields)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
app_pagination_fields = {
|
||||||
|
"page": fields.Integer,
|
||||||
|
"limit": fields.Integer(attribute="per_page"),
|
||||||
|
"total": fields.Integer,
|
||||||
|
"has_more": fields.Boolean(attribute="has_next"),
|
||||||
|
"data": fields.List(fields.Nested(app_partial_fields), attribute="items"),
|
||||||
|
}
|
||||||
|
|
||||||
|
template_fields = {
|
||||||
|
"name": fields.String,
|
||||||
|
"icon": fields.String,
|
||||||
|
"icon_background": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"mode": fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
template_list_fields = {
|
||||||
|
"data": fields.List(fields.Nested(template_fields)),
|
||||||
|
}
|
||||||
|
|
||||||
|
site_fields = {
|
||||||
|
"access_token": fields.String(attribute="code"),
|
||||||
|
"code": fields.String,
|
||||||
|
"title": fields.String,
|
||||||
|
"icon_type": fields.String,
|
||||||
|
"icon": fields.String,
|
||||||
|
"icon_background": fields.String,
|
||||||
|
"icon_url": AppIconUrlField,
|
||||||
|
"description": fields.String,
|
||||||
|
"default_language": fields.String,
|
||||||
|
"chat_color_theme": fields.String,
|
||||||
|
"chat_color_theme_inverted": fields.Boolean,
|
||||||
|
"customize_domain": fields.String,
|
||||||
|
"copyright": fields.String,
|
||||||
|
"privacy_policy": fields.String,
|
||||||
|
"custom_disclaimer": fields.String,
|
||||||
|
"customize_token_strategy": fields.String,
|
||||||
|
"prompt_public": fields.Boolean,
|
||||||
|
"app_base_url": fields.String,
|
||||||
|
"show_workflow_steps": fields.Boolean,
|
||||||
|
"use_icon_as_answer_icon": fields.Boolean,
|
||||||
|
"created_by": fields.String,
|
||||||
|
"created_at": TimestampField,
|
||||||
|
"updated_by": fields.String,
|
||||||
|
"updated_at": TimestampField,
|
||||||
|
}
|
||||||
|
|
||||||
|
deleted_tool_fields = {
|
||||||
|
"type": fields.String,
|
||||||
|
"tool_name": fields.String,
|
||||||
|
"provider_id": fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
app_detail_fields_with_site = {
|
||||||
|
"id": fields.String,
|
||||||
|
"name": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||||
|
"icon_type": fields.String,
|
||||||
|
"icon": fields.String,
|
||||||
|
"icon_background": fields.String,
|
||||||
|
"icon_url": AppIconUrlField,
|
||||||
|
"enable_site": fields.Boolean,
|
||||||
|
"enable_api": fields.Boolean,
|
||||||
|
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
|
||||||
|
"site": fields.Nested(site_fields),
|
||||||
|
"api_base_url": fields.String,
|
||||||
|
"use_icon_as_answer_icon": fields.Boolean,
|
||||||
|
"created_by": fields.String,
|
||||||
|
"created_at": TimestampField,
|
||||||
|
"updated_by": fields.String,
|
||||||
|
"updated_at": TimestampField,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
app_site_fields = {
|
||||||
|
"app_id": fields.String,
|
||||||
|
"access_token": fields.String(attribute="code"),
|
||||||
|
"code": fields.String,
|
||||||
|
"title": fields.String,
|
||||||
|
"icon": fields.String,
|
||||||
|
"icon_background": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"default_language": fields.String,
|
||||||
|
"customize_domain": fields.String,
|
||||||
|
"copyright": fields.String,
|
||||||
|
"privacy_policy": fields.String,
|
||||||
|
"custom_disclaimer": fields.String,
|
||||||
|
"customize_token_strategy": fields.String,
|
||||||
|
"prompt_public": fields.Boolean,
|
||||||
|
"show_workflow_steps": fields.Boolean,
|
||||||
|
"use_icon_as_answer_icon": fields.Boolean,
|
||||||
|
}
|
||||||
|
|
||||||
|
leaked_dependency_fields = {"type": fields.String, "value": fields.Raw, "current_identifier": fields.String}
|
||||||
|
|
||||||
|
pipeline_import_fields = {
|
||||||
|
"id": fields.String,
|
||||||
|
"status": fields.String,
|
||||||
|
"pipeline_id": fields.String,
|
||||||
|
"dataset_id": fields.String,
|
||||||
|
"current_dsl_version": fields.String,
|
||||||
|
"imported_dsl_version": fields.String,
|
||||||
|
"error": fields.String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pipeline_import_check_dependencies_fields = {
|
||||||
|
"leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)),
|
||||||
|
}
|
||||||
@ -0,0 +1 @@
|
|||||||
|
{"not_installed": [], "plugin_install_failed": []}
|
||||||
@ -0,0 +1,113 @@
|
|||||||
|
"""add_pipeline_info
|
||||||
|
|
||||||
|
Revision ID: b35c3db83d09
|
||||||
|
Revises: d28f2004b072
|
||||||
|
Create Date: 2025-05-15 15:58:05.179877
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'b35c3db83d09'
|
||||||
|
down_revision = '0ab65e1cc7fa'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('pipeline_built_in_templates',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('pipeline_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=False),
|
||||||
|
sa.Column('icon', sa.JSON(), nullable=False),
|
||||||
|
sa.Column('copyright', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('privacy_policy', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('position', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('install_count', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('language', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='pipeline_built_in_template_pkey')
|
||||||
|
)
|
||||||
|
op.create_table('pipeline_customized_templates',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('pipeline_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), nullable=False),
|
||||||
|
sa.Column('icon', sa.JSON(), nullable=False),
|
||||||
|
sa.Column('position', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('install_count', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('language', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='pipeline_customized_template_pkey')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
|
||||||
|
batch_op.create_index('pipeline_customized_template_tenant_idx', ['tenant_id'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('pipelines',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('name', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('description', sa.Text(), server_default=sa.text("''::character varying"), nullable=False),
|
||||||
|
sa.Column('mode', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('workflow_id', models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column('is_public', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||||
|
sa.Column('is_published', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||||
|
sa.Column('created_by', models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='pipeline_pkey')
|
||||||
|
)
|
||||||
|
op.create_table('tool_builtin_datasource_providers',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=True),
|
||||||
|
sa.Column('user_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(length=256), nullable=False),
|
||||||
|
sa.Column('encrypted_credentials', sa.Text(), nullable=True),
|
||||||
|
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='tool_builtin_datasource_provider_pkey'),
|
||||||
|
sa.UniqueConstraint('tenant_id', 'provider', name='unique_builtin_datasource_provider')
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('keyword_number', sa.Integer(), server_default=sa.text('10'), nullable=True))
|
||||||
|
batch_op.add_column(sa.Column('icon_info', postgresql.JSONB(astext_type=sa.Text()), nullable=True))
|
||||||
|
batch_op.add_column(sa.Column('runtime_mode', sa.String(length=255), server_default=sa.text("'general'::character varying"), nullable=True))
|
||||||
|
batch_op.add_column(sa.Column('pipeline_id', models.types.StringUUID(), nullable=True))
|
||||||
|
batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=True))
|
||||||
|
|
||||||
|
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('rag_pipeline_variables', sa.Text(), server_default='{}', nullable=False))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('workflows', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('rag_pipeline_variables')
|
||||||
|
|
||||||
|
with op.batch_alter_table('datasets', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('chunk_structure')
|
||||||
|
batch_op.drop_column('pipeline_id')
|
||||||
|
batch_op.drop_column('runtime_mode')
|
||||||
|
batch_op.drop_column('icon_info')
|
||||||
|
batch_op.drop_column('keyword_number')
|
||||||
|
|
||||||
|
op.drop_table('tool_builtin_datasource_providers')
|
||||||
|
op.drop_table('pipelines')
|
||||||
|
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index('pipeline_customized_template_tenant_idx')
|
||||||
|
|
||||||
|
op.drop_table('pipeline_customized_templates')
|
||||||
|
op.drop_table('pipeline_built_in_templates')
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -0,0 +1,33 @@
|
|||||||
|
"""add_pipeline_info_2
|
||||||
|
|
||||||
|
Revision ID: abb18a379e62
|
||||||
|
Revises: b35c3db83d09
|
||||||
|
Create Date: 2025-05-16 16:59:16.423127
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'abb18a379e62'
|
||||||
|
down_revision = 'b35c3db83d09'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('pipelines', schema=None) as batch_op:
|
||||||
|
batch_op.drop_column('mode')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('pipelines', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('mode', sa.VARCHAR(length=255), autoincrement=False, nullable=False))
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
@ -0,0 +1,70 @@
|
|||||||
|
"""add_pipeline_info_3
|
||||||
|
|
||||||
|
Revision ID: c459994abfa8
|
||||||
|
Revises: abb18a379e62
|
||||||
|
Create Date: 2025-05-30 00:33:14.068312
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import models as models
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'c459994abfa8'
|
||||||
|
down_revision = 'abb18a379e62'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('datasource_oauth_params',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('plugin_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('system_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='datasource_oauth_config_pkey'),
|
||||||
|
sa.UniqueConstraint('plugin_id', 'provider', name='datasource_oauth_config_datasource_id_provider_idx')
|
||||||
|
)
|
||||||
|
op.create_table('datasource_providers',
|
||||||
|
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||||
|
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('plugin_id', models.types.StringUUID(), nullable=False),
|
||||||
|
sa.Column('provider', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('auth_type', sa.String(length=255), nullable=False),
|
||||||
|
sa.Column('encrypted_credentials', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||||
|
sa.Column('created_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.Column('updated_at', sa.DateTime(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('id', name='datasource_provider_pkey'),
|
||||||
|
sa.UniqueConstraint('plugin_id', 'provider', name='datasource_provider_plugin_id_provider_idx')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=False))
|
||||||
|
batch_op.add_column(sa.Column('yaml_content', sa.Text(), nullable=False))
|
||||||
|
batch_op.drop_column('pipeline_id')
|
||||||
|
|
||||||
|
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('chunk_structure', sa.String(length=255), nullable=False))
|
||||||
|
batch_op.add_column(sa.Column('yaml_content', sa.Text(), nullable=False))
|
||||||
|
batch_op.drop_column('pipeline_id')
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade():
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
|
||||||
|
with op.batch_alter_table('pipeline_customized_templates', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('pipeline_id', sa.UUID(), autoincrement=False, nullable=False))
|
||||||
|
batch_op.drop_column('yaml_content')
|
||||||
|
batch_op.drop_column('chunk_structure')
|
||||||
|
|
||||||
|
with op.batch_alter_table('pipeline_built_in_templates', schema=None) as batch_op:
|
||||||
|
batch_op.add_column(sa.Column('pipeline_id', sa.UUID(), autoincrement=False, nullable=False))
|
||||||
|
batch_op.drop_column('yaml_content')
|
||||||
|
batch_op.drop_column('chunk_structure')
|
||||||
|
|
||||||
|
op.drop_table('datasource_providers')
|
||||||
|
op.drop_table('datasource_oauth_params')
|
||||||
|
# ### end Alembic commands ###
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue