diff --git a/api/models/_workflow_exc.py b/api/models/_workflow_exc.py new file mode 100644 index 0000000000..f6271bda47 --- /dev/null +++ b/api/models/_workflow_exc.py @@ -0,0 +1,20 @@ +"""All these exceptions are not meant to be caught by callers.""" + + +class WorkflowDataError(Exception): + """Base class for all workflow data related exceptions. + + This should be used to indicate issues with workflow data integrity, such as + no `graph` configuration, missing `nodes` field in `graph` configuration, or + similar issues. + """ + + pass + + +class NodeNotFoundError(WorkflowDataError): + """Raised when a node with the specified ID is not found in the workflow.""" + + def __init__(self, node_id: str): + super().__init__(f"Node with ID '{node_id}' not found in the workflow.") + self.node_id = node_id diff --git a/api/models/workflow.py b/api/models/workflow.py index 2b4fbeab37..975ca882d9 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -12,6 +12,8 @@ from core.variables import utils as variable_utils from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment +from ._workflow_exc import NodeNotFoundError, WorkflowDataError + if TYPE_CHECKING: from models.model import AppMode @@ -201,6 +203,26 @@ class Workflow(Base): # - `_get_graph_and_variable_pool_of_single_loop`. return json.loads(self.graph) if self.graph else {} + def get_node_config_by_id(self, node_id: str) -> Mapping[str, Any]: + """Extract a node configuration from the workflow graph by node ID. + A node configuration is a dictionary containing the node's properties, including + the node's id, title, and its data as a dict. + """ + workflow_graph = self.graph_dict + + if not workflow_graph: + raise WorkflowDataError(f"workflow graph not found, workflow_id={self.id}") + + nodes = workflow_graph.get("nodes") + if not nodes: + raise WorkflowDataError("nodes not found in workflow graph") + + try: + node_config = next(filter(lambda node: node["id"] == node_id, nodes)) + except StopIteration: + raise NodeNotFoundError(node_id) + return node_config + @property def features(self) -> str: """ @@ -955,7 +977,7 @@ class WorkflowDraftVariable(Base): def _set_selector(self, value: list[str]): self.selector = json.dumps(value) - def get_value(self) -> Segment | None: + def get_value(self) -> Segment: return build_segment(json.loads(self.value)) def set_name(self, name: str): @@ -1009,12 +1031,14 @@ class WorkflowDraftVariable(Base): app_id: str, name: str, value: Segment, + description: str = "", ) -> "WorkflowDraftVariable": variable = cls._new( app_id=app_id, node_id=CONVERSATION_VARIABLE_NODE_ID, name=name, value=value, + description=description, ) return variable