feat(api): implement draft var related api
parent
c7e6b9ce9c
commit
5aa044e392
@ -0,0 +1,319 @@
|
|||||||
|
import logging
|
||||||
|
from typing import NoReturn
|
||||||
|
|
||||||
|
from flask import Response
|
||||||
|
from flask_restful import Resource, fields, inputs, marshal_with, reqparse
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
|
from controllers.console import api
|
||||||
|
from controllers.console.app.error import (
|
||||||
|
DraftWorkflowNotExist,
|
||||||
|
)
|
||||||
|
from controllers.console.app.wraps import get_app_model
|
||||||
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
|
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||||
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
|
from factories.variable_factory import build_segment
|
||||||
|
from libs.login import current_user, login_required
|
||||||
|
from models import App, AppMode, db
|
||||||
|
from models.workflow import WorkflowDraftVariable
|
||||||
|
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||||
|
from services.workflow_service import WorkflowService
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _create_pagination_parser():
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"page",
|
||||||
|
type=inputs.int_range(1, 100_000),
|
||||||
|
required=False,
|
||||||
|
default=1,
|
||||||
|
location="args",
|
||||||
|
help="the page of data requested",
|
||||||
|
)
|
||||||
|
parser.add_argument("limit", type=inputs.int_range(1, 100), required=False, default=20, location="args")
|
||||||
|
return parser
|
||||||
|
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS = {
|
||||||
|
"id": fields.String,
|
||||||
|
"type": fields.String(attribute=lambda model: model.get_variable_type()),
|
||||||
|
"name": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||||
|
"value_type": fields.String,
|
||||||
|
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||||
|
"visible": fields.Boolean,
|
||||||
|
}
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_FIELDS = dict(
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||||
|
value=fields.Raw(attribute=lambda variable: variable.get_value().value),
|
||||||
|
)
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS = {
|
||||||
|
"id": fields.String,
|
||||||
|
"type": fields.String(attribute=lambda _: "env"),
|
||||||
|
"name": fields.String,
|
||||||
|
"description": fields.String,
|
||||||
|
"selector": fields.List(fields.String, attribute=lambda model: model.get_selector()),
|
||||||
|
"value_type": fields.String,
|
||||||
|
"edited": fields.Boolean(attribute=lambda model: model.edited),
|
||||||
|
"visible": fields.Boolean,
|
||||||
|
}
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_ENV_VARIABLE_LIST_FIELDS = {
|
||||||
|
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_ENV_VARIABLE_FIELDS)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_items(var_list: WorkflowDraftVariableList) -> list[WorkflowDraftVariable]:
|
||||||
|
return var_list.variables
|
||||||
|
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS = {
|
||||||
|
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS), attribute=_get_items),
|
||||||
|
"total": fields.Raw(),
|
||||||
|
}
|
||||||
|
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS = {
|
||||||
|
"items": fields.List(fields.Nested(_WORKFLOW_DRAFT_VARIABLE_FIELDS), attribute=_get_items),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _api_prerequisite(f):
|
||||||
|
"""Common prerequisites for all draft workflow variable APIs.
|
||||||
|
|
||||||
|
It ensures the following conditions are satisfied:
|
||||||
|
|
||||||
|
- Dify has been property setup.
|
||||||
|
- The request user has logged in and initialized.
|
||||||
|
- The requested app is a workflow or a chat flow.
|
||||||
|
- The request user has the edit permission for the app.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
if not current_user.is_editor:
|
||||||
|
raise Forbidden()
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS)
|
||||||
|
def get(self, app_model: App):
|
||||||
|
"""
|
||||||
|
Get draft workflow
|
||||||
|
"""
|
||||||
|
parser = _create_pagination_parser()
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
# fetch draft workflow by app_model
|
||||||
|
workflow_service = WorkflowService()
|
||||||
|
workflow_exist = workflow_service.is_workflow_exist(app_model=app_model)
|
||||||
|
if not workflow_exist:
|
||||||
|
raise DraftWorkflowNotExist()
|
||||||
|
|
||||||
|
# fetch draft workflow by app_model
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
workflow_vars = draft_var_srv.list_variables_without_values(
|
||||||
|
app_id=app_model.id,
|
||||||
|
page=args.page,
|
||||||
|
limit=args.limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
return workflow_vars
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
def delete(self, app_model: App):
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session,
|
||||||
|
)
|
||||||
|
draft_var_srv.delete_workflow_variables(app_model.id)
|
||||||
|
db.session.commit()
|
||||||
|
return Response("", 204)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_node_id(node_id: str) -> NoReturn | None:
|
||||||
|
if node_id in [
|
||||||
|
CONVERSATION_VARIABLE_NODE_ID,
|
||||||
|
SYSTEM_VARIABLE_NODE_ID,
|
||||||
|
]:
|
||||||
|
# NOTE(QuantumGhost): While we store the system and conversation variables as node variables
|
||||||
|
# with specific `node_id` in database, we still want to make the API separated. By disallowing
|
||||||
|
# accessing system and conversation variables in `WorkflowDraftNodeVariableListApi`,
|
||||||
|
# we mitigate the risk that user of the API depending on the implementation detail of the API.
|
||||||
|
#
|
||||||
|
# ref: [Hyrum's Law](https://www.hyrumslaw.com/)
|
||||||
|
|
||||||
|
raise InvalidArgumentError(
|
||||||
|
f"invalid node_id, please use correspond api for conversation and system variables, node_id={node_id}",
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class NodeVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||||
|
def get(self, app_model: App, node_id: str):
|
||||||
|
validate_node_id(node_id)
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
node_vars = draft_var_srv.list_node_variables(app_model.id, node_id)
|
||||||
|
|
||||||
|
return node_vars
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
def delete(self, app_model: App, node_id: str):
|
||||||
|
validate_node_id(node_id)
|
||||||
|
srv = WorkflowDraftVariableService(db.session)
|
||||||
|
srv.delete_node_variables(app_model.id, node_id)
|
||||||
|
db.session.commit()
|
||||||
|
return Response("", 204)
|
||||||
|
|
||||||
|
|
||||||
|
class VariableApi(Resource):
|
||||||
|
_PATCH_NAME_FIELD = "name"
|
||||||
|
_PATCH_VALUE_FIELD = "value"
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||||
|
def get(self, app_model: App, variable_id: str):
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session,
|
||||||
|
)
|
||||||
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
|
if variable is None:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
if variable.app_id != app_model.id:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
return variable
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_FIELDS)
|
||||||
|
def patch(self, app_model: App, variable_id: str):
|
||||||
|
parser = reqparse.RequestParser()
|
||||||
|
parser.add_argument(self._PATCH_NAME_FIELD, type=str, required=False, nullable=True, location="json")
|
||||||
|
parser.add_argument(self._PATCH_VALUE_FIELD, type=build_segment, required=False, nullable=True, location="json")
|
||||||
|
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session,
|
||||||
|
)
|
||||||
|
args = parser.parse_args(strict=True)
|
||||||
|
|
||||||
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
|
if variable is None:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
if variable.app_id != app_model.id:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
|
||||||
|
new_name = args.get(self._PATCH_NAME_FIELD, None)
|
||||||
|
new_value = args.get(self._PATCH_VALUE_FIELD, None)
|
||||||
|
|
||||||
|
if new_name is None and new_value is None:
|
||||||
|
return variable
|
||||||
|
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
|
||||||
|
db.session.commit()
|
||||||
|
return variable
|
||||||
|
|
||||||
|
@_api_prerequisite
|
||||||
|
def delete(self, app_model: App, variable_id: str):
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=db.session,
|
||||||
|
)
|
||||||
|
variable = draft_var_srv.get_variable(variable_id=variable_id)
|
||||||
|
if variable is None:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
if variable.app_id != app_model.id:
|
||||||
|
raise NotFoundError(description=f"variable not found, id={variable_id}")
|
||||||
|
draft_var_srv.delete_variable(variable)
|
||||||
|
db.session.commit()
|
||||||
|
return Response("", 204)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_variable_list(app_model: App, node_id) -> WorkflowDraftVariableList:
|
||||||
|
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||||
|
draft_var_srv = WorkflowDraftVariableService(
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
if node_id == CONVERSATION_VARIABLE_NODE_ID:
|
||||||
|
draft_vars = draft_var_srv.list_conversation_variables(app_model.id)
|
||||||
|
elif node_id == SYSTEM_VARIABLE_NODE_ID:
|
||||||
|
draft_vars = draft_var_srv.list_system_variables(app_model.id)
|
||||||
|
else:
|
||||||
|
draft_vars = draft_var_srv.list_node_variables(app_id=app_model.id, node_id=node_id)
|
||||||
|
return draft_vars
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||||
|
def get(self, app_model: App):
|
||||||
|
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
|
||||||
|
|
||||||
|
|
||||||
|
class SystemVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
@marshal_with(_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||||
|
def get(self, app_model: App):
|
||||||
|
return _get_variable_list(app_model, SYSTEM_VARIABLE_NODE_ID)
|
||||||
|
|
||||||
|
|
||||||
|
class EnvironmentVariableCollectionApi(Resource):
|
||||||
|
@_api_prerequisite
|
||||||
|
def get(self, app_model: App):
|
||||||
|
"""
|
||||||
|
Get draft workflow
|
||||||
|
"""
|
||||||
|
# fetch draft workflow by app_model
|
||||||
|
workflow_service = WorkflowService()
|
||||||
|
workflow = workflow_service.get_draft_workflow(app_model=app_model)
|
||||||
|
if workflow is None:
|
||||||
|
raise DraftWorkflowNotExist()
|
||||||
|
|
||||||
|
env_vars = workflow.environment_variables
|
||||||
|
env_vars_list = []
|
||||||
|
for v in env_vars:
|
||||||
|
env_vars_list.append(
|
||||||
|
{
|
||||||
|
"id": v.id,
|
||||||
|
"type": "env",
|
||||||
|
"name": v.name,
|
||||||
|
"description": v.description,
|
||||||
|
"selector": v.selector,
|
||||||
|
"value_type": v.value_type.value,
|
||||||
|
"value": v.value,
|
||||||
|
# Do not track edited for env vars.
|
||||||
|
"edited": False,
|
||||||
|
"visible": True,
|
||||||
|
"editable": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {"items": env_vars_list}
|
||||||
|
|
||||||
|
|
||||||
|
api.add_resource(
|
||||||
|
WorkflowVariableCollectionApi,
|
||||||
|
"/apps/<uuid:app_id>/workflows/draft/variables",
|
||||||
|
)
|
||||||
|
api.add_resource(NodeVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/variables")
|
||||||
|
api.add_resource(VariableApi, "/apps/<uuid:app_id>/workflows/draft/variables/<uuid:variable_id>")
|
||||||
|
|
||||||
|
api.add_resource(ConversationVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/conversation-variables")
|
||||||
|
api.add_resource(SystemVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/system-variables")
|
||||||
|
api.add_resource(EnvironmentVariableCollectionApi, "/apps/<uuid:app_id>/workflows/draft/environment-variables")
|
||||||
@ -0,0 +1,196 @@
|
|||||||
|
import datetime
|
||||||
|
import uuid
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
from flask_restful import marshal
|
||||||
|
|
||||||
|
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||||
|
from factories.variable_factory import build_segment
|
||||||
|
from models.workflow import WorkflowDraftVariable
|
||||||
|
from services.workflow_draft_variable_service import WorkflowDraftVariableList
|
||||||
|
|
||||||
|
from .workflow_draft_variable import (
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_FIELDS,
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS,
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS,
|
||||||
|
_WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
_TEST_APP_ID = "test_app_id"
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkflowDraftVariableFields:
|
||||||
|
def test_conversation_variable(self):
|
||||||
|
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||||
|
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
|
||||||
|
)
|
||||||
|
|
||||||
|
conv_var.id = str(uuid.uuid4())
|
||||||
|
conv_var.visible = True
|
||||||
|
|
||||||
|
expected_without_value = OrderedDict(
|
||||||
|
{
|
||||||
|
"id": str(conv_var.id),
|
||||||
|
"type": conv_var.get_variable_type().value,
|
||||||
|
"name": "conv_var",
|
||||||
|
"description": "",
|
||||||
|
"selector": [CONVERSATION_VARIABLE_NODE_ID, "conv_var"],
|
||||||
|
"value_type": "number",
|
||||||
|
"edited": False,
|
||||||
|
"visible": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||||
|
expected_with_value = expected_without_value.copy()
|
||||||
|
expected_with_value["value"] = 1
|
||||||
|
assert marshal(conv_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||||
|
|
||||||
|
def test_create_sys_variable(self):
|
||||||
|
sys_var = WorkflowDraftVariable.new_sys_variable(
|
||||||
|
app_id=_TEST_APP_ID,
|
||||||
|
name="sys_var",
|
||||||
|
value=build_segment("a"),
|
||||||
|
editable=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
sys_var.id = str(uuid.uuid4())
|
||||||
|
sys_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||||
|
sys_var.visible = True
|
||||||
|
|
||||||
|
expected_without_value = OrderedDict(
|
||||||
|
{
|
||||||
|
"id": str(sys_var.id),
|
||||||
|
"type": sys_var.get_variable_type().value,
|
||||||
|
"name": "sys_var",
|
||||||
|
"description": "",
|
||||||
|
"selector": [SYSTEM_VARIABLE_NODE_ID, "sys_var"],
|
||||||
|
"value_type": "string",
|
||||||
|
"edited": True,
|
||||||
|
"visible": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||||
|
expected_with_value = expected_without_value.copy()
|
||||||
|
expected_with_value["value"] = "a"
|
||||||
|
assert marshal(sys_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||||
|
|
||||||
|
def test_node_variable(self):
|
||||||
|
node_var = WorkflowDraftVariable.new_node_variable(
|
||||||
|
app_id=_TEST_APP_ID,
|
||||||
|
node_id="test_node",
|
||||||
|
name="node_var",
|
||||||
|
value=build_segment([1, "a"]),
|
||||||
|
visible=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
node_var.id = str(uuid.uuid4())
|
||||||
|
node_var.last_edited_at = datetime.datetime.now(datetime.UTC).replace(tzinfo=None)
|
||||||
|
|
||||||
|
expected_without_value = OrderedDict(
|
||||||
|
{
|
||||||
|
"id": str(node_var.id),
|
||||||
|
"type": node_var.get_variable_type().value,
|
||||||
|
"name": "node_var",
|
||||||
|
"description": "",
|
||||||
|
"selector": ["test_node", "node_var"],
|
||||||
|
"value_type": "array[any]",
|
||||||
|
"edited": True,
|
||||||
|
"visible": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS) == expected_without_value
|
||||||
|
expected_with_value = expected_without_value.copy()
|
||||||
|
expected_with_value["value"] = [1, "a"]
|
||||||
|
assert marshal(node_var, _WORKFLOW_DRAFT_VARIABLE_FIELDS) == expected_with_value
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkflowDraftVariableList:
|
||||||
|
def test_workflow_draft_variable_list(self):
|
||||||
|
class TestCase(NamedTuple):
|
||||||
|
name: str
|
||||||
|
var_list: WorkflowDraftVariableList
|
||||||
|
expected: dict
|
||||||
|
|
||||||
|
node_var = WorkflowDraftVariable.new_node_variable(
|
||||||
|
app_id=_TEST_APP_ID,
|
||||||
|
node_id="test_node",
|
||||||
|
name="test_var",
|
||||||
|
value=build_segment("a"),
|
||||||
|
visible=True,
|
||||||
|
)
|
||||||
|
node_var.id = str(uuid.uuid4())
|
||||||
|
node_var_dict = OrderedDict(
|
||||||
|
{
|
||||||
|
"id": str(node_var.id),
|
||||||
|
"type": node_var.get_variable_type().value,
|
||||||
|
"name": "test_var",
|
||||||
|
"description": "",
|
||||||
|
"selector": ["test_node", "test_var"],
|
||||||
|
"value_type": "string",
|
||||||
|
"edited": False,
|
||||||
|
"visible": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
cases = [
|
||||||
|
TestCase(
|
||||||
|
name="empty variable list",
|
||||||
|
var_list=WorkflowDraftVariableList(variables=[]),
|
||||||
|
expected=OrderedDict(
|
||||||
|
{
|
||||||
|
"items": [],
|
||||||
|
"total": None,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="empty variable list with total",
|
||||||
|
var_list=WorkflowDraftVariableList(variables=[], total=10),
|
||||||
|
expected=OrderedDict(
|
||||||
|
{
|
||||||
|
"items": [],
|
||||||
|
"total": 10,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="non-empty variable list",
|
||||||
|
var_list=WorkflowDraftVariableList(variables=[node_var], total=None),
|
||||||
|
expected=OrderedDict(
|
||||||
|
{
|
||||||
|
"items": [node_var_dict],
|
||||||
|
"total": None,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
),
|
||||||
|
TestCase(
|
||||||
|
name="non-empty variable list with total",
|
||||||
|
var_list=WorkflowDraftVariableList(variables=[node_var], total=10),
|
||||||
|
expected=OrderedDict(
|
||||||
|
{
|
||||||
|
"items": [node_var_dict],
|
||||||
|
"total": 10,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for idx, case in enumerate(cases, 1):
|
||||||
|
assert marshal(case.var_list, _WORKFLOW_DRAFT_VARIABLE_LIST_WITHOUT_VALUE_FIELDS) == case.expected, (
|
||||||
|
f"Test case {idx} failed, {case.name=}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_workflow_node_variables_fields():
|
||||||
|
conv_var = WorkflowDraftVariable.new_conversation_variable(
|
||||||
|
app_id=_TEST_APP_ID, name="conv_var", value=build_segment(1)
|
||||||
|
)
|
||||||
|
resp = marshal(WorkflowDraftVariableList(variables=[conv_var]), _WORKFLOW_DRAFT_VARIABLE_LIST_FIELDS)
|
||||||
|
assert isinstance(resp, dict)
|
||||||
|
assert len(resp["items"]) == 1
|
||||||
|
item_dict = resp["items"][0]
|
||||||
|
assert item_dict["name"] == "conv_var"
|
||||||
|
assert item_dict["value"] == 1
|
||||||
@ -1 +1,21 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
FILE_MODEL_IDENTITY = "__dify__file__"
|
FILE_MODEL_IDENTITY = "__dify__file__"
|
||||||
|
|
||||||
|
# DUMMY_OUTPUT_IDENTITY is a placeholder output for workflow nodes.
|
||||||
|
# Its sole possible value is `None`.
|
||||||
|
#
|
||||||
|
# This is used to signal the execution of a workflow node when it has no other outputs.
|
||||||
|
_DUMMY_OUTPUT_IDENTITY = "__dummy__"
|
||||||
|
_DUMMY_OUTPUT_VALUE: None = None
|
||||||
|
|
||||||
|
|
||||||
|
def add_dummy_output(original: dict[str, Any] | None) -> dict[str, Any]:
|
||||||
|
if original is None:
|
||||||
|
original = {}
|
||||||
|
original[_DUMMY_OUTPUT_IDENTITY] = _DUMMY_OUTPUT_VALUE
|
||||||
|
return original
|
||||||
|
|
||||||
|
|
||||||
|
def is_dummy_output_variable(name: str) -> bool:
|
||||||
|
return name == _DUMMY_OUTPUT_IDENTITY
|
||||||
|
|||||||
@ -0,0 +1,46 @@
|
|||||||
|
import uuid
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from controllers.console.app import workflow_draft_variable as draft_variable_api
|
||||||
|
from controllers.console.app import wraps
|
||||||
|
from factories.variable_factory import build_segment
|
||||||
|
from models import App, AppMode
|
||||||
|
from models.workflow import WorkflowDraftVariable
|
||||||
|
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mock_srv_class() -> type[WorkflowDraftVariableService]:
|
||||||
|
return mock.create_autospec(WorkflowDraftVariableService)
|
||||||
|
|
||||||
|
|
||||||
|
class TestWorkflowDraftNodeVariableListApi:
|
||||||
|
def test_get(self, test_client, auth_header, monkeypatch):
|
||||||
|
srv_class = _get_mock_srv_class()
|
||||||
|
mock_app_model: App = App()
|
||||||
|
mock_app_model.id = str(uuid.uuid4())
|
||||||
|
test_node_id = "test_node_id"
|
||||||
|
mock_app_model.mode = AppMode.ADVANCED_CHAT
|
||||||
|
mock_load_app_model = mock.Mock(return_value=mock_app_model)
|
||||||
|
|
||||||
|
monkeypatch.setattr(draft_variable_api, "WorkflowDraftVariableService", srv_class)
|
||||||
|
monkeypatch.setattr(wraps, "_load_app_model", mock_load_app_model)
|
||||||
|
|
||||||
|
var1 = WorkflowDraftVariable.create_node_variable(
|
||||||
|
app_id="test_app_1",
|
||||||
|
node_id="test_node_1",
|
||||||
|
name="str_var",
|
||||||
|
value=build_segment("str_value"),
|
||||||
|
)
|
||||||
|
srv_instance = mock.create_autospec(WorkflowDraftVariableService, instance=True)
|
||||||
|
srv_class.return_value = srv_instance
|
||||||
|
srv_instance.list_node_variables.return_value = WorkflowDraftVariableList(variables=[var1])
|
||||||
|
|
||||||
|
response = test_client.get(
|
||||||
|
f"/console/api/apps/{mock_app_model.id}/workflows/draft/nodes/{test_node_id}/variables",
|
||||||
|
headers=auth_header,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
response_dict = response.json
|
||||||
|
assert isinstance(response_dict, dict)
|
||||||
|
assert "items" in response_dict
|
||||||
|
assert len(response_dict["items"]) == 1
|
||||||
@ -0,0 +1,142 @@
|
|||||||
|
import unittest
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from factories.variable_factory import build_segment
|
||||||
|
from models import db
|
||||||
|
from models.workflow import WorkflowDraftVariable
|
||||||
|
from services.workflow_draft_variable_service import WorkflowDraftVariableService
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("flask_req_ctx")
|
||||||
|
class TestWorkflowDraftVariableService(unittest.TestCase):
|
||||||
|
_test_app_id: str
|
||||||
|
_session: Session
|
||||||
|
_node2_id = "test_node_2"
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self._test_app_id = str(uuid.uuid4())
|
||||||
|
self._session: Session = db.session
|
||||||
|
sys_var = WorkflowDraftVariable.create_sys_variable(
|
||||||
|
app_id=self._test_app_id,
|
||||||
|
name="sys_var",
|
||||||
|
value=build_segment("sys_value"),
|
||||||
|
)
|
||||||
|
conv_var = WorkflowDraftVariable.create_conversation_variable(
|
||||||
|
app_id=self._test_app_id,
|
||||||
|
name="conv_var",
|
||||||
|
value=build_segment("conv_value"),
|
||||||
|
)
|
||||||
|
node2_vars = [
|
||||||
|
WorkflowDraftVariable.create_node_variable(
|
||||||
|
app_id=self._test_app_id,
|
||||||
|
node_id=self._node2_id,
|
||||||
|
name="int_var",
|
||||||
|
value=build_segment(1),
|
||||||
|
visible=False,
|
||||||
|
),
|
||||||
|
WorkflowDraftVariable.create_node_variable(
|
||||||
|
app_id=self._test_app_id,
|
||||||
|
node_id=self._node2_id,
|
||||||
|
name="str_var",
|
||||||
|
value=build_segment("str_value"),
|
||||||
|
visible=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
node1_var = WorkflowDraftVariable.create_node_variable(
|
||||||
|
app_id=self._test_app_id,
|
||||||
|
node_id="node_1",
|
||||||
|
name="str_var",
|
||||||
|
value=build_segment("str_value"),
|
||||||
|
visible=True,
|
||||||
|
)
|
||||||
|
_variables = list(node2_vars)
|
||||||
|
_variables.extend(
|
||||||
|
[
|
||||||
|
node1_var,
|
||||||
|
sys_var,
|
||||||
|
conv_var,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
db.session.add_all(_variables)
|
||||||
|
db.session.flush()
|
||||||
|
self._variable_ids = [v.id for v in _variables]
|
||||||
|
self._node1_str_var_id = node1_var.id
|
||||||
|
self._sys_var_id = sys_var.id
|
||||||
|
self._conv_var_id = conv_var.id
|
||||||
|
self._node2_var_ids = [v.id for v in node2_vars]
|
||||||
|
|
||||||
|
def _get_test_srv(self) -> WorkflowDraftVariableService:
|
||||||
|
return WorkflowDraftVariableService(session=self._session)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self._session.rollback()
|
||||||
|
|
||||||
|
def test_list_variables(self):
|
||||||
|
srv = self._get_test_srv()
|
||||||
|
var_list = srv.list_variables_without_values(self._test_app_id, page=1, limit=2)
|
||||||
|
assert var_list.total == 5
|
||||||
|
assert len(var_list.variables) == 2
|
||||||
|
page1_var_ids = {v.id for v in var_list.variables}
|
||||||
|
assert page1_var_ids.issubset(self._variable_ids)
|
||||||
|
|
||||||
|
var_list_2 = srv.list_variables_without_values(self._test_app_id, page=2, limit=2)
|
||||||
|
assert var_list_2.total is None
|
||||||
|
assert len(var_list_2.variables) == 2
|
||||||
|
page2_var_ids = {v.id for v in var_list_2.variables}
|
||||||
|
assert page2_var_ids.isdisjoint(page1_var_ids)
|
||||||
|
assert page2_var_ids.issubset(self._variable_ids)
|
||||||
|
|
||||||
|
def test_get_node_variable(self):
|
||||||
|
srv = self._get_test_srv()
|
||||||
|
node_var = srv.get_node_variable(self._test_app_id, "node_1", "str_var")
|
||||||
|
assert node_var.id == self._node1_str_var_id
|
||||||
|
assert node_var.name == "str_var"
|
||||||
|
assert node_var.get_value() == build_segment("str_value")
|
||||||
|
|
||||||
|
def test_get_system_variable(self):
|
||||||
|
srv = self._get_test_srv()
|
||||||
|
sys_var = srv.get_system_variable(self._test_app_id, "sys_var")
|
||||||
|
assert sys_var.id == self._sys_var_id
|
||||||
|
assert sys_var.name == "sys_var"
|
||||||
|
assert sys_var.get_value() == build_segment("sys_value")
|
||||||
|
|
||||||
|
def test_get_conversation_variable(self):
|
||||||
|
srv = self._get_test_srv()
|
||||||
|
conv_var = srv.get_conversation_variable(self._test_app_id, "conv_var")
|
||||||
|
assert conv_var.id == self._conv_var_id
|
||||||
|
assert conv_var.name == "conv_var"
|
||||||
|
assert conv_var.get_value() == build_segment("conv_value")
|
||||||
|
|
||||||
|
def test_delete_node_variables(self):
|
||||||
|
srv = self._get_test_srv()
|
||||||
|
srv.delete_node_variables(self._test_app_id, self._node2_id)
|
||||||
|
node2_var_count = (
|
||||||
|
self._session.query(WorkflowDraftVariable)
|
||||||
|
.where(
|
||||||
|
WorkflowDraftVariable.app_id == self._test_app_id,
|
||||||
|
WorkflowDraftVariable.node_id == self._node2_id,
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
assert node2_var_count == 0
|
||||||
|
|
||||||
|
def test_delete_variable(self):
|
||||||
|
srv = self._get_test_srv()
|
||||||
|
node_1_var = (
|
||||||
|
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).one()
|
||||||
|
)
|
||||||
|
srv.delete_variable(node_1_var)
|
||||||
|
exists = bool(
|
||||||
|
self._session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id == self._node1_str_var_id).first()
|
||||||
|
)
|
||||||
|
assert exists is False
|
||||||
|
|
||||||
|
def test__list_node_variables(self):
|
||||||
|
srv = self._get_test_srv()
|
||||||
|
node_vars = srv._list_node_variables(self._test_app_id, self._node2_id)
|
||||||
|
assert len(node_vars) == 2
|
||||||
|
assert {v.id for v in node_vars} == set(self._node2_var_ids)
|
||||||
@ -0,0 +1,25 @@
|
|||||||
|
from core.file import File, FileTransferMethod, FileType
|
||||||
|
|
||||||
|
|
||||||
|
def test_file():
|
||||||
|
file = File(
|
||||||
|
id="test-file",
|
||||||
|
tenant_id="test-tenant-id",
|
||||||
|
type=FileType.IMAGE,
|
||||||
|
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||||
|
related_id="test-related-id",
|
||||||
|
filename="image.png",
|
||||||
|
extension=".png",
|
||||||
|
mime_type="image/png",
|
||||||
|
size=67,
|
||||||
|
storage_key="test-storage-key",
|
||||||
|
url="https://example.com/image.png",
|
||||||
|
)
|
||||||
|
assert file.tenant_id == "test-tenant-id"
|
||||||
|
assert file.type == FileType.IMAGE
|
||||||
|
assert file.transfer_method == FileTransferMethod.TOOL_FILE
|
||||||
|
assert file.related_id == "test-related-id"
|
||||||
|
assert file.filename == "image.png"
|
||||||
|
assert file.extension == ".png"
|
||||||
|
assert file.mime_type == "image/png"
|
||||||
|
assert file.size == 67
|
||||||
Loading…
Reference in New Issue