Merge branch 'feat/rag-pipeline' into deploy/rag-dev
commit
3c2ce07f38
@ -0,0 +1,156 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Optional, cast
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_entities import QuotaUnit
|
||||
from core.file.models import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.plugin.entities.plugin import ModelProviderID
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.entities.variable_pool import VariablePool
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from models import db
|
||||
from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
|
||||
|
||||
|
||||
def fetch_model_config(
|
||||
tenant_id: str, node_data_model: ModelConfig
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
model = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
)
|
||||
|
||||
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
|
||||
|
||||
# check model
|
||||
provider_model = model.provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
# model config
|
||||
stop: list[str] = []
|
||||
if "stop" in node_data_model.completion_params:
|
||||
stop = node_data_model.completion_params.pop("stop")
|
||||
|
||||
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
return model, ModelConfigWithCredentialsEntity(
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
model_schema=model_schema,
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=model.provider_model_bundle,
|
||||
credentials=model.credentials,
|
||||
parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
|
||||
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
|
||||
variable = variable_pool.get(selector)
|
||||
if variable is None:
|
||||
return []
|
||||
elif isinstance(variable, FileSegment):
|
||||
return [variable.value]
|
||||
elif isinstance(variable, ArrayFileSegment):
|
||||
return variable.value
|
||||
elif isinstance(variable, NoneSegment | ArrayAnySegment):
|
||||
return []
|
||||
raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}")
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
variable_pool: VariablePool, app_id: str, node_data_memory: Optional[MemoryConfig], model_instance: ModelInstance
|
||||
) -> Optional[TokenBufferMemory]:
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID.value])
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
return memory
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
|
||||
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
|
||||
return
|
||||
|
||||
system_configuration = provider_configuration.system_configuration
|
||||
|
||||
quota_unit = None
|
||||
for quota_configuration in system_configuration.quota_configurations:
|
||||
if quota_configuration.quota_type == system_configuration.current_quota_type:
|
||||
quota_unit = quota_configuration.quota_unit
|
||||
|
||||
if quota_configuration.quota_limit == -1:
|
||||
return
|
||||
|
||||
break
|
||||
|
||||
used_quota = None
|
||||
if quota_unit:
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_instance.model)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
if used_quota is not None and system_configuration.current_quota_type is not None:
|
||||
with Session(db.engine) as session:
|
||||
stmt = (
|
||||
update(Provider)
|
||||
.where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
# TODO: Use provider name with prefix after the data migration.
|
||||
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == system_configuration.current_quota_type.value,
|
||||
Provider.quota_limit > Provider.quota_used,
|
||||
)
|
||||
.values(
|
||||
quota_used=Provider.quota_used + used_quota,
|
||||
last_used=datetime.now(tz=UTC).replace(tzinfo=None),
|
||||
)
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,150 @@
|
||||
'use client'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { RiCloseLine } from '@remixicon/react'
|
||||
import AppIconPicker from '@/app/components/base/app-icon-picker'
|
||||
import type { AppIconSelection } from '@/app/components/base/app-icon-picker'
|
||||
import Modal from '@/app/components/base/modal'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Input from '@/app/components/base/input'
|
||||
import Textarea from '@/app/components/base/textarea'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import { noop } from 'lodash-es'
|
||||
import { useStore } from '@/app/components/workflow/store'
|
||||
import type { IconInfo } from '@/models/datasets'
|
||||
|
||||
type PublishAsKnowledgePipelineModalProps = {
|
||||
confirmDisabled?: boolean
|
||||
onCancel: () => void
|
||||
onConfirm: (
|
||||
name: string,
|
||||
icon: IconInfo,
|
||||
description?: string,
|
||||
) => Promise<void>
|
||||
}
|
||||
const PublishAsKnowledgePipelineModal = ({
|
||||
confirmDisabled,
|
||||
onCancel,
|
||||
onConfirm,
|
||||
}: PublishAsKnowledgePipelineModalProps) => {
|
||||
const { t } = useTranslation()
|
||||
const knowledgeName = useStore(s => s.knowledgeName)
|
||||
const knowledgeIcon = useStore(s => s.knowledgeIcon)
|
||||
const [pipelineName, setPipelineName] = useState(knowledgeName!)
|
||||
const [pipelineIcon, setPipelineIcon] = useState(knowledgeIcon!)
|
||||
const [description, setDescription] = useState('')
|
||||
const [showAppIconPicker, setShowAppIconPicker] = useState(false)
|
||||
|
||||
const handleSelectIcon = useCallback((item: AppIconSelection) => {
|
||||
if (item.type === 'image') {
|
||||
setPipelineIcon({
|
||||
icon_type: 'image',
|
||||
icon_url: item.url,
|
||||
icon_background: '',
|
||||
icon: '',
|
||||
})
|
||||
}
|
||||
|
||||
if (item.type === 'emoji') {
|
||||
setPipelineIcon({
|
||||
icon_type: 'emoji',
|
||||
icon: item.icon,
|
||||
icon_background: item.background,
|
||||
icon_url: '',
|
||||
})
|
||||
}
|
||||
setShowAppIconPicker(false)
|
||||
}, [])
|
||||
const handleCloseIconPicker = useCallback(() => {
|
||||
setPipelineIcon({
|
||||
icon_type: pipelineIcon.icon_type,
|
||||
icon: pipelineIcon.icon,
|
||||
icon_background: pipelineIcon.icon_background,
|
||||
icon_url: pipelineIcon.icon_url,
|
||||
})
|
||||
setShowAppIconPicker(false)
|
||||
}, [pipelineIcon])
|
||||
|
||||
const handleConfirm = () => {
|
||||
if (confirmDisabled)
|
||||
return
|
||||
|
||||
onConfirm(
|
||||
pipelineName?.trim() || '',
|
||||
pipelineIcon,
|
||||
description?.trim(),
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Modal
|
||||
isShow
|
||||
onClose={noop}
|
||||
className='relative !w-[520px] !p-0'
|
||||
>
|
||||
<div className='title-2xl-semi-bold relative flex items-center p-6 pb-3 pr-14 text-text-primary'>
|
||||
{t('pipeline.common.publishAs')}
|
||||
<div className='absolute right-5 top-5 flex h-8 w-8 cursor-pointer items-center justify-center' onClick={onCancel}>
|
||||
<RiCloseLine className='h-4 w-4 text-text-tertiary' />
|
||||
</div>
|
||||
</div>
|
||||
<div className='px-6 py-3'>
|
||||
<div className='mb-5 flex'>
|
||||
<div className='mr-3 grow'>
|
||||
<div className='system-sm-medium mb-1 flex h-6 items-center text-text-secondary'>
|
||||
{t('pipeline.common.publishAsPipeline.name')}
|
||||
</div>
|
||||
<Input
|
||||
value={pipelineName}
|
||||
onChange={e => setPipelineName(e.target.value)}
|
||||
placeholder={t('pipeline.common.publishAsPipeline.namePlaceholder') || ''}
|
||||
/>
|
||||
</div>
|
||||
<AppIcon
|
||||
size='xxl'
|
||||
onClick={() => { setShowAppIconPicker(true) }}
|
||||
className='mt-2 shrink-0 cursor-pointer'
|
||||
iconType={pipelineIcon?.icon_type}
|
||||
icon={pipelineIcon?.icon}
|
||||
background={pipelineIcon?.icon_background}
|
||||
imageUrl={pipelineIcon?.icon_url}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<div className='system-sm-medium mb-1 flex h-6 items-center text-text-secondary '>
|
||||
{t('pipeline.common.publishAsPipeline.description')}
|
||||
</div>
|
||||
<Textarea
|
||||
className='resize-none'
|
||||
placeholder={t('pipeline.common.publishAsPipeline.descriptionPlaceholder') || ''}
|
||||
value={description}
|
||||
onChange={e => setDescription(e.target.value)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className='flex items-center justify-end px-6 py-5'>
|
||||
<Button
|
||||
className='mr-2'
|
||||
onClick={onCancel}
|
||||
>
|
||||
{t('common.operation.cancel')}
|
||||
</Button>
|
||||
<Button
|
||||
disabled={!pipelineName?.trim() || confirmDisabled}
|
||||
variant='primary'
|
||||
onClick={() => handleConfirm()}
|
||||
>
|
||||
{t('workflow.common.publish')}
|
||||
</Button>
|
||||
</div>
|
||||
</Modal>
|
||||
{showAppIconPicker && <AppIconPicker
|
||||
onSelect={handleSelectIcon}
|
||||
onClose={handleCloseIconPicker}
|
||||
/>}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default PublishAsKnowledgePipelineModal
|
||||
Loading…
Reference in New Issue