|
|
|
|
@ -23,7 +23,7 @@ import type { DataSet } from '@/models/datasets'
|
|
|
|
|
import { fetchDatasets } from '@/service/datasets'
|
|
|
|
|
import useNodeCrud from '@/app/components/workflow/nodes/_base/hooks/use-node-crud'
|
|
|
|
|
import useOneStepRun from '@/app/components/workflow/nodes/_base/hooks/use-one-step-run'
|
|
|
|
|
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
|
|
|
|
import { useCurrentProviderAndModel, useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
|
|
|
|
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
|
|
|
|
|
|
|
|
|
const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|
|
|
|
@ -34,6 +34,8 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|
|
|
|
const startNodeId = startNode?.id
|
|
|
|
|
const { inputs, setInputs: doSetInputs } = useNodeCrud<KnowledgeRetrievalNodeType>(id, payload)
|
|
|
|
|
|
|
|
|
|
const inputRef = useRef(inputs)
|
|
|
|
|
|
|
|
|
|
const setInputs = useCallback((s: KnowledgeRetrievalNodeType) => {
|
|
|
|
|
const newInputs = produce(s, (draft) => {
|
|
|
|
|
if (s.retrieval_mode === RETRIEVE_TYPE.multiWay)
|
|
|
|
|
@ -43,13 +45,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|
|
|
|
})
|
|
|
|
|
// not work in pass to draft...
|
|
|
|
|
doSetInputs(newInputs)
|
|
|
|
|
inputRef.current = newInputs
|
|
|
|
|
}, [doSetInputs])
|
|
|
|
|
|
|
|
|
|
const inputRef = useRef(inputs)
|
|
|
|
|
useEffect(() => {
|
|
|
|
|
inputRef.current = inputs
|
|
|
|
|
}, [inputs])
|
|
|
|
|
|
|
|
|
|
const handleQueryVarChange = useCallback((newVar: ValueSelector | string) => {
|
|
|
|
|
const newInputs = produce(inputs, (draft) => {
|
|
|
|
|
draft.query_variable_selector = newVar as ValueSelector
|
|
|
|
|
@ -63,9 +61,22 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|
|
|
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
|
|
|
|
|
|
|
|
|
|
const {
|
|
|
|
|
modelList: rerankModelList,
|
|
|
|
|
defaultModel: rerankDefaultModel,
|
|
|
|
|
} = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.rerank)
|
|
|
|
|
|
|
|
|
|
const {
|
|
|
|
|
currentModel: currentRerankModel,
|
|
|
|
|
} = useCurrentProviderAndModel(
|
|
|
|
|
rerankModelList,
|
|
|
|
|
rerankDefaultModel
|
|
|
|
|
? {
|
|
|
|
|
...rerankDefaultModel,
|
|
|
|
|
provider: rerankDefaultModel.provider.provider,
|
|
|
|
|
}
|
|
|
|
|
: undefined,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
const handleModelChanged = useCallback((model: { provider: string; modelId: string; mode?: string }) => {
|
|
|
|
|
const newInputs = produce(inputRef.current, (draft) => {
|
|
|
|
|
if (!draft.single_retrieval_config) {
|
|
|
|
|
@ -110,7 +121,7 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|
|
|
|
// set defaults models
|
|
|
|
|
useEffect(() => {
|
|
|
|
|
const inputs = inputRef.current
|
|
|
|
|
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider)
|
|
|
|
|
if (inputs.retrieval_mode === RETRIEVE_TYPE.multiWay && inputs.multiple_retrieval_config?.reranking_model?.provider && currentRerankModel && rerankDefaultModel)
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
if (inputs.retrieval_mode === RETRIEVE_TYPE.oneWay && inputs.single_retrieval_config?.model?.provider)
|
|
|
|
|
@ -130,7 +141,6 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const multipleRetrievalConfig = draft.multiple_retrieval_config
|
|
|
|
|
draft.multiple_retrieval_config = {
|
|
|
|
|
top_k: multipleRetrievalConfig?.top_k || DATASET_DEFAULT.top_k,
|
|
|
|
|
@ -138,6 +148,9 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|
|
|
|
reranking_model: multipleRetrievalConfig?.reranking_model,
|
|
|
|
|
reranking_mode: multipleRetrievalConfig?.reranking_mode,
|
|
|
|
|
weights: multipleRetrievalConfig?.weights,
|
|
|
|
|
reranking_enable: multipleRetrievalConfig?.reranking_enable !== undefined
|
|
|
|
|
? multipleRetrievalConfig.reranking_enable
|
|
|
|
|
: Boolean(currentRerankModel && rerankDefaultModel),
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
setInputs(newInput)
|
|
|
|
|
@ -194,14 +207,14 @@ const useConfig = (id: string, payload: KnowledgeRetrievalNodeType) => {
|
|
|
|
|
}, [])
|
|
|
|
|
|
|
|
|
|
useEffect(() => {
|
|
|
|
|
const inputs = inputRef.current
|
|
|
|
|
let query_variable_selector: ValueSelector = inputs.query_variable_selector
|
|
|
|
|
if (isChatMode && inputs.query_variable_selector.length === 0 && startNodeId)
|
|
|
|
|
query_variable_selector = [startNodeId, 'sys.query']
|
|
|
|
|
|
|
|
|
|
setInputs({
|
|
|
|
|
...inputs,
|
|
|
|
|
query_variable_selector,
|
|
|
|
|
})
|
|
|
|
|
setInputs(produce(inputs, (draft) => {
|
|
|
|
|
draft.query_variable_selector = query_variable_selector
|
|
|
|
|
}))
|
|
|
|
|
// eslint-disable-next-line react-hooks/exhaustive-deps
|
|
|
|
|
}, [])
|
|
|
|
|
|
|
|
|
|
|