feat(api): do not save `finish_reason` from LLM node outputs

pull/20699/head
QuantumGhost 12 months ago
parent b28be1a1ff
commit 282f44822c

@ -140,13 +140,6 @@ class WorkflowEntry:
# Get node class
node_type = NodeType(node_config_data.get("type"))
node_version = node_config_data.get("version", "1")
if node_type == NodeType.START:
# special handing for start node.
#
# 1. create conversation variables and system variables
# 2. create environment variables
pass
node_cls = NODE_TYPE_CLASSES_MAPPING[node_type][node_version]
metadata_attacher = _attach_execution_metadata_based_on_node_config(node_config_data)

@ -13,6 +13,7 @@ from core.file.models import File
from core.variables import utils as variable_utils
from core.variables.segments import ArrayFileSegment, FileSegment
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.nodes.enums import NodeType
from factories.variable_factory import build_segment
from ._workflow_exc import NodeNotFoundError, WorkflowDataError
@ -77,6 +78,10 @@ class WorkflowType(Enum):
return cls.WORKFLOW if app_mode == AppMode.WORKFLOW else cls.CHAT
class _InvalidGraphDefinitionError(Exception):
pass
class Workflow(Base):
"""
Workflow, for `Workflow App` and `Chat App workflow mode`.
@ -226,6 +231,31 @@ class Workflow(Base):
raise NodeNotFoundError(node_id)
return node_config
@staticmethod
def get_node_type_from_node_config(node_config: Mapping[str, Any]) -> NodeType:
"""Extract type of a node from the node configuration returned by `get_node_config_by_id`."""
node_config_data = node_config.get("data", {})
# Get node class
node_type = NodeType(node_config_data.get("type"))
return node_type
@staticmethod
def get_enclosing_node_type_and_id(node_config: Mapping[str, Any]) -> tuple[NodeType, str] | None:
in_loop = node_config.get("isInLoop", False)
in_iteration = node_config.get("isInIteration", False)
if in_loop:
loop_id = node_config.get("loop_id")
if loop_id is None:
raise _InvalidGraphDefinitionError("invalid graph")
return NodeType.LOOP, loop_id
elif in_iteration:
iteration_id = node_config.get("iteration_id")
if iteration_id is None:
raise _InvalidGraphDefinitionError("invalid graph")
return NodeType.ITERATION, iteration_id
else:
return None
@property
def features(self) -> str:
"""

@ -498,6 +498,14 @@ class DraftVariableSaver:
_DUMMY_OUTPUT_IDENTITY: ClassVar[str] = "__dummy__"
_DUMMY_OUTPUT_VALUE: ClassVar[None] = None
# _EXCLUDE_VARIABLE_NAMES_MAPPING maps node types and versions to variable names that
# should be excluded when saving draft variables. This prevents certain internal or
# technical variables from being exposed in the draft environment, particularly those
# that aren't meant to be directly edited or viewed by users.
_EXCLUDE_VARIABLE_NAMES_MAPPING: dict[NodeType, frozenset[str]] = {
NodeType.LLM: frozenset(["finish_reason"]),
}
# Database session used for persisting draft variables.
_session: Session
@ -639,6 +647,14 @@ class DraftVariableSaver:
def _build_variables_from_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]:
draft_vars = []
for name, value in output.items():
if not self._should_variable_be_saved(name):
_logger.debug(
"Skip saving variable as it has been excluded by its node_type, name=%s, node_type=%s",
name,
self._node_type,
)
continue
value_seg = _build_segment_for_value(value)
draft_vars.append(
WorkflowDraftVariable.new_node_variable(
@ -692,20 +708,8 @@ class DraftVariableSaver:
return False
return True
# @staticmethod
# def _normalize_variable(node_type: NodeType, node_id: str, name: str) -> tuple[str, str]:
# if node_type != NodeType.START:
# return node_id, name
#
# # TODO(QuantumGhost): need special handling for dummy output variable in
# # `Start` node.
# if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."):
# return node_id, name
# logging.getLogger(__name__).info(
# "Normalizing variable: node_type=%s, node_id=%s, name=%s",
# node_type,
# node_id,
# name,
# )
# node_id, name_ = name.split(".", maxsplit=1)
# return node_id, name_
def _should_variable_be_saved(self, name: str) -> bool:
exclude_var_names = self._EXCLUDE_VARIABLE_NAMES_MAPPING.get(self._node_type)
if exclude_var_names is None:
return True
return name in exclude_var_names

Loading…
Cancel
Save