@ -1,16 +1,21 @@
import uuid
from datetime import UTC , datetime
import pandas as pd
from flask import request
from flask_login import current_user
from flask_restful import Resource , marshal , reqparse
from flask_login import current_user # type: ignore
from flask_restful import Resource , marshal , reqparse # type: ignore
from werkzeug . exceptions import Forbidden , NotFound
import services
from controllers . console import api
from controllers . console . app . error import ProviderNotInitializeError
from controllers . console . datasets . error import InvalidActionError , NoFileUploadedError , TooManyFilesError
from controllers . console . datasets . error import (
ChildChunkDeleteIndexError ,
ChildChunkIndexingError ,
InvalidActionError ,
NoFileUploadedError ,
TooManyFilesError ,
)
from controllers . console . wraps import (
account_initialization_required ,
cloud_edition_billing_knowledge_limit_check ,
@ -20,15 +25,15 @@ from controllers.console.wraps import (
from core . errors . error import LLMBadRequestError , ProviderTokenNotInitError
from core . model_manager import ModelManager
from core . model_runtime . entities . model_entities import ModelType
from extensions . ext_database import db
from extensions . ext_redis import redis_client
from fields . segment_fields import segment_fields
from fields . segment_fields import child_chunk_fields, segment_fields
from libs . login import login_required
from models import DocumentSegment
from models . dataset import ChildChunk , DocumentSegment
from services . dataset_service import DatasetService , DocumentService , SegmentService
from services . entities . knowledge_entities . knowledge_entities import ChildChunkUpdateArgs , SegmentUpdateArgs
from services . errors . chunk import ChildChunkDeleteIndexError as ChildChunkDeleteIndexServiceError
from services . errors . chunk import ChildChunkIndexingError as ChildChunkIndexingServiceError
from tasks . batch_create_segment_to_index_task import batch_create_segment_to_index_task
from tasks . disable_segment_from_index_task import disable_segment_from_index_task
from tasks . enable_segment_to_index_task import enable_segment_to_index_task
class DatasetDocumentSegmentListApi ( Resource ) :
@ -53,15 +58,16 @@ class DatasetDocumentSegmentListApi(Resource):
raise NotFound ( " Document not found. " )
parser = reqparse . RequestParser ( )
parser . add_argument ( " last_id " , type = str , default = None , location = " args " )
parser . add_argument ( " limit " , type = int , default = 20 , location = " args " )
parser . add_argument ( " status " , type = str , action = " append " , default = [ ] , location = " args " )
parser . add_argument ( " hit_count_gte " , type = int , default = None , location = " args " )
parser . add_argument ( " enabled " , type = str , default = " all " , location = " args " )
parser . add_argument ( " keyword " , type = str , default = None , location = " args " )
parser . add_argument ( " page " , type = int , default = 1 , location = " args " )
args = parser . parse_args ( )
last_id = args [ " last_id " ]
page = args [ " page " ]
limit = min ( args [ " limit " ] , 100 )
status_list = args [ " status " ]
hit_count_gte = args [ " hit_count_gte " ]
@ -69,14 +75,7 @@ class DatasetDocumentSegmentListApi(Resource):
query = DocumentSegment . query . filter (
DocumentSegment . document_id == str ( document_id ) , DocumentSegment . tenant_id == current_user . current_tenant_id
)
if last_id is not None :
last_segment = db . session . get ( DocumentSegment , str ( last_id ) )
if last_segment :
query = query . filter ( DocumentSegment . position > last_segment . position )
else :
return { " data " : [ ] , " has_more " : False , " limit " : limit } , 200
) . order_by ( DocumentSegment . position . asc ( ) )
if status_list :
query = query . filter ( DocumentSegment . status . in_ ( status_list ) )
@ -93,21 +92,44 @@ class DatasetDocumentSegmentListApi(Resource):
elif args [ " enabled " ] . lower ( ) == " false " :
query = query . filter ( DocumentSegment . enabled == False )
total = query . count ( )
segments = query . order_by ( DocumentSegment . position ) . limit ( limit + 1 ) . all ( )
segments = query . paginate ( page = page , per_page = limit , max_per_page = 100 , error_out = False )
has_more = False
if len ( segments ) > limit :
has_more = True
segments = segments [ : - 1 ]
return {
" data " : marshal ( segments , segment_fields ) ,
" doc_form " : document . doc_form ,
" has_more " : has_more ,
response = {
" data " : marshal ( segments . items , segment_fields ) ,
" limit " : limit ,
" total " : total ,
} , 200
" total " : segments . total ,
" total_pages " : segments . pages ,
" page " : page ,
}
return response , 200
@setup_required
@login_required
@account_initialization_required
def delete ( self , dataset_id , document_id ) :
# check dataset
dataset_id = str ( dataset_id )
dataset = DatasetService . get_dataset ( dataset_id )
if not dataset :
raise NotFound ( " Dataset not found. " )
# check user's model setting
DatasetService . check_dataset_model_setting ( dataset )
# check document
document_id = str ( document_id )
document = DocumentService . get_document ( dataset_id , document_id )
if not document :
raise NotFound ( " Document not found. " )
segment_ids = request . args . getlist ( " segment_id " )
# The role of the current user in the ta table must be admin or owner
if not current_user . is_editor :
raise Forbidden ( )
try :
DatasetService . check_dataset_permission ( dataset , current_user )
except services . errors . account . NoPermissionError as e :
raise Forbidden ( str ( e ) )
SegmentService . delete_segments ( segment_ids , document , dataset )
return { " result " : " success " } , 200
class DatasetDocumentSegmentApi ( Resource ) :
@ -115,11 +137,15 @@ class DatasetDocumentSegmentApi(Resource):
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check ( " vector_space " )
def patch ( self , dataset_id , seg ment_id, action ) :
def patch ( self , dataset_id , docu ment_id, action ) :
dataset_id = str ( dataset_id )
dataset = DatasetService . get_dataset ( dataset_id )
if not dataset :
raise NotFound ( " Dataset not found. " )
document_id = str ( document_id )
document = DocumentService . get_document ( dataset_id , document_id )
if not document :
raise NotFound ( " Document not found. " )
# check user's model setting
DatasetService . check_dataset_model_setting ( dataset )
# The role of the current user in the ta table must be admin, owner, or editor
@ -142,64 +168,21 @@ class DatasetDocumentSegmentApi(Resource):
)
except LLMBadRequestError :
raise ProviderNotInitializeError (
" No Embedding Model available. Please configure a valid provider "
" in the Settings -> Model Provider. "
" No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider. "
)
except ProviderTokenNotInitError as ex :
raise ProviderNotInitializeError ( ex . description )
segment_ids = request . args . getlist ( " segment_id " )
segment = DocumentSegment . query . filter (
DocumentSegment . id == str ( segment_id ) , DocumentSegment . tenant_id == current_user . current_tenant_id
) . first ( )
if not segment :
raise NotFound ( " Segment not found. " )
if segment . status != " completed " :
raise NotFound ( " Segment is not completed, enable or disable function is not allowed " )
document_indexing_cache_key = " document_ {} _indexing " . format ( segment . document_id )
document_indexing_cache_key = " document_ {} _indexing " . format ( document . id )
cache_result = redis_client . get ( document_indexing_cache_key )
if cache_result is not None :
raise InvalidActionError ( " Document is being indexed, please try again later " )
indexing_cache_key = " segment_ {} _indexing " . format ( segment . id )
cache_result = redis_client . get ( indexing_cache_key )
if cache_result is not None :
raise InvalidActionError ( " Segment is being indexed, please try again later " )
if action == " enable " :
if segment . enabled :
raise InvalidActionError ( " Segment is already enabled. " )
segment . enabled = True
segment . disabled_at = None
segment . disabled_by = None
db . session . commit ( )
# Set cache to prevent indexing the same segment multiple times
redis_client . setex ( indexing_cache_key , 600 , 1 )
enable_segment_to_index_task . delay ( segment . id )
return { " result " : " success " } , 200
elif action == " disable " :
if not segment . enabled :
raise InvalidActionError ( " Segment is already disabled. " )
segment . enabled = False
segment . disabled_at = datetime . now ( UTC ) . replace ( tzinfo = None )
segment . disabled_by = current_user . id
db . session . commit ( )
# Set cache to prevent indexing the same segment multiple times
redis_client . setex ( indexing_cache_key , 600 , 1 )
disable_segment_from_index_task . delay ( segment . id )
try :
SegmentService . update_segments_status ( segment_ids , action , dataset , document )
except Exception as e :
raise InvalidActionError ( str ( e ) )
return { " result " : " success " } , 200
else :
raise InvalidActionError ( )
class DatasetDocumentSegmentAddApi ( Resource ) :
@ -233,8 +216,7 @@ class DatasetDocumentSegmentAddApi(Resource):
)
except LLMBadRequestError :
raise ProviderNotInitializeError (
" No Embedding Model available. Please configure a valid provider "
" in the Settings -> Model Provider. "
" No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider. "
)
except ProviderTokenNotInitError as ex :
raise ProviderNotInitializeError ( ex . description )
@ -283,8 +265,7 @@ class DatasetDocumentSegmentUpdateApi(Resource):
)
except LLMBadRequestError :
raise ProviderNotInitializeError (
" No Embedding Model available. Please configure a valid provider "
" in the Settings -> Model Provider. "
" No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider. "
)
except ProviderTokenNotInitError as ex :
raise ProviderNotInitializeError ( ex . description )
@ -307,9 +288,12 @@ class DatasetDocumentSegmentUpdateApi(Resource):
parser . add_argument ( " content " , type = str , required = True , nullable = False , location = " json " )
parser . add_argument ( " answer " , type = str , required = False , nullable = True , location = " json " )
parser . add_argument ( " keywords " , type = list , required = False , nullable = True , location = " json " )
parser . add_argument (
" regenerate_child_chunks " , type = bool , required = False , nullable = True , default = False , location = " json "
)
args = parser . parse_args ( )
SegmentService . segment_create_args_validate ( args , document )
segment = SegmentService . update_segment ( args , segment , document , dataset )
segment = SegmentService . update_segment ( SegmentUpdateArgs( * * args) , segment , document , dataset )
return { " data " : marshal ( segment , segment_fields ) , " doc_form " : document . doc_form } , 200
@setup_required
@ -381,9 +365,9 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
result = [ ]
for index , row in df . iterrows ( ) :
if document . doc_form == " qa_model " :
data = { " content " : row [ 0 ] , " answer " : row [ 1 ] }
data = { " content " : row . iloc [ 0 ] , " answer " : row . iloc [ 1 ] }
else :
data = { " content " : row [ 0 ] }
data = { " content " : row . iloc [ 0 ] }
result . append ( data )
if len ( result ) == 0 :
raise ValueError ( " The CSV file is empty. " )
@ -412,8 +396,247 @@ class DatasetDocumentSegmentBatchImportApi(Resource):
return { " job_id " : job_id , " job_status " : cache_result . decode ( ) } , 200
class ChildChunkAddApi ( Resource ) :
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check ( " vector_space " )
@cloud_edition_billing_knowledge_limit_check ( " add_segment " )
def post ( self , dataset_id , document_id , segment_id ) :
# check dataset
dataset_id = str ( dataset_id )
dataset = DatasetService . get_dataset ( dataset_id )
if not dataset :
raise NotFound ( " Dataset not found. " )
# check document
document_id = str ( document_id )
document = DocumentService . get_document ( dataset_id , document_id )
if not document :
raise NotFound ( " Document not found. " )
# check segment
segment_id = str ( segment_id )
segment = DocumentSegment . query . filter (
DocumentSegment . id == str ( segment_id ) , DocumentSegment . tenant_id == current_user . current_tenant_id
) . first ( )
if not segment :
raise NotFound ( " Segment not found. " )
if not current_user . is_editor :
raise Forbidden ( )
# check embedding model setting
if dataset . indexing_technique == " high_quality " :
try :
model_manager = ModelManager ( )
model_manager . get_model_instance (
tenant_id = current_user . current_tenant_id ,
provider = dataset . embedding_model_provider ,
model_type = ModelType . TEXT_EMBEDDING ,
model = dataset . embedding_model ,
)
except LLMBadRequestError :
raise ProviderNotInitializeError (
" No Embedding Model available. Please configure a valid provider in the Settings -> Model Provider. "
)
except ProviderTokenNotInitError as ex :
raise ProviderNotInitializeError ( ex . description )
try :
DatasetService . check_dataset_permission ( dataset , current_user )
except services . errors . account . NoPermissionError as e :
raise Forbidden ( str ( e ) )
# validate args
parser = reqparse . RequestParser ( )
parser . add_argument ( " content " , type = str , required = True , nullable = False , location = " json " )
args = parser . parse_args ( )
try :
child_chunk = SegmentService . create_child_chunk ( args . get ( " content " ) , segment , document , dataset )
except ChildChunkIndexingServiceError as e :
raise ChildChunkIndexingError ( str ( e ) )
return { " data " : marshal ( child_chunk , child_chunk_fields ) } , 200
@setup_required
@login_required
@account_initialization_required
def get ( self , dataset_id , document_id , segment_id ) :
# check dataset
dataset_id = str ( dataset_id )
dataset = DatasetService . get_dataset ( dataset_id )
if not dataset :
raise NotFound ( " Dataset not found. " )
# check user's model setting
DatasetService . check_dataset_model_setting ( dataset )
# check document
document_id = str ( document_id )
document = DocumentService . get_document ( dataset_id , document_id )
if not document :
raise NotFound ( " Document not found. " )
# check segment
segment_id = str ( segment_id )
segment = DocumentSegment . query . filter (
DocumentSegment . id == str ( segment_id ) , DocumentSegment . tenant_id == current_user . current_tenant_id
) . first ( )
if not segment :
raise NotFound ( " Segment not found. " )
parser = reqparse . RequestParser ( )
parser . add_argument ( " limit " , type = int , default = 20 , location = " args " )
parser . add_argument ( " keyword " , type = str , default = None , location = " args " )
parser . add_argument ( " page " , type = int , default = 1 , location = " args " )
args = parser . parse_args ( )
page = args [ " page " ]
limit = min ( args [ " limit " ] , 100 )
keyword = args [ " keyword " ]
child_chunks = SegmentService . get_child_chunks ( segment_id , document_id , dataset_id , page , limit , keyword )
return {
" data " : marshal ( child_chunks . items , child_chunk_fields ) ,
" total " : child_chunks . total ,
" total_pages " : child_chunks . pages ,
" page " : page ,
" limit " : limit ,
} , 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check ( " vector_space " )
def patch ( self , dataset_id , document_id , segment_id ) :
# check dataset
dataset_id = str ( dataset_id )
dataset = DatasetService . get_dataset ( dataset_id )
if not dataset :
raise NotFound ( " Dataset not found. " )
# check user's model setting
DatasetService . check_dataset_model_setting ( dataset )
# check document
document_id = str ( document_id )
document = DocumentService . get_document ( dataset_id , document_id )
if not document :
raise NotFound ( " Document not found. " )
# check segment
segment_id = str ( segment_id )
segment = DocumentSegment . query . filter (
DocumentSegment . id == str ( segment_id ) , DocumentSegment . tenant_id == current_user . current_tenant_id
) . first ( )
if not segment :
raise NotFound ( " Segment not found. " )
# The role of the current user in the ta table must be admin, owner, or editor
if not current_user . is_editor :
raise Forbidden ( )
try :
DatasetService . check_dataset_permission ( dataset , current_user )
except services . errors . account . NoPermissionError as e :
raise Forbidden ( str ( e ) )
# validate args
parser = reqparse . RequestParser ( )
parser . add_argument ( " chunks " , type = list , required = True , nullable = False , location = " json " )
args = parser . parse_args ( )
try :
chunks = [ ChildChunkUpdateArgs ( * * chunk ) for chunk in args . get ( " chunks " ) ]
child_chunks = SegmentService . update_child_chunks ( chunks , segment , document , dataset )
except ChildChunkIndexingServiceError as e :
raise ChildChunkIndexingError ( str ( e ) )
return { " data " : marshal ( child_chunks , child_chunk_fields ) } , 200
class ChildChunkUpdateApi ( Resource ) :
@setup_required
@login_required
@account_initialization_required
def delete ( self , dataset_id , document_id , segment_id , child_chunk_id ) :
# check dataset
dataset_id = str ( dataset_id )
dataset = DatasetService . get_dataset ( dataset_id )
if not dataset :
raise NotFound ( " Dataset not found. " )
# check user's model setting
DatasetService . check_dataset_model_setting ( dataset )
# check document
document_id = str ( document_id )
document = DocumentService . get_document ( dataset_id , document_id )
if not document :
raise NotFound ( " Document not found. " )
# check segment
segment_id = str ( segment_id )
segment = DocumentSegment . query . filter (
DocumentSegment . id == str ( segment_id ) , DocumentSegment . tenant_id == current_user . current_tenant_id
) . first ( )
if not segment :
raise NotFound ( " Segment not found. " )
# check child chunk
child_chunk_id = str ( child_chunk_id )
child_chunk = ChildChunk . query . filter (
ChildChunk . id == str ( child_chunk_id ) , ChildChunk . tenant_id == current_user . current_tenant_id
) . first ( )
if not child_chunk :
raise NotFound ( " Child chunk not found. " )
# The role of the current user in the ta table must be admin or owner
if not current_user . is_editor :
raise Forbidden ( )
try :
DatasetService . check_dataset_permission ( dataset , current_user )
except services . errors . account . NoPermissionError as e :
raise Forbidden ( str ( e ) )
try :
SegmentService . delete_child_chunk ( child_chunk , dataset )
except ChildChunkDeleteIndexServiceError as e :
raise ChildChunkDeleteIndexError ( str ( e ) )
return { " result " : " success " } , 200
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check ( " vector_space " )
def patch ( self , dataset_id , document_id , segment_id , child_chunk_id ) :
# check dataset
dataset_id = str ( dataset_id )
dataset = DatasetService . get_dataset ( dataset_id )
if not dataset :
raise NotFound ( " Dataset not found. " )
# check user's model setting
DatasetService . check_dataset_model_setting ( dataset )
# check document
document_id = str ( document_id )
document = DocumentService . get_document ( dataset_id , document_id )
if not document :
raise NotFound ( " Document not found. " )
# check segment
segment_id = str ( segment_id )
segment = DocumentSegment . query . filter (
DocumentSegment . id == str ( segment_id ) , DocumentSegment . tenant_id == current_user . current_tenant_id
) . first ( )
if not segment :
raise NotFound ( " Segment not found. " )
# check child chunk
child_chunk_id = str ( child_chunk_id )
child_chunk = ChildChunk . query . filter (
ChildChunk . id == str ( child_chunk_id ) , ChildChunk . tenant_id == current_user . current_tenant_id
) . first ( )
if not child_chunk :
raise NotFound ( " Child chunk not found. " )
# The role of the current user in the ta table must be admin or owner
if not current_user . is_editor :
raise Forbidden ( )
try :
DatasetService . check_dataset_permission ( dataset , current_user )
except services . errors . account . NoPermissionError as e :
raise Forbidden ( str ( e ) )
# validate args
parser = reqparse . RequestParser ( )
parser . add_argument ( " content " , type = str , required = True , nullable = False , location = " json " )
args = parser . parse_args ( )
try :
child_chunk = SegmentService . update_child_chunk (
args . get ( " content " ) , child_chunk , segment , document , dataset
)
except ChildChunkIndexingServiceError as e :
raise ChildChunkIndexingError ( str ( e ) )
return { " data " : marshal ( child_chunk , child_chunk_fields ) } , 200
api . add_resource ( DatasetDocumentSegmentListApi , " /datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments " )
api . add_resource ( DatasetDocumentSegmentApi , " /datasets/<uuid:dataset_id>/segments/<uuid:segment_id>/<string:action> " )
api . add_resource (
DatasetDocumentSegmentApi , " /datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment/<string:action> "
)
api . add_resource ( DatasetDocumentSegmentAddApi , " /datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segment " )
api . add_resource (
DatasetDocumentSegmentUpdateApi ,
@ -424,3 +647,11 @@ api.add_resource(
" /datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/batch_import " ,
" /datasets/batch_import_status/<uuid:job_id> " ,
)
api . add_resource (
ChildChunkAddApi ,
" /datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks " ,
)
api . add_resource (
ChildChunkUpdateApi ,
" /datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>/child_chunks/<uuid:child_chunk_id> " ,
)