diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 16de3dbe42..ca0b274faf 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -293,14 +293,7 @@ class VariableResetApi(Resource): if variable.app_id != app_model.id: raise NotFoundError(description=f"variable not found, id={variable_id}") - if variable.node_id != CONVERSATION_VARIABLE_NODE_ID: - error_msg = "variable is not a conversation variable, id={}, node_id={}, name={}".format( - variable.id, - variable.node_id, - variable.name, - ) - raise InvalidArgumentError(error_msg) - resetted = draft_var_srv.reset_conversation_variable(draft_workflow, variable) + resetted = draft_var_srv.reset_variable(draft_workflow, variable) db.session.commit() if resetted is None: return Response("", 204) diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index dfa4154c78..d69a6aa8a7 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -5,7 +5,7 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import Any, ClassVar -from sqlalchemy import Engine, orm +from sqlalchemy import Engine, orm, select from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session from sqlalchemy.sql.expression import and_, or_ @@ -23,7 +23,8 @@ from core.workflow.nodes.variable_assigner.common.helpers import get_updated_var from core.workflow.variable_loader import VariableLoader from factories.variable_factory import build_segment, segment_to_variable from models import App, Conversation -from models.workflow import Workflow, WorkflowDraftVariable, is_system_variable_editable +from models.enums import DraftVariableType +from models.workflow import Workflow, WorkflowDraftVariable, WorkflowNodeExecutionModel, is_system_variable_editable _logger = logging.getLogger(__name__) @@ -34,7 +35,7 @@ class WorkflowDraftVariableList: total: int | None = None -class _DraftVarServiceError(Exception): +class VariableResetError(Exception): pass @@ -206,9 +207,7 @@ class WorkflowDraftVariableService: self._session.flush() return variable - def reset_conversation_variable( - self, workflow: Workflow, variable: WorkflowDraftVariable - ) -> WorkflowDraftVariable | None: + def _reset_conv_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: conv_var_by_name = {i.name: i for i in workflow.conversation_variables} conv_var = conv_var_by_name.get(variable.name) @@ -226,6 +225,61 @@ class WorkflowDraftVariableService: self._session.flush() return variable + def _reset_node_var(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: + # No execution record for this variable, delete the variable instead. + if variable.node_execution_id is None: + self._session.delete(instance=variable) + self._session.flush() + _logger.warning("draft variable has no node_execution_id, id=%s, name=%s", variable.id, variable.name) + return None + + query = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == variable.node_execution_id) + node_exec = self._session.scalars(query).first() + if node_exec is None: + _logger.warning( + "Node exectution not found for draft variable, id=%s, name=%s, node_execution_id=%s", + variable.id, + variable.name, + variable.node_execution_id, + ) + self._session.delete(instance=variable) + self._session.flush() + return None + + def _extract_pair_from_dict(d: dict[str, Any], needle: str) -> dict[str, Any]: + result: dict[str, Any] = {} + for key, value in d: + if key == needle: + result[needle] = value + return result + + return result + + to_save_process_data = _extract_pair_from_dict(node_exec.process_data_dict or {}, variable.name) + to_save_outputs = _extract_pair_from_dict(node_exec.outputs_dict or {}, variable.name) + + node_config = workflow.get_node_config_by_id(variable.node_id) + node_type = workflow.get_node_type_from_node_config(node_config) + + saver = DraftVariableSaver( + session=self._session, + app_id=workflow.app_id, + node_id=variable.node_id, + node_type=node_type, + invoke_from=InvokeFrom.DEBUGGER, + node_execution_id=variable.node_execution_id, + ) + saver.save(to_save_process_data, to_save_outputs) + + def reset_variable(self, workflow: Workflow, variable: WorkflowDraftVariable) -> WorkflowDraftVariable | None: + variable_type = variable.get_variable_type() + if variable_type == DraftVariableType.CONVERSATION: + return self._reset_conv_var(workflow, variable) + elif variable_type == DraftVariableType.NODE: + return self._reset_node_var(workflow, variable) + else: + raise VariableResetError(f"cannot reset system variable, variable_id={variable.id}") + def delete_variable(self, variable: WorkflowDraftVariable): self._session.delete(variable)