|
|
|
@ -21,6 +21,7 @@ import {
|
|
|
|
PortalToFollowElemTrigger,
|
|
|
|
PortalToFollowElemTrigger,
|
|
|
|
} from '@/app/components/base/portal-to-follow-elem'
|
|
|
|
} from '@/app/components/base/portal-to-follow-elem'
|
|
|
|
import LLMParamsPanel from './llm-params-panel'
|
|
|
|
import LLMParamsPanel from './llm-params-panel'
|
|
|
|
|
|
|
|
import TTSParamsPanel from './tts-params-panel'
|
|
|
|
import { useProviderContext } from '@/context/provider-context'
|
|
|
|
import { useProviderContext } from '@/context/provider-context'
|
|
|
|
import cn from '@/utils/classnames'
|
|
|
|
import cn from '@/utils/classnames'
|
|
|
|
|
|
|
|
|
|
|
|
@ -28,12 +29,8 @@ export type ModelParameterModalProps = {
|
|
|
|
popupClassName?: string
|
|
|
|
popupClassName?: string
|
|
|
|
portalToFollowElemContentClassName?: string
|
|
|
|
portalToFollowElemContentClassName?: string
|
|
|
|
isAdvancedMode: boolean
|
|
|
|
isAdvancedMode: boolean
|
|
|
|
mode: string
|
|
|
|
value: any
|
|
|
|
modelId: string
|
|
|
|
setModel: (model: any) => void
|
|
|
|
provider: string
|
|
|
|
|
|
|
|
setModel: (model: { modelId: string; provider: string; mode?: string; features?: string[] }) => void
|
|
|
|
|
|
|
|
completionParams: FormValue
|
|
|
|
|
|
|
|
onCompletionParamsChange: (newParams: FormValue) => void
|
|
|
|
|
|
|
|
renderTrigger?: (v: TriggerProps) => ReactNode
|
|
|
|
renderTrigger?: (v: TriggerProps) => ReactNode
|
|
|
|
readonly?: boolean
|
|
|
|
readonly?: boolean
|
|
|
|
isInWorkflow?: boolean
|
|
|
|
isInWorkflow?: boolean
|
|
|
|
@ -44,15 +41,12 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
|
|
|
popupClassName,
|
|
|
|
popupClassName,
|
|
|
|
portalToFollowElemContentClassName,
|
|
|
|
portalToFollowElemContentClassName,
|
|
|
|
isAdvancedMode,
|
|
|
|
isAdvancedMode,
|
|
|
|
modelId,
|
|
|
|
value,
|
|
|
|
provider,
|
|
|
|
|
|
|
|
setModel,
|
|
|
|
setModel,
|
|
|
|
completionParams,
|
|
|
|
|
|
|
|
onCompletionParamsChange,
|
|
|
|
|
|
|
|
renderTrigger,
|
|
|
|
renderTrigger,
|
|
|
|
readonly,
|
|
|
|
readonly,
|
|
|
|
isInWorkflow,
|
|
|
|
isInWorkflow,
|
|
|
|
scope = 'text-generation',
|
|
|
|
scope = ModelTypeEnum.textGeneration,
|
|
|
|
}) => {
|
|
|
|
}) => {
|
|
|
|
const { t } = useTranslation()
|
|
|
|
const { t } = useTranslation()
|
|
|
|
const { isAPIKeySet } = useProviderContext()
|
|
|
|
const { isAPIKeySet } = useProviderContext()
|
|
|
|
@ -79,29 +73,29 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
|
|
|
...moderationList,
|
|
|
|
...moderationList,
|
|
|
|
]
|
|
|
|
]
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (scopeArray.includes('text-generation'))
|
|
|
|
if (scopeArray.includes(ModelTypeEnum.textGeneration))
|
|
|
|
return textGenerationList
|
|
|
|
return textGenerationList
|
|
|
|
if (scopeArray.includes('embedding'))
|
|
|
|
if (scopeArray.includes(ModelTypeEnum.textEmbedding))
|
|
|
|
return textEmbeddingList
|
|
|
|
return textEmbeddingList
|
|
|
|
if (scopeArray.includes('rerank'))
|
|
|
|
if (scopeArray.includes(ModelTypeEnum.rerank))
|
|
|
|
return rerankList
|
|
|
|
return rerankList
|
|
|
|
if (scopeArray.includes('moderation'))
|
|
|
|
if (scopeArray.includes(ModelTypeEnum.moderation))
|
|
|
|
return moderationList
|
|
|
|
return moderationList
|
|
|
|
if (scopeArray.includes('stt'))
|
|
|
|
if (scopeArray.includes(ModelTypeEnum.speech2text))
|
|
|
|
return sttList
|
|
|
|
return sttList
|
|
|
|
if (scopeArray.includes('tts'))
|
|
|
|
if (scopeArray.includes(ModelTypeEnum.tts))
|
|
|
|
return ttsList
|
|
|
|
return ttsList
|
|
|
|
return resultList
|
|
|
|
return resultList
|
|
|
|
}, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList])
|
|
|
|
}, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList])
|
|
|
|
|
|
|
|
|
|
|
|
const { currentProvider, currentModel } = useMemo(() => {
|
|
|
|
const { currentProvider, currentModel } = useMemo(() => {
|
|
|
|
const currentProvider = scopedModelList.find(item => item.provider === provider)
|
|
|
|
const currentProvider = scopedModelList.find(item => item.provider === value?.provider)
|
|
|
|
const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === modelId)
|
|
|
|
const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model)
|
|
|
|
return {
|
|
|
|
return {
|
|
|
|
currentProvider,
|
|
|
|
currentProvider,
|
|
|
|
currentModel,
|
|
|
|
currentModel,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}, [provider, modelId, scopedModelList])
|
|
|
|
}, [scopedModelList, value?.provider, value?.model])
|
|
|
|
|
|
|
|
|
|
|
|
const hasDeprecated = useMemo(() => {
|
|
|
|
const hasDeprecated = useMemo(() => {
|
|
|
|
return !currentProvider || !currentModel
|
|
|
|
return !currentProvider || !currentModel
|
|
|
|
@ -116,11 +110,33 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
|
|
|
const handleChangeModel = ({ provider, model }: DefaultModel) => {
|
|
|
|
const handleChangeModel = ({ provider, model }: DefaultModel) => {
|
|
|
|
const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider)
|
|
|
|
const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider)
|
|
|
|
const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model)
|
|
|
|
const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model)
|
|
|
|
|
|
|
|
const model_type = targetModelItem?.model_type as string
|
|
|
|
setModel({
|
|
|
|
setModel({
|
|
|
|
modelId: model,
|
|
|
|
|
|
|
|
provider,
|
|
|
|
provider,
|
|
|
|
|
|
|
|
model,
|
|
|
|
|
|
|
|
model_type,
|
|
|
|
|
|
|
|
...(model_type === ModelTypeEnum.textGeneration ? {
|
|
|
|
mode: targetModelItem?.model_properties.mode as string,
|
|
|
|
mode: targetModelItem?.model_properties.mode as string,
|
|
|
|
features: targetModelItem?.features || [],
|
|
|
|
} : {}),
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const handleLLMParamsChange = (newParams: FormValue) => {
|
|
|
|
|
|
|
|
const newValue = {
|
|
|
|
|
|
|
|
...(value?.completionParams || {}),
|
|
|
|
|
|
|
|
completion_params: newParams,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
setModel({
|
|
|
|
|
|
|
|
...value,
|
|
|
|
|
|
|
|
...newValue,
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const handleTTSParamsChange = (language: string, voice: string) => {
|
|
|
|
|
|
|
|
setModel({
|
|
|
|
|
|
|
|
...value,
|
|
|
|
|
|
|
|
language,
|
|
|
|
|
|
|
|
voice,
|
|
|
|
})
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
@ -149,8 +165,8 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
|
|
|
hasDeprecated,
|
|
|
|
hasDeprecated,
|
|
|
|
currentProvider,
|
|
|
|
currentProvider,
|
|
|
|
currentModel,
|
|
|
|
currentModel,
|
|
|
|
providerName: provider,
|
|
|
|
providerName: value?.provider,
|
|
|
|
modelId,
|
|
|
|
modelId: value?.model,
|
|
|
|
})
|
|
|
|
})
|
|
|
|
: (
|
|
|
|
: (
|
|
|
|
<Trigger
|
|
|
|
<Trigger
|
|
|
|
@ -160,8 +176,8 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
|
|
|
hasDeprecated={hasDeprecated}
|
|
|
|
hasDeprecated={hasDeprecated}
|
|
|
|
currentProvider={currentProvider}
|
|
|
|
currentProvider={currentProvider}
|
|
|
|
currentModel={currentModel}
|
|
|
|
currentModel={currentModel}
|
|
|
|
providerName={provider}
|
|
|
|
providerName={value?.provider}
|
|
|
|
modelId={modelId}
|
|
|
|
modelId={value?.model}
|
|
|
|
/>
|
|
|
|
/>
|
|
|
|
)
|
|
|
|
)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
@ -174,7 +190,7 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
|
|
|
{t('common.modelProvider.model').toLocaleUpperCase()}
|
|
|
|
{t('common.modelProvider.model').toLocaleUpperCase()}
|
|
|
|
</div>
|
|
|
|
</div>
|
|
|
|
<ModelSelector
|
|
|
|
<ModelSelector
|
|
|
|
defaultModel={(provider || modelId) ? { provider, model: modelId } : undefined}
|
|
|
|
defaultModel={(value?.provider || value?.model) ? { provider: value?.provider, model: value?.model } : undefined}
|
|
|
|
modelList={scopedModelList}
|
|
|
|
modelList={scopedModelList}
|
|
|
|
scopeFeatures={scopeFeatures}
|
|
|
|
scopeFeatures={scopeFeatures}
|
|
|
|
onSelect={handleChangeModel}
|
|
|
|
onSelect={handleChangeModel}
|
|
|
|
@ -185,13 +201,21 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
|
|
|
)}
|
|
|
|
)}
|
|
|
|
{currentModel?.model_type === ModelTypeEnum.textGeneration && (
|
|
|
|
{currentModel?.model_type === ModelTypeEnum.textGeneration && (
|
|
|
|
<LLMParamsPanel
|
|
|
|
<LLMParamsPanel
|
|
|
|
provider={provider}
|
|
|
|
provider={value?.provider}
|
|
|
|
modelId={modelId}
|
|
|
|
modelId={value?.model}
|
|
|
|
completionParams={completionParams}
|
|
|
|
completionParams={value?.completion_params || {}}
|
|
|
|
onCompletionParamsChange={onCompletionParamsChange}
|
|
|
|
onCompletionParamsChange={handleLLMParamsChange}
|
|
|
|
isAdvancedMode={isAdvancedMode}
|
|
|
|
isAdvancedMode={isAdvancedMode}
|
|
|
|
/>
|
|
|
|
/>
|
|
|
|
)}
|
|
|
|
)}
|
|
|
|
|
|
|
|
{currentModel?.model_type === ModelTypeEnum.tts && (
|
|
|
|
|
|
|
|
<TTSParamsPanel
|
|
|
|
|
|
|
|
currentModel={currentModel}
|
|
|
|
|
|
|
|
language={value?.language}
|
|
|
|
|
|
|
|
voice={value?.voice}
|
|
|
|
|
|
|
|
onChange={handleTTSParamsChange}
|
|
|
|
|
|
|
|
/>
|
|
|
|
|
|
|
|
)}
|
|
|
|
</div>
|
|
|
|
</div>
|
|
|
|
</div>
|
|
|
|
</div>
|
|
|
|
</PortalToFollowElemContent>
|
|
|
|
</PortalToFollowElemContent>
|
|
|
|
|