feat(api): add support for node variable value resetting

pull/20699/head
QuantumGhost 12 months ago
parent 91ee015114
commit b28be1a1ff

@ -293,14 +293,7 @@ class VariableResetApi(Resource):
if variable.app_id != app_model.id: if variable.app_id != app_model.id:
raise NotFoundError(description=f"variable not found, id={variable_id}") raise NotFoundError(description=f"variable not found, id={variable_id}")
if variable.node_id != CONVERSATION_VARIABLE_NODE_ID: resetted = draft_var_srv.reset_variable(draft_workflow, variable)
error_msg = "variable is not a conversation variable, id={}, node_id={}, name={}".format(
variable.id,
variable.node_id,
variable.name,
)
raise InvalidArgumentError(error_msg)
resetted = draft_var_srv.reset_conversation_variable(draft_workflow, variable)
db.session.commit() db.session.commit()
if resetted is None: if resetted is None:
return Response("", 204) return Response("", 204)

@ -5,7 +5,7 @@ from collections.abc import Mapping, Sequence
from enum import StrEnum from enum import StrEnum
from typing import Any, ClassVar from typing import Any, ClassVar
from sqlalchemy import Engine, orm from sqlalchemy import Engine, orm, select
from sqlalchemy.dialects.postgresql import insert from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import and_, or_ from sqlalchemy.sql.expression import and_, or_
@ -23,7 +23,8 @@ from core.workflow.nodes.variable_assigner.common.helpers import get_updated_var
from core.workflow.variable_loader import VariableLoader from core.workflow.variable_loader import VariableLoader
from factories.variable_factory import build_segment, segment_to_variable from factories.variable_factory import build_segment, segment_to_variable
from models import App, Conversation from models import App, Conversation
from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable from models.enums import DraftVariableType
from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -34,7 +35,7 @@ class WorkflowDraftVariableList:
total: int | None = None total: int | None = None
class _DraftVarServiceError(Exception): class VariableResetError(Exception):
pass pass
@ -206,9 +207,7 @@ class WorkflowDraftVariableService:
self._session.flush() self._session.flush()
return variable return variable
def reset_conversation_variable( def _reset_conv_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None:
self, workflow: Workflow, variable: WorkflowDraftVariable
) -> WorkflowDraftVariable | None:
conv_var_by_name = {i.name: i for i in workflow.conversation_variables} conv_var_by_name = {i.name: i for i in workflow.conversation_variables}
conv_var = conv_var_by_name.get(variable.name) conv_var = conv_var_by_name.get(variable.name)
@ -226,6 +225,61 @@ class WorkflowDraftVariableService:
self._session.flush() self._session.flush()
return variable return variable
def _reset_node_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None:
# No execution record for this variable, delete the variable instead.
if variable.node_execution_id is None:
self._session.delete(instance=variable)
self._session.flush()
_logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name)
return None
query = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == variable.node_execution_id)
node_exec = self._session.scalars(query).first()
if node_exec is None:
_logger.warning(
"Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s",
variable.id,
variable.name,
variable.node_execution_id,
)
self._session.delete(instance=variable)
self._session.flush()
return None
def _extract_pair_from_dict(d: dict[str, Any], needle: str) -> dict[str, Any]:
result: dict[str, Any] = {}
for key, value in d:
if key == needle:
result[needle] = value
return result
return result
to_save_process_data = _extract_pair_from_dict(node_exec.process_data_dict or {}, variable.name)
to_save_outputs = _extract_pair_from_dict(node_exec.outputs_dict or {}, variable.name)
node_config = workflow.get_node_config_by_id(variable.node_id)
node_type = workflow.get_node_type_from_node_config(node_config)
saver = DraftVariableSaver(
session=self._session,
app_id=workflow.app_id,
node_id=variable.node_id,
node_type=node_type,
invoke_from=InvokeFrom.DEBUGGER,
node_execution_id=variable.node_execution_id,
)
saver.save(to_save_process_data, to_save_outputs)
def reset_variable(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None:
variable_type = variable.get_variable_type()
if variable_type == DraftVariableType.CONVERSATION:
return self._reset_conv_var(workflow, variable)
elif variable_type == DraftVariableType.NODE:
return self._reset_node_var(workflow, variable)
else:
raise VariableResetError(f"cannot reset system variable, variable_id={variable.id}")
def delete_variable(self, variable: WorkflowDraftVariable): def delete_variable(self, variable: WorkflowDraftVariable):
self._session.delete(variable) self._session.delete(variable)

Loading…
Cancel
Save