|
|
|
|
@ -1,10 +1,8 @@
|
|
|
|
|
import json
|
|
|
|
|
import logging
|
|
|
|
|
from collections.abc import Mapping, Sequence
|
|
|
|
|
from typing import TYPE_CHECKING, Any, Optional, cast
|
|
|
|
|
|
|
|
|
|
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
|
|
|
|
from core.llm_generator.output_parser.errors import OutputParserError
|
|
|
|
|
from core.memory.token_buffer_memory import TokenBufferMemory
|
|
|
|
|
from core.model_manager import ModelInstance
|
|
|
|
|
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
|
|
|
|
@ -96,27 +94,28 @@ class QuestionClassifierNode(LLMNode):
|
|
|
|
|
jinja2_variables=[],
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# handle invoke result
|
|
|
|
|
generator = self._invoke_llm(
|
|
|
|
|
node_data_model=node_data.model,
|
|
|
|
|
model_instance=model_instance,
|
|
|
|
|
prompt_messages=prompt_messages,
|
|
|
|
|
stop=stop,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
result_text = ""
|
|
|
|
|
usage = LLMUsage.empty_usage()
|
|
|
|
|
finish_reason = None
|
|
|
|
|
for event in generator:
|
|
|
|
|
if isinstance(event, ModelInvokeCompletedEvent):
|
|
|
|
|
result_text = event.text
|
|
|
|
|
usage = event.usage
|
|
|
|
|
finish_reason = event.finish_reason
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
category_name = node_data.classes[0].name
|
|
|
|
|
category_id = node_data.classes[0].id
|
|
|
|
|
try:
|
|
|
|
|
# handle invoke result
|
|
|
|
|
generator = self._invoke_llm(
|
|
|
|
|
node_data_model=node_data.model,
|
|
|
|
|
model_instance=model_instance,
|
|
|
|
|
prompt_messages=prompt_messages,
|
|
|
|
|
stop=stop,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
for event in generator:
|
|
|
|
|
if isinstance(event, ModelInvokeCompletedEvent):
|
|
|
|
|
result_text = event.text
|
|
|
|
|
usage = event.usage
|
|
|
|
|
finish_reason = event.finish_reason
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
category_name = node_data.classes[0].name
|
|
|
|
|
category_id = node_data.classes[0].id
|
|
|
|
|
result_text_json = parse_and_check_json_markdown(result_text, [])
|
|
|
|
|
# result_text_json = json.loads(result_text.strip('```JSON\n'))
|
|
|
|
|
if "category_name" in result_text_json and "category_id" in result_text_json:
|
|
|
|
|
@ -127,10 +126,6 @@ class QuestionClassifierNode(LLMNode):
|
|
|
|
|
if category_id_result in category_ids:
|
|
|
|
|
category_name = classes_map[category_id_result]
|
|
|
|
|
category_id = category_id_result
|
|
|
|
|
|
|
|
|
|
except OutputParserError:
|
|
|
|
|
logging.exception(f"Failed to parse result text: {result_text}")
|
|
|
|
|
try:
|
|
|
|
|
process_data = {
|
|
|
|
|
"model_mode": model_config.mode,
|
|
|
|
|
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
|
|
|
|
@ -154,7 +149,7 @@ class QuestionClassifierNode(LLMNode):
|
|
|
|
|
},
|
|
|
|
|
llm_usage=usage,
|
|
|
|
|
)
|
|
|
|
|
except Exception as e:
|
|
|
|
|
except ValueError as e:
|
|
|
|
|
return NodeRunResult(
|
|
|
|
|
status=WorkflowNodeExecutionStatus.FAILED,
|
|
|
|
|
inputs=variables,
|
|
|
|
|
|