support model params change

pull/12372/head
JzoNg 1 year ago
parent c8fc1deca6
commit e2e2090e0c

@ -2,7 +2,7 @@
import type { FC } from 'react'
import React, { Fragment, useEffect, useState } from 'react'
import { Combobox, Listbox, Transition } from '@headlessui/react'
import { CheckIcon, ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid'
import { ChevronDownIcon, ChevronUpIcon, XMarkIcon } from '@heroicons/react/20/solid'
import Badge from '../badge/index'
import { RiCheckLine } from '@remixicon/react'
import { useTranslation } from 'react-i18next'
@ -352,7 +352,7 @@ const PortalSelect: FC<PortalSelectProps> = ({
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className={`z-20 ${popupClassName}`}>
<div
className={classNames('px-1 py-1 max-h-60 overflow-auto rounded-md bg-white text-base shadow-lg border-gray-200 border-[0.5px] focus:outline-none sm:text-sm', popupInnerClassName)}
className={classNames('px-1 py-1 max-h-60 overflow-auto rounded-md text-base shadow-lg border-components-panel-border bg-components-panel-bg border-[0.5px] focus:outline-none sm:text-sm', popupInnerClassName)}
>
{items.map((item: Item) => (
<div

@ -72,25 +72,15 @@ const Form: FC<FormProps> = ({
onChange({ ...value, [key]: val, ...shouldClearVariable })
}
const handleModelChanged = useCallback((key: string, model: { provider: string; modelId: string; mode?: string }) => {
const handleModelChanged = useCallback((key: string, model: any) => {
const newValue = {
...value[key],
provider: model.provider,
model: model.modelId,
mode: model.mode,
...model,
type: FormTypeEnum.modelSelector,
}
onChange({ ...value, [key]: newValue })
}, [onChange, value])
const handleCompletionParamsChange = useCallback((key: string, newParams: Record<string, any>) => {
const newValue = {
...value[key],
completion_params: newParams,
}
onChange({ ...value, [key]: newValue })
}, [onChange, value])
const renderField = (formSchema: CredentialFormSchema) => {
const tooltip = formSchema.tooltip
const tooltipContent = (tooltip && (
@ -302,12 +292,8 @@ const Form: FC<FormProps> = ({
popupClassName='!w-[387px]'
isAdvancedMode
isInWorkflow
provider={value[variable]?.provider}
modelId={value[variable]?.model}
mode={value[variable]?.mode}
completionParams={value[variable]?.completion_params}
value={value[variable]}
setModel={model => handleModelChanged(variable, model)}
onCompletionParamsChange={params => handleCompletionParamsChange(variable, params)}
readonly={readonly}
scope={scope}
/>

@ -21,6 +21,7 @@ import {
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
import LLMParamsPanel from './llm-params-panel'
import TTSParamsPanel from './tts-params-panel'
import { useProviderContext } from '@/context/provider-context'
import cn from '@/utils/classnames'
@ -28,12 +29,8 @@ export type ModelParameterModalProps = {
popupClassName?: string
portalToFollowElemContentClassName?: string
isAdvancedMode: boolean
mode: string
modelId: string
provider: string
setModel: (model: { modelId: string; provider: string; mode?: string; features?: string[] }) => void
completionParams: FormValue
onCompletionParamsChange: (newParams: FormValue) => void
value: any
setModel: (model: any) => void
renderTrigger?: (v: TriggerProps) => ReactNode
readonly?: boolean
isInWorkflow?: boolean
@ -44,15 +41,12 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
popupClassName,
portalToFollowElemContentClassName,
isAdvancedMode,
modelId,
provider,
value,
setModel,
completionParams,
onCompletionParamsChange,
renderTrigger,
readonly,
isInWorkflow,
scope = 'text-generation',
scope = ModelTypeEnum.textGeneration,
}) => {
const { t } = useTranslation()
const { isAPIKeySet } = useProviderContext()
@ -79,29 +73,29 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
...moderationList,
]
}
if (scopeArray.includes('text-generation'))
if (scopeArray.includes(ModelTypeEnum.textGeneration))
return textGenerationList
if (scopeArray.includes('embedding'))
if (scopeArray.includes(ModelTypeEnum.textEmbedding))
return textEmbeddingList
if (scopeArray.includes('rerank'))
if (scopeArray.includes(ModelTypeEnum.rerank))
return rerankList
if (scopeArray.includes('moderation'))
if (scopeArray.includes(ModelTypeEnum.moderation))
return moderationList
if (scopeArray.includes('stt'))
if (scopeArray.includes(ModelTypeEnum.speech2text))
return sttList
if (scopeArray.includes('tts'))
if (scopeArray.includes(ModelTypeEnum.tts))
return ttsList
return resultList
}, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList])
const { currentProvider, currentModel } = useMemo(() => {
const currentProvider = scopedModelList.find(item => item.provider === provider)
const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === modelId)
const currentProvider = scopedModelList.find(item => item.provider === value?.provider)
const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model)
return {
currentProvider,
currentModel,
}
}, [provider, modelId, scopedModelList])
}, [scopedModelList, value?.provider, value?.model])
const hasDeprecated = useMemo(() => {
return !currentProvider || !currentModel
@ -116,11 +110,33 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
const handleChangeModel = ({ provider, model }: DefaultModel) => {
const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider)
const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model)
const model_type = targetModelItem?.model_type as string
setModel({
modelId: model,
provider,
mode: targetModelItem?.model_properties.mode as string,
features: targetModelItem?.features || [],
model,
model_type,
...(model_type === ModelTypeEnum.textGeneration ? {
mode: targetModelItem?.model_properties.mode as string,
} : {}),
})
}
const handleLLMParamsChange = (newParams: FormValue) => {
const newValue = {
...(value?.completionParams || {}),
completion_params: newParams,
}
setModel({
...value,
...newValue,
})
}
const handleTTSParamsChange = (language: string, voice: string) => {
setModel({
...value,
language,
voice,
})
}
@ -149,8 +165,8 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
hasDeprecated,
currentProvider,
currentModel,
providerName: provider,
modelId,
providerName: value?.provider,
modelId: value?.model,
})
: (
<Trigger
@ -160,8 +176,8 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
hasDeprecated={hasDeprecated}
currentProvider={currentProvider}
currentModel={currentModel}
providerName={provider}
modelId={modelId}
providerName={value?.provider}
modelId={value?.model}
/>
)
}
@ -174,7 +190,7 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
{t('common.modelProvider.model').toLocaleUpperCase()}
</div>
<ModelSelector
defaultModel={(provider || modelId) ? { provider, model: modelId } : undefined}
defaultModel={(value?.provider || value?.model) ? { provider: value?.provider, model: value?.model } : undefined}
modelList={scopedModelList}
scopeFeatures={scopeFeatures}
onSelect={handleChangeModel}
@ -185,13 +201,21 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
)}
{currentModel?.model_type === ModelTypeEnum.textGeneration && (
<LLMParamsPanel
provider={provider}
modelId={modelId}
completionParams={completionParams}
onCompletionParamsChange={onCompletionParamsChange}
provider={value?.provider}
modelId={value?.model}
completionParams={value?.completion_params || {}}
onCompletionParamsChange={handleLLMParamsChange}
isAdvancedMode={isAdvancedMode}
/>
)}
{currentModel?.model_type === ModelTypeEnum.tts && (
<TTSParamsPanel
currentModel={currentModel}
language={value?.language}
voice={value?.voice}
onChange={handleTTSParamsChange}
/>
)}
</div>
</div>
</PortalToFollowElemContent>

@ -0,0 +1,67 @@
import React, { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { languages } from '@/i18n/language'
import { PortalSelect } from '@/app/components/base/select'
import cn from '@/utils/classnames'
type Props = {
currentModel: any
language: string
voice: string
onChange: (language: string, voice: string) => void
}
const TTSParamsPanel = ({
currentModel,
language,
voice,
onChange,
}: Props) => {
const { t } = useTranslation()
const voiceList = useMemo(() => {
if (!currentModel)
return []
return currentModel.model_properties.voices.map((item: { mode: any }) => ({
...item,
value: item.mode,
}))
}, [currentModel])
const setLanguage = (language: string) => {
onChange(language, voice)
}
const setVoice = (voice: string) => {
onChange(language, voice)
}
return (
<>
<div className='mb-3'>
<div className='mb-1 py-1 flex items-center text-text-secondary system-sm-semibold'>
{t('appDebug.voice.voiceSettings.language')}
</div>
<PortalSelect
triggerClassName='h-8'
popupClassName={cn('z-[1000]')}
popupInnerClassName={cn('w-[354px]')}
value={language}
items={languages.filter(item => item.supported)}
onSelect={item => setLanguage(item.value as string)}
/>
</div>
<div className='mb-3'>
<div className='mb-1 py-1 flex items-center text-text-secondary system-sm-semibold'>
{t('appDebug.voice.voiceSettings.voice')}
</div>
<PortalSelect
triggerClassName='h-8'
popupClassName={cn('z-[1000]')}
popupInnerClassName={cn('w-[354px]')}
value={voice}
items={voiceList}
onSelect={item => setVoice(item.value as string)}
/>
</div>
</>
)
}
export default TTSParamsPanel

@ -123,24 +123,11 @@ const InputVarList: FC<Props> = ({
}
}, [onChange, value])
const handleModelChange = useCallback((variable: string) => {
return (model: { provider: string; modelId: string; mode?: string }) => {
return (model: any) => {
const newValue = produce(value, (draft: ToolVarInputs) => {
draft[variable] = {
...draft[variable],
provider: model.provider,
model: model.modelId,
mode: model.mode,
} as any
})
onChange(newValue)
}
}, [onChange, value])
const handleModelParamsChange = useCallback((variable: string) => {
return (newParams: Record<string, any>) => {
const newValue = produce(value, (draft: ToolVarInputs) => {
draft[variable] = {
...draft[variable],
completion_params: newParams,
...model,
} as any
})
onChange(newValue)
@ -242,12 +229,8 @@ const InputVarList: FC<Props> = ({
popupClassName='!w-[387px]'
isAdvancedMode
isInWorkflow
provider={(varInput as any)?.provider}
modelId={(varInput as any)?.model}
mode={(varInput as any)?.mode}
completionParams={(varInput as any)?.completion_params}
value={varInput as any}
setModel={handleModelChange(variable)}
onCompletionParamsChange={handleModelParamsChange(variable)}
readonly={readOnly}
scope={scope}
/>

Loading…
Cancel
Save