Merge branch 'main' into fix/chore-fix

pull/12372/head
Yeuoly 1 year ago
commit 78664c8903

@ -65,7 +65,7 @@ OPENDAL_FS_ROOT=storage
# S3 Storage configuration # S3 Storage configuration
S3_USE_AWS_MANAGED_IAM=false S3_USE_AWS_MANAGED_IAM=false
S3_ENDPOINT=https://your-bucket-name.storage.s3.clooudflare.com S3_ENDPOINT=https://your-bucket-name.storage.s3.cloudflare.com
S3_BUCKET_NAME=your-bucket-name S3_BUCKET_NAME=your-bucket-name
S3_ACCESS_KEY=your-access-key S3_ACCESS_KEY=your-access-key
S3_SECRET_KEY=your-secret-key S3_SECRET_KEY=your-secret-key
@ -74,7 +74,7 @@ S3_REGION=your-region
# Azure Blob Storage configuration # Azure Blob Storage configuration
AZURE_BLOB_ACCOUNT_NAME=your-account-name AZURE_BLOB_ACCOUNT_NAME=your-account-name
AZURE_BLOB_ACCOUNT_KEY=your-account-key AZURE_BLOB_ACCOUNT_KEY=your-account-key
AZURE_BLOB_CONTAINER_NAME=yout-container-name AZURE_BLOB_CONTAINER_NAME=your-container-name
AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net AZURE_BLOB_ACCOUNT_URL=https://<your_account_name>.blob.core.windows.net
# Aliyun oss Storage configuration # Aliyun oss Storage configuration
@ -88,7 +88,7 @@ ALIYUN_OSS_REGION=your-region
ALIYUN_OSS_PATH=your-path ALIYUN_OSS_PATH=your-path
# Google Storage configuration # Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=yout-bucket-name GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string GOOGLE_STORAGE_SERVICE_ACCOUNT_JSON_BASE64=your-google-service-account-json-base64-string
# Tencent COS Storage configuration # Tencent COS Storage configuration

@ -67,7 +67,7 @@ ignore = [
"SIM105", # suppressible-exception "SIM105", # suppressible-exception
"SIM107", # return-in-try-except-finally "SIM107", # return-in-try-except-finally
"SIM108", # if-else-block-instead-of-if-exp "SIM108", # if-else-block-instead-of-if-exp
"SIM113", # eumerate-for-loop "SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements "SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false "SIM210", # if-expr-with-true-false
] ]

@ -563,8 +563,13 @@ def create_tenant(email: str, language: Optional[str] = None, name: Optional[str
new_password = secrets.token_urlsafe(16) new_password = secrets.token_urlsafe(16)
# register account # register account
account = RegisterService.register(email=email, name=account_name, password=new_password, language=language) account = RegisterService.register(
email=email,
name=account_name,
password=new_password,
language=language,
create_workspace_required=False,
)
TenantService.create_owner_tenant_if_not_exist(account, name) TenantService.create_owner_tenant_if_not_exist(account, name)
click.echo( click.echo(
@ -584,7 +589,7 @@ def upgrade_db():
click.echo(click.style("Starting database migration.", fg="green")) click.echo(click.style("Starting database migration.", fg="green"))
# run db migration # run db migration
import flask_migrate import flask_migrate # type: ignore
flask_migrate.upgrade() flask_migrate.upgrade()

@ -659,7 +659,7 @@ class RagEtlConfig(BaseSettings):
UNSTRUCTURED_API_KEY: Optional[str] = Field( UNSTRUCTURED_API_KEY: Optional[str] = Field(
description="API key for Unstructured.io service", description="API key for Unstructured.io service",
default=None, default="",
) )
SCARF_NO_ANALYTICS: Optional[str] = Field( SCARF_NO_ANALYTICS: Optional[str] = Field(

@ -232,7 +232,7 @@ class DataSourceNotionApi(Resource):
args["doc_form"], args["doc_form"],
args["doc_language"], args["doc_language"],
) )
return response, 200 return response.model_dump(), 200
class DataSourceNotionDatasetSyncApi(Resource): class DataSourceNotionDatasetSyncApi(Resource):

@ -464,7 +464,7 @@ class DatasetIndexingEstimateApi(Resource):
except Exception as e: except Exception as e:
raise IndexingEstimateError(str(e)) raise IndexingEstimateError(str(e))
return response, 200 return response.model_dump(), 200
class DatasetRelatedAppListApi(Resource): class DatasetRelatedAppListApi(Resource):
@ -733,6 +733,18 @@ class DatasetPermissionUserListApi(Resource):
}, 200 }, 200
class DatasetAutoDisableLogApi(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id):
dataset_id_str = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id_str)
if dataset is None:
raise NotFound("Dataset not found.")
return DatasetService.get_dataset_auto_disable_logs(dataset_id_str), 200
api.add_resource(DatasetListApi, "/datasets") api.add_resource(DatasetListApi, "/datasets")
api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>") api.add_resource(DatasetApi, "/datasets/<uuid:dataset_id>")
api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check") api.add_resource(DatasetUseCheckApi, "/datasets/<uuid:dataset_id>/use-check")
@ -747,3 +759,4 @@ api.add_resource(DatasetApiBaseUrlApi, "/datasets/api-base-info")
api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting") api.add_resource(DatasetRetrievalSettingApi, "/datasets/retrieval-setting")
api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>") api.add_resource(DatasetRetrievalSettingMockApi, "/datasets/retrieval-setting/<string:vector_type>")
api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users") api.add_resource(DatasetPermissionUserListApi, "/datasets/<uuid:dataset_id>/permission-part-users")
api.add_resource(DatasetAutoDisableLogApi, "/datasets/<uuid:dataset_id>/auto-disable-logs")

@ -52,6 +52,7 @@ from fields.document_fields import (
from libs.login import login_required from libs.login import login_required
from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile from models import Dataset, DatasetProcessRule, Document, DocumentSegment, UploadFile
from services.dataset_service import DatasetService, DocumentService from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from tasks.add_document_to_index_task import add_document_to_index_task from tasks.add_document_to_index_task import add_document_to_index_task
from tasks.remove_document_from_index_task import remove_document_from_index_task from tasks.remove_document_from_index_task import remove_document_from_index_task
@ -267,20 +268,22 @@ class DatasetDocumentListApi(Resource):
parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json") parser.add_argument("duplicate", type=bool, default=True, nullable=False, location="json")
parser.add_argument("original_document_id", type=str, required=False, location="json") parser.add_argument("original_document_id", type=str, required=False, location="json")
parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json") parser.add_argument("doc_form", type=str, default="text_model", required=False, nullable=False, location="json")
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
parser.add_argument( parser.add_argument(
"doc_language", type=str, default="English", required=False, nullable=False, location="json" "doc_language", type=str, default="English", required=False, nullable=False, location="json"
) )
parser.add_argument("retrieval_model", type=dict, required=False, nullable=False, location="json")
args = parser.parse_args() args = parser.parse_args()
knowledge_config = KnowledgeConfig(**args)
if not dataset.indexing_technique and not args["indexing_technique"]: if not dataset.indexing_technique and not knowledge_config.indexing_technique:
raise ValueError("indexing_technique is required.") raise ValueError("indexing_technique is required.")
# validate args # validate args
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(knowledge_config)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id(dataset, args, current_user) documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, current_user)
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
except QuotaExceededError: except QuotaExceededError:
@ -290,6 +293,25 @@ class DatasetDocumentListApi(Resource):
return {"documents": documents, "batch": batch} return {"documents": documents, "batch": batch}
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id):
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if dataset is None:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
try:
document_ids = request.args.getlist("document_id")
DocumentService.delete_documents(dataset, document_ids)
except services.errors.document.DocumentIndexingError:
raise DocumentIndexingError("Cannot delete document during indexing.")
return {"result": "success"}, 204
class DatasetInitApi(Resource): class DatasetInitApi(Resource):
@setup_required @setup_required
@ -325,9 +347,9 @@ class DatasetInitApi(Resource):
# The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator # The role of the current user in the ta table must be admin, owner, or editor, or dataset_operator
if not current_user.is_dataset_editor: if not current_user.is_dataset_editor:
raise Forbidden() raise Forbidden()
knowledge_config = KnowledgeConfig(**args)
if args["indexing_technique"] == "high_quality": if knowledge_config.indexing_technique == "high_quality":
if args["embedding_model"] is None or args["embedding_model_provider"] is None: if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None:
raise ValueError("embedding model and embedding model provider are required for high quality indexing.") raise ValueError("embedding model and embedding model provider are required for high quality indexing.")
try: try:
model_manager = ModelManager() model_manager = ModelManager()
@ -346,11 +368,11 @@ class DatasetInitApi(Resource):
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
# validate args # validate args
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(knowledge_config)
try: try:
dataset, documents, batch = DocumentService.save_document_without_dataset_id( dataset, documents, batch = DocumentService.save_document_without_dataset_id(
tenant_id=current_user.current_tenant_id, document_data=args, account=current_user tenant_id=current_user.current_tenant_id, knowledge_config=knowledge_config, account=current_user
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
@ -403,7 +425,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
indexing_runner = IndexingRunner() indexing_runner = IndexingRunner()
try: try:
response = indexing_runner.indexing_estimate( estimate_response = indexing_runner.indexing_estimate(
current_user.current_tenant_id, current_user.current_tenant_id,
[extract_setting], [extract_setting],
data_process_rule_dict, data_process_rule_dict,
@ -411,6 +433,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
"English", "English",
dataset_id, dataset_id,
) )
return estimate_response.model_dump(), 200
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider "
@ -423,7 +446,7 @@ class DocumentIndexingEstimateApi(DocumentResource):
except Exception as e: except Exception as e:
raise IndexingEstimateError(str(e)) raise IndexingEstimateError(str(e))
return response return response, 200
class DocumentBatchIndexingEstimateApi(DocumentResource): class DocumentBatchIndexingEstimateApi(DocumentResource):
@ -434,9 +457,8 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
batch = str(batch) batch = str(batch)
documents = self.get_batch_documents(dataset_id, batch) documents = self.get_batch_documents(dataset_id, batch)
response = {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}
if not documents: if not documents:
return response return {"tokens": 0, "total_price": 0, "currency": "USD", "total_segments": 0, "preview": []}, 200
data_process_rule = documents[0].dataset_process_rule data_process_rule = documents[0].dataset_process_rule
data_process_rule_dict = data_process_rule.to_dict() data_process_rule_dict = data_process_rule.to_dict()
info_list = [] info_list = []
@ -514,6 +536,7 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
"English", "English",
dataset_id, dataset_id,
) )
return response.model_dump(), 200
except LLMBadRequestError: except LLMBadRequestError:
raise ProviderNotInitializeError( raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider " "No Embedding Model available. Please configure a valid provider "
@ -525,7 +548,6 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
except Exception as e: except Exception as e:
raise IndexingEstimateError(str(e)) raise IndexingEstimateError(str(e))
return response
class DocumentBatchIndexingStatusApi(DocumentResource): class DocumentBatchIndexingStatusApi(DocumentResource):
@ -598,7 +620,8 @@ class DocumentDetailApi(DocumentResource):
if metadata == "only": if metadata == "only":
response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata} response = {"id": document.id, "doc_type": document.doc_type, "doc_metadata": document.doc_metadata}
elif metadata == "without": elif metadata == "without":
process_rules = DatasetService.get_process_rules(dataset_id) dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
data_source_info = document.data_source_detail_dict data_source_info = document.data_source_detail_dict
response = { response = {
"id": document.id, "id": document.id,
@ -606,7 +629,8 @@ class DocumentDetailApi(DocumentResource):
"data_source_type": document.data_source_type, "data_source_type": document.data_source_type,
"data_source_info": data_source_info, "data_source_info": data_source_info,
"dataset_process_rule_id": document.dataset_process_rule_id, "dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": process_rules, "dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
"name": document.name, "name": document.name,
"created_from": document.created_from, "created_from": document.created_from,
"created_by": document.created_by, "created_by": document.created_by,
@ -629,7 +653,8 @@ class DocumentDetailApi(DocumentResource):
"doc_language": document.doc_language, "doc_language": document.doc_language,
} }
else: else:
process_rules = DatasetService.get_process_rules(dataset_id) dataset_process_rules = DatasetService.get_process_rules(dataset_id)
document_process_rules = document.dataset_process_rule.to_dict()
data_source_info = document.data_source_detail_dict data_source_info = document.data_source_detail_dict
response = { response = {
"id": document.id, "id": document.id,
@ -637,7 +662,8 @@ class DocumentDetailApi(DocumentResource):
"data_source_type": document.data_source_type, "data_source_type": document.data_source_type,
"data_source_info": data_source_info, "data_source_info": data_source_info,
"dataset_process_rule_id": document.dataset_process_rule_id, "dataset_process_rule_id": document.dataset_process_rule_id,
"dataset_process_rule": process_rules, "dataset_process_rule": dataset_process_rules,
"document_process_rule": document_process_rules,
"name": document.name, "name": document.name,
"created_from": document.created_from, "created_from": document.created_from,
"created_by": document.created_by, "created_by": document.created_by,
@ -773,9 +799,8 @@ class DocumentStatusApi(DocumentResource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, action): def patch(self, dataset_id, action):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
document_id = str(document_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if dataset is None: if dataset is None:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
@ -790,84 +815,79 @@ class DocumentStatusApi(DocumentResource):
# check user's permission # check user's permission
DatasetService.check_dataset_permission(dataset, current_user) DatasetService.check_dataset_permission(dataset, current_user)
document = self.get_document(dataset_id, document_id) document_ids = request.args.getlist("document_id")
for document_id in document_ids:
document = self.get_document(dataset_id, document_id)
indexing_cache_key = "document_{}_indexing".format(document.id) indexing_cache_key = "document_{}_indexing".format(document.id)
cache_result = redis_client.get(indexing_cache_key) cache_result = redis_client.get(indexing_cache_key)
if cache_result is not None: if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later") raise InvalidActionError(f"Document:{document.name} is being indexed, please try again later")
if action == "enable": if action == "enable":
if document.enabled: if document.enabled:
raise InvalidActionError("Document already enabled.") continue
document.enabled = True
document.disabled_at = None
document.disabled_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
document.enabled = True # Set cache to prevent indexing the same document multiple times
document.disabled_at = None redis_client.setex(indexing_cache_key, 600, 1)
document.disabled_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times add_document_to_index_task.delay(document_id)
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id) elif action == "disable":
if not document.completed_at or document.indexing_status != "completed":
raise InvalidActionError(f"Document: {document.name} is not completed.")
if not document.enabled:
continue
return {"result": "success"}, 200 document.enabled = False
document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
document.disabled_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
elif action == "disable": # Set cache to prevent indexing the same document multiple times
if not document.completed_at or document.indexing_status != "completed": redis_client.setex(indexing_cache_key, 600, 1)
raise InvalidActionError("Document is not completed.")
if not document.enabled:
raise InvalidActionError("Document already disabled.")
document.enabled = False remove_document_from_index_task.delay(document_id)
document.disabled_at = datetime.now(UTC).replace(tzinfo=None)
document.disabled_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times elif action == "archive":
redis_client.setex(indexing_cache_key, 600, 1) if document.archived:
continue
remove_document_from_index_task.delay(document_id) document.archived = True
document.archived_at = datetime.now(UTC).replace(tzinfo=None)
document.archived_by = current_user.id
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return {"result": "success"}, 200 if document.enabled:
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
elif action == "archive": remove_document_from_index_task.delay(document_id)
if document.archived:
raise InvalidActionError("Document already archived.")
document.archived = True elif action == "un_archive":
document.archived_at = datetime.now(UTC).replace(tzinfo=None) if not document.archived:
document.archived_by = current_user.id continue
document.updated_at = datetime.now(UTC).replace(tzinfo=None) document.archived = False
db.session.commit() document.archived_at = None
document.archived_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
if document.enabled:
# Set cache to prevent indexing the same document multiple times # Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1) redis_client.setex(indexing_cache_key, 600, 1)
remove_document_from_index_task.delay(document_id) add_document_to_index_task.delay(document_id)
return {"result": "success"}, 200
elif action == "un_archive":
if not document.archived:
raise InvalidActionError("Document is not archived.")
document.archived = False
document.archived_at = None
document.archived_by = None
document.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
# Set cache to prevent indexing the same document multiple times
redis_client.setex(indexing_cache_key, 600, 1)
add_document_to_index_task.delay(document_id) else:
raise InvalidActionError()
return {"result": "success"}, 200 return {"result": "success"}, 200
else:
raise InvalidActionError()
class DocumentPauseApi(DocumentResource): class DocumentPauseApi(DocumentResource):
@ -1038,7 +1058,7 @@ api.add_resource(
) )
api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>") api.add_resource(DocumentDeleteApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>")
api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata") api.add_resource(DocumentMetadataApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/metadata")
api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/status/<string:action>") api.add_resource(DocumentStatusApi, "/datasets/<uuid:dataset_id>/documents/status/<string:action>/batch")
api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause") api.add_resource(DocumentPauseApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/pause")
api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume") api.add_resource(DocumentRecoverApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/processing/resume")
api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry") api.add_resource(DocumentRetryApi, "/datasets/<uuid:dataset_id>/retry")

@ -1,5 +1,4 @@
import uuid import uuid
from datetime import UTC, datetime
import pandas as pd import pandas as pd
from flask import request from flask import request
@ -10,7 +9,13 @@ from werkzeug.exceptions import Forbidden, NotFound
import services import services
from controllers.console import api from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError from controllers.console.app.error import ProviderNotInitializeError
from controllers.console.datasets.error import InvalidActionError, NoFileUploadedError, TooManyFilesError from controllers.console.datasets.error import (
ChildChunkDeleteIndexError,
ChildChunkIndexingError,
InvalidActionError,
NoFileUploadedError,
TooManyFilesError,
)
from controllers.console.wraps import ( from controllers.console.wraps import (
account_initialization_required, account_initialization_required,
cloud_edition_billing_knowledge_limit_check, cloud_edition_billing_knowledge_limit_check,
@ -20,15 +25,15 @@ from controllers.console.wraps import (
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from fields.segment_fields import segment_fields from fields.segment_fields import child_chunk_fields, segment_fields
from libs.login import login_required from libs.login import login_required
from models import DocumentSegment from models.dataset import ChildChunk, DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import ChildChunkUpdateArgs, SegmentUpdateArgs
from services.errors.chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
class DatasetDocumentSegmentListApi(Resource): class DatasetDocumentSegmentListApi(Resource):
@ -53,15 +58,16 @@ class DatasetDocumentSegmentListApi(Resource):
raise NotFound("Document not found.") raise NotFound("Document not found.")
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("last_id", type=str, default=None, location="args")
parser.add_argument("limit", type=int, default=20, location="args") parser.add_argument("limit", type=int, default=20, location="args")
parser.add_argument("status", type=str, action="append", default=[], location="args") parser.add_argument("status", type=str, action="append", default=[], location="args")
parser.add_argument("hit_count_gte", type=int, default=None, location="args") parser.add_argument("hit_count_gte", type=int, default=None, location="args")
parser.add_argument("enabled", type=str, default="all", location="args") parser.add_argument("enabled", type=str, default="all", location="args")
parser.add_argument("keyword", type=str, default=None, location="args") parser.add_argument("keyword", type=str, default=None, location="args")
parser.add_argument("page", type=int, default=1, location="args")
args = parser.parse_args() args = parser.parse_args()
last_id = args["last_id"] page = args["page"]
limit = min(args["limit"], 100) limit = min(args["limit"], 100)
status_list = args["status"] status_list = args["status"]
hit_count_gte = args["hit_count_gte"] hit_count_gte = args["hit_count_gte"]
@ -69,14 +75,7 @@ class DatasetDocumentSegmentListApi(Resource):
query = DocumentSegment.query.filter( query = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id DocumentSegment.document_id == str(document_id), DocumentSegment.tenant_id == current_user.current_tenant_id
) ).order_by(DocumentSegment.position.asc())
if last_id is not None:
last_segment = db.session.get(DocumentSegment, str(last_id))
if last_segment:
query = query.filter(DocumentSegment.position > last_segment.position)
else:
return {"data": [], "has_more": False, "limit": limit}, 200
if status_list: if status_list:
query = query.filter(DocumentSegment.status.in_(status_list)) query = query.filter(DocumentSegment.status.in_(status_list))
@ -93,21 +92,44 @@ class DatasetDocumentSegmentListApi(Resource):
elif args["enabled"].lower() == "false": elif args["enabled"].lower() == "false":
query = query.filter(DocumentSegment.enabled == False) query = query.filter(DocumentSegment.enabled == False)
total = query.count() segments = query.paginate(page=page, per_page=limit, max_per_page=100, error_out=False)
segments = query.order_by(DocumentSegment.position).limit(limit + 1).all()
has_more = False
if len(segments) > limit:
has_more = True
segments = segments[:-1]
return { response = {
"data": marshal(segments, segment_fields), "data": marshal(segments.items, segment_fields),
"doc_form": document.doc_form,
"has_more": has_more,
"limit": limit, "limit": limit,
"total": total, "total": segments.total,
}, 200 "total_pages": segments.pages,
"page": page,
}
return response, 200
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id, document_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
segment_ids = request.args.getlist("segment_id")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
SegmentService.delete_segments(segment_ids, document, dataset)
return {"result": "success"}, 200
class DatasetDocumentSegmentApi(Resource): class DatasetDocumentSegmentApi(Resource):
@ -115,11 +137,15 @@ class DatasetDocumentSegmentApi(Resource):
@login_required @login_required
@account_initialization_required @account_initialization_required
@cloud_edition_billing_resource_check("vector_space") @cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, segment_id, action): def patch(self, dataset_id, document_id, action):
dataset_id = str(dataset_id) dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id) dataset = DatasetService.get_dataset(dataset_id)
if not dataset: if not dataset:
raise NotFound("Dataset not found.") raise NotFound("Dataset not found.")
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check user's model setting # check user's model setting
DatasetService.check_dataset_model_setting(dataset) DatasetService.check_dataset_model_setting(dataset)
# The role of the current user in the ta table must be admin, owner, or editor # The role of the current user in the ta table must be admin, owner, or editor
@ -147,59 +173,17 @@ class DatasetDocumentSegmentApi(Resource):
) )
except ProviderTokenNotInitError as ex: except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description) raise ProviderNotInitializeError(ex.description)
segment_ids = request.args.getlist("segment_id")
segment = DocumentSegment.query.filter( document_indexing_cache_key = "document_{}_indexing".format(document.id)
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
if segment.status != "completed":
raise NotFound("Segment is not completed, enable or disable function is not allowed")
document_indexing_cache_key = "document_{}_indexing".format(segment.document_id)
cache_result = redis_client.get(document_indexing_cache_key) cache_result = redis_client.get(document_indexing_cache_key)
if cache_result is not None: if cache_result is not None:
raise InvalidActionError("Document is being indexed, please try again later") raise InvalidActionError("Document is being indexed, please try again later")
try:
indexing_cache_key = "segment_{}_indexing".format(segment.id) SegmentService.update_segments_status(segment_ids, action, dataset, document)
cache_result = redis_client.get(indexing_cache_key) except Exception as e:
if cache_result is not None: raise InvalidActionError(str(e))
raise InvalidActionError("Segment is being indexed, please try again later") return {"result": "success"}, 200
if action == "enable":
if segment.enabled:
raise InvalidActionError("Segment is already enabled.")
segment.enabled = True
segment.disabled_at = None
segment.disabled_by = None
db.session.commit()
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
enable_segment_to_index_task.delay(segment.id)
return {"result": "success"}, 200
elif action == "disable":
if not segment.enabled:
raise InvalidActionError("Segment is already disabled.")
segment.enabled = False
segment.disabled_at = datetime.now(UTC).replace(tzinfo=None)
segment.disabled_by = current_user.id
db.session.commit()
# Set cache to prevent indexing the same segment multiple times
redis_client.setex(indexing_cache_key, 600, 1)
disable_segment_from_index_task.delay(segment.id)
return {"result": "success"}, 200
else:
raise InvalidActionError()
class DatasetDocumentSegmentAddApi(Resource): class DatasetDocumentSegmentAddApi(Resource):
@ -307,9 +291,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
parser.add_argument("content", type=str, required=True, nullable=False, location="json") parser.add_argument("content", type=str, required=True, nullable=False, location="json")
parser.add_argument("answer", type=str, required=False, nullable=True, location="json") parser.add_argument("answer", type=str, required=False, nullable=True, location="json")
parser.add_argument("keywords", type=list, required=False, nullable=True, location="json") parser.add_argument("keywords", type=list, required=False, nullable=True, location="json")
parser.add_argument(
"regenerate_child_chunks", type=bool, required=False, nullable=True, default=False, location="json"
)
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args, document) SegmentService.segment_create_args_validate(args, document)
segment = SegmentService.update_segment(args, segment, document, dataset) segment = SegmentService.update_segment(SegmentUpdateArgs(**args), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required @setup_required
@ -412,8 +399,248 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
return {"job_id": job_id, "job_status": cache_result.decode()}, 200 return {"job_id": job_id, "job_status": cache_result.decode()}, 200
class ChildChunkAddApi(Resource):
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
@cloud_edition_billing_knowledge_limit_check("add_segment")
def post(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
if not current_user.is_editor:
raise Forbidden()
# check embedding model setting
if dataset.indexing_technique == "high_quality":
try:
model_manager = ModelManager()
model_manager.get_model_instance(
tenant_id=current_user.current_tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
"No Embedding Model available. Please configure a valid provider "
"in the Settings -> Model Provider."
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
child_chunk = SegmentService.create_child_chunk(args.get("content"), segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
parser = reqparse.RequestParser()
parser.add_argument("limit", type=int, default=20, location="args")
parser.add_argument("keyword", type=str, default=None, location="args")
parser.add_argument("page", type=int, default=1, location="args")
args = parser.parse_args()
page = args["page"]
limit = min(args["limit"], 100)
keyword = args["keyword"]
child_chunks = SegmentService.get_child_chunks(segment_id, document_id, dataset_id, page, limit, keyword)
return {
"data": marshal(child_chunks.items, child_chunk_fields),
"total": child_chunks.total,
"total_pages": child_chunks.pages,
"page": page,
"limit": limit,
}, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("chunks", type=list, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
chunks = [ChildChunkUpdateArgs(**chunk) for chunk in args.get("chunks")]
child_chunks = SegmentService.update_child_chunks(chunks, segment, document, dataset)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunks, child_chunk_fields)}, 200
class ChildChunkUpdateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def delete(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter(
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
).first()
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
try:
SegmentService.delete_child_chunk(child_chunk, dataset)
except ChildChunkDeleteIndexServiceError as e:
raise ChildChunkDeleteIndexError(str(e))
return {"result": "success"}, 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("vector_space")
def patch(self, dataset_id, document_id, segment_id, child_chunk_id):
# check dataset
dataset_id = str(dataset_id)
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound("Document not found.")
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id), DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound("Segment not found.")
# check child chunk
child_chunk_id = str(child_chunk_id)
child_chunk = ChildChunk.query.filter(
ChildChunk.id == str(child_chunk_id), ChildChunk.tenant_id == current_user.current_tenant_id
).first()
if not child_chunk:
raise NotFound("Child chunk not found.")
# The role of the current user in the ta table must be admin or owner
if not current_user.is_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# validate args
parser = reqparse.RequestParser()
parser.add_argument("content", type=str, required=True, nullable=False, location="json")
args = parser.parse_args()
try:
child_chunk = SegmentService.update_child_chunk(
args.get("content"), child_chunk, segment, document, dataset
)
except ChildChunkIndexingServiceError as e:
raise ChildChunkIndexingError(str(e))
return {"data": marshal(child_chunk, child_chunk_fields)}, 200
api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments") api.add_resource(DatasetDocumentSegmentListApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments")
api.add_resource(DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action>") api.add_resource(
DatasetDocumentSegmentApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action>"
)
api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment") api.add_resource(DatasetDocumentSegmentAddApi, "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment")
api.add_resource( api.add_resource(
DatasetDocumentSegmentUpdateApi, DatasetDocumentSegmentUpdateApi,
@ -424,3 +651,11 @@ api.add_resource(
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import", "/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import",
"/datasets/batch_import_status/<uuid:job_id>", "/datasets/batch_import_status/<uuid:job_id>",
) )
api.add_resource(
ChildChunkAddApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks",
)
api.add_resource(
ChildChunkUpdateApi,
"/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id>",
)

@ -89,3 +89,15 @@ class IndexingEstimateError(BaseHTTPException):
error_code = "indexing_estimate_error" error_code = "indexing_estimate_error"
description = "Knowledge indexing estimate failed: {message}" description = "Knowledge indexing estimate failed: {message}"
code = 500 code = 500
class ChildChunkIndexingError(BaseHTTPException):
error_code = "child_chunk_indexing_error"
description = "Create child chunk index failed: {message}"
code = 500
class ChildChunkDeleteIndexError(BaseHTTPException):
error_code = "child_chunk_delete_index_error"
description = "Delete child chunk index failed: {message}"
code = 500

@ -66,10 +66,17 @@ class MessageFeedbackApi(InstalledAppResource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
parser.add_argument("content", type=str, location="json")
args = parser.parse_args() args = parser.parse_args()
try: try:
MessageService.create_feedback(app_model, message_id, current_user, args["rating"], args["content"]) MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
user=current_user,
rating=args.get("rating"),
content=args.get("content"),
)
except services.errors.message.MessageNotExistsError: except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")

@ -108,7 +108,13 @@ class MessageFeedbackApi(Resource):
args = parser.parse_args() args = parser.parse_args()
try: try:
MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
user=end_user,
rating=args.get("rating"),
content=args.get("content"),
)
except services.errors.message.MessageNotExistsError: except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")

@ -8,12 +8,16 @@ from werkzeug.exceptions import NotFound
import services.dataset_service import services.dataset_service
from controllers.common.errors import FilenameNotExistsError from controllers.common.errors import FilenameNotExistsError
from controllers.service_api import api from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError from controllers.service_api.app.error import (
FileTooLargeError,
NoFileUploadedError,
ProviderNotInitializeError,
TooManyFilesError,
UnsupportedFileTypeError,
)
from controllers.service_api.dataset.error import ( from controllers.service_api.dataset.error import (
ArchivedDocumentImmutableError, ArchivedDocumentImmutableError,
DocumentIndexingError, DocumentIndexingError,
NoFileUploadedError,
TooManyFilesError,
) )
from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check from controllers.service_api.wraps import DatasetApiResource, cloud_edition_billing_resource_check
from core.errors.error import ProviderTokenNotInitError from core.errors.error import ProviderTokenNotInitError
@ -22,6 +26,7 @@ from fields.document_fields import document_fields, document_status_fields
from libs.login import current_user from libs.login import current_user
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from services.dataset_service import DocumentService from services.dataset_service import DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig
from services.file_service import FileService from services.file_service import FileService
@ -67,13 +72,14 @@ class DocumentAddByTextApi(DatasetApiResource):
"info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}}, "info_list": {"data_source_type": "upload_file", "file_info_list": {"file_ids": [upload_file.id]}},
} }
args["data_source"] = data_source args["data_source"] = data_source
knowledge_config = KnowledgeConfig(**args)
# validate args # validate args
DocumentService.document_create_args_validate(args) DocumentService.document_create_args_validate(knowledge_config)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=args, knowledge_config=knowledge_config,
account=current_user, account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api", created_from="api",
@ -122,12 +128,13 @@ class DocumentUpdateByTextApi(DatasetApiResource):
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
args["original_document_id"] = str(document_id) args["original_document_id"] = str(document_id)
DocumentService.document_create_args_validate(args) knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=args, knowledge_config=knowledge_config,
account=current_user, account=current_user,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api", created_from="api",
@ -186,12 +193,13 @@ class DocumentAddByFileApi(DatasetApiResource):
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
DocumentService.document_create_args_validate(args) knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=args, knowledge_config=knowledge_config,
account=dataset.created_by_account, account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api", created_from="api",
@ -234,23 +242,30 @@ class DocumentUpdateByFileApi(DatasetApiResource):
if not file.filename: if not file.filename:
raise FilenameNotExistsError raise FilenameNotExistsError
upload_file = FileService.upload_file( try:
filename=file.filename, upload_file = FileService.upload_file(
content=file.read(), filename=file.filename,
mimetype=file.mimetype, content=file.read(),
user=current_user, mimetype=file.mimetype,
source="datasets", user=current_user,
) source="datasets",
)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}} data_source = {"type": "upload_file", "info_list": {"file_info_list": {"file_ids": [upload_file.id]}}}
args["data_source"] = data_source args["data_source"] = data_source
# validate args # validate args
args["original_document_id"] = str(document_id) args["original_document_id"] = str(document_id)
DocumentService.document_create_args_validate(args)
knowledge_config = KnowledgeConfig(**args)
DocumentService.document_create_args_validate(knowledge_config)
try: try:
documents, batch = DocumentService.save_document_with_dataset_id( documents, batch = DocumentService.save_document_with_dataset_id(
dataset=dataset, dataset=dataset,
document_data=args, knowledge_config=knowledge_config,
account=dataset.created_by_account, account=dataset.created_by_account,
dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None, dataset_process_rule=dataset.latest_process_rule if "process_rule" not in args else None,
created_from="api", created_from="api",

@ -16,6 +16,7 @@ from extensions.ext_database import db
from fields.segment_fields import segment_fields from fields.segment_fields import segment_fields
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService from services.dataset_service import DatasetService, DocumentService, SegmentService
from services.entities.knowledge_entities.knowledge_entities import SegmentUpdateArgs
class SegmentApi(DatasetApiResource): class SegmentApi(DatasetApiResource):
@ -193,7 +194,7 @@ class DatasetSegmentApi(DatasetApiResource):
args = parser.parse_args() args = parser.parse_args()
SegmentService.segment_create_args_validate(args["segment"], document) SegmentService.segment_create_args_validate(args["segment"], document)
segment = SegmentService.update_segment(args["segment"], segment, document, dataset) segment = SegmentService.update_segment(SegmentUpdateArgs(**args["segment"]), segment, document, dataset)
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200 return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200

@ -105,10 +105,17 @@ class MessageFeedbackApi(WebApiResource):
parser = reqparse.RequestParser() parser = reqparse.RequestParser()
parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json") parser.add_argument("rating", type=str, choices=["like", "dislike", None], location="json")
parser.add_argument("content", type=str, location="json", default=None)
args = parser.parse_args() args = parser.parse_args()
try: try:
MessageService.create_feedback(app_model, message_id, end_user, args["rating"], args["content"]) MessageService.create_feedback(
app_model=app_model,
message_id=message_id,
user=end_user,
rating=args.get("rating"),
content=args.get("content"),
)
except services.errors.message.MessageNotExistsError: except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.") raise NotFound("Message Not Exists.")

@ -393,7 +393,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
try: try:
return generate_task_pipeline.process() return generate_task_pipeline.process()
except ValueError as e: except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError() raise GenerateTaskStoppedError()
else: else:
logger.exception(f"Failed to process generate task pipeline, conversation_id: {conversation.id}") logger.exception(f"Failed to process generate task pipeline, conversation_id: {conversation.id}")

@ -5,6 +5,9 @@ from collections.abc import Generator, Mapping
from threading import Thread from threading import Thread
from typing import Any, Optional, Union from typing import Any, Optional, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -66,7 +69,6 @@ from models.enums import CreatedByRole
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
WorkflowNodeExecution, WorkflowNodeExecution,
WorkflowRun,
WorkflowRunStatus, WorkflowRunStatus,
) )
@ -80,8 +82,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
_task_state: WorkflowTaskState _task_state: WorkflowTaskState
_application_generate_entity: AdvancedChatAppGenerateEntity _application_generate_entity: AdvancedChatAppGenerateEntity
_workflow: Workflow
_user: Union[Account, EndUser]
_workflow_system_variables: dict[SystemVariableKey, Any] _workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution] _wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_conversation_name_generate_thread: Optional[Thread] = None _conversation_name_generate_thread: Optional[Thread] = None
@ -97,32 +97,37 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
stream: bool, stream: bool,
dialogue_count: int, dialogue_count: int,
) -> None: ) -> None:
""" super().__init__(
Initialize AdvancedChatAppGenerateTaskPipeline. application_generate_entity=application_generate_entity,
:param application_generate_entity: application generate entity queue_manager=queue_manager,
:param workflow: workflow stream=stream,
:param queue_manager: queue manager )
:param conversation: conversation
:param message: message
:param user: user
:param stream: stream
:param dialogue_count: dialogue count
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
if isinstance(self._user, EndUser): if isinstance(user, EndUser):
user_id = self._user.session_id self._user_id = user.id
user_session_id = user.session_id
self._created_by_role = CreatedByRole.END_USER
elif isinstance(user, Account):
self._user_id = user.id
user_session_id = user.id
self._created_by_role = CreatedByRole.ACCOUNT
else: else:
user_id = self._user.id raise NotImplementedError(f"User type not supported: {type(user)}")
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp())
self._workflow = workflow
self._conversation = conversation
self._message = message
self._workflow_system_variables = { self._workflow_system_variables = {
SystemVariableKey.QUERY: message.query, SystemVariableKey.QUERY: message.query,
SystemVariableKey.FILES: application_generate_entity.files, SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.CONVERSATION_ID: conversation.id, SystemVariableKey.CONVERSATION_ID: conversation.id,
SystemVariableKey.USER_ID: user_id, SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.DIALOGUE_COUNT: dialogue_count, SystemVariableKey.DIALOGUE_COUNT: dialogue_count,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_ID: workflow.id,
@ -135,19 +140,16 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._conversation_name_generate_thread = None self._conversation_name_generate_thread = None
self._recorded_files: list[Mapping[str, Any]] = [] self._recorded_files: list[Mapping[str, Any]] = []
self._workflow_run_id = ""
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]: def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
""" """
Process generate task pipeline. Process generate task pipeline.
:return: :return:
""" """
db.session.refresh(self._workflow)
db.session.refresh(self._user)
db.session.close()
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name( self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation, self._application_generate_entity.query conversation_id=self._conversation_id, query=self._application_generate_entity.query
) )
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
@ -173,12 +175,12 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return ChatbotAppBlockingResponse( return ChatbotAppBlockingResponse(
task_id=stream_response.task_id, task_id=stream_response.task_id,
data=ChatbotAppBlockingResponse.Data( data=ChatbotAppBlockingResponse.Data(
id=self._message.id, id=self._message_id,
mode=self._conversation.mode, mode=self._conversation_mode,
conversation_id=self._conversation.id, conversation_id=self._conversation_id,
message_id=self._message.id, message_id=self._message_id,
answer=self._task_state.answer, answer=self._task_state.answer,
created_at=int(self._message.created_at.timestamp()), created_at=self._message_created_at,
**extras, **extras,
), ),
) )
@ -196,9 +198,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
""" """
for stream_response in generator: for stream_response in generator:
yield ChatbotAppStreamResponse( yield ChatbotAppStreamResponse(
conversation_id=self._conversation.id, conversation_id=self._conversation_id,
message_id=self._message.id, message_id=self._message_id,
created_at=int(self._message.created_at.timestamp()), created_at=self._message_created_at,
stream_response=stream_response, stream_response=stream_response,
) )
@ -216,7 +218,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
tts_publisher = None tts_publisher = None
task_id = self._application_generate_entity.task_id task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict features_dict = self._workflow_features_dict
if ( if (
features_dict.get("text_to_speech") features_dict.get("text_to_speech")
@ -268,7 +270,6 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
""" """
# init fake graph runtime state # init fake graph runtime state
graph_runtime_state: Optional[GraphRuntimeState] = None graph_runtime_state: Optional[GraphRuntimeState] = None
workflow_run: Optional[WorkflowRun] = None
for queue_message in self._queue_manager.listen(): for queue_message in self._queue_manager.listen():
event = queue_message.event event = queue_message.event
@ -276,237 +277,303 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if isinstance(event, QueuePingEvent): if isinstance(event, QueuePingEvent):
yield self._ping_stream_response() yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent): elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event, self._message) with Session(db.engine) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id)
session.commit()
yield self._error_to_stream_response(err) yield self._error_to_stream_response(err)
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state # override graph runtime state
graph_runtime_state = event.graph_runtime_state graph_runtime_state = event.graph_runtime_state
# init workflow run with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_start() # init workflow run
workflow_run = self._handle_workflow_run_start(
self._refetch_message() session=session,
self._message.workflow_run_id = workflow_run.id workflow_id=self._workflow_id,
user_id=self._user_id,
db.session.commit() created_by_role=self._created_by_role,
db.session.refresh(self._message) )
db.session.close() self._workflow_run_id = workflow_run.id
message = self._get_message(session=session)
if not message:
raise ValueError(f"Message not found: {self._message_id}")
message.workflow_run_id = workflow_run.id
workflow_start_resp = self._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield self._workflow_start_to_stream_response( yield workflow_start_resp
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
elif isinstance( elif isinstance(
event, event,
QueueNodeRetryEvent, QueueNodeRetryEvent,
): ):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
workflow_node_execution = self._handle_workflow_node_execution_retried(
workflow_run=workflow_run, event=event
)
response = self._workflow_node_retry_to_stream_response( with Session(db.engine) as session:
event=event, workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
task_id=self._application_generate_entity.task_id, workflow_node_execution = self._handle_workflow_node_execution_retried(
workflow_node_execution=workflow_node_execution, session=session, workflow_run=workflow_run, event=event
) )
node_retry_resp = self._workflow_node_retry_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
if response: if node_retry_resp:
yield response yield node_retry_resp
elif isinstance(event, QueueNodeStartedEvent): elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
workflow_node_execution = self._handle_node_execution_start(
session=session, workflow_run=workflow_run, event=event
)
response_start = self._workflow_node_start_to_stream_response( node_start_resp = self._workflow_node_start_to_stream_response(
event=event, session=session,
task_id=self._application_generate_entity.task_id, event=event,
workflow_node_execution=workflow_node_execution, task_id=self._application_generate_entity.task_id,
) workflow_node_execution=workflow_node_execution,
)
session.commit()
if response_start: if node_start_resp:
yield response_start yield node_start_resp
elif isinstance(event, QueueNodeSucceededEvent): elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event)
# Record files if it's an answer node or end node # Record files if it's an answer node or end node
if event.node_type in [NodeType.ANSWER, NodeType.END]: if event.node_type in [NodeType.ANSWER, NodeType.END]:
self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {})) self._recorded_files.extend(self._fetch_files_from_node_outputs(event.outputs or {}))
response_finish = self._workflow_node_finish_to_stream_response( with Session(db.engine) as session:
event=event, workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if response_finish:
yield response_finish
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event)
response_finish = self._workflow_node_finish_to_stream_response( node_finish_resp = self._workflow_node_finish_to_stream_response(
event=event, session=session,
task_id=self._application_generate_entity.task_id, event=event,
workflow_node_execution=workflow_node_execution, task_id=self._application_generate_entity.task_id,
) workflow_node_execution=workflow_node_execution,
)
session.commit()
if response_finish: if node_finish_resp:
yield response_finish yield node_finish_resp
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed(session=session, event=event)
node_finish_resp = self._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
if node_finish_resp:
yield node_finish_resp
elif isinstance(event, QueueParallelBranchRunStartedEvent): elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
yield self._workflow_parallel_branch_start_to_stream_response( with Session(db.engine) as session:
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
) parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
yield parallel_start_resp
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
yield self._workflow_parallel_branch_finished_to_stream_response( with Session(db.engine) as session:
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
) parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
yield parallel_finish_resp
elif isinstance(event, QueueIterationStartEvent): elif isinstance(event, QueueIterationStartEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
yield self._workflow_iteration_start_to_stream_response( with Session(db.engine) as session:
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
) iter_start_resp = self._workflow_iteration_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
yield iter_start_resp
elif isinstance(event, QueueIterationNextEvent): elif isinstance(event, QueueIterationNextEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
yield self._workflow_iteration_next_to_stream_response( with Session(db.engine) as session:
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
) iter_next_resp = self._workflow_iteration_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
yield iter_next_resp
elif isinstance(event, QueueIterationCompletedEvent): elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
yield self._workflow_iteration_completed_to_stream_response( with Session(db.engine) as session:
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
) iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
yield iter_finish_resp
elif isinstance(event, QueueWorkflowSucceededEvent): elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
workflow_run = self._handle_workflow_run_success( with Session(db.engine) as session:
workflow_run=workflow_run, workflow_run = self._handle_workflow_run_success(
start_at=graph_runtime_state.start_at, session=session,
total_tokens=graph_runtime_state.total_tokens, workflow_run_id=self._workflow_run_id,
total_steps=graph_runtime_state.node_run_steps, start_at=graph_runtime_state.start_at,
outputs=event.outputs, total_tokens=graph_runtime_state.total_tokens,
conversation_id=self._conversation.id, total_steps=graph_runtime_state.node_run_steps,
trace_manager=trace_manager, outputs=event.outputs,
) conversation_id=self._conversation_id,
trace_manager=trace_manager,
)
yield self._workflow_finish_to_stream_response( workflow_finish_resp = self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
session.commit()
yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowPartialSuccessEvent): elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_partial_success( with Session(db.engine) as session:
workflow_run=workflow_run, workflow_run = self._handle_workflow_run_partial_success(
start_at=graph_runtime_state.start_at, session=session,
total_tokens=graph_runtime_state.total_tokens, workflow_run_id=self._workflow_run_id,
total_steps=graph_runtime_state.node_run_steps, start_at=graph_runtime_state.start_at,
outputs=event.outputs, total_tokens=graph_runtime_state.total_tokens,
exceptions_count=event.exceptions_count, total_steps=graph_runtime_state.node_run_steps,
conversation_id=None, outputs=event.outputs,
trace_manager=trace_manager, exceptions_count=event.exceptions_count,
) conversation_id=None,
trace_manager=trace_manager,
yield self._workflow_finish_to_stream_response( )
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run workflow_finish_resp = self._workflow_finish_to_stream_response(
) session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield workflow_finish_resp
self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE) self._queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
elif isinstance(event, QueueWorkflowFailedEvent): elif isinstance(event, QueueWorkflowFailedEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed( with Session(db.engine) as session:
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED,
error=event.error,
conversation_id=self._conversation.id,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
)
yield self._workflow_finish_to_stream_response(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
yield self._error_to_stream_response(self._handle_error(err_event, self._message))
break
elif isinstance(event, QueueStopEvent):
if workflow_run and graph_runtime_state:
workflow_run = self._handle_workflow_run_failed( workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run, session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at, start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens, total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps, total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.STOPPED, status=WorkflowRunStatus.FAILED,
error=event.get_stop_reason(), error=event.error,
conversation_id=self._conversation.id, conversation_id=self._conversation_id,
trace_manager=trace_manager, trace_manager=trace_manager,
exceptions_count=event.exceptions_count,
) )
workflow_finish_resp = self._workflow_finish_to_stream_response(
yield self._workflow_finish_to_stream_response( session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) )
err_event = QueueErrorEvent(error=ValueError(f"Run failed: {workflow_run.error}"))
err = self._handle_error(event=err_event, session=session, message_id=self._message_id)
session.commit()
# Save message yield workflow_finish_resp
self._save_message(graph_runtime_state=graph_runtime_state) yield self._error_to_stream_response(err)
break
elif isinstance(event, QueueStopEvent):
if self._workflow_run_id and graph_runtime_state:
with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_failed(
session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.STOPPED,
error=event.get_stop_reason(),
conversation_id=self._conversation_id,
trace_manager=trace_manager,
)
workflow_finish_resp = self._workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
)
# Save message
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
session.commit()
yield workflow_finish_resp
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
break break
elif isinstance(event, QueueRetrieverResourcesEvent): elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event) self._handle_retriever_resources(event)
self._refetch_message() with Session(db.engine) as session:
message = self._get_message(session=session)
self._message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
session.commit()
db.session.commit()
db.session.refresh(self._message)
db.session.close()
elif isinstance(event, QueueAnnotationReplyEvent): elif isinstance(event, QueueAnnotationReplyEvent):
self._handle_annotation_reply(event) self._handle_annotation_reply(event)
self._refetch_message() with Session(db.engine) as session:
message = self._get_message(session=session)
self._message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
session.commit()
db.session.commit()
db.session.refresh(self._message)
db.session.close()
elif isinstance(event, QueueTextChunkEvent): elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text delta_text = event.text
if delta_text is None: if delta_text is None:
@ -523,7 +590,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
self._task_state.answer += delta_text self._task_state.answer += delta_text
yield self._message_to_stream_response( yield self._message_to_stream_response(
answer=delta_text, message_id=self._message.id, from_variable_selector=event.from_variable_selector answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
) )
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
# published by moderation # published by moderation
@ -538,7 +605,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
yield self._message_replace_to_stream_response(answer=output_moderation_answer) yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message # Save message
self._save_message(graph_runtime_state=graph_runtime_state) with Session(db.engine) as session:
self._save_message(session=session, graph_runtime_state=graph_runtime_state)
session.commit()
yield self._message_end_to_stream_response() yield self._message_end_to_stream_response()
elif isinstance(event, QueueAgentLogEvent): elif isinstance(event, QueueAgentLogEvent):
@ -553,54 +622,46 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
if self._conversation_name_generate_thread: if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join() self._conversation_name_generate_thread.join()
def _save_message(self, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None: def _save_message(self, *, session: Session, graph_runtime_state: Optional[GraphRuntimeState] = None) -> None:
self._refetch_message() message = self._get_message(session=session)
message.answer = self._task_state.answer
self._message.answer = self._task_state.answer message.provider_response_latency = time.perf_counter() - self._start_at
self._message.provider_response_latency = time.perf_counter() - self._start_at message.message_metadata = (
self._message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
message_files = [ message_files = [
MessageFile( MessageFile(
message_id=self._message.id, message_id=message.id,
type=file["type"], type=file["type"],
transfer_method=file["transfer_method"], transfer_method=file["transfer_method"],
url=file["remote_url"], url=file["remote_url"],
belongs_to="assistant", belongs_to="assistant",
upload_file_id=file["related_id"], upload_file_id=file["related_id"],
created_by_role=CreatedByRole.ACCOUNT created_by_role=CreatedByRole.ACCOUNT
if self._message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else CreatedByRole.END_USER, else CreatedByRole.END_USER,
created_by=self._message.from_account_id or self._message.from_end_user_id or "", created_by=message.from_account_id or message.from_end_user_id or "",
) )
for file in self._recorded_files for file in self._recorded_files
] ]
db.session.add_all(message_files) session.add_all(message_files)
if graph_runtime_state and graph_runtime_state.llm_usage: if graph_runtime_state and graph_runtime_state.llm_usage:
usage = graph_runtime_state.llm_usage usage = graph_runtime_state.llm_usage
self._message.message_tokens = usage.prompt_tokens message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit message.message_price_unit = usage.prompt_price_unit
self._message.answer_tokens = usage.completion_tokens message.answer_tokens = usage.completion_tokens
self._message.answer_unit_price = usage.completion_unit_price message.answer_unit_price = usage.completion_unit_price
self._message.answer_price_unit = usage.completion_price_unit message.answer_price_unit = usage.completion_price_unit
self._message.total_price = usage.total_price message.total_price = usage.total_price
self._message.currency = usage.currency message.currency = usage.currency
self._task_state.metadata["usage"] = jsonable_encoder(usage) self._task_state.metadata["usage"] = jsonable_encoder(usage)
else: else:
self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage()) self._task_state.metadata["usage"] = jsonable_encoder(LLMUsage.empty_usage())
db.session.commit()
message_was_created.send( message_was_created.send(
self._message, message,
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras,
) )
def _message_end_to_stream_response(self) -> MessageEndStreamResponse: def _message_end_to_stream_response(self) -> MessageEndStreamResponse:
@ -617,7 +678,7 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return MessageEndStreamResponse( return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
id=self._message.id, id=self._message_id,
files=self._recorded_files, files=self._recorded_files,
metadata=extras.get("metadata", {}), metadata=extras.get("metadata", {}),
) )
@ -645,11 +706,9 @@ class AdvancedChatAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCyc
return False return False
def _refetch_message(self) -> None: def _get_message(self, *, session: Session):
""" stmt = select(Message).where(Message.id == self._message_id)
Refetch message. message = session.scalar(stmt)
:return: if not message:
""" raise ValueError(f"Message not found: {self._message_id}")
message = db.session.query(Message).filter(Message.id == self._message.id).first() return message
if message:
self._message = message

@ -70,14 +70,13 @@ class MessageBasedAppGenerator(BaseAppGenerator):
queue_manager=queue_manager, queue_manager=queue_manager,
conversation=conversation, conversation=conversation,
message=message, message=message,
user=user,
stream=stream, stream=stream,
) )
try: try:
return generate_task_pipeline.process() return generate_task_pipeline.process()
except ValueError as e: except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError() raise GenerateTaskStoppedError()
else: else:
logger.exception(f"Failed to handle response, conversation_id: {conversation.id}") logger.exception(f"Failed to handle response, conversation_id: {conversation.id}")

@ -325,7 +325,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
try: try:
return generate_task_pipeline.process() return generate_task_pipeline.process()
except ValueError as e: except ValueError as e:
if e.args[0] == "I/O operation on closed file.": # ignore this error if len(e.args) > 0 and e.args[0] == "I/O operation on closed file.": # ignore this error
raise GenerateTaskStoppedError() raise GenerateTaskStoppedError()
else: else:
logger.exception( logger.exception(

@ -3,6 +3,8 @@ import time
from collections.abc import Generator from collections.abc import Generator
from typing import Any, Optional, Union from typing import Any, Optional, Union
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
@ -51,6 +53,7 @@ from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.enums import CreatedByRole
from models.model import EndUser from models.model import EndUser
from models.workflow import ( from models.workflow import (
Workflow, Workflow,
@ -69,8 +72,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application. WorkflowAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
""" """
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState _task_state: WorkflowTaskState
_application_generate_entity: WorkflowAppGenerateEntity _application_generate_entity: WorkflowAppGenerateEntity
_workflow_system_variables: dict[SystemVariableKey, Any] _workflow_system_variables: dict[SystemVariableKey, Any]
@ -84,44 +85,42 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
user: Union[Account, EndUser], user: Union[Account, EndUser],
stream: bool, stream: bool,
) -> None: ) -> None:
""" super().__init__(
Initialize GenerateTaskPipeline. application_generate_entity=application_generate_entity,
:param application_generate_entity: application generate entity queue_manager=queue_manager,
:param workflow: workflow stream=stream,
:param queue_manager: queue manager )
:param user: user
:param stream: is streamed
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
if isinstance(self._user, EndUser): if isinstance(user, EndUser):
user_id = self._user.session_id self._user_id = user.id
user_session_id = user.session_id
self._created_by_role = CreatedByRole.END_USER
elif isinstance(user, Account):
self._user_id = user.id
user_session_id = user.id
self._created_by_role = CreatedByRole.ACCOUNT
else: else:
user_id = self._user.id raise ValueError(f"Invalid user type: {type(user)}")
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._workflow = workflow
self._workflow_system_variables = { self._workflow_system_variables = {
SystemVariableKey.FILES: application_generate_entity.files, SystemVariableKey.FILES: application_generate_entity.files,
SystemVariableKey.USER_ID: user_id, SystemVariableKey.USER_ID: user_session_id,
SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id, SystemVariableKey.APP_ID: application_generate_entity.app_config.app_id,
SystemVariableKey.WORKFLOW_ID: workflow.id, SystemVariableKey.WORKFLOW_ID: workflow.id,
SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id, SystemVariableKey.WORKFLOW_RUN_ID: application_generate_entity.workflow_run_id,
} }
self._task_state = WorkflowTaskState() self._task_state = WorkflowTaskState()
self._wip_workflow_node_executions = {} self._workflow_run_id = ""
self._wip_workflow_agent_logs = {}
self.total_tokens: int = 0
def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]: def process(self) -> Union[WorkflowAppBlockingResponse, Generator[WorkflowAppStreamResponse, None, None]]:
""" """
Process generate task pipeline. Process generate task pipeline.
:return: :return:
""" """
db.session.refresh(self._workflow)
db.session.refresh(self._user)
db.session.close()
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
if self._stream: if self._stream:
return self._to_stream_response(generator) return self._to_stream_response(generator)
@ -188,7 +187,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
tts_publisher = None tts_publisher = None
task_id = self._application_generate_entity.task_id task_id = self._application_generate_entity.task_id
tenant_id = self._application_generate_entity.app_config.tenant_id tenant_id = self._application_generate_entity.app_config.tenant_id
features_dict = self._workflow.features_dict features_dict = self._workflow_features_dict
if ( if (
features_dict.get("text_to_speech") features_dict.get("text_to_speech")
@ -237,7 +236,6 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
:return: :return:
""" """
graph_runtime_state = None graph_runtime_state = None
workflow_run = None
for queue_message in self._queue_manager.listen(): for queue_message in self._queue_manager.listen():
event = queue_message.event event = queue_message.event
@ -245,180 +243,261 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if isinstance(event, QueuePingEvent): if isinstance(event, QueuePingEvent):
yield self._ping_stream_response() yield self._ping_stream_response()
elif isinstance(event, QueueErrorEvent): elif isinstance(event, QueueErrorEvent):
err = self._handle_error(event) err = self._handle_error(event=event)
yield self._error_to_stream_response(err) yield self._error_to_stream_response(err)
break break
elif isinstance(event, QueueWorkflowStartedEvent): elif isinstance(event, QueueWorkflowStartedEvent):
# override graph runtime state # override graph runtime state
graph_runtime_state = event.graph_runtime_state graph_runtime_state = event.graph_runtime_state
# init workflow run with Session(db.engine) as session:
workflow_run = self._handle_workflow_run_start() # init workflow run
yield self._workflow_start_to_stream_response( workflow_run = self._handle_workflow_run_start(
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run session=session,
) workflow_id=self._workflow_id,
user_id=self._user_id,
created_by_role=self._created_by_role,
)
self._workflow_run_id = workflow_run.id
start_resp = self._workflow_start_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield start_resp
elif isinstance( elif isinstance(
event, event,
QueueNodeRetryEvent, QueueNodeRetryEvent,
): ):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
workflow_node_execution = self._handle_workflow_node_execution_retried( with Session(db.engine) as session:
workflow_run=workflow_run, event=event workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
) workflow_node_execution = self._handle_workflow_node_execution_retried(
session=session, workflow_run=workflow_run, event=event
response = self._workflow_node_retry_to_stream_response( )
event=event, response = self._workflow_node_retry_to_stream_response(
task_id=self._application_generate_entity.task_id, session=session,
workflow_node_execution=workflow_node_execution, event=event,
) task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
if response: if response:
yield response yield response
elif isinstance(event, QueueNodeStartedEvent): elif isinstance(event, QueueNodeStartedEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
workflow_node_execution = self._handle_node_execution_start(workflow_run=workflow_run, event=event) with Session(db.engine) as session:
workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
node_start_response = self._workflow_node_start_to_stream_response( workflow_node_execution = self._handle_node_execution_start(
event=event, session=session, workflow_run=workflow_run, event=event
task_id=self._application_generate_entity.task_id, )
workflow_node_execution=workflow_node_execution, node_start_response = self._workflow_node_start_to_stream_response(
) session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
if node_start_response: if node_start_response:
yield node_start_response yield node_start_response
elif isinstance(event, QueueNodeSucceededEvent): elif isinstance(event, QueueNodeSucceededEvent):
workflow_node_execution = self._handle_workflow_node_execution_success(event) with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_success(session=session, event=event)
node_success_response = self._workflow_node_finish_to_stream_response( node_success_response = self._workflow_node_finish_to_stream_response(
event=event, session=session,
task_id=self._application_generate_entity.task_id, event=event,
workflow_node_execution=workflow_node_execution, task_id=self._application_generate_entity.task_id,
) workflow_node_execution=workflow_node_execution,
)
session.commit()
if node_success_response: if node_success_response:
yield node_success_response yield node_success_response
elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent): elif isinstance(event, QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent):
workflow_node_execution = self._handle_workflow_node_execution_failed(event) with Session(db.engine) as session:
workflow_node_execution = self._handle_workflow_node_execution_failed(
session=session,
event=event,
)
node_failed_response = self._workflow_node_finish_to_stream_response(
session=session,
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
session.commit()
node_failed_response = self._workflow_node_finish_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
workflow_node_execution=workflow_node_execution,
)
if node_failed_response: if node_failed_response:
yield node_failed_response yield node_failed_response
elif isinstance(event, QueueParallelBranchRunStartedEvent): elif isinstance(event, QueueParallelBranchRunStartedEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
yield self._workflow_parallel_branch_start_to_stream_response( with Session(db.engine) as session:
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
) parallel_start_resp = self._workflow_parallel_branch_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
yield parallel_start_resp
elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent): elif isinstance(event, QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
yield self._workflow_parallel_branch_finished_to_stream_response( with Session(db.engine) as session:
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
) parallel_finish_resp = self._workflow_parallel_branch_finished_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
yield parallel_finish_resp
elif isinstance(event, QueueIterationStartEvent): elif isinstance(event, QueueIterationStartEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
yield self._workflow_iteration_start_to_stream_response( with Session(db.engine) as session:
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
) iter_start_resp = self._workflow_iteration_start_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
yield iter_start_resp
elif isinstance(event, QueueIterationNextEvent): elif isinstance(event, QueueIterationNextEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
yield self._workflow_iteration_next_to_stream_response( with Session(db.engine) as session:
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
) iter_next_resp = self._workflow_iteration_next_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
yield iter_next_resp
elif isinstance(event, QueueIterationCompletedEvent): elif isinstance(event, QueueIterationCompletedEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
yield self._workflow_iteration_completed_to_stream_response( with Session(db.engine) as session:
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run, event=event workflow_run = self._get_workflow_run(session=session, workflow_run_id=self._workflow_run_id)
) iter_finish_resp = self._workflow_iteration_completed_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
event=event,
)
yield iter_finish_resp
elif isinstance(event, QueueWorkflowSucceededEvent): elif isinstance(event, QueueWorkflowSucceededEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_success( with Session(db.engine) as session:
workflow_run=workflow_run, workflow_run = self._handle_workflow_run_success(
start_at=graph_runtime_state.start_at, session=session,
total_tokens=graph_runtime_state.total_tokens, workflow_run_id=self._workflow_run_id,
total_steps=graph_runtime_state.node_run_steps, start_at=graph_runtime_state.start_at,
outputs=event.outputs, total_tokens=graph_runtime_state.total_tokens,
conversation_id=None, total_steps=graph_runtime_state.node_run_steps,
trace_manager=trace_manager, outputs=event.outputs,
) conversation_id=None,
trace_manager=trace_manager,
# save workflow app log )
self._save_workflow_app_log(workflow_run)
# save workflow app log
yield self._workflow_finish_to_stream_response( self._save_workflow_app_log(session=session, workflow_run=workflow_run)
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) workflow_finish_resp = self._workflow_finish_to_stream_response(
session=session,
task_id=self._application_generate_entity.task_id,
workflow_run=workflow_run,
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueWorkflowPartialSuccessEvent): elif isinstance(event, QueueWorkflowPartialSuccessEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_partial_success( with Session(db.engine) as session:
workflow_run=workflow_run, workflow_run = self._handle_workflow_run_partial_success(
start_at=graph_runtime_state.start_at, session=session,
total_tokens=graph_runtime_state.total_tokens, workflow_run_id=self._workflow_run_id,
total_steps=graph_runtime_state.node_run_steps, start_at=graph_runtime_state.start_at,
outputs=event.outputs, total_tokens=graph_runtime_state.total_tokens,
exceptions_count=event.exceptions_count, total_steps=graph_runtime_state.node_run_steps,
conversation_id=None, outputs=event.outputs,
trace_manager=trace_manager, exceptions_count=event.exceptions_count,
) conversation_id=None,
trace_manager=trace_manager,
# save workflow app log )
self._save_workflow_app_log(workflow_run)
# save workflow app log
yield self._workflow_finish_to_stream_response( self._save_workflow_app_log(session=session, workflow_run=workflow_run)
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
) workflow_finish_resp = self._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent): elif isinstance(event, QueueWorkflowFailedEvent | QueueStopEvent):
if not workflow_run: if not self._workflow_run_id:
raise ValueError("workflow run not initialized.") raise ValueError("workflow run not initialized.")
if not graph_runtime_state: if not graph_runtime_state:
raise ValueError("graph runtime state not initialized.") raise ValueError("graph runtime state not initialized.")
workflow_run = self._handle_workflow_run_failed(
workflow_run=workflow_run,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)
# save workflow app log
self._save_workflow_app_log(workflow_run)
yield self._workflow_finish_to_stream_response( with Session(db.engine) as session:
task_id=self._application_generate_entity.task_id, workflow_run=workflow_run workflow_run = self._handle_workflow_run_failed(
) session=session,
workflow_run_id=self._workflow_run_id,
start_at=graph_runtime_state.start_at,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
status=WorkflowRunStatus.FAILED
if isinstance(event, QueueWorkflowFailedEvent)
else WorkflowRunStatus.STOPPED,
error=event.error if isinstance(event, QueueWorkflowFailedEvent) else event.get_stop_reason(),
conversation_id=None,
trace_manager=trace_manager,
exceptions_count=event.exceptions_count if isinstance(event, QueueWorkflowFailedEvent) else 0,
)
# save workflow app log
self._save_workflow_app_log(session=session, workflow_run=workflow_run)
workflow_finish_resp = self._workflow_finish_to_stream_response(
session=session, task_id=self._application_generate_entity.task_id, workflow_run=workflow_run
)
session.commit()
yield workflow_finish_resp
elif isinstance(event, QueueTextChunkEvent): elif isinstance(event, QueueTextChunkEvent):
delta_text = event.text delta_text = event.text
if delta_text is None: if delta_text is None:
@ -440,7 +519,7 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
if tts_publisher: if tts_publisher:
tts_publisher.publish(None) tts_publisher.publish(None)
def _save_workflow_app_log(self, workflow_run: WorkflowRun) -> None: def _save_workflow_app_log(self, *, session: Session, workflow_run: WorkflowRun) -> None:
""" """
Save workflow app log. Save workflow app log.
:return: :return:
@ -462,12 +541,10 @@ class WorkflowAppGenerateTaskPipeline(BasedGenerateTaskPipeline, WorkflowCycleMa
workflow_app_log.workflow_id = workflow_run.workflow_id workflow_app_log.workflow_id = workflow_run.workflow_id
workflow_app_log.workflow_run_id = workflow_run.id workflow_app_log.workflow_run_id = workflow_run.id
workflow_app_log.created_from = created_from.value workflow_app_log.created_from = created_from.value
workflow_app_log.created_by_role = "account" if isinstance(self._user, Account) else "end_user" workflow_app_log.created_by_role = self._created_by_role
workflow_app_log.created_by = self._user.id workflow_app_log.created_by = self._user_id
db.session.add(workflow_app_log) session.add(workflow_app_log)
db.session.commit()
db.session.close()
def _text_chunk_to_stream_response( def _text_chunk_to_stream_response(
self, text: str, from_variable_selector: Optional[list[str]] = None self, text: str, from_variable_selector: Optional[list[str]] = None

@ -1,6 +1,9 @@
import logging import logging
import time import time
from typing import Optional, Union from typing import Optional
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import ( from core.app.entities.app_invoke_entities import (
@ -17,9 +20,7 @@ from core.app.entities.task_entities import (
from core.errors.error import QuotaExceededError from core.errors.error import QuotaExceededError
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.moderation.output_moderation import ModerationRule, OutputModeration from core.moderation.output_moderation import ModerationRule, OutputModeration
from extensions.ext_database import db from models.model import Message
from models.account import Account
from models.model import EndUser, Message
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -36,7 +37,6 @@ class BasedGenerateTaskPipeline:
self, self,
application_generate_entity: AppGenerateEntity, application_generate_entity: AppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
user: Union[Account, EndUser],
stream: bool, stream: bool,
) -> None: ) -> None:
""" """
@ -48,18 +48,11 @@ class BasedGenerateTaskPipeline:
""" """
self._application_generate_entity = application_generate_entity self._application_generate_entity = application_generate_entity
self._queue_manager = queue_manager self._queue_manager = queue_manager
self._user = user
self._start_at = time.perf_counter() self._start_at = time.perf_counter()
self._output_moderation_handler = self._init_output_moderation() self._output_moderation_handler = self._init_output_moderation()
self._stream = stream self._stream = stream
def _handle_error(self, event: QueueErrorEvent, message: Optional[Message] = None): def _handle_error(self, *, event: QueueErrorEvent, session: Session | None = None, message_id: str = ""):
"""
Handle error event.
:param event: event
:param message: message
:return:
"""
logger.debug("error: %s", event.error) logger.debug("error: %s", event.error)
e = event.error e = event.error
err: Exception err: Exception
@ -71,16 +64,17 @@ class BasedGenerateTaskPipeline:
else: else:
err = Exception(e.description if getattr(e, "description", None) is not None else str(e)) err = Exception(e.description if getattr(e, "description", None) is not None else str(e))
if message: if not message_id or not session:
refetch_message = db.session.query(Message).filter(Message.id == message.id).first() return err
if refetch_message:
err_desc = self._error_to_desc(err)
refetch_message.status = "error"
refetch_message.error = err_desc
db.session.commit() stmt = select(Message).where(Message.id == message_id)
message = session.scalar(stmt)
if not message:
return err
err_desc = self._error_to_desc(err)
message.status = "error"
message.error = err_desc
return err return err
def _error_to_desc(self, e: Exception) -> str: def _error_to_desc(self, e: Exception) -> str:

@ -5,6 +5,9 @@ from collections.abc import Generator
from threading import Thread from threading import Thread
from typing import Optional, Union, cast from typing import Optional, Union, cast
from sqlalchemy import select
from sqlalchemy.orm import Session
from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME
from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk from core.app.apps.advanced_chat.app_generator_tts_publisher import AppGeneratorTTSPublisher, AudioTrunk
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
@ -55,8 +58,7 @@ from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from events.message_event import message_was_created from events.message_event import message_was_created
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.model import AppMode, Conversation, Message, MessageAgentThought
from models.model import AppMode, Conversation, EndUser, Message, MessageAgentThought
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -77,23 +79,21 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
user: Union[Account, EndUser],
stream: bool, stream: bool,
) -> None: ) -> None:
""" super().__init__(
Initialize GenerateTaskPipeline. application_generate_entity=application_generate_entity,
:param application_generate_entity: application generate entity queue_manager=queue_manager,
:param queue_manager: queue manager stream=stream,
:param conversation: conversation )
:param message: message
:param user: user
:param stream: stream
"""
super().__init__(application_generate_entity, queue_manager, user, stream)
self._model_config = application_generate_entity.model_conf self._model_config = application_generate_entity.model_conf
self._app_config = application_generate_entity.app_config self._app_config = application_generate_entity.app_config
self._conversation = conversation
self._message = message self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
self._message_created_at = int(message.created_at.timestamp())
self._task_state = EasyUITaskState( self._task_state = EasyUITaskState(
llm_result=LLMResult( llm_result=LLMResult(
@ -113,18 +113,10 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
CompletionAppBlockingResponse, CompletionAppBlockingResponse,
Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None], Generator[Union[ChatbotAppStreamResponse, CompletionAppStreamResponse], None, None],
]: ]:
"""
Process generate task pipeline.
:return:
"""
db.session.refresh(self._conversation)
db.session.refresh(self._message)
db.session.close()
if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION: if self._application_generate_entity.app_config.app_mode != AppMode.COMPLETION:
# start generate conversation name thread # start generate conversation name thread
self._conversation_name_generate_thread = self._generate_conversation_name( self._conversation_name_generate_thread = self._generate_conversation_name(
self._conversation, self._application_generate_entity.query or "" conversation_id=self._conversation_id, query=self._application_generate_entity.query or ""
) )
generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager) generator = self._wrapper_process_stream_response(trace_manager=self._application_generate_entity.trace_manager)
@ -148,15 +140,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if self._task_state.metadata: if self._task_state.metadata:
extras["metadata"] = self._task_state.metadata extras["metadata"] = self._task_state.metadata
response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse] response: Union[ChatbotAppBlockingResponse, CompletionAppBlockingResponse]
if self._conversation.mode == AppMode.COMPLETION.value: if self._conversation_mode == AppMode.COMPLETION.value:
response = CompletionAppBlockingResponse( response = CompletionAppBlockingResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
data=CompletionAppBlockingResponse.Data( data=CompletionAppBlockingResponse.Data(
id=self._message.id, id=self._message_id,
mode=self._conversation.mode, mode=self._conversation_mode,
message_id=self._message.id, message_id=self._message_id,
answer=cast(str, self._task_state.llm_result.message.content), answer=cast(str, self._task_state.llm_result.message.content),
created_at=int(self._message.created_at.timestamp()), created_at=self._message_created_at,
**extras, **extras,
), ),
) )
@ -164,12 +156,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
response = ChatbotAppBlockingResponse( response = ChatbotAppBlockingResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
data=ChatbotAppBlockingResponse.Data( data=ChatbotAppBlockingResponse.Data(
id=self._message.id, id=self._message_id,
mode=self._conversation.mode, mode=self._conversation_mode,
conversation_id=self._conversation.id, conversation_id=self._conversation_id,
message_id=self._message.id, message_id=self._message_id,
answer=cast(str, self._task_state.llm_result.message.content), answer=cast(str, self._task_state.llm_result.message.content),
created_at=int(self._message.created_at.timestamp()), created_at=self._message_created_at,
**extras, **extras,
), ),
) )
@ -190,15 +182,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
for stream_response in generator: for stream_response in generator:
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity): if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
yield CompletionAppStreamResponse( yield CompletionAppStreamResponse(
message_id=self._message.id, message_id=self._message_id,
created_at=int(self._message.created_at.timestamp()), created_at=self._message_created_at,
stream_response=stream_response, stream_response=stream_response,
) )
else: else:
yield ChatbotAppStreamResponse( yield ChatbotAppStreamResponse(
conversation_id=self._conversation.id, conversation_id=self._conversation_id,
message_id=self._message.id, message_id=self._message_id,
created_at=int(self._message.created_at.timestamp()), created_at=self._message_created_at,
stream_response=stream_response, stream_response=stream_response,
) )
@ -265,7 +257,9 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
event = message.event event = message.event
if isinstance(event, QueueErrorEvent): if isinstance(event, QueueErrorEvent):
err = self._handle_error(event, self._message) with Session(db.engine) as session:
err = self._handle_error(event=event, session=session, message_id=self._message_id)
session.commit()
yield self._error_to_stream_response(err) yield self._error_to_stream_response(err)
break break
elif isinstance(event, QueueStopEvent | QueueMessageEndEvent): elif isinstance(event, QueueStopEvent | QueueMessageEndEvent):
@ -283,10 +277,12 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._task_state.llm_result.message.content = output_moderation_answer self._task_state.llm_result.message.content = output_moderation_answer
yield self._message_replace_to_stream_response(answer=output_moderation_answer) yield self._message_replace_to_stream_response(answer=output_moderation_answer)
# Save message with Session(db.engine) as session:
self._save_message(trace_manager) # Save message
self._save_message(session=session, trace_manager=trace_manager)
yield self._message_end_to_stream_response() session.commit()
message_end_resp = self._message_end_to_stream_response()
yield message_end_resp
elif isinstance(event, QueueRetrieverResourcesEvent): elif isinstance(event, QueueRetrieverResourcesEvent):
self._handle_retriever_resources(event) self._handle_retriever_resources(event)
elif isinstance(event, QueueAnnotationReplyEvent): elif isinstance(event, QueueAnnotationReplyEvent):
@ -320,9 +316,15 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
self._task_state.llm_result.message.content = current_content self._task_state.llm_result.message.content = current_content
if isinstance(event, QueueLLMChunkEvent): if isinstance(event, QueueLLMChunkEvent):
yield self._message_to_stream_response(cast(str, delta_text), self._message.id) yield self._message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
)
else: else:
yield self._agent_message_to_stream_response(cast(str, delta_text), self._message.id) yield self._agent_message_to_stream_response(
answer=cast(str, delta_text),
message_id=self._message_id,
)
elif isinstance(event, QueueMessageReplaceEvent): elif isinstance(event, QueueMessageReplaceEvent):
yield self._message_replace_to_stream_response(answer=event.text) yield self._message_replace_to_stream_response(answer=event.text)
elif isinstance(event, QueuePingEvent): elif isinstance(event, QueuePingEvent):
@ -334,7 +336,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
if self._conversation_name_generate_thread: if self._conversation_name_generate_thread:
self._conversation_name_generate_thread.join() self._conversation_name_generate_thread.join()
def _save_message(self, trace_manager: Optional[TraceQueueManager] = None) -> None: def _save_message(self, *, session: Session, trace_manager: Optional[TraceQueueManager] = None) -> None:
""" """
Save message. Save message.
:return: :return:
@ -342,53 +344,46 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
llm_result = self._task_state.llm_result llm_result = self._task_state.llm_result
usage = llm_result.usage usage = llm_result.usage
message = db.session.query(Message).filter(Message.id == self._message.id).first() message_stmt = select(Message).where(Message.id == self._message_id)
message = session.scalar(message_stmt)
if not message: if not message:
raise Exception(f"Message {self._message.id} not found") raise ValueError(f"message {self._message_id} not found")
self._message = message conversation_stmt = select(Conversation).where(Conversation.id == self._conversation_id)
conversation = db.session.query(Conversation).filter(Conversation.id == self._conversation.id).first() conversation = session.scalar(conversation_stmt)
if not conversation: if not conversation:
raise Exception(f"Conversation {self._conversation.id} not found") raise ValueError(f"Conversation {self._conversation_id} not found")
self._conversation = conversation
self._message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving( message.message = PromptMessageUtil.prompt_messages_to_prompt_for_saving(
self._model_config.mode, self._task_state.llm_result.prompt_messages self._model_config.mode, self._task_state.llm_result.prompt_messages
) )
self._message.message_tokens = usage.prompt_tokens message.message_tokens = usage.prompt_tokens
self._message.message_unit_price = usage.prompt_unit_price message.message_unit_price = usage.prompt_unit_price
self._message.message_price_unit = usage.prompt_price_unit message.message_price_unit = usage.prompt_price_unit
self._message.answer = ( message.answer = (
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip()) PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
if llm_result.message.content if llm_result.message.content
else "" else ""
) )
self._message.answer_tokens = usage.completion_tokens message.answer_tokens = usage.completion_tokens
self._message.answer_unit_price = usage.completion_unit_price message.answer_unit_price = usage.completion_unit_price
self._message.answer_price_unit = usage.completion_price_unit message.answer_price_unit = usage.completion_price_unit
self._message.provider_response_latency = time.perf_counter() - self._start_at message.provider_response_latency = time.perf_counter() - self._start_at
self._message.total_price = usage.total_price message.total_price = usage.total_price
self._message.currency = usage.currency message.currency = usage.currency
self._message.message_metadata = ( message.message_metadata = (
json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None json.dumps(jsonable_encoder(self._task_state.metadata)) if self._task_state.metadata else None
) )
db.session.commit()
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
TraceTask( TraceTask(
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation.id, message_id=self._message.id TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
) )
) )
message_was_created.send( message_was_created.send(
self._message, message,
application_generate_entity=self._application_generate_entity, application_generate_entity=self._application_generate_entity,
conversation=self._conversation,
is_first_message=self._application_generate_entity.app_config.app_mode in {AppMode.AGENT_CHAT, AppMode.CHAT}
and hasattr(self._application_generate_entity, "conversation_id")
and self._application_generate_entity.conversation_id is None,
extras=self._application_generate_entity.extras,
) )
def _handle_stop(self, event: QueueStopEvent) -> None: def _handle_stop(self, event: QueueStopEvent) -> None:
@ -434,7 +429,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline, MessageCycleMan
return MessageEndStreamResponse( return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id, task_id=self._application_generate_entity.task_id,
id=self._message.id, id=self._message_id,
metadata=extras.get("metadata", {}), metadata=extras.get("metadata", {}),
) )

@ -36,7 +36,7 @@ class MessageCycleManage:
] ]
_task_state: Union[EasyUITaskState, WorkflowTaskState] _task_state: Union[EasyUITaskState, WorkflowTaskState]
def _generate_conversation_name(self, conversation: Conversation, query: str) -> Optional[Thread]: def _generate_conversation_name(self, *, conversation_id: str, query: str) -> Optional[Thread]:
""" """
Generate conversation name. Generate conversation name.
:param conversation: conversation :param conversation: conversation
@ -56,7 +56,7 @@ class MessageCycleManage:
target=self._generate_conversation_name_worker, target=self._generate_conversation_name_worker,
kwargs={ kwargs={
"flask_app": current_app._get_current_object(), # type: ignore "flask_app": current_app._get_current_object(), # type: ignore
"conversation_id": conversation.id, "conversation_id": conversation_id,
"query": query, "query": query,
}, },
) )

@ -5,6 +5,7 @@ from datetime import UTC, datetime
from typing import Any, Optional, Union, cast from typing import Any, Optional, Union, cast
from uuid import uuid4 from uuid import uuid4
from sqlalchemy import func, select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
@ -47,7 +48,6 @@ from core.workflow.enums import SystemVariableKey
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.tool.entities import ToolNodeData from core.workflow.nodes.tool.entities import ToolNodeData
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.enums import CreatedByRole, WorkflowRunTriggeredFrom from models.enums import CreatedByRole, WorkflowRunTriggeredFrom
from models.model import EndUser from models.model import EndUser
@ -65,28 +65,33 @@ from .exc import WorkflowNodeExecutionNotFoundError, WorkflowRunNotFoundError
class WorkflowCycleManage: class WorkflowCycleManage:
_application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity] _application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity]
_workflow: Workflow
_user: Union[Account, EndUser]
_task_state: WorkflowTaskState _task_state: WorkflowTaskState
_workflow_system_variables: dict[SystemVariableKey, Any] _workflow_system_variables: dict[SystemVariableKey, Any]
_wip_workflow_node_executions: dict[str, WorkflowNodeExecution]
_wip_workflow_agent_logs: dict[str, list[AgentLogStreamResponse.Data]] def _handle_workflow_run_start(
self,
def _handle_workflow_run_start(self) -> WorkflowRun: *,
max_sequence = ( session: Session,
db.session.query(db.func.max(WorkflowRun.sequence_number)) workflow_id: str,
.filter(WorkflowRun.tenant_id == self._workflow.tenant_id) user_id: str,
.filter(WorkflowRun.app_id == self._workflow.app_id) created_by_role: CreatedByRole,
.scalar() ) -> WorkflowRun:
or 0 workflow_stmt = select(Workflow).where(Workflow.id == workflow_id)
workflow = session.scalar(workflow_stmt)
if not workflow:
raise ValueError(f"Workflow not found: {workflow_id}")
max_sequence_stmt = select(func.max(WorkflowRun.sequence_number)).where(
WorkflowRun.tenant_id == workflow.tenant_id,
WorkflowRun.app_id == workflow.app_id,
) )
max_sequence = session.scalar(max_sequence_stmt) or 0
new_sequence_number = max_sequence + 1 new_sequence_number = max_sequence + 1
inputs = {**self._application_generate_entity.inputs} inputs = {**self._application_generate_entity.inputs}
for key, value in (self._workflow_system_variables or {}).items(): for key, value in (self._workflow_system_variables or {}).items():
if key.value == "conversation": if key.value == "conversation":
continue continue
inputs[f"sys.{key.value}"] = value inputs[f"sys.{key.value}"] = value
triggered_from = ( triggered_from = (
@ -99,34 +104,33 @@ class WorkflowCycleManage:
inputs = dict(WorkflowEntry.handle_special_values(inputs) or {}) inputs = dict(WorkflowEntry.handle_special_values(inputs) or {})
# init workflow run # init workflow run
with Session(db.engine, expire_on_commit=False) as session: workflow_run_id = str(self._workflow_system_variables.get(SystemVariableKey.WORKFLOW_RUN_ID, uuid4()))
workflow_run = WorkflowRun()
system_id = self._workflow_system_variables[SystemVariableKey.WORKFLOW_RUN_ID] workflow_run = WorkflowRun()
workflow_run.id = system_id or str(uuid4()) workflow_run.id = workflow_run_id
workflow_run.tenant_id = self._workflow.tenant_id workflow_run.tenant_id = workflow.tenant_id
workflow_run.app_id = self._workflow.app_id workflow_run.app_id = workflow.app_id
workflow_run.sequence_number = new_sequence_number workflow_run.sequence_number = new_sequence_number
workflow_run.workflow_id = self._workflow.id workflow_run.workflow_id = workflow.id
workflow_run.type = self._workflow.type workflow_run.type = workflow.type
workflow_run.triggered_from = triggered_from.value workflow_run.triggered_from = triggered_from.value
workflow_run.version = self._workflow.version workflow_run.version = workflow.version
workflow_run.graph = self._workflow.graph workflow_run.graph = workflow.graph
workflow_run.inputs = json.dumps(inputs) workflow_run.inputs = json.dumps(inputs)
workflow_run.status = WorkflowRunStatus.RUNNING workflow_run.status = WorkflowRunStatus.RUNNING
workflow_run.created_by_role = ( workflow_run.created_by_role = created_by_role
CreatedByRole.ACCOUNT if isinstance(self._user, Account) else CreatedByRole.END_USER workflow_run.created_by = user_id
) workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.created_by = self._user.id
workflow_run.created_at = datetime.now(UTC).replace(tzinfo=None) session.add(workflow_run)
session.add(workflow_run)
session.commit()
return workflow_run return workflow_run
def _handle_workflow_run_success( def _handle_workflow_run_success(
self, self,
workflow_run: WorkflowRun, *,
session: Session,
workflow_run_id: str,
start_at: float, start_at: float,
total_tokens: int, total_tokens: int,
total_steps: int, total_steps: int,
@ -144,7 +148,7 @@ class WorkflowCycleManage:
:param conversation_id: conversation id :param conversation_id: conversation id
:return: :return:
""" """
workflow_run = self._refetch_workflow_run(workflow_run.id) workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
outputs = WorkflowEntry.handle_special_values(outputs) outputs = WorkflowEntry.handle_special_values(outputs)
@ -155,9 +159,6 @@ class WorkflowCycleManage:
workflow_run.total_steps = total_steps workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
db.session.refresh(workflow_run)
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
TraceTask( TraceTask(
@ -168,13 +169,13 @@ class WorkflowCycleManage:
) )
) )
db.session.close()
return workflow_run return workflow_run
def _handle_workflow_run_partial_success( def _handle_workflow_run_partial_success(
self, self,
workflow_run: WorkflowRun, *,
session: Session,
workflow_run_id: str,
start_at: float, start_at: float,
total_tokens: int, total_tokens: int,
total_steps: int, total_steps: int,
@ -183,18 +184,7 @@ class WorkflowCycleManage:
conversation_id: Optional[str] = None, conversation_id: Optional[str] = None,
trace_manager: Optional[TraceQueueManager] = None, trace_manager: Optional[TraceQueueManager] = None,
) -> WorkflowRun: ) -> WorkflowRun:
""" workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
Workflow run success
:param workflow_run: workflow run
:param start_at: start time
:param total_tokens: total tokens
:param total_steps: total steps
:param outputs: outputs
:param conversation_id: conversation id
:return:
"""
workflow_run = self._refetch_workflow_run(workflow_run.id)
outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None) outputs = WorkflowEntry.handle_special_values(dict(outputs) if outputs else None)
workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value workflow_run.status = WorkflowRunStatus.PARTIAL_SUCCESSED.value
@ -204,8 +194,6 @@ class WorkflowCycleManage:
workflow_run.total_steps = total_steps workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count workflow_run.exceptions_count = exceptions_count
db.session.commit()
db.session.refresh(workflow_run)
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
@ -217,13 +205,13 @@ class WorkflowCycleManage:
) )
) )
db.session.close()
return workflow_run return workflow_run
def _handle_workflow_run_failed( def _handle_workflow_run_failed(
self, self,
workflow_run: WorkflowRun, *,
session: Session,
workflow_run_id: str,
start_at: float, start_at: float,
total_tokens: int, total_tokens: int,
total_steps: int, total_steps: int,
@ -243,7 +231,7 @@ class WorkflowCycleManage:
:param error: error message :param error: error message
:return: :return:
""" """
workflow_run = self._refetch_workflow_run(workflow_run.id) workflow_run = self._get_workflow_run(session=session, workflow_run_id=workflow_run_id)
workflow_run.status = status.value workflow_run.status = status.value
workflow_run.error = error workflow_run.error = error
@ -252,21 +240,18 @@ class WorkflowCycleManage:
workflow_run.total_steps = total_steps workflow_run.total_steps = total_steps
workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None) workflow_run.finished_at = datetime.now(UTC).replace(tzinfo=None)
workflow_run.exceptions_count = exceptions_count workflow_run.exceptions_count = exceptions_count
db.session.commit()
stmt = select(WorkflowNodeExecution).where(
running_workflow_node_executions = ( WorkflowNodeExecution.tenant_id == workflow_run.tenant_id,
db.session.query(WorkflowNodeExecution) WorkflowNodeExecution.app_id == workflow_run.app_id,
.filter( WorkflowNodeExecution.workflow_id == workflow_run.workflow_id,
WorkflowNodeExecution.tenant_id == workflow_run.tenant_id, WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.app_id == workflow_run.app_id, WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.workflow_id == workflow_run.workflow_id, WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value,
WorkflowNodeExecution.workflow_run_id == workflow_run.id,
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING.value,
)
.all()
) )
running_workflow_node_executions = session.scalars(stmt).all()
for workflow_node_execution in running_workflow_node_executions: for workflow_node_execution in running_workflow_node_executions:
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error workflow_node_execution.error = error
@ -274,13 +259,6 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = ( workflow_node_execution.elapsed_time = (
workflow_node_execution.finished_at - workflow_node_execution.created_at workflow_node_execution.finished_at - workflow_node_execution.created_at
).total_seconds() ).total_seconds()
db.session.commit()
db.session.close()
# with Session(db.engine, expire_on_commit=False) as session:
# session.add(workflow_run)
# session.refresh(workflow_run)
if trace_manager: if trace_manager:
trace_manager.add_trace_task( trace_manager.add_trace_task(
@ -295,79 +273,49 @@ class WorkflowCycleManage:
return workflow_run return workflow_run
def _handle_node_execution_start( def _handle_node_execution_start(
self, workflow_run: WorkflowRun, event: QueueNodeStartedEvent self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeStartedEvent
) -> WorkflowNodeExecution: ) -> WorkflowNodeExecution:
# init workflow node execution workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = event.node_execution_id
with Session(db.engine, expire_on_commit=False) as session: workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution = WorkflowNodeExecution() workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.tenant_id = workflow_run.tenant_id workflow_node_execution.workflow_id = workflow_run.workflow_id
workflow_node_execution.app_id = workflow_run.app_id workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value
workflow_node_execution.workflow_id = workflow_run.workflow_id workflow_node_execution.workflow_run_id = workflow_run.id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN.value workflow_node_execution.predecessor_node_id = event.predecessor_node_id
workflow_node_execution.workflow_run_id = workflow_run.id workflow_node_execution.index = event.node_run_index
workflow_node_execution.predecessor_node_id = event.predecessor_node_id workflow_node_execution.node_execution_id = event.node_execution_id
workflow_node_execution.index = event.node_run_index workflow_node_execution.node_id = event.node_id
workflow_node_execution.node_execution_id = event.node_execution_id workflow_node_execution.node_type = event.node_type.value
workflow_node_execution.node_id = event.node_id workflow_node_execution.title = event.node_data.title
workflow_node_execution.node_type = event.node_type.value workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value
workflow_node_execution.title = event.node_data.title workflow_node_execution.created_by_role = workflow_run.created_by_role
workflow_node_execution.status = WorkflowNodeExecutionStatus.RUNNING.value workflow_node_execution.created_by = workflow_run.created_by
workflow_node_execution.created_by_role = workflow_run.created_by_role workflow_node_execution.execution_metadata = json.dumps(
workflow_node_execution.created_by = workflow_run.created_by {
workflow_node_execution.execution_metadata = json.dumps( NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id,
{ NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id,
NodeRunMetadataKey.PARALLEL_MODE_RUN_ID: event.parallel_mode_run_id, }
NodeRunMetadataKey.ITERATION_ID: event.in_iteration_id, )
} workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
)
workflow_node_execution.created_at = datetime.now(UTC).replace(tzinfo=None)
session.add(workflow_node_execution)
session.commit()
session.refresh(workflow_node_execution)
self._wip_workflow_node_executions[workflow_node_execution.node_execution_id] = workflow_node_execution session.add(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_success(self, event: QueueNodeSucceededEvent) -> WorkflowNodeExecution: def _handle_workflow_node_execution_success(
""" self, *, session: Session, event: QueueNodeSucceededEvent
Workflow node execution success ) -> WorkflowNodeExecution:
:param event: queue node succeeded event workflow_node_execution = self._get_workflow_node_execution(
:return: session=session, node_execution_id=event.node_execution_id
""" )
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id)
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
execution_metadata_dict = dict(event.execution_metadata or {}) execution_metadata_dict = dict(event.execution_metadata or {})
if self._wip_workflow_agent_logs.get(workflow_node_execution.id):
if not execution_metadata_dict:
execution_metadata_dict = {}
execution_metadata_dict[NodeRunMetadataKey.AGENT_LOG] = self._wip_workflow_agent_logs.get(
workflow_node_execution.id, []
)
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
finished_at = datetime.now(UTC).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: WorkflowNodeExecutionStatus.SUCCEEDED.value,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.execution_metadata: execution_metadata,
WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time,
}
)
db.session.commit()
db.session.close()
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value workflow_node_execution.status = WorkflowNodeExecutionStatus.SUCCEEDED.value
@ -378,54 +326,31 @@ class WorkflowCycleManage:
workflow_node_execution.finished_at = finished_at workflow_node_execution.finished_at = finished_at
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_failed( def _handle_workflow_node_execution_failed(
self, event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent self,
*,
session: Session,
event: QueueNodeFailedEvent | QueueNodeInIterationFailedEvent | QueueNodeExceptionEvent,
) -> WorkflowNodeExecution: ) -> WorkflowNodeExecution:
""" """
Workflow node execution failed Workflow node execution failed
:param event: queue node failed event :param event: queue node failed event
:return: :return:
""" """
workflow_node_execution = self._refetch_workflow_node_execution(event.node_execution_id) workflow_node_execution = self._get_workflow_node_execution(
session=session, node_execution_id=event.node_execution_id
)
inputs = WorkflowEntry.handle_special_values(event.inputs) inputs = WorkflowEntry.handle_special_values(event.inputs)
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
outputs = WorkflowEntry.handle_special_values(event.outputs) outputs = WorkflowEntry.handle_special_values(event.outputs)
finished_at = datetime.now(UTC).replace(tzinfo=None) finished_at = datetime.now(UTC).replace(tzinfo=None)
elapsed_time = (finished_at - event.start_at).total_seconds() elapsed_time = (finished_at - event.start_at).total_seconds()
execution_metadata_dict = dict(event.execution_metadata or {}) execution_metadata = (
if self._wip_workflow_agent_logs.get(workflow_node_execution.id): json.dumps(jsonable_encoder(event.execution_metadata)) if event.execution_metadata else None
if not execution_metadata_dict:
execution_metadata_dict = {}
execution_metadata_dict[NodeRunMetadataKey.AGENT_LOG] = self._wip_workflow_agent_logs.get(
workflow_node_execution.id, []
)
execution_metadata = json.dumps(jsonable_encoder(execution_metadata_dict)) if execution_metadata_dict else None
db.session.query(WorkflowNodeExecution).filter(WorkflowNodeExecution.id == workflow_node_execution.id).update(
{
WorkflowNodeExecution.status: (
WorkflowNodeExecutionStatus.FAILED.value
if not isinstance(event, QueueNodeExceptionEvent)
else WorkflowNodeExecutionStatus.EXCEPTION.value
),
WorkflowNodeExecution.error: event.error,
WorkflowNodeExecution.inputs: json.dumps(inputs) if inputs else None,
WorkflowNodeExecution.process_data: json.dumps(process_data) if process_data else None,
WorkflowNodeExecution.outputs: json.dumps(outputs) if outputs else None,
WorkflowNodeExecution.finished_at: finished_at,
WorkflowNodeExecution.elapsed_time: elapsed_time,
WorkflowNodeExecution.execution_metadata: execution_metadata,
}
) )
db.session.commit()
db.session.close()
process_data = WorkflowEntry.handle_special_values(event.process_data) process_data = WorkflowEntry.handle_special_values(event.process_data)
workflow_node_execution.status = ( workflow_node_execution.status = (
WorkflowNodeExecutionStatus.FAILED.value WorkflowNodeExecutionStatus.FAILED.value
@ -440,12 +365,10 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.execution_metadata = execution_metadata
self._wip_workflow_node_executions.pop(workflow_node_execution.node_execution_id)
return workflow_node_execution return workflow_node_execution
def _handle_workflow_node_execution_retried( def _handle_workflow_node_execution_retried(
self, workflow_run: WorkflowRun, event: QueueNodeRetryEvent self, *, session: Session, workflow_run: WorkflowRun, event: QueueNodeRetryEvent
) -> WorkflowNodeExecution: ) -> WorkflowNodeExecution:
""" """
Workflow node execution failed Workflow node execution failed
@ -469,6 +392,7 @@ class WorkflowCycleManage:
execution_metadata = json.dumps(merged_metadata) execution_metadata = json.dumps(merged_metadata)
workflow_node_execution = WorkflowNodeExecution() workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = event.node_execution_id
workflow_node_execution.tenant_id = workflow_run.tenant_id workflow_node_execution.tenant_id = workflow_run.tenant_id
workflow_node_execution.app_id = workflow_run.app_id workflow_node_execution.app_id = workflow_run.app_id
workflow_node_execution.workflow_id = workflow_run.workflow_id workflow_node_execution.workflow_id = workflow_run.workflow_id
@ -491,10 +415,7 @@ class WorkflowCycleManage:
workflow_node_execution.execution_metadata = execution_metadata workflow_node_execution.execution_metadata = execution_metadata
workflow_node_execution.index = event.node_run_index workflow_node_execution.index = event.node_run_index
db.session.add(workflow_node_execution) session.add(workflow_node_execution)
db.session.commit()
db.session.refresh(workflow_node_execution)
return workflow_node_execution return workflow_node_execution
################################################# #################################################
@ -502,14 +423,14 @@ class WorkflowCycleManage:
################################################# #################################################
def _workflow_start_to_stream_response( def _workflow_start_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun self,
*,
session: Session,
task_id: str,
workflow_run: WorkflowRun,
) -> WorkflowStartStreamResponse: ) -> WorkflowStartStreamResponse:
""" # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
Workflow start to stream response. _ = session
:param task_id: task id
:param workflow_run: workflow run
:return:
"""
return WorkflowStartStreamResponse( return WorkflowStartStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
@ -523,36 +444,32 @@ class WorkflowCycleManage:
) )
def _workflow_finish_to_stream_response( def _workflow_finish_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun self,
*,
session: Session,
task_id: str,
workflow_run: WorkflowRun,
) -> WorkflowFinishStreamResponse: ) -> WorkflowFinishStreamResponse:
"""
Workflow finish to stream response.
:param task_id: task id
:param workflow_run: workflow run
:return:
"""
# Attach WorkflowRun to an active session so "created_by_role" can be accessed.
workflow_run = db.session.merge(workflow_run)
# Refresh to ensure any expired attributes are fully loaded
db.session.refresh(workflow_run)
created_by = None created_by = None
if workflow_run.created_by_role == CreatedByRole.ACCOUNT.value: if workflow_run.created_by_role == CreatedByRole.ACCOUNT:
created_by_account = workflow_run.created_by_account stmt = select(Account).where(Account.id == workflow_run.created_by)
if created_by_account: account = session.scalar(stmt)
if account:
created_by = { created_by = {
"id": created_by_account.id, "id": account.id,
"name": created_by_account.name, "name": account.name,
"email": created_by_account.email, "email": account.email,
} }
else: elif workflow_run.created_by_role == CreatedByRole.END_USER:
created_by_end_user = workflow_run.created_by_end_user stmt = select(EndUser).where(EndUser.id == workflow_run.created_by)
if created_by_end_user: end_user = session.scalar(stmt)
if end_user:
created_by = { created_by = {
"id": created_by_end_user.id, "id": end_user.id,
"user": created_by_end_user.session_id, "user": end_user.session_id,
} }
else:
raise NotImplementedError(f"unknown created_by_role: {workflow_run.created_by_role}")
return WorkflowFinishStreamResponse( return WorkflowFinishStreamResponse(
task_id=task_id, task_id=task_id,
@ -576,17 +493,20 @@ class WorkflowCycleManage:
) )
def _workflow_node_start_to_stream_response( def _workflow_node_start_to_stream_response(
self, event: QueueNodeStartedEvent, task_id: str, workflow_node_execution: WorkflowNodeExecution self,
*,
session: Session,
event: QueueNodeStartedEvent,
task_id: str,
workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeStartStreamResponse]: ) -> Optional[NodeStartStreamResponse]:
""" # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
Workflow node start to stream response. _ = session
:param event: queue node started event
:param task_id: task id
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
if not workflow_node_execution.workflow_run_id:
return None
response = NodeStartStreamResponse( response = NodeStartStreamResponse(
task_id=task_id, task_id=task_id,
@ -622,6 +542,8 @@ class WorkflowCycleManage:
def _workflow_node_finish_to_stream_response( def _workflow_node_finish_to_stream_response(
self, self,
*,
session: Session,
event: QueueNodeSucceededEvent event: QueueNodeSucceededEvent
| QueueNodeFailedEvent | QueueNodeFailedEvent
| QueueNodeInIterationFailedEvent | QueueNodeInIterationFailedEvent
@ -629,15 +551,14 @@ class WorkflowCycleManage:
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: WorkflowNodeExecution,
) -> Optional[NodeFinishStreamResponse]: ) -> Optional[NodeFinishStreamResponse]:
""" # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
Workflow node finish to stream response. _ = session
:param event: queue node succeeded or failed event
:param task_id: task id
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
if not workflow_node_execution.workflow_run_id:
return None
if not workflow_node_execution.finished_at:
return None
return NodeFinishStreamResponse( return NodeFinishStreamResponse(
task_id=task_id, task_id=task_id,
@ -669,19 +590,20 @@ class WorkflowCycleManage:
def _workflow_node_retry_to_stream_response( def _workflow_node_retry_to_stream_response(
self, self,
*,
session: Session,
event: QueueNodeRetryEvent, event: QueueNodeRetryEvent,
task_id: str, task_id: str,
workflow_node_execution: WorkflowNodeExecution, workflow_node_execution: WorkflowNodeExecution,
) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]: ) -> Optional[Union[NodeRetryStreamResponse, NodeFinishStreamResponse]]:
""" # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
Workflow node finish to stream response. _ = session
:param event: queue node succeeded or failed event
:param task_id: task id
:param workflow_node_execution: workflow node execution
:return:
"""
if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}: if workflow_node_execution.node_type in {NodeType.ITERATION.value, NodeType.LOOP.value}:
return None return None
if not workflow_node_execution.workflow_run_id:
return None
if not workflow_node_execution.finished_at:
return None
return NodeRetryStreamResponse( return NodeRetryStreamResponse(
task_id=task_id, task_id=task_id,
@ -713,15 +635,10 @@ class WorkflowCycleManage:
) )
def _workflow_parallel_branch_start_to_stream_response( def _workflow_parallel_branch_start_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueParallelBranchRunStartedEvent
) -> ParallelBranchStartStreamResponse: ) -> ParallelBranchStartStreamResponse:
""" # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
Workflow parallel branch start to stream response _ = session
:param task_id: task id
:param workflow_run: workflow run
:param event: parallel branch run started event
:return:
"""
return ParallelBranchStartStreamResponse( return ParallelBranchStartStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
@ -737,17 +654,14 @@ class WorkflowCycleManage:
def _workflow_parallel_branch_finished_to_stream_response( def _workflow_parallel_branch_finished_to_stream_response(
self, self,
*,
session: Session,
task_id: str, task_id: str,
workflow_run: WorkflowRun, workflow_run: WorkflowRun,
event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent, event: QueueParallelBranchRunSucceededEvent | QueueParallelBranchRunFailedEvent,
) -> ParallelBranchFinishedStreamResponse: ) -> ParallelBranchFinishedStreamResponse:
""" # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
Workflow parallel branch finished to stream response _ = session
:param task_id: task id
:param workflow_run: workflow run
:param event: parallel branch run succeeded or failed event
:return:
"""
return ParallelBranchFinishedStreamResponse( return ParallelBranchFinishedStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
@ -764,15 +678,10 @@ class WorkflowCycleManage:
) )
def _workflow_iteration_start_to_stream_response( def _workflow_iteration_start_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationStartEvent
) -> IterationNodeStartStreamResponse: ) -> IterationNodeStartStreamResponse:
""" # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
Workflow iteration start to stream response _ = session
:param task_id: task id
:param workflow_run: workflow run
:param event: iteration start event
:return:
"""
return IterationNodeStartStreamResponse( return IterationNodeStartStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
@ -791,15 +700,10 @@ class WorkflowCycleManage:
) )
def _workflow_iteration_next_to_stream_response( def _workflow_iteration_next_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationNextEvent
) -> IterationNodeNextStreamResponse: ) -> IterationNodeNextStreamResponse:
""" # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
Workflow iteration next to stream response _ = session
:param task_id: task id
:param workflow_run: workflow run
:param event: iteration next event
:return:
"""
return IterationNodeNextStreamResponse( return IterationNodeNextStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
@ -820,15 +724,10 @@ class WorkflowCycleManage:
) )
def _workflow_iteration_completed_to_stream_response( def _workflow_iteration_completed_to_stream_response(
self, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent self, *, session: Session, task_id: str, workflow_run: WorkflowRun, event: QueueIterationCompletedEvent
) -> IterationNodeCompletedStreamResponse: ) -> IterationNodeCompletedStreamResponse:
""" # receive session to make sure the workflow_run won't be expired, need a more elegant way to handle this
Workflow iteration completed to stream response _ = session
:param task_id: task id
:param workflow_run: workflow run
:param event: iteration completed event
:return:
"""
return IterationNodeCompletedStreamResponse( return IterationNodeCompletedStreamResponse(
task_id=task_id, task_id=task_id,
workflow_run_id=workflow_run.id, workflow_run_id=workflow_run.id,
@ -912,27 +811,22 @@ class WorkflowCycleManage:
return None return None
def _refetch_workflow_run(self, workflow_run_id: str) -> WorkflowRun: def _get_workflow_run(self, *, session: Session, workflow_run_id: str) -> WorkflowRun:
""" """
Refetch workflow run Refetch workflow run
:param workflow_run_id: workflow run id :param workflow_run_id: workflow run id
:return: :return:
""" """
workflow_run = db.session.query(WorkflowRun).filter(WorkflowRun.id == workflow_run_id).first() stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
workflow_run = session.scalar(stmt)
if not workflow_run: if not workflow_run:
raise WorkflowRunNotFoundError(workflow_run_id) raise WorkflowRunNotFoundError(workflow_run_id)
return workflow_run return workflow_run
def _refetch_workflow_node_execution(self, node_execution_id: str) -> WorkflowNodeExecution: def _get_workflow_node_execution(self, session: Session, node_execution_id: str) -> WorkflowNodeExecution:
""" stmt = select(WorkflowNodeExecution).where(WorkflowNodeExecution.id == node_execution_id)
Refetch workflow node execution workflow_node_execution = session.scalar(stmt)
:param node_execution_id: workflow node execution id
:return:
"""
workflow_node_execution = self._wip_workflow_node_executions.get(node_execution_id)
if not workflow_node_execution: if not workflow_node_execution:
raise WorkflowNodeExecutionNotFoundError(node_execution_id) raise WorkflowNodeExecutionNotFoundError(node_execution_id)
@ -945,41 +839,10 @@ class WorkflowCycleManage:
:param event: agent log event :param event: agent log event
:return: :return:
""" """
node_execution = self._wip_workflow_node_executions.get(event.node_execution_id)
if not node_execution:
raise Exception(f"Workflow node execution not found: {event.node_execution_id}")
node_execution_id = node_execution.id
original_agent_logs = self._wip_workflow_agent_logs.get(node_execution_id, [])
# try to find the log with the same id
for log in original_agent_logs:
if log.id == event.id:
# update the log
log.status = event.status
log.error = event.error
log.data = event.data
break
else:
# append the log
original_agent_logs.append(
AgentLogStreamResponse.Data(
id=event.id,
parent_id=event.parent_id,
node_execution_id=node_execution_id,
error=event.error,
status=event.status,
data=event.data,
label=event.label,
)
)
self._wip_workflow_agent_logs[node_execution_id] = original_agent_logs
return AgentLogStreamResponse( return AgentLogStreamResponse(
task_id=task_id, task_id=task_id,
data=AgentLogStreamResponse.Data( data=AgentLogStreamResponse.Data(
node_execution_id=node_execution_id, node_execution_id=event.node_execution_id,
id=event.id, id=event.id,
parent_id=event.parent_id, parent_id=event.parent_id,
label=event.label, label=event.label,

@ -0,0 +1,19 @@
from typing import Optional
from pydantic import BaseModel
class PreviewDetail(BaseModel):
content: str
child_chunks: Optional[list[str]] = None
class QAPreviewDetail(BaseModel):
question: str
answer: str
class IndexingEstimate(BaseModel):
total_segments: int
preview: list[PreviewDetail]
qa_preview: Optional[list[QAPreviewDetail]] = None

@ -881,7 +881,7 @@ class ProviderConfiguration(BaseModel):
# if llm name not in restricted llm list, remove it # if llm name not in restricted llm list, remove it
restrict_model_names = [rm.model for rm in restrict_models] restrict_model_names = [rm.model for rm in restrict_models]
for model in provider_models: for model in provider_models:
if model.model_type == ModelType.LLM and m.model not in restrict_model_names: if model.model_type == ModelType.LLM and model.model not in restrict_model_names:
model.status = ModelStatus.NO_PERMISSION model.status = ModelStatus.NO_PERMISSION
elif not quota_configuration.is_valid: elif not quota_configuration.is_valid:
model.status = ModelStatus.QUOTA_EXCEEDED model.status = ModelStatus.QUOTA_EXCEEDED

@ -8,34 +8,34 @@ import time
import uuid import uuid
from typing import Any, Optional, cast from typing import Any, Optional, cast
from flask import Flask, current_app from flask import current_app
from flask_login import current_user # type: ignore from flask_login import current_user # type: ignore
from sqlalchemy.orm.exc import ObjectDeletedError from sqlalchemy.orm.exc import ObjectDeletedError
from configs import dify_config from configs import dify_config
from core.entities.knowledge_entities import IndexingEstimate, PreviewDetail, QAPreviewDetail
from core.errors.error import ProviderTokenNotInitError from core.errors.error import ProviderTokenNotInitError
from core.llm_generator.llm_generator import LLMGenerator
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.docstore.dataset_docstore import DatasetDocumentStore from core.rag.docstore.dataset_docstore import DatasetDocumentStore
from core.rag.extractor.entity.extract_setting import ExtractSetting from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document from core.rag.models.document import ChildDocument, Document
from core.rag.splitter.fixed_text_splitter import ( from core.rag.splitter.fixed_text_splitter import (
EnhanceRecursiveCharacterTextSplitter, EnhanceRecursiveCharacterTextSplitter,
FixedRecursiveCharacterTextSplitter, FixedRecursiveCharacterTextSplitter,
) )
from core.rag.splitter.text_splitter import TextSplitter from core.rag.splitter.text_splitter import TextSplitter
from core.tools.utils.rag_web_reader import get_image_upload_file_ids from core.tools.utils.rag_web_reader import get_image_upload_file_ids
from core.tools.utils.text_processing_utils import remove_leading_symbols
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from extensions.ext_storage import storage from extensions.ext_storage import storage
from libs import helper from libs import helper
from models.dataset import Dataset, DatasetProcessRule, DocumentSegment from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from models.model import UploadFile from models.model import UploadFile
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -115,6 +115,9 @@ class IndexingRunner:
for document_segment in document_segments: for document_segment in document_segments:
db.session.delete(document_segment) db.session.delete(document_segment)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
# delete child chunks
db.session.query(ChildChunk).filter(ChildChunk.segment_id == document_segment.id).delete()
db.session.commit() db.session.commit()
# get the process rule # get the process rule
processing_rule = ( processing_rule = (
@ -183,7 +186,22 @@ class IndexingRunner:
"dataset_id": document_segment.dataset_id, "dataset_id": document_segment.dataset_id,
}, },
) )
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = document_segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": document_segment.document_id,
"dataset_id": document_segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document) documents.append(document)
# build index # build index
@ -222,7 +240,7 @@ class IndexingRunner:
doc_language: str = "English", doc_language: str = "English",
dataset_id: Optional[str] = None, dataset_id: Optional[str] = None,
indexing_technique: str = "economy", indexing_technique: str = "economy",
) -> dict: ) -> IndexingEstimate:
""" """
Estimate the indexing for the document. Estimate the indexing for the document.
""" """
@ -258,31 +276,38 @@ class IndexingRunner:
tenant_id=tenant_id, tenant_id=tenant_id,
model_type=ModelType.TEXT_EMBEDDING, model_type=ModelType.TEXT_EMBEDDING,
) )
preview_texts: list[str] = [] preview_texts = [] # type: ignore
total_segments = 0 total_segments = 0
index_type = doc_form index_type = doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
all_text_docs = []
for extract_setting in extract_settings: for extract_setting in extract_settings:
# extract # extract
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
all_text_docs.extend(text_docs)
processing_rule = DatasetProcessRule( processing_rule = DatasetProcessRule(
mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"]) mode=tmp_processing_rule["mode"], rules=json.dumps(tmp_processing_rule["rules"])
) )
text_docs = index_processor.extract(extract_setting, process_rule_mode=tmp_processing_rule["mode"])
# get splitter documents = index_processor.transform(
splitter = self._get_splitter(processing_rule, embedding_model_instance) text_docs,
embedding_model_instance=embedding_model_instance,
# split to documents process_rule=processing_rule.to_dict(),
documents = self._split_to_documents_for_estimate( tenant_id=current_user.current_tenant_id,
text_docs=text_docs, splitter=splitter, processing_rule=processing_rule doc_language=doc_language,
preview=True,
) )
total_segments += len(documents) total_segments += len(documents)
for document in documents: for document in documents:
if len(preview_texts) < 5: if len(preview_texts) < 10:
preview_texts.append(document.page_content) if doc_form and doc_form == "qa_model":
preview_detail = QAPreviewDetail(
question=document.page_content, answer=document.metadata.get("answer") or ""
)
preview_texts.append(preview_detail)
else:
preview_detail = PreviewDetail(content=document.page_content) # type: ignore
if document.children:
preview_detail.child_chunks = [child.page_content for child in document.children] # type: ignore
preview_texts.append(preview_detail)
# delete image files and related db records # delete image files and related db records
image_upload_file_ids = get_image_upload_file_ids(document.page_content) image_upload_file_ids = get_image_upload_file_ids(document.page_content)
@ -299,15 +324,8 @@ class IndexingRunner:
db.session.delete(image_file) db.session.delete(image_file)
if doc_form and doc_form == "qa_model": if doc_form and doc_form == "qa_model":
if len(preview_texts) > 0: return IndexingEstimate(total_segments=total_segments * 20, qa_preview=preview_texts, preview=[])
# qa model document return IndexingEstimate(total_segments=total_segments, preview=preview_texts) # type: ignore
response = LLMGenerator.generate_qa_document(
current_user.current_tenant_id, preview_texts[0], doc_language
)
document_qa_list = self.format_split_text(response)
return {"total_segments": total_segments * 20, "qa_preview": document_qa_list, "preview": preview_texts}
return {"total_segments": total_segments, "preview": preview_texts}
def _extract( def _extract(
self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict self, index_processor: BaseIndexProcessor, dataset_document: DatasetDocument, process_rule: dict
@ -401,31 +419,26 @@ class IndexingRunner:
@staticmethod @staticmethod
def _get_splitter( def _get_splitter(
processing_rule: DatasetProcessRule, embedding_model_instance: Optional[ModelInstance] processing_rule_mode: str,
max_tokens: int,
chunk_overlap: int,
separator: str,
embedding_model_instance: Optional[ModelInstance],
) -> TextSplitter: ) -> TextSplitter:
""" """
Get the NodeParser object according to the processing rule. Get the NodeParser object according to the processing rule.
""" """
character_splitter: TextSplitter if processing_rule_mode in ["custom", "hierarchical"]:
if processing_rule.mode == "custom":
# The user-defined segmentation rule # The user-defined segmentation rule
rules = json.loads(processing_rule.rules)
segmentation = rules["segmentation"]
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: if max_tokens < 50 or max_tokens > max_segmentation_tokens_length:
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
separator = segmentation["separator"]
if separator: if separator:
separator = separator.replace("\\n", "\n") separator = separator.replace("\\n", "\n")
if segmentation.get("chunk_overlap"):
chunk_overlap = segmentation["chunk_overlap"]
else:
chunk_overlap = 0
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"], chunk_size=max_tokens,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
fixed_separator=separator, fixed_separator=separator,
separators=["\n\n", "", ". ", " ", ""], separators=["\n\n", "", ". ", " ", ""],
@ -441,143 +454,7 @@ class IndexingRunner:
embedding_model_instance=embedding_model_instance, embedding_model_instance=embedding_model_instance,
) )
return character_splitter return character_splitter # type: ignore
def _step_split(
self,
text_docs: list[Document],
splitter: TextSplitter,
dataset: Dataset,
dataset_document: DatasetDocument,
processing_rule: DatasetProcessRule,
) -> list[Document]:
"""
Split the text documents into documents and save them to the document segment.
"""
documents = self._split_to_documents(
text_docs=text_docs,
splitter=splitter,
processing_rule=processing_rule,
tenant_id=dataset.tenant_id,
document_form=dataset_document.doc_form,
document_language=dataset_document.doc_language,
)
# save node to document segment
doc_store = DatasetDocumentStore(
dataset=dataset, user_id=dataset_document.created_by, document_id=dataset_document.id
)
# add document segments
doc_store.add_documents(documents)
# update document status to indexing
cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
self._update_document_index_status(
document_id=dataset_document.id,
after_indexing_status="indexing",
extra_update_params={
DatasetDocument.cleaning_completed_at: cur_time,
DatasetDocument.splitting_completed_at: cur_time,
},
)
# update segment status to indexing
self._update_segments_by_document(
dataset_document_id=dataset_document.id,
update_params={
DocumentSegment.status: "indexing",
DocumentSegment.indexing_at: datetime.datetime.now(datetime.UTC).replace(tzinfo=None),
},
)
return documents
def _split_to_documents(
self,
text_docs: list[Document],
splitter: TextSplitter,
processing_rule: DatasetProcessRule,
tenant_id: str,
document_form: str,
document_language: str,
) -> list[Document]:
"""
Split the text documents into nodes.
"""
all_documents: list[Document] = []
all_qa_documents: list[Document] = []
for text_doc in text_docs:
# document clean
document_text = self._document_clean(text_doc.page_content, processing_rule)
text_doc.page_content = document_text
# parse document to nodes
documents = splitter.split_documents([text_doc])
split_documents = []
for document_node in documents:
if document_node.page_content.strip():
if document_node.metadata is not None:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
document_node.page_content = remove_leading_symbols(page_content)
if document_node.page_content:
split_documents.append(document_node)
all_documents.extend(split_documents)
# processing qa document
if document_form == "qa_model":
for i in range(0, len(all_documents), 10):
threads = []
sub_documents = all_documents[i : i + 10]
for doc in sub_documents:
document_format_thread = threading.Thread(
target=self.format_qa_document,
kwargs={
"flask_app": current_app._get_current_object(), # type: ignore
"tenant_id": tenant_id,
"document_node": doc,
"all_qa_documents": all_qa_documents,
"document_language": document_language,
},
)
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
return all_qa_documents
return all_documents
def format_qa_document(self, flask_app: Flask, tenant_id: str, document_node, all_qa_documents, document_language):
format_documents = []
if document_node.page_content is None or not document_node.page_content.strip():
return
with flask_app.app_context():
try:
# qa model document
response = LLMGenerator.generate_qa_document(tenant_id, document_node.page_content, document_language)
document_qa_list = self.format_split_text(response)
qa_documents = []
for result in document_qa_list:
qa_document = Document(
page_content=result["question"], metadata=document_node.metadata.model_copy()
)
if qa_document.metadata is not None:
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(result["question"])
qa_document.metadata["answer"] = result["answer"]
qa_document.metadata["doc_id"] = doc_id
qa_document.metadata["doc_hash"] = hash
qa_documents.append(qa_document)
format_documents.extend(qa_documents)
except Exception as e:
logging.exception("Failed to format qa document")
all_qa_documents.extend(format_documents)
def _split_to_documents_for_estimate( def _split_to_documents_for_estimate(
self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule self, text_docs: list[Document], splitter: TextSplitter, processing_rule: DatasetProcessRule
@ -624,11 +501,11 @@ class IndexingRunner:
return document_text return document_text
@staticmethod @staticmethod
def format_split_text(text): def format_split_text(text: str) -> list[QAPreviewDetail]:
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)" regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q\d+:|$)"
matches = re.findall(regex, text, re.UNICODE) matches = re.findall(regex, text, re.UNICODE)
return [{"question": q, "answer": re.sub(r"\n\s*", "\n", a.strip())} for q, a in matches if q and a] return [QAPreviewDetail(question=q, answer=re.sub(r"\n\s*", "\n", a.strip())) for q, a in matches if q and a]
def _load( def _load(
self, self,
@ -654,13 +531,14 @@ class IndexingRunner:
indexing_start_at = time.perf_counter() indexing_start_at = time.perf_counter()
tokens = 0 tokens = 0
chunk_size = 10 chunk_size = 10
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
)
create_keyword_thread.start()
# create keyword index
create_keyword_thread = threading.Thread(
target=self._process_keyword_index,
args=(current_app._get_current_object(), dataset.id, dataset_document.id, documents), # type: ignore
)
create_keyword_thread.start()
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [] futures = []
@ -680,8 +558,8 @@ class IndexingRunner:
for future in futures: for future in futures:
tokens += future.result() tokens += future.result()
if dataset_document.doc_form != IndexType.PARENT_CHILD_INDEX:
create_keyword_thread.join() create_keyword_thread.join()
indexing_end_at = time.perf_counter() indexing_end_at = time.perf_counter()
# update document status to completed # update document status to completed
@ -791,28 +669,6 @@ class IndexingRunner:
DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params) DocumentSegment.query.filter_by(document_id=dataset_document_id).update(update_params)
db.session.commit() db.session.commit()
@staticmethod
def batch_add_segments(segments: list[DocumentSegment], dataset: Dataset):
"""
Batch add segments index processing
"""
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
# save vector index
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents)
def _transform( def _transform(
self, self,
index_processor: BaseIndexProcessor, index_processor: BaseIndexProcessor,
@ -854,7 +710,7 @@ class IndexingRunner:
) )
# add document segments # add document segments
doc_store.add_documents(documents) doc_store.add_documents(docs=documents, save_child=dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX)
# update document status to indexing # update document status to indexing
cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None) cur_time = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)

@ -9,6 +9,8 @@ from typing import Any, Optional, Union
from uuid import UUID, uuid4 from uuid import UUID, uuid4
from flask import current_app from flask import current_app
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token from core.helper.encrypter import decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import ( from core.ops.entities.config_entity import (
@ -329,15 +331,15 @@ class TraceTask:
): ):
self.trace_type = trace_type self.trace_type = trace_type
self.message_id = message_id self.message_id = message_id
self.workflow_run = workflow_run self.workflow_run_id = workflow_run.id if workflow_run else None
self.conversation_id = conversation_id self.conversation_id = conversation_id
self.user_id = user_id self.user_id = user_id
self.timer = timer self.timer = timer
self.kwargs = kwargs
self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001") self.file_base_url = os.getenv("FILES_URL", "http://127.0.0.1:5001")
self.app_id = None self.app_id = None
self.kwargs = kwargs
def execute(self): def execute(self):
return self.preprocess() return self.preprocess()
@ -345,19 +347,23 @@ class TraceTask:
preprocess_map = { preprocess_map = {
TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs), TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace( TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
self.workflow_run, self.conversation_id, self.user_id workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
message_id=self.message_id, timer=self.timer, **self.kwargs
), ),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(self.message_id),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(self.message_id, self.timer, **self.kwargs),
TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace( TraceTaskName.SUGGESTED_QUESTION_TRACE: lambda: self.suggested_question_trace(
self.message_id, self.timer, **self.kwargs message_id=self.message_id, timer=self.timer, **self.kwargs
), ),
TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace( TraceTaskName.DATASET_RETRIEVAL_TRACE: lambda: self.dataset_retrieval_trace(
self.message_id, self.timer, **self.kwargs message_id=self.message_id, timer=self.timer, **self.kwargs
),
TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(
message_id=self.message_id, timer=self.timer, **self.kwargs
), ),
TraceTaskName.TOOL_TRACE: lambda: self.tool_trace(self.message_id, self.timer, **self.kwargs),
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace( TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
self.conversation_id, self.timer, **self.kwargs conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
), ),
} }
@ -367,86 +373,100 @@ class TraceTask:
def conversation_trace(self, **kwargs): def conversation_trace(self, **kwargs):
return kwargs return kwargs
def workflow_trace(self, workflow_run: WorkflowRun | None, conversation_id, user_id): def workflow_trace(
if not workflow_run: self,
raise ValueError("Workflow run not found") *,
workflow_run_id: str | None,
db.session.merge(workflow_run) conversation_id: str | None,
db.session.refresh(workflow_run) user_id: str | None,
):
workflow_id = workflow_run.workflow_id if not workflow_run_id:
tenant_id = workflow_run.tenant_id return {}
workflow_run_id = workflow_run.id
workflow_run_elapsed_time = workflow_run.elapsed_time
workflow_run_status = workflow_run.status
workflow_run_inputs = workflow_run.inputs_dict
workflow_run_outputs = workflow_run.outputs_dict
workflow_run_version = workflow_run.version
error = workflow_run.error or ""
total_tokens = workflow_run.total_tokens
file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
# get workflow_app_log_id
workflow_app_log_data = (
db.session.query(WorkflowAppLog)
.filter_by(tenant_id=tenant_id, app_id=workflow_run.app_id, workflow_run_id=workflow_run.id)
.first()
)
workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None
# get message_id
message_data = (
db.session.query(Message.id)
.filter_by(conversation_id=conversation_id, workflow_run_id=workflow_run_id)
.first()
)
message_id = str(message_data.id) if message_data else None
metadata = {
"workflow_id": workflow_id,
"conversation_id": conversation_id,
"workflow_run_id": workflow_run_id,
"tenant_id": tenant_id,
"elapsed_time": workflow_run_elapsed_time,
"status": workflow_run_status,
"version": workflow_run_version,
"total_tokens": total_tokens,
"file_list": file_list,
"triggered_form": workflow_run.triggered_from,
"user_id": user_id,
}
workflow_trace_info = WorkflowTraceInfo( with Session(db.engine) as session:
workflow_data=workflow_run.to_dict(), workflow_run_stmt = select(WorkflowRun).where(WorkflowRun.id == workflow_run_id)
conversation_id=conversation_id, workflow_run = session.scalars(workflow_run_stmt).first()
workflow_id=workflow_id, if not workflow_run:
tenant_id=tenant_id, raise ValueError("Workflow run not found")
workflow_run_id=workflow_run_id,
workflow_run_elapsed_time=workflow_run_elapsed_time, workflow_id = workflow_run.workflow_id
workflow_run_status=workflow_run_status, tenant_id = workflow_run.tenant_id
workflow_run_inputs=workflow_run_inputs, workflow_run_id = workflow_run.id
workflow_run_outputs=workflow_run_outputs, workflow_run_elapsed_time = workflow_run.elapsed_time
workflow_run_version=workflow_run_version, workflow_run_status = workflow_run.status
error=error, workflow_run_inputs = workflow_run.inputs_dict
total_tokens=total_tokens, workflow_run_outputs = workflow_run.outputs_dict
file_list=file_list, workflow_run_version = workflow_run.version
query=query, error = workflow_run.error or ""
metadata=metadata,
workflow_app_log_id=workflow_app_log_id, total_tokens = workflow_run.total_tokens
message_id=message_id,
start_time=workflow_run.created_at, file_list = workflow_run_inputs.get("sys.file") or []
end_time=workflow_run.finished_at, query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
)
# get workflow_app_log_id
workflow_app_log_data_stmt = select(WorkflowAppLog.id).where(
WorkflowAppLog.tenant_id == tenant_id,
WorkflowAppLog.app_id == workflow_run.app_id,
WorkflowAppLog.workflow_run_id == workflow_run.id,
)
workflow_app_log_id = session.scalar(workflow_app_log_data_stmt)
# get message_id
message_id = None
if conversation_id:
message_data_stmt = select(Message.id).where(
Message.conversation_id == conversation_id,
Message.workflow_run_id == workflow_run_id,
)
message_id = session.scalar(message_data_stmt)
metadata = {
"workflow_id": workflow_id,
"conversation_id": conversation_id,
"workflow_run_id": workflow_run_id,
"tenant_id": tenant_id,
"elapsed_time": workflow_run_elapsed_time,
"status": workflow_run_status,
"version": workflow_run_version,
"total_tokens": total_tokens,
"file_list": file_list,
"triggered_form": workflow_run.triggered_from,
"user_id": user_id,
}
workflow_trace_info = WorkflowTraceInfo(
workflow_data=workflow_run.to_dict(),
conversation_id=conversation_id,
workflow_id=workflow_id,
tenant_id=tenant_id,
workflow_run_id=workflow_run_id,
workflow_run_elapsed_time=workflow_run_elapsed_time,
workflow_run_status=workflow_run_status,
workflow_run_inputs=workflow_run_inputs,
workflow_run_outputs=workflow_run_outputs,
workflow_run_version=workflow_run_version,
error=error,
total_tokens=total_tokens,
file_list=file_list,
query=query,
metadata=metadata,
workflow_app_log_id=workflow_app_log_id,
message_id=message_id,
start_time=workflow_run.created_at,
end_time=workflow_run.finished_at,
)
return workflow_trace_info return workflow_trace_info
def message_trace(self, message_id): def message_trace(self, message_id: str | None):
if not message_id:
return {}
message_data = get_message_data(message_id) message_data = get_message_data(message_id)
if not message_data: if not message_data:
return {} return {}
conversation_mode = db.session.query(Conversation.mode).filter_by(id=message_data.conversation_id).first() conversation_mode_stmt = select(Conversation.mode).where(Conversation.id == message_data.conversation_id)
conversation_mode = db.session.scalars(conversation_mode_stmt).all()
if not conversation_mode or len(conversation_mode) == 0:
return {}
conversation_mode = conversation_mode[0] conversation_mode = conversation_mode[0]
created_at = message_data.created_at created_at = message_data.created_at
inputs = message_data.message inputs = message_data.message

@ -18,7 +18,7 @@ def filter_none_values(data: dict):
return new_data return new_data
def get_message_data(message_id): def get_message_data(message_id: str):
return db.session.query(Message).filter(Message.id == message_id).first() return db.session.query(Message).filter(Message.id == message_id).first()

@ -1,5 +1,5 @@
import re import re
from typing import Optional from typing import Optional, cast
class JiebaKeywordTableHandler: class JiebaKeywordTableHandler:
@ -8,18 +8,20 @@ class JiebaKeywordTableHandler:
from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS from core.rag.datasource.keyword.jieba.stopwords import STOPWORDS
jieba.analyse.default_tfidf.stop_words = STOPWORDS jieba.analyse.default_tfidf.stop_words = STOPWORDS # type: ignore
def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]: def extract_keywords(self, text: str, max_keywords_per_chunk: Optional[int] = 10) -> set[str]:
"""Extract keywords with JIEBA tfidf.""" """Extract keywords with JIEBA tfidf."""
import jieba # type: ignore import jieba.analyse # type: ignore
keywords = jieba.analyse.extract_tags( keywords = jieba.analyse.extract_tags(
sentence=text, sentence=text,
topK=max_keywords_per_chunk, topK=max_keywords_per_chunk,
) )
# jieba.analyse.extract_tags returns list[Any] when withFlag is False by default.
keywords = cast(list[str], keywords)
return set(self._expand_tokens_with_subtokens(keywords)) return set(self._expand_tokens_with_subtokens(set(keywords)))
def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]: def _expand_tokens_with_subtokens(self, tokens: set[str]) -> set[str]:
"""Get subtokens from a list of tokens., filtering for stopwords.""" """Get subtokens from a list of tokens., filtering for stopwords."""

@ -6,11 +6,14 @@ from flask import Flask, current_app
from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.data_post_processor import DataPostProcessor
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.embedding.retrieval import RetrievalSegments
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.models.document import Document from core.rag.models.document import Document
from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.rerank_type import RerankMode
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import ChildChunk, Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.external_knowledge_service import ExternalDatasetService from services.external_knowledge_service import ExternalDatasetService
default_retrieval_model = { default_retrieval_model = {
@ -248,3 +251,89 @@ class RetrievalService:
@staticmethod @staticmethod
def escape_query_for_search(query: str) -> str: def escape_query_for_search(query: str) -> str:
return query.replace('"', '\\"') return query.replace('"', '\\"')
@staticmethod
def format_retrieval_documents(documents: list[Document]) -> list[RetrievalSegments]:
records = []
include_segment_ids = []
segment_child_map = {}
for document in documents:
document_id = document.metadata.get("document_id")
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if dataset_document:
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_index_node_id = document.metadata.get("doc_id")
result = (
db.session.query(ChildChunk, DocumentSegment)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
ChildChunk.index_node_id == child_index_node_id,
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
)
.first()
)
if result:
child_chunk, segment = result
if not segment:
continue
if segment.id not in include_segment_ids:
include_segment_ids.append(segment.id)
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
map_detail = {
"max_score": document.metadata.get("score", 0.0),
"child_chunks": [child_chunk_detail],
}
segment_child_map[segment.id] = map_detail
record = {
"segment": segment,
}
records.append(record)
else:
child_chunk_detail = {
"id": child_chunk.id,
"content": child_chunk.content,
"position": child_chunk.position,
"score": document.metadata.get("score", 0.0),
}
segment_child_map[segment.id]["child_chunks"].append(child_chunk_detail)
segment_child_map[segment.id]["max_score"] = max(
segment_child_map[segment.id]["max_score"], document.metadata.get("score", 0.0)
)
else:
continue
else:
index_node_id = document.metadata["doc_id"]
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset_document.dataset_id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
)
if not segment:
continue
include_segment_ids.append(segment.id)
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}
records.append(record)
for record in records:
if record["segment"].id in segment_child_map:
record["child_chunks"] = segment_child_map[record["segment"].id].get("child_chunks", None)
record["score"] = segment_child_map[record["segment"].id]["max_score"]
return [RetrievalSegments(**record) for record in records]

@ -7,7 +7,7 @@ from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment from models.dataset import ChildChunk, Dataset, DocumentSegment
class DatasetDocumentStore: class DatasetDocumentStore:
@ -60,7 +60,7 @@ class DatasetDocumentStore:
return output return output
def add_documents(self, docs: Sequence[Document], allow_update: bool = True) -> None: def add_documents(self, docs: Sequence[Document], allow_update: bool = True, save_child: bool = False) -> None:
max_position = ( max_position = (
db.session.query(func.max(DocumentSegment.position)) db.session.query(func.max(DocumentSegment.position))
.filter(DocumentSegment.document_id == self._document_id) .filter(DocumentSegment.document_id == self._document_id)
@ -120,13 +120,55 @@ class DatasetDocumentStore:
segment_document.answer = doc.metadata.pop("answer", "") segment_document.answer = doc.metadata.pop("answer", "")
db.session.add(segment_document) db.session.add(segment_document)
db.session.flush()
if save_child:
if doc.children:
for postion, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=postion,
index_node_id=child.metadata.get("doc_id"),
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
else: else:
segment_document.content = doc.page_content segment_document.content = doc.page_content
if doc.metadata.get("answer"): if doc.metadata.get("answer"):
segment_document.answer = doc.metadata.pop("answer", "") segment_document.answer = doc.metadata.pop("answer", "")
segment_document.index_node_hash = doc.metadata["doc_hash"] segment_document.index_node_hash = doc.metadata.get("doc_hash")
segment_document.word_count = len(doc.page_content) segment_document.word_count = len(doc.page_content)
segment_document.tokens = tokens segment_document.tokens = tokens
if save_child and doc.children:
# delete the existing child chunks
db.session.query(ChildChunk).filter(
ChildChunk.tenant_id == self._dataset.tenant_id,
ChildChunk.dataset_id == self._dataset.id,
ChildChunk.document_id == self._document_id,
ChildChunk.segment_id == segment_document.id,
).delete()
# add new child chunks
for position, child in enumerate(doc.children, start=1):
child_segment = ChildChunk(
tenant_id=self._dataset.tenant_id,
dataset_id=self._dataset.id,
document_id=self._document_id,
segment_id=segment_document.id,
position=position,
index_node_id=child.metadata.get("doc_id"),
index_node_hash=child.metadata.get("doc_hash"),
content=child.page_content,
word_count=len(child.page_content),
type="automatic",
created_by=self._user_id,
)
db.session.add(child_segment)
db.session.commit() db.session.commit()

@ -0,0 +1,23 @@
from typing import Optional
from pydantic import BaseModel
from models.dataset import DocumentSegment
class RetrievalChildChunk(BaseModel):
"""Retrieval segments."""
id: str
content: str
score: float
position: int
class RetrievalSegments(BaseModel):
"""Retrieval segments."""
model_config = {"arbitrary_types_allowed": True}
segment: DocumentSegment
child_chunks: Optional[list[RetrievalChildChunk]] = None
score: Optional[float] = None

@ -4,7 +4,7 @@ import os
from typing import Optional, cast from typing import Optional, cast
import pandas as pd import pandas as pd
from openpyxl import load_workbook from openpyxl import load_workbook # type: ignore
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document

@ -24,7 +24,6 @@ from core.rag.extractor.unstructured.unstructured_markdown_extractor import Unst
from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor from core.rag.extractor.unstructured.unstructured_msg_extractor import UnstructuredMsgExtractor
from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor from core.rag.extractor.unstructured.unstructured_ppt_extractor import UnstructuredPPTExtractor
from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor from core.rag.extractor.unstructured.unstructured_pptx_extractor import UnstructuredPPTXExtractor
from core.rag.extractor.unstructured.unstructured_text_extractor import UnstructuredTextExtractor
from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor from core.rag.extractor.unstructured.unstructured_xml_extractor import UnstructuredXmlExtractor
from core.rag.extractor.word_extractor import WordExtractor from core.rag.extractor.word_extractor import WordExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
@ -103,12 +102,11 @@ class ExtractProcessor:
input_file = Path(file_path) input_file = Path(file_path)
file_extension = input_file.suffix.lower() file_extension = input_file.suffix.lower()
etl_type = dify_config.ETL_TYPE etl_type = dify_config.ETL_TYPE
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY
assert unstructured_api_url is not None, "unstructured_api_url is required"
assert unstructured_api_key is not None, "unstructured_api_key is required"
extractor: Optional[BaseExtractor] = None extractor: Optional[BaseExtractor] = None
if etl_type == "Unstructured": if etl_type == "Unstructured":
unstructured_api_url = dify_config.UNSTRUCTURED_API_URL
unstructured_api_key = dify_config.UNSTRUCTURED_API_KEY or ""
if file_extension in {".xlsx", ".xls"}: if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path) extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf": elif file_extension == ".pdf":
@ -141,11 +139,7 @@ class ExtractProcessor:
extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key) extractor = UnstructuredEpubExtractor(file_path, unstructured_api_url, unstructured_api_key)
else: else:
# txt # txt
extractor = ( extractor = TextExtractor(file_path, autodetect_encoding=True)
UnstructuredTextExtractor(file_path, unstructured_api_url)
if is_automatic
else TextExtractor(file_path, autodetect_encoding=True)
)
else: else:
if file_extension in {".xlsx", ".xls"}: if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path) extractor = ExcelExtractor(file_path)

@ -1,5 +1,6 @@
import base64 import base64
import logging import logging
from typing import Optional
from bs4 import BeautifulSoup # type: ignore from bs4 import BeautifulSoup # type: ignore
@ -15,7 +16,7 @@ class UnstructuredEmailExtractor(BaseExtractor):
file_path: Path to the file to load. file_path: Path to the file to load.
""" """
def __init__(self, file_path: str, api_url: str, api_key: str): def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url

@ -19,7 +19,7 @@ class UnstructuredEpubExtractor(BaseExtractor):
self, self,
file_path: str, file_path: str,
api_url: Optional[str] = None, api_url: Optional[str] = None,
api_key: Optional[str] = None, api_key: str = "",
): ):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path
@ -30,9 +30,6 @@ class UnstructuredEpubExtractor(BaseExtractor):
if self._api_url: if self._api_url:
from unstructured.partition.api import partition_via_api from unstructured.partition.api import partition_via_api
if self._api_key is None:
raise ValueError("api_key is required")
elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key) elements = partition_via_api(filename=self._file_path, api_url=self._api_url, api_key=self._api_key)
else: else:
from unstructured.partition.epub import partition_epub from unstructured.partition.epub import partition_epub

@ -1,4 +1,5 @@
import logging import logging
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
@ -24,7 +25,7 @@ class UnstructuredMarkdownExtractor(BaseExtractor):
if the specified encoding fails. if the specified encoding fails.
""" """
def __init__(self, file_path: str, api_url: str, api_key: str): def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url

@ -1,4 +1,5 @@
import logging import logging
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
@ -14,7 +15,7 @@ class UnstructuredMsgExtractor(BaseExtractor):
file_path: Path to the file to load. file_path: Path to the file to load.
""" """
def __init__(self, file_path: str, api_url: str, api_key: str): def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url

@ -1,4 +1,5 @@
import logging import logging
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
@ -14,7 +15,7 @@ class UnstructuredPPTExtractor(BaseExtractor):
file_path: Path to the file to load. file_path: Path to the file to load.
""" """
def __init__(self, file_path: str, api_url: str, api_key: str): def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url

@ -1,4 +1,5 @@
import logging import logging
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
@ -14,7 +15,7 @@ class UnstructuredPPTXExtractor(BaseExtractor):
file_path: Path to the file to load. file_path: Path to the file to load.
""" """
def __init__(self, file_path: str, api_url: str, api_key: str): def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url

@ -1,4 +1,5 @@
import logging import logging
from typing import Optional
from core.rag.extractor.extractor_base import BaseExtractor from core.rag.extractor.extractor_base import BaseExtractor
from core.rag.models.document import Document from core.rag.models.document import Document
@ -14,7 +15,7 @@ class UnstructuredXmlExtractor(BaseExtractor):
file_path: Path to the file to load. file_path: Path to the file to load.
""" """
def __init__(self, file_path: str, api_url: str, api_key: str): def __init__(self, file_path: str, api_url: Optional[str] = None, api_key: str = ""):
"""Initialize with file path.""" """Initialize with file path."""
self._file_path = file_path self._file_path = file_path
self._api_url = api_url self._api_url = api_url

@ -267,8 +267,10 @@ class WordExtractor(BaseExtractor):
if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph
para = paragraphs.pop(0) para = paragraphs.pop(0)
parsed_paragraph = parse_paragraph(para) parsed_paragraph = parse_paragraph(para)
if parsed_paragraph: if parsed_paragraph.strip():
content.append(parsed_paragraph) content.append(parsed_paragraph)
else:
content.append("\n")
elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table elif isinstance(element.tag, str) and element.tag.endswith("tbl"): # table
table = tables.pop(0) table = tables.pop(0)
content.append(self._table_to_markdown(table, image_map)) content.append(self._table_to_markdown(table, image_map))

@ -1,8 +1,7 @@
from enum import Enum from enum import Enum
class IndexType(Enum): class IndexType(str, Enum):
PARAGRAPH_INDEX = "text_model" PARAGRAPH_INDEX = "text_model"
QA_INDEX = "qa_model" QA_INDEX = "qa_model"
PARENT_CHILD_INDEX = "parent_child_index" PARENT_CHILD_INDEX = "hierarchical_model"
SUMMARY_INDEX = "summary_index"

@ -27,10 +27,10 @@ class BaseIndexProcessor(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
raise NotImplementedError raise NotImplementedError
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
@ -45,26 +45,29 @@ class BaseIndexProcessor(ABC):
) -> list[Document]: ) -> list[Document]:
raise NotImplementedError raise NotImplementedError
def _get_splitter(self, processing_rule: dict, embedding_model_instance: Optional[ModelInstance]) -> TextSplitter: def _get_splitter(
self,
processing_rule_mode: str,
max_tokens: int,
chunk_overlap: int,
separator: str,
embedding_model_instance: Optional[ModelInstance],
) -> TextSplitter:
""" """
Get the NodeParser object according to the processing rule. Get the NodeParser object according to the processing rule.
""" """
character_splitter: TextSplitter if processing_rule_mode in ["custom", "hierarchical"]:
if processing_rule["mode"] == "custom":
# The user-defined segmentation rule # The user-defined segmentation rule
rules = processing_rule["rules"]
segmentation = rules["segmentation"]
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
if segmentation["max_tokens"] < 50 or segmentation["max_tokens"] > max_segmentation_tokens_length: if max_tokens < 50 or max_tokens > max_segmentation_tokens_length:
raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.") raise ValueError(f"Custom segment length should be between 50 and {max_segmentation_tokens_length}.")
separator = segmentation["separator"]
if separator: if separator:
separator = separator.replace("\\n", "\n") separator = separator.replace("\\n", "\n")
character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder( character_splitter = FixedRecursiveCharacterTextSplitter.from_encoder(
chunk_size=segmentation["max_tokens"], chunk_size=max_tokens,
chunk_overlap=segmentation.get("chunk_overlap", 0) or 0, chunk_overlap=chunk_overlap,
fixed_separator=separator, fixed_separator=separator,
separators=["\n\n", "", ". ", " ", ""], separators=["\n\n", "", ". ", " ", ""],
embedding_model_instance=embedding_model_instance, embedding_model_instance=embedding_model_instance,
@ -78,4 +81,4 @@ class BaseIndexProcessor(ABC):
embedding_model_instance=embedding_model_instance, embedding_model_instance=embedding_model_instance,
) )
return character_splitter return character_splitter # type: ignore

@ -3,6 +3,7 @@
from core.rag.index_processor.constant.index_type import IndexType from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor
from core.rag.index_processor.processor.parent_child_index_processor import ParentChildIndexProcessor
from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor from core.rag.index_processor.processor.qa_index_processor import QAIndexProcessor
@ -18,9 +19,11 @@ class IndexProcessorFactory:
if not self._index_type: if not self._index_type:
raise ValueError("Index type must be specified.") raise ValueError("Index type must be specified.")
if self._index_type == IndexType.PARAGRAPH_INDEX.value: if self._index_type == IndexType.PARAGRAPH_INDEX:
return ParagraphIndexProcessor() return ParagraphIndexProcessor()
elif self._index_type == IndexType.QA_INDEX.value: elif self._index_type == IndexType.QA_INDEX:
return QAIndexProcessor() return QAIndexProcessor()
elif self._index_type == IndexType.PARENT_CHILD_INDEX:
return ParentChildIndexProcessor()
else: else:
raise ValueError(f"Index type {self._index_type} is not supported.") raise ValueError(f"Index type {self._index_type} is not supported.")

@ -13,21 +13,40 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import Document from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper from libs import helper
from models.dataset import Dataset from models.dataset import Dataset, DatasetProcessRule
from services.entities.knowledge_entities.knowledge_entities import Rule
class ParagraphIndexProcessor(BaseIndexProcessor): class ParagraphIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract( text_docs = ExtractProcessor.extract(
extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" extract_setting=extract_setting,
is_automatic=(
kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
),
) )
return text_docs return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
if process_rule.get("mode") == "automatic":
automatic_rule = DatasetProcessRule.AUTOMATIC_RULES
rules = Rule(**automatic_rule)
else:
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
# Split the text documents into nodes. # Split the text documents into nodes.
if not rules.segmentation:
raise ValueError("No segmentation found in rules.")
splitter = self._get_splitter( splitter = self._get_splitter(
processing_rule=kwargs.get("process_rule", {}), processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
embedding_model_instance=kwargs.get("embedding_model_instance"), embedding_model_instance=kwargs.get("embedding_model_instance"),
) )
all_documents = [] all_documents = []
@ -53,15 +72,19 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
all_documents.extend(split_documents) all_documents.extend(split_documents)
return all_documents return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
vector = Vector(dataset) vector = Vector(dataset)
vector.create(documents) vector.create(documents)
if with_keywords: if with_keywords:
keywords_list = kwargs.get("keywords_list")
keyword = Keyword(dataset) keyword = Keyword(dataset)
keyword.create(documents) if keywords_list and len(keywords_list) > 0:
keyword.add_texts(documents, keywords_list=keywords_list)
else:
keyword.add_texts(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
vector = Vector(dataset) vector = Vector(dataset)
if node_ids: if node_ids:

@ -0,0 +1,195 @@
"""Paragraph index processor."""
import uuid
from typing import Optional
from core.model_manager import ModelInstance
from core.rag.cleaner.clean_processor import CleanProcessor
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.extractor.entity.extract_setting import ExtractSetting
from core.rag.extractor.extract_processor import ExtractProcessor
from core.rag.index_processor.index_processor_base import BaseIndexProcessor
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from libs import helper
from models.dataset import ChildChunk, Dataset, DocumentSegment
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
class ParentChildIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract(
extract_setting=extract_setting,
is_automatic=(
kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
),
)
return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]:
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
all_documents = [] # type: ignore
if rules.parent_mode == ParentMode.PARAGRAPH:
# Split the text documents into nodes.
splitter = self._get_splitter(
processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens,
chunk_overlap=rules.segmentation.chunk_overlap,
separator=rules.segmentation.separator,
embedding_model_instance=kwargs.get("embedding_model_instance"),
)
for document in documents:
# document clean
document_text = CleanProcessor.clean(document.page_content, process_rule)
document.page_content = document_text
# parse document to nodes
document_nodes = splitter.split_documents([document])
split_documents = []
for document_node in document_nodes:
if document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document_node.page_content)
document_node.metadata["doc_id"] = doc_id
document_node.metadata["doc_hash"] = hash
# delete Splitter character
page_content = document_node.page_content
if page_content.startswith(".") or page_content.startswith(""):
page_content = page_content[1:].strip()
else:
page_content = page_content
if len(page_content) > 0:
document_node.page_content = page_content
# parse document to child nodes
child_nodes = self._split_child_nodes(
document_node, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
)
document_node.children = child_nodes
split_documents.append(document_node)
all_documents.extend(split_documents)
elif rules.parent_mode == ParentMode.FULL_DOC:
page_content = "\n".join([document.page_content for document in documents])
document = Document(page_content=page_content, metadata=documents[0].metadata)
# parse document to child nodes
child_nodes = self._split_child_nodes(
document, rules, process_rule.get("mode"), kwargs.get("embedding_model_instance")
)
document.children = child_nodes
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(document.page_content)
document.metadata["doc_id"] = doc_id
document.metadata["doc_hash"] = hash
all_documents.append(document)
return all_documents
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality":
vector = Vector(dataset)
for document in documents:
child_documents = document.children
if child_documents:
formatted_child_documents = [
Document(**child_document.model_dump()) for child_document in child_documents
]
vector.create(formatted_child_documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
# node_ids is segment's node_ids
if dataset.indexing_technique == "high_quality":
delete_child_chunks = kwargs.get("delete_child_chunks") or False
vector = Vector(dataset)
if node_ids:
child_node_ids = (
db.session.query(ChildChunk.index_node_id)
.join(DocumentSegment, ChildChunk.segment_id == DocumentSegment.id)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.index_node_id.in_(node_ids),
ChildChunk.dataset_id == dataset.id,
)
.all()
)
child_node_ids = [child_node_id[0] for child_node_id in child_node_ids]
vector.delete_by_ids(child_node_ids)
if delete_child_chunks:
db.session.query(ChildChunk).filter(
ChildChunk.dataset_id == dataset.id, ChildChunk.index_node_id.in_(child_node_ids)
).delete()
db.session.commit()
else:
vector.delete()
if delete_child_chunks:
db.session.query(ChildChunk).filter(ChildChunk.dataset_id == dataset.id).delete()
db.session.commit()
def retrieve(
self,
retrieval_method: str,
query: str,
dataset: Dataset,
top_k: int,
score_threshold: float,
reranking_model: dict,
) -> list[Document]:
# Set search parameters.
results = RetrievalService.retrieve(
retrieval_method=retrieval_method,
dataset_id=dataset.id,
query=query,
top_k=top_k,
score_threshold=score_threshold,
reranking_model=reranking_model,
)
# Organize results.
docs = []
for result in results:
metadata = result.metadata
metadata["score"] = result.score
if result.score > score_threshold:
doc = Document(page_content=result.page_content, metadata=metadata)
docs.append(doc)
return docs
def _split_child_nodes(
self,
document_node: Document,
rules: Rule,
process_rule_mode: str,
embedding_model_instance: Optional[ModelInstance],
) -> list[ChildDocument]:
if not rules.subchunk_segmentation:
raise ValueError("No subchunk segmentation found in rules.")
child_splitter = self._get_splitter(
processing_rule_mode=process_rule_mode,
max_tokens=rules.subchunk_segmentation.max_tokens,
chunk_overlap=rules.subchunk_segmentation.chunk_overlap,
separator=rules.subchunk_segmentation.separator,
embedding_model_instance=embedding_model_instance,
)
# parse document to child nodes
child_nodes = []
child_documents = child_splitter.split_documents([document_node])
for child_document_node in child_documents:
if child_document_node.page_content.strip():
doc_id = str(uuid.uuid4())
hash = helper.generate_text_hash(child_document_node.page_content)
child_document = ChildDocument(
page_content=child_document_node.page_content, metadata=document_node.metadata
)
child_document.metadata["doc_id"] = doc_id
child_document.metadata["doc_hash"] = hash
child_page_content = child_document.page_content
if child_page_content.startswith(".") or child_page_content.startswith(""):
child_page_content = child_page_content[1:].strip()
if len(child_page_content) > 0:
child_document.page_content = child_page_content
child_nodes.append(child_document)
return child_nodes

@ -21,18 +21,32 @@ from core.rag.models.document import Document
from core.tools.utils.text_processing_utils import remove_leading_symbols from core.tools.utils.text_processing_utils import remove_leading_symbols
from libs import helper from libs import helper
from models.dataset import Dataset from models.dataset import Dataset
from services.entities.knowledge_entities.knowledge_entities import Rule
class QAIndexProcessor(BaseIndexProcessor): class QAIndexProcessor(BaseIndexProcessor):
def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]:
text_docs = ExtractProcessor.extract( text_docs = ExtractProcessor.extract(
extract_setting=extract_setting, is_automatic=kwargs.get("process_rule_mode") == "automatic" extract_setting=extract_setting,
is_automatic=(
kwargs.get("process_rule_mode") == "automatic" or kwargs.get("process_rule_mode") == "hierarchical"
),
) )
return text_docs return text_docs
def transform(self, documents: list[Document], **kwargs) -> list[Document]: def transform(self, documents: list[Document], **kwargs) -> list[Document]:
preview = kwargs.get("preview")
process_rule = kwargs.get("process_rule")
if not process_rule:
raise ValueError("No process rule found.")
if not process_rule.get("rules"):
raise ValueError("No rules found in process rule.")
rules = Rule(**process_rule.get("rules"))
splitter = self._get_splitter( splitter = self._get_splitter(
processing_rule=kwargs.get("process_rule") or {}, processing_rule_mode=process_rule.get("mode"),
max_tokens=rules.segmentation.max_tokens if rules.segmentation else 0,
chunk_overlap=rules.segmentation.chunk_overlap if rules.segmentation else 0,
separator=rules.segmentation.separator if rules.segmentation else "",
embedding_model_instance=kwargs.get("embedding_model_instance"), embedding_model_instance=kwargs.get("embedding_model_instance"),
) )
@ -59,24 +73,33 @@ class QAIndexProcessor(BaseIndexProcessor):
document_node.page_content = remove_leading_symbols(page_content) document_node.page_content = remove_leading_symbols(page_content)
split_documents.append(document_node) split_documents.append(document_node)
all_documents.extend(split_documents) all_documents.extend(split_documents)
for i in range(0, len(all_documents), 10): if preview:
threads = [] self._format_qa_document(
sub_documents = all_documents[i : i + 10] current_app._get_current_object(), # type: ignore
for doc in sub_documents: kwargs.get("tenant_id"), # type: ignore
document_format_thread = threading.Thread( all_documents[0],
target=self._format_qa_document, all_qa_documents,
kwargs={ kwargs.get("doc_language", "English"),
"flask_app": current_app._get_current_object(), # type: ignore )
"tenant_id": kwargs.get("tenant_id"), else:
"document_node": doc, for i in range(0, len(all_documents), 10):
"all_qa_documents": all_qa_documents, threads = []
"document_language": kwargs.get("doc_language", "English"), sub_documents = all_documents[i : i + 10]
}, for doc in sub_documents:
) document_format_thread = threading.Thread(
threads.append(document_format_thread) target=self._format_qa_document,
document_format_thread.start() kwargs={
for thread in threads: "flask_app": current_app._get_current_object(), # type: ignore
thread.join() "tenant_id": kwargs.get("tenant_id"), # type: ignore
"document_node": doc,
"all_qa_documents": all_qa_documents,
"document_language": kwargs.get("doc_language", "English"),
},
)
threads.append(document_format_thread)
document_format_thread.start()
for thread in threads:
thread.join()
return all_qa_documents return all_qa_documents
def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]: def format_by_template(self, file: FileStorage, **kwargs) -> list[Document]:
@ -98,12 +121,12 @@ class QAIndexProcessor(BaseIndexProcessor):
raise ValueError(str(e)) raise ValueError(str(e))
return text_docs return text_docs
def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True): def load(self, dataset: Dataset, documents: list[Document], with_keywords: bool = True, **kwargs):
if dataset.indexing_technique == "high_quality": if dataset.indexing_technique == "high_quality":
vector = Vector(dataset) vector = Vector(dataset)
vector.create(documents) vector.create(documents)
def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True): def clean(self, dataset: Dataset, node_ids: Optional[list[str]], with_keywords: bool = True, **kwargs):
vector = Vector(dataset) vector = Vector(dataset)
if node_ids: if node_ids:
vector.delete_by_ids(node_ids) vector.delete_by_ids(node_ids)

@ -2,7 +2,20 @@ from abc import ABC, abstractmethod
from collections.abc import Sequence from collections.abc import Sequence
from typing import Any, Optional from typing import Any, Optional
from pydantic import BaseModel, Field from pydantic import BaseModel
class ChildDocument(BaseModel):
"""Class for storing a piece of text and associated metadata."""
page_content: str
vector: Optional[list[float]] = None
"""Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.).
"""
metadata: dict = {}
class Document(BaseModel): class Document(BaseModel):
@ -15,10 +28,12 @@ class Document(BaseModel):
"""Arbitrary metadata about the page content (e.g., source, relationships to other """Arbitrary metadata about the page content (e.g., source, relationships to other
documents, etc.). documents, etc.).
""" """
metadata: Optional[dict] = Field(default_factory=dict) metadata: dict = {}
provider: Optional[str] = "dify" provider: Optional[str] = "dify"
children: Optional[list[ChildDocument]] = None
class BaseDocumentTransformer(ABC): class BaseDocumentTransformer(ABC):
"""Abstract base class for document transformation systems. """Abstract base class for document transformation systems.

@ -164,43 +164,29 @@ class DatasetRetrieval:
"content": item.page_content, "content": item.page_content,
} }
retrieval_resource_list.append(source) retrieval_resource_list.append(source)
document_score_list = {}
# deal with dify documents # deal with dify documents
if dify_documents: if dify_documents:
for item in dify_documents: records = RetrievalService.format_retrieval_documents(dify_documents)
if item.metadata.get("score"): if records:
document_score_list[item.metadata["doc_id"]] = item.metadata["score"] for record in records:
segment = record.segment
index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
if segment.answer: if segment.answer:
document_context_list.append( document_context_list.append(
DocumentContext( DocumentContext(
content=f"question:{segment.get_sign_content()} answer:{segment.answer}", content=f"question:{segment.get_sign_content()} answer:{segment.answer}",
score=document_score_list.get(segment.index_node_id, None), score=record.score,
) )
) )
else: else:
document_context_list.append( document_context_list.append(
DocumentContext( DocumentContext(
content=segment.get_sign_content(), content=segment.get_sign_content(),
score=document_score_list.get(segment.index_node_id, None), score=record.score,
) )
) )
if show_retrieve_source: if show_retrieve_source:
for segment in sorted_segments: for record in records:
segment = record.segment
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = DatasetDocument.query.filter( document = DatasetDocument.query.filter(
DatasetDocument.id == segment.document_id, DatasetDocument.id == segment.document_id,
@ -216,7 +202,7 @@ class DatasetRetrieval:
"data_source_type": document.data_source_type, "data_source_type": document.data_source_type,
"segment_id": segment.id, "segment_id": segment.id,
"retriever_from": invoke_from.to_source(), "retriever_from": invoke_from.to_source(),
"score": document_score_list.get(segment.index_node_id, 0.0), "score": record.score or 0.0,
} }
if invoke_from.to_source() == "dev": if invoke_from.to_source() == "dev":

@ -267,6 +267,7 @@ class ToolParameter(PluginParameter):
:param options: the options of the parameter :param options: the options of the parameter
""" """
# convert options to ToolParameterOption # convert options to ToolParameterOption
# FIXME fix the type error
if options: if options:
option_objs = [ option_objs = [
PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option)) PluginParameterOption(value=option, label=I18nObject(en_US=option, zh_Hans=option))

@ -139,7 +139,7 @@ class ToolEngine:
error_response = f"tool invoke error: {e}" error_response = f"tool invoke error: {e}"
agent_tool_callback.on_tool_error(e) agent_tool_callback.on_tool_error(e)
except ToolEngineInvokeError as e: except ToolEngineInvokeError as e:
meta = e.args[0] meta = e.meta
error_response = f"tool invoke error: {meta.error}" error_response = f"tool invoke error: {meta.error}"
agent_tool_callback.on_tool_error(e) agent_tool_callback.on_tool_error(e)
return error_response, [], meta return error_response, [], meta

@ -12,5 +12,6 @@ def remove_leading_symbols(text: str) -> str:
str: The text with leading punctuation or symbols removed. str: The text with leading punctuation or symbols removed.
""" """
# Match Unicode ranges for punctuation and symbols # Match Unicode ranges for punctuation and symbols
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,\-./:;<=>?@\[\]^_`{|}~]+" # FIXME this pattern is confused quick fix for #11868 maybe refactor it later
pattern = r"^[\u2000-\u206F\u2E00-\u2E7F\u3000-\u303F!\"#$%&'()*+,./:;<=>?@^_`~]+"
return re.sub(pattern, "", text) return re.sub(pattern, "", text)

@ -613,10 +613,10 @@ class Graph(BaseModel):
for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items(): for (node_id, node_id2), branch_node_ids in duplicate_end_node_ids.items():
# check which node is after # check which node is after
if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping): if cls._is_node2_after_node1(node1_id=node_id, node2_id=node_id2, edge_mapping=edge_mapping):
if node_id in merge_branch_node_ids: if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id2] del merge_branch_node_ids[node_id2]
elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping): elif cls._is_node2_after_node1(node1_id=node_id2, node2_id=node_id, edge_mapping=edge_mapping):
if node_id2 in merge_branch_node_ids: if node_id in merge_branch_node_ids and node_id2 in merge_branch_node_ids:
del merge_branch_node_ids[node_id] del merge_branch_node_ids[node_id]
branches_merge_node_ids: dict[str, str] = {} branches_merge_node_ids: dict[str, str] = {}

@ -48,9 +48,11 @@ class StreamProcessor(ABC):
# we remove the node maybe shortcut the answer node, so comment this code for now # we remove the node maybe shortcut the answer node, so comment this code for now
# there is not effect on the answer node and the workflow, when we have a better solution # there is not effect on the answer node and the workflow, when we have a better solution
# we can open this code. Issues: #11542 #9560 #10638 #10564 # we can open this code. Issues: #11542 #9560 #10638 #10564
ids = self._fetch_node_ids_in_reachable_branch(edge.target_node_id)
# reachable_node_ids.extend(self._fetch_node_ids_in_reachable_branch(edge.target_node_id)) if "answer" in ids:
continue continue
else:
reachable_node_ids.extend(ids)
else: else:
unreachable_first_node_ids.append(edge.target_node_id) unreachable_first_node_ids.append(edge.target_node_id)

@ -20,3 +20,7 @@ class ResponseSizeError(HttpRequestNodeError):
class RequestBodyError(HttpRequestNodeError): class RequestBodyError(HttpRequestNodeError):
"""Raised when the request body is invalid.""" """Raised when the request body is invalid."""
class InvalidURLError(HttpRequestNodeError):
"""Raised when the URL is invalid."""

@ -23,6 +23,7 @@ from .exc import (
FileFetchError, FileFetchError,
HttpRequestNodeError, HttpRequestNodeError,
InvalidHttpMethodError, InvalidHttpMethodError,
InvalidURLError,
RequestBodyError, RequestBodyError,
ResponseSizeError, ResponseSizeError,
) )
@ -66,6 +67,12 @@ class Executor:
node_data.authorization.config.api_key node_data.authorization.config.api_key
).text ).text
# check if node_data.url is a valid URL
if not node_data.url:
raise InvalidURLError("url is required")
if not node_data.url.startswith(("http://", "https://")):
raise InvalidURLError("url should start with http:// or https://")
self.url: str = node_data.url self.url: str = node_data.url
self.method = node_data.method self.method = node_data.method
self.auth = node_data.authorization self.auth = node_data.authorization

@ -11,6 +11,7 @@ from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelFeature, ModelType from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import StringSegment from core.variables import StringSegment
@ -18,7 +19,7 @@ from core.workflow.entities.node_entities import NodeRunResult
from core.workflow.nodes.base import BaseNode from core.workflow.nodes.base import BaseNode
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document
from models.workflow import WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionStatus
from .entities import KnowledgeRetrievalNodeData from .entities import KnowledgeRetrievalNodeData
@ -211,29 +212,12 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
"content": item.page_content, "content": item.page_content,
} }
retrieval_resource_list.append(source) retrieval_resource_list.append(source)
document_score_list: dict[str, float] = {}
# deal with dify documents # deal with dify documents
if dify_documents: if dify_documents:
document_score_list = {} records = RetrievalService.format_retrieval_documents(dify_documents)
for item in dify_documents: if records:
if item.metadata.get("score"): for record in records:
document_score_list[item.metadata["doc_id"]] = item.metadata["score"] segment = record.segment
index_node_ids = [document.metadata["doc_id"] for document in dify_documents]
segments = DocumentSegment.query.filter(
DocumentSegment.dataset_id.in_(dataset_ids),
DocumentSegment.completed_at.isnot(None),
DocumentSegment.status == "completed",
DocumentSegment.enabled == True,
DocumentSegment.index_node_id.in_(index_node_ids),
).all()
if segments:
index_node_id_to_position = {id: position for position, id in enumerate(index_node_ids)}
sorted_segments = sorted(
segments, key=lambda segment: index_node_id_to_position.get(segment.index_node_id, float("inf"))
)
for segment in sorted_segments:
dataset = Dataset.query.filter_by(id=segment.dataset_id).first() dataset = Dataset.query.filter_by(id=segment.dataset_id).first()
document = Document.query.filter( document = Document.query.filter(
Document.id == segment.document_id, Document.id == segment.document_id,
@ -251,7 +235,7 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
"document_data_source_type": document.data_source_type, "document_data_source_type": document.data_source_type,
"segment_id": segment.id, "segment_id": segment.id,
"retriever_from": "workflow", "retriever_from": "workflow",
"score": document_score_list.get(segment.index_node_id, None), "score": record.score or 0.0,
"segment_hit_count": segment.hit_count, "segment_hit_count": segment.hit_count,
"segment_word_count": segment.word_count, "segment_word_count": segment.word_count,
"segment_position": segment.position, "segment_position": segment.position,
@ -270,10 +254,8 @@ class KnowledgeRetrievalNode(BaseNode[KnowledgeRetrievalNodeData]):
key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0, key=lambda x: x["metadata"]["score"] if x["metadata"].get("score") is not None else 0.0,
reverse=True, reverse=True,
) )
position = 1 for position, item in enumerate(retrieval_resource_list, start=1):
for item in retrieval_resource_list:
item["metadata"]["position"] = position item["metadata"]["position"] = position
position += 1
return retrieval_resource_list return retrieval_resource_list
@classmethod @classmethod

@ -5,7 +5,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
from core.file import File, FileTransferMethod, FileType from core.file import File, FileTransferMethod
from core.plugin.manager.exc import PluginDaemonClientSideError from core.plugin.manager.exc import PluginDaemonClientSideError
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
from core.tools.tool_engine import ToolEngine from core.tools.tool_engine import ToolEngine
@ -189,10 +189,12 @@ class ToolNode(BaseNode[ToolNodeData]):
conversation_id=None, conversation_id=None,
) )
files: list[File] = []
text = "" text = ""
files: list[File] = []
json: list[dict] = [] json: list[dict] = []
agent_logs: list[AgentLog] = []
variables: dict[str, Any] = {} variables: dict[str, Any] = {}
for message in message_stream: for message in message_stream:
@ -239,14 +241,16 @@ class ToolNode(BaseNode[ToolNodeData]):
tool_file = session.scalar(stmt) tool_file = session.scalar(stmt)
if tool_file is None: if tool_file is None:
raise ToolFileError(f"tool file {tool_file_id} not exists") raise ToolFileError(f"tool file {tool_file_id} not exists")
mapping = {
"tool_file_id": tool_file_id,
"transfer_method": FileTransferMethod.TOOL_FILE,
}
files.append( files.append(
File( file_factory.build_from_mapping(
mapping=mapping,
tenant_id=self.tenant_id, tenant_id=self.tenant_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=tool_file_id,
extension=None,
mime_type=message.meta.get("mime_type", "application/octet-stream"),
) )
) )
elif message.type == ToolInvokeMessage.MessageType.TEXT: elif message.type == ToolInvokeMessage.MessageType.TEXT:

@ -5,7 +5,7 @@ from dify_app import DifyApp
def init_app(app: DifyApp): def init_app(app: DifyApp):
# register blueprint routers # register blueprint routers
from flask_cors import CORS from flask_cors import CORS # type: ignore
from controllers.console import bp as console_app_bp from controllers.console import bp as console_app_bp
from controllers.files import bp as files_bp from controllers.files import bp as files_bp

@ -69,6 +69,7 @@ def init_app(app: DifyApp) -> Celery:
"schedule.create_tidb_serverless_task", "schedule.create_tidb_serverless_task",
"schedule.update_tidb_serverless_status_task", "schedule.update_tidb_serverless_status_task",
"schedule.clean_messages", "schedule.clean_messages",
"schedule.mail_clean_document_notify_task",
] ]
day = dify_config.CELERY_BEAT_SCHEDULER_TIME day = dify_config.CELERY_BEAT_SCHEDULER_TIME
beat_schedule = { beat_schedule = {
@ -92,6 +93,11 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.clean_messages.clean_messages", "task": "schedule.clean_messages.clean_messages",
"schedule": timedelta(days=day), "schedule": timedelta(days=day),
}, },
# every Monday
"mail_clean_document_notify_task": {
"task": "schedule.mail_clean_document_notify_task.mail_clean_document_notify_task",
"schedule": crontab(minute="0", hour="10", day_of_week="1"),
},
} }
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)

@ -1,4 +1,5 @@
import mimetypes import mimetypes
import uuid
from collections.abc import Callable, Mapping, Sequence from collections.abc import Callable, Mapping, Sequence
from typing import Any, cast from typing import Any, cast
@ -119,6 +120,11 @@ def _build_from_local_file(
upload_file_id = mapping.get("upload_file_id") upload_file_id = mapping.get("upload_file_id")
if not upload_file_id: if not upload_file_id:
raise ValueError("Invalid upload file id") raise ValueError("Invalid upload file id")
# check if upload_file_id is a valid uuid
try:
uuid.UUID(upload_file_id)
except ValueError:
raise ValueError("Invalid upload file id format")
stmt = select(UploadFile).where( stmt = select(UploadFile).where(
UploadFile.id == upload_file_id, UploadFile.id == upload_file_id,
UploadFile.tenant_id == tenant_id, UploadFile.tenant_id == tenant_id,

@ -73,6 +73,7 @@ dataset_detail_fields = {
"embedding_available": fields.Boolean, "embedding_available": fields.Boolean,
"retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields), "retrieval_model_dict": fields.Nested(dataset_retrieval_model_fields),
"tags": fields.List(fields.Nested(tag_fields)), "tags": fields.List(fields.Nested(tag_fields)),
"doc_form": fields.String,
"external_knowledge_info": fields.Nested(external_knowledge_info_fields), "external_knowledge_info": fields.Nested(external_knowledge_info_fields),
"external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True), "external_retrieval_model": fields.Nested(external_retrieval_model_fields, allow_null=True),
} }

@ -34,6 +34,7 @@ document_with_segments_fields = {
"data_source_info": fields.Raw(attribute="data_source_info_dict"), "data_source_info": fields.Raw(attribute="data_source_info_dict"),
"data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"), "data_source_detail_dict": fields.Raw(attribute="data_source_detail_dict"),
"dataset_process_rule_id": fields.String, "dataset_process_rule_id": fields.String,
"process_rule_dict": fields.Raw(attribute="process_rule_dict"),
"name": fields.String, "name": fields.String,
"created_from": fields.String, "created_from": fields.String,
"created_by": fields.String, "created_by": fields.String,

@ -34,8 +34,16 @@ segment_fields = {
"document": fields.Nested(document_fields), "document": fields.Nested(document_fields),
} }
child_chunk_fields = {
"id": fields.String,
"content": fields.String,
"position": fields.Integer,
"score": fields.Float,
}
hit_testing_record_fields = { hit_testing_record_fields = {
"segment": fields.Nested(segment_fields), "segment": fields.Nested(segment_fields),
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
"score": fields.Float, "score": fields.Float,
"tsne_position": fields.Raw, "tsne_position": fields.Raw,
} }

@ -2,6 +2,17 @@ from flask_restful import fields # type: ignore
from libs.helper import TimestampField from libs.helper import TimestampField
child_chunk_fields = {
"id": fields.String,
"segment_id": fields.String,
"content": fields.String,
"position": fields.Integer,
"word_count": fields.Integer,
"type": fields.String,
"created_at": TimestampField,
"updated_at": TimestampField,
}
segment_fields = { segment_fields = {
"id": fields.String, "id": fields.String,
"position": fields.Integer, "position": fields.Integer,
@ -20,10 +31,13 @@ segment_fields = {
"status": fields.String, "status": fields.String,
"created_by": fields.String, "created_by": fields.String,
"created_at": TimestampField, "created_at": TimestampField,
"updated_at": TimestampField,
"updated_by": fields.String,
"indexing_at": TimestampField, "indexing_at": TimestampField,
"completed_at": TimestampField, "completed_at": TimestampField,
"error": fields.String, "error": fields.String,
"stopped_at": TimestampField, "stopped_at": TimestampField,
"child_chunks": fields.List(fields.Nested(child_chunk_fields)),
} }
segment_list_response = { segment_list_response = {

@ -0,0 +1,55 @@
"""parent-child-index
Revision ID: e19037032219
Revises: 01d6889832f7
Create Date: 2024-11-22 07:01:17.550037
"""
from alembic import op
import models as models
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'e19037032219'
down_revision = 'd7999dfa4aae'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('child_chunks',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('segment_id', models.types.StringUUID(), nullable=False),
sa.Column('position', sa.Integer(), nullable=False),
sa.Column('content', sa.Text(), nullable=False),
sa.Column('word_count', sa.Integer(), nullable=False),
sa.Column('index_node_id', sa.String(length=255), nullable=True),
sa.Column('index_node_hash', sa.String(length=255), nullable=True),
sa.Column('type', sa.String(length=255), server_default=sa.text("'automatic'::character varying"), nullable=False),
sa.Column('created_by', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.Column('indexing_at', sa.DateTime(), nullable=True),
sa.Column('completed_at', sa.DateTime(), nullable=True),
sa.Column('error', sa.Text(), nullable=True),
sa.PrimaryKeyConstraint('id', name='child_chunk_pkey')
)
with op.batch_alter_table('child_chunks', schema=None) as batch_op:
batch_op.create_index('child_chunk_dataset_id_idx', ['tenant_id', 'dataset_id', 'document_id', 'segment_id', 'index_node_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('child_chunks', schema=None) as batch_op:
batch_op.drop_index('child_chunk_dataset_id_idx')
op.drop_table('child_chunks')
# ### end Alembic commands ###

@ -0,0 +1,47 @@
"""add_auto_disabled_dataset_logs
Revision ID: 923752d42eb6
Revises: e19037032219
Create Date: 2024-12-25 11:37:55.467101
"""
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '923752d42eb6'
down_revision = 'e19037032219'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('dataset_auto_disable_logs',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('dataset_id', models.types.StringUUID(), nullable=False),
sa.Column('document_id', models.types.StringUUID(), nullable=False),
sa.Column('notified', sa.Boolean(), server_default=sa.text('false'), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='dataset_auto_disable_log_pkey')
)
with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op:
batch_op.create_index('dataset_auto_disable_log_created_atx', ['created_at'], unique=False)
batch_op.create_index('dataset_auto_disable_log_dataset_idx', ['dataset_id'], unique=False)
batch_op.create_index('dataset_auto_disable_log_tenant_idx', ['tenant_id'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('dataset_auto_disable_logs', schema=None) as batch_op:
batch_op.drop_index('dataset_auto_disable_log_tenant_idx')
batch_op.drop_index('dataset_auto_disable_log_dataset_idx')
batch_op.drop_index('dataset_auto_disable_log_created_atx')
op.drop_table('dataset_auto_disable_logs')
# ### end Alembic commands ###

@ -23,7 +23,7 @@ class Account(UserMixin, Base):
__tablename__ = "accounts" __tablename__ = "accounts"
__table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email")) __table_args__ = (db.PrimaryKeyConstraint("id", name="account_pkey"), db.Index("account_email_idx", "email"))
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
email = db.Column(db.String(255), nullable=False) email = db.Column(db.String(255), nullable=False)
password = db.Column(db.String(255), nullable=True) password = db.Column(db.String(255), nullable=True)

@ -17,6 +17,7 @@ from sqlalchemy.dialects.postgresql import JSONB
from configs import dify_config from configs import dify_config
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_storage import storage from extensions.ext_storage import storage
from services.entities.knowledge_entities.knowledge_entities import ParentMode, Rule
from .account import Account from .account import Account
from .engine import db from .engine import db
@ -215,7 +216,7 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
created_by = db.Column(StringUUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
MODES = ["automatic", "custom"] MODES = ["automatic", "custom", "hierarchical"]
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
AUTOMATIC_RULES: dict[str, Any] = { AUTOMATIC_RULES: dict[str, Any] = {
"pre_processing_rules": [ "pre_processing_rules": [
@ -231,8 +232,6 @@ class DatasetProcessRule(db.Model): # type: ignore[name-defined]
"dataset_id": self.dataset_id, "dataset_id": self.dataset_id,
"mode": self.mode, "mode": self.mode,
"rules": self.rules_dict, "rules": self.rules_dict,
"created_by": self.created_by,
"created_at": self.created_at,
} }
@property @property
@ -396,6 +395,12 @@ class Document(db.Model): # type: ignore[name-defined]
.scalar() .scalar()
) )
@property
def process_rule_dict(self):
if self.dataset_process_rule_id:
return self.dataset_process_rule.to_dict()
return None
def to_dict(self): def to_dict(self):
return { return {
"id": self.id, "id": self.id,
@ -560,6 +565,24 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
.first() .first()
) )
@property
def child_chunks(self):
process_rule = self.document.dataset_process_rule
if process_rule.mode == "hierarchical":
rules = Rule(**process_rule.rules_dict)
if rules.parent_mode and rules.parent_mode != ParentMode.FULL_DOC:
child_chunks = (
db.session.query(ChildChunk)
.filter(ChildChunk.segment_id == self.id)
.order_by(ChildChunk.position.asc())
.all()
)
return child_chunks or []
else:
return []
else:
return []
def get_sign_content(self): def get_sign_content(self):
signed_urls = [] signed_urls = []
text = self.content text = self.content
@ -605,6 +628,47 @@ class DocumentSegment(db.Model): # type: ignore[name-defined]
return text return text
class ChildChunk(db.Model): # type: ignore[name-defined]
__tablename__ = "child_chunks"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="child_chunk_pkey"),
db.Index("child_chunk_dataset_id_idx", "tenant_id", "dataset_id", "document_id", "segment_id", "index_node_id"),
)
# initial fields
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
segment_id = db.Column(StringUUID, nullable=False)
position = db.Column(db.Integer, nullable=False)
content = db.Column(db.Text, nullable=False)
word_count = db.Column(db.Integer, nullable=False)
# indexing fields
index_node_id = db.Column(db.String(255), nullable=True)
index_node_hash = db.Column(db.String(255), nullable=True)
type = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
indexing_at = db.Column(db.DateTime, nullable=True)
completed_at = db.Column(db.DateTime, nullable=True)
error = db.Column(db.Text, nullable=True)
@property
def dataset(self):
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()
@property
def document(self):
return db.session.query(Document).filter(Document.id == self.document_id).first()
@property
def segment(self):
return db.session.query(DocumentSegment).filter(DocumentSegment.id == self.segment_id).first()
class AppDatasetJoin(db.Model): # type: ignore[name-defined] class AppDatasetJoin(db.Model): # type: ignore[name-defined]
__tablename__ = "app_dataset_joins" __tablename__ = "app_dataset_joins"
__table_args__ = ( __table_args__ = (
@ -844,3 +908,20 @@ class ExternalKnowledgeBindings(db.Model): # type: ignore[name-defined]
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_by = db.Column(StringUUID, nullable=True) updated_by = db.Column(StringUUID, nullable=True)
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
class DatasetAutoDisableLog(db.Model): # type: ignore[name-defined]
__tablename__ = "dataset_auto_disable_logs"
__table_args__ = (
db.PrimaryKeyConstraint("id", name="dataset_auto_disable_log_pkey"),
db.Index("dataset_auto_disable_log_tenant_idx", "tenant_id"),
db.Index("dataset_auto_disable_log_dataset_idx", "dataset_id"),
db.Index("dataset_auto_disable_log_created_atx", "created_at"),
)
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False)
dataset_id = db.Column(StringUUID, nullable=False)
document_id = db.Column(StringUUID, nullable=False)
notified = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))

@ -611,13 +611,13 @@ class Conversation(Base):
db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"), db.Index("conversation_app_from_user_idx", "app_id", "from_source", "from_end_user_id"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
app_model_config_id = db.Column(StringUUID, nullable=True) app_model_config_id = db.Column(StringUUID, nullable=True)
model_provider = db.Column(db.String(255), nullable=True) model_provider = db.Column(db.String(255), nullable=True)
override_model_configs = db.Column(db.Text) override_model_configs = db.Column(db.Text)
model_id = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True)
mode = db.Column(db.String(255), nullable=False) mode: Mapped[str] = mapped_column(db.String(255))
name = db.Column(db.String(255), nullable=False) name = db.Column(db.String(255), nullable=False)
summary = db.Column(db.Text) summary = db.Column(db.Text)
_inputs: Mapped[dict] = mapped_column("inputs", db.JSON) _inputs: Mapped[dict] = mapped_column("inputs", db.JSON)
@ -851,7 +851,7 @@ class Message(Base):
Index("message_created_at_idx", "created_at"), Index("message_created_at_idx", "created_at"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
app_id = db.Column(StringUUID, nullable=False) app_id = db.Column(StringUUID, nullable=False)
model_provider = db.Column(db.String(255), nullable=True) model_provider = db.Column(db.String(255), nullable=True)
model_id = db.Column(db.String(255), nullable=True) model_id = db.Column(db.String(255), nullable=True)
@ -878,7 +878,7 @@ class Message(Base):
from_source = db.Column(db.String(255), nullable=False) from_source = db.Column(db.String(255), nullable=False)
from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID) from_end_user_id: Mapped[Optional[str]] = db.Column(StringUUID)
from_account_id: Mapped[Optional[str]] = db.Column(StringUUID) from_account_id: Mapped[Optional[str]] = db.Column(StringUUID)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) agent_based = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
workflow_run_id = db.Column(StringUUID) workflow_run_id = db.Column(StringUUID)
@ -1403,7 +1403,7 @@ class EndUser(Base, UserMixin):
external_user_id = db.Column(db.String(255), nullable=True) external_user_id = db.Column(db.String(255), nullable=True)
name = db.Column(db.String(255)) name = db.Column(db.String(255))
is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) is_anonymous = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
session_id = db.Column(db.String(255), nullable=False) session_id: Mapped[str] = mapped_column()
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())

@ -1,6 +1,7 @@
import json import json
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from typing import Any
import sqlalchemy as sa import sqlalchemy as sa
from deprecated import deprecated from deprecated import deprecated
@ -256,8 +257,8 @@ class ToolConversationVariables(Base):
updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) updated_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
@property @property
def variables(self) -> dict: def variables(self) -> Any:
return dict(json.loads(self.variables_str)) return json.loads(self.variables_str)
class ToolFile(Base): class ToolFile(Base):

@ -402,23 +402,23 @@ class WorkflowRun(Base):
db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"), db.Index("workflow_run_tenant_app_sequence_idx", "tenant_id", "app_id", "sequence_number"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id = db.Column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID)
sequence_number = db.Column(db.Integer, nullable=False) sequence_number: Mapped[int] = mapped_column()
workflow_id = db.Column(StringUUID, nullable=False) workflow_id: Mapped[str] = mapped_column(StringUUID)
type = db.Column(db.String(255), nullable=False) type: Mapped[str] = mapped_column(db.String(255))
triggered_from = db.Column(db.String(255), nullable=False) triggered_from: Mapped[str] = mapped_column(db.String(255))
version = db.Column(db.String(255), nullable=False) version: Mapped[str] = mapped_column(db.String(255))
graph = db.Column(db.Text) graph: Mapped[Optional[str]] = mapped_column(db.Text)
inputs = db.Column(db.Text) inputs: Mapped[Optional[str]] = mapped_column(db.Text)
status = db.Column(db.String(255), nullable=False) # running, succeeded, failed, stopped, partial-succeeded status: Mapped[str] = mapped_column(db.String(255)) # running, succeeded, failed, stopped, partial-succeeded
outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}") outputs: Mapped[Optional[str]] = mapped_column(sa.Text, default="{}")
error = db.Column(db.Text) error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0"))
total_tokens = db.Column(db.Integer, nullable=False, server_default=db.text("0")) total_tokens: Mapped[int] = mapped_column(server_default=db.text("0"))
total_steps = db.Column(db.Integer, server_default=db.text("0")) total_steps = db.Column(db.Integer, server_default=db.text("0"))
created_by_role = db.Column(db.String(255), nullable=False) # account, end_user created_by_role: Mapped[str] = mapped_column(db.String(255)) # account, end_user
created_by = db.Column(StringUUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp())
finished_at = db.Column(db.DateTime) finished_at = db.Column(db.DateTime)
@ -631,29 +631,29 @@ class WorkflowNodeExecution(Base):
), ),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id = db.Column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id = db.Column(StringUUID, nullable=False) workflow_id: Mapped[str] = mapped_column(StringUUID)
triggered_from = db.Column(db.String(255), nullable=False) triggered_from: Mapped[str] = mapped_column(db.String(255))
workflow_run_id = db.Column(StringUUID) workflow_run_id: Mapped[Optional[str]] = mapped_column(StringUUID)
index = db.Column(db.Integer, nullable=False) index: Mapped[int] = mapped_column(db.Integer)
predecessor_node_id = db.Column(db.String(255)) predecessor_node_id: Mapped[Optional[str]] = mapped_column(db.String(255))
node_execution_id = db.Column(db.String(255), nullable=True) node_execution_id: Mapped[Optional[str]] = mapped_column(db.String(255))
node_id = db.Column(db.String(255), nullable=False) node_id: Mapped[str] = mapped_column(db.String(255))
node_type = db.Column(db.String(255), nullable=False) node_type: Mapped[str] = mapped_column(db.String(255))
title = db.Column(db.String(255), nullable=False) title: Mapped[str] = mapped_column(db.String(255))
inputs = db.Column(db.Text) inputs: Mapped[Optional[str]] = mapped_column(db.Text)
process_data = db.Column(db.Text) process_data: Mapped[Optional[str]] = mapped_column(db.Text)
outputs = db.Column(db.Text) outputs: Mapped[Optional[str]] = mapped_column(db.Text)
status = db.Column(db.String(255), nullable=False) status: Mapped[str] = mapped_column(db.String(255))
error = db.Column(db.Text) error: Mapped[Optional[str]] = mapped_column(db.Text)
elapsed_time = db.Column(db.Float, nullable=False, server_default=db.text("0")) elapsed_time: Mapped[float] = mapped_column(db.Float, server_default=db.text("0"))
execution_metadata = db.Column(db.Text) execution_metadata: Mapped[Optional[str]] = mapped_column(db.Text)
created_at = db.Column(db.DateTime, nullable=False, server_default=func.current_timestamp()) created_at: Mapped[datetime] = mapped_column(db.DateTime, server_default=func.current_timestamp())
created_by_role = db.Column(db.String(255), nullable=False) created_by_role: Mapped[str] = mapped_column(db.String(255))
created_by = db.Column(StringUUID, nullable=False) created_by: Mapped[str] = mapped_column(StringUUID)
finished_at = db.Column(db.DateTime) finished_at: Mapped[Optional[datetime]] = mapped_column(db.DateTime)
@property @property
def created_by_account(self): def created_by_account(self):
@ -760,11 +760,11 @@ class WorkflowAppLog(Base):
db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"), db.Index("workflow_app_log_app_idx", "tenant_id", "app_id"),
) )
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuid_generate_v4()"))
tenant_id = db.Column(StringUUID, nullable=False) tenant_id: Mapped[str] = mapped_column(StringUUID)
app_id = db.Column(StringUUID, nullable=False) app_id: Mapped[str] = mapped_column(StringUUID)
workflow_id = db.Column(StringUUID, nullable=False) workflow_id = db.Column(StringUUID, nullable=False)
workflow_run_id = db.Column(StringUUID, nullable=False) workflow_run_id: Mapped[str] = mapped_column(StringUUID)
created_from = db.Column(db.String(255), nullable=False) created_from = db.Column(db.String(255), nullable=False)
created_by_role = db.Column(db.String(255), nullable=False) created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(StringUUID, nullable=False) created_by = db.Column(StringUUID, nullable=False)

@ -28,7 +28,6 @@ def clean_messages():
plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta( plan_sandbox_clean_message_day = datetime.datetime.now() - datetime.timedelta(
days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING days=dify_config.PLAN_SANDBOX_CLEAN_MESSAGE_DAY_SETTING
) )
page = 1
while True: while True:
try: try:
# Main query with join and filter # Main query with join and filter
@ -79,4 +78,4 @@ def clean_messages():
db.session.query(Message).filter(Message.id == message.id).delete() db.session.query(Message).filter(Message.id == message.id).delete()
db.session.commit() db.session.commit()
end_at = time.perf_counter() end_at = time.perf_counter()
click.echo(click.style("Cleaned unused dataset from db success latency: {}".format(end_at - start_at), fg="green")) click.echo(click.style("Cleaned messages from db success latency: {}".format(end_at - start_at), fg="green"))

@ -10,7 +10,7 @@ from configs import dify_config
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import Dataset, DatasetQuery, Document from models.dataset import Dataset, DatasetAutoDisableLog, DatasetQuery, Document
from services.feature_service import FeatureService from services.feature_service import FeatureService
@ -75,6 +75,23 @@ def clean_unused_datasets_task():
) )
if not dataset_query or len(dataset_query) == 0: if not dataset_query or len(dataset_query) == 0:
try: try:
# add auto disable log
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# remove index # remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None) index_processor.clean(dataset, None)
@ -151,6 +168,23 @@ def clean_unused_datasets_task():
else: else:
plan = plan_cache.decode() plan = plan_cache.decode()
if plan == "sandbox": if plan == "sandbox":
# add auto disable log
documents = (
db.session.query(Document)
.filter(
Document.dataset_id == dataset.id,
Document.enabled == True,
Document.archived == False,
)
.all()
)
for document in documents:
dataset_auto_disable_log = DatasetAutoDisableLog(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=document.id,
)
db.session.add(dataset_auto_disable_log)
# remove index # remove index
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor() index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
index_processor.clean(dataset, None) index_processor.clean(dataset, None)

@ -0,0 +1,90 @@
import logging
import time
from collections import defaultdict
import click
from flask import render_template # type: ignore
import app
from configs import dify_config
from extensions.ext_database import db
from extensions.ext_mail import mail
from models.account import Account, Tenant, TenantAccountJoin
from models.dataset import Dataset, DatasetAutoDisableLog
from services.feature_service import FeatureService
@app.celery.task(queue="dataset")
def send_document_clean_notify_task():
"""
Async Send document clean notify mail
Usage: send_document_clean_notify_task.delay()
"""
if not mail.is_inited():
return
logging.info(click.style("Start send document clean notify mail", fg="green"))
start_at = time.perf_counter()
# send document clean notify mail
try:
dataset_auto_disable_logs = DatasetAutoDisableLog.query.filter(DatasetAutoDisableLog.notified == False).all()
# group by tenant_id
dataset_auto_disable_logs_map: dict[str, list[DatasetAutoDisableLog]] = defaultdict(list)
for dataset_auto_disable_log in dataset_auto_disable_logs:
if dataset_auto_disable_log.tenant_id not in dataset_auto_disable_logs_map:
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id] = []
dataset_auto_disable_logs_map[dataset_auto_disable_log.tenant_id].append(dataset_auto_disable_log)
url = f"{dify_config.CONSOLE_WEB_URL}/datasets"
for tenant_id, tenant_dataset_auto_disable_logs in dataset_auto_disable_logs_map.items():
features = FeatureService.get_features(tenant_id)
plan = features.billing.subscription.plan
if plan != "sandbox":
knowledge_details = []
# check tenant
tenant = Tenant.query.filter(Tenant.id == tenant_id).first()
if not tenant:
continue
# check current owner
current_owner_join = TenantAccountJoin.query.filter_by(tenant_id=tenant.id, role="owner").first()
if not current_owner_join:
continue
account = Account.query.filter(Account.id == current_owner_join.account_id).first()
if not account:
continue
dataset_auto_dataset_map = {} # type: ignore
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
if dataset_auto_disable_log.dataset_id not in dataset_auto_dataset_map:
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id] = []
dataset_auto_dataset_map[dataset_auto_disable_log.dataset_id].append(
dataset_auto_disable_log.document_id
)
for dataset_id, document_ids in dataset_auto_dataset_map.items():
dataset = Dataset.query.filter(Dataset.id == dataset_id).first()
if dataset:
document_count = len(document_ids)
knowledge_details.append(rf"Knowledge base {dataset.name}: {document_count} documents")
if knowledge_details:
html_content = render_template(
"clean_document_job_mail_template-US.html",
userName=account.email,
knowledge_details=knowledge_details,
url=url,
)
mail.send(
to=account.email, subject="Dify Knowledge base auto disable notification", html=html_content
)
# update notified to True
for dataset_auto_disable_log in tenant_dataset_auto_disable_logs:
dataset_auto_disable_log.notified = True
db.session.commit()
end_at = time.perf_counter()
logging.info(
click.style("Send document clean notify mail succeeded: latency: {}".format(end_at - start_at), fg="green")
)
except Exception:
logging.exception("Send document clean notify mail failed")

@ -798,6 +798,7 @@ class RegisterService:
language: Optional[str] = None, language: Optional[str] = None,
status: Optional[AccountStatus] = None, status: Optional[AccountStatus] = None,
is_setup: Optional[bool] = False, is_setup: Optional[bool] = False,
create_workspace_required: Optional[bool] = True,
) -> Account: ) -> Account:
db.session.begin_nested() db.session.begin_nested()
"""Register account""" """Register account"""
@ -815,7 +816,7 @@ class RegisterService:
if open_id is not None and provider is not None: if open_id is not None and provider is not None:
AccountService.link_account_integrate(provider, open_id, account) AccountService.link_account_integrate(provider, open_id, account)
if FeatureService.get_system_features().is_allow_create_workspace: if FeatureService.get_system_features().is_allow_create_workspace and create_workspace_required:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace") tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner") TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant account.current_tenant = tenant

@ -4,7 +4,7 @@ from enum import StrEnum
from typing import Optional, cast from typing import Optional, cast
from uuid import uuid4 from uuid import uuid4
import yaml import yaml # type: ignore
from packaging import version from packaging import version
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
@ -196,6 +196,9 @@ class AppDslService:
data["kind"] = "app" data["kind"] = "app"
imported_version = data.get("version", "0.1.0") imported_version = data.get("version", "0.1.0")
# check if imported_version is a float-like string
if not isinstance(imported_version, str):
raise ValueError(f"Invalid version type, expected str, got {type(imported_version)}")
status = _check_version_compatibility(imported_version) status = _check_version_compatibility(imported_version)
# Extract app data # Extract app data
@ -524,7 +527,7 @@ class AppDslService:
else: else:
cls._append_model_config_export_data(export_data, app_model) cls._append_model_config_export_data(export_data, app_model)
return yaml.dump(export_data, allow_unicode=True) return yaml.dump(export_data, allow_unicode=True) # type: ignore
@classmethod @classmethod
def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None: def _append_workflow_export_data(cls, *, export_data: dict, app_model: App, include_secret: bool) -> None:

@ -1,5 +1,6 @@
import io import io
import logging import logging
import uuid
from typing import Optional from typing import Optional
from werkzeug.datastructures import FileStorage from werkzeug.datastructures import FileStorage
@ -122,6 +123,10 @@ class AudioService:
raise e raise e
if message_id: if message_id:
try:
uuid.UUID(message_id)
except ValueError:
return None
message = db.session.query(Message).filter(Message.id == message_id).first() message = db.session.query(Message).filter(Message.id == message_id).first()
if message is None: if message is None:
return None return None

@ -2,7 +2,7 @@ import os
from typing import Optional from typing import Optional
import httpx import httpx
from tenacity import retry, retry_if_not_exception_type, stop_before_delay, wait_fixed from tenacity import retry, retry_if_exception_type, stop_before_delay, wait_fixed
from extensions.ext_database import db from extensions.ext_database import db
from models.account import TenantAccountJoin, TenantAccountRole from models.account import TenantAccountJoin, TenantAccountRole
@ -44,7 +44,7 @@ class BillingService:
@retry( @retry(
wait=wait_fixed(2), wait=wait_fixed(2),
stop=stop_before_delay(10), stop=stop_before_delay(10),
retry=retry_if_not_exception_type(httpx.RequestError), retry=retry_if_exception_type(httpx.RequestError),
reraise=True, reraise=True,
) )
def _send_request(cls, method, endpoint, json=None, params=None): def _send_request(cls, method, endpoint, json=None, params=None):

File diff suppressed because it is too large Load Diff

@ -1,4 +1,5 @@
from typing import Optional from enum import Enum
from typing import Literal, Optional
from pydantic import BaseModel from pydantic import BaseModel
@ -8,3 +9,112 @@ class SegmentUpdateEntity(BaseModel):
answer: Optional[str] = None answer: Optional[str] = None
keywords: Optional[list[str]] = None keywords: Optional[list[str]] = None
enabled: Optional[bool] = None enabled: Optional[bool] = None
class ParentMode(str, Enum):
FULL_DOC = "full-doc"
PARAGRAPH = "paragraph"
class NotionIcon(BaseModel):
type: str
url: Optional[str] = None
emoji: Optional[str] = None
class NotionPage(BaseModel):
page_id: str
page_name: str
page_icon: Optional[NotionIcon] = None
type: str
class NotionInfo(BaseModel):
workspace_id: str
pages: list[NotionPage]
class WebsiteInfo(BaseModel):
provider: str
job_id: str
urls: list[str]
only_main_content: bool = True
class FileInfo(BaseModel):
file_ids: list[str]
class InfoList(BaseModel):
data_source_type: Literal["upload_file", "notion_import", "website_crawl"]
notion_info_list: Optional[list[NotionInfo]] = None
file_info_list: Optional[FileInfo] = None
website_info_list: Optional[WebsiteInfo] = None
class DataSource(BaseModel):
info_list: InfoList
class PreProcessingRule(BaseModel):
id: str
enabled: bool
class Segmentation(BaseModel):
separator: str = "\n"
max_tokens: int
chunk_overlap: int = 0
class Rule(BaseModel):
pre_processing_rules: Optional[list[PreProcessingRule]] = None
segmentation: Optional[Segmentation] = None
parent_mode: Optional[Literal["full-doc", "paragraph"]] = None
subchunk_segmentation: Optional[Segmentation] = None
class ProcessRule(BaseModel):
mode: Literal["automatic", "custom", "hierarchical"]
rules: Optional[Rule] = None
class RerankingModel(BaseModel):
reranking_provider_name: Optional[str] = None
reranking_model_name: Optional[str] = None
class RetrievalModel(BaseModel):
search_method: Literal["hybrid_search", "semantic_search", "full_text_search"]
reranking_enable: bool
reranking_model: Optional[RerankingModel] = None
top_k: int
score_threshold_enabled: bool
score_threshold: Optional[float] = None
class KnowledgeConfig(BaseModel):
original_document_id: Optional[str] = None
duplicate: bool = True
indexing_technique: Literal["high_quality", "economy"]
data_source: DataSource
process_rule: Optional[ProcessRule] = None
retrieval_model: Optional[RetrievalModel] = None
doc_form: str = "text_model"
doc_language: str = "English"
embedding_model: Optional[str] = None
embedding_model_provider: Optional[str] = None
name: Optional[str] = None
class SegmentUpdateArgs(BaseModel):
content: Optional[str] = None
answer: Optional[str] = None
keywords: Optional[list[str]] = None
regenerate_child_chunks: bool = False
enabled: Optional[bool] = None
class ChildChunkUpdateArgs(BaseModel):
id: Optional[str] = None
content: str

@ -1,6 +1,6 @@
from typing import Optional from typing import Optional
class BaseServiceError(Exception): class BaseServiceError(ValueError):
def __init__(self, description: Optional[str] = None): def __init__(self, description: Optional[str] = None):
self.description = description self.description = description

@ -0,0 +1,9 @@
from services.errors.base import BaseServiceError
class ChildChunkIndexingError(BaseServiceError):
description = "{message}"
class ChildChunkDeleteIndexError(BaseServiceError):
description = "{message}"

@ -76,7 +76,7 @@ class FeatureService:
cls._fulfill_params_from_env(features) cls._fulfill_params_from_env(features)
if dify_config.BILLING_ENABLED: if dify_config.BILLING_ENABLED and tenant_id:
cls._fulfill_params_from_billing_api(features, tenant_id) cls._fulfill_params_from_billing_api(features, tenant_id)
return features return features

@ -7,7 +7,7 @@ from core.rag.models.document import Document
from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db from extensions.ext_database import db
from models.account import Account from models.account import Account
from models.dataset import Dataset, DatasetQuery, DocumentSegment from models.dataset import Dataset, DatasetQuery
default_retrieval_model = { default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value, "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
@ -69,7 +69,7 @@ class HitTestingService:
db.session.add(dataset_query) db.session.add(dataset_query)
db.session.commit() db.session.commit()
return dict(cls.compact_retrieve_response(dataset, query, all_documents)) return cls.compact_retrieve_response(query, all_documents) # type: ignore
@classmethod @classmethod
def external_retrieve( def external_retrieve(
@ -106,41 +106,14 @@ class HitTestingService:
return dict(cls.compact_external_retrieve_response(dataset, query, all_documents)) return dict(cls.compact_external_retrieve_response(dataset, query, all_documents))
@classmethod @classmethod
def compact_retrieve_response(cls, dataset: Dataset, query: str, documents: list[Document]): def compact_retrieve_response(cls, query: str, documents: list[Document]):
records = [] records = RetrievalService.format_retrieval_documents(documents)
for document in documents:
if document.metadata is None:
continue
index_node_id = document.metadata["doc_id"]
segment = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.dataset_id == dataset.id,
DocumentSegment.enabled == True,
DocumentSegment.status == "completed",
DocumentSegment.index_node_id == index_node_id,
)
.first()
)
if not segment:
continue
record = {
"segment": segment,
"score": document.metadata.get("score", None),
}
records.append(record)
return { return {
"query": { "query": {
"content": query, "content": query,
}, },
"records": records, "records": [record.model_dump() for record in records],
} }
@classmethod @classmethod

@ -152,6 +152,7 @@ class MessageService:
@classmethod @classmethod
def create_feedback( def create_feedback(
cls, cls,
*,
app_model: App, app_model: App,
message_id: str, message_id: str,
user: Optional[Union[Account, EndUser]], user: Optional[Union[Account, EndUser]],

@ -64,7 +64,10 @@ class ToolTransformService:
) )
elif isinstance(provider, ToolProviderApiEntity): elif isinstance(provider, ToolProviderApiEntity):
if provider.plugin_id: if provider.plugin_id:
provider.icon = ToolTransformService.get_plugin_icon_url(tenant_id=tenant_id, filename=provider.icon) if isinstance(provider.icon, str):
provider.icon = ToolTransformService.get_plugin_icon_url(
tenant_id=tenant_id, filename=provider.icon
)
else: else:
provider.icon = ToolTransformService.get_tool_provider_icon_url( provider.icon = ToolTransformService.get_tool_provider_icon_url(
provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon provider_type=provider.type.value, provider_name=provider.name, icon=provider.icon

@ -1,40 +1,70 @@
from typing import Optional from typing import Optional
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.keyword.keyword_factory import Keyword
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document from core.rag.models.document import Document
from models.dataset import Dataset, DocumentSegment from extensions.ext_database import db
from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment
from models.dataset import Document as DatasetDocument
from services.entities.knowledge_entities.knowledge_entities import ParentMode
class VectorService: class VectorService:
@classmethod @classmethod
def create_segments_vector( def create_segments_vector(
cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset cls, keywords_list: Optional[list[list[str]]], segments: list[DocumentSegment], dataset: Dataset, doc_form: str
): ):
documents = [] documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
if dataset.indexing_technique == "high_quality":
# save vector index
vector = Vector(dataset=dataset)
vector.add_texts(documents, duplicate_check=True)
# save keyword index for segment in segments:
keyword = Keyword(dataset) if doc_form == IndexType.PARENT_CHILD_INDEX:
document = DatasetDocument.query.filter_by(id=segment.document_id).first()
# get the process rule
processing_rule = (
db.session.query(DatasetProcessRule)
.filter(DatasetProcessRule.id == document.dataset_process_rule_id)
.first()
)
if not processing_rule:
raise ValueError("No processing rule found.")
# get embedding model instance
if dataset.indexing_technique == "high_quality":
# check embedding model setting
model_manager = ModelManager()
if keywords_list and len(keywords_list) > 0: if dataset.embedding_model_provider:
keyword.add_texts(documents, keywords_list=keywords_list) embedding_model_instance = model_manager.get_model_instance(
else: tenant_id=dataset.tenant_id,
keyword.add_texts(documents) provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
else:
embedding_model_instance = model_manager.get_default_model_instance(
tenant_id=dataset.tenant_id,
model_type=ModelType.TEXT_EMBEDDING,
)
else:
raise ValueError("The knowledge base index technique is not high quality!")
cls.generate_child_chunks(segment, document, dataset, embedding_model_instance, processing_rule, False)
else:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
if len(documents) > 0:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.load(dataset, documents, with_keywords=True, keywords_list=keywords_list)
@classmethod @classmethod
def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset): def update_segment_vector(cls, keywords: Optional[list[str]], segment: DocumentSegment, dataset: Dataset):
@ -65,3 +95,123 @@ class VectorService:
keyword.add_texts([document], keywords_list=[keywords]) keyword.add_texts([document], keywords_list=[keywords])
else: else:
keyword.add_texts([document]) keyword.add_texts([document])
@classmethod
def generate_child_chunks(
cls,
segment: DocumentSegment,
dataset_document: DatasetDocument,
dataset: Dataset,
embedding_model_instance: ModelInstance,
processing_rule: DatasetProcessRule,
regenerate: bool = False,
):
index_processor = IndexProcessorFactory(dataset.doc_form).init_index_processor()
if regenerate:
# delete child chunks
index_processor.clean(dataset, [segment.index_node_id], with_keywords=True, delete_child_chunks=True)
# generate child chunks
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
# use full doc mode to generate segment's child chunk
processing_rule_dict = processing_rule.to_dict()
processing_rule_dict["rules"]["parent_mode"] = ParentMode.FULL_DOC.value
documents = index_processor.transform(
[document],
embedding_model_instance=embedding_model_instance,
process_rule=processing_rule_dict,
tenant_id=dataset.tenant_id,
doc_language=dataset_document.doc_language,
)
# save child chunks
if documents and documents[0].children:
index_processor.load(dataset, documents)
for position, child_chunk in enumerate(documents[0].children, start=1):
child_segment = ChildChunk(
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
document_id=dataset_document.id,
segment_id=segment.id,
position=position,
index_node_id=child_chunk.metadata["doc_id"],
index_node_hash=child_chunk.metadata["doc_hash"],
content=child_chunk.page_content,
word_count=len(child_chunk.page_content),
type="automatic",
created_by=dataset_document.created_by,
)
db.session.add(child_segment)
db.session.commit()
@classmethod
def create_child_chunk_vector(cls, child_segment: ChildChunk, dataset: Dataset):
child_document = Document(
page_content=child_segment.content,
metadata={
"doc_id": child_segment.index_node_id,
"doc_hash": child_segment.index_node_hash,
"document_id": child_segment.document_id,
"dataset_id": child_segment.dataset_id,
},
)
if dataset.indexing_technique == "high_quality":
# save vector index
vector = Vector(dataset=dataset)
vector.add_texts([child_document], duplicate_check=True)
@classmethod
def update_child_chunk_vector(
cls,
new_child_chunks: list[ChildChunk],
update_child_chunks: list[ChildChunk],
delete_child_chunks: list[ChildChunk],
dataset: Dataset,
):
documents = []
delete_node_ids = []
for new_child_chunk in new_child_chunks:
new_child_document = Document(
page_content=new_child_chunk.content,
metadata={
"doc_id": new_child_chunk.index_node_id,
"doc_hash": new_child_chunk.index_node_hash,
"document_id": new_child_chunk.document_id,
"dataset_id": new_child_chunk.dataset_id,
},
)
documents.append(new_child_document)
for update_child_chunk in update_child_chunks:
child_document = Document(
page_content=update_child_chunk.content,
metadata={
"doc_id": update_child_chunk.index_node_id,
"doc_hash": update_child_chunk.index_node_hash,
"document_id": update_child_chunk.document_id,
"dataset_id": update_child_chunk.dataset_id,
},
)
documents.append(child_document)
delete_node_ids.append(update_child_chunk.index_node_id)
for delete_child_chunk in delete_child_chunks:
delete_node_ids.append(delete_child_chunk.index_node_id)
if dataset.indexing_technique == "high_quality":
# update vector index
vector = Vector(dataset=dataset)
if delete_node_ids:
vector.delete_by_ids(delete_node_ids)
if documents:
vector.add_texts(documents, duplicate_check=True)
@classmethod
def delete_child_chunk_vector(cls, child_chunk: ChildChunk, dataset: Dataset):
vector = Vector(dataset=dataset)
vector.delete_by_ids([child_chunk.index_node_id])

@ -3,6 +3,7 @@ import time
from collections.abc import Callable, Generator, Sequence from collections.abc import Callable, Generator, Sequence
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any, Optional from typing import Any, Optional
from uuid import uuid4
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
@ -333,6 +334,7 @@ class WorkflowService:
error = e.error error = e.error
workflow_node_execution = WorkflowNodeExecution() workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.id = str(uuid4())
workflow_node_execution.tenant_id = tenant_id workflow_node_execution.tenant_id = tenant_id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
workflow_node_execution.index = 1 workflow_node_execution.index = 1

@ -6,12 +6,13 @@ import click
from celery import shared_task # type: ignore from celery import shared_task # type: ignore
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from models.dataset import DatasetAutoDisableLog, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
from models.dataset import DocumentSegment
@shared_task(queue="dataset") @shared_task(queue="dataset")
@ -53,7 +54,22 @@ def add_document_to_index_task(dataset_document_id: str):
"dataset_id": segment.dataset_id, "dataset_id": segment.dataset_id,
}, },
) )
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document) documents.append(document)
dataset = dataset_document.dataset dataset = dataset_document.dataset
@ -65,6 +81,12 @@ def add_document_to_index_task(dataset_document_id: str):
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.load(dataset, documents) index_processor.load(dataset, documents)
# delete auto disable log
db.session.query(DatasetAutoDisableLog).filter(
DatasetAutoDisableLog.document_id == dataset_document.id
).delete()
db.session.commit()
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(
click.style( click.style(

@ -0,0 +1,76 @@
import logging
import time
import click
from celery import shared_task # type: ignore
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.dataset import Dataset, DocumentSegment
from models.model import UploadFile
@shared_task(queue="dataset")
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]):
"""
Clean document when document deleted.
:param document_ids: document ids
:param dataset_id: dataset id
:param doc_form: doc_form
:param file_ids: file ids
Usage: clean_document_task.delay(document_id, dataset_id)
"""
logging.info(click.style("Start batch clean documents when documents deleted", fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id.in_(document_ids)).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
for upload_file_id in image_upload_file_ids:
image_file = db.session.query(UploadFile).filter(UploadFile.id == upload_file_id).first()
try:
if image_file and image_file.key:
storage.delete(image_file.key)
except Exception:
logging.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: {}".format(upload_file_id)
)
db.session.delete(image_file)
db.session.delete(segment)
db.session.commit()
if file_ids:
files = db.session.query(UploadFile).filter(UploadFile.id.in_(file_ids)).all()
for file in files:
try:
storage.delete(file.key)
except Exception:
logging.exception("Delete file failed when document deleted, file_id: {}".format(file.id))
db.session.delete(file)
db.session.commit()
end_at = time.perf_counter()
logging.info(
click.style(
"Cleaned documents when documents deleted latency: {}".format(end_at - start_at),
fg="green",
)
)
except Exception:
logging.exception("Cleaned documents when documents deleted failed")

@ -7,13 +7,13 @@ import click
from celery import shared_task # type: ignore from celery import shared_task # type: ignore
from sqlalchemy import func from sqlalchemy import func
from core.indexing_runner import IndexingRunner
from core.model_manager import ModelManager from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client from extensions.ext_redis import redis_client
from libs import helper from libs import helper
from models.dataset import Dataset, Document, DocumentSegment from models.dataset import Dataset, Document, DocumentSegment
from services.vector_service import VectorService
@shared_task(queue="dataset") @shared_task(queue="dataset")
@ -98,8 +98,7 @@ def batch_create_segment_to_index_task(
dataset_document.word_count += word_count_change dataset_document.word_count += word_count_change
db.session.add(dataset_document) db.session.add(dataset_document)
# add index to db # add index to db
indexing_runner = IndexingRunner() VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
indexing_runner.batch_add_segments(document_segments, dataset)
db.session.commit() db.session.commit()
redis_client.setex(indexing_cache_key, 600, "completed") redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter() end_at = time.perf_counter()

@ -62,7 +62,7 @@ def clean_dataset_task(
if doc_form is None: if doc_form is None:
raise ValueError("Index type must be specified.") raise ValueError("Index type must be specified.")
index_processor = IndexProcessorFactory(doc_form).init_index_processor() index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None) index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
for document in documents: for document in documents:
db.session.delete(document) db.session.delete(document)

@ -38,7 +38,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
if segments: if segments:
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor() index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, index_node_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments: for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content) image_upload_file_ids = get_image_upload_file_ids(segment.content)

@ -37,7 +37,7 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all() segments = db.session.query(DocumentSegment).filter(DocumentSegment.document_id == document_id).all()
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)

@ -4,8 +4,9 @@ import time
import click import click
from celery import shared_task # type: ignore from celery import shared_task # type: ignore
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import Document from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument from models.dataset import Document as DatasetDocument
@ -105,7 +106,7 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
db.session.commit() db.session.commit()
# clean index # clean index
index_processor.clean(dataset, None, with_keywords=False) index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
for dataset_document in dataset_documents: for dataset_document in dataset_documents:
# update from vector index # update from vector index
@ -128,7 +129,22 @@ def deal_dataset_vector_index_task(dataset_id: str, action: str):
"dataset_id": segment.dataset_id, "dataset_id": segment.dataset_id,
}, },
) )
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.child_chunks
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document) documents.append(document)
# save vector index # save vector index
index_processor.load(dataset, documents, with_keywords=False) index_processor.load(dataset, documents, with_keywords=False)

@ -6,48 +6,38 @@ from celery import shared_task # type: ignore
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, Document from models.dataset import Dataset, Document
@shared_task(queue="dataset") @shared_task(queue="dataset")
def delete_segment_from_index_task(segment_id: str, index_node_id: str, dataset_id: str, document_id: str): def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str):
""" """
Async Remove segment from index Async Remove segment from index
:param segment_id: :param index_node_ids:
:param index_node_id:
:param dataset_id: :param dataset_id:
:param document_id: :param document_id:
Usage: delete_segment_from_index_task.delay(segment_id) Usage: delete_segment_from_index_task.delay(segment_ids)
""" """
logging.info(click.style("Start delete segment from index: {}".format(segment_id), fg="green")) logging.info(click.style("Start delete segment from index", fg="green"))
start_at = time.perf_counter() start_at = time.perf_counter()
indexing_cache_key = "segment_{}_delete_indexing".format(segment_id)
try: try:
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first() dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset: if not dataset:
logging.info(click.style("Segment {} has no dataset, pass.".format(segment_id), fg="cyan"))
return return
dataset_document = db.session.query(Document).filter(Document.id == document_id).first() dataset_document = db.session.query(Document).filter(Document.id == document_id).first()
if not dataset_document: if not dataset_document:
logging.info(click.style("Segment {} has no document, pass.".format(segment_id), fg="cyan"))
return return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed": if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Segment {} document status is invalid, pass.".format(segment_id), fg="cyan"))
return return
index_type = dataset_document.doc_form index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor() index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.clean(dataset, [index_node_id]) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
end_at = time.perf_counter() end_at = time.perf_counter()
logging.info( logging.info(click.style("Segment deleted from index latency: {}".format(end_at - start_at), fg="green"))
click.style("Segment deleted from index: {} latency: {}".format(segment_id, end_at - start_at), fg="green")
)
except Exception: except Exception:
logging.exception("delete segment from index failed") logging.exception("delete segment from index failed")
finally:
redis_client.delete(indexing_cache_key)

@ -0,0 +1,76 @@
import logging
import time
import click
from celery import shared_task # type: ignore
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@shared_task(queue="dataset")
def disable_segments_from_index_task(segment_ids: list, dataset_id: str, document_id: str):
"""
Async disable segments from index
:param segment_ids:
Usage: disable_segments_from_index_task.delay(segment_ids, dataset_id, document_id)
"""
start_at = time.perf_counter()
dataset = db.session.query(Dataset).filter(Dataset.id == dataset_id).first()
if not dataset:
logging.info(click.style("Dataset {} not found, pass.".format(dataset_id), fg="cyan"))
return
dataset_document = db.session.query(DatasetDocument).filter(DatasetDocument.id == document_id).first()
if not dataset_document:
logging.info(click.style("Document {} not found, pass.".format(document_id), fg="cyan"))
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info(click.style("Document {} status is invalid, pass.".format(document_id), fg="cyan"))
return
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = (
db.session.query(DocumentSegment)
.filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.all()
)
if not segments:
return
try:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=False)
end_at = time.perf_counter()
logging.info(click.style("Segments removed from index latency: {}".format(end_at - start_at), fg="green"))
except Exception:
# update segment error msg
db.session.query(DocumentSegment).filter(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
).update(
{
"disabled_at": None,
"disabled_by": None,
"enabled": True,
}
)
db.session.commit()
finally:
for segment in segments:
indexing_cache_key = "segment_{}_indexing".format(segment.id)
redis_client.delete(indexing_cache_key)

@ -82,7 +82,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
index_node_ids = [segment.index_node_id for segment in segments] index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index # delete from vector index
index_processor.clean(dataset, index_node_ids) index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments: for segment in segments:
db.session.delete(segment) db.session.delete(segment)

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save