|
|
|
|
@ -2,11 +2,14 @@ import tempfile
|
|
|
|
|
from binascii import hexlify, unhexlify
|
|
|
|
|
from collections.abc import Generator
|
|
|
|
|
|
|
|
|
|
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
|
|
|
|
from core.model_manager import ModelManager
|
|
|
|
|
from core.model_runtime.entities.llm_entities import (
|
|
|
|
|
LLMResult,
|
|
|
|
|
LLMResultChunk,
|
|
|
|
|
LLMResultChunkDelta,
|
|
|
|
|
LLMResultChunkWithStructuredOutput,
|
|
|
|
|
LLMResultWithStructuredOutput,
|
|
|
|
|
)
|
|
|
|
|
from core.model_runtime.entities.message_entities import (
|
|
|
|
|
PromptMessage,
|
|
|
|
|
@ -16,6 +19,7 @@ from core.model_runtime.entities.message_entities import (
|
|
|
|
|
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
|
|
|
|
from core.plugin.entities.request import (
|
|
|
|
|
RequestInvokeLLM,
|
|
|
|
|
RequestInvokeLLMWithStructuredOutput,
|
|
|
|
|
RequestInvokeModeration,
|
|
|
|
|
RequestInvokeRerank,
|
|
|
|
|
RequestInvokeSpeech2Text,
|
|
|
|
|
@ -85,6 +89,72 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
|
|
|
|
|
|
|
|
|
|
return handle_non_streaming(response)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def invoke_llm_with_structured_output(
|
|
|
|
|
cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLMWithStructuredOutput
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
invoke llm with structured output
|
|
|
|
|
"""
|
|
|
|
|
model_instance = ModelManager().get_model_instance(
|
|
|
|
|
tenant_id=tenant.id,
|
|
|
|
|
provider=payload.provider,
|
|
|
|
|
model_type=payload.model_type,
|
|
|
|
|
model=payload.model,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
model_schema = model_instance.model_type_instance.get_model_schema(payload.model, model_instance.credentials)
|
|
|
|
|
|
|
|
|
|
if not model_schema:
|
|
|
|
|
raise ValueError(f"Model schema not found for {payload.model}")
|
|
|
|
|
|
|
|
|
|
response = invoke_llm_with_structured_output(
|
|
|
|
|
provider=payload.provider,
|
|
|
|
|
model_schema=model_schema,
|
|
|
|
|
model_instance=model_instance,
|
|
|
|
|
prompt_messages=payload.prompt_messages,
|
|
|
|
|
json_schema=payload.structured_output_schema,
|
|
|
|
|
tools=payload.tools,
|
|
|
|
|
stop=payload.stop,
|
|
|
|
|
stream=True if payload.stream is None else payload.stream,
|
|
|
|
|
user=user_id,
|
|
|
|
|
model_parameters=payload.completion_params,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if isinstance(response, Generator):
|
|
|
|
|
|
|
|
|
|
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
|
|
|
|
for chunk in response:
|
|
|
|
|
if chunk.delta.usage:
|
|
|
|
|
llm_utils.deduct_llm_quota(
|
|
|
|
|
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
|
|
|
|
|
)
|
|
|
|
|
chunk.prompt_messages = []
|
|
|
|
|
yield chunk
|
|
|
|
|
|
|
|
|
|
return handle()
|
|
|
|
|
else:
|
|
|
|
|
if response.usage:
|
|
|
|
|
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
|
|
|
|
|
|
|
|
|
|
def handle_non_streaming(
|
|
|
|
|
response: LLMResultWithStructuredOutput,
|
|
|
|
|
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
|
|
|
|
yield LLMResultChunkWithStructuredOutput(
|
|
|
|
|
model=response.model,
|
|
|
|
|
prompt_messages=[],
|
|
|
|
|
system_fingerprint=response.system_fingerprint,
|
|
|
|
|
structured_output=response.structured_output,
|
|
|
|
|
delta=LLMResultChunkDelta(
|
|
|
|
|
index=0,
|
|
|
|
|
message=response.message,
|
|
|
|
|
usage=response.usage,
|
|
|
|
|
finish_reason="",
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return handle_non_streaming(response)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def invoke_text_embedding(cls, user_id: str, tenant: Tenant, payload: RequestInvokeTextEmbedding):
|
|
|
|
|
"""
|
|
|
|
|
|