Merge branch 'langgenius:main' into add-document-status-update

pull/18235/head
GuanMu 1 year ago committed by GitHub
commit 2e2a60b92d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -6,6 +6,7 @@ from typing import Optional
import click import click
from flask import current_app from flask import current_app
from sqlalchemy import select
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from configs import dify_config from configs import dify_config
@ -297,11 +298,11 @@ def migrate_knowledge_vector_database():
page = 1 page = 1
while True: while True:
try: try:
datasets = ( stmt = (
Dataset.query.filter(Dataset.indexing_technique == "high_quality") select(Dataset).filter(Dataset.indexing_technique == "high_quality").order_by(Dataset.created_at.desc())
.order_by(Dataset.created_at.desc())
.paginate(page=page, per_page=50)
) )
datasets = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
except NotFound: except NotFound:
break break
@ -551,11 +552,12 @@ def old_metadata_migration():
page = 1 page = 1
while True: while True:
try: try:
documents = ( stmt = (
DatasetDocument.query.filter(DatasetDocument.doc_metadata is not None) select(DatasetDocument)
.filter(DatasetDocument.doc_metadata.is_not(None))
.order_by(DatasetDocument.created_at.desc()) .order_by(DatasetDocument.created_at.desc())
.paginate(page=page, per_page=50)
) )
documents = db.paginate(select=stmt, page=page, per_page=50, max_per_page=50, error_out=False)
except NotFound: except NotFound:
break break
if not documents: if not documents:
@ -592,11 +594,15 @@ def old_metadata_migration():
) )
db.session.add(dataset_metadata_binding) db.session.add(dataset_metadata_binding)
else: else:
dataset_metadata_binding = DatasetMetadataBinding.query.filter( dataset_metadata_binding = (
db.session.query(DatasetMetadataBinding) # type: ignore
.filter(
DatasetMetadataBinding.dataset_id == document.dataset_id, DatasetMetadataBinding.dataset_id == document.dataset_id,
DatasetMetadataBinding.document_id == document.id, DatasetMetadataBinding.document_id == document.id,
DatasetMetadataBinding.metadata_id == dataset_metadata.id, DatasetMetadataBinding.metadata_id == dataset_metadata.id,
).first() )
.first()
)
if not dataset_metadata_binding: if not dataset_metadata_binding:
dataset_metadata_binding = DatasetMetadataBinding( dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=document.tenant_id, tenant_id=document.tenant_id,

@ -74,7 +74,7 @@ class CodeExecutionSandboxConfig(BaseSettings):
CODE_EXECUTION_ENDPOINT: HttpUrl = Field( CODE_EXECUTION_ENDPOINT: HttpUrl = Field(
description="URL endpoint for the code execution service", description="URL endpoint for the code execution service",
default="http://sandbox:8194", default=HttpUrl("http://sandbox:8194"),
) )
CODE_EXECUTION_API_KEY: str = Field( CODE_EXECUTION_API_KEY: str = Field(
@ -145,7 +145,7 @@ class PluginConfig(BaseSettings):
PLUGIN_DAEMON_URL: HttpUrl = Field( PLUGIN_DAEMON_URL: HttpUrl = Field(
description="Plugin API URL", description="Plugin API URL",
default="http://localhost:5002", default=HttpUrl("http://localhost:5002"),
) )
PLUGIN_DAEMON_KEY: str = Field( PLUGIN_DAEMON_KEY: str = Field(
@ -188,7 +188,7 @@ class MarketplaceConfig(BaseSettings):
MARKETPLACE_API_URL: HttpUrl = Field( MARKETPLACE_API_URL: HttpUrl = Field(
description="Marketplace API URL", description="Marketplace API URL",
default="https://marketplace.dify.ai", default=HttpUrl("https://marketplace.dify.ai"),
) )

@ -1,6 +1,6 @@
import os import os
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
from urllib.parse import quote_plus from urllib.parse import parse_qsl, quote_plus
from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field from pydantic import Field, NonNegativeInt, PositiveFloat, PositiveInt, computed_field
from pydantic_settings import BaseSettings from pydantic_settings import BaseSettings
@ -173,17 +173,31 @@ class DatabaseConfig(BaseSettings):
RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field( RETRIEVAL_SERVICE_EXECUTORS: NonNegativeInt = Field(
description="Number of processes for the retrieval service, default to CPU cores.", description="Number of processes for the retrieval service, default to CPU cores.",
default=os.cpu_count(), default=os.cpu_count() or 1,
) )
@computed_field @computed_field # type: ignore[misc]
@property
def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]: def SQLALCHEMY_ENGINE_OPTIONS(self) -> dict[str, Any]:
# Parse DB_EXTRAS for 'options'
db_extras_dict = dict(parse_qsl(self.DB_EXTRAS))
options = db_extras_dict.get("options", "")
# Always include timezone
timezone_opt = "-c timezone=UTC"
if options:
# Merge user options and timezone
merged_options = f"{options} {timezone_opt}"
else:
merged_options = timezone_opt
connect_args = {"options": merged_options}
return { return {
"pool_size": self.SQLALCHEMY_POOL_SIZE, "pool_size": self.SQLALCHEMY_POOL_SIZE,
"max_overflow": self.SQLALCHEMY_MAX_OVERFLOW, "max_overflow": self.SQLALCHEMY_MAX_OVERFLOW,
"pool_recycle": self.SQLALCHEMY_POOL_RECYCLE, "pool_recycle": self.SQLALCHEMY_POOL_RECYCLE,
"pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING, "pool_pre_ping": self.SQLALCHEMY_POOL_PRE_PING,
"connect_args": {"options": "-c timezone=UTC"}, "connect_args": connect_args,
} }

@ -83,3 +83,13 @@ class RedisConfig(BaseSettings):
description="Password for Redis Clusters authentication (if required)", description="Password for Redis Clusters authentication (if required)",
default=None, default=None,
) )
REDIS_SERIALIZATION_PROTOCOL: int = Field(
description="Redis serialization protocol (RESP) version",
default=3,
)
REDIS_ENABLE_CLIENT_SIDE_CACHE: bool = Field(
description="Enable client side cache in redis",
default=False,
)

@ -1,5 +1,7 @@
from flask_restful import fields from flask_restful import fields
from libs.helper import AppIconUrlField
parameters__system_parameters = { parameters__system_parameters = {
"image_file_size_limit": fields.Integer, "image_file_size_limit": fields.Integer,
"video_file_size_limit": fields.Integer, "video_file_size_limit": fields.Integer,
@ -22,3 +24,20 @@ parameters_fields = {
"file_upload": fields.Raw, "file_upload": fields.Raw,
"system_parameters": fields.Nested(parameters__system_parameters), "system_parameters": fields.Nested(parameters__system_parameters),
} }
site_fields = {
"title": fields.String,
"chat_color_theme": fields.String,
"chat_color_theme_inverted": fields.Boolean,
"icon_type": fields.String,
"icon": fields.String,
"icon_background": fields.String,
"icon_url": AppIconUrlField,
"description": fields.String,
"copyright": fields.String,
"privacy_policy": fields.String,
"custom_disclaimer": fields.String,
"default_language": fields.String,
"show_workflow_steps": fields.Boolean,
"use_icon_as_answer_icon": fields.Boolean,
}

@ -526,14 +526,20 @@ class DatasetIndexingStatusApi(Resource):
) )
documents_status = [] documents_status = []
for document in documents: for document in documents:
completed_segments = DocumentSegment.query.filter( completed_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
).count() )
total_segments = DocumentSegment.query.filter( .count()
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" )
).count() total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
documents_status.append(marshal(document, document_status_fields)) documents_status.append(marshal(document, document_status_fields))

@ -6,7 +6,7 @@ from typing import cast
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, fields, marshal, marshal_with, reqparse from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import asc, desc from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
@ -109,7 +109,7 @@ class GetProcessRuleApi(Resource):
limits = DocumentService.DEFAULT_RULES["limits"] limits = DocumentService.DEFAULT_RULES["limits"]
if document_id: if document_id:
# get the latest process rule # get the latest process rule
document = Document.query.get_or_404(document_id) document = db.get_or_404(Document, document_id)
dataset = DatasetService.get_dataset(document.dataset_id) dataset = DatasetService.get_dataset(document.dataset_id)
@ -172,7 +172,7 @@ class DatasetDocumentListApi(Resource):
except services.errors.account.NoPermissionError as e: except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e)) raise Forbidden(str(e))
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id) query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=current_user.current_tenant_id)
if search: if search:
search = f"%{search}%" search = f"%{search}%"
@ -206,18 +206,24 @@ class DatasetDocumentListApi(Resource):
desc(Document.position), desc(Document.position),
) )
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items documents = paginated_documents.items
if fetch: if fetch:
for document in documents: for document in documents:
completed_segments = DocumentSegment.query.filter( completed_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
).count() )
total_segments = DocumentSegment.query.filter( .count()
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" )
).count() total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
data = marshal(documents, document_with_segments_fields) data = marshal(documents, document_with_segments_fields)
@ -560,14 +566,20 @@ class DocumentBatchIndexingStatusApi(DocumentResource):
documents = self.get_batch_documents(dataset_id, batch) documents = self.get_batch_documents(dataset_id, batch)
documents_status = [] documents_status = []
for document in documents: for document in documents:
completed_segments = DocumentSegment.query.filter( completed_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
).count() )
total_segments = DocumentSegment.query.filter( .count()
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" )
).count() total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
if document.is_paused: if document.is_paused:
@ -586,14 +598,20 @@ class DocumentIndexingStatusApi(DocumentResource):
document_id = str(document_id) document_id = str(document_id)
document = self.get_document(dataset_id, document_id) document = self.get_document(dataset_id, document_id)
completed_segments = DocumentSegment.query.filter( completed_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document_id), DocumentSegment.document_id == str(document_id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
).count() )
total_segments = DocumentSegment.query.filter( .count()
DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment" )
).count() total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document_id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments

@ -4,6 +4,7 @@ import pandas as pd
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, marshal, reqparse from flask_restful import Resource, marshal, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Forbidden, NotFound from werkzeug.exceptions import Forbidden, NotFound
import services import services
@ -26,6 +27,7 @@ from controllers.console.wraps import (
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.segment_fields import child_chunk_fields, segment_fields from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required from libs.login import login_required
@ -74,9 +76,14 @@ class DatasetDocumentSegmentListApi(Resource):
hit_count_gte = args["hit_count_gte"] hit_count_gte = args["hit_count_gte"]
keyword = args["keyword"] keyword = args["keyword"]
query = DocumentSegment.query.filter( query = (
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id select(DocumentSegment)
).order_by(DocumentSegment.position.asc()) .filter(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id,
)
.order_by(DocumentSegment.position.asc())
)
if status_list: if status_list:
query = query.filter(DocumentSegment.status.in_(status_list)) query = query.filter(DocumentSegment.status.in_(status_list))
@ -93,7 +100,7 @@ class DatasetDocumentSegmentListApi(Resource):
elif args["enabled"].lower() == "false": elif args["enabled"].lower() == "false":
query = query.filter(DocumentSegment.enabled == False) query = query.filter(DocumentSegment.enabled == False)
segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
response = { response = {
"data": marshal(segments.items, segment_fields), "data": marshal(segments.items, segment_fields),
@ -276,9 +283,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -320,9 +329,11 @@ class DatasetDocumentSegmentUpdateApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -423,9 +434,11 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
@ -478,9 +491,11 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
@ -523,9 +538,11 @@ class ChildChunkAddApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -567,16 +584,20 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# check child chunk # check child chunk
child_chunk_id = str(child_chunk_id) child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter( child_chunk = (
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id db.session.query(ChildChunk)
).first() .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first()
)
if not child_chunk: if not child_chunk:
raise NotFound("Child chunk not found.") raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor
@ -612,16 +633,20 @@ class ChildChunkUpdateApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
# check segment # check segment
segment_id = str(segment_id) segment_id = str(segment_id)
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
# check child chunk # check child chunk
child_chunk_id = str(child_chunk_id) child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter( child_chunk = (
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id db.session.query(ChildChunk)
).first() .filter(ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id)
.first()
)
if not child_chunk: if not child_chunk:
raise NotFound("Child chunk not found.") raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin, owner, dataset_operator, or editor # The role of the current user in the ta table must be admin, owner, dataset_operator, or editor

@ -209,6 +209,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("query", type=str, location="json") parser.add_argument("query", type=str, location="json")
parser.add_argument("external_retrieval_model", type=dict, required=False, location="json") parser.add_argument("external_retrieval_model", type=dict, required=False, location="json")
parser.add_argument("metadata_filtering_conditions", type=dict, required=False, location="json")
args = parser.parse_args() args = parser.parse_args()
HitTestingService.hit_testing_args_check(args) HitTestingService.hit_testing_args_check(args)
@ -219,6 +220,7 @@ class ExternalKnowledgeHitTestingApi(Resource):
query=args["query"], query=args["query"],
account=current_user, account=current_user,
external_retrieval_model=args["external_retrieval_model"], external_retrieval_model=args["external_retrieval_model"],
metadata_filtering_conditions=args["metadata_filtering_conditions"],
) )
return response return response

@ -66,7 +66,7 @@ class InstalledAppsListApi(Resource):
parser.add_argument("app_id", type=str, required=True, help="Invalid app_id") parser.add_argument("app_id", type=str, required=True, help="Invalid app_id")
args = parser.parse_args() args = parser.parse_args()
recommended_app = RecommendedApp.query.filter(RecommendedApp.app_id == args["app_id"]).first() recommended_app = db.session.query(RecommendedApp).filter(RecommendedApp.app_id == args["app_id"]).first()
if recommended_app is None: if recommended_app is None:
raise NotFound("App not found") raise NotFound("App not found")
@ -79,9 +79,11 @@ class InstalledAppsListApi(Resource):
if not app.is_public: if not app.is_public:
raise Forbidden("You can't install a non-public app") raise Forbidden("You can't install a non-public app")
installed_app = InstalledApp.query.filter( installed_app = (
and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id) db.session.query(InstalledApp)
).first() .filter(and_(InstalledApp.app_id == args["app_id"], InstalledApp.tenant_id == current_tenant_id))
.first()
)
if installed_app is None: if installed_app is None:
# todo: position # todo: position

@ -71,7 +71,6 @@ class MemberInviteEmailApi(Resource):
invitation_results.append( invitation_results.append(
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"} {"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
) )
break
except Exception as e: except Exception as e:
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)}) invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})

@ -3,6 +3,7 @@ import logging
from flask import request from flask import request
from flask_login import current_user from flask_login import current_user
from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse from flask_restful import Resource, fields, inputs, marshal, marshal_with, reqparse
from sqlalchemy import select
from werkzeug.exceptions import Unauthorized from werkzeug.exceptions import Unauthorized
import services import services
@ -88,9 +89,8 @@ class WorkspaceListApi(Resource):
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args") parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
args = parser.parse_args() args = parser.parse_args()
tenants = Tenant.query.order_by(Tenant.created_at.desc()).paginate( stmt = select(Tenant).order_by(Tenant.created_at.desc())
page=args["page"], per_page=args["limit"], error_out=False tenants = db.paginate(select=stmt, page=args["page"], per_page=args["limit"], error_out=False)
)
has_more = False has_more = False
if tenants.has_next: if tenants.has_next:
@ -162,7 +162,7 @@ class CustomConfigWorkspaceApi(Resource):
parser.add_argument("replace_webapp_logo", type=str, location="json") parser.add_argument("replace_webapp_logo", type=str, location="json")
args = parser.parse_args() args = parser.parse_args()
tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404() tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
custom_config_dict = { custom_config_dict = {
"remove_webapp_brand": args["remove_webapp_brand"], "remove_webapp_brand": args["remove_webapp_brand"],
@ -226,7 +226,7 @@ class WorkspaceInfoApi(Resource):
parser.add_argument("name", type=str, required=True, location="json") parser.add_argument("name", type=str, required=True, location="json")
args = parser.parse_args() args = parser.parse_args()
tenant = Tenant.query.filter(Tenant.id == current_user.current_tenant_id).one_or_404() tenant = db.get_or_404(Tenant, current_user.current_tenant_id)
tenant.name = args["name"] tenant.name = args["name"]
db.session.commit() db.session.commit()

@ -6,6 +6,6 @@ bp = Blueprint("service_api", __name__, url_prefix="/v1")
api = ExternalApi(bp) api = ExternalApi(bp)
from . import index from . import index
from .app import annotation, app, audio, completion, conversation, file, message, workflow from .app import annotation, app, audio, completion, conversation, file, message, site, workflow
from .dataset import dataset, document, hit_testing, metadata, segment, upload_file from .dataset import dataset, document, hit_testing, metadata, segment, upload_file
from .workspace import models from .workspace import models

@ -93,6 +93,18 @@ class MessageFeedbackApi(Resource):
return {"result": "success"} return {"result": "success"}
class AppGetFeedbacksApi(Resource):
@validate_app_token
def get(self, app_model: App):
"""Get All Feedbacks of an app"""
parser = reqparse.RequestParser()
parser.add_argument("page", type=int, default=1, location="args")
parser.add_argument("limit", type=int_range(1, 101), required=False, default=20, location="args")
args = parser.parse_args()
feedbacks = MessageService.get_all_messages_feedbacks(app_model, page=args["page"], limit=args["limit"])
return {"data": feedbacks}
class MessageSuggestedApi(Resource): class MessageSuggestedApi(Resource):
@validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True)) @validate_app_token(fetch_user_arg=FetchUserArg(fetch_from=WhereisUserArg.QUERY, required=True))
def get(self, app_model: App, end_user: EndUser, message_id): def get(self, app_model: App, end_user: EndUser, message_id):
@ -119,3 +131,4 @@ class MessageSuggestedApi(Resource):
api.add_resource(MessageListApi, "/messages") api.add_resource(MessageListApi, "/messages")
api.add_resource(MessageFeedbackApi, "/messages/<uuid:message_id>/feedbacks") api.add_resource(MessageFeedbackApi, "/messages/<uuid:message_id>/feedbacks")
api.add_resource(MessageSuggestedApi, "/messages/<uuid:message_id>/suggested") api.add_resource(MessageSuggestedApi, "/messages/<uuid:message_id>/suggested")
api.add_resource(AppGetFeedbacksApi, "/app/feedbacks")

@ -0,0 +1,30 @@
from flask_restful import Resource, marshal_with
from werkzeug.exceptions import Forbidden
from controllers.common import fields
from controllers.service_api import api
from controllers.service_api.wraps import validate_app_token
from extensions.ext_database import db
from models.account import TenantStatus
from models.model import App, Site
class AppSiteApi(Resource):
"""Resource for app sites."""
@validate_app_token
@marshal_with(fields.site_fields)
def get(self, app_model: App):
"""Retrieve app site info."""
site = db.session.query(Site).filter(Site.app_id == app_model.id).first()
if not site:
raise Forbidden()
if app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return site
api.add_resource(AppSiteApi, "/site")

@ -2,10 +2,10 @@ import json
from flask import request from flask import request
from flask_restful import marshal, reqparse from flask_restful import marshal, reqparse
from sqlalchemy import desc from sqlalchemy import desc, select
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import services.dataset_service import services
from controllers.common.errors import FilenameNotExistsError from controllers.common.errors import FilenameNotExistsError
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app.error import ( from controllers.service_api.app.error import (
@ -337,7 +337,7 @@ class DocumentListApi(DatasetApiResource):
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
query = Document.query.filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id) query = select(Document).filter_by(dataset_id=str(dataset_id), tenant_id=tenant_id)
if search: if search:
search = f"%{search}%" search = f"%{search}%"
@ -345,7 +345,7 @@ class DocumentListApi(DatasetApiResource):
query = query.order_by(desc(Document.created_at), desc(Document.position)) query = query.order_by(desc(Document.created_at), desc(Document.position))
paginated_documents = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items documents = paginated_documents.items
response = { response = {
@ -374,14 +374,20 @@ class DocumentIndexingStatusApi(DatasetApiResource):
raise NotFound("Documents not found.") raise NotFound("Documents not found.")
documents_status = [] documents_status = []
for document in documents: for document in documents:
completed_segments = DocumentSegment.query.filter( completed_segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.document_id == str(document.id), DocumentSegment.document_id == str(document.id),
DocumentSegment.status != "re_segment", DocumentSegment.status != "re_segment",
).count() )
total_segments = DocumentSegment.query.filter( .count()
DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment" )
).count() total_segments = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.document_id == str(document.id), DocumentSegment.status != "re_segment")
.count()
)
document.completed_segments = completed_segments document.completed_segments = completed_segments
document.total_segments = total_segments document.total_segments = total_segments
if document.is_paused: if document.is_paused:

@ -91,6 +91,8 @@ class BaseAgentRunner(AppRunner):
return_resource=app_config.additional_features.show_retrieve_source, return_resource=app_config.additional_features.show_retrieve_source,
invoke_from=application_generate_entity.invoke_from, invoke_from=application_generate_entity.invoke_from,
hit_callback=hit_callback, hit_callback=hit_callback,
user_id=user_id,
inputs=cast(dict, application_generate_entity.inputs),
) )
# get how many agent thoughts have been created # get how many agent thoughts have been created
self.agent_thought_count = ( self.agent_thought_count = (

@ -69,13 +69,6 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_instances, prompt_messages_tools = self._init_prompt_tools() tool_instances, prompt_messages_tools = self._init_prompt_tools()
self._prompt_messages_tools = prompt_messages_tools self._prompt_messages_tools = prompt_messages_tools
# fix metadata filter not work
if app_config.dataset is not None:
metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions
for key, dataset_retriever_tool in tool_instances.items():
if hasattr(dataset_retriever_tool, "retrieval_tool"):
dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions
function_call_state = True function_call_state = True
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None} llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = "" final_answer = ""
@ -87,6 +80,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
llm_usage = final_llm_usage_dict["usage"] llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens llm_usage.completion_tokens += usage.completion_tokens
llm_usage.total_tokens += usage.total_tokens
llm_usage.prompt_price += usage.prompt_price llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price llm_usage.completion_price += usage.completion_price
llm_usage.total_price += usage.total_price llm_usage.total_price += usage.total_price

@ -45,13 +45,6 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# convert tools into ModelRuntime Tool format # convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools() tool_instances, prompt_messages_tools = self._init_prompt_tools()
# fix metadata filter not work
if app_config.dataset is not None:
metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions
for key, dataset_retriever_tool in tool_instances.items():
if hasattr(dataset_retriever_tool, "retrieval_tool"):
dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions
assert app_config.agent assert app_config.agent
iteration_step = 1 iteration_step = 1
@ -72,6 +65,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
llm_usage = final_llm_usage_dict["usage"] llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens llm_usage.completion_tokens += usage.completion_tokens
llm_usage.total_tokens += usage.total_tokens
llm_usage.prompt_price += usage.prompt_price llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price llm_usage.completion_price += usage.completion_price
llm_usage.total_price += usage.total_price llm_usage.total_price += usage.total_price

@ -1,3 +1,5 @@
import logging
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
@ -7,6 +9,8 @@ from extensions.ext_database import db
from models.dataset import ChildChunk, DatasetQuery, DocumentSegment from models.dataset import ChildChunk, DatasetQuery, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
_logger = logging.getLogger(__name__)
class DatasetIndexToolCallbackHandler: class DatasetIndexToolCallbackHandler:
"""Callback handler for dataset tool.""" """Callback handler for dataset tool."""
@ -42,19 +46,32 @@ class DatasetIndexToolCallbackHandler:
"""Handle tool end.""" """Handle tool end."""
for document in documents: for document in documents:
if document.metadata is not None: if document.metadata is not None:
dataset_document = DatasetDocument.query.filter( document_id = document.metadata["document_id"]
DatasetDocument.id == document.metadata["document_id"] dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
).first() if not dataset_document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s",
document_id,
)
continue
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ChildChunk.query.filter( child_chunk = (
db.session.query(ChildChunk)
.filter(
ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id, ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id, ChildChunk.document_id == dataset_document.id,
).first() )
.first()
)
if child_chunk: if child_chunk:
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update( segment = (
db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
) )
)
else: else:
query = db.session.query(DocumentSegment).filter( query = db.session.query(DocumentSegment).filter(
DocumentSegment.index_node_id == document.metadata["doc_id"] DocumentSegment.index_node_id == document.metadata["doc_id"]

@ -51,7 +51,7 @@ class IndexingRunner:
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
try: try:
# get dataset # get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
@ -103,15 +103,17 @@ class IndexingRunner:
"""Run the indexing process when the index_status is splitting.""" """Run the indexing process when the index_status is splitting."""
try: try:
# get dataset # get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
# get exist document_segment list and delete # get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by( document_segments = (
dataset_id=dataset.id, document_id=dataset_document.id db.session.query(DocumentSegment)
).all() .filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.all()
)
for document_segment in document_segments: for document_segment in document_segments:
db.session.delete(document_segment) db.session.delete(document_segment)
@ -162,15 +164,17 @@ class IndexingRunner:
"""Run the indexing process when the index_status is indexing.""" """Run the indexing process when the index_status is indexing."""
try: try:
# get dataset # get dataset
dataset = Dataset.query.filter_by(id=dataset_document.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_document.dataset_id).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
# get exist document_segment list and delete # get exist document_segment list and delete
document_segments = DocumentSegment.query.filter_by( document_segments = (
dataset_id=dataset.id, document_id=dataset_document.id db.session.query(DocumentSegment)
).all() .filter_by(dataset_id=dataset.id, document_id=dataset_document.id)
.all()
)
documents = [] documents = []
if document_segments: if document_segments:
@ -254,7 +258,7 @@ class IndexingRunner:
embedding_model_instance = None embedding_model_instance = None
if dataset_id: if dataset_id:
dataset = Dataset.query.filter_by(id=dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset: if not dataset:
raise ValueError("Dataset not found.") raise ValueError("Dataset not found.")
if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality" or indexing_technique == "high_quality":
@ -587,7 +591,7 @@ class IndexingRunner:
@staticmethod @staticmethod
def _process_keyword_index(flask_app, dataset_id, document_id, documents): def _process_keyword_index(flask_app, dataset_id, document_id, documents):
with flask_app.app_context(): with flask_app.app_context():
dataset = Dataset.query.filter_by(id=dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset: if not dataset:
raise ValueError("no dataset found") raise ValueError("no dataset found")
keyword = Keyword(dataset) keyword = Keyword(dataset)
@ -656,10 +660,10 @@ class IndexingRunner:
""" """
Update the document indexing status. Update the document indexing status.
""" """
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count() count = db.session.query(DatasetDocument).filter_by(id=document_id, is_paused=True).count()
if count > 0: if count > 0:
raise DocumentIsPausedError() raise DocumentIsPausedError()
document = DatasetDocument.query.filter_by(id=document_id).first() document = db.session.query(DatasetDocument).filter_by(id=document_id).first()
if not document: if not document:
raise DocumentIsDeletedPausedError() raise DocumentIsDeletedPausedError()
@ -668,7 +672,7 @@ class IndexingRunner:
if extra_update_params: if extra_update_params:
update_params.update(extra_update_params) update_params.update(extra_update_params)
DatasetDocument.query.filter_by(id=document_id).update(update_params) db.session.query(DatasetDocument).filter_by(id=document_id).update(update_params)
db.session.commit() db.session.commit()
@staticmethod @staticmethod
@ -676,7 +680,7 @@ class IndexingRunner:
""" """
Update the document segment by document id. Update the document segment by document id.
""" """
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) db.session.query(DocumentSegment).filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit() db.session.commit()
def _transform( def _transform(

@ -1,9 +1,9 @@
from enum import Enum from enum import StrEnum
from pydantic import BaseModel, ValidationInfo, field_validator from pydantic import BaseModel, ValidationInfo, field_validator
class TracingProviderEnum(Enum): class TracingProviderEnum(StrEnum):
LANGFUSE = "langfuse" LANGFUSE = "langfuse"
LANGSMITH = "langsmith" LANGSMITH = "langsmith"
OPIK = "opik" OPIK = "opik"

@ -16,11 +16,7 @@ from sqlalchemy.orm import Session
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import ( from core.ops.entities.config_entity import (
OPS_FILE_PATH, OPS_FILE_PATH,
LangfuseConfig,
LangSmithConfig,
OpikConfig,
TracingProviderEnum, TracingProviderEnum,
WeaveConfig,
) )
from core.ops.entities.trace_entity import ( from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo, DatasetRetrievalTraceInfo,
@ -33,11 +29,7 @@ from core.ops.entities.trace_entity import (
TraceTaskName, TraceTaskName,
WorkflowTraceInfo, WorkflowTraceInfo,
) )
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
from core.ops.opik_trace.opik_trace import OpikDataTrace
from core.ops.utils import get_message_data from core.ops.utils import get_message_data
from core.ops.weave_trace.weave_trace import WeaveDataTrace
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_storage import storage from extensions.ext_storage import storage
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
@ -45,36 +37,58 @@ from models.workflow import WorkflowAppLog, WorkflowRun
from tasks.ops_trace_task import process_trace_tasks from tasks.ops_trace_task import process_trace_tasks
def build_opik_trace_instance(config: OpikConfig): class OpsTraceProviderConfigMap(dict[str, dict[str, Any]]):
return OpikDataTrace(config) def __getitem__(self, provider: str) -> dict[str, Any]:
match provider:
case TracingProviderEnum.LANGFUSE:
from core.ops.entities.config_entity import LangfuseConfig
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
return {
provider_config_map: dict[str, dict[str, Any]] = {
TracingProviderEnum.LANGFUSE.value: {
"config_class": LangfuseConfig, "config_class": LangfuseConfig,
"secret_keys": ["public_key", "secret_key"], "secret_keys": ["public_key", "secret_key"],
"other_keys": ["host", "project_key"], "other_keys": ["host", "project_key"],
"trace_instance": LangFuseDataTrace, "trace_instance": LangFuseDataTrace,
}, }
TracingProviderEnum.LANGSMITH.value: {
case TracingProviderEnum.LANGSMITH:
from core.ops.entities.config_entity import LangSmithConfig
from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace
return {
"config_class": LangSmithConfig, "config_class": LangSmithConfig,
"secret_keys": ["api_key"], "secret_keys": ["api_key"],
"other_keys": ["project", "endpoint"], "other_keys": ["project", "endpoint"],
"trace_instance": LangSmithDataTrace, "trace_instance": LangSmithDataTrace,
}, }
TracingProviderEnum.OPIK.value: {
case TracingProviderEnum.OPIK:
from core.ops.entities.config_entity import OpikConfig
from core.ops.opik_trace.opik_trace import OpikDataTrace
return {
"config_class": OpikConfig, "config_class": OpikConfig,
"secret_keys": ["api_key"], "secret_keys": ["api_key"],
"other_keys": ["project", "url", "workspace"], "other_keys": ["project", "url", "workspace"],
"trace_instance": lambda config: build_opik_trace_instance(config), "trace_instance": OpikDataTrace,
}, }
TracingProviderEnum.WEAVE.value: {
case TracingProviderEnum.WEAVE:
from core.ops.entities.config_entity import WeaveConfig
from core.ops.weave_trace.weave_trace import WeaveDataTrace
return {
"config_class": WeaveConfig, "config_class": WeaveConfig,
"secret_keys": ["api_key"], "secret_keys": ["api_key"],
"other_keys": ["project", "entity", "endpoint"], "other_keys": ["project", "entity", "endpoint"],
"trace_instance": WeaveDataTrace, "trace_instance": WeaveDataTrace,
}, }
}
case _:
raise KeyError(f"Unsupported tracing provider: {provider}")
provider_config_map: dict[str, dict[str, Any]] = OpsTraceProviderConfigMap()
class OpsTraceManager: class OpsTraceManager:

@ -24,7 +24,7 @@ class EndpointProviderDeclaration(BaseModel):
""" """
settings: list[ProviderConfig] = Field(default_factory=list) settings: list[ProviderConfig] = Field(default_factory=list)
endpoints: Optional[list[EndpointDeclaration]] = Field(default_factory=list) endpoints: Optional[list[EndpointDeclaration]] = Field(default_factory=list[EndpointDeclaration])
class EndpointEntity(BasePluginEntity): class EndpointEntity(BasePluginEntity):

@ -52,7 +52,7 @@ class PluginResourceRequirements(BaseModel):
model: Optional[Model] = Field(default=None) model: Optional[Model] = Field(default=None)
node: Optional[Node] = Field(default=None) node: Optional[Node] = Field(default=None)
endpoint: Optional[Endpoint] = Field(default=None) endpoint: Optional[Endpoint] = Field(default=None)
storage: Storage = Field(default=None) storage: Optional[Storage] = Field(default=None)
permission: Optional[Permission] = Field(default=None) permission: Optional[Permission] = Field(default=None)
@ -66,9 +66,9 @@ class PluginCategory(enum.StrEnum):
class PluginDeclaration(BaseModel): class PluginDeclaration(BaseModel):
class Plugins(BaseModel): class Plugins(BaseModel):
tools: Optional[list[str]] = Field(default_factory=list) tools: Optional[list[str]] = Field(default_factory=list[str])
models: Optional[list[str]] = Field(default_factory=list) models: Optional[list[str]] = Field(default_factory=list[str])
endpoints: Optional[list[str]] = Field(default_factory=list) endpoints: Optional[list[str]] = Field(default_factory=list[str])
class Meta(BaseModel): class Meta(BaseModel):
minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$") minimum_dify_version: Optional[str] = Field(default=None, pattern=r"^\d{1,4}(\.\d{1,4}){1,3}(-\w{1,16})?$")
@ -84,6 +84,7 @@ class PluginDeclaration(BaseModel):
resource: PluginResourceRequirements resource: PluginResourceRequirements
plugins: Plugins plugins: Plugins
tags: list[str] = Field(default_factory=list) tags: list[str] = Field(default_factory=list)
repo: Optional[str] = Field(default=None)
verified: bool = Field(default=False) verified: bool = Field(default=False)
tool: Optional[ToolProviderEntity] = None tool: Optional[ToolProviderEntity] = None
model: Optional[ProviderEntity] = None model: Optional[ProviderEntity] = None

@ -55,8 +55,8 @@ class RequestInvokeLLM(BaseRequestInvokeModel):
mode: str mode: str
completion_params: dict[str, Any] = Field(default_factory=dict) completion_params: dict[str, Any] = Field(default_factory=dict)
prompt_messages: list[PromptMessage] = Field(default_factory=list) prompt_messages: list[PromptMessage] = Field(default_factory=list)
tools: Optional[list[PromptMessageTool]] = Field(default_factory=list) tools: Optional[list[PromptMessageTool]] = Field(default_factory=list[PromptMessageTool])
stop: Optional[list[str]] = Field(default_factory=list) stop: Optional[list[str]] = Field(default_factory=list[str])
stream: Optional[bool] = False stream: Optional[bool] = False
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())

@ -10,6 +10,7 @@ from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.rerank_type import RerankMode
@ -119,12 +120,25 @@ class RetrievalService:
return all_documents return all_documents
@classmethod @classmethod
def external_retrieve(cls, dataset_id: str, query: str, external_retrieval_model: Optional[dict] = None): def external_retrieve(
cls,
dataset_id: str,
query: str,
external_retrieval_model: Optional[dict] = None,
metadata_filtering_conditions: Optional[dict] = None,
):
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset: if not dataset:
return [] return []
metadata_condition = (
MetadataCondition(**metadata_filtering_conditions) if metadata_filtering_conditions else None
)
all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( all_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
dataset.tenant_id, dataset_id, query, external_retrieval_model or {} dataset.tenant_id,
dataset_id,
query,
external_retrieval_model or {},
metadata_condition=metadata_condition,
) )
return all_documents return all_documents

@ -317,7 +317,7 @@ class NotionExtractor(BaseExtractor):
data_source_info["last_edited_time"] = last_edited_time data_source_info["last_edited_time"] = last_edited_time
update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)} update_params = {DocumentModel.data_source_info: json.dumps(data_source_info)}
DocumentModel.query.filter_by(id=document_model.id).update(update_params) db.session.query(DocumentModel).filter_by(id=document_model.id).update(update_params)
db.session.commit() db.session.commit()
def get_notion_last_edited_time(self) -> str: def get_notion_last_edited_time(self) -> str:
@ -347,14 +347,18 @@ class NotionExtractor(BaseExtractor):
@classmethod @classmethod
def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str: def _get_access_token(cls, tenant_id: str, notion_workspace_id: str) -> str:
data_source_binding = DataSourceOauthBinding.query.filter( data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_( db.and_(
DataSourceOauthBinding.tenant_id == tenant_id, DataSourceOauthBinding.tenant_id == tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"', DataSourceOauthBinding.source_info["workspace_id"] == f'"{notion_workspace_id}"',
) )
).first() )
.first()
)
if not data_source_binding: if not data_source_binding:
raise Exception( raise Exception(

@ -76,8 +76,7 @@ class WordExtractor(BaseExtractor):
parsed = urlparse(url) parsed = urlparse(url)
return bool(parsed.netloc) and bool(parsed.scheme) return bool(parsed.netloc) and bool(parsed.scheme)
def _extract_images_from_docx(self, doc, image_folder): def _extract_images_from_docx(self, doc):
os.makedirs(image_folder, exist_ok=True)
image_count = 0 image_count = 0
image_map = {} image_map = {}
@ -210,7 +209,7 @@ class WordExtractor(BaseExtractor):
content = [] content = []
image_map = self._extract_images_from_docx(doc, image_folder) image_map = self._extract_images_from_docx(doc)
hyperlinks_url = None hyperlinks_url = None
url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+") url_pattern = re.compile(r"http://[^\s+]+//|https://[^\s+]+")
@ -225,7 +224,7 @@ class WordExtractor(BaseExtractor):
xml = ElementTree.XML(run.element.xml) xml = ElementTree.XML(run.element.xml)
x_child = [c for c in xml.iter() if c is not None] x_child = [c for c in xml.iter() if c is not None]
for x in x_child: for x in x_child:
if x_child is None: if x is None:
continue continue
if x.tag.endswith("instrText"): if x.tag.endswith("instrText"):
if x.text is None: if x.text is None:

@ -149,7 +149,7 @@ class DatasetRetrieval:
else: else:
inputs = {} inputs = {}
available_datasets_ids = [dataset.id for dataset in available_datasets] available_datasets_ids = [dataset.id for dataset in available_datasets]
metadata_filter_document_ids, metadata_condition = self._get_metadata_filter_condition( metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
available_datasets_ids, available_datasets_ids,
query, query,
tenant_id, tenant_id,
@ -237,12 +237,16 @@ class DatasetRetrieval:
if show_retrieve_source: if show_retrieve_source:
for record in records: for record in records:
segment = record.segment segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter( document = (
db.session.query(DatasetDocument)
.filter(
DatasetDocument.id == segment.document_id, DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
DatasetDocument.archived == False, DatasetDocument.archived == False,
).first() )
.first()
)
if dataset and document: if dataset and document:
source = { source = {
"dataset_id": dataset.id, "dataset_id": dataset.id,
@ -506,19 +510,30 @@ class DatasetRetrieval:
dify_documents = [document for document in documents if document.provider == "dify"] dify_documents = [document for document in documents if document.provider == "dify"]
for document in dify_documents: for document in dify_documents:
if document.metadata is not None: if document.metadata is not None:
dataset_document = DatasetDocument.query.filter( dataset_document = (
DatasetDocument.id == document.metadata["document_id"] db.session.query(DatasetDocument)
).first() .filter(DatasetDocument.id == document.metadata["document_id"])
.first()
)
if dataset_document: if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX: if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunk = ChildChunk.query.filter( child_chunk = (
db.session.query(ChildChunk)
.filter(
ChildChunk.index_node_id == document.metadata["doc_id"], ChildChunk.index_node_id == document.metadata["doc_id"],
ChildChunk.dataset_id == dataset_document.dataset_id, ChildChunk.dataset_id == dataset_document.dataset_id,
ChildChunk.document_id == dataset_document.id, ChildChunk.document_id == dataset_document.id,
).first() )
.first()
)
if child_chunk: if child_chunk:
segment = DocumentSegment.query.filter(DocumentSegment.id == child_chunk.segment_id).update( segment = (
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False db.session.query(DocumentSegment)
.filter(DocumentSegment.id == child_chunk.segment_id)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1},
synchronize_session=False,
)
) )
db.session.commit() db.session.commit()
else: else:
@ -649,6 +664,8 @@ class DatasetRetrieval:
return_resource: bool, return_resource: bool,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler, hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: dict,
) -> Optional[list[DatasetRetrieverBaseTool]]: ) -> Optional[list[DatasetRetrieverBaseTool]]:
""" """
A dataset tool is a tool that can be used to retrieve information from a dataset A dataset tool is a tool that can be used to retrieve information from a dataset
@ -706,6 +723,9 @@ class DatasetRetrieval:
hit_callbacks=[hit_callback], hit_callbacks=[hit_callback],
return_resource=return_resource, return_resource=return_resource,
retriever_from=invoke_from.to_source(), retriever_from=invoke_from.to_source(),
retrieve_config=retrieve_config,
user_id=user_id,
inputs=inputs,
) )
tools.append(tool) tools.append(tool)
@ -826,7 +846,7 @@ class DatasetRetrieval:
) )
return filter_documents[:top_k] if top_k else filter_documents return filter_documents[:top_k] if top_k else filter_documents
def _get_metadata_filter_condition( def get_metadata_filter_condition(
self, self,
dataset_ids: list, dataset_ids: list,
query: str, query: str,
@ -876,13 +896,20 @@ class DatasetRetrieval:
) )
elif metadata_filtering_mode == "manual": elif metadata_filtering_mode == "manual":
if metadata_filtering_conditions: if metadata_filtering_conditions:
metadata_condition = MetadataCondition(**metadata_filtering_conditions.model_dump()) conditions = []
for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore for sequence, condition in enumerate(metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name metadata_name = condition.name
expected_value = condition.value expected_value = condition.value
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"): if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str): if isinstance(expected_value, str):
expected_value = self._replace_metadata_filter_value(expected_value, inputs) expected_value = self._replace_metadata_filter_value(expected_value, inputs)
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
)
)
filters = self._process_metadata_filter_func( filters = self._process_metadata_filter_func(
sequence, sequence,
condition.comparison_operator, condition.comparison_operator,
@ -890,6 +917,10 @@ class DatasetRetrieval:
expected_value, expected_value,
filters, filters,
) )
metadata_condition = MetadataCondition(
logical_operator=metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
else: else:
raise ValueError("Invalid metadata filtering mode") raise ValueError("Invalid metadata filtering mode")
if filters: if filters:

@ -84,13 +84,17 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
document_context_list = [] document_context_list = []
index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata] index_node_ids = [document.metadata["doc_id"] for document in all_documents if document.metadata]
segments = DocumentSegment.query.filter( segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id.in_(self.dataset_ids), DocumentSegment.dataset_id.in_(self.dataset_ids),
DocumentSegment.completed_at.isnot(None), DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed", DocumentSegment.status == "completed",
DocumentSegment.enabled == True, DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids), DocumentSegment.index_node_id.in_(index_node_ids),
).all() )
.all()
)
if segments: if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)} index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
@ -106,12 +110,16 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
context_list = [] context_list = []
resource_number = 1 resource_number = 1
for segment in sorted_segments: for segment in sorted_segments:
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = Document.query.filter( document = (
db.session.query(Document)
.filter(
Document.id == segment.document_id, Document.id == segment.document_id,
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
).first() )
.first()
)
if dataset and document: if dataset and document:
source = { source = {
"position": resource_number, "position": resource_number,

@ -1,11 +1,12 @@
from typing import Any from typing import Any, Optional, cast
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from core.app.app_config.entities import DatasetRetrieveConfigEntity, ModelConfig
from core.rag.datasource.retrieval_service import RetrievalService from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.models.document import Document as RetrievalDocument from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db from extensions.ext_database import db
@ -34,7 +35,9 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
args_schema: type[BaseModel] = DatasetRetrieverToolInput args_schema: type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. " description: str = "use this to retrieve a dataset. "
dataset_id: str dataset_id: str
metadata_filtering_conditions: MetadataCondition user_id: Optional[str] = None
retrieve_config: DatasetRetrieveConfigEntity
inputs: dict
@classmethod @classmethod
def from_dataset(cls, dataset: Dataset, **kwargs): def from_dataset(cls, dataset: Dataset, **kwargs):
@ -48,7 +51,6 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
tenant_id=dataset.tenant_id, tenant_id=dataset.tenant_id,
dataset_id=dataset.id, dataset_id=dataset.id,
description=description, description=description,
metadata_filtering_conditions=MetadataCondition(),
**kwargs, **kwargs,
) )
@ -61,6 +63,21 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return "" return ""
for hit_callback in self.hit_callbacks: for hit_callback in self.hit_callbacks:
hit_callback.on_query(query, dataset.id) hit_callback.on_query(query, dataset.id)
dataset_retrieval = DatasetRetrieval()
metadata_filter_document_ids, metadata_condition = dataset_retrieval.get_metadata_filter_condition(
[dataset.id],
query,
self.tenant_id,
self.user_id or "unknown",
cast(str, self.retrieve_config.metadata_filtering_mode),
cast(ModelConfig, self.retrieve_config.metadata_model_config),
self.retrieve_config.metadata_filtering_conditions,
self.inputs,
)
if metadata_filter_document_ids:
document_ids_filter = metadata_filter_document_ids.get(dataset.id, [])
else:
document_ids_filter = None
if dataset.provider == "external": if dataset.provider == "external":
results = [] results = []
external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval( external_documents = ExternalDatasetService.fetch_external_knowledge_retrieval(
@ -68,7 +85,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
dataset_id=dataset.id, dataset_id=dataset.id,
query=query, query=query,
external_retrieval_parameters=dataset.retrieval_model, external_retrieval_parameters=dataset.retrieval_model,
metadata_condition=self.metadata_filtering_conditions, metadata_condition=metadata_condition,
) )
for external_document in external_documents: for external_document in external_documents:
document = RetrievalDocument( document = RetrievalDocument(
@ -104,12 +121,18 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
return str("\n".join([item.page_content for item in results])) return str("\n".join([item.page_content for item in results]))
else: else:
if metadata_condition and not document_ids_filter:
return ""
# get retrieval model , if the model is not setting , using default # get retrieval model , if the model is not setting , using default
retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model retrieval_model: dict[str, Any] = dataset.retrieval_model or default_retrieval_model
if dataset.indexing_technique == "economy": if dataset.indexing_technique == "economy":
# use keyword table query # use keyword table query
documents = RetrievalService.retrieve( documents = RetrievalService.retrieve(
retrieval_method="keyword_search", dataset_id=dataset.id, query=query, top_k=self.top_k retrieval_method="keyword_search",
dataset_id=dataset.id,
query=query,
top_k=self.top_k,
document_ids_filter=document_ids_filter,
) )
return str("\n".join([document.page_content for document in documents])) return str("\n".join([document.page_content for document in documents]))
else: else:
@ -128,6 +151,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
else None, else None,
reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model", reranking_mode=retrieval_model.get("reranking_mode") or "reranking_model",
weights=retrieval_model.get("weights"), weights=retrieval_model.get("weights"),
document_ids_filter=document_ids_filter,
) )
else: else:
documents = [] documents = []
@ -161,12 +185,16 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if self.return_resource: if self.return_resource:
for record in records: for record in records:
segment = record.segment segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter( document = (
db.session.query(DatasetDocument) # type: ignore
.filter(
DatasetDocument.id == segment.document_id, DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True, DatasetDocument.enabled == True,
DatasetDocument.archived == False, DatasetDocument.archived == False,
).first() )
.first()
)
if dataset and document: if dataset and document:
source = { source = {
"dataset_id": dataset.id, "dataset_id": dataset.id,

@ -34,6 +34,8 @@ class DatasetRetrieverTool(Tool):
return_resource: bool, return_resource: bool,
invoke_from: InvokeFrom, invoke_from: InvokeFrom,
hit_callback: DatasetIndexToolCallbackHandler, hit_callback: DatasetIndexToolCallbackHandler,
user_id: str,
inputs: dict,
) -> list["DatasetRetrieverTool"]: ) -> list["DatasetRetrieverTool"]:
""" """
get dataset tool get dataset tool
@ -57,6 +59,8 @@ class DatasetRetrieverTool(Tool):
return_resource=return_resource, return_resource=return_resource,
invoke_from=invoke_from, invoke_from=invoke_from,
hit_callback=hit_callback, hit_callback=hit_callback,
user_id=user_id,
inputs=inputs,
) )
if retrieval_tools is None or len(retrieval_tools) == 0: if retrieval_tools is None or len(retrieval_tools) == 0:
return [] return []

@ -30,7 +30,7 @@ class Variable(Segment):
""" """
id: str = Field( id: str = Field(
default=lambda _: str(uuid4()), default_factory=lambda: str(uuid4()),
description="Unique identity for variable.", description="Unique identity for variable.",
) )
name: str name: str

@ -36,7 +36,7 @@ class Graph(BaseModel):
root_node_id: str = Field(..., description="root node id of the graph") root_node_id: str = Field(..., description="root node id of the graph")
node_ids: list[str] = Field(default_factory=list, description="graph node ids") node_ids: list[str] = Field(default_factory=list, description="graph node ids")
node_id_config_mapping: dict[str, dict] = Field( node_id_config_mapping: dict[str, dict] = Field(
default_factory=list, description="node configs mapping (node id: node config)" default_factory=dict, description="node configs mapping (node id: node config)"
) )
edge_mapping: dict[str, list[GraphEdge]] = Field( edge_mapping: dict[str, list[GraphEdge]] = Field(
default_factory=dict, description="graph edge mapping (source node id: edges)" default_factory=dict, description="graph edge mapping (source node id: edges)"

@ -95,7 +95,12 @@ class StreamProcessor(ABC):
if node_id not in self.rest_node_ids: if node_id not in self.rest_node_ids:
return return
if node_id in reachable_node_ids:
return
self.rest_node_ids.remove(node_id) self.rest_node_ids.remove(node_id)
self.rest_node_ids.extend(set(reachable_node_ids) - set(self.rest_node_ids))
for edge in self.graph.edge_mapping.get(node_id, []): for edge in self.graph.edge_mapping.get(node_id, []):
if edge.target_node_id in reachable_node_ids: if edge.target_node_id in reachable_node_ids:
continue continue

@ -127,7 +127,7 @@ class CodeNode(BaseNode[CodeNodeData]):
depth: int = 1, depth: int = 1,
): ):
if depth > dify_config.CODE_MAX_DEPTH: if depth > dify_config.CODE_MAX_DEPTH:
raise DepthLimitError(f"Depth limit ${dify_config.CODE_MAX_DEPTH} reached, object too deep.") raise DepthLimitError(f"Depth limit {dify_config.CODE_MAX_DEPTH} reached, object too deep.")
transformed_result: dict[str, Any] = {} transformed_result: dict[str, Any] = {}
if output_schema is None: if output_schema is None:

@ -353,27 +353,26 @@ class IterationNode(BaseNode[IterationNodeData]):
) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent: ) -> NodeRunStartedEvent | BaseNodeEvent | InNodeEvent:
""" """
add iteration metadata to event. add iteration metadata to event.
ensures iteration context (ID, index/parallel_run_id) is added to metadata,
""" """
if not isinstance(event, BaseNodeEvent): if not isinstance(event, BaseNodeEvent):
return event return event
if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent): if self.node_data.is_parallel and isinstance(event, NodeRunStartedEvent):
event.parallel_mode_run_id = parallel_mode_run_id event.parallel_mode_run_id = parallel_mode_run_id
return event
if event.route_node_state.node_run_result: iter_metadata = {
metadata = event.route_node_state.node_run_result.metadata
if not metadata:
metadata = {}
if NodeRunMetadataKey.ITERATION_ID not in metadata:
metadata = {
**metadata,
NodeRunMetadataKey.ITERATION_ID: self.node_id, NodeRunMetadataKey.ITERATION_ID: self.node_id,
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID NodeRunMetadataKey.ITERATION_INDEX: iter_run_index,
if self.node_data.is_parallel
else NodeRunMetadataKey.ITERATION_INDEX: parallel_mode_run_id
if self.node_data.is_parallel
else iter_run_index,
} }
event.route_node_state.node_run_result.metadata = metadata if parallel_mode_run_id:
# for parallel, the specific branch ID is more important than the sequential index
iter_metadata[NodeRunMetadataKey.PARALLEL_MODE_RUN_ID] = parallel_mode_run_id
if event.route_node_state.node_run_result:
current_metadata = event.route_node_state.node_run_result.metadata or {}
if NodeRunMetadataKey.ITERATION_ID not in current_metadata:
event.route_node_state.node_run_result.metadata = {**current_metadata, **iter_metadata}
return event return event
def _run_single_iter( def _run_single_iter(

@ -264,6 +264,7 @@ class KnowledgeRetrievalNode(LLMNode):
"data_source_type": "external", "data_source_type": "external",
"retriever_from": "workflow", "retriever_from": "workflow",
"score": item.metadata.get("score"), "score": item.metadata.get("score"),
"doc_metadata": item.metadata,
}, },
"title": item.metadata.get("title"), "title": item.metadata.get("title"),
"content": item.page_content, "content": item.page_content,
@ -275,12 +276,16 @@ class KnowledgeRetrievalNode(LLMNode):
if records: if records:
for record in records: for record in records:
segment = record.segment segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
document = Document.query.filter( document = (
db.session.query(Document)
.filter(
Document.id == segment.document_id, Document.id == segment.document_id,
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
).first() )
.first()
)
if dataset and document: if dataset and document:
source = { source = {
"metadata": { "metadata": {
@ -289,7 +294,7 @@ class KnowledgeRetrievalNode(LLMNode):
"dataset_name": dataset.name, "dataset_name": dataset.name,
"document_id": document.id, "document_id": document.id,
"document_name": document.name, "document_name": document.name,
"document_data_source_type": document.data_source_type, "data_source_type": document.data_source_type,
"segment_id": segment.id, "segment_id": segment.id,
"retriever_from": "workflow", "retriever_from": "workflow",
"score": record.score or 0.0, "score": record.score or 0.0,
@ -356,12 +361,12 @@ class KnowledgeRetrievalNode(LLMNode):
) )
elif node_data.metadata_filtering_mode == "manual": elif node_data.metadata_filtering_mode == "manual":
if node_data.metadata_filtering_conditions: if node_data.metadata_filtering_conditions:
metadata_condition = MetadataCondition(**node_data.metadata_filtering_conditions.model_dump()) conditions = []
if node_data.metadata_filtering_conditions: if node_data.metadata_filtering_conditions:
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name metadata_name = condition.name
expected_value = condition.value expected_value = condition.value
if expected_value is not None or condition.comparison_operator in ("empty", "not empty"): if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str): if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template( expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value expected_value
@ -372,6 +377,13 @@ class KnowledgeRetrievalNode(LLMNode):
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() # type: ignore
else: else:
raise ValueError("Invalid expected metadata value type") raise ValueError("Invalid expected metadata value type")
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
)
)
filters = self._process_metadata_filter_func( filters = self._process_metadata_filter_func(
sequence, sequence,
condition.comparison_operator, condition.comparison_operator,
@ -379,6 +391,10 @@ class KnowledgeRetrievalNode(LLMNode):
expected_value, expected_value,
filters, filters,
) )
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
else: else:
raise ValueError("Invalid metadata filtering mode") raise ValueError("Invalid metadata filtering mode")
if filters: if filters:

@ -506,7 +506,7 @@ class LLMNode(BaseNode[LLMNodeData]):
"dataset_name": metadata.get("dataset_name"), "dataset_name": metadata.get("dataset_name"),
"document_id": metadata.get("document_id"), "document_id": metadata.get("document_id"),
"document_name": metadata.get("document_name"), "document_name": metadata.get("document_name"),
"data_source_type": metadata.get("document_data_source_type"), "data_source_type": metadata.get("data_source_type"),
"segment_id": metadata.get("segment_id"), "segment_id": metadata.get("segment_id"),
"retriever_from": metadata.get("retriever_from"), "retriever_from": metadata.get("retriever_from"),
"score": metadata.get("score"), "score": metadata.get("score"),

@ -26,7 +26,7 @@ class LoopNodeData(BaseLoopNodeData):
loop_count: int # Maximum number of loops loop_count: int # Maximum number of loops
break_conditions: list[Condition] # Conditions to break the loop break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"] logical_operator: Literal["and", "or"]
loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list) loop_variables: Optional[list[LoopVariableData]] = Field(default_factory=list[LoopVariableData])
outputs: Optional[Mapping[str, Any]] = None outputs: Optional[Mapping[str, Any]] = None

@ -337,7 +337,7 @@ class LoopNode(BaseNode[LoopNodeData]):
return {"check_break_result": True} return {"check_break_result": True}
elif isinstance(event, NodeRunFailedEvent): elif isinstance(event, NodeRunFailedEvent):
# Loop run failed # Loop run failed
yield event yield self._handle_event_metadata(event=event, iter_run_index=current_index)
yield LoopRunFailedEvent( yield LoopRunFailedEvent(
loop_id=self.id, loop_id=self.id,
loop_node_id=self.node_id, loop_node_id=self.node_id,

@ -39,7 +39,7 @@ class SubCondition(BaseModel):
class SubVariableCondition(BaseModel): class SubVariableCondition(BaseModel):
logical_operator: Literal["and", "or"] logical_operator: Literal["and", "or"]
conditions: list[SubCondition] = Field(default=list) conditions: list[SubCondition] = Field(default_factory=list)
class Condition(BaseModel): class Condition(BaseModel):

@ -39,6 +39,10 @@ def init_app(app: DifyApp):
handlers=log_handlers, handlers=log_handlers,
force=True, force=True,
) )
# Apply RequestIdFormatter to all handlers
apply_request_id_formatter()
# Disable propagation for noisy loggers to avoid duplicate logs # Disable propagation for noisy loggers to avoid duplicate logs
logging.getLogger("sqlalchemy.engine").propagate = False logging.getLogger("sqlalchemy.engine").propagate = False
log_tz = dify_config.LOG_TZ log_tz = dify_config.LOG_TZ
@ -74,3 +78,16 @@ class RequestIdFilter(logging.Filter):
def filter(self, record): def filter(self, record):
record.req_id = get_request_id() if flask.has_request_context() else "" record.req_id = get_request_id() if flask.has_request_context() else ""
return True return True
class RequestIdFormatter(logging.Formatter):
def format(self, record):
if not hasattr(record, "req_id"):
record.req_id = ""
return super().format(record)
def apply_request_id_formatter():
for handler in logging.root.handlers:
if handler.formatter:
handler.formatter = RequestIdFormatter(dify_config.LOG_FORMAT, dify_config.LOG_DATEFORMAT)

@ -26,7 +26,7 @@ class Mail:
match mail_type: match mail_type:
case "resend": case "resend":
import resend # type: ignore import resend
api_key = dify_config.RESEND_API_KEY api_key = dify_config.RESEND_API_KEY
if not api_key: if not api_key:

@ -1,6 +1,7 @@
from typing import Any, Union from typing import Any, Union
import redis import redis
from redis.cache import CacheConfig
from redis.cluster import ClusterNode, RedisCluster from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection from redis.connection import Connection, SSLConnection
from redis.sentinel import Sentinel from redis.sentinel import Sentinel
@ -51,6 +52,14 @@ def init_app(app: DifyApp):
connection_class: type[Union[Connection, SSLConnection]] = Connection connection_class: type[Union[Connection, SSLConnection]] = Connection
if dify_config.REDIS_USE_SSL: if dify_config.REDIS_USE_SSL:
connection_class = SSLConnection connection_class = SSLConnection
resp_protocol = dify_config.REDIS_SERIALIZATION_PROTOCOL
if dify_config.REDIS_ENABLE_CLIENT_SIDE_CACHE:
if resp_protocol >= 3:
clientside_cache_config = CacheConfig()
else:
raise ValueError("Client side cache is only supported in RESP3")
else:
clientside_cache_config = None
redis_params: dict[str, Any] = { redis_params: dict[str, Any] = {
"username": dify_config.REDIS_USERNAME, "username": dify_config.REDIS_USERNAME,
@ -59,6 +68,8 @@ def init_app(app: DifyApp):
"encoding": "utf-8", "encoding": "utf-8",
"encoding_errors": "strict", "encoding_errors": "strict",
"decode_responses": False, "decode_responses": False,
"protocol": resp_protocol,
"cache_config": clientside_cache_config,
} }
if dify_config.REDIS_USE_SENTINEL: if dify_config.REDIS_USE_SENTINEL:
@ -82,14 +93,22 @@ def init_app(app: DifyApp):
ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1])) ClusterNode(host=node.split(":")[0], port=int(node.split(":")[1]))
for node in dify_config.REDIS_CLUSTERS.split(",") for node in dify_config.REDIS_CLUSTERS.split(",")
] ]
# FIXME: mypy error here, try to figure out how to fix it redis_client.initialize(
redis_client.initialize(RedisCluster(startup_nodes=nodes, password=dify_config.REDIS_CLUSTERS_PASSWORD)) # type: ignore RedisCluster(
startup_nodes=nodes,
password=dify_config.REDIS_CLUSTERS_PASSWORD,
protocol=resp_protocol,
cache_config=clientside_cache_config,
)
)
else: else:
redis_params.update( redis_params.update(
{ {
"host": dify_config.REDIS_HOST, "host": dify_config.REDIS_HOST,
"port": dify_config.REDIS_PORT, "port": dify_config.REDIS_PORT,
"connection_class": connection_class, "connection_class": connection_class,
"protocol": resp_protocol,
"cache_config": clientside_cache_config,
} }
) )
pool = redis.ConnectionPool(**redis_params) pool = redis.ConnectionPool(**redis_params)

@ -61,13 +61,17 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages), "total": len(pages),
} }
# save data source binding # save data source binding
data_source_binding = DataSourceOauthBinding.query.filter( data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_( db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token, DataSourceOauthBinding.access_token == access_token,
) )
).first() )
.first()
)
if data_source_binding: if data_source_binding:
data_source_binding.source_info = source_info data_source_binding.source_info = source_info
data_source_binding.disabled = False data_source_binding.disabled = False
@ -97,13 +101,17 @@ class NotionOAuth(OAuthDataSource):
"total": len(pages), "total": len(pages),
} }
# save data source binding # save data source binding
data_source_binding = DataSourceOauthBinding.query.filter( data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_( db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.access_token == access_token, DataSourceOauthBinding.access_token == access_token,
) )
).first() )
.first()
)
if data_source_binding: if data_source_binding:
data_source_binding.source_info = source_info data_source_binding.source_info = source_info
data_source_binding.disabled = False data_source_binding.disabled = False
@ -121,14 +129,18 @@ class NotionOAuth(OAuthDataSource):
def sync_data_source(self, binding_id: str): def sync_data_source(self, binding_id: str):
# save data source binding # save data source binding
data_source_binding = DataSourceOauthBinding.query.filter( data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_( db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.id == binding_id, DataSourceOauthBinding.id == binding_id,
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
) )
).first() )
.first()
)
if data_source_binding: if data_source_binding:
# get all authorized pages # get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token) pages = self.get_authorized_pages(data_source_binding.access_token)

@ -0,0 +1,33 @@
"""add index for workflow_conversation_variables.conversation_id
Revision ID: d28f2004b072
Revises: 6a9f914f656c
Create Date: 2025-05-14 14:03:36.713828
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'd28f2004b072'
down_revision = '6a9f914f656c'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op:
batch_op.create_index(batch_op.f('workflow_conversation_variables_conversation_id_idx'), ['conversation_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('workflow_conversation_variables', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('workflow_conversation_variables_conversation_id_idx'))
# ### end Alembic commands ###

@ -1,5 +1,6 @@
import enum import enum
import json import json
from typing import cast
from flask_login import UserMixin # type: ignore from flask_login import UserMixin # type: ignore
from sqlalchemy import func from sqlalchemy import func
@ -46,13 +47,12 @@ class Account(UserMixin, Base):
@property @property
def current_tenant(self): def current_tenant(self):
# FIXME: fix the type error later, because the type is important maybe cause some bugs
return self._current_tenant # type: ignore return self._current_tenant # type: ignore
@current_tenant.setter @current_tenant.setter
def current_tenant(self, value: "Tenant"): def current_tenant(self, value: "Tenant"):
tenant = value tenant = value
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=self.id).first() ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=self.id).first()
if ta: if ta:
tenant.current_role = ta.role tenant.current_role = ta.role
else: else:
@ -64,25 +64,23 @@ class Account(UserMixin, Base):
def current_tenant_id(self) -> str | None: def current_tenant_id(self) -> str | None:
return self._current_tenant.id if self._current_tenant else None return self._current_tenant.id if self._current_tenant else None
@current_tenant_id.setter def set_tenant_id(self, tenant_id: str):
def current_tenant_id(self, value: str): tenant_account_join = cast(
try: tuple[Tenant, TenantAccountJoin],
tenant_account_join = ( (
db.session.query(Tenant, TenantAccountJoin) db.session.query(Tenant, TenantAccountJoin)
.filter(Tenant.id == value) .filter(Tenant.id == tenant_id)
.filter(TenantAccountJoin.tenant_id == Tenant.id) .filter(TenantAccountJoin.tenant_id == Tenant.id)
.filter(TenantAccountJoin.account_id == self.id) .filter(TenantAccountJoin.account_id == self.id)
.one_or_none() .one_or_none()
),
) )
if tenant_account_join: if not tenant_account_join:
tenant, ta = tenant_account_join return
tenant.current_role = ta.role
else:
tenant = None
except Exception:
tenant = None
tenant, join = tenant_account_join
tenant.current_role = join.role
self._current_tenant = tenant self._current_tenant = tenant
@property @property
@ -191,7 +189,7 @@ class TenantAccountRole(enum.StrEnum):
} }
class Tenant(db.Model): # type: ignore[name-defined] class Tenant(Base):
__tablename__ = "tenants" __tablename__ = "tenants"
__table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),) __table_args__ = (db.PrimaryKeyConstraint("id", name="tenant_pkey"),)
@ -220,7 +218,7 @@ class Tenant(db.Model): # type: ignore[name-defined]
self.custom_config = json.dumps(value) self.custom_config = json.dumps(value)
class TenantAccountJoin(db.Model): # type: ignore[name-defined] class TenantAccountJoin(Base):
__tablename__ = "tenant_account_joins" __tablename__ = "tenant_account_joins"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"), db.PrimaryKeyConstraint("id", name="tenant_account_join_pkey"),
@ -239,7 +237,7 @@ class TenantAccountJoin(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class AccountIntegrate(db.Model): # type: ignore[name-defined] class AccountIntegrate(Base):
__tablename__ = "account_integrates" __tablename__ = "account_integrates"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="account_integrate_pkey"), db.PrimaryKeyConstraint("id", name="account_integrate_pkey"),
@ -256,7 +254,7 @@ class AccountIntegrate(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class InvitationCode(db.Model): # type: ignore[name-defined] class InvitationCode(Base):
__tablename__ = "invitation_codes" __tablename__ = "invitation_codes"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="invitation_code_pkey"), db.PrimaryKeyConstraint("id", name="invitation_code_pkey"),

@ -2,6 +2,7 @@ import enum
from sqlalchemy import func from sqlalchemy import func
from .base import Base
from .engine import db from .engine import db
from .types import StringUUID from .types import StringUUID
@ -13,7 +14,7 @@ class APIBasedExtensionPoint(enum.Enum):
APP_MODERATION_OUTPUT = "app.moderation.output" APP_MODERATION_OUTPUT = "app.moderation.output"
class APIBasedExtension(db.Model): # type: ignore[name-defined] class APIBasedExtension(Base):
__tablename__ = "api_based_extensions" __tablename__ = "api_based_extensions"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"), db.PrimaryKeyConstraint("id", name="api_based_extension_pkey"),

@ -22,6 +22,7 @@ from extensions.ext_storage import storage
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from .account import Account from .account import Account
from .base import Base
from .engine import db from .engine import db
from .model import App, Tag, TagBinding, UploadFile from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID from .types import StringUUID
@ -33,7 +34,7 @@ class DatasetPermissionEnum(enum.StrEnum):
PARTIAL_TEAM = "partial_members" PARTIAL_TEAM = "partial_members"
class Dataset(db.Model): # type: ignore[name-defined] class Dataset(Base):
__tablename__ = "datasets" __tablename__ = "datasets"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_pkey"), db.PrimaryKeyConstraint("id", name="dataset_pkey"),
@ -92,7 +93,8 @@ class Dataset(db.Model): # type: ignore[name-defined]
@property @property
def latest_process_rule(self): def latest_process_rule(self):
return ( return (
DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.dataset_id == self.id)
.order_by(DatasetProcessRule.created_at.desc()) .order_by(DatasetProcessRule.created_at.desc())
.first() .first()
) )
@ -137,7 +139,8 @@ class Dataset(db.Model): # type: ignore[name-defined]
@property @property
def word_count(self): def word_count(self):
return ( return (
Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) db.session.query(Document)
.with_entities(func.coalesce(func.sum(Document.word_count)))
.filter(Document.dataset_id == self.id) .filter(Document.dataset_id == self.id)
.scalar() .scalar()
) )
@ -255,7 +258,7 @@ class Dataset(db.Model): # type: ignore[name-defined]
return f"Vector_index_{normalized_dataset_id}_Node" return f"Vector_index_{normalized_dataset_id}_Node"
class DatasetProcessRule(db.Model): # type: ignore[name-defined] class DatasetProcessRule(Base):
__tablename__ = "dataset_process_rules" __tablename__ = "dataset_process_rules"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
@ -295,7 +298,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
return None return None
class Document(db.Model): # type: ignore[name-defined] class Document(Base):
__tablename__ = "documents" __tablename__ = "documents"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="document_pkey"), db.PrimaryKeyConstraint("id", name="document_pkey"),
@ -439,12 +442,13 @@ class Document(db.Model): # type: ignore[name-defined]
@property @property
def segment_count(self): def segment_count(self):
return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count() return db.session.query(DocumentSegment).filter(DocumentSegment.document_id == self.id).count()
@property @property
def hit_count(self): def hit_count(self):
return ( return (
DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) db.session.query(DocumentSegment)
.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
.filter(DocumentSegment.document_id == self.id) .filter(DocumentSegment.document_id == self.id)
.scalar() .scalar()
) )
@ -635,7 +639,7 @@ class Document(db.Model): # type: ignore[name-defined]
) )
class DocumentSegment(db.Model): # type: ignore[name-defined] class DocumentSegment(Base):
__tablename__ = "document_segments" __tablename__ = "document_segments"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="document_segment_pkey"), db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
@ -786,7 +790,7 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
return text return text
class ChildChunk(db.Model): # type: ignore[name-defined] class ChildChunk(Base):
__tablename__ = "child_chunks" __tablename__ = "child_chunks"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="child_chunk_pkey"), db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
@ -829,7 +833,7 @@ class ChildChunk(db.Model): # type: ignore[name-defined]
return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first() return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
class AppDatasetJoin(db.Model): # type: ignore[name-defined] class AppDatasetJoin(Base):
__tablename__ = "app_dataset_joins" __tablename__ = "app_dataset_joins"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
@ -846,7 +850,7 @@ class AppDatasetJoin(db.Model): # type: ignore[name-defined]
return db.session.get(App, self.app_id) return db.session.get(App, self.app_id)
class DatasetQuery(db.Model): # type: ignore[name-defined] class DatasetQuery(Base):
__tablename__ = "dataset_queries" __tablename__ = "dataset_queries"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
@ -863,7 +867,7 @@ class DatasetQuery(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
class DatasetKeywordTable(db.Model): # type: ignore[name-defined] class DatasetKeywordTable(Base):
__tablename__ = "dataset_keyword_tables" __tablename__ = "dataset_keyword_tables"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
@ -891,7 +895,7 @@ class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
return dct return dct
# get dataset # get dataset
dataset = Dataset.query.filter_by(id=self.dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=self.dataset_id).first()
if not dataset: if not dataset:
return None return None
if self.data_source_type == "database": if self.data_source_type == "database":
@ -908,7 +912,7 @@ class DatasetKeywordTable(db.Model): # type: ignore[name-defined]
return None return None
class Embedding(db.Model): # type: ignore[name-defined] class Embedding(Base):
__tablename__ = "embeddings" __tablename__ = "embeddings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="embedding_pkey"), db.PrimaryKeyConstraint("id", name="embedding_pkey"),
@ -932,7 +936,7 @@ class Embedding(db.Model): # type: ignore[name-defined]
return cast(list[float], pickle.loads(self.embedding)) # noqa: S301 return cast(list[float], pickle.loads(self.embedding)) # noqa: S301
class DatasetCollectionBinding(db.Model): # type: ignore[name-defined] class DatasetCollectionBinding(Base):
__tablename__ = "dataset_collection_bindings" __tablename__ = "dataset_collection_bindings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
@ -947,7 +951,7 @@ class DatasetCollectionBinding(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class TidbAuthBinding(db.Model): # type: ignore[name-defined] class TidbAuthBinding(Base):
__tablename__ = "tidb_auth_bindings" __tablename__ = "tidb_auth_bindings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
@ -967,7 +971,7 @@ class TidbAuthBinding(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class Whitelist(db.Model): # type: ignore[name-defined] class Whitelist(Base):
__tablename__ = "whitelists" __tablename__ = "whitelists"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="whitelists_pkey"), db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
@ -979,7 +983,7 @@ class Whitelist(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetPermission(db.Model): # type: ignore[name-defined] class DatasetPermission(Base):
__tablename__ = "dataset_permissions" __tablename__ = "dataset_permissions"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
@ -996,7 +1000,7 @@ class DatasetPermission(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined] class ExternalKnowledgeApis(Base):
__tablename__ = "external_knowledge_apis" __tablename__ = "external_knowledge_apis"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
@ -1049,7 +1053,7 @@ class ExternalKnowledgeApis(db.Model): # type: ignore[name-defined]
return dataset_bindings return dataset_bindings
class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined] class ExternalKnowledgeBindings(Base):
__tablename__ = "external_knowledge_bindings" __tablename__ = "external_knowledge_bindings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
@ -1070,7 +1074,7 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined] class DatasetAutoDisableLog(Base):
__tablename__ = "dataset_auto_disable_logs" __tablename__ = "dataset_auto_disable_logs"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"), db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
@ -1087,7 +1091,7 @@ class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class RateLimitLog(db.Model): # type: ignore[name-defined] class RateLimitLog(Base):
__tablename__ = "rate_limit_logs" __tablename__ = "rate_limit_logs"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"), db.PrimaryKeyConstraint("id", name="rate_limit_log_pkey"),
@ -1102,7 +1106,7 @@ class RateLimitLog(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
class DatasetMetadata(db.Model): # type: ignore[name-defined] class DatasetMetadata(Base):
__tablename__ = "dataset_metadatas" __tablename__ = "dataset_metadatas"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"), db.PrimaryKeyConstraint("id", name="dataset_metadata_pkey"),
@ -1121,7 +1125,7 @@ class DatasetMetadata(db.Model): # type: ignore[name-defined]
updated_by = db.Column(StringUUID, nullable=True) updated_by = db.Column(StringUUID, nullable=True)
class DatasetMetadataBinding(db.Model): # type: ignore[name-defined] class DatasetMetadataBinding(Base):
__tablename__ = "dataset_metadata_bindings" __tablename__ = "dataset_metadata_bindings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"), db.PrimaryKeyConstraint("id", name="dataset_metadata_binding_pkey"),

@ -16,7 +16,7 @@ if TYPE_CHECKING:
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from flask_login import UserMixin # type: ignore from flask_login import UserMixin
from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text from sqlalchemy import Float, Index, PrimaryKeyConstraint, func, text
from sqlalchemy.orm import Mapped, Session, mapped_column from sqlalchemy.orm import Mapped, Session, mapped_column
@ -25,13 +25,13 @@ from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import helpers as file_helpers from core.file import helpers as file_helpers
from libs.helper import generate_string from libs.helper import generate_string
from models.base import Base
from models.enums import CreatedByRole
from models.workflow import WorkflowRunStatus
from .account import Account, Tenant from .account import Account, Tenant
from .base import Base
from .engine import db from .engine import db
from .enums import CreatedByRole
from .types import StringUUID from .types import StringUUID
from .workflow import WorkflowRunStatus
if TYPE_CHECKING: if TYPE_CHECKING:
from .workflow import Workflow from .workflow import Workflow
@ -602,7 +602,7 @@ class InstalledApp(Base):
return tenant return tenant
class Conversation(db.Model): # type: ignore[name-defined] class Conversation(Base):
__tablename__ = "conversations" __tablename__ = "conversations"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="conversation_pkey"), db.PrimaryKeyConstraint("id", name="conversation_pkey"),
@ -794,7 +794,7 @@ class Conversation(db.Model): # type: ignore[name-defined]
for message in messages: for message in messages:
if message.workflow_run: if message.workflow_run:
status_counts[message.workflow_run.status] += 1 status_counts[WorkflowRunStatus(message.workflow_run.status)] += 1
return ( return (
{ {
@ -864,7 +864,7 @@ class Conversation(db.Model): # type: ignore[name-defined]
} }
class Message(db.Model): # type: ignore[name-defined] class Message(Base):
__tablename__ = "messages" __tablename__ = "messages"
__table_args__ = ( __table_args__ = (
PrimaryKeyConstraint("id", name="message_pkey"), PrimaryKeyConstraint("id", name="message_pkey"),
@ -1211,7 +1211,7 @@ class Message(db.Model): # type: ignore[name-defined]
) )
class MessageFeedback(db.Model): # type: ignore[name-defined] class MessageFeedback(Base):
__tablename__ = "message_feedbacks" __tablename__ = "message_feedbacks"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="message_feedback_pkey"), db.PrimaryKeyConstraint("id", name="message_feedback_pkey"),
@ -1237,8 +1237,23 @@ class MessageFeedback(db.Model): # type: ignore[name-defined]
account = db.session.query(Account).filter(Account.id == self.from_account_id).first() account = db.session.query(Account).filter(Account.id == self.from_account_id).first()
return account return account
def to_dict(self):
return {
"id": str(self.id),
"app_id": str(self.app_id),
"conversation_id": str(self.conversation_id),
"message_id": str(self.message_id),
"rating": self.rating,
"content": self.content,
"from_source": self.from_source,
"from_end_user_id": str(self.from_end_user_id) if self.from_end_user_id else None,
"from_account_id": str(self.from_account_id) if self.from_account_id else None,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
}
class MessageFile(db.Model): # type: ignore[name-defined]
class MessageFile(Base):
__tablename__ = "message_files" __tablename__ = "message_files"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="message_file_pkey"), db.PrimaryKeyConstraint("id", name="message_file_pkey"),
@ -1279,7 +1294,7 @@ class MessageFile(db.Model): # type: ignore[name-defined]
created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class MessageAnnotation(db.Model): # type: ignore[name-defined] class MessageAnnotation(Base):
__tablename__ = "message_annotations" __tablename__ = "message_annotations"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="message_annotation_pkey"), db.PrimaryKeyConstraint("id", name="message_annotation_pkey"),
@ -1310,7 +1325,7 @@ class MessageAnnotation(db.Model): # type: ignore[name-defined]
return account return account
class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined] class AppAnnotationHitHistory(Base):
__tablename__ = "app_annotation_hit_histories" __tablename__ = "app_annotation_hit_histories"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"), db.PrimaryKeyConstraint("id", name="app_annotation_hit_histories_pkey"),
@ -1322,7 +1337,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
annotation_id = db.Column(StringUUID, nullable=False) annotation_id: Mapped[str] = db.Column(StringUUID, nullable=False)
source = db.Column(db.Text, nullable=False) source = db.Column(db.Text, nullable=False)
question = db.Column(db.Text, nullable=False) question = db.Column(db.Text, nullable=False)
account_id = db.Column(StringUUID, nullable=False) account_id = db.Column(StringUUID, nullable=False)
@ -1348,7 +1363,7 @@ class AppAnnotationHitHistory(db.Model): # type: ignore[name-defined]
return account return account
class AppAnnotationSetting(db.Model): # type: ignore[name-defined] class AppAnnotationSetting(Base):
__tablename__ = "app_annotation_settings" __tablename__ = "app_annotation_settings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"), db.PrimaryKeyConstraint("id", name="app_annotation_settings_pkey"),
@ -1364,26 +1379,6 @@ class AppAnnotationSetting(db.Model): # type: ignore[name-defined]
updated_user_id = db.Column(StringUUID, nullable=False) updated_user_id = db.Column(StringUUID, nullable=False)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property
def created_account(self):
account = (
db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.created_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id)
.first()
)
return account
@property
def updated_account(self):
account = (
db.session.query(Account)
.join(AppAnnotationSetting, AppAnnotationSetting.updated_user_id == Account.id)
.filter(AppAnnotationSetting.id == self.annotation_id)
.first()
)
return account
@property @property
def collection_binding_detail(self): def collection_binding_detail(self):
from .dataset import DatasetCollectionBinding from .dataset import DatasetCollectionBinding

@ -2,8 +2,7 @@ from enum import Enum
from sqlalchemy import func from sqlalchemy import func
from models.base import Base from .base import Base
from .engine import db from .engine import db
from .types import StringUUID from .types import StringUUID

@ -9,7 +9,7 @@ from .engine import db
from .types import StringUUID from .types import StringUUID
class DataSourceOauthBinding(db.Model): # type: ignore[name-defined] class DataSourceOauthBinding(Base):
__tablename__ = "data_source_oauth_bindings" __tablename__ = "data_source_oauth_bindings"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="source_binding_pkey"), db.PrimaryKeyConstraint("id", name="source_binding_pkey"),

@ -9,7 +9,7 @@ if TYPE_CHECKING:
from models.model import AppMode from models.model import AppMode
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import Index, PrimaryKeyConstraint, func from sqlalchemy import func
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
import contexts import contexts
@ -18,11 +18,11 @@ from core.helper import encrypter
from core.variables import SecretVariable, Variable from core.variables import SecretVariable, Variable
from factories import variable_factory from factories import variable_factory
from libs import helper from libs import helper
from models.base import Base
from models.enums import CreatedByRole
from .account import Account from .account import Account
from .base import Base
from .engine import db from .engine import db
from .enums import CreatedByRole
from .types import StringUUID from .types import StringUUID
if TYPE_CHECKING: if TYPE_CHECKING:
@ -736,8 +736,7 @@ class WorkflowAppLog(Base):
__tablename__ = "workflow_app_logs" __tablename__ = "workflow_app_logs"
__table_args__ = ( __table_args__ = (
db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"), db.PrimaryKeyConstraint("id", name="workflow_app_log_pkey"),
db.Index("workflow_app_log_app_idx", "tenant_id", "app_id", "created_at"), db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
db.Index("workflow_app_log_workflow_run_idx", "workflow_run_id"),
) )
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
@ -769,17 +768,12 @@ class WorkflowAppLog(Base):
class ConversationVariable(Base): class ConversationVariable(Base):
__tablename__ = "workflow_conversation_variables" __tablename__ = "workflow_conversation_variables"
__table_args__ = (
PrimaryKeyConstraint("id", "conversation_id", name="workflow_conversation_variables_pkey"),
Index("workflow__conversation_variables_app_id_idx", "app_id"),
Index("workflow__conversation_variables_created_at_idx", "created_at"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True) id: Mapped[str] = mapped_column(StringUUID, primary_key=True)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True) conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False, primary_key=True, index=True)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
data = mapped_column(db.Text, nullable=False) data = mapped_column(db.Text, nullable=False)
created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp(), index=True)
updated_at = mapped_column( updated_at = mapped_column(
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp() db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
) )

@ -14,7 +14,7 @@ dependencies = [
"chardet~=5.1.0", "chardet~=5.1.0",
"flask~=3.1.0", "flask~=3.1.0",
"flask-compress~=1.17", "flask-compress~=1.17",
"flask-cors~=4.0.0", "flask-cors~=5.0.0",
"flask-login~=0.6.3", "flask-login~=0.6.3",
"flask-migrate~=4.0.7", "flask-migrate~=4.0.7",
"flask-restful~=0.3.10", "flask-restful~=0.3.10",
@ -63,25 +63,24 @@ dependencies = [
"psycogreen~=1.0.2", "psycogreen~=1.0.2",
"psycopg2-binary~=2.9.6", "psycopg2-binary~=2.9.6",
"pycryptodome==3.19.1", "pycryptodome==3.19.1",
"pydantic~=2.9.2", "pydantic~=2.11.4",
"pydantic-extra-types~=2.9.0", "pydantic-extra-types~=2.10.3",
"pydantic-settings~=2.6.0", "pydantic-settings~=2.9.1",
"pyjwt~=2.8.0", "pyjwt~=2.8.0",
"pypdfium2~=4.30.0", "pypdfium2==4.30.0",
"python-docx~=1.1.0", "python-docx~=1.1.0",
"python-dotenv==1.0.1", "python-dotenv==1.0.1",
"pyyaml~=6.0.1", "pyyaml~=6.0.1",
"readabilipy==0.2.0", "readabilipy~=0.3.0",
"redis[hiredis]~=5.0.3", "redis[hiredis]~=6.0.0",
"resend~=0.7.0", "resend~=2.9.0",
"sentry-sdk[flask]~=1.44.1", "sentry-sdk[flask]~=2.28.0",
"sqlalchemy~=2.0.29", "sqlalchemy~=2.0.29",
"starlette==0.41.0", "starlette==0.41.0",
"tiktoken~=0.9.0", "tiktoken~=0.9.0",
"tokenizers~=0.15.0", "transformers~=4.51.0",
"transformers~=4.35.0",
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1", "unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
"weave~=0.51.34", "weave~=0.51.0",
"yarl~=1.18.3", "yarl~=1.18.3",
"webvtt-py~=0.5.1", "webvtt-py~=0.5.1",
] ]
@ -195,7 +194,7 @@ vdb = [
"tcvectordb~=1.6.4", "tcvectordb~=1.6.4",
"tidb-vector==0.0.9", "tidb-vector==0.0.9",
"upstash-vector==0.6.0", "upstash-vector==0.6.0",
"volcengine-compat~=1.0.156", "volcengine-compat~=1.0.0",
"weaviate-client~=3.24.0", "weaviate-client~=3.24.0",
"xinference-client~=1.2.2", "xinference-client~=1.2.2",
] ]

@ -1,4 +1,5 @@
import datetime import datetime
import logging
import time import time
import click import click
@ -20,6 +21,8 @@ from models.model import (
from models.web import SavedMessage from models.web import SavedMessage
from services.feature_service import FeatureService from services.feature_service import FeatureService
_logger = logging.getLogger(__name__)
@app.celery.task(queue="dataset") @app.celery.task(queue="dataset")
def clean_messages(): def clean_messages():
@ -46,7 +49,14 @@ def clean_messages():
break break
for message in messages: for message in messages:
plan_sandbox_clean_message_day = message.created_at plan_sandbox_clean_message_day = message.created_at
app = App.query.filter_by(id=message.app_id).first() app = db.session.query(App).filter_by(id=message.app_id).first()
if not app:
_logger.warning(
"Expected App record to exist, but none was found, app_id=%s, message_id=%s",
message.app_id,
message.id,
)
continue
features_cache_key = f"features:{app.tenant_id}" features_cache_key = f"features:{app.tenant_id}"
plan_cache = redis_client.get(features_cache_key) plan_cache = redis_client.get(features_cache_key)
if plan_cache is None: if plan_cache is None:

@ -2,7 +2,7 @@ import datetime
import time import time
import click import click
from sqlalchemy import func from sqlalchemy import func, select
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
import app import app
@ -51,8 +51,9 @@ def clean_unused_datasets_task():
) )
# Main query with join and filter # Main query with join and filter
datasets = ( stmt = (
Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter( .filter(
Dataset.created_at < plan_sandbox_clean_day, Dataset.created_at < plan_sandbox_clean_day,
@ -60,9 +61,10 @@ def clean_unused_datasets_task():
func.coalesce(document_subquery_old.c.document_count, 0) > 0, func.coalesce(document_subquery_old.c.document_count, 0) > 0,
) )
.order_by(Dataset.created_at.desc()) .order_by(Dataset.created_at.desc())
.paginate(page=1, per_page=50)
) )
datasets = db.paginate(stmt, page=1, per_page=50)
except NotFound: except NotFound:
break break
if datasets.items is None or len(datasets.items) == 0: if datasets.items is None or len(datasets.items) == 0:
@ -99,7 +101,7 @@ def clean_unused_datasets_task():
# update document # update document
update_params = {Document.enabled: False} update_params = {Document.enabled: False}
Document.query.filter_by(dataset_id=dataset.id).update(update_params) db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params)
db.session.commit() db.session.commit()
click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")) click.echo(click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green"))
except Exception as e: except Exception as e:
@ -135,8 +137,9 @@ def clean_unused_datasets_task():
) )
# Main query with join and filter # Main query with join and filter
datasets = ( stmt = (
Dataset.query.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id) select(Dataset)
.outerjoin(document_subquery_new, Dataset.id == document_subquery_new.c.dataset_id)
.outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id) .outerjoin(document_subquery_old, Dataset.id == document_subquery_old.c.dataset_id)
.filter( .filter(
Dataset.created_at < plan_pro_clean_day, Dataset.created_at < plan_pro_clean_day,
@ -144,8 +147,8 @@ def clean_unused_datasets_task():
func.coalesce(document_subquery_old.c.document_count, 0) > 0, func.coalesce(document_subquery_old.c.document_count, 0) > 0,
) )
.order_by(Dataset.created_at.desc()) .order_by(Dataset.created_at.desc())
.paginate(page=1, per_page=50)
) )
datasets = db.paginate(stmt, page=1, per_page=50)
except NotFound: except NotFound:
break break
@ -175,7 +178,7 @@ def clean_unused_datasets_task():
# update document # update document
update_params = {Document.enabled: False} update_params = {Document.enabled: False}
Document.query.filter_by(dataset_id=dataset.id).update(update_params) db.session.query(Document).filter_by(dataset_id=dataset.id).update(update_params)
db.session.commit() db.session.commit()
click.echo( click.echo(
click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green") click.style("Cleaned unused dataset {} from db success!".format(dataset.id), fg="green")

@ -19,7 +19,9 @@ def create_tidb_serverless_task():
while True: while True:
try: try:
# check the number of idle tidb serverless # check the number of idle tidb serverless
idle_tidb_serverless_number = TidbAuthBinding.query.filter(TidbAuthBinding.active == False).count() idle_tidb_serverless_number = (
db.session.query(TidbAuthBinding).filter(TidbAuthBinding.active == False).count()
)
if idle_tidb_serverless_number >= tidb_serverless_number: if idle_tidb_serverless_number >= tidb_serverless_number:
break break
# create tidb serverless # create tidb serverless

@ -29,7 +29,9 @@ def mail_clean_document_notify_task():
# send document clean notify mail # send document clean notify mail
try: try:
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all() dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog).filter(DatasetAutoDisableLog.notified == False).all()
)
# group by tenant_id # group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list) dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
for dataset_auto_disable_log in dataset_auto_disable_logs: for dataset_auto_disable_log in dataset_auto_disable_logs:
@ -43,14 +45,16 @@ def mail_clean_document_notify_task():
if plan != "sandbox": if plan != "sandbox":
knowledge_details = [] knowledge_details = []
# check tenant # check tenant
tenant = Tenant.query.filter(Tenant.id == tenant_id).first() tenant = db.session.query(Tenant).filter(Tenant.id == tenant_id).first()
if not tenant: if not tenant:
continue continue
# check current owner # check current owner
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() current_owner_join = (
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
)
if not current_owner_join: if not current_owner_join:
continue continue
account = Account.query.filter(Account.id == current_owner_join.account_id).first() account = db.session.query(Account).filter(Account.id == current_owner_join.account_id).first()
if not account: if not account:
continue continue
@ -63,7 +67,7 @@ def mail_clean_document_notify_task():
) )
for dataset_id, document_ids in dataset_auto_dataset_map.items(): for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = Dataset.query.filter(Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if dataset: if dataset:
document_count = len(document_ids) document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents") knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")

@ -5,6 +5,7 @@ import click
import app import app
from configs import dify_config from configs import dify_config
from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService from core.rag.datasource.vdb.tidb_on_qdrant.tidb_service import TidbService
from extensions.ext_database import db
from models.dataset import TidbAuthBinding from models.dataset import TidbAuthBinding
@ -14,9 +15,11 @@ def update_tidb_serverless_status_task():
start_at = time.perf_counter() start_at = time.perf_counter()
try: try:
# check the number of idle tidb serverless # check the number of idle tidb serverless
tidb_serverless_list = TidbAuthBinding.query.filter( tidb_serverless_list = (
TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING" db.session.query(TidbAuthBinding)
).all() .filter(TidbAuthBinding.active == False, TidbAuthBinding.status == "CREATING")
.all()
)
if len(tidb_serverless_list) == 0: if len(tidb_serverless_list) == 0:
return return
# update tidb serverless status # update tidb serverless status

@ -108,17 +108,20 @@ class AccountService:
if account.status == AccountStatus.BANNED.value: if account.status == AccountStatus.BANNED.value:
raise Unauthorized("Account is banned.") raise Unauthorized("Account is banned.")
current_tenant = TenantAccountJoin.query.filter_by(account_id=account.id, current=True).first() current_tenant = db.session.query(TenantAccountJoin).filter_by(account_id=account.id, current=True).first()
if current_tenant: if current_tenant:
account.current_tenant_id = current_tenant.tenant_id account.set_tenant_id(current_tenant.tenant_id)
else: else:
available_ta = ( available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() db.session.query(TenantAccountJoin)
.filter_by(account_id=account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
) )
if not available_ta: if not available_ta:
return None return None
account.current_tenant_id = available_ta.tenant_id account.set_tenant_id(available_ta.tenant_id)
available_ta.current = True available_ta.current = True
db.session.commit() db.session.commit()
@ -297,9 +300,9 @@ class AccountService:
"""Link account integrate""" """Link account integrate"""
try: try:
# Query whether there is an existing binding record for the same provider # Query whether there is an existing binding record for the same provider
account_integrate: Optional[AccountIntegrate] = AccountIntegrate.query.filter_by( account_integrate: Optional[AccountIntegrate] = (
account_id=account.id, provider=provider db.session.query(AccountIntegrate).filter_by(account_id=account.id, provider=provider).first()
).first() )
if account_integrate: if account_integrate:
# If it exists, update the record # If it exists, update the record
@ -612,7 +615,10 @@ class TenantService:
): ):
"""Check if user have a workspace or not""" """Check if user have a workspace or not"""
available_ta = ( available_ta = (
TenantAccountJoin.query.filter_by(account_id=account.id).order_by(TenantAccountJoin.id.asc()).first() db.session.query(TenantAccountJoin)
.filter_by(account_id=account.id)
.order_by(TenantAccountJoin.id.asc())
.first()
) )
if available_ta: if available_ta:
@ -666,7 +672,7 @@ class TenantService:
if not tenant: if not tenant:
raise TenantNotFoundError("Tenant not found.") raise TenantNotFoundError("Tenant not found.")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
if ta: if ta:
tenant.role = ta.role tenant.role = ta.role
else: else:
@ -695,12 +701,12 @@ class TenantService:
if not tenant_account_join: if not tenant_account_join:
raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.") raise AccountNotLinkTenantError("Tenant not found or account is not a member of the tenant.")
else: else:
TenantAccountJoin.query.filter( db.session.query(TenantAccountJoin).filter(
TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id TenantAccountJoin.account_id == account.id, TenantAccountJoin.tenant_id != tenant_id
).update({"current": False}) ).update({"current": False})
tenant_account_join.current = True tenant_account_join.current = True
# Set the current tenant for the account # Set the current tenant for the account
account.current_tenant_id = tenant_account_join.tenant_id account.set_tenant_id(tenant_account_join.tenant_id)
db.session.commit() db.session.commit()
@staticmethod @staticmethod
@ -787,7 +793,7 @@ class TenantService:
if operator.id == member.id: if operator.id == member.id:
raise CannotOperateSelfError("Cannot operate self.") raise CannotOperateSelfError("Cannot operate self.")
ta_operator = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=operator.id).first() ta_operator = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=operator.id).first()
if not ta_operator or ta_operator.role not in perms[action]: if not ta_operator or ta_operator.role not in perms[action]:
raise NoPermissionError(f"No permission to {action} member.") raise NoPermissionError(f"No permission to {action} member.")
@ -800,7 +806,7 @@ class TenantService:
TenantService.check_member_permission(tenant, operator, account, "remove") TenantService.check_member_permission(tenant, operator, account, "remove")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
if not ta: if not ta:
raise MemberNotInTenantError("Member not in tenant.") raise MemberNotInTenantError("Member not in tenant.")
@ -812,14 +818,22 @@ class TenantService:
"""Update member role""" """Update member role"""
TenantService.check_member_permission(tenant, operator, member, "update") TenantService.check_member_permission(tenant, operator, member, "update")
target_member_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=member.id).first() target_member_join = (
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=member.id).first()
)
if not target_member_join:
raise MemberNotInTenantError("Member not in tenant.")
if target_member_join.role == new_role: if target_member_join.role == new_role:
raise RoleAlreadyAssignedError("The provided role is already assigned to the member.") raise RoleAlreadyAssignedError("The provided role is already assigned to the member.")
if new_role == "owner": if new_role == "owner":
# Find the current owner and change their role to 'admin' # Find the current owner and change their role to 'admin'
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first() current_owner_join = (
db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, role="owner").first()
)
if current_owner_join:
current_owner_join.role = "admin" current_owner_join.role = "admin"
# Update the role of the target member # Update the role of the target member
@ -837,7 +851,7 @@ class TenantService:
@staticmethod @staticmethod
def get_custom_config(tenant_id: str) -> dict: def get_custom_config(tenant_id: str) -> dict:
tenant = Tenant.query.filter(Tenant.id == tenant_id).one_or_404() tenant = db.get_or_404(Tenant, tenant_id)
return cast(dict, tenant.custom_config_dict) return cast(dict, tenant.custom_config_dict)
@ -959,7 +973,7 @@ class RegisterService:
TenantService.switch_tenant(account, tenant.id) TenantService.switch_tenant(account, tenant.id)
else: else:
TenantService.check_member_permission(tenant, inviter, account, "add") TenantService.check_member_permission(tenant, inviter, account, "add")
ta = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, account_id=account.id).first() ta = db.session.query(TenantAccountJoin).filter_by(tenant_id=tenant.id, account_id=account.id).first()
if not ta: if not ta:
TenantService.create_tenant_member(tenant, account, role) TenantService.create_tenant_member(tenant, account, role)

@ -4,7 +4,7 @@ from typing import cast
import pandas as pd import pandas as pd
from flask_login import current_user from flask_login import current_user
from sqlalchemy import or_ from sqlalchemy import or_, select
from werkzeug.datastructures import FileStorage from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -124,8 +124,9 @@ class AppAnnotationService:
if not app: if not app:
raise NotFound("App not found") raise NotFound("App not found")
if keyword: if keyword:
annotations = ( stmt = (
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) select(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.filter( .filter(
or_( or_(
MessageAnnotation.question.ilike("%{}%".format(keyword)), MessageAnnotation.question.ilike("%{}%".format(keyword)),
@ -133,14 +134,14 @@ class AppAnnotationService:
) )
) )
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
) )
else: else:
annotations = ( stmt = (
MessageAnnotation.query.filter(MessageAnnotation.app_id == app_id) select(MessageAnnotation)
.filter(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc()) .order_by(MessageAnnotation.created_at.desc(), MessageAnnotation.id.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
) )
annotations = db.paginate(select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False)
return annotations.items, annotations.total return annotations.items, annotations.total
@classmethod @classmethod
@ -325,13 +326,16 @@ class AppAnnotationService:
if not annotation: if not annotation:
raise NotFound("Annotation not found") raise NotFound("Annotation not found")
annotation_hit_histories = ( stmt = (
AppAnnotationHitHistory.query.filter( select(AppAnnotationHitHistory)
.filter(
AppAnnotationHitHistory.app_id == app_id, AppAnnotationHitHistory.app_id == app_id,
AppAnnotationHitHistory.annotation_id == annotation_id, AppAnnotationHitHistory.annotation_id == annotation_id,
) )
.order_by(AppAnnotationHitHistory.created_at.desc()) .order_by(AppAnnotationHitHistory.created_at.desc())
.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) )
annotation_hit_histories = db.paginate(
select=stmt, page=page, per_page=limit, max_per_page=100, error_out=False
) )
return annotation_hit_histories.items, annotation_hit_histories.total return annotation_hit_histories.items, annotation_hit_histories.total

@ -9,7 +9,7 @@ from collections import Counter
from typing import Any, Optional from typing import Any, Optional
from flask_login import current_user from flask_login import current_user
from sqlalchemy import func from sqlalchemy import func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
@ -79,11 +79,13 @@ from tasks.sync_website_document_indexing_task import sync_website_document_inde
class DatasetService: class DatasetService:
@staticmethod @staticmethod
def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False): def get_datasets(page, per_page, tenant_id=None, user=None, search=None, tag_ids=None, include_all=False):
query = Dataset.query.filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc()) query = select(Dataset).filter(Dataset.tenant_id == tenant_id).order_by(Dataset.created_at.desc())
if user: if user:
# get permitted dataset ids # get permitted dataset ids
dataset_permission = DatasetPermission.query.filter_by(account_id=user.id, tenant_id=tenant_id).all() dataset_permission = (
db.session.query(DatasetPermission).filter_by(account_id=user.id, tenant_id=tenant_id).all()
)
permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None permitted_dataset_ids = {dp.dataset_id for dp in dataset_permission} if dataset_permission else None
if user.current_role == TenantAccountRole.DATASET_OPERATOR: if user.current_role == TenantAccountRole.DATASET_OPERATOR:
@ -131,7 +133,7 @@ class DatasetService:
else: else:
return [], 0 return [], 0
datasets = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) datasets = db.paginate(select=query, page=page, per_page=per_page, max_per_page=100, error_out=False)
return datasets.items, datasets.total return datasets.items, datasets.total
@ -155,9 +157,10 @@ class DatasetService:
@staticmethod @staticmethod
def get_datasets_by_ids(ids, tenant_id): def get_datasets_by_ids(ids, tenant_id):
datasets = Dataset.query.filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id).paginate( stmt = select(Dataset).filter(Dataset.id.in_(ids), Dataset.tenant_id == tenant_id)
page=1, per_page=len(ids), max_per_page=len(ids), error_out=False
) datasets = db.paginate(select=stmt, page=1, per_page=len(ids), max_per_page=len(ids), error_out=False)
return datasets.items, datasets.total return datasets.items, datasets.total
@staticmethod @staticmethod
@ -176,7 +179,7 @@ class DatasetService:
retrieval_model: Optional[RetrievalModel] = None, retrieval_model: Optional[RetrievalModel] = None,
): ):
# check if dataset name already exists # check if dataset name already exists
if Dataset.query.filter_by(name=name, tenant_id=tenant_id).first(): if db.session.query(Dataset).filter_by(name=name, tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.")
embedding_model = None embedding_model = None
if indexing_technique == "high_quality": if indexing_technique == "high_quality":
@ -237,7 +240,7 @@ class DatasetService:
@staticmethod @staticmethod
def get_dataset(dataset_id) -> Optional[Dataset]: def get_dataset(dataset_id) -> Optional[Dataset]:
dataset: Optional[Dataset] = Dataset.query.filter_by(id=dataset_id).first() dataset: Optional[Dataset] = db.session.query(Dataset).filter_by(id=dataset_id).first()
return dataset return dataset
@staticmethod @staticmethod
@ -438,7 +441,7 @@ class DatasetService:
# update Retrieval model # update Retrieval model
filtered_data["retrieval_model"] = data["retrieval_model"] filtered_data["retrieval_model"] = data["retrieval_model"]
dataset.query.filter_by(id=dataset_id).update(filtered_data) db.session.query(Dataset).filter_by(id=dataset_id).update(filtered_data)
db.session.commit() db.session.commit()
if action: if action:
@ -462,7 +465,7 @@ class DatasetService:
@staticmethod @staticmethod
def dataset_use_check(dataset_id) -> bool: def dataset_use_check(dataset_id) -> bool:
count = AppDatasetJoin.query.filter_by(dataset_id=dataset_id).count() count = db.session.query(AppDatasetJoin).filter_by(dataset_id=dataset_id).count()
if count > 0: if count > 0:
return True return True
return False return False
@ -477,7 +480,9 @@ class DatasetService:
logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}") logging.debug(f"User {user.id} does not have permission to access dataset {dataset.id}")
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
if dataset.permission == "partial_members": if dataset.permission == "partial_members":
user_permission = DatasetPermission.query.filter_by(dataset_id=dataset.id, account_id=user.id).first() user_permission = (
db.session.query(DatasetPermission).filter_by(dataset_id=dataset.id, account_id=user.id).first()
)
if ( if (
not user_permission not user_permission
and dataset.tenant_id != user.current_tenant_id and dataset.tenant_id != user.current_tenant_id
@ -501,23 +506,24 @@ class DatasetService:
elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM: elif dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM:
if not any( if not any(
dp.dataset_id == dataset.id for dp in DatasetPermission.query.filter_by(account_id=user.id).all() dp.dataset_id == dataset.id
for dp in db.session.query(DatasetPermission).filter_by(account_id=user.id).all()
): ):
raise NoPermissionError("You do not have permission to access this dataset.") raise NoPermissionError("You do not have permission to access this dataset.")
@staticmethod @staticmethod
def get_dataset_queries(dataset_id: str, page: int, per_page: int): def get_dataset_queries(dataset_id: str, page: int, per_page: int):
dataset_queries = ( stmt = select(DatasetQuery).filter_by(dataset_id=dataset_id).order_by(db.desc(DatasetQuery.created_at))
DatasetQuery.query.filter_by(dataset_id=dataset_id)
.order_by(db.desc(DatasetQuery.created_at)) dataset_queries = db.paginate(select=stmt, page=page, per_page=per_page, max_per_page=100, error_out=False)
.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False)
)
return dataset_queries.items, dataset_queries.total return dataset_queries.items, dataset_queries.total
@staticmethod @staticmethod
def get_related_apps(dataset_id: str): def get_related_apps(dataset_id: str):
return ( return (
AppDatasetJoin.query.filter(AppDatasetJoin.dataset_id == dataset_id) db.session.query(AppDatasetJoin)
.filter(AppDatasetJoin.dataset_id == dataset_id)
.order_by(db.desc(AppDatasetJoin.created_at)) .order_by(db.desc(AppDatasetJoin.created_at))
.all() .all()
) )
@ -532,10 +538,14 @@ class DatasetService:
} }
# get recent 30 days auto disable logs # get recent 30 days auto disable logs
start_date = datetime.datetime.now() - datetime.timedelta(days=30) start_date = datetime.datetime.now() - datetime.timedelta(days=30)
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter( dataset_auto_disable_logs = (
db.session.query(DatasetAutoDisableLog)
.filter(
DatasetAutoDisableLog.dataset_id == dataset_id, DatasetAutoDisableLog.dataset_id == dataset_id,
DatasetAutoDisableLog.created_at >= start_date, DatasetAutoDisableLog.created_at >= start_date,
).all() )
.all()
)
if dataset_auto_disable_logs: if dataset_auto_disable_logs:
return { return {
"document_ids": [log.document_id for log in dataset_auto_disable_logs], "document_ids": [log.document_id for log in dataset_auto_disable_logs],
@ -875,7 +885,9 @@ class DocumentService:
@staticmethod @staticmethod
def get_documents_position(dataset_id): def get_documents_position(dataset_id):
document = Document.query.filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first() document = (
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
)
if document: if document:
return document.position + 1 return document.position + 1
else: else:
@ -1012,13 +1024,17 @@ class DocumentService:
} }
# check duplicate # check duplicate
if knowledge_config.duplicate: if knowledge_config.duplicate:
document = Document.query.filter_by( document = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id, dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
data_source_type="upload_file", data_source_type="upload_file",
enabled=True, enabled=True,
name=file_name, name=file_name,
).first() )
.first()
)
if document: if document:
document.dataset_process_rule_id = dataset_process_rule.id # type: ignore document.dataset_process_rule_id = dataset_process_rule.id # type: ignore
document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) document.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
@ -1056,12 +1072,16 @@ class DocumentService:
raise ValueError("No notion info list found.") raise ValueError("No notion info list found.")
exist_page_ids = [] exist_page_ids = []
exist_document = {} exist_document = {}
documents = Document.query.filter_by( documents = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id, dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
data_source_type="notion_import", data_source_type="notion_import",
enabled=True, enabled=True,
).all() )
.all()
)
if documents: if documents:
for document in documents: for document in documents:
data_source_info = json.loads(document.data_source_info) data_source_info = json.loads(document.data_source_info)
@ -1069,14 +1089,18 @@ class DocumentService:
exist_document[data_source_info["notion_page_id"]] = document.id exist_document[data_source_info["notion_page_id"]] = document.id
for notion_info in notion_info_list: for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id workspace_id = notion_info.workspace_id
data_source_binding = DataSourceOauthBinding.query.filter( data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_( db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
) )
).first() )
.first()
)
if not data_source_binding: if not data_source_binding:
raise ValueError("Data source binding not found.") raise ValueError("Data source binding not found.")
for page in notion_info.pages: for page in notion_info.pages:
@ -1208,12 +1232,16 @@ class DocumentService:
@staticmethod @staticmethod
def get_tenant_documents_count(): def get_tenant_documents_count():
documents_count = Document.query.filter( documents_count = (
db.session.query(Document)
.filter(
Document.completed_at.isnot(None), Document.completed_at.isnot(None),
Document.enabled == True, Document.enabled == True,
Document.archived == False, Document.archived == False,
Document.tenant_id == current_user.current_tenant_id, Document.tenant_id == current_user.current_tenant_id,
).count() )
.count()
)
return documents_count return documents_count
@staticmethod @staticmethod
@ -1280,14 +1308,18 @@ class DocumentService:
notion_info_list = document_data.data_source.info_list.notion_info_list notion_info_list = document_data.data_source.info_list.notion_info_list
for notion_info in notion_info_list: for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id workspace_id = notion_info.workspace_id
data_source_binding = DataSourceOauthBinding.query.filter( data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_( db.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id, DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
) )
).first() )
.first()
)
if not data_source_binding: if not data_source_binding:
raise ValueError("Data source binding not found.") raise ValueError("Data source binding not found.")
for page in notion_info.pages: for page in notion_info.pages:
@ -1330,7 +1362,7 @@ class DocumentService:
db.session.commit() db.session.commit()
# update document segment # update document segment
update_params = {DocumentSegment.status: "re_segment"} update_params = {DocumentSegment.status: "re_segment"}
DocumentSegment.query.filter_by(document_id=document.id).update(update_params) db.session.query(DocumentSegment).filter_by(document_id=document.id).update(update_params)
db.session.commit() db.session.commit()
# trigger async task # trigger async task
document_indexing_update_task.delay(document.dataset_id, document.id) document_indexing_update_task.delay(document.dataset_id, document.id)
@ -2013,7 +2045,8 @@ class SegmentService:
@classmethod @classmethod
def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset): def delete_segments(cls, segment_ids: list, document: Document, dataset: Dataset):
index_node_ids = ( index_node_ids = (
DocumentSegment.query.with_entities(DocumentSegment.index_node_id) db.session.query(DocumentSegment)
.with_entities(DocumentSegment.index_node_id)
.filter( .filter(
DocumentSegment.id.in_(segment_ids), DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset.id, DocumentSegment.dataset_id == dataset.id,
@ -2252,20 +2285,28 @@ class SegmentService:
def get_child_chunks( def get_child_chunks(
cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None cls, segment_id: str, document_id: str, dataset_id: str, page: int, limit: int, keyword: Optional[str] = None
): ):
query = ChildChunk.query.filter_by( query = (
select(ChildChunk)
.filter_by(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
dataset_id=dataset_id, dataset_id=dataset_id,
document_id=document_id, document_id=document_id,
segment_id=segment_id, segment_id=segment_id,
).order_by(ChildChunk.position.asc()) )
.order_by(ChildChunk.position.asc())
)
if keyword: if keyword:
query = query.where(ChildChunk.content.ilike(f"%{keyword}%")) query = query.where(ChildChunk.content.ilike(f"%{keyword}%"))
return query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False) return db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
@classmethod @classmethod
def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]: def get_child_chunk_by_id(cls, child_chunk_id: str, tenant_id: str) -> Optional[ChildChunk]:
"""Get a child chunk by its ID.""" """Get a child chunk by its ID."""
result = ChildChunk.query.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id).first() result = (
db.session.query(ChildChunk)
.filter(ChildChunk.id == child_chunk_id, ChildChunk.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, ChildChunk) else None return result if isinstance(result, ChildChunk) else None
@classmethod @classmethod
@ -2279,7 +2320,7 @@ class SegmentService:
limit: int = 20, limit: int = 20,
): ):
"""Get segments for a document with optional filtering.""" """Get segments for a document with optional filtering."""
query = DocumentSegment.query.filter( query = select(DocumentSegment).filter(
DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id DocumentSegment.document_id == document_id, DocumentSegment.tenant_id == tenant_id
) )
@ -2289,9 +2330,8 @@ class SegmentService:
if keyword: if keyword:
query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%")) query = query.filter(DocumentSegment.content.ilike(f"%{keyword}%"))
paginated_segments = query.order_by(DocumentSegment.position.asc()).paginate( query = query.order_by(DocumentSegment.position.asc())
page=page, per_page=limit, max_per_page=100, error_out=False paginated_segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
)
return paginated_segments.items, paginated_segments.total return paginated_segments.items, paginated_segments.total
@ -2331,9 +2371,11 @@ class SegmentService:
raise ValueError(ex.description) raise ValueError(ex.description)
# check segment # check segment
segment = DocumentSegment.query.filter( segment = (
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == user_id)
.first()
)
if not segment: if not segment:
raise NotFound("Segment not found.") raise NotFound("Segment not found.")
@ -2346,9 +2388,11 @@ class SegmentService:
@classmethod @classmethod
def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]: def get_segment_by_id(cls, segment_id: str, tenant_id: str) -> Optional[DocumentSegment]:
"""Get a segment by its ID.""" """Get a segment by its ID."""
result = DocumentSegment.query.filter( result = (
DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id db.session.query(DocumentSegment)
).first() .filter(DocumentSegment.id == segment_id, DocumentSegment.tenant_id == tenant_id)
.first()
)
return result if isinstance(result, DocumentSegment) else None return result if isinstance(result, DocumentSegment) else None

@ -5,6 +5,7 @@ from typing import Any, Optional, Union, cast
from urllib.parse import urlparse from urllib.parse import urlparse
import httpx import httpx
from sqlalchemy import select
from constants import HIDDEN_VALUE from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy from core.helper import ssrf_proxy
@ -24,14 +25,20 @@ from services.errors.dataset import DatasetNameDuplicateError
class ExternalDatasetService: class ExternalDatasetService:
@staticmethod @staticmethod
def get_external_knowledge_apis(page, per_page, tenant_id, search=None) -> tuple[list[ExternalKnowledgeApis], int]: def get_external_knowledge_apis(
query = ExternalKnowledgeApis.query.filter(ExternalKnowledgeApis.tenant_id == tenant_id).order_by( page, per_page, tenant_id, search=None
ExternalKnowledgeApis.created_at.desc() ) -> tuple[list[ExternalKnowledgeApis], int | None]:
query = (
select(ExternalKnowledgeApis)
.filter(ExternalKnowledgeApis.tenant_id == tenant_id)
.order_by(ExternalKnowledgeApis.created_at.desc())
) )
if search: if search:
query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%")) query = query.filter(ExternalKnowledgeApis.name.ilike(f"%{search}%"))
external_knowledge_apis = query.paginate(page=page, per_page=per_page, max_per_page=100, error_out=False) external_knowledge_apis = db.paginate(
select=query, page=page, per_page=per_page, max_per_page=100, error_out=False
)
return external_knowledge_apis.items, external_knowledge_apis.total return external_knowledge_apis.items, external_knowledge_apis.total
@ -92,18 +99,18 @@ class ExternalDatasetService:
@staticmethod @staticmethod
def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis: def get_external_knowledge_api(external_knowledge_api_id: str) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( external_knowledge_api: Optional[ExternalKnowledgeApis] = (
id=external_knowledge_api_id db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id).first()
).first() )
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
return external_knowledge_api return external_knowledge_api
@staticmethod @staticmethod
def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis: def update_external_knowledge_api(tenant_id, user_id, external_knowledge_api_id, args) -> ExternalKnowledgeApis:
external_knowledge_api: Optional[ExternalKnowledgeApis] = ExternalKnowledgeApis.query.filter_by( external_knowledge_api: Optional[ExternalKnowledgeApis] = (
id=external_knowledge_api_id, tenant_id=tenant_id db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
).first() )
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE: if args.get("settings") and args.get("settings").get("api_key") == HIDDEN_VALUE:
@ -120,9 +127,9 @@ class ExternalDatasetService:
@staticmethod @staticmethod
def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str): def delete_external_knowledge_api(tenant_id: str, external_knowledge_api_id: str):
external_knowledge_api = ExternalKnowledgeApis.query.filter_by( external_knowledge_api = (
id=external_knowledge_api_id, tenant_id=tenant_id db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
).first() )
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
@ -131,25 +138,29 @@ class ExternalDatasetService:
@staticmethod @staticmethod
def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]: def external_knowledge_api_use_check(external_knowledge_api_id: str) -> tuple[bool, int]:
count = ExternalKnowledgeBindings.query.filter_by(external_knowledge_api_id=external_knowledge_api_id).count() count = (
db.session.query(ExternalKnowledgeBindings)
.filter_by(external_knowledge_api_id=external_knowledge_api_id)
.count()
)
if count > 0: if count > 0:
return True, count return True, count
return False, 0 return False, 0
@staticmethod @staticmethod
def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings: def get_external_knowledge_binding_with_dataset_id(tenant_id: str, dataset_id: str) -> ExternalKnowledgeBindings:
external_knowledge_binding: Optional[ExternalKnowledgeBindings] = ExternalKnowledgeBindings.query.filter_by( external_knowledge_binding: Optional[ExternalKnowledgeBindings] = (
dataset_id=dataset_id, tenant_id=tenant_id db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
).first() )
if not external_knowledge_binding: if not external_knowledge_binding:
raise ValueError("external knowledge binding not found") raise ValueError("external knowledge binding not found")
return external_knowledge_binding return external_knowledge_binding
@staticmethod @staticmethod
def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict): def document_create_args_validate(tenant_id: str, external_knowledge_api_id: str, process_parameter: dict):
external_knowledge_api = ExternalKnowledgeApis.query.filter_by( external_knowledge_api = (
id=external_knowledge_api_id, tenant_id=tenant_id db.session.query(ExternalKnowledgeApis).filter_by(id=external_knowledge_api_id, tenant_id=tenant_id).first()
).first() )
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
settings = json.loads(external_knowledge_api.settings) settings = json.loads(external_knowledge_api.settings)
@ -212,11 +223,13 @@ class ExternalDatasetService:
@staticmethod @staticmethod
def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset: def create_external_dataset(tenant_id: str, user_id: str, args: dict) -> Dataset:
# check if dataset name already exists # check if dataset name already exists
if Dataset.query.filter_by(name=args.get("name"), tenant_id=tenant_id).first(): if db.session.query(Dataset).filter_by(name=args.get("name"), tenant_id=tenant_id).first():
raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.") raise DatasetNameDuplicateError(f"Dataset with name {args.get('name')} already exists.")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by( external_knowledge_api = (
id=args.get("external_knowledge_api_id"), tenant_id=tenant_id db.session.query(ExternalKnowledgeApis)
).first() .filter_by(id=args.get("external_knowledge_api_id"), tenant_id=tenant_id)
.first()
)
if external_knowledge_api is None: if external_knowledge_api is None:
raise ValueError("api template not found") raise ValueError("api template not found")
@ -254,15 +267,17 @@ class ExternalDatasetService:
external_retrieval_parameters: dict, external_retrieval_parameters: dict,
metadata_condition: Optional[MetadataCondition] = None, metadata_condition: Optional[MetadataCondition] = None,
) -> list: ) -> list:
external_knowledge_binding = ExternalKnowledgeBindings.query.filter_by( external_knowledge_binding = (
dataset_id=dataset_id, tenant_id=tenant_id db.session.query(ExternalKnowledgeBindings).filter_by(dataset_id=dataset_id, tenant_id=tenant_id).first()
).first() )
if not external_knowledge_binding: if not external_knowledge_binding:
raise ValueError("external knowledge binding not found") raise ValueError("external knowledge binding not found")
external_knowledge_api = ExternalKnowledgeApis.query.filter_by( external_knowledge_api = (
id=external_knowledge_binding.external_knowledge_api_id db.session.query(ExternalKnowledgeApis)
).first() .filter_by(id=external_knowledge_binding.external_knowledge_api_id)
.first()
)
if not external_knowledge_api: if not external_knowledge_api:
raise ValueError("external api template not found") raise ValueError("external api template not found")

@ -69,6 +69,7 @@ class HitTestingService:
query: str, query: str,
account: Account, account: Account,
external_retrieval_model: dict, external_retrieval_model: dict,
metadata_filtering_conditions: dict,
) -> dict: ) -> dict:
if dataset.provider != "external": if dataset.provider != "external":
return { return {
@ -82,6 +83,7 @@ class HitTestingService:
dataset_id=dataset.id, dataset_id=dataset.id,
query=cls.escape_query_for_search(query), query=cls.escape_query_for_search(query),
external_retrieval_model=external_retrieval_model, external_retrieval_model=external_retrieval_model,
metadata_filtering_conditions=metadata_filtering_conditions,
) )
end = time.perf_counter() end = time.perf_counter()

@ -177,6 +177,21 @@ class MessageService:
return feedback return feedback
@classmethod
def get_all_messages_feedbacks(cls, app_model: App, page: int, limit: int):
"""Get all feedbacks of an app"""
offset = (page - 1) * limit
feedbacks = (
db.session.query(MessageFeedback)
.filter(MessageFeedback.app_id == app_model.id)
.order_by(MessageFeedback.created_at.desc(), MessageFeedback.id.desc())
.limit(limit)
.offset(offset)
.all()
)
return [record.to_dict() for record in feedbacks]
@classmethod @classmethod
def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str): def get_message(cls, app_model: App, user: Optional[Union[Account, EndUser]], message_id: str):
message = ( message = (

@ -20,9 +20,11 @@ class MetadataService:
@staticmethod @staticmethod
def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata: def create_metadata(dataset_id: str, metadata_args: MetadataArgs) -> DatasetMetadata:
# check if metadata name already exists # check if metadata name already exists
if DatasetMetadata.query.filter_by( if (
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name db.session.query(DatasetMetadata)
).first(): .filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=metadata_args.name)
.first()
):
raise ValueError("Metadata name already exists.") raise ValueError("Metadata name already exists.")
for field in BuiltInField: for field in BuiltInField:
if field.value == metadata_args.name: if field.value == metadata_args.name:
@ -42,16 +44,18 @@ class MetadataService:
def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore def update_metadata_name(dataset_id: str, metadata_id: str, name: str) -> DatasetMetadata: # type: ignore
lock_key = f"dataset_metadata_lock_{dataset_id}" lock_key = f"dataset_metadata_lock_{dataset_id}"
# check if metadata name already exists # check if metadata name already exists
if DatasetMetadata.query.filter_by( if (
tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name db.session.query(DatasetMetadata)
).first(): .filter_by(tenant_id=current_user.current_tenant_id, dataset_id=dataset_id, name=name)
.first()
):
raise ValueError("Metadata name already exists.") raise ValueError("Metadata name already exists.")
for field in BuiltInField: for field in BuiltInField:
if field.value == name: if field.value == name:
raise ValueError("Metadata name already exists in Built-in fields.") raise ValueError("Metadata name already exists in Built-in fields.")
try: try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
if metadata is None: if metadata is None:
raise ValueError("Metadata not found.") raise ValueError("Metadata not found.")
old_name = metadata.name old_name = metadata.name
@ -60,7 +64,9 @@ class MetadataService:
metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) metadata.updated_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
# update related documents # update related documents
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
if dataset_metadata_bindings: if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings] document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids) documents = DocumentService.get_document_by_ids(document_ids)
@ -82,13 +88,15 @@ class MetadataService:
lock_key = f"dataset_metadata_lock_{dataset_id}" lock_key = f"dataset_metadata_lock_{dataset_id}"
try: try:
MetadataService.knowledge_base_metadata_lock_check(dataset_id, None) MetadataService.knowledge_base_metadata_lock_check(dataset_id, None)
metadata = DatasetMetadata.query.filter_by(id=metadata_id).first() metadata = db.session.query(DatasetMetadata).filter_by(id=metadata_id).first()
if metadata is None: if metadata is None:
raise ValueError("Metadata not found.") raise ValueError("Metadata not found.")
db.session.delete(metadata) db.session.delete(metadata)
# deal related documents # deal related documents
dataset_metadata_bindings = DatasetMetadataBinding.query.filter_by(metadata_id=metadata_id).all() dataset_metadata_bindings = (
db.session.query(DatasetMetadataBinding).filter_by(metadata_id=metadata_id).all()
)
if dataset_metadata_bindings: if dataset_metadata_bindings:
document_ids = [binding.document_id for binding in dataset_metadata_bindings] document_ids = [binding.document_id for binding in dataset_metadata_bindings]
documents = DocumentService.get_document_by_ids(document_ids) documents = DocumentService.get_document_by_ids(document_ids)
@ -193,7 +201,7 @@ class MetadataService:
db.session.add(document) db.session.add(document)
db.session.commit() db.session.commit()
# deal metadata binding # deal metadata binding
DatasetMetadataBinding.query.filter_by(document_id=operation.document_id).delete() db.session.query(DatasetMetadataBinding).filter_by(document_id=operation.document_id).delete()
for metadata_value in operation.metadata_list: for metadata_value in operation.metadata_list:
dataset_metadata_binding = DatasetMetadataBinding( dataset_metadata_binding = DatasetMetadataBinding(
tenant_id=current_user.current_tenant_id, tenant_id=current_user.current_tenant_id,
@ -230,9 +238,9 @@ class MetadataService:
"id": item.get("id"), "id": item.get("id"),
"name": item.get("name"), "name": item.get("name"),
"type": item.get("type"), "type": item.get("type"),
"count": DatasetMetadataBinding.query.filter_by( "count": db.session.query(DatasetMetadataBinding)
metadata_id=item.get("id"), dataset_id=dataset.id .filter_by(metadata_id=item.get("id"), dataset_id=dataset.id)
).count(), .count(),
} }
for item in dataset.doc_metadata or [] for item in dataset.doc_metadata or []
if item.get("id") != "built-in" if item.get("id") != "built-in"

@ -1,3 +1,4 @@
import logging
from typing import Optional from typing import Optional
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
@ -12,17 +13,27 @@ from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegm
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode from services.entities.knowledge_entities.knowledge_entities import ParentMode
_logger = logging.getLogger(__name__)
class VectorService: class VectorService:
@classmethod @classmethod
def create_segments_vector( def create_segments_vector(
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
): ):
documents = [] documents: list[Document] = []
document: Document | None = None
for segment in segments: for segment in segments:
if doc_form == IndexType.PARENT_CHILD_INDEX: if doc_form == IndexType.PARENT_CHILD_INDEX:
document = DatasetDocument.query.filter_by(id=segment.document_id).first() document = db.session.query(DatasetDocument).filter_by(id=segment.document_id).first()
if not document:
_logger.warning(
"Expected DatasetDocument record to exist, but none was found, document_id=%s, segment_id=%s",
segment.document_id,
segment.id,
)
continue
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
db.session.query(DatasetProcessRule) db.session.query(DatasetProcessRule)

@ -41,7 +41,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
DocumentSegment.status: "indexing", DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
} }
DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params)
db.session.commit() db.session.commit()
document = Document( document = Document(
page_content=segment.content, page_content=segment.content,
@ -78,7 +78,7 @@ def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]]
DocumentSegment.status: "completed", DocumentSegment.status: "completed",
DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None), DocumentSegment.completed_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
} }
DocumentSegment.query.filter_by(id=segment.id).update(update_params) db.session.query(DocumentSegment).filter_by(id=segment.id).update(update_params)
db.session.commit() db.session.commit()
end_at = time.perf_counter() end_at = time.perf_counter()

@ -24,7 +24,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
start_at = time.perf_counter() start_at = time.perf_counter()
try: try:
dataset = Dataset.query.filter_by(id=dataset_id).first() dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset: if not dataset:
raise Exception("Dataset not found") raise Exception("Dataset not found")

@ -44,14 +44,18 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
page_id = data_source_info["notion_page_id"] page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"] page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"] page_edited_time = data_source_info["last_edited_time"]
data_source_binding = DataSourceOauthBinding.query.filter( data_source_binding = (
db.session.query(DataSourceOauthBinding)
.filter(
db.and_( db.and_(
DataSourceOauthBinding.tenant_id == document.tenant_id, DataSourceOauthBinding.tenant_id == document.tenant_id,
DataSourceOauthBinding.provider == "notion", DataSourceOauthBinding.provider == "notion",
DataSourceOauthBinding.disabled == False, DataSourceOauthBinding.disabled == False,
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"', DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
) )
).first() )
.first()
)
if not data_source_binding: if not data_source_binding:
raise ValueError("Data source binding not found.") raise ValueError("Data source binding not found.")

@ -2,9 +2,10 @@ import os
import pytest import pytest
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from api.core.rag.datasource.vdb.field import Field
from elasticsearch import Elasticsearch from elasticsearch import Elasticsearch
from core.rag.datasource.vdb.field import Field
class MockIndicesClient: class MockIndicesClient:
def __init__(self): def __init__(self):

@ -1,49 +1,28 @@
import os import os
from textwrap import dedent
import pytest
from flask import Flask from flask import Flask
from yarl import URL from yarl import URL
from configs.app_config import DifyConfig from configs.app_config import DifyConfig
EXAMPLE_ENV_FILENAME = ".env"
def test_dify_config(monkeypatch):
@pytest.fixture
def example_env_file(tmp_path, monkeypatch) -> str:
monkeypatch.chdir(tmp_path)
file_path = tmp_path.joinpath(EXAMPLE_ENV_FILENAME)
file_path.write_text(
dedent(
"""
CONSOLE_API_URL=https://example.com
CONSOLE_WEB_URL=https://example.com
HTTP_REQUEST_MAX_WRITE_TIMEOUT=30
"""
)
)
return str(file_path)
def test_dify_config_undefined_entry(example_env_file):
# NOTE: See https://github.com/microsoft/pylance-release/issues/6099 for more details about this type error.
# load dotenv file with pydantic-settings
config = DifyConfig(_env_file=example_env_file)
# entries not defined in app settings
with pytest.raises(TypeError):
# TypeError: 'AppSettings' object is not subscriptable
assert config["LOG_LEVEL"] == "INFO"
# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected.
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.
def test_dify_config(example_env_file):
# clear system environment variables # clear system environment variables
os.environ.clear() os.environ.clear()
# Set environment variables using monkeypatch
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
monkeypatch.setenv("HTTP_REQUEST_MAX_READ_TIMEOUT", "600")
# load dotenv file with pydantic-settings # load dotenv file with pydantic-settings
config = DifyConfig(_env_file=example_env_file) config = DifyConfig()
# constant values # constant values
assert config.COMMIT_SHA == "" assert config.COMMIT_SHA == ""
@ -54,7 +33,7 @@ def test_dify_config(example_env_file):
assert config.SENTRY_TRACES_SAMPLE_RATE == 1.0 assert config.SENTRY_TRACES_SAMPLE_RATE == 1.0
# annotated field with default value # annotated field with default value
assert config.HTTP_REQUEST_MAX_READ_TIMEOUT == 60 assert config.HTTP_REQUEST_MAX_READ_TIMEOUT == 600
# annotated field with configured value # annotated field with configured value
assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30 assert config.HTTP_REQUEST_MAX_WRITE_TIMEOUT == 30
@ -64,11 +43,24 @@ def test_dify_config(example_env_file):
# NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected. # NOTE: If there is a `.env` file in your Workspace, this test might not succeed as expected.
# This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`. # This is due to `pymilvus` loading all the variables from the `.env` file into `os.environ`.
def test_flask_configs(example_env_file): def test_flask_configs(monkeypatch):
flask_app = Flask("app") flask_app = Flask("app")
# clear system environment variables # clear system environment variables
os.environ.clear() os.environ.clear()
flask_app.config.from_mapping(DifyConfig(_env_file=example_env_file).model_dump()) # pyright: ignore
# Set environment variables using monkeypatch
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
monkeypatch.setenv("WEB_API_CORS_ALLOW_ORIGINS", "http://127.0.0.1:3000,*")
monkeypatch.setenv("CODE_EXECUTION_ENDPOINT", "http://127.0.0.1:8194/")
flask_app.config.from_mapping(DifyConfig().model_dump()) # pyright: ignore
config = flask_app.config config = flask_app.config
# configs read from pydantic-settings # configs read from pydantic-settings
@ -83,7 +75,7 @@ def test_flask_configs(example_env_file):
# fallback to alias choices value as CONSOLE_API_URL # fallback to alias choices value as CONSOLE_API_URL
assert config["FILES_URL"] == "https://example.com" assert config["FILES_URL"] == "https://example.com"
assert config["SQLALCHEMY_DATABASE_URI"] == "postgresql://postgres:@localhost:5432/dify" assert config["SQLALCHEMY_DATABASE_URI"] == "postgresql://postgres:postgres@localhost:5432/dify"
assert config["SQLALCHEMY_ENGINE_OPTIONS"] == { assert config["SQLALCHEMY_ENGINE_OPTIONS"] == {
"connect_args": { "connect_args": {
"options": "-c timezone=UTC", "options": "-c timezone=UTC",
@ -96,13 +88,47 @@ def test_flask_configs(example_env_file):
assert config["CONSOLE_WEB_URL"] == "https://example.com" assert config["CONSOLE_WEB_URL"] == "https://example.com"
assert config["CONSOLE_CORS_ALLOW_ORIGINS"] == ["https://example.com"] assert config["CONSOLE_CORS_ALLOW_ORIGINS"] == ["https://example.com"]
assert config["WEB_API_CORS_ALLOW_ORIGINS"] == ["*"] assert config["WEB_API_CORS_ALLOW_ORIGINS"] == ["http://127.0.0.1:3000", "*"]
assert str(config["CODE_EXECUTION_ENDPOINT"]) == "http://sandbox:8194/" assert str(config["CODE_EXECUTION_ENDPOINT"]) == "http://127.0.0.1:8194/"
assert str(URL(str(config["CODE_EXECUTION_ENDPOINT"])) / "v1") == "http://sandbox:8194/v1" assert str(URL(str(config["CODE_EXECUTION_ENDPOINT"])) / "v1") == "http://127.0.0.1:8194/v1"
def test_inner_api_config_exist(): def test_inner_api_config_exist(monkeypatch):
# Set environment variables using monkeypatch
monkeypatch.setenv("CONSOLE_API_URL", "https://example.com")
monkeypatch.setenv("CONSOLE_WEB_URL", "https://example.com")
monkeypatch.setenv("HTTP_REQUEST_MAX_WRITE_TIMEOUT", "30")
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
monkeypatch.setenv("INNER_API_KEY", "test-inner-api-key")
config = DifyConfig() config = DifyConfig()
assert config.INNER_API is False assert config.INNER_API is False
assert config.INNER_API_KEY is None assert isinstance(config.INNER_API_KEY, str)
assert len(config.INNER_API_KEY) > 0
def test_db_extras_options_merging(monkeypatch):
"""Test that DB_EXTRAS options are properly merged with default timezone setting"""
# Set environment variables
monkeypatch.setenv("DB_USERNAME", "postgres")
monkeypatch.setenv("DB_PASSWORD", "postgres")
monkeypatch.setenv("DB_HOST", "localhost")
monkeypatch.setenv("DB_PORT", "5432")
monkeypatch.setenv("DB_DATABASE", "dify")
monkeypatch.setenv("DB_EXTRAS", "options=-c search_path=myschema")
# Create config
config = DifyConfig()
# Get engine options
engine_options = config.SQLALCHEMY_ENGINE_OPTIONS
# Verify options contains both search_path and timezone
options = engine_options["connect_args"]["options"]
assert "search_path=myschema" in options
assert "timezone=UTC" in options

File diff suppressed because it is too large Load Diff

@ -0,0 +1,10 @@
#!/bin/bash
set -x
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/.."
uv --directory api run \
flask run --host 0.0.0.0 --port=5001 --debug

@ -0,0 +1,11 @@
#!/bin/bash
set -x
SCRIPT_DIR="$(dirname "$(realpath "$0")")"
cd "$SCRIPT_DIR/.."
uv --directory api run \
celery -A app.celery worker \
-P gevent -c 1 --loglevel INFO -Q dataset,generation,mail,ops_trace,app_deletion

@ -6,10 +6,12 @@ NEXT_PUBLIC_EDITION=SELF_HOSTED
# different from api or web app domain. # different from api or web app domain.
# example: http://cloud.dify.ai/console/api # example: http://cloud.dify.ai/console/api
NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api
NEXT_PUBLIC_WEB_PREFIX=http://localhost:3000
# The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from # The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from
# console or api domain. # console or api domain.
# example: http://udify.app/api # example: http://udify.app/api
NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api
NEXT_PUBLIC_PUBLIC_WEB_PREFIX=http://localhost:3000
# The API PREFIX for MARKETPLACE # The API PREFIX for MARKETPLACE
NEXT_PUBLIC_MARKETPLACE_API_PREFIX=https://marketplace.dify.ai/api/v1 NEXT_PUBLIC_MARKETPLACE_API_PREFIX=https://marketplace.dify.ai/api/v1
# The URL for MARKETPLACE # The URL for MARKETPLACE

@ -31,10 +31,12 @@ NEXT_PUBLIC_EDITION=SELF_HOSTED
# different from api or web app domain. # different from api or web app domain.
# example: http://cloud.dify.ai/console/api # example: http://cloud.dify.ai/console/api
NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api
NEXT_PUBLIC_WEB_PREFIX=http://localhost:3000
# The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from # The URL for Web APP, refers to the Web App base URL of WEB service if web app domain is different from
# console or api domain. # console or api domain.
# example: http://udify.app/api # example: http://udify.app/api
NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api
NEXT_PUBLIC_PUBLIC_WEB_PREFIX=http://localhost:3000
# SENTRY # SENTRY
NEXT_PUBLIC_SENTRY_DSN= NEXT_PUBLIC_SENTRY_DSN=

@ -16,7 +16,7 @@ import AppsContext, { useAppContext } from '@/context/app-context'
import type { HtmlContentProps } from '@/app/components/base/popover' import type { HtmlContentProps } from '@/app/components/base/popover'
import CustomPopover from '@/app/components/base/popover' import CustomPopover from '@/app/components/base/popover'
import Divider from '@/app/components/base/divider' import Divider from '@/app/components/base/divider'
import { basePath } from '@/utils/var' import { WEB_PREFIX } from '@/config'
import { getRedirection } from '@/utils/app-redirection' import { getRedirection } from '@/utils/app-redirection'
import { useProviderContext } from '@/context/provider-context' import { useProviderContext } from '@/context/provider-context'
import { NEED_REFRESH_APP_LIST_KEY } from '@/config' import { NEED_REFRESH_APP_LIST_KEY } from '@/config'
@ -217,7 +217,7 @@ const AppCard = ({ app, onRefresh }: AppCardProps) => {
try { try {
const { installed_apps }: any = await fetchInstalledAppList(app.id) || {} const { installed_apps }: any = await fetchInstalledAppList(app.id) || {}
if (installed_apps?.length > 0) if (installed_apps?.length > 0)
window.open(`${basePath}/explore/installed/${installed_apps[0].id}`, '_blank') window.open(`${WEB_PREFIX}/explore/installed/${installed_apps[0].id}`, '_blank')
else else
throw new Error('No app found in Explore') throw new Error('No app found in Explore')
} }

@ -85,7 +85,7 @@ const Container = () => {
return ( return (
<div ref={containerRef} className='scroll-container relative flex grow flex-col overflow-y-auto bg-background-body'> <div ref={containerRef} className='scroll-container relative flex grow flex-col overflow-y-auto bg-background-body'>
<div className='sticky top-0 z-10 flex flex-wrap justify-between gap-y-2 bg-background-body px-12 pb-2 pt-4 leading-[56px]'> <div className='sticky top-0 z-10 flex flex-wrap items-center justify-between gap-y-2 bg-background-body px-12 pb-2 pt-4 leading-[56px]'>
<TabSliderNew <TabSliderNew
value={activeTab} value={activeTab}
onChange={newActiveTab => setActiveTab(newActiveTab)} onChange={newActiveTab => setActiveTab(newActiveTab)}

@ -121,7 +121,7 @@ const Doc = ({ apiBaseUrl }: DocProps) => {
</button> </button>
)} )}
</div> </div>
<article className={cn('prose-xl prose mx-1 rounded-t-xl bg-background-default px-4 pt-16 sm:mx-12', theme === Theme.dark && 'dark:prose-invert')}> <article className={cn('prose-xl prose mx-1 rounded-t-xl bg-background-default px-4 pt-16 sm:mx-12', theme === Theme.dark && 'prose-invert')}>
{Template} {Template}
</article> </article>
</div> </div>

@ -1,6 +1,6 @@
'use client' 'use client'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { basePath } from '@/utils/var' import Link from 'next/link'
import { import {
RiAddLine, RiAddLine,
RiArrowRightLine, RiArrowRightLine,
@ -18,7 +18,7 @@ const CreateAppCard = (
<div className='bg-background-default-dimm flex min-h-[160px] flex-col rounded-xl border-[0.5px] <div className='bg-background-default-dimm flex min-h-[160px] flex-col rounded-xl border-[0.5px]
border-components-panel-border transition-all duration-200 ease-in-out' border-components-panel-border transition-all duration-200 ease-in-out'
> >
<a ref={ref} className='group flex grow cursor-pointer items-start p-4' href={`${basePath}/datasets/create`}> <Link ref={ref} className='group flex grow cursor-pointer items-start p-4' href={'/datasets/create'}>
<div className='flex items-center gap-3'> <div className='flex items-center gap-3'>
<div className='flex h-10 w-10 items-center justify-center rounded-lg border border-dashed border-divider-regular bg-background-default-lighter <div className='flex h-10 w-10 items-center justify-center rounded-lg border border-dashed border-divider-regular bg-background-default-lighter
p-2 group-hover:border-solid group-hover:border-effects-highlight group-hover:bg-background-default-dodge' p-2 group-hover:border-solid group-hover:border-effects-highlight group-hover:bg-background-default-dodge'
@ -27,12 +27,12 @@ const CreateAppCard = (
</div> </div>
<div className='system-md-semibold text-text-secondary group-hover:text-text-accent'>{t('dataset.createDataset')}</div> <div className='system-md-semibold text-text-secondary group-hover:text-text-accent'>{t('dataset.createDataset')}</div>
</div> </div>
</a> </Link>
<div className='system-xs-regular p-4 pt-0 text-text-tertiary'>{t('dataset.createDatasetIntro')}</div> <div className='system-xs-regular p-4 pt-0 text-text-tertiary'>{t('dataset.createDatasetIntro')}</div>
<a className='group flex cursor-pointer items-center gap-1 rounded-b-xl border-t-[0.5px] border-divider-subtle p-4' href={`${basePath}/datasets/connect`}> <Link className='group flex cursor-pointer items-center gap-1 rounded-b-xl border-t-[0.5px] border-divider-subtle p-4' href={'datasets/connect'}>
<div className='system-xs-medium text-text-tertiary group-hover:text-text-accent'>{t('dataset.connectDataset')}</div> <div className='system-xs-medium text-text-tertiary group-hover:text-text-accent'>{t('dataset.connectDataset')}</div>
<RiArrowRightLine className='h-3.5 w-3.5 text-text-tertiary group-hover:text-text-accent' /> <RiArrowRightLine className='h-3.5 w-3.5 text-text-tertiary group-hover:text-text-accent' />
</a> </Link>
</div> </div>
) )
} }

@ -314,7 +314,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
</Property> </Property>
<Property name='indexing_technique' type='string' key='indexing_technique'> <Property name='indexing_technique' type='string' key='indexing_technique'>
Index technique (optional) Index technique (optional)
If this is not set, embedding_model, embedding_provider_name and retrieval_model will be set to null If this is not set, embedding_model, embedding_model_provider and retrieval_model will be set to null
- <code>high_quality</code> High quality - <code>high_quality</code> High quality
- <code>economy</code> Economy - <code>economy</code> Economy
</Property> </Property>
@ -338,7 +338,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
<Property name='embedding_model' type='str' key='embedding_model'> <Property name='embedding_model' type='str' key='embedding_model'>
Embedding model name (optional) Embedding model name (optional)
</Property> </Property>
<Property name='embedding_provider_name' type='str' key='embedding_provider_name'> <Property name='embedding_model_provider' type='str' key='embedding_model_provider'>
Embedding model provider name (optional) Embedding model provider name (optional)
</Property> </Property>
<Property name='retrieval_model' type='object' key='retrieval_model'> <Property name='retrieval_model' type='object' key='retrieval_model'>
@ -1040,10 +1040,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
``` ```
</CodeGroup> </CodeGroup>
<CodeGroup title="Response"> <CodeGroup title="Response">
```json {{ title: 'Response' }} ```text {{ title: 'Response' }}
{ 204 No Content
"result": "success"
}
``` ```
</CodeGroup> </CodeGroup>
</Col> </Col>
@ -1392,10 +1390,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
``` ```
</CodeGroup> </CodeGroup>
<CodeGroup title="Response"> <CodeGroup title="Response">
```json {{ title: 'Response' }} ```text {{ title: 'Response' }}
{ 204 No Content
"result": "success"
}
``` ```
</CodeGroup> </CodeGroup>
</Col> </Col>
@ -1677,10 +1673,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
``` ```
</CodeGroup> </CodeGroup>
<CodeGroup title="Response"> <CodeGroup title="Response">
```json {{ title: 'Response' }} ```text {{ title: 'Response' }}
{ 204 No Content
"result": "success"
}
``` ```
</CodeGroup> </CodeGroup>
</Col> </Col>

@ -337,7 +337,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
<Property name='embedding_model' type='str' key='embedding_model'> <Property name='embedding_model' type='str' key='embedding_model'>
埋め込みモデル名(任意) 埋め込みモデル名(任意)
</Property> </Property>
<Property name='embedding_provider_name' type='str' key='embedding_provider_name'> <Property name='embedding_model_provider' type='str' key='embedding_model_provider'>
埋め込みモデルのプロバイダ名(任意) 埋め込みモデルのプロバイダ名(任意)
</Property> </Property>
<Property name='retrieval_model' type='object' key='retrieval_model'> <Property name='retrieval_model' type='object' key='retrieval_model'>
@ -501,7 +501,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
``` ```
</CodeGroup> </CodeGroup>
<CodeGroup title="レスポンス"> <CodeGroup title="レスポンス">
```text {{ title: 'Response' }} ```text {{ title: 'レスポンス' }}
204 No Content 204 No Content
``` ```
</CodeGroup> </CodeGroup>
@ -797,10 +797,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
``` ```
</CodeGroup> </CodeGroup>
<CodeGroup title="レスポンス"> <CodeGroup title="レスポンス">
```json {{ title: 'Response' }} ```text {{ title: 'レスポンス' }}
{ 204 No Content
"result": "success"
}
``` ```
</CodeGroup> </CodeGroup>
</Col> </Col>
@ -1149,10 +1147,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
``` ```
</CodeGroup> </CodeGroup>
<CodeGroup title="レスポンス"> <CodeGroup title="レスポンス">
```json {{ title: 'Response' }} ```text {{ title: 'レスポンス' }}
{ 204 No Content
"result": "success"
}
``` ```
</CodeGroup> </CodeGroup>
</Col> </Col>
@ -1434,10 +1430,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
``` ```
</CodeGroup> </CodeGroup>
<CodeGroup title="レスポンス"> <CodeGroup title="レスポンス">
```json {{ title: 'Response' }} ```text {{ title: 'レスポンス' }}
{ 204 No Content
"result": "success"
}
``` ```
</CodeGroup> </CodeGroup>
</Col> </Col>

@ -341,7 +341,7 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
<Property name='embedding_model' type='str' key='embedding_model'> <Property name='embedding_model' type='str' key='embedding_model'>
Embedding 模型名称 Embedding 模型名称
</Property> </Property>
<Property name='embedding_provider_name' type='str' key='embedding_provider_name'> <Property name='embedding_model_provider' type='str' key='embedding_model_provider'>
Embedding 模型供应商 Embedding 模型供应商
</Property> </Property>
<Property name='retrieval_model' type='object' key='retrieval_model'> <Property name='retrieval_model' type='object' key='retrieval_model'>
@ -1047,10 +1047,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
``` ```
</CodeGroup> </CodeGroup>
<CodeGroup title="Response"> <CodeGroup title="Response">
```json {{ title: 'Response' }} ```text {{ title: 'Response' }}
{ 204 No Content
"result": "success"
}
``` ```
</CodeGroup> </CodeGroup>
</Col> </Col>
@ -1399,10 +1397,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
``` ```
</CodeGroup> </CodeGroup>
<CodeGroup title="Response"> <CodeGroup title="Response">
```json {{ title: 'Response' }} ```text {{ title: 'Response' }}
{ 204 No Content
"result": "success"
}
``` ```
</CodeGroup> </CodeGroup>
</Col> </Col>
@ -1685,10 +1681,8 @@ import { Row, Col, Properties, Property, Heading, SubProperty, PropertyInstructi
``` ```
</CodeGroup> </CodeGroup>
<CodeGroup title="Response"> <CodeGroup title="Response">
```json {{ title: 'Response' }} ```text {{ title: 'Response' }}
{ 204 No Content
"result": "success"
}
``` ```
</CodeGroup> </CodeGroup>
</Col> </Col>

@ -6,9 +6,11 @@ import {
RiDeleteBinLine, RiDeleteBinLine,
RiEditLine, RiEditLine,
RiEqualizer2Line, RiEqualizer2Line,
RiExchange2Line,
RiFileCopy2Line, RiFileCopy2Line,
RiFileDownloadLine, RiFileDownloadLine,
RiFileUploadLine, RiFileUploadLine,
RiMoreLine,
} from '@remixicon/react' } from '@remixicon/react'
import AppIcon from '../base/app-icon' import AppIcon from '../base/app-icon'
import SwitchAppModal from '../app/switch-app-modal' import SwitchAppModal from '../app/switch-app-modal'
@ -32,6 +34,7 @@ import { fetchWorkflowDraft } from '@/service/workflow'
import ContentDialog from '@/app/components/base/content-dialog' import ContentDialog from '@/app/components/base/content-dialog'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView' import CardView from '@/app/(commonLayout)/app/(appDetailLayout)/[appId]/overview/cardView'
import { PortalToFollowElem, PortalToFollowElemContent, PortalToFollowElemTrigger } from '../base/portal-to-follow-elem'
export type IAppInfoProps = { export type IAppInfoProps = {
expand: boolean expand: boolean
@ -179,6 +182,11 @@ const AppInfo = ({ expand }: IAppInfoProps) => {
const { isCurrentWorkspaceEditor } = useAppContext() const { isCurrentWorkspaceEditor } = useAppContext()
const [showMore, setShowMore] = useState(false)
const handleTriggerMore = useCallback(() => {
setShowMore(true)
}, [setShowMore])
if (!appDetail) if (!appDetail)
return null return null
@ -276,23 +284,51 @@ const AppInfo = ({ expand }: IAppInfoProps) => {
<RiFileDownloadLine className='h-3.5 w-3.5 text-components-button-secondary-text' /> <RiFileDownloadLine className='h-3.5 w-3.5 text-components-button-secondary-text' />
<span className='system-xs-medium text-components-button-secondary-text'>{t('app.export')}</span> <span className='system-xs-medium text-components-button-secondary-text'>{t('app.export')}</span>
</Button> </Button>
{ {appDetail.mode !== 'agent-chat' && <PortalToFollowElem
(appDetail.mode === 'advanced-chat' || appDetail.mode === 'workflow') && ( open={showMore}
onOpenChange={setShowMore}
placement='bottom-end'
offset={{
mainAxis: 4,
}}>
<PortalToFollowElemTrigger onClick={handleTriggerMore}>
<Button <Button
size={'small'} size={'small'}
variant={'secondary'} variant={'secondary'}
className='gap-[1px]' className='gap-[1px]'
>
<RiMoreLine className='h-3.5 w-3.5 text-components-button-secondary-text' />
<span className='system-xs-medium text-components-button-secondary-text'>{t('common.operation.more')}</span>
</Button>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className='z-[21]'>
<div className='flex w-[264px] flex-col rounded-[12px] border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg backdrop-blur-[5px]'>
{
(appDetail.mode === 'advanced-chat' || appDetail.mode === 'workflow')
&& <div className='flex h-8 cursor-pointer items-center gap-x-1 rounded-lg p-1.5 hover:bg-state-base-hover'
onClick={() => { onClick={() => {
setOpen(false) setOpen(false)
setShowImportDSLModal(true) setShowImportDSLModal(true)
}} }}>
> <RiFileUploadLine className='h-4 w-4 text-text-tertiary' />
<RiFileUploadLine className='h-3.5 w-3.5 text-components-button-secondary-text' /> <span className='system-md-regular text-text-secondary'>{t('workflow.common.importDSL')}</span>
<span className='system-xs-medium text-components-button-secondary-text'>{t('workflow.common.importDSL')}</span> </div>
</Button> }
) {
(appDetail.mode === 'completion' || appDetail.mode === 'chat')
&& <div className='flex h-8 cursor-pointer items-center gap-x-1 rounded-lg p-1.5 hover:bg-state-base-hover'
onClick={() => {
setOpen(false)
setShowSwitchModal(true)
}}>
<RiExchange2Line className='h-4 w-4 text-text-tertiary' />
<span className='system-md-regular text-text-secondary'>{t('app.switch')}</span>
</div>
} }
</div> </div>
</PortalToFollowElemContent>
</PortalToFollowElem>}
</div>
</div> </div>
<div className='flex flex-1'> <div className='flex flex-1'>
<CardView <CardView

@ -24,7 +24,7 @@ import {
PortalToFollowElemContent, PortalToFollowElemContent,
PortalToFollowElemTrigger, PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem' } from '@/app/components/base/portal-to-follow-elem'
import { basePath } from '@/utils/var' import { WEB_PREFIX } from '@/config'
import { fetchInstalledAppList } from '@/service/explore' import { fetchInstalledAppList } from '@/service/explore'
import EmbeddedModal from '@/app/components/app/overview/embedded' import EmbeddedModal from '@/app/components/app/overview/embedded'
import { useStore as useAppStore } from '@/app/components/app/store' import { useStore as useAppStore } from '@/app/components/app/store'
@ -76,7 +76,7 @@ const AppPublisher = ({
const appDetail = useAppStore(state => state.appDetail) const appDetail = useAppStore(state => state.appDetail)
const { app_base_url: appBaseURL = '', access_token: accessToken = '' } = appDetail?.site ?? {} const { app_base_url: appBaseURL = '', access_token: accessToken = '' } = appDetail?.site ?? {}
const appMode = (appDetail?.mode !== 'completion' && appDetail?.mode !== 'workflow') ? 'chat' : appDetail.mode const appMode = (appDetail?.mode !== 'completion' && appDetail?.mode !== 'workflow') ? 'chat' : appDetail.mode
const appURL = `${appBaseURL}${basePath}/${appMode}/${accessToken}` const appURL = `${appBaseURL}/${appMode}/${accessToken}`
const isChatApp = ['chat', 'agent-chat', 'completion'].includes(appDetail?.mode || '') const isChatApp = ['chat', 'agent-chat', 'completion'].includes(appDetail?.mode || '')
const language = useGetLanguage() const language = useGetLanguage()
@ -121,7 +121,7 @@ const AppPublisher = ({
try { try {
const { installed_apps }: any = await fetchInstalledAppList(appDetail?.id) || {} const { installed_apps }: any = await fetchInstalledAppList(appDetail?.id) || {}
if (installed_apps?.length > 0) if (installed_apps?.length > 0)
window.open(`${basePath}/explore/installed/${installed_apps[0].id}`, '_blank') window.open(`${WEB_PREFIX}/explore/installed/${installed_apps[0].id}`, '_blank')
else else
throw new Error('No app found in Explore') throw new Error('No app found in Explore')
} }

@ -14,7 +14,6 @@ import Loading from '@/app/components/base/loading'
import Badge from '@/app/components/base/badge' import Badge from '@/app/components/base/badge'
import { useKnowledge } from '@/hooks/use-knowledge' import { useKnowledge } from '@/hooks/use-knowledge'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import { basePath } from '@/utils/var'
export type ISelectDataSetProps = { export type ISelectDataSetProps = {
isShow: boolean isShow: boolean
@ -112,7 +111,7 @@ const SelectDataSet: FC<ISelectDataSetProps> = ({
}} }}
> >
<span className='text-text-tertiary'>{t('appDebug.feature.dataSet.noDataSet')}</span> <span className='text-text-tertiary'>{t('appDebug.feature.dataSet.noDataSet')}</span>
<Link href={`${basePath}/datasets/create`} className='font-normal text-text-accent'>{t('appDebug.feature.dataSet.toCreate')}</Link> <Link href={'/datasets/create'} className='font-normal text-text-accent'>{t('appDebug.feature.dataSet.toCreate')}</Link>
</div> </div>
)} )}

@ -14,7 +14,7 @@ import type { AppIconSelection } from '../../base/app-icon-picker'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import Divider from '@/app/components/base/divider' import Divider from '@/app/components/base/divider'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import { basePath } from '@/utils/var' import { WEB_PREFIX } from '@/config'
import AppsContext, { useAppContext } from '@/context/app-context' import AppsContext, { useAppContext } from '@/context/app-context'
import { useProviderContext } from '@/context/provider-context' import { useProviderContext } from '@/context/provider-context'
import { ToastContext } from '@/app/components/base/toast' import { ToastContext } from '@/app/components/base/toast'
@ -353,11 +353,11 @@ function AppScreenShot({ mode, show }: { mode: AppMode; show: boolean }) {
'workflow': 'Workflow', 'workflow': 'Workflow',
} }
return <picture> return <picture>
<source media="(resolution: 1x)" srcSet={`${basePath}/screenshots/${theme}/${modeToImageMap[mode]}.png`} /> <source media="(resolution: 1x)" srcSet={`${WEB_PREFIX}/screenshots/${theme}/${modeToImageMap[mode]}.png`} />
<source media="(resolution: 2x)" srcSet={`${basePath}/screenshots/${theme}/${modeToImageMap[mode]}@2x.png`} /> <source media="(resolution: 2x)" srcSet={`${WEB_PREFIX}/screenshots/${theme}/${modeToImageMap[mode]}@2x.png`} />
<source media="(resolution: 3x)" srcSet={`${basePath}/screenshots/${theme}/${modeToImageMap[mode]}@3x.png`} /> <source media="(resolution: 3x)" srcSet={`${WEB_PREFIX}/screenshots/${theme}/${modeToImageMap[mode]}@3x.png`} />
<Image className={show ? '' : 'hidden'} <Image className={show ? '' : 'hidden'}
src={`${basePath}/screenshots/${theme}/${modeToImageMap[mode]}.png`} src={`${WEB_PREFIX}/screenshots/${theme}/${modeToImageMap[mode]}.png`}
alt='App Screen Shot' alt='App Screen Shot'
width={664} height={448} /> width={664} height={448} />
</picture> </picture>

@ -262,7 +262,7 @@ const CreateFromDSLModal = ({ show, onSuccess, onClose, activeTab = CreateFromDS
{ {
currentTab === CreateFromDSLModalTab.FROM_URL && ( currentTab === CreateFromDSLModalTab.FROM_URL && (
<div> <div>
<div className='system-md-semibold leading6 mb-1'>DSL URL</div> <div className='system-md-semibold mb-1 text-text-secondary'>DSL URL</div>
<Input <Input
placeholder={t('app.importFromDSLUrlPlaceholder') || ''} placeholder={t('app.importFromDSLUrlPlaceholder') || ''}
value={dslUrlValue} value={dslUrlValue}

@ -3,6 +3,7 @@ import type { FC } from 'react'
import React, { useEffect, useRef, useState } from 'react' import React, { useEffect, useRef, useState } from 'react'
import { import {
RiDeleteBinLine, RiDeleteBinLine,
RiUploadCloud2Line,
} from '@remixicon/react' } from '@remixicon/react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector' import { useContext } from 'use-context-selector'
@ -10,8 +11,7 @@ import { formatFileSize } from '@/utils/format'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import { Yaml as YamlIcon } from '@/app/components/base/icons/src/public/files' import { Yaml as YamlIcon } from '@/app/components/base/icons/src/public/files'
import { ToastContext } from '@/app/components/base/toast' import { ToastContext } from '@/app/components/base/toast'
import { UploadCloud01 } from '@/app/components/base/icons/src/vender/line/general' import ActionButton from '@/app/components/base/action-button'
import Button from '@/app/components/base/button'
export type Props = { export type Props = {
file: File | undefined file: File | undefined
@ -102,19 +102,19 @@ const Uploader: FC<Props> = ({
/> />
<div ref={dropRef}> <div ref={dropRef}>
{!file && ( {!file && (
<div className={cn('flex h-12 items-center rounded-xl border border-dashed border-gray-200 bg-gray-50 text-sm font-normal', dragging && 'border border-[#B2CCFF] bg-[#F5F8FF]')}> <div className={cn('flex h-12 items-center rounded-[10px] border border-dashed border-components-dropzone-border bg-components-dropzone-bg text-sm font-normal', dragging && 'border-components-dropzone-border-accent bg-components-dropzone-bg-accent')}>
<div className='flex w-full items-center justify-center space-x-2'> <div className='flex w-full items-center justify-center space-x-2'>
<UploadCloud01 className='mr-2 h-6 w-6' /> <RiUploadCloud2Line className='h-6 w-6 text-text-tertiary' />
<div className='text-gray-500'> <div className='text-text-tertiary'>
{t('datasetCreation.stepOne.uploader.button')} {t('datasetCreation.stepOne.uploader.button')}
<span className='cursor-pointer pl-1 text-[#155eef]' onClick={selectHandle}>{t('datasetDocuments.list.batchModal.browse')}</span> <span className='cursor-pointer pl-1 text-text-accent' onClick={selectHandle}>{t('datasetDocuments.list.batchModal.browse')}</span>
</div> </div>
</div> </div>
{dragging && <div ref={dragRef} className='absolute left-0 top-0 h-full w-full' />} {dragging && <div ref={dragRef} className='absolute left-0 top-0 h-full w-full' />}
</div> </div>
)} )}
{file && ( {file && (
<div className={cn('group flex items-center rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-on-panel-item-bg shadow-xs', 'hover:border-[#B2CCFF] hover:bg-[#F5F8FF]')}> <div className={cn('group flex items-center rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-on-panel-item-bg shadow-xs', ' hover:bg-components-panel-on-panel-item-bg-hover')}>
<div className='flex items-center justify-center p-3'> <div className='flex items-center justify-center p-3'>
<YamlIcon className="h-6 w-6 shrink-0" /> <YamlIcon className="h-6 w-6 shrink-0" />
</div> </div>
@ -126,12 +126,10 @@ const Uploader: FC<Props> = ({
<span>{formatFileSize(file.size)}</span> <span>{formatFileSize(file.size)}</span>
</div> </div>
</div> </div>
<div className='hidden items-center group-hover:flex'> <div className='hidden items-center pr-3 group-hover:flex'>
<Button onClick={selectHandle}>{t('datasetCreation.stepOne.uploader.change')}</Button> <ActionButton onClick={removeFile}>
<div className='mx-2 h-4 w-px bg-gray-200' />
<div className='cursor-pointer p-2' onClick={removeFile}>
<RiDeleteBinLine className='h-4 w-4 text-text-tertiary' /> <RiDeleteBinLine className='h-4 w-4 text-text-tertiary' />
</div> </ActionButton>
</div> </div>
</div> </div>
)} )}

@ -7,7 +7,6 @@ import { usePathname } from 'next/navigation'
import { useDebounce } from 'ahooks' import { useDebounce } from 'ahooks'
import { omit } from 'lodash-es' import { omit } from 'lodash-es'
import dayjs from 'dayjs' import dayjs from 'dayjs'
import { basePath } from '@/utils/var'
import { Trans, useTranslation } from 'react-i18next' import { Trans, useTranslation } from 'react-i18next'
import List from './list' import List from './list'
import Filter, { TIME_PERIOD_MAPPING } from './filter' import Filter, { TIME_PERIOD_MAPPING } from './filter'
@ -110,7 +109,7 @@ const Logs: FC<ILogsProps> = ({ appDetail }) => {
? <Loading type='app' /> ? <Loading type='app' />
: total > 0 : total > 0
? <List logs={isChatMode ? chatConversations : completionConversations} appDetail={appDetail} onRefresh={isChatMode ? mutateChatList : mutateCompletionList} /> ? <List logs={isChatMode ? chatConversations : completionConversations} appDetail={appDetail} onRefresh={isChatMode ? mutateChatList : mutateCompletionList} />
: <EmptyElement appUrl={`${appDetail.site.app_base_url}${basePath}/${getWebAppType(appDetail.mode)}/${appDetail.site.access_token}`} /> : <EmptyElement appUrl={`${appDetail.site.app_base_url}/${getWebAppType(appDetail.mode)}/${appDetail.site.access_token}`} />
} }
{/* Show Pagination only if the total is more than the limit */} {/* Show Pagination only if the total is more than the limit */}
{(total && total > APP_PAGE_LIMIT) {(total && total > APP_PAGE_LIMIT)

@ -17,7 +17,6 @@ import type { ConfigParams } from './settings'
import Tooltip from '@/app/components/base/tooltip' import Tooltip from '@/app/components/base/tooltip'
import AppBasic from '@/app/components/app-sidebar/basic' import AppBasic from '@/app/components/app-sidebar/basic'
import { asyncRunSafe, randomString } from '@/utils' import { asyncRunSafe, randomString } from '@/utils'
import { basePath } from '@/utils/var'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import Switch from '@/app/components/base/switch' import Switch from '@/app/components/base/switch'
import Divider from '@/app/components/base/divider' import Divider from '@/app/components/base/divider'
@ -89,7 +88,7 @@ function AppCard({
const runningStatus = isApp ? appInfo.enable_site : appInfo.enable_api const runningStatus = isApp ? appInfo.enable_site : appInfo.enable_api
const { app_base_url, access_token } = appInfo.site ?? {} const { app_base_url, access_token } = appInfo.site ?? {}
const appMode = (appInfo.mode !== 'completion' && appInfo.mode !== 'workflow') ? 'chat' : appInfo.mode const appMode = (appInfo.mode !== 'completion' && appInfo.mode !== 'workflow') ? 'chat' : appInfo.mode
const appUrl = `${app_base_url}${basePath}/${appMode}/${access_token}` const appUrl = `${app_base_url}/${appMode}/${access_token}`
const apiUrl = appInfo?.api_base_url const apiUrl = appInfo?.api_base_url
const genClickFuncByName = (opName: string) => { const genClickFuncByName = (opName: string) => {

@ -13,7 +13,6 @@ import { IS_CE_EDITION } from '@/config'
import type { SiteInfo } from '@/models/share' import type { SiteInfo } from '@/models/share'
import { useThemeContext } from '@/app/components/base/chat/embedded-chatbot/theme/theme-context' import { useThemeContext } from '@/app/components/base/chat/embedded-chatbot/theme/theme-context'
import ActionButton from '@/app/components/base/action-button' import ActionButton from '@/app/components/base/action-button'
import { basePath } from '@/utils/var'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
type Props = { type Props = {
@ -29,7 +28,7 @@ const OPTION_MAP = {
iframe: { iframe: {
getContent: (url: string, token: string) => getContent: (url: string, token: string) =>
`<iframe `<iframe
src="${url}${basePath}/chatbot/${token}" src="${url}/chatbot/${token}"
style="width: 100%; height: 100%; min-height: 700px" style="width: 100%; height: 100%; min-height: 700px"
frameborder="0" frameborder="0"
allow="microphone"> allow="microphone">
@ -44,7 +43,7 @@ const OPTION_MAP = {
isDev: true` isDev: true`
: ''}${IS_CE_EDITION : ''}${IS_CE_EDITION
? `, ? `,
baseUrl: '${url}${basePath}'` baseUrl: '${url}'`
: ''}, : ''},
systemVariables: { systemVariables: {
// user_id: 'YOU CAN DEFINE USER ID HERE', // user_id: 'YOU CAN DEFINE USER ID HERE',
@ -53,7 +52,7 @@ const OPTION_MAP = {
} }
</script> </script>
<script <script
src="${url}${basePath}/embed.min.js" src="${url}/embed.min.js"
id="${token}" id="${token}"
defer> defer>
</script> </script>
@ -68,7 +67,7 @@ const OPTION_MAP = {
</style>`, </style>`,
}, },
chromePlugin: { chromePlugin: {
getContent: (url: string, token: string) => `ChatBot URL: ${url}${basePath}/chatbot/${token}`, getContent: (url: string, token: string) => `ChatBot URL: ${url}/chatbot/${token}`,
}, },
} }
const prefixEmbedded = 'appOverview.overview.appInfo.embedded' const prefixEmbedded = 'appOverview.overview.appInfo.embedded'

@ -11,7 +11,6 @@ import timezone from 'dayjs/plugin/timezone'
import { Trans, useTranslation } from 'react-i18next' import { Trans, useTranslation } from 'react-i18next'
import Link from 'next/link' import Link from 'next/link'
import List from './list' import List from './list'
import { basePath } from '@/utils/var'
import Filter, { TIME_PERIOD_MAPPING } from './filter' import Filter, { TIME_PERIOD_MAPPING } from './filter'
import Pagination from '@/app/components/base/pagination' import Pagination from '@/app/components/base/pagination'
import Loading from '@/app/components/base/loading' import Loading from '@/app/components/base/loading'
@ -101,7 +100,7 @@ const Logs: FC<ILogsProps> = ({ appDetail }) => {
? <Loading type='app' /> ? <Loading type='app' />
: total > 0 : total > 0
? <List logs={workflowLogs} appDetail={appDetail} onRefresh={mutate} /> ? <List logs={workflowLogs} appDetail={appDetail} onRefresh={mutate} />
: <EmptyElement appUrl={`${appDetail.site.app_base_url}${basePath}/${getWebAppType(appDetail.mode)}/${appDetail.site.access_token}`} /> : <EmptyElement appUrl={`${appDetail.site.app_base_url}/${getWebAppType(appDetail.mode)}/${appDetail.site.access_token}`} />
} }
{/* Show Pagination only if the total is more than the limit */} {/* Show Pagination only if the total is more than the limit */}
{(total && total > APP_PAGE_LIMIT) {(total && total > APP_PAGE_LIMIT)

@ -1,4 +1,4 @@
import React, { useCallback } from 'react' import React, { memo, useCallback } from 'react'
import { useTranslation } from 'react-i18next' import { useTranslation } from 'react-i18next'
import { useChatWithHistoryContext } from '../context' import { useChatWithHistoryContext } from '../context'
import Input from '@/app/components/base/input' import Input from '@/app/components/base/input'
@ -112,4 +112,4 @@ const InputsFormContent = ({ showTip }: Props) => {
) )
} }
export default InputsFormContent export default memo(InputsFormContent)

@ -424,6 +424,8 @@ export const useChat = (
const response = responseItem as any const response = responseItem as any
if (thought.message_id && !hasSetResponseId) if (thought.message_id && !hasSetResponseId)
response.id = thought.message_id response.id = thought.message_id
if (thought.conversation_id)
response.conversationId = thought.conversation_id
if (response.agent_thoughts.length === 0) { if (response.agent_thoughts.length === 0) {
response.agent_thoughts.push(thought) response.agent_thoughts.push(thought)

@ -5,6 +5,8 @@ import type {
import { import {
memo, memo,
useCallback, useCallback,
useEffect,
useRef,
useState, useState,
} from 'react' } from 'react'
import type { ChatItem } from '../types' import type { ChatItem } from '../types'
@ -52,6 +54,8 @@ const Question: FC<QuestionProps> = ({
const [isEditing, setIsEditing] = useState(false) const [isEditing, setIsEditing] = useState(false)
const [editedContent, setEditedContent] = useState(content) const [editedContent, setEditedContent] = useState(content)
const [contentWidth, setContentWidth] = useState(0)
const contentRef = useRef<HTMLDivElement>(null)
const handleEdit = useCallback(() => { const handleEdit = useCallback(() => {
setIsEditing(true) setIsEditing(true)
@ -75,14 +79,31 @@ const Question: FC<QuestionProps> = ({
item.nextSibling && switchSibling?.(item.nextSibling) item.nextSibling && switchSibling?.(item.nextSibling)
}, [switchSibling, item.prevSibling, item.nextSibling]) }, [switchSibling, item.prevSibling, item.nextSibling])
const getContentWidth = () => {
if (contentRef.current)
setContentWidth(contentRef.current?.clientWidth)
}
useEffect(() => {
if (!contentRef.current)
return
const resizeObserver = new ResizeObserver(() => {
getContentWidth()
})
resizeObserver.observe(contentRef.current)
return () => {
resizeObserver.disconnect()
}
}, [])
return ( return (
<div className='mb-2 flex justify-end pl-14 last:mb-0'> <div className='mb-2 flex justify-end last:mb-0'>
<div className={cn('group relative mr-4 flex max-w-full items-start', isEditing && 'flex-1')}> <div className={cn('group relative mr-4 flex max-w-full items-start pl-14', isEditing && 'flex-1')}>
<div className={cn('mr-2 gap-1', isEditing ? 'hidden' : 'flex')}> <div className={cn('mr-2 gap-1', isEditing ? 'hidden' : 'flex')}>
<div className=" <div
absolutegap-0.5 hidden rounded-[10px] border-[0.5px] border-components-actionbar-border className="absolute hidden gap-0.5 rounded-[10px] border-[0.5px] border-components-actionbar-border bg-components-actionbar-bg p-0.5 shadow-md backdrop-blur-sm group-hover:flex"
bg-components-actionbar-bg p-0.5 shadow-md backdrop-blur-sm group-hover:flex style={{ right: contentWidth + 8 }}
"> >
<ActionButton onClick={() => { <ActionButton onClick={() => {
copy(content) copy(content)
Toast.notify({ type: 'success', message: t('common.actionMsg.copySuccessfully') }) Toast.notify({ type: 'success', message: t('common.actionMsg.copySuccessfully') })
@ -95,6 +116,7 @@ const Question: FC<QuestionProps> = ({
</div> </div>
</div> </div>
<div <div
ref={contentRef}
className='w-full rounded-2xl bg-[#D1E9FF]/50 px-4 py-3 text-sm text-gray-900' className='w-full rounded-2xl bg-[#D1E9FF]/50 px-4 py-3 text-sm text-gray-900'
style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}} style={theme?.chatBubbleColorStyle ? CssTransform(theme.chatBubbleColorStyle) : {}}
> >

@ -41,6 +41,7 @@ export type ThoughtItem = {
tool_input: string tool_input: string
tool_labels?: { [key: string]: TypeWithI18N } tool_labels?: { [key: string]: TypeWithI18N }
message_id: string message_id: string
conversation_id: string
observation: string observation: string
position: number position: number
files?: string[] files?: string[]

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save