refactor: step 2

pull/12097/head
AkaraChen 1 year ago
parent a77aa169b4
commit dfdc4ed3b1

@ -9,7 +9,6 @@ import {
RiSearchEyeLine, RiSearchEyeLine,
} from '@remixicon/react' } from '@remixicon/react'
import Link from 'next/link' import Link from 'next/link'
import { groupBy } from 'lodash-es'
import Image from 'next/image' import Image from 'next/image'
import SettingCog from '../assets/setting-gear-mod.svg' import SettingCog from '../assets/setting-gear-mod.svg'
import OrangeEffect from '../assets/option-card-effect-orange.svg' import OrangeEffect from '../assets/option-card-effect-orange.svg'
@ -17,23 +16,21 @@ import FamilyMod from '../assets/family-mod.svg'
import Note from '../assets/note-mod.svg' import Note from '../assets/note-mod.svg'
import FileList from '../assets/file-list-3-fill.svg' import FileList from '../assets/file-list-3-fill.svg'
import { indexMethodIcon } from '../icons' import { indexMethodIcon } from '../icons'
import PreviewItem, { PreviewType } from './preview-item'
import s from './index.module.css' import s from './index.module.css'
import unescape from './unescape' import unescape from './unescape'
import escape from './escape' import escape from './escape'
import { OptionCard } from './option-card' import { OptionCard } from './option-card'
import LanguageSelect from './language-select' import LanguageSelect from './language-select'
import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs' import { DelimiterInput, MaxLengthInput, OverlapInput } from './inputs'
import PreviewItem, { PreviewType } from './preview-item'
import cn from '@/utils/classnames' import cn from '@/utils/classnames'
import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, FileIndexingEstimateResponse, FullDocumentDetail, IndexingEstimateParams, NotionInfo, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets' import type { CrawlOptions, CrawlResultItem, CreateDocumentReq, CustomFile, FullDocumentDetail, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets'
import { import {
createDocument, createDocument,
createFirstDocument, createFirstDocument,
fetchFileIndexingEstimate as didFetchFileIndexingEstimate,
fetchDefaultProcessRule, fetchDefaultProcessRule,
} from '@/service/datasets' } from '@/service/datasets'
import Button from '@/app/components/base/button' import Button from '@/app/components/base/button'
import Loading from '@/app/components/base/loading'
import FloatRightContainer from '@/app/components/base/float-right-container' import FloatRightContainer from '@/app/components/base/float-right-container'
import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config' import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-method-config'
import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config' import EconomicalRetrievalMethodConfig from '@/app/components/datasets/common/economical-retrieval-method-config'
@ -58,6 +55,8 @@ import { MessageChatSquare } from '@/app/components/base/icons/src/public/common
import { IS_CE_EDITION } from '@/config' import { IS_CE_EDITION } from '@/config'
import Switch from '@/app/components/base/switch' import Switch from '@/app/components/base/switch'
import Divider from '@/app/components/base/divider' import Divider from '@/app/components/base/divider'
import { getNotionInfo, getWebsiteInfo, useFetchFileIndexingEstimateForFile, useFetchFileIndexingEstimateForNotion, useFetchFileIndexingEstimateForWeb } from '@/service/use-datasets'
import Loading from '@/app/components/base/loading'
const TextLabel: FC<PropsWithChildren> = (props) => { const TextLabel: FC<PropsWithChildren> = (props) => {
return <label className='text-text-secondary text-xs font-semibold leading-none'>{props.children}</label> return <label className='text-text-secondary text-xs font-semibold leading-none'>{props.children}</label>
@ -87,7 +86,7 @@ type StepTwoProps = {
onCancel?: () => void onCancel?: () => void
} }
enum SegmentType { export enum SegmentType {
AUTO = 'automatic', AUTO = 'automatic',
CUSTOM = 'custom', CUSTOM = 'custom',
} }
@ -176,17 +175,92 @@ const StepTwo = ({
) )
const [QATipHide, setQATipHide] = useState(false) const [QATipHide, setQATipHide] = useState(false)
const [previewSwitched, setPreviewSwitched] = useState(false) const [previewSwitched, setPreviewSwitched] = useState(false)
const [customFileIndexingEstimate, setCustomFileIndexingEstimate] = useState<FileIndexingEstimateResponse | null>(null)
const [automaticFileIndexingEstimate, setAutomaticFileIndexingEstimate] = useState<FileIndexingEstimateResponse | null>(null)
const fileIndexingEstimate = segmentationType === SegmentType.AUTO
? automaticFileIndexingEstimate
: customFileIndexingEstimate
const [isCreating, setIsCreating] = useState(false) const [isCreating, setIsCreating] = useState(false)
const [parentChildConfig, setParentChildConfig] = useState<ParentChildConfig>(defaultParentChildConfig) const [parentChildConfig, setParentChildConfig] = useState<ParentChildConfig>(defaultParentChildConfig)
const getIndexing_technique = () => indexingType || indexType
const getProcessRule = () => {
const processRule: ProcessRule = {
rules: {} as any, // api will check this. It will be removed after api refactored.
mode: segmentationType,
}
if (segmentationType === SegmentType.CUSTOM) {
const ruleObj = {
pre_processing_rules: rules,
segmentation: {
separator: unescape(segmentIdentifier),
max_tokens: max,
chunk_overlap: overlap,
},
}
processRule.rules = ruleObj
}
return processRule
}
const fileIndexingEstimateQuery = useFetchFileIndexingEstimateForFile({
docForm: docForm as DocForm,
docLanguage,
dataSourceType: DataSourceType.FILE,
files,
indexingTechnique: getIndexing_technique() as any,
processRule: getProcessRule(),
dataset_id: datasetId!,
})
const notionIndexingEstimateQuery = useFetchFileIndexingEstimateForNotion({
docForm: docForm as DocForm,
docLanguage,
dataSourceType: DataSourceType.NOTION,
notionPages,
indexingTechnique: getIndexing_technique() as any,
processRule: getProcessRule(),
dataset_id: datasetId || '',
})
const websiteIndexingEstimateQuery = useFetchFileIndexingEstimateForWeb({
docForm: docForm as DocForm,
docLanguage,
dataSourceType: DataSourceType.WEB,
websitePages,
crawlOptions,
websiteCrawlProvider,
websiteCrawlJobId,
indexingTechnique: getIndexing_technique() as any,
processRule: getProcessRule(),
dataset_id: datasetId || '',
})
const fetchEstimate = useCallback(() => {
if (dataSourceType === DataSourceType.FILE)
fileIndexingEstimateQuery.mutate()
if (dataSourceType === DataSourceType.NOTION)
notionIndexingEstimateQuery.mutate()
if (dataSourceType === DataSourceType.WEB)
websiteIndexingEstimateQuery.mutate()
}, [dataSourceType, fileIndexingEstimateQuery, notionIndexingEstimateQuery, websiteIndexingEstimateQuery])
const estimate
= dataSourceType === DataSourceType.FILE
? fileIndexingEstimateQuery.data
: dataSourceType === DataSourceType.NOTION
? notionIndexingEstimateQuery.data
: websiteIndexingEstimateQuery.data
const getIsEstimateReady = useCallback(() => {
if (dataSourceType === DataSourceType.FILE)
return fileIndexingEstimateQuery.isSuccess
if (dataSourceType === DataSourceType.NOTION)
return notionIndexingEstimateQuery.isSuccess
if (dataSourceType === DataSourceType.WEB)
return websiteIndexingEstimateQuery.isSuccess
}, [dataSourceType, fileIndexingEstimateQuery.isSuccess, notionIndexingEstimateQuery.isSuccess, websiteIndexingEstimateQuery.isSuccess])
const getFileName = (name: string) => { const getFileName = (name: string) => {
const arr = name.split('.') const arr = name.split('.')
return arr.slice(0, -1).join('.') return arr.slice(0, -1).join('.')
@ -224,122 +298,15 @@ const StepTwo = ({
setParentChildConfig(defaultParentChildConfig) setParentChildConfig(defaultParentChildConfig)
} }
const fetchFileIndexingEstimate = async (docForm = DocForm.TEXT, language?: string) => {
// eslint-disable-next-line @typescript-eslint/no-use-before-define
const res = await didFetchFileIndexingEstimate(getFileIndexingEstimateParams(docForm, language)!)
if (segmentationType === SegmentType.CUSTOM)
setCustomFileIndexingEstimate(res)
else
setAutomaticFileIndexingEstimate(res)
}
const updatePreview = () => { const updatePreview = () => {
if (segmentationType === SegmentType.CUSTOM && max > 4000) { if (segmentationType === SegmentType.CUSTOM && max > 4000) {
Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') }) Toast.notify({ type: 'error', message: t('datasetCreation.stepTwo.maxLengthCheck') })
return return
} }
setCustomFileIndexingEstimate(null) fetchEstimate()
fetchFileIndexingEstimate()
setPreviewSwitched(false) setPreviewSwitched(false)
} }
const getIndexing_technique = () => indexingType || indexType
const getProcessRule = () => {
const processRule: ProcessRule = {
rules: {} as any, // api will check this. It will be removed after api refactored.
mode: segmentationType,
}
if (segmentationType === SegmentType.CUSTOM) {
const ruleObj = {
pre_processing_rules: rules,
segmentation: {
separator: unescape(segmentIdentifier),
max_tokens: max,
chunk_overlap: overlap,
},
}
processRule.rules = ruleObj
}
return processRule
}
const getNotionInfo = () => {
const workspacesMap = groupBy(notionPages, 'workspace_id')
const workspaces = Object.keys(workspacesMap).map((workspaceId) => {
return {
workspaceId,
pages: workspacesMap[workspaceId],
}
})
return workspaces.map((workspace) => {
return {
workspace_id: workspace.workspaceId,
pages: workspace.pages.map((page) => {
const { page_id, page_name, page_icon, type } = page
return {
page_id,
page_name,
page_icon,
type,
}
}),
}
}) as NotionInfo[]
}
const getWebsiteInfo = () => {
return {
provider: websiteCrawlProvider,
job_id: websiteCrawlJobId,
urls: websitePages.map(page => page.source_url),
only_main_content: crawlOptions?.only_main_content,
}
}
const getFileIndexingEstimateParams = (docForm: DocForm, language?: string): IndexingEstimateParams | undefined => {
if (dataSourceType === DataSourceType.FILE) {
return {
info_list: {
data_source_type: dataSourceType,
file_info_list: {
file_ids: files.map(file => file.id) as string[],
},
},
indexing_technique: getIndexing_technique() as string,
process_rule: getProcessRule(),
doc_form: docForm,
doc_language: language || docLanguage,
dataset_id: datasetId as string,
}
}
if (dataSourceType === DataSourceType.NOTION) {
return {
info_list: {
data_source_type: dataSourceType,
notion_info_list: getNotionInfo(),
},
indexing_technique: getIndexing_technique() as string,
process_rule: getProcessRule(),
doc_form: docForm,
doc_language: language || docLanguage,
dataset_id: datasetId as string,
}
}
if (dataSourceType === DataSourceType.WEB) {
return {
info_list: {
data_source_type: dataSourceType,
website_info_list: getWebsiteInfo(),
},
indexing_technique: getIndexing_technique() as string,
process_rule: getProcessRule(),
doc_form: docForm,
doc_language: language || docLanguage,
dataset_id: datasetId as string,
}
}
}
const { const {
modelList: rerankModelList, modelList: rerankModelList,
defaultModel: rerankDefaultModel, defaultModel: rerankDefaultModel,
@ -423,10 +390,15 @@ const StepTwo = ({
} }
} }
if (dataSourceType === DataSourceType.NOTION) if (dataSourceType === DataSourceType.NOTION)
params.data_source.info_list.notion_info_list = getNotionInfo() params.data_source.info_list.notion_info_list = getNotionInfo(notionPages)
if (dataSourceType === DataSourceType.WEB) if (dataSourceType === DataSourceType.WEB) {
params.data_source.info_list.website_info_list = getWebsiteInfo() params.data_source.info_list.website_info_list = getWebsiteInfo({
websiteCrawlProvider,
websiteCrawlJobId,
websitePages,
})
}
} }
return params return params
} }
@ -519,16 +491,7 @@ const StepTwo = ({
const previewSwitch = async (language?: string) => { const previewSwitch = async (language?: string) => {
setPreviewSwitched(true) setPreviewSwitched(true)
setIsLanguageSelectDisabled(true) setIsLanguageSelectDisabled(true)
if (segmentationType === SegmentType.AUTO) fetchEstimate()
setAutomaticFileIndexingEstimate(null)
else
setCustomFileIndexingEstimate(null)
try {
await fetchFileIndexingEstimate(DocForm.QA, language)
}
finally {
setIsLanguageSelectDisabled(false)
}
} }
const handleSelect = (language: string) => { const handleSelect = (language: string) => {
@ -570,18 +533,6 @@ const StepTwo = ({
setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL) setIndexType(isAPIKeySet ? IndexingType.QUALIFIED : IndexingType.ECONOMICAL)
}, [isAPIKeySet, indexingType, datasetId]) }, [isAPIKeySet, indexingType, datasetId])
useEffect(() => {
if (segmentationType === SegmentType.AUTO) {
setAutomaticFileIndexingEstimate(null)
fetchFileIndexingEstimate()
setPreviewSwitched(false)
}
else {
setCustomFileIndexingEstimate(null)
setPreviewSwitched(false)
}
}, [segmentationType, indexType])
const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || { const [retrievalConfig, setRetrievalConfig] = useState(currentDataset?.retrieval_model_dict || {
search_method: RETRIEVE_METHOD.semantic, search_method: RETRIEVE_METHOD.semantic,
reranking_enable: false, reranking_enable: false,
@ -971,26 +922,26 @@ const StepTwo = ({
)} )}
</div> </div>
<div className='my-4 px-8 space-y-4'> <div className='my-4 px-8 space-y-4'>
{previewSwitched && docForm === DocForm.QA && fileIndexingEstimate?.qa_preview && ( {previewSwitched && docForm === DocForm.QA && estimate?.qa_preview && (
<> <>
{fileIndexingEstimate?.qa_preview.map((item, index) => ( {estimate?.qa_preview.map((item, index) => (
<PreviewItem type={PreviewType.QA} key={item.question} qa={item} index={index + 1} /> <PreviewItem type={PreviewType.QA} key={item.question} qa={item} index={index + 1} />
))} ))}
</> </>
)} )}
{(docForm === DocForm.TEXT || !previewSwitched) && fileIndexingEstimate?.preview && ( {(docForm === DocForm.TEXT || !previewSwitched) && estimate?.preview && (
<> <>
{fileIndexingEstimate?.preview.map((item, index) => ( {estimate?.preview.map((item, index) => (
<PreviewItem type={PreviewType.TEXT} key={item} content={item} index={index + 1} /> <PreviewItem type={PreviewType.TEXT} key={item} content={item} index={index + 1} />
))} ))}
</> </>
)} )}
{previewSwitched && docForm === DocForm.QA && !fileIndexingEstimate?.qa_preview && ( {previewSwitched && docForm === DocForm.QA && !estimate?.qa_preview && (
<div className='flex items-center justify-center h-[200px]'> <div className='flex items-center justify-center h-[200px]'>
<Loading type='area' /> <Loading type='area' />
</div> </div>
)} )}
{!previewSwitched && !fileIndexingEstimate?.preview && ( {!previewSwitched && !estimate?.preview && (
<div className='flex items-center justify-center h-[200px]'> <div className='flex items-center justify-center h-[200px]'>
<Loading type='area' /> <Loading type='area' />
</div> </div>

@ -1,11 +1,12 @@
import groupBy from 'lodash-es/groupBy' import groupBy from 'lodash-es/groupBy'
import type { MutationOptions } from '@tanstack/react-query'
import { useMutation } from '@tanstack/react-query' import { useMutation } from '@tanstack/react-query'
import { fetchFileIndexingEstimate } from './datasets' import { fetchFileIndexingEstimate } from './datasets'
import type { IndexingType } from '@/app/components/datasets/create/step-two' import { type IndexingType } from '@/app/components/datasets/create/step-two'
import type { CrawlOptions, CrawlResultItem, CustomFile, DataSourceType, DocForm, IndexingEstimateParams, NotionInfo, ProcessRule } from '@/models/datasets' import type { CrawlOptions, CrawlResultItem, CustomFile, DataSourceType, DocForm, FileIndexingEstimateResponse, IndexingEstimateParams, NotionInfo, ProcessRule } from '@/models/datasets'
import type { DataSourceProvider, NotionPage } from '@/models/common' import type { DataSourceProvider, NotionPage } from '@/models/common'
const getNotionInfo = ( export const getNotionInfo = (
notionPages: NotionPage[], notionPages: NotionPage[],
) => { ) => {
const workspacesMap = groupBy(notionPages, 'workspace_id') const workspacesMap = groupBy(notionPages, 'workspace_id')
@ -31,7 +32,7 @@ const getNotionInfo = (
}) as NotionInfo[] }) as NotionInfo[]
} }
const getWebsiteInfo = ( export const getWebsiteInfo = (
opts: { opts: {
websiteCrawlProvider: DataSourceProvider websiteCrawlProvider: DataSourceProvider
websiteCrawlJobId: string websiteCrawlJobId: string
@ -152,30 +153,36 @@ const getFileIndexingEstimateParamsForWeb = ({
export const useFetchFileIndexingEstimateForFile = ( export const useFetchFileIndexingEstimateForFile = (
options: GetFileIndexingEstimateParamsOptionFile, options: GetFileIndexingEstimateParamsOptionFile,
mutationOptions: MutationOptions<FileIndexingEstimateResponse> = {},
) => { ) => {
return useMutation({ return useMutation({
mutationFn: async () => { mutationFn: async () => {
return fetchFileIndexingEstimate(getFileIndexingEstimateParamsForFile(options)) return fetchFileIndexingEstimate(getFileIndexingEstimateParamsForFile(options))
}, },
...mutationOptions,
}) })
} }
export const useFetchFileIndexingEstimateForNotion = ( export const useFetchFileIndexingEstimateForNotion = (
options: GetFileIndexingEstimateParamsOptionNotion, options: GetFileIndexingEstimateParamsOptionNotion,
mutationOptions: MutationOptions<FileIndexingEstimateResponse> = {},
) => { ) => {
return useMutation({ return useMutation({
mutationFn: async () => { mutationFn: async () => {
return fetchFileIndexingEstimate(getFileIndexingEstimateParamsForNotion(options)) return fetchFileIndexingEstimate(getFileIndexingEstimateParamsForNotion(options))
}, },
...mutationOptions,
}) })
} }
export const useFetchFileIndexingEstimateForWeb = ( export const useFetchFileIndexingEstimateForWeb = (
options: GetFileIndexingEstimateParamsOptionWeb, options: GetFileIndexingEstimateParamsOptionWeb,
mutationOptions: MutationOptions<FileIndexingEstimateResponse> = {},
) => { ) => {
return useMutation({ return useMutation({
mutationFn: async () => { mutationFn: async () => {
return fetchFileIndexingEstimate(getFileIndexingEstimateParamsForWeb(options)) return fetchFileIndexingEstimate(getFileIndexingEstimateParamsForWeb(options))
}, },
...mutationOptions,
}) })
} }

Loading…
Cancel
Save