diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index 178c9b84c6..f68a94dedd 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -245,6 +245,41 @@ class VariableApi(Resource): return Response("", 204) +class VariableResetApi(Resource): + @_api_prerequisite + @marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS) + def put(self, app_model: App, variable_id: str): + draft_var_srv = WorkflowDraftVariableService( + session=db.session(), + ) + + workflow_srv = WorkflowService() + draft_workflow = workflow_srv.get_draft_workflow(app_model) + if draft_workflow is None: + raise NotFoundError( + f"Draft workflow not found, app_id={app_model.id}", + ) + variable = draft_var_srv.get_variable(variable_id=variable_id) + if variable is None: + raise NotFoundError(description=f"variable not found, id={variable_id}") + if variable.app_id != app_model.id: + raise NotFoundError(description=f"variable not found, id={variable_id}") + + if variable.node_id != CONVERSATION_VARIABLE_NODE_ID: + error_msg = "variable is not a conversation variable, id={}, node_id={}, name={}".format( + variable.id, + variable.node_id, + variable.name, + ) + raise InvalidArgumentError(error_msg) + resetted = draft_var_srv.reset_conversation_variable(draft_workflow, variable) + db.session.commit() + if resetted is None: + return Response("", 204) + else: + return variable + + def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList: with Session(bind=db.engine, expire_on_commit=False) as session: draft_var_srv = WorkflowDraftVariableService( @@ -321,6 +356,7 @@ api.add_resource( ) api.add_resource(NodeVariableCollectionApi, "/apps//workflows/draft/nodes//variables") api.add_resource(VariableApi, "/apps//workflows/draft/variables/") +api.add_resource(VariableApi, "/apps//workflows/draft/variables//reset") api.add_resource(ConversationVariableCollectionApi, "/apps//workflows/draft/conversation-variables") api.add_resource(SystemVariableCollectionApi, "/apps//workflows/draft/system-variables") diff --git a/api/core/file/constants.py b/api/core/file/constants.py index ce1d238e93..02d710303c 100644 --- a/api/core/file/constants.py +++ b/api/core/file/constants.py @@ -1 +1,5 @@ +# TODO(QuantumGhost): Refactor variable type identification. Instead of directly +# comparing `dify_model_identity` with constants throughout the codebase, extract +# this logic into a dedicated function. This would encapsulate the implementation +# details of how different variable types are identified. FILE_MODEL_IDENTITY = "__dify__file__" diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index 8235163b76..847f5e771f 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -32,6 +32,10 @@ class WorkflowDraftVariableList: total: int | None = None +class _DraftVarServiceError(Exception): + pass + + class DraftVarLoader(VariableLoader): # This implements the VariableLoader interface for loading draft variables. # @@ -200,6 +204,25 @@ class WorkflowDraftVariableService: self._session.flush() return variable + def reset_conversation_variable( + self, workflow: Workflow, variable: WorkflowDraftVariable + ) -> WorkflowDraftVariable | None: + conv_var_by_name = {i.name: i for i in workflow.conversation_variables} + conv_var = conv_var_by_name.get(variable.name) + + if conv_var is None: + self._session.delete(instance=variable) + self._session.flush() + _logger.warning( + "Conversation variable not found for draft variable, id=%s, name=%s", variable.id, variable.name + ) + return None + + variable.set_value(conv_var) + self._session.add(variable) + self._session.flush() + return variable + def delete_variable(self, variable: WorkflowDraftVariable): self._session.delete(variable) diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 59ff016235..87cefc3479 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -1,5 +1,4 @@ import json -import logging import time import uuid from collections.abc import Callable, Generator, Mapping, Sequence @@ -32,7 +31,6 @@ from core.workflow.workflow_entry import WorkflowEntry from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db from factories.variable_factory import segment_to_variable -from libs import gen_utils from models.account import Account from models.model import App, AppMode from models.tools import WorkflowToolProvider