diff --git a/api/config.py b/api/config.py index f81527da61..1d6de39abb 100644 --- a/api/config.py +++ b/api/config.py @@ -185,6 +185,9 @@ class Config: # For temp use only # set default LLM provider, default is 'openai', support `azure_openai` self.DEFAULT_LLM_PROVIDER = get_env('DEFAULT_LLM_PROVIDER') + # notion import setting + self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID') + self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET') class CloudEditionConfig(Config): diff --git a/api/controllers/console/__init__.py b/api/controllers/console/__init__.py index 971e489971..7426d84699 100644 --- a/api/controllers/console/__init__.py +++ b/api/controllers/console/__init__.py @@ -9,10 +9,10 @@ api = ExternalApi(bp) from .app import app, site, explore, completion, model_config, statistic, conversation, message # Import auth controllers -from .auth import login, oauth +from .auth import login, oauth, data_source_oauth # Import datasets controllers -from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing +from .datasets import datasets, datasets_document, datasets_segments, file, hit_testing, data_source # Import other controllers from . import setup, version, apikey diff --git a/api/controllers/console/auth/data_source_oauth.py b/api/controllers/console/auth/data_source_oauth.py index 576e8c0d0d..9dd9136705 100644 --- a/api/controllers/console/auth/data_source_oauth.py +++ b/api/controllers/console/auth/data_source_oauth.py @@ -9,7 +9,7 @@ from flask_login import current_user, login_required from flask_restful import Resource from werkzeug.exceptions import Forbidden from libs.oauth_data_source import NotionOAuth -from .. import api +from controllers.console import api from ..setup import setup_required from ..wraps import account_initialization_required @@ -29,9 +29,6 @@ def get_oauth_providers(): class OAuthDataSource(Resource): - @setup_required - @login_required - @account_initialization_required def get(self, provider: str): # The role of the current user in the table must be admin or owner if current_user.current_tenant.current_role not in ['admin', 'owner']: @@ -66,5 +63,5 @@ class OAuthDataSourceCallback(Resource): return redirect(f'{current_app.config.get("CONSOLE_URL")}?oauth_data_source=success') -api.add_resource(OAuthDataSource, '/oauth/data-source/') -api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/') +api.add_resource(OAuthDataSource, '/oauth/data-source/') +api.add_resource(OAuthDataSourceCallback, '/oauth/data-source/callback/') diff --git a/api/controllers/console/datasets/data_source.py b/api/controllers/console/datasets/data_source.py index 2c2b9a1fd8..d67aad3d61 100644 --- a/api/controllers/console/datasets/data_source.py +++ b/api/controllers/console/datasets/data_source.py @@ -1,32 +1,22 @@ import datetime -import hashlib import json -import tempfile -import time -import uuid -from pathlib import Path + from cachetools import TTLCache from flask import request, current_app from flask_login import login_required, current_user -from flask_restful import Resource, marshal_with, fields, reqparse +from flask_restful import Resource, marshal_with, fields, reqparse, marshal from werkzeug.exceptions import NotFound from controllers.console import api -from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \ - UnsupportedFileTypeError from controllers.console.setup import setup_required from controllers.console.wraps import account_initialization_required from core.data_source.notion import NotionPageReader -from core.index.readers.html_parser import HTMLParser -from core.index.readers.pdf_parser import PDFParser from core.indexing_runner import IndexingRunner -from extensions.ext_storage import storage -from libs.helper import TimestampField from extensions.ext_database import db +from libs.helper import TimestampField from libs.oauth_data_source import NotionOAuth from models.dataset import Document -from models.model import UploadFile from models.source import DataSourceBinding from services.dataset_service import DatasetService, DocumentService @@ -39,9 +29,35 @@ PREVIEW_WORDS_LIMIT = 3000 class DataSourceApi(Resource): + integrate_page_fields = { + 'page_name': fields.String, + 'page_id': fields.String, + 'page_icon': fields.String, + 'total': fields.Integer + } + integrate_workspace_fields = { + 'workspace_name': fields.String, + 'workspace_id': fields.String, + 'workspace_icon': fields.String, + 'pages': fields.List(fields.Nested(integrate_page_fields)) + } + integrate_fields = { + 'id': fields.String, + 'provider': fields.String, + 'created_at': TimestampField, + 'is_bound': fields.Boolean, + 'disabled': fields.Boolean, + 'link': fields.String, + 'source_info': fields.Nested(integrate_workspace_fields) + } + integrate_list_fields = { + 'data': fields.List(fields.Nested(integrate_fields)), + } + @setup_required @login_required @account_initialization_required + @marshal_with(integrate_list_fields) def get(self): # get workspace data source integrates data_source_integrates = db.session.query(DataSourceBinding).filter( @@ -76,8 +92,7 @@ class DataSourceApi(Resource): 'disabled': None, 'link': f'{base_url}{data_source_oauth_base_path}/{provider}' }) - - return {'data': integrate_data} + return {'data': integrate_data}, 200 @setup_required @login_required @@ -110,10 +125,25 @@ class DataSourceApi(Resource): class DataSourceNotionListApi(Resource): + integrate_page_fields = { + 'page_name': fields.String, + 'page_id': fields.String, + 'page_icon': fields.String + } + integrate_workspace_fields = { + 'workspace_name': fields.String, + 'workspace_id': fields.String, + 'workspace_icon': fields.String, + 'pages': fields.List(fields.Nested(integrate_page_fields)) + } + integrate_notion_info_list_fields = { + 'notion_info': fields.List(fields.Nested(integrate_workspace_fields)), + } @setup_required @login_required @account_initialization_required + @marshal_with(integrate_notion_info_list_fields) def get(self): dataset_id = request.args.get('dataset_id', default=None, type=str) exist_page_ids = [] @@ -143,9 +173,14 @@ class DataSourceNotionListApi(Resource): raise NotFound('Data source binding not found.') pre_import_info_list = [] for data_source_binding in data_source_bindings: - pages = NotionOAuth.get_authorized_pages(data_source_binding.access_token) + notion_oauth = NotionOAuth(client_id=current_app.config.get('NOTION_CLIENT_ID'), + client_secret=current_app.config.get( + 'NOTION_CLIENT_SECRET'), + redirect_uri=current_app.config.get( + 'CONSOLE_URL') + '/console/api/oauth/data-source/authorize/notion') + pages = notion_oauth.get_authorized_pages(data_source_binding.access_token) # Filter out already bound pages - filter_pages = filter(lambda page: page['page_id'] not in exist_page_ids, pages) + filter_pages = [page for page in pages if page['page_id'] not in exist_page_ids] source_info = json.loads(data_source_binding.source_info) pre_import_info = { 'workspace_name': source_info['workspace_name'], @@ -165,12 +200,14 @@ class DataSourceNotionApi(Resource): @login_required @account_initialization_required def get(self, workspace_id, page_id): + workspace_id = str(workspace_id) + page_id = str(page_id) data_source_binding = DataSourceBinding.query.filter( db.and_( DataSourceBinding.tenant_id == current_user.current_tenant_id, DataSourceBinding.provider == 'notion', DataSourceBinding.disabled == False, - DataSourceBinding.source_info['workspace_id'] == workspace_id + DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' ) ).first() if not data_source_binding: @@ -185,9 +222,8 @@ class DataSourceNotionApi(Resource): @login_required @account_initialization_required def post(self): - notion_import_info = request.get_json() parser = reqparse.RequestParser() - parser.add_argument('notion_info_list', type=dict, required=True, nullable=True, location='json') + parser.add_argument('notion_info_list', type=list, required=True, nullable=True, location='json') parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') args = parser.parse_args() # validate args @@ -197,7 +233,7 @@ class DataSourceNotionApi(Resource): return response, 200 -api.add_resource(DataSourceApi, '/oauth/data-source/integrates') -api.add_resource(DataSourceApi, '/oauth/data-source/integrates//') +api.add_resource(DataSourceApi, '/data-source/integrates', '/data-source/integrates//') api.add_resource(DataSourceNotionListApi, '/notion/pre-import/pages') -api.add_resource(DataSourceNotionApi, '/notion/workspaces//pages//preview') +api.add_resource(DataSourceNotionApi, '/notion/workspaces//pages//preview', + '/datasets/notion-indexing-estimate') diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 55193da432..873d439fcc 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -251,7 +251,7 @@ class DatasetDocumentListApi(Resource): class DatasetInitApi(Resource): dataset_and_document_fields = { 'dataset': fields.Nested(dataset_fields), - 'document': fields.Nested(document_fields) + 'documents': fields.List(fields.Nested(document_fields)) } @setup_required diff --git a/api/core/data_source/notion.py b/api/core/data_source/notion.py index efbaf9137f..7e18814c37 100644 --- a/api/core/data_source/notion.py +++ b/api/core/data_source/notion.py @@ -126,6 +126,7 @@ class NotionPageReader(BaseReader): cur_result_text_arr.append(children_text) cur_result_text = "\n".join(cur_result_text_arr) + cur_result_text += "\n\n" result_lines_arr.append(cur_result_text) if data["next_cursor"] is None: @@ -204,11 +205,11 @@ class NotionPageReader(BaseReader): page_ids = self.query_database(database_id) for page_id in page_ids: page_text = self.read_page(page_id) - docs.append(Document(page_text, extra_info={"page_id": page_id})) + docs.append(Document(page_text)) else: for page_id in page_ids: page_text = self.read_page(page_id) - docs.append(Document(page_text, extra_info={"page_id": page_id})) + docs.append(Document(page_text)) return docs @@ -223,12 +224,12 @@ class NotionPageReader(BaseReader): page_ids = self.query_database(database_id) for page_id in page_ids: page_text = self.read_page(page_id) - docs.append(Document(page_text, extra_info={"page_id": page_id})) + docs.append(Document(page_text)) else: for page_id in page_ids: page_text_list = self.read_page_as_documents(page_id) for page_text in page_text_list: - docs.append(Document(page_text, extra_info={"page_id": page_id})) + docs.append(Document(page_text)) return docs diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index f9dd7e99ff..2d855a3b96 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -215,12 +215,13 @@ class IndexingRunner: preview_texts = [] total_segments = 0 for notion_info in notion_info_list: + workspace_id = notion_info['workspace_id'] data_source_binding = DataSourceBinding.query.filter( db.and_( DataSourceBinding.tenant_id == current_user.current_tenant_id, DataSourceBinding.provider == 'notion', DataSourceBinding.disabled == False, - DataSourceBinding.source_info['workspace_id'] == notion_info['workspace_id'] + DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' ) ).first() if not data_source_binding: @@ -228,7 +229,7 @@ class IndexingRunner: reader = NotionPageReader(integration_token=data_source_binding.access_token) for page in notion_info['pages']: page_ids = [page['page_id']] - documents = reader.load_data(page_ids=page_ids) + documents = reader.load_data_as_documents(page_ids=page_ids) processing_rule = DatasetProcessRule( mode=tmp_processing_rule["mode"], @@ -279,7 +280,7 @@ class IndexingRunner: if not data_source_info or 'notion_page_id' not in data_source_info \ or 'notion_workspace_id' not in data_source_info: raise ValueError("no notion page found") - text_docs = self._load_data_from_notion(data_source_info['notion_workspace_id'], data_source_info['notion_page_id']) + text_docs = self._load_data_from_notion(data_source_info['notion_workspace_id'], data_source_info['notion_page_id'], document.tenant_id) # update document status to splitting self._update_document_index_status( document_id=document.id, @@ -319,13 +320,13 @@ class IndexingRunner: return text_docs - def _load_data_from_notion(self, workspace_id: str, page_id: str) -> List[Document]: + def _load_data_from_notion(self, workspace_id: str, page_id: str, tenant_id: str) -> List[Document]: data_source_binding = DataSourceBinding.query.filter( db.and_( - DataSourceBinding.tenant_id == current_user.current_tenant_id, + DataSourceBinding.tenant_id == tenant_id, DataSourceBinding.provider == 'notion', DataSourceBinding.disabled == False, - DataSourceBinding.source_info['workspace_id'] == workspace_id + DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' ) ).first() if not data_source_binding: diff --git a/api/libs/oauth_data_source.py b/api/libs/oauth_data_source.py index d8ea7dd58d..d3ddd114cf 100644 --- a/api/libs/oauth_data_source.py +++ b/api/libs/oauth_data_source.py @@ -65,7 +65,7 @@ class NotionOAuth(OAuthDataSource): data_source_binding = DataSourceBinding( tenant_id=current_user.current_tenant_id, access_token=access_token, - source_info=json.dumps(source_info), + source_info=source_info, provider='notion' ) db.session.add(data_source_binding) diff --git a/api/models/dataset.py b/api/models/dataset.py index 29588c1f38..fa81af5faa 100644 --- a/api/models/dataset.py +++ b/api/models/dataset.py @@ -190,7 +190,7 @@ class Document(db.Model): doc_type = db.Column(db.String(40), nullable=True) doc_metadata = db.Column(db.JSON, nullable=True) - DATA_SOURCES = ['upload_file'] + DATA_SOURCES = ['upload_file', 'notion_import'] @property def display_status(self): diff --git a/api/models/source.py b/api/models/source.py index 399a7d50ab..c7c04075bc 100644 --- a/api/models/source.py +++ b/api/models/source.py @@ -7,7 +7,8 @@ class DataSourceBinding(db.Model): __tablename__ = 'data_source_bindings' __table_args__ = ( db.PrimaryKeyConstraint('id', name='source_binding_pkey'), - db.Index('source_binding_tenant_id_idx', 'tenant_id') + db.Index('source_binding_tenant_id_idx', 'tenant_id'), + db.Index('source_info_idx', "source_info", postgresql_using='gin') ) id = db.Column(UUID, server_default=db.text('uuid_generate_v4()')) diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index bcf72ed12b..1fce9413bf 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -423,7 +423,7 @@ class DocumentService: DataSourceBinding.tenant_id == current_user.current_tenant_id, DataSourceBinding.provider == 'notion', DataSourceBinding.disabled == False, - DataSourceBinding.source_info['workspace_id'] == workspace_id + DataSourceBinding.source_info['workspace_id'] == f'"{workspace_id}"' ) ).first() if not data_source_binding: @@ -581,7 +581,7 @@ class DocumentService: if 'notion_info_list' not in args or not args['notion_info_list']: raise ValueError("Notion info is required") - if not isinstance(args['notion_info_list'], dict): + if not isinstance(args['notion_info_list'], list): raise ValueError("Notion info is invalid") if 'process_rule' not in args or not args['process_rule']: