Feat/dataset service api (#1245)
Co-authored-by: jyong <jyong@dify.ai> Co-authored-by: StyleZhang <jasonapring2015@outlook.com>pull/1248/head
parent
54ff03c35d
commit
46154c6705
@ -0,0 +1,84 @@
|
||||
from flask import request
|
||||
from flask_restful import reqparse, marshal
|
||||
import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.dataset.error import DatasetNameDuplicateError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.login.login import current_user
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
def _validate_name(name):
|
||||
if not name or len(name) < 1 or len(name) > 40:
|
||||
raise ValueError('Name must be between 1 to 40 characters.')
|
||||
return name
|
||||
|
||||
|
||||
class DatasetApi(DatasetApiResource):
|
||||
"""Resource for get datasets."""
|
||||
|
||||
def get(self, tenant_id):
|
||||
page = request.args.get('page', default=1, type=int)
|
||||
limit = request.args.get('limit', default=20, type=int)
|
||||
provider = request.args.get('provider', default="vendor")
|
||||
datasets, total = DatasetService.get_datasets(page, limit, provider,
|
||||
tenant_id, current_user)
|
||||
# check embedding setting
|
||||
provider_service = ProviderService()
|
||||
valid_model_list = provider_service.get_valid_model_list(current_user.current_tenant_id,
|
||||
ModelType.EMBEDDINGS.value)
|
||||
model_names = []
|
||||
for valid_model in valid_model_list:
|
||||
model_names.append(f"{valid_model['model_name']}:{valid_model['model_provider']['provider_name']}")
|
||||
data = marshal(datasets, dataset_detail_fields)
|
||||
for item in data:
|
||||
if item['indexing_technique'] == 'high_quality':
|
||||
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}"
|
||||
if item_model in model_names:
|
||||
item['embedding_available'] = True
|
||||
else:
|
||||
item['embedding_available'] = False
|
||||
else:
|
||||
item['embedding_available'] = True
|
||||
response = {
|
||||
'data': data,
|
||||
'has_more': len(datasets) == limit,
|
||||
'limit': limit,
|
||||
'total': total,
|
||||
'page': page
|
||||
}
|
||||
return response, 200
|
||||
|
||||
"""Resource for datasets."""
|
||||
|
||||
def post(self, tenant_id):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', nullable=False, required=True,
|
||||
help='type is required. Name must be between 1 to 40 characters.',
|
||||
type=_validate_name)
|
||||
parser.add_argument('indexing_technique', type=str, location='json',
|
||||
choices=('high_quality', 'economy'),
|
||||
help='Invalid indexing technique.')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
dataset = DatasetService.create_empty_dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=args['name'],
|
||||
indexing_technique=args['indexing_technique'],
|
||||
account=current_user
|
||||
)
|
||||
except services.errors.dataset.DatasetNameDuplicateError:
|
||||
raise DatasetNameDuplicateError()
|
||||
|
||||
return marshal(dataset, dataset_detail_fields), 200
|
||||
|
||||
|
||||
api.add_resource(DatasetApi, '/datasets')
|
||||
|
||||
@ -1,20 +1,73 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = 'no_file_uploaded'
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = 'file_too_large'
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
|
||||
class HighQualityDatasetOnlyError(BaseHTTPException):
|
||||
error_code = 'high_quality_dataset_only'
|
||||
description = "Current operation only supports 'high-quality' datasets."
|
||||
code = 400
|
||||
|
||||
|
||||
class DatasetNotInitializedError(BaseHTTPException):
|
||||
error_code = 'dataset_not_initialized'
|
||||
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
||||
code = 400
|
||||
|
||||
|
||||
class ArchivedDocumentImmutableError(BaseHTTPException):
|
||||
error_code = 'archived_document_immutable'
|
||||
description = "Cannot operate when document was archived."
|
||||
description = "The archived document is not editable."
|
||||
code = 403
|
||||
|
||||
|
||||
class DatasetNameDuplicateError(BaseHTTPException):
|
||||
error_code = 'dataset_name_duplicate'
|
||||
description = "The dataset name already exists. Please modify your dataset name."
|
||||
code = 409
|
||||
|
||||
|
||||
class InvalidActionError(BaseHTTPException):
|
||||
error_code = 'invalid_action'
|
||||
description = "Invalid action."
|
||||
code = 400
|
||||
|
||||
|
||||
class DocumentAlreadyFinishedError(BaseHTTPException):
|
||||
error_code = 'document_already_finished'
|
||||
description = "The document has been processed. Please refresh the page or go to the document details."
|
||||
code = 400
|
||||
|
||||
|
||||
class DocumentIndexingError(BaseHTTPException):
|
||||
error_code = 'document_indexing'
|
||||
description = "Cannot operate document during indexing."
|
||||
code = 403
|
||||
description = "The document is being processed and cannot be edited."
|
||||
code = 400
|
||||
|
||||
|
||||
class DatasetNotInitedError(BaseHTTPException):
|
||||
error_code = 'dataset_not_inited'
|
||||
description = "The dataset is still being initialized or indexing. Please wait a moment."
|
||||
code = 403
|
||||
class InvalidMetadataError(BaseHTTPException):
|
||||
error_code = 'invalid_metadata'
|
||||
description = "The metadata content is incorrect. Please check and verify."
|
||||
code = 400
|
||||
|
||||
@ -0,0 +1,59 @@
|
||||
from flask_login import current_user
|
||||
from flask_restful import reqparse, marshal
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_database import db
|
||||
from fields.segment_fields import segment_fields
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DocumentService, SegmentService
|
||||
|
||||
|
||||
class SegmentApi(DatasetApiResource):
|
||||
"""Resource for segments."""
|
||||
def post(self, tenant_id, dataset_id, document_id):
|
||||
"""Create single segment."""
|
||||
# check dataset
|
||||
dataset_id = str(dataset_id)
|
||||
tenant_id = str(tenant_id)
|
||||
dataset = db.session.query(Dataset).filter(
|
||||
Dataset.tenant_id == tenant_id,
|
||||
Dataset.id == dataset_id
|
||||
).first()
|
||||
# check document
|
||||
document_id = str(document_id)
|
||||
document = DocumentService.get_document(dataset.id, document_id)
|
||||
if not document:
|
||||
raise NotFound('Document not found.')
|
||||
# check embedding model setting
|
||||
if dataset.indexing_technique == 'high_quality':
|
||||
try:
|
||||
ModelFactory.get_embedding_model(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
model_provider_name=dataset.embedding_model_provider,
|
||||
model_name=dataset.embedding_model
|
||||
)
|
||||
except LLMBadRequestError:
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
# validate args
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('segments', type=list, required=False, nullable=True, location='json')
|
||||
args = parser.parse_args()
|
||||
for args_item in args['segments']:
|
||||
SegmentService.segment_create_args_validate(args_item, document)
|
||||
segments = SegmentService.multi_create_segment(args['segments'], document, dataset)
|
||||
return {
|
||||
'data': marshal(segments, segment_fields),
|
||||
'doc_form': document.doc_form
|
||||
}, 200
|
||||
|
||||
|
||||
api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
|
||||
@ -0,0 +1,138 @@
|
||||
from flask_restful import fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
app_detail_kernel_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
}
|
||||
|
||||
related_app_list = {
|
||||
'data': fields.List(fields.Nested(app_detail_kernel_fields)),
|
||||
'total': fields.Integer,
|
||||
}
|
||||
|
||||
model_config_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw(attribute='suggested_questions_list'),
|
||||
'suggested_questions_after_answer': fields.Raw(attribute='suggested_questions_after_answer_dict'),
|
||||
'speech_to_text': fields.Raw(attribute='speech_to_text_dict'),
|
||||
'retriever_resource': fields.Raw(attribute='retriever_resource_dict'),
|
||||
'more_like_this': fields.Raw(attribute='more_like_this_dict'),
|
||||
'sensitive_word_avoidance': fields.Raw(attribute='sensitive_word_avoidance_dict'),
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'user_input_form': fields.Raw(attribute='user_input_form_list'),
|
||||
'dataset_query_variable': fields.String,
|
||||
'pre_prompt': fields.String,
|
||||
'agent_mode': fields.Raw(attribute='agent_mode_dict'),
|
||||
}
|
||||
|
||||
app_detail_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'enable_site': fields.Boolean,
|
||||
'enable_api': fields.Boolean,
|
||||
'api_rpm': fields.Integer,
|
||||
'api_rph': fields.Integer,
|
||||
'is_demo': fields.Boolean,
|
||||
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
prompt_config_fields = {
|
||||
'prompt_template': fields.String,
|
||||
}
|
||||
|
||||
model_config_partial_fields = {
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'pre_prompt': fields.String,
|
||||
}
|
||||
|
||||
app_partial_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'enable_site': fields.Boolean,
|
||||
'enable_api': fields.Boolean,
|
||||
'is_demo': fields.Boolean,
|
||||
'model_config': fields.Nested(model_config_partial_fields, attribute='app_model_config'),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
app_pagination_fields = {
|
||||
'page': fields.Integer,
|
||||
'limit': fields.Integer(attribute='per_page'),
|
||||
'total': fields.Integer,
|
||||
'has_more': fields.Boolean(attribute='has_next'),
|
||||
'data': fields.List(fields.Nested(app_partial_fields), attribute='items')
|
||||
}
|
||||
|
||||
template_fields = {
|
||||
'name': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'description': fields.String,
|
||||
'mode': fields.String,
|
||||
'model_config': fields.Nested(model_config_fields),
|
||||
}
|
||||
|
||||
template_list_fields = {
|
||||
'data': fields.List(fields.Nested(template_fields)),
|
||||
}
|
||||
|
||||
site_fields = {
|
||||
'access_token': fields.String(attribute='code'),
|
||||
'code': fields.String,
|
||||
'title': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'description': fields.String,
|
||||
'default_language': fields.String,
|
||||
'customize_domain': fields.String,
|
||||
'copyright': fields.String,
|
||||
'privacy_policy': fields.String,
|
||||
'customize_token_strategy': fields.String,
|
||||
'prompt_public': fields.Boolean,
|
||||
'app_base_url': fields.String,
|
||||
}
|
||||
|
||||
app_detail_fields_with_site = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'enable_site': fields.Boolean,
|
||||
'enable_api': fields.Boolean,
|
||||
'api_rpm': fields.Integer,
|
||||
'api_rph': fields.Integer,
|
||||
'is_demo': fields.Boolean,
|
||||
'model_config': fields.Nested(model_config_fields, attribute='app_model_config'),
|
||||
'site': fields.Nested(site_fields),
|
||||
'api_base_url': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
app_site_fields = {
|
||||
'app_id': fields.String,
|
||||
'access_token': fields.String(attribute='code'),
|
||||
'code': fields.String,
|
||||
'title': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String,
|
||||
'description': fields.String,
|
||||
'default_language': fields.String,
|
||||
'customize_domain': fields.String,
|
||||
'copyright': fields.String,
|
||||
'privacy_policy': fields.String,
|
||||
'customize_token_strategy': fields.String,
|
||||
'prompt_public': fields.Boolean
|
||||
}
|
||||
@ -0,0 +1,182 @@
|
||||
from flask_restful import fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
|
||||
class MessageTextField(fields.Raw):
|
||||
def format(self, value):
|
||||
return value[0]['text'] if value else ''
|
||||
|
||||
|
||||
account_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'email': fields.String
|
||||
}
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String,
|
||||
'content': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account': fields.Nested(account_fields, allow_null=True),
|
||||
}
|
||||
|
||||
annotation_fields = {
|
||||
'content': fields.String,
|
||||
'account': fields.Nested(account_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_detail_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'message': fields.Raw,
|
||||
'message_tokens': fields.Integer,
|
||||
'answer': fields.String,
|
||||
'answer_tokens': fields.Integer,
|
||||
'provider_response_latency': fields.Float,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'feedbacks': fields.List(fields.Nested(feedback_fields)),
|
||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
feedback_stat_fields = {
|
||||
'like': fields.Integer,
|
||||
'dislike': fields.Integer
|
||||
}
|
||||
|
||||
model_config_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
'model': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'pre_prompt': fields.String,
|
||||
'agent_mode': fields.Raw,
|
||||
}
|
||||
|
||||
simple_configs_fields = {
|
||||
'prompt_template': fields.String,
|
||||
}
|
||||
|
||||
simple_model_config_fields = {
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'pre_prompt': fields.String,
|
||||
}
|
||||
|
||||
simple_message_detail_fields = {
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'message': MessageTextField,
|
||||
'answer': fields.String,
|
||||
}
|
||||
|
||||
conversation_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_end_user_session_id': fields.String(),
|
||||
'from_account_id': fields.String,
|
||||
'read_at': TimestampField,
|
||||
'created_at': TimestampField,
|
||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
||||
'model_config': fields.Nested(simple_model_config_fields),
|
||||
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||
'admin_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||
'message': fields.Nested(simple_message_detail_fields, attribute='first_message')
|
||||
}
|
||||
|
||||
conversation_pagination_fields = {
|
||||
'page': fields.Integer,
|
||||
'limit': fields.Integer(attribute='per_page'),
|
||||
'total': fields.Integer,
|
||||
'has_more': fields.Boolean(attribute='has_next'),
|
||||
'data': fields.List(fields.Nested(conversation_fields), attribute='items')
|
||||
}
|
||||
|
||||
conversation_message_detail_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'model_config': fields.Nested(model_config_fields),
|
||||
'message': fields.Nested(message_detail_fields, attribute='first_message'),
|
||||
}
|
||||
|
||||
simple_model_config_fields = {
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'pre_prompt': fields.String,
|
||||
}
|
||||
|
||||
conversation_with_summary_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_end_user_session_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'summary': fields.String(attribute='summary_or_query'),
|
||||
'read_at': TimestampField,
|
||||
'created_at': TimestampField,
|
||||
'annotated': fields.Boolean,
|
||||
'model_config': fields.Nested(simple_model_config_fields),
|
||||
'message_count': fields.Integer,
|
||||
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
|
||||
}
|
||||
|
||||
conversation_with_summary_pagination_fields = {
|
||||
'page': fields.Integer,
|
||||
'limit': fields.Integer(attribute='per_page'),
|
||||
'total': fields.Integer,
|
||||
'has_more': fields.Boolean(attribute='has_next'),
|
||||
'data': fields.List(fields.Nested(conversation_with_summary_fields), attribute='items')
|
||||
}
|
||||
|
||||
conversation_detail_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
'from_source': fields.String,
|
||||
'from_end_user_id': fields.String,
|
||||
'from_account_id': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'annotated': fields.Boolean,
|
||||
'model_config': fields.Nested(model_config_fields),
|
||||
'message_count': fields.Integer,
|
||||
'user_feedback_stats': fields.Nested(feedback_stat_fields),
|
||||
'admin_feedback_stats': fields.Nested(feedback_stat_fields)
|
||||
}
|
||||
|
||||
simple_conversation_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'status': fields.String,
|
||||
'introduction': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
conversation_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(simple_conversation_fields))
|
||||
}
|
||||
|
||||
conversation_with_model_config_fields = {
|
||||
**simple_conversation_fields,
|
||||
'model_config': fields.Raw,
|
||||
}
|
||||
|
||||
conversation_with_model_config_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(conversation_with_model_config_fields))
|
||||
}
|
||||
@ -0,0 +1,65 @@
|
||||
from flask_restful import fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
integrate_icon_fields = {
|
||||
'type': fields.String,
|
||||
'url': fields.String,
|
||||
'emoji': fields.String
|
||||
}
|
||||
|
||||
integrate_page_fields = {
|
||||
'page_name': fields.String,
|
||||
'page_id': fields.String,
|
||||
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
|
||||
'is_bound': fields.Boolean,
|
||||
'parent_id': fields.String,
|
||||
'type': fields.String
|
||||
}
|
||||
|
||||
integrate_workspace_fields = {
|
||||
'workspace_name': fields.String,
|
||||
'workspace_id': fields.String,
|
||||
'workspace_icon': fields.String,
|
||||
'pages': fields.List(fields.Nested(integrate_page_fields))
|
||||
}
|
||||
|
||||
integrate_notion_info_list_fields = {
|
||||
'notion_info': fields.List(fields.Nested(integrate_workspace_fields)),
|
||||
}
|
||||
|
||||
integrate_icon_fields = {
|
||||
'type': fields.String,
|
||||
'url': fields.String,
|
||||
'emoji': fields.String
|
||||
}
|
||||
|
||||
integrate_page_fields = {
|
||||
'page_name': fields.String,
|
||||
'page_id': fields.String,
|
||||
'page_icon': fields.Nested(integrate_icon_fields, allow_null=True),
|
||||
'parent_id': fields.String,
|
||||
'type': fields.String
|
||||
}
|
||||
|
||||
integrate_workspace_fields = {
|
||||
'workspace_name': fields.String,
|
||||
'workspace_id': fields.String,
|
||||
'workspace_icon': fields.String,
|
||||
'pages': fields.List(fields.Nested(integrate_page_fields)),
|
||||
'total': fields.Integer
|
||||
}
|
||||
|
||||
integrate_fields = {
|
||||
'id': fields.String,
|
||||
'provider': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'is_bound': fields.Boolean,
|
||||
'disabled': fields.Boolean,
|
||||
'link': fields.String,
|
||||
'source_info': fields.Nested(integrate_workspace_fields)
|
||||
}
|
||||
|
||||
integrate_list_fields = {
|
||||
'data': fields.List(fields.Nested(integrate_fields)),
|
||||
}
|
||||
@ -0,0 +1,43 @@
|
||||
from flask_restful import fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
dataset_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'description': fields.String,
|
||||
'permission': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'indexing_technique': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
}
|
||||
|
||||
dataset_detail_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'description': fields.String,
|
||||
'provider': fields.String,
|
||||
'permission': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'indexing_technique': fields.String,
|
||||
'app_count': fields.Integer,
|
||||
'document_count': fields.Integer,
|
||||
'word_count': fields.Integer,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'updated_by': fields.String,
|
||||
'updated_at': TimestampField,
|
||||
'embedding_model': fields.String,
|
||||
'embedding_model_provider': fields.String,
|
||||
'embedding_available': fields.Boolean
|
||||
}
|
||||
|
||||
dataset_query_detail_fields = {
|
||||
"id": fields.String,
|
||||
"content": fields.String,
|
||||
"source": fields.String,
|
||||
"source_app_id": fields.String,
|
||||
"created_by_role": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_at": TimestampField
|
||||
}
|
||||
@ -0,0 +1,76 @@
|
||||
from flask_restful import fields
|
||||
|
||||
from fields.dataset_fields import dataset_fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
document_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'data_source_type': fields.String,
|
||||
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
|
||||
'dataset_process_rule_id': fields.String,
|
||||
'name': fields.String,
|
||||
'created_from': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'tokens': fields.Integer,
|
||||
'indexing_status': fields.String,
|
||||
'error': fields.String,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'archived': fields.Boolean,
|
||||
'display_status': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'hit_count': fields.Integer,
|
||||
'doc_form': fields.String,
|
||||
}
|
||||
|
||||
document_with_segments_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'data_source_type': fields.String,
|
||||
'data_source_info': fields.Raw(attribute='data_source_info_dict'),
|
||||
'dataset_process_rule_id': fields.String,
|
||||
'name': fields.String,
|
||||
'created_from': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'tokens': fields.Integer,
|
||||
'indexing_status': fields.String,
|
||||
'error': fields.String,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'archived': fields.Boolean,
|
||||
'display_status': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'hit_count': fields.Integer,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer
|
||||
}
|
||||
|
||||
dataset_and_document_fields = {
|
||||
'dataset': fields.Nested(dataset_fields),
|
||||
'documents': fields.List(fields.Nested(document_fields)),
|
||||
'batch': fields.String
|
||||
}
|
||||
|
||||
document_status_fields = {
|
||||
'id': fields.String,
|
||||
'indexing_status': fields.String,
|
||||
'processing_started_at': TimestampField,
|
||||
'parsing_completed_at': TimestampField,
|
||||
'cleaning_completed_at': TimestampField,
|
||||
'splitting_completed_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'paused_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'completed_segments': fields.Integer,
|
||||
'total_segments': fields.Integer,
|
||||
}
|
||||
|
||||
document_status_fields_list = {
|
||||
'data': fields.List(fields.Nested(document_status_fields))
|
||||
}
|
||||
@ -0,0 +1,18 @@
|
||||
from flask_restful import fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
upload_config_fields = {
|
||||
'file_size_limit': fields.Integer,
|
||||
'batch_count_limit': fields.Integer
|
||||
}
|
||||
|
||||
file_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'size': fields.Integer,
|
||||
'extension': fields.String,
|
||||
'mime_type': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
}
|
||||
@ -0,0 +1,41 @@
|
||||
from flask_restful import fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
document_fields = {
|
||||
'id': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'name': fields.String,
|
||||
'doc_type': fields.String,
|
||||
}
|
||||
|
||||
segment_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'document_id': fields.String,
|
||||
'content': fields.String,
|
||||
'answer': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'tokens': fields.Integer,
|
||||
'keywords': fields.List(fields.String),
|
||||
'index_node_id': fields.String,
|
||||
'index_node_hash': fields.String,
|
||||
'hit_count': fields.Integer,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'status': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'indexing_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField,
|
||||
'document': fields.Nested(document_fields),
|
||||
}
|
||||
|
||||
hit_testing_record_fields = {
|
||||
'segment': fields.Nested(segment_fields),
|
||||
'score': fields.Float,
|
||||
'tsne_position': fields.Raw
|
||||
}
|
||||
@ -0,0 +1,25 @@
|
||||
from flask_restful import fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
app_fields = {
|
||||
'id': fields.String,
|
||||
'name': fields.String,
|
||||
'mode': fields.String,
|
||||
'icon': fields.String,
|
||||
'icon_background': fields.String
|
||||
}
|
||||
|
||||
installed_app_fields = {
|
||||
'id': fields.String,
|
||||
'app': fields.Nested(app_fields),
|
||||
'app_owner_tenant_id': fields.String,
|
||||
'is_pinned': fields.Boolean,
|
||||
'last_used_at': TimestampField,
|
||||
'editable': fields.Boolean,
|
||||
'uninstallable': fields.Boolean,
|
||||
}
|
||||
|
||||
installed_app_list_fields = {
|
||||
'installed_apps': fields.List(fields.Nested(installed_app_fields))
|
||||
}
|
||||
@ -0,0 +1,43 @@
|
||||
from flask_restful import fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
}
|
||||
|
||||
retriever_resource_fields = {
|
||||
'id': fields.String,
|
||||
'message_id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'dataset_id': fields.String,
|
||||
'dataset_name': fields.String,
|
||||
'document_id': fields.String,
|
||||
'document_name': fields.String,
|
||||
'data_source_type': fields.String,
|
||||
'segment_id': fields.String,
|
||||
'score': fields.Float,
|
||||
'hit_count': fields.Integer,
|
||||
'word_count': fields.Integer,
|
||||
'segment_position': fields.Integer,
|
||||
'index_node_hash': fields.String,
|
||||
'content': fields.String,
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(message_fields))
|
||||
}
|
||||
@ -0,0 +1,32 @@
|
||||
from flask_restful import fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
segment_fields = {
|
||||
'id': fields.String,
|
||||
'position': fields.Integer,
|
||||
'document_id': fields.String,
|
||||
'content': fields.String,
|
||||
'answer': fields.String,
|
||||
'word_count': fields.Integer,
|
||||
'tokens': fields.Integer,
|
||||
'keywords': fields.List(fields.String),
|
||||
'index_node_id': fields.String,
|
||||
'index_node_hash': fields.String,
|
||||
'hit_count': fields.Integer,
|
||||
'enabled': fields.Boolean,
|
||||
'disabled_at': TimestampField,
|
||||
'disabled_by': fields.String,
|
||||
'status': fields.String,
|
||||
'created_by': fields.String,
|
||||
'created_at': TimestampField,
|
||||
'indexing_at': TimestampField,
|
||||
'completed_at': TimestampField,
|
||||
'error': fields.String,
|
||||
'stopped_at': TimestampField
|
||||
}
|
||||
|
||||
segment_list_response = {
|
||||
'data': fields.List(fields.Nested(segment_fields)),
|
||||
'has_more': fields.Boolean,
|
||||
'limit': fields.Integer
|
||||
}
|
||||
@ -0,0 +1,36 @@
|
||||
"""add_tenant_id_in_api_token
|
||||
|
||||
Revision ID: 2e9819ca5b28
|
||||
Revises: 6e2cfb077b04
|
||||
Create Date: 2023-09-22 15:41:01.243183
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '2e9819ca5b28'
|
||||
down_revision = 'ab23c11305d4'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('tenant_id', postgresql.UUID(), nullable=True))
|
||||
batch_op.create_index('api_token_tenant_idx', ['tenant_id', 'type'], unique=False)
|
||||
batch_op.drop_column('dataset_id')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('api_tokens', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('dataset_id', postgresql.UUID(), autoincrement=False, nullable=True))
|
||||
batch_op.drop_index('api_token_tenant_idx')
|
||||
batch_op.drop_column('tenant_id')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
__all__ = [
|
||||
'base', 'conversation', 'message', 'index', 'app_model_config', 'account', 'document', 'dataset',
|
||||
'app', 'completion', 'audio'
|
||||
'app', 'completion', 'audio', 'file'
|
||||
]
|
||||
|
||||
from . import *
|
||||
|
||||
@ -0,0 +1,123 @@
|
||||
import datetime
|
||||
import hashlib
|
||||
import time
|
||||
import uuid
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import current_user
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
|
||||
class FileService:
|
||||
|
||||
@staticmethod
|
||||
def upload_file(file: FileStorage) -> UploadFile:
|
||||
# read file content
|
||||
file_content = file.read()
|
||||
# get file size
|
||||
file_size = len(file_content)
|
||||
|
||||
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
|
||||
if file_size > file_size_limit:
|
||||
message = f'File size exceeded. {file_size} > {file_size_limit}'
|
||||
raise FileTooLargeError(message)
|
||||
|
||||
extension = file.filename.split('.')[-1]
|
||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, file_content)
|
||||
|
||||
# save file to db
|
||||
config = current_app.config
|
||||
upload_file = UploadFile(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
storage_type=config['STORAGE_TYPE'],
|
||||
key=file_key,
|
||||
name=file.filename,
|
||||
size=file_size,
|
||||
extension=extension,
|
||||
mime_type=file.mimetype,
|
||||
created_by=current_user.id,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(file_content).hexdigest()
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
|
||||
return upload_file
|
||||
|
||||
@staticmethod
|
||||
def upload_text(text: str, text_name: str) -> UploadFile:
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.txt'
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, text.encode('utf-8'))
|
||||
|
||||
# save file to db
|
||||
config = current_app.config
|
||||
upload_file = UploadFile(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
storage_type=config['STORAGE_TYPE'],
|
||||
key=file_key,
|
||||
name=text_name + '.txt',
|
||||
size=len(text),
|
||||
extension='txt',
|
||||
mime_type='text/plain',
|
||||
created_by=current_user.id,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
used=True,
|
||||
used_by=current_user.id,
|
||||
used_at=datetime.datetime.utcnow()
|
||||
)
|
||||
|
||||
db.session.add(upload_file)
|
||||
db.session.commit()
|
||||
|
||||
return upload_file
|
||||
|
||||
@staticmethod
|
||||
def get_file_preview(file_id: str) -> str:
|
||||
# get file storage key
|
||||
key = file_id + request.path
|
||||
cached_response = cache.get(key)
|
||||
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
|
||||
return cached_response['response']
|
||||
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
|
||||
# extract text from file
|
||||
extension = upload_file.extension
|
||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
text = FileExtractor.load(upload_file, return_text=True)
|
||||
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
|
||||
|
||||
return text
|
||||
Loading…
Reference in New Issue