feat/datasource
parent
3f1363503b
commit
818eb46a8b
@ -0,0 +1,170 @@
|
||||
from flask_login import current_user # type: ignore # type: ignore
|
||||
from flask_restful import Resource, marshal, reqparse # type: ignore
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import DatasetNameDuplicateError
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_rate_limit_check,
|
||||
setup_required,
|
||||
)
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from libs.login import login_required
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import RagPipelineDatasetCreateEntity
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError("Name must be between 1 to 40 characters.")
|
||||
return name
|
||||
|
||||
|
||||
def _validate_description_length(description):
|
||||
if len(description) > 400:
|
||||
raise ValueError("Description cannot exceed 400 characters.")
|
||||
return description
|
||||
|
||||
|
||||
class CreateRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default={},
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||
nullable=True,
|
||||
required=False,
|
||||
default=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"partial_member_list",
|
||||
type=list,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default=[],
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"yaml_content",
|
||||
type=str,
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="yaml_content is required.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
rag_pipeline_dataset_create_entity = RagPipelineDatasetCreateEntity(**args)
|
||||
try:
|
||||
import_info = DatasetService.create_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=rag_pipeline_dataset_create_entity,
|
||||
)
|
||||
if rag_pipeline_dataset_create_entity.permission == "partial_members":
|
||||
DatasetPermissionService.update_partial_member_list(
|
||||
current_user.current_tenant_id,
|
||||
import_info["dataset_id"],
|
||||
rag_pipeline_dataset_create_entity.partial_member_list,
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return import_info, 201
|
||||
|
||||
|
||||
class CreateEmptyRagPipelineDatasetApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@cloud_edition_billing_rate_limit_check("knowledge")
|
||||
def post(self):
|
||||
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
|
||||
if not current_user.is_dataset_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
"name",
|
||||
nullable=False,
|
||||
required=True,
|
||||
help="type is required. Name must be between 1 to 40 characters.",
|
||||
type=_validate_name,
|
||||
)
|
||||
parser.add_argument(
|
||||
"description",
|
||||
type=str,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default="",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"icon_info",
|
||||
type=dict,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default={},
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"permission",
|
||||
type=str,
|
||||
choices=(DatasetPermissionEnum.ONLY_ME, DatasetPermissionEnum.ALL_TEAM, DatasetPermissionEnum.PARTIAL_TEAM),
|
||||
nullable=True,
|
||||
required=False,
|
||||
default=DatasetPermissionEnum.ONLY_ME,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"partial_member_list",
|
||||
type=list,
|
||||
nullable=True,
|
||||
required=False,
|
||||
default=[],
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
dataset = DatasetService.create_empty_rag_pipeline_dataset(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
rag_pipeline_dataset_create_entity=args,
|
||||
)
|
||||
return marshal(dataset, dataset_detail_fields), 201
|
||||
|
||||
|
||||
api.add_resource(CreateRagPipelineDatasetApi, "/rag/pipeline/dataset")
|
||||
api.add_resource(CreateEmptyRagPipelineDatasetApi, "/rag/pipeline/empty-dataset")
|
||||
@ -0,0 +1,147 @@
|
||||
from typing import cast
|
||||
|
||||
from flask_login import current_user # type: ignore
|
||||
from flask_restful import Resource, marshal_with, reqparse # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.wraps import get_rag_pipeline
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.rag_pipeline_fields import pipeline_import_check_dependencies_fields, pipeline_import_fields
|
||||
from libs.login import login_required
|
||||
from models import Account
|
||||
from models.dataset import Pipeline
|
||||
from services.app_dsl_service import ImportStatus
|
||||
from services.rag_pipeline.rag_pipeline_dsl_service import RagPipelineDslService
|
||||
|
||||
|
||||
class RagPipelineImportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("mode", type=str, required=True, location="json")
|
||||
parser.add_argument("yaml_content", type=str, location="json")
|
||||
parser.add_argument("yaml_url", type=str, location="json")
|
||||
parser.add_argument("name", type=str, location="json")
|
||||
parser.add_argument("description", type=str, location="json")
|
||||
parser.add_argument("icon_type", type=str, location="json")
|
||||
parser.add_argument("icon", type=str, location="json")
|
||||
parser.add_argument("icon_background", type=str, location="json")
|
||||
parser.add_argument("pipeline_id", type=str, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Import app
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.import_rag_pipeline(
|
||||
account=account,
|
||||
import_mode=args["mode"],
|
||||
yaml_content=args.get("yaml_content"),
|
||||
yaml_url=args.get("yaml_url"),
|
||||
pipeline_id=args.get("pipeline_id"),
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
status = result.status
|
||||
if status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
elif status == ImportStatus.PENDING.value:
|
||||
return result.model_dump(mode="json"), 202
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineImportConfirmApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_fields)
|
||||
def post(self, import_id):
|
||||
# Check user role first
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Create service with session
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
# Confirm import
|
||||
account = cast(Account, current_user)
|
||||
result = import_service.confirm_import(import_id=import_id, account=account)
|
||||
session.commit()
|
||||
|
||||
# Return appropriate status code based on result
|
||||
if result.status == ImportStatus.FAILED.value:
|
||||
return result.model_dump(mode="json"), 400
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineImportCheckDependenciesApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_check_dependencies_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
import_service = RagPipelineDslService(session)
|
||||
result = import_service.check_dependencies(pipeline=pipeline)
|
||||
|
||||
return result.model_dump(mode="json"), 200
|
||||
|
||||
|
||||
class RagPipelineExportApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@get_rag_pipeline
|
||||
@account_initialization_required
|
||||
@marshal_with(pipeline_import_check_dependencies_fields)
|
||||
def get(self, pipeline: Pipeline):
|
||||
if not current_user.is_editor:
|
||||
raise Forbidden()
|
||||
|
||||
# Add include_secret params
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("include_secret", type=bool, default=False, location="args")
|
||||
args = parser.parse_args()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
export_service = RagPipelineDslService(session)
|
||||
result = export_service.export_rag_pipeline_dsl(pipeline=pipeline, include_secret=args["include_secret"])
|
||||
|
||||
return {"data": result}, 200
|
||||
|
||||
|
||||
# Import Rag Pipeline
|
||||
api.add_resource(
|
||||
RagPipelineImportApi,
|
||||
"/rag/pipelines/imports",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportConfirmApi,
|
||||
"/rag/pipelines/imports/<string:import_id>/confirm",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineImportCheckDependenciesApi,
|
||||
"/rag/pipelines/imports/<string:pipeline_id>/check-dependencies",
|
||||
)
|
||||
api.add_resource(
|
||||
RagPipelineExportApi,
|
||||
"/rag/pipelines/<string:pipeline_id>/exports",
|
||||
)
|
||||
@ -0,0 +1,73 @@
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceParameter
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.__base.tool import ToolParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
|
||||
class DatasourceApiEntity(BaseModel):
|
||||
author: str
|
||||
name: str # identifier
|
||||
label: I18nObject # label
|
||||
description: I18nObject
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
output_schema: Optional[dict] = None
|
||||
|
||||
|
||||
ToolProviderTypeApiLiteral = Optional[Literal["builtin", "api", "workflow"]]
|
||||
|
||||
|
||||
class DatasourceProviderApiEntity(BaseModel):
|
||||
id: str
|
||||
author: str
|
||||
name: str # identifier
|
||||
description: I18nObject
|
||||
icon: str | dict
|
||||
label: I18nObject # label
|
||||
type: ToolProviderType
|
||||
masked_credentials: Optional[dict] = None
|
||||
original_credentials: Optional[dict] = None
|
||||
is_team_authorization: bool = False
|
||||
allow_delete: bool = True
|
||||
plugin_id: Optional[str] = Field(default="", description="The plugin id of the tool")
|
||||
plugin_unique_identifier: Optional[str] = Field(default="", description="The unique identifier of the tool")
|
||||
datasources: list[DatasourceApiEntity] = Field(default_factory=list)
|
||||
labels: list[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("datasources", mode="before")
|
||||
@classmethod
|
||||
def convert_none_to_empty_list(cls, v):
|
||||
return v if v is not None else []
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
# -------------
|
||||
# overwrite datasource parameter types for temp fix
|
||||
datasources = jsonable_encoder(self.datasources)
|
||||
for datasource in datasources:
|
||||
if datasource.get("parameters"):
|
||||
for parameter in datasource.get("parameters"):
|
||||
if parameter.get("type") == DatasourceParameter.DatasourceParameterType.SYSTEM_FILES.value:
|
||||
parameter["type"] = "files"
|
||||
# -------------
|
||||
|
||||
return {
|
||||
"id": self.id,
|
||||
"author": self.author,
|
||||
"name": self.name,
|
||||
"plugin_id": self.plugin_id,
|
||||
"plugin_unique_identifier": self.plugin_unique_identifier,
|
||||
"description": self.description.to_dict(),
|
||||
"icon": self.icon,
|
||||
"label": self.label.to_dict(),
|
||||
"type": self.type.value,
|
||||
"team_credentials": self.masked_credentials,
|
||||
"is_team_authorization": self.is_team_authorization,
|
||||
"allow_delete": self.allow_delete,
|
||||
"datasources": datasources,
|
||||
"labels": self.labels,
|
||||
}
|
||||
@ -0,0 +1,163 @@
|
||||
from flask_restful import fields # type: ignore
|
||||
|
||||
from fields.workflow_fields import workflow_partial_fields
|
||||
from libs.helper import AppIconUrlField, TimestampField
|
||||
|
||||
pipeline_detail_kernel_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
}
|
||||
|
||||
related_app_list = {
|
||||
"data": fields.List(fields.Nested(pipeline_detail_kernel_fields)),
|
||||
"total": fields.Integer,
|
||||
}
|
||||
|
||||
app_detail_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
|
||||
"tracing": fields.Raw,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
|
||||
tag_fields = {"id": fields.String, "name": fields.String, "type": fields.String}
|
||||
|
||||
app_partial_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String(attribute="desc_or_prompt"),
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
"tags": fields.List(fields.Nested(tag_fields)),
|
||||
}
|
||||
|
||||
|
||||
app_pagination_fields = {
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer(attribute="per_page"),
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean(attribute="has_next"),
|
||||
"data": fields.List(fields.Nested(app_partial_fields), attribute="items"),
|
||||
}
|
||||
|
||||
template_fields = {
|
||||
"name": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"description": fields.String,
|
||||
"mode": fields.String,
|
||||
}
|
||||
|
||||
template_list_fields = {
|
||||
"data": fields.List(fields.Nested(template_fields)),
|
||||
}
|
||||
|
||||
site_fields = {
|
||||
"access_token": fields.String(attribute="code"),
|
||||
"code": fields.String,
|
||||
"title": fields.String,
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"description": fields.String,
|
||||
"default_language": fields.String,
|
||||
"chat_color_theme": fields.String,
|
||||
"chat_color_theme_inverted": fields.Boolean,
|
||||
"customize_domain": fields.String,
|
||||
"copyright": fields.String,
|
||||
"privacy_policy": fields.String,
|
||||
"custom_disclaimer": fields.String,
|
||||
"customize_token_strategy": fields.String,
|
||||
"prompt_public": fields.Boolean,
|
||||
"app_base_url": fields.String,
|
||||
"show_workflow_steps": fields.Boolean,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
deleted_tool_fields = {
|
||||
"type": fields.String,
|
||||
"tool_name": fields.String,
|
||||
"provider_id": fields.String,
|
||||
}
|
||||
|
||||
app_detail_fields_with_site = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"mode": fields.String(attribute="mode_compatible_with_agent"),
|
||||
"icon_type": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"icon_url": AppIconUrlField,
|
||||
"enable_site": fields.Boolean,
|
||||
"enable_api": fields.Boolean,
|
||||
"workflow": fields.Nested(workflow_partial_fields, allow_null=True),
|
||||
"site": fields.Nested(site_fields),
|
||||
"api_base_url": fields.String,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
|
||||
app_site_fields = {
|
||||
"app_id": fields.String,
|
||||
"access_token": fields.String(attribute="code"),
|
||||
"code": fields.String,
|
||||
"title": fields.String,
|
||||
"icon": fields.String,
|
||||
"icon_background": fields.String,
|
||||
"description": fields.String,
|
||||
"default_language": fields.String,
|
||||
"customize_domain": fields.String,
|
||||
"copyright": fields.String,
|
||||
"privacy_policy": fields.String,
|
||||
"custom_disclaimer": fields.String,
|
||||
"customize_token_strategy": fields.String,
|
||||
"prompt_public": fields.Boolean,
|
||||
"show_workflow_steps": fields.Boolean,
|
||||
"use_icon_as_answer_icon": fields.Boolean,
|
||||
}
|
||||
|
||||
leaked_dependency_fields = {"type": fields.String, "value": fields.Raw, "current_identifier": fields.String}
|
||||
|
||||
pipeline_import_fields = {
|
||||
"id": fields.String,
|
||||
"status": fields.String,
|
||||
"pipeline_id": fields.String,
|
||||
"current_dsl_version": fields.String,
|
||||
"imported_dsl_version": fields.String,
|
||||
"error": fields.String,
|
||||
}
|
||||
|
||||
pipeline_import_check_dependencies_fields = {
|
||||
"leaked_dependencies": fields.List(fields.Nested(leaked_dependency_fields)),
|
||||
}
|
||||
@ -0,0 +1,841 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml # type: ignore
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from packaging import version
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.helper import ssrf_proxy
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.plugin.entities.plugin import PluginDependency
|
||||
from core.workflow.nodes.enums import NodeType
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData
|
||||
from core.workflow.nodes.llm.entities import LLMNodeData
|
||||
from core.workflow.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from core.workflow.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from core.workflow.nodes.tool.entities import ToolNodeData
|
||||
from extensions.ext_redis import redis_client
|
||||
from factories import variable_factory
|
||||
from models import Account
|
||||
from models.dataset import Dataset, Pipeline
|
||||
from models.workflow import Workflow
|
||||
from services.dataset_service import DatasetCollectionBindingService
|
||||
from services.entities.knowledge_entities.rag_pipeline_entities import KnowledgeConfiguration
|
||||
from services.plugin.dependencies_analysis import DependenciesAnalysisService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
IMPORT_INFO_REDIS_KEY_PREFIX = "app_import_info:"
|
||||
CHECK_DEPENDENCIES_REDIS_KEY_PREFIX = "app_check_dependencies:"
|
||||
IMPORT_INFO_REDIS_EXPIRY = 10 * 60 # 10 minutes
|
||||
DSL_MAX_SIZE = 10 * 1024 * 1024 # 10MB
|
||||
CURRENT_DSL_VERSION = "0.1.0"
|
||||
|
||||
|
||||
class ImportMode(StrEnum):
|
||||
YAML_CONTENT = "yaml-content"
|
||||
YAML_URL = "yaml-url"
|
||||
|
||||
|
||||
class ImportStatus(StrEnum):
|
||||
COMPLETED = "completed"
|
||||
COMPLETED_WITH_WARNINGS = "completed-with-warnings"
|
||||
PENDING = "pending"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class RagPipelineImportInfo(BaseModel):
|
||||
id: str
|
||||
status: ImportStatus
|
||||
pipeline_id: Optional[str] = None
|
||||
current_dsl_version: str = CURRENT_DSL_VERSION
|
||||
imported_dsl_version: str = ""
|
||||
error: str = ""
|
||||
dataset_id: Optional[str] = None
|
||||
|
||||
|
||||
class CheckDependenciesResult(BaseModel):
|
||||
leaked_dependencies: list[PluginDependency] = Field(default_factory=list)
|
||||
|
||||
|
||||
def _check_version_compatibility(imported_version: str) -> ImportStatus:
|
||||
"""Determine import status based on version comparison"""
|
||||
try:
|
||||
current_ver = version.parse(CURRENT_DSL_VERSION)
|
||||
imported_ver = version.parse(imported_version)
|
||||
except version.InvalidVersion:
|
||||
return ImportStatus.FAILED
|
||||
|
||||
# If imported version is newer than current, always return PENDING
|
||||
if imported_ver > current_ver:
|
||||
return ImportStatus.PENDING
|
||||
|
||||
# If imported version is older than current's major, return PENDING
|
||||
if imported_ver.major < current_ver.major:
|
||||
return ImportStatus.PENDING
|
||||
|
||||
# If imported version is older than current's minor, return COMPLETED_WITH_WARNINGS
|
||||
if imported_ver.minor < current_ver.minor:
|
||||
return ImportStatus.COMPLETED_WITH_WARNINGS
|
||||
|
||||
# If imported version equals or is older than current's micro, return COMPLETED
|
||||
return ImportStatus.COMPLETED
|
||||
|
||||
|
||||
class RagPipelinePendingData(BaseModel):
|
||||
import_mode: str
|
||||
yaml_content: str
|
||||
name: str | None
|
||||
description: str | None
|
||||
icon_type: str | None
|
||||
icon: str | None
|
||||
icon_background: str | None
|
||||
pipeline_id: str | None
|
||||
|
||||
|
||||
class CheckDependenciesPendingData(BaseModel):
|
||||
dependencies: list[PluginDependency]
|
||||
pipeline_id: str | None
|
||||
|
||||
|
||||
class RagPipelineDslService:
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
|
||||
def import_rag_pipeline(
|
||||
self,
|
||||
*,
|
||||
account: Account,
|
||||
import_mode: str,
|
||||
yaml_content: Optional[str] = None,
|
||||
yaml_url: Optional[str] = None,
|
||||
pipeline_id: Optional[str] = None,
|
||||
dataset: Optional[Dataset] = None,
|
||||
) -> RagPipelineImportInfo:
|
||||
"""Import an app from YAML content or URL."""
|
||||
import_id = str(uuid.uuid4())
|
||||
|
||||
# Validate import mode
|
||||
try:
|
||||
mode = ImportMode(import_mode)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid import_mode: {import_mode}")
|
||||
|
||||
# Get YAML content
|
||||
content: str = ""
|
||||
if mode == ImportMode.YAML_URL:
|
||||
if not yaml_url:
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="yaml_url is required when import_mode is yaml-url",
|
||||
)
|
||||
try:
|
||||
parsed_url = urlparse(yaml_url)
|
||||
if (
|
||||
parsed_url.scheme == "https"
|
||||
and parsed_url.netloc == "github.com"
|
||||
and parsed_url.path.endswith((".yml", ".yaml"))
|
||||
):
|
||||
yaml_url = yaml_url.replace("https://github.com", "https://raw.githubusercontent.com")
|
||||
yaml_url = yaml_url.replace("/blob/", "/")
|
||||
response = ssrf_proxy.get(yaml_url.strip(), follow_redirects=True, timeout=(10, 10))
|
||||
response.raise_for_status()
|
||||
content = response.content.decode()
|
||||
|
||||
if len(content) > DSL_MAX_SIZE:
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="File size exceeds the limit of 10MB",
|
||||
)
|
||||
|
||||
if not content:
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="Empty content from url",
|
||||
)
|
||||
except Exception as e:
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error=f"Error fetching YAML from URL: {str(e)}",
|
||||
)
|
||||
elif mode == ImportMode.YAML_CONTENT:
|
||||
if not yaml_content:
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="yaml_content is required when import_mode is yaml-content",
|
||||
)
|
||||
content = yaml_content
|
||||
|
||||
# Process YAML content
|
||||
try:
|
||||
# Parse YAML to validate format
|
||||
data = yaml.safe_load(content)
|
||||
if not isinstance(data, dict):
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="Invalid YAML format: content must be a mapping",
|
||||
)
|
||||
|
||||
# Validate and fix DSL version
|
||||
if not data.get("version"):
|
||||
data["version"] = "0.1.0"
|
||||
if not data.get("kind") or data.get("kind") != "rag-pipeline":
|
||||
data["kind"] = "rag-pipeline"
|
||||
|
||||
imported_version = data.get("version", "0.1.0")
|
||||
# check if imported_version is a float-like string
|
||||
if not isinstance(imported_version, str):
|
||||
raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}")
|
||||
status = _check_version_compatibility(imported_version)
|
||||
|
||||
# Extract app data
|
||||
pipeline_data = data.get("pipeline")
|
||||
if not pipeline_data:
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="Missing pipeline data in YAML content",
|
||||
)
|
||||
|
||||
# If app_id is provided, check if it exists
|
||||
pipeline = None
|
||||
if pipeline_id:
|
||||
stmt = select(Pipeline).where(
|
||||
Pipeline.id == pipeline_id,
|
||||
Pipeline.tenant_id == account.current_tenant_id,
|
||||
)
|
||||
pipeline = self._session.scalar(stmt)
|
||||
|
||||
if not pipeline:
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="Pipeline not found",
|
||||
)
|
||||
|
||||
# If major version mismatch, store import info in Redis
|
||||
if status == ImportStatus.PENDING:
|
||||
pending_data = RagPipelinePendingData(
|
||||
import_mode=import_mode,
|
||||
yaml_content=content,
|
||||
pipeline_id=pipeline_id,
|
||||
)
|
||||
redis_client.setex(
|
||||
f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}",
|
||||
IMPORT_INFO_REDIS_EXPIRY,
|
||||
pending_data.model_dump_json(),
|
||||
)
|
||||
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=status,
|
||||
pipeline_id=pipeline_id,
|
||||
imported_dsl_version=imported_version,
|
||||
)
|
||||
|
||||
# Extract dependencies
|
||||
dependencies = data.get("dependencies", [])
|
||||
check_dependencies_pending_data = None
|
||||
if dependencies:
|
||||
check_dependencies_pending_data = [PluginDependency.model_validate(d) for d in dependencies]
|
||||
|
||||
# Create or update app
|
||||
pipeline = self._create_or_update_pipeline(
|
||||
pipeline=pipeline,
|
||||
data=data,
|
||||
account=account,
|
||||
dependencies=check_dependencies_pending_data,
|
||||
)
|
||||
# create dataset
|
||||
name = pipeline.name
|
||||
description = pipeline.description
|
||||
icon_type = data.get("rag_pipeline", {}).get("icon_type")
|
||||
icon = data.get("rag_pipeline", {}).get("icon")
|
||||
icon_background = data.get("rag_pipeline", {}).get("icon_background")
|
||||
icon_url = data.get("rag_pipeline", {}).get("icon_url")
|
||||
workflow = data.get("workflow", {})
|
||||
graph = workflow.get("graph", {})
|
||||
nodes = graph.get("nodes", [])
|
||||
dataset_id = None
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") == "knowledge_index":
|
||||
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
|
||||
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
|
||||
if not dataset:
|
||||
dataset = Dataset(
|
||||
tenant_id=account.current_tenant_id,
|
||||
name=name,
|
||||
description=description,
|
||||
icon_info={
|
||||
"type": icon_type,
|
||||
"icon": icon,
|
||||
"background": icon_background,
|
||||
"url": icon_url,
|
||||
},
|
||||
indexing_technique=knowledge_configuration.index_method.indexing_technique,
|
||||
created_by=account.id,
|
||||
retrieval_model=knowledge_configuration.retrieval_setting.model_dump(),
|
||||
runtime_mode="rag_pipeline",
|
||||
chunk_structure=knowledge_configuration.chunk_structure,
|
||||
)
|
||||
else:
|
||||
dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
|
||||
dataset.runtime_mode = "rag_pipeline"
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
if knowledge_configuration.index_method.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore
|
||||
knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore
|
||||
)
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = (
|
||||
knowledge_configuration.index_method.embedding_setting.embedding_model_name
|
||||
)
|
||||
dataset.embedding_model_provider = (
|
||||
knowledge_configuration.index_method.embedding_setting.embedding_provider_name
|
||||
)
|
||||
elif knowledge_configuration.index_method.indexing_technique == "economy":
|
||||
dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number
|
||||
dataset.pipeline_id = pipeline.id
|
||||
self._session.add(dataset)
|
||||
self._session.commit()
|
||||
dataset_id = dataset.id
|
||||
if not dataset_id:
|
||||
raise ValueError("DSL is not valid, please check the Knowledge Index node.")
|
||||
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=status,
|
||||
pipeline_id=pipeline.id,
|
||||
dataset_id=dataset_id,
|
||||
imported_dsl_version=imported_version,
|
||||
)
|
||||
|
||||
except yaml.YAMLError as e:
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error=f"Invalid YAML format: {str(e)}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to import app")
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
def confirm_import(self, *, import_id: str, account: Account) -> RagPipelineImportInfo:
|
||||
"""
|
||||
Confirm an import that requires confirmation
|
||||
"""
|
||||
redis_key = f"{IMPORT_INFO_REDIS_KEY_PREFIX}{import_id}"
|
||||
pending_data = redis_client.get(redis_key)
|
||||
|
||||
if not pending_data:
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="Import information expired or does not exist",
|
||||
)
|
||||
|
||||
try:
|
||||
if not isinstance(pending_data, str | bytes):
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error="Invalid import information",
|
||||
)
|
||||
pending_data = RagPipelinePendingData.model_validate_json(pending_data)
|
||||
data = yaml.safe_load(pending_data.yaml_content)
|
||||
|
||||
pipeline = None
|
||||
if pending_data.pipeline_id:
|
||||
stmt = select(Pipeline).where(
|
||||
Pipeline.id == pending_data.pipeline_id,
|
||||
Pipeline.tenant_id == account.current_tenant_id,
|
||||
)
|
||||
pipeline = self._session.scalar(stmt)
|
||||
|
||||
# Create or update app
|
||||
pipeline = self._create_or_update_pipeline(
|
||||
pipeline=pipeline,
|
||||
data=data,
|
||||
account=account,
|
||||
)
|
||||
|
||||
# create dataset
|
||||
name = pipeline.name
|
||||
description = pipeline.description
|
||||
icon_type = data.get("rag_pipeline", {}).get("icon_type")
|
||||
icon = data.get("rag_pipeline", {}).get("icon")
|
||||
icon_background = data.get("rag_pipeline", {}).get("icon_background")
|
||||
icon_url = data.get("rag_pipeline", {}).get("icon_url")
|
||||
workflow = data.get("workflow", {})
|
||||
graph = workflow.get("graph", {})
|
||||
nodes = graph.get("nodes", [])
|
||||
dataset_id = None
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") == "knowledge_index":
|
||||
knowledge_configuration = node.get("data", {}).get("knowledge_configuration", {})
|
||||
knowledge_configuration = KnowledgeConfiguration(**knowledge_configuration)
|
||||
if not dataset:
|
||||
dataset = Dataset(
|
||||
tenant_id=account.current_tenant_id,
|
||||
name=name,
|
||||
description=description,
|
||||
icon_info={
|
||||
"type": icon_type,
|
||||
"icon": icon,
|
||||
"background": icon_background,
|
||||
"url": icon_url,
|
||||
},
|
||||
indexing_technique=knowledge_configuration.index_method.indexing_technique,
|
||||
created_by=account.id,
|
||||
retrieval_model=knowledge_configuration.retrieval_setting.model_dump(),
|
||||
runtime_mode="rag_pipeline",
|
||||
chunk_structure=knowledge_configuration.chunk_structure,
|
||||
)
|
||||
else:
|
||||
dataset.indexing_technique = knowledge_configuration.index_method.indexing_technique
|
||||
dataset.retrieval_model = knowledge_configuration.retrieval_setting.model_dump()
|
||||
dataset.runtime_mode = "rag_pipeline"
|
||||
dataset.chunk_structure = knowledge_configuration.chunk_structure
|
||||
if knowledge_configuration.index_method.indexing_technique == "high_quality":
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
knowledge_configuration.index_method.embedding_setting.embedding_provider_name, # type: ignore
|
||||
knowledge_configuration.index_method.embedding_setting.embedding_model_name, # type: ignore
|
||||
)
|
||||
dataset_collection_binding_id = dataset_collection_binding.id
|
||||
dataset.collection_binding_id = dataset_collection_binding_id
|
||||
dataset.embedding_model = (
|
||||
knowledge_configuration.index_method.embedding_setting.embedding_model_name
|
||||
)
|
||||
dataset.embedding_model_provider = (
|
||||
knowledge_configuration.index_method.embedding_setting.embedding_provider_name
|
||||
)
|
||||
elif knowledge_configuration.index_method.indexing_technique == "economy":
|
||||
dataset.keyword_number = knowledge_configuration.index_method.economy_setting.keyword_number
|
||||
dataset.pipeline_id = pipeline.id
|
||||
self._session.add(dataset)
|
||||
self._session.commit()
|
||||
dataset_id = dataset.id
|
||||
if not dataset_id:
|
||||
raise ValueError("DSL is not valid, please check the Knowledge Index node.")
|
||||
|
||||
# Delete import info from Redis
|
||||
redis_client.delete(redis_key)
|
||||
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.COMPLETED,
|
||||
pipeline_id=pipeline.id,
|
||||
dataset_id=dataset_id,
|
||||
current_dsl_version=CURRENT_DSL_VERSION,
|
||||
imported_dsl_version=data.get("version", "0.1.0"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error confirming import")
|
||||
return RagPipelineImportInfo(
|
||||
id=import_id,
|
||||
status=ImportStatus.FAILED,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
def check_dependencies(
|
||||
self,
|
||||
*,
|
||||
pipeline: Pipeline,
|
||||
) -> CheckDependenciesResult:
|
||||
"""Check dependencies"""
|
||||
# Get dependencies from Redis
|
||||
redis_key = f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}"
|
||||
dependencies = redis_client.get(redis_key)
|
||||
if not dependencies:
|
||||
return CheckDependenciesResult()
|
||||
|
||||
# Extract dependencies
|
||||
dependencies = CheckDependenciesPendingData.model_validate_json(dependencies)
|
||||
|
||||
# Get leaked dependencies
|
||||
leaked_dependencies = DependenciesAnalysisService.get_leaked_dependencies(
|
||||
tenant_id=pipeline.tenant_id, dependencies=dependencies.dependencies
|
||||
)
|
||||
return CheckDependenciesResult(
|
||||
leaked_dependencies=leaked_dependencies,
|
||||
)
|
||||
|
||||
def _create_or_update_pipeline(
|
||||
self,
|
||||
*,
|
||||
pipeline: Optional[Pipeline],
|
||||
data: dict,
|
||||
account: Account,
|
||||
dependencies: Optional[list[PluginDependency]] = None,
|
||||
) -> Pipeline:
|
||||
"""Create a new app or update an existing one."""
|
||||
pipeline_data = data.get("pipeline", {})
|
||||
pipeline_mode = pipeline_data.get("mode")
|
||||
if not pipeline_mode:
|
||||
raise ValueError("loss pipeline mode")
|
||||
# Set icon type
|
||||
icon_type_value = icon_type or pipeline_data.get("icon_type")
|
||||
if icon_type_value in ["emoji", "link"]:
|
||||
icon_type = icon_type_value
|
||||
else:
|
||||
icon_type = "emoji"
|
||||
icon = icon or str(pipeline_data.get("icon", ""))
|
||||
|
||||
if pipeline:
|
||||
# Update existing pipeline
|
||||
pipeline.name = pipeline_data.get("name", pipeline.name)
|
||||
pipeline.description = pipeline_data.get("description", pipeline.description)
|
||||
pipeline.icon_type = icon_type
|
||||
pipeline.icon = icon
|
||||
pipeline.icon_background = pipeline_data.get("icon_background", pipeline.icon_background)
|
||||
pipeline.updated_by = account.id
|
||||
else:
|
||||
if account.current_tenant_id is None:
|
||||
raise ValueError("Current tenant is not set")
|
||||
|
||||
# Create new app
|
||||
pipeline = Pipeline()
|
||||
pipeline.id = str(uuid4())
|
||||
pipeline.tenant_id = account.current_tenant_id
|
||||
pipeline.mode = pipeline_mode.value
|
||||
pipeline.name = pipeline_data.get("name", "")
|
||||
pipeline.description = pipeline_data.get("description", "")
|
||||
pipeline.icon_type = icon_type
|
||||
pipeline.icon = icon
|
||||
pipeline.icon_background = pipeline_data.get("icon_background", "#FFFFFF")
|
||||
pipeline.enable_site = True
|
||||
pipeline.enable_api = True
|
||||
pipeline.use_icon_as_answer_icon = pipeline_data.get("use_icon_as_answer_icon", False)
|
||||
pipeline.created_by = account.id
|
||||
pipeline.updated_by = account.id
|
||||
|
||||
self._session.add(pipeline)
|
||||
self._session.commit()
|
||||
# save dependencies
|
||||
if dependencies:
|
||||
redis_client.setex(
|
||||
f"{CHECK_DEPENDENCIES_REDIS_KEY_PREFIX}{pipeline.id}",
|
||||
IMPORT_INFO_REDIS_EXPIRY,
|
||||
CheckDependenciesPendingData(pipeline_id=pipeline.id, dependencies=dependencies).model_dump_json(),
|
||||
)
|
||||
|
||||
# Initialize pipeline based on mode
|
||||
workflow_data = data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
raise ValueError("Missing workflow data for rag pipeline")
|
||||
|
||||
environment_variables_list = workflow_data.get("environment_variables", [])
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = workflow_data.get("conversation_variables", [])
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
rag_pipeline_variables_list = workflow_data.get("rag_pipeline_variables", [])
|
||||
rag_pipeline_variables = [
|
||||
variable_factory.build_pipeline_variable_from_mapping(obj) for obj in rag_pipeline_variables_list
|
||||
]
|
||||
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
current_draft_workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||
if current_draft_workflow:
|
||||
unique_hash = current_draft_workflow.unique_hash
|
||||
else:
|
||||
unique_hash = None
|
||||
graph = workflow_data.get("graph", {})
|
||||
for node in graph.get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
decrypted_id
|
||||
for dataset_id in dataset_ids
|
||||
if (
|
||||
decrypted_id := self.decrypt_dataset_id(
|
||||
encrypted_data=dataset_id,
|
||||
tenant_id=pipeline.tenant_id,
|
||||
)
|
||||
)
|
||||
]
|
||||
rag_pipeline_service.sync_draft_workflow(
|
||||
pipeline=pipeline,
|
||||
graph=workflow_data.get("graph", {}),
|
||||
features=workflow_data.get("features", {}),
|
||||
unique_hash=unique_hash,
|
||||
account=account,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
return pipeline
|
||||
|
||||
@classmethod
|
||||
def export_rag_pipeline_dsl(cls, pipeline: Pipeline, include_secret: bool = False) -> str:
|
||||
"""
|
||||
Export pipeline
|
||||
:param pipeline: Pipeline instance
|
||||
:param include_secret: Whether include secret variable
|
||||
:return:
|
||||
"""
|
||||
export_data = {
|
||||
"version": CURRENT_DSL_VERSION,
|
||||
"kind": "rag_pipeline",
|
||||
"pipeline": {
|
||||
"name": pipeline.name,
|
||||
"mode": pipeline.mode,
|
||||
"icon": "🤖" if pipeline.icon_type == "image" else pipeline.icon,
|
||||
"icon_background": "#FFEAD5" if pipeline.icon_type == "image" else pipeline.icon_background,
|
||||
"description": pipeline.description,
|
||||
"use_icon_as_answer_icon": pipeline.use_icon_as_answer_icon,
|
||||
},
|
||||
}
|
||||
|
||||
cls._append_workflow_export_data(export_data=export_data, pipeline=pipeline, include_secret=include_secret)
|
||||
|
||||
return yaml.dump(export_data, allow_unicode=True) # type: ignore
|
||||
|
||||
@classmethod
|
||||
def _append_workflow_export_data(cls, *, export_data: dict, pipeline: Pipeline, include_secret: bool) -> None:
|
||||
"""
|
||||
Append workflow export data
|
||||
:param export_data: export data
|
||||
:param pipeline: Pipeline instance
|
||||
"""
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
workflow = rag_pipeline_service.get_draft_workflow(pipeline=pipeline)
|
||||
if not workflow:
|
||||
raise ValueError("Missing draft workflow configuration, please check.")
|
||||
|
||||
workflow_dict = workflow.to_dict(include_secret=include_secret)
|
||||
for node in workflow_dict.get("graph", {}).get("nodes", []):
|
||||
if node.get("data", {}).get("type", "") == NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
dataset_ids = node["data"].get("dataset_ids", [])
|
||||
node["data"]["dataset_ids"] = [
|
||||
cls.encrypt_dataset_id(dataset_id=dataset_id, tenant_id=pipeline.tenant_id)
|
||||
for dataset_id in dataset_ids
|
||||
]
|
||||
export_data["workflow"] = workflow_dict
|
||||
dependencies = cls._extract_dependencies_from_workflow(workflow)
|
||||
export_data["dependencies"] = [
|
||||
jsonable_encoder(d.model_dump())
|
||||
for d in DependenciesAnalysisService.generate_dependencies(
|
||||
tenant_id=pipeline.tenant_id, dependencies=dependencies
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _append_model_config_export_data(cls, export_data: dict, pipeline: Pipeline) -> None:
|
||||
"""
|
||||
Append model config export data
|
||||
:param export_data: export data
|
||||
:param pipeline: Pipeline instance
|
||||
"""
|
||||
app_model_config = pipeline.app_model_config
|
||||
if not app_model_config:
|
||||
raise ValueError("Missing app configuration, please check.")
|
||||
|
||||
export_data["model_config"] = app_model_config.to_dict()
|
||||
dependencies = cls._extract_dependencies_from_model_config(app_model_config.to_dict())
|
||||
export_data["dependencies"] = [
|
||||
jsonable_encoder(d.model_dump())
|
||||
for d in DependenciesAnalysisService.generate_dependencies(
|
||||
tenant_id=pipeline.tenant_id, dependencies=dependencies
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _extract_dependencies_from_workflow(cls, workflow: Workflow) -> list[str]:
|
||||
"""
|
||||
Extract dependencies from workflow
|
||||
:param workflow: Workflow instance
|
||||
:return: dependencies list format like ["langgenius/google"]
|
||||
"""
|
||||
graph = workflow.graph_dict
|
||||
dependencies = cls._extract_dependencies_from_workflow_graph(graph)
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def _extract_dependencies_from_workflow_graph(cls, graph: Mapping) -> list[str]:
|
||||
"""
|
||||
Extract dependencies from workflow graph
|
||||
:param graph: Workflow graph
|
||||
:return: dependencies list format like ["langgenius/google"]
|
||||
"""
|
||||
dependencies = []
|
||||
for node in graph.get("nodes", []):
|
||||
try:
|
||||
typ = node.get("data", {}).get("type")
|
||||
match typ:
|
||||
case NodeType.TOOL.value:
|
||||
tool_entity = ToolNodeData(**node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_tool_dependency(tool_entity.provider_id),
|
||||
)
|
||||
case NodeType.LLM.value:
|
||||
llm_entity = LLMNodeData(**node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(llm_entity.model.provider),
|
||||
)
|
||||
case NodeType.QUESTION_CLASSIFIER.value:
|
||||
question_classifier_entity = QuestionClassifierNodeData(**node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
question_classifier_entity.model.provider
|
||||
),
|
||||
)
|
||||
case NodeType.PARAMETER_EXTRACTOR.value:
|
||||
parameter_extractor_entity = ParameterExtractorNodeData(**node["data"])
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
parameter_extractor_entity.model.provider
|
||||
),
|
||||
)
|
||||
case NodeType.KNOWLEDGE_RETRIEVAL.value:
|
||||
knowledge_retrieval_entity = KnowledgeRetrievalNodeData(**node["data"])
|
||||
if knowledge_retrieval_entity.retrieval_mode == "multiple":
|
||||
if knowledge_retrieval_entity.multiple_retrieval_config:
|
||||
if (
|
||||
knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode
|
||||
== "reranking_model"
|
||||
):
|
||||
if knowledge_retrieval_entity.multiple_retrieval_config.reranking_model:
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
knowledge_retrieval_entity.multiple_retrieval_config.reranking_model.provider
|
||||
),
|
||||
)
|
||||
elif (
|
||||
knowledge_retrieval_entity.multiple_retrieval_config.reranking_mode
|
||||
== "weighted_score"
|
||||
):
|
||||
if knowledge_retrieval_entity.multiple_retrieval_config.weights:
|
||||
vector_setting = (
|
||||
knowledge_retrieval_entity.multiple_retrieval_config.weights.vector_setting
|
||||
)
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
vector_setting.embedding_provider_name
|
||||
),
|
||||
)
|
||||
elif knowledge_retrieval_entity.retrieval_mode == "single":
|
||||
model_config = knowledge_retrieval_entity.single_retrieval_config
|
||||
if model_config:
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
model_config.model.provider
|
||||
),
|
||||
)
|
||||
case _:
|
||||
# TODO: Handle default case or unknown node types
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception("Error extracting node dependency", exc_info=e)
|
||||
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def _extract_dependencies_from_model_config(cls, model_config: Mapping) -> list[str]:
|
||||
"""
|
||||
Extract dependencies from model config
|
||||
:param model_config: model config dict
|
||||
:return: dependencies list format like ["langgenius/google"]
|
||||
"""
|
||||
dependencies = []
|
||||
|
||||
try:
|
||||
# completion model
|
||||
model_dict = model_config.get("model", {})
|
||||
if model_dict:
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(model_dict.get("provider", ""))
|
||||
)
|
||||
|
||||
# reranking model
|
||||
dataset_configs = model_config.get("dataset_configs", {})
|
||||
if dataset_configs:
|
||||
for dataset_config in dataset_configs.get("datasets", {}).get("datasets", []):
|
||||
if dataset_config.get("reranking_model"):
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_model_provider_dependency(
|
||||
dataset_config.get("reranking_model", {})
|
||||
.get("reranking_provider_name", {})
|
||||
.get("provider")
|
||||
)
|
||||
)
|
||||
|
||||
# tools
|
||||
agent_configs = model_config.get("agent_mode", {})
|
||||
if agent_configs:
|
||||
for agent_config in agent_configs.get("tools", []):
|
||||
dependencies.append(
|
||||
DependenciesAnalysisService.analyze_tool_dependency(agent_config.get("provider_id"))
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error extracting model config dependency", exc_info=e)
|
||||
|
||||
return dependencies
|
||||
|
||||
@classmethod
|
||||
def get_leaked_dependencies(cls, tenant_id: str, dsl_dependencies: list[dict]) -> list[PluginDependency]:
|
||||
"""
|
||||
Returns the leaked dependencies in current workspace
|
||||
"""
|
||||
dependencies = [PluginDependency(**dep) for dep in dsl_dependencies]
|
||||
if not dependencies:
|
||||
return []
|
||||
|
||||
return DependenciesAnalysisService.get_leaked_dependencies(tenant_id=tenant_id, dependencies=dependencies)
|
||||
|
||||
@staticmethod
|
||||
def _generate_aes_key(tenant_id: str) -> bytes:
|
||||
"""Generate AES key based on tenant_id"""
|
||||
return hashlib.sha256(tenant_id.encode()).digest()
|
||||
|
||||
@classmethod
|
||||
def encrypt_dataset_id(cls, dataset_id: str, tenant_id: str) -> str:
|
||||
"""Encrypt dataset_id using AES-CBC mode"""
|
||||
key = cls._generate_aes_key(tenant_id)
|
||||
iv = key[:16]
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
ct_bytes = cipher.encrypt(pad(dataset_id.encode(), AES.block_size))
|
||||
return base64.b64encode(ct_bytes).decode()
|
||||
|
||||
@classmethod
|
||||
def decrypt_dataset_id(cls, encrypted_data: str, tenant_id: str) -> str | None:
|
||||
"""AES decryption"""
|
||||
try:
|
||||
key = cls._generate_aes_key(tenant_id)
|
||||
iv = key[:16]
|
||||
cipher = AES.new(key, AES.MODE_CBC, iv)
|
||||
pt = unpad(cipher.decrypt(base64.b64decode(encrypted_data)), AES.block_size)
|
||||
return pt.decode()
|
||||
except Exception:
|
||||
return None
|
||||
Loading…
Reference in New Issue