feat: add backwards invoke node api

pull/9184/head
Yeuoly 2 years ago
parent 592f85f7a9
commit 68c10a1672
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -8,13 +8,15 @@ from controllers.inner_api.plugin.wraps import get_tenant, plugin_data
from controllers.inner_api.wraps import plugin_inner_api_only from controllers.inner_api.wraps import plugin_inner_api_only
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
from core.plugin.backwards_invocation.node import PluginNodeBackwardsInvocation
from core.plugin.encrypt import PluginEncrypter from core.plugin.encrypt import PluginEncrypter
from core.plugin.entities.request import ( from core.plugin.entities.request import (
RequestInvokeApp, RequestInvokeApp,
RequestInvokeEncrypt, RequestInvokeEncrypt,
RequestInvokeLLM, RequestInvokeLLM,
RequestInvokeModeration, RequestInvokeModeration,
RequestInvokeNode, RequestInvokeParameterExtractorNode,
RequestInvokeQuestionClassifierNode,
RequestInvokeRerank, RequestInvokeRerank,
RequestInvokeSpeech2Text, RequestInvokeSpeech2Text,
RequestInvokeTextEmbedding, RequestInvokeTextEmbedding,
@ -96,23 +98,46 @@ class PluginInvokeToolApi(Resource):
yield ( yield (
ToolInvokeMessage( ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.TEXT, type=ToolInvokeMessage.MessageType.TEXT,
message=ToolInvokeMessage.TextMessage(text='helloworld'), message=ToolInvokeMessage.TextMessage(text="helloworld"),
) )
.model_dump_json() .model_dump_json()
.encode() .encode()
+ b'\n\n' + b"\n\n"
) )
return compact_generate_response(generator()) return compact_generate_response(generator())
class PluginInvokeNodeApi(Resource): class PluginInvokeParameterExtractorNodeApi(Resource):
@setup_required @setup_required
@plugin_inner_api_only @plugin_inner_api_only
@get_tenant @get_tenant
@plugin_data(payload_type=RequestInvokeNode) @plugin_data(payload_type=RequestInvokeParameterExtractorNode)
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeNode): def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeParameterExtractorNode):
pass return PluginNodeBackwardsInvocation.invoke_parameter_extractor(
tenant_id=tenant_model.id,
user_id=user_id,
parameters=payload.parameters,
model_config=payload.model,
instruction=payload.instruction,
query=payload.query,
)
class PluginInvokeQuestionClassifierNodeApi(Resource):
@setup_required
@plugin_inner_api_only
@get_tenant
@plugin_data(payload_type=RequestInvokeQuestionClassifierNode)
def post(self, user_id: str, tenant_model: Tenant, payload: RequestInvokeQuestionClassifierNode):
return PluginNodeBackwardsInvocation.invoke_question_classifier(
tenant_id=tenant_model.id,
user_id=user_id,
query=payload.query,
model_config=payload.model,
classes=payload.classes,
instruction=payload.instruction,
)
class PluginInvokeAppApi(Resource): class PluginInvokeAppApi(Resource):
@ -127,15 +152,13 @@ class PluginInvokeAppApi(Resource):
tenant_id=tenant_model.id, tenant_id=tenant_model.id,
conversation_id=payload.conversation_id, conversation_id=payload.conversation_id,
query=payload.query, query=payload.query,
stream=payload.response_mode == 'streaming', stream=payload.response_mode == "streaming",
inputs=payload.inputs, inputs=payload.inputs,
files=payload.files files=payload.files,
)
return compact_generate_response(
PluginAppBackwardsInvocation.convert_to_event_stream(response)
) )
return compact_generate_response(PluginAppBackwardsInvocation.convert_to_event_stream(response))
class PluginInvokeEncryptApi(Resource): class PluginInvokeEncryptApi(Resource):
@setup_required @setup_required
@ -149,13 +172,14 @@ class PluginInvokeEncryptApi(Resource):
return PluginEncrypter.invoke_encrypt(tenant_model, payload) return PluginEncrypter.invoke_encrypt(tenant_model, payload)
api.add_resource(PluginInvokeLLMApi, '/invoke/llm') api.add_resource(PluginInvokeLLMApi, "/invoke/llm")
api.add_resource(PluginInvokeTextEmbeddingApi, '/invoke/text-embedding') api.add_resource(PluginInvokeTextEmbeddingApi, "/invoke/text-embedding")
api.add_resource(PluginInvokeRerankApi, '/invoke/rerank') api.add_resource(PluginInvokeRerankApi, "/invoke/rerank")
api.add_resource(PluginInvokeTTSApi, '/invoke/tts') api.add_resource(PluginInvokeTTSApi, "/invoke/tts")
api.add_resource(PluginInvokeSpeech2TextApi, '/invoke/speech2text') api.add_resource(PluginInvokeSpeech2TextApi, "/invoke/speech2text")
api.add_resource(PluginInvokeModerationApi, '/invoke/moderation') api.add_resource(PluginInvokeModerationApi, "/invoke/moderation")
api.add_resource(PluginInvokeToolApi, '/invoke/tool') api.add_resource(PluginInvokeToolApi, "/invoke/tool")
api.add_resource(PluginInvokeNodeApi, '/invoke/node') api.add_resource(PluginInvokeParameterExtractorNodeApi, "/invoke/parameter-extractor")
api.add_resource(PluginInvokeAppApi, '/invoke/app') api.add_resource(PluginInvokeQuestionClassifierNodeApi, "/invoke/question-classifier")
api.add_resource(PluginInvokeEncryptApi, '/invoke/encrypt') api.add_resource(PluginInvokeAppApi, "/invoke/app")
api.add_resource(PluginInvokeEncryptApi, "/invoke/encrypt")

@ -0,0 +1,114 @@
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from core.workflow.nodes.parameter_extractor.entities import (
ModelConfig as ParameterExtractorModelConfig,
)
from core.workflow.nodes.parameter_extractor.entities import (
ParameterConfig,
ParameterExtractorNodeData,
)
from core.workflow.nodes.question_classifier.entities import (
ClassConfig,
QuestionClassifierNodeData,
)
from core.workflow.nodes.question_classifier.entities import (
ModelConfig as QuestionClassifierModelConfig,
)
from services.workflow_service import WorkflowService
class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
@classmethod
def invoke_parameter_extractor(
cls,
tenant_id: str,
user_id: str,
parameters: list[ParameterConfig],
model_config: ParameterExtractorModelConfig,
instruction: str,
query: str,
) -> dict:
"""
Invoke parameter extractor node.
:param tenant_id: str
:param user_id: str
:param parameters: list[ParameterConfig]
:param model_config: ModelConfig
:param instruction: str
:param query: str
:return: dict with __reason, __is_success, and other parameters
"""
workflow_service = WorkflowService()
node_id = "1919810"
node_data = ParameterExtractorNodeData(
title="parameter_extractor",
desc="parameter_extractor",
parameters=parameters,
reasoning_mode="function_call",
query=[node_id, "query"],
model=model_config,
instruction=instruction, # instruct with variables are not supported
)
node_data_dict = node_data.model_dump()
execution = workflow_service.run_free_workflow_node(
node_data_dict,
tenant_id=tenant_id,
user_id=user_id,
node_id=node_id,
user_inputs={
f"{node_id}.query": query,
},
)
output = execution.outputs_dict
return output or {
"__reason": "No parameters extracted",
"__is_success": False,
}
@classmethod
def invoke_question_classifier(
cls,
tenant_id: str,
user_id: str,
model_config: QuestionClassifierModelConfig,
classes: list[ClassConfig],
instruction: str,
query: str,
) -> dict:
"""
Invoke question classifier node.
:param tenant_id: str
:param user_id: str
:param model_config: ModelConfig
:param classes: list[ClassConfig]
:param instruction: str
:param query: str
:return: dict with class_name
"""
workflow_service = WorkflowService()
node_id = "1919810"
node_data = QuestionClassifierNodeData(
title="question_classifier",
desc="question_classifier",
query_variable_selector=[node_id, "query"],
model=model_config,
classes=classes,
instruction=instruction, # instruct with variables are not supported
)
node_data_dict = node_data.model_dump()
execution = workflow_service.run_free_workflow_node(
node_data_dict,
tenant_id=tenant_id,
user_id=user_id,
node_id=node_id,
user_inputs={
f"{node_id}.query": query,
},
)
output = execution.outputs_dict
return output or {
"class_name": classes[0].name,
}

@ -14,6 +14,16 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.workflow.nodes.question_classifier.entities import (
ClassConfig,
ModelConfig as QuestionClassifierModelConfig,
)
from core.workflow.nodes.parameter_extractor.entities import (
ModelConfig as ParameterExtractorModelConfig,
)
from core.workflow.nodes.parameter_extractor.entities import (
ParameterConfig,
)
class RequestInvokeTool(BaseModel): class RequestInvokeTool(BaseModel):
@ -92,11 +102,27 @@ class RequestInvokeModeration(BaseModel):
""" """
class RequestInvokeNode(BaseModel): class RequestInvokeParameterExtractorNode(BaseModel):
""" """
Request to invoke node Request to invoke parameter extractor node
""" """
parameters: list[ParameterConfig]
model: ParameterExtractorModelConfig
instruction: str
query: str
class RequestInvokeQuestionClassifierNode(BaseModel):
"""
Request to invoke question classifier node
"""
query: str
model: QuestionClassifierModelConfig
classes: list[ClassConfig]
instruction: str
class RequestInvokeApp(BaseModel): class RequestInvokeApp(BaseModel):
""" """

@ -205,6 +205,88 @@ class WorkflowEntry:
except Exception as e: except Exception as e:
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e)) raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
@classmethod
def run_free_node(
cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
) -> tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]:
"""
Run free node
NOTE: only parameter_extractor/question_classifier are supported
:param node_data: node data
:param user_id: user id
:param user_inputs: user inputs
:return:
"""
# generate a fake graph
node_config = {"id": node_id, "width": 114, "height": 514, "type": "custom", "data": node_data}
graph_dict = {
"nodes": [node_config],
}
node_type = NodeType.value_of(node_data.get("type", ""))
if node_type not in {NodeType.PARAMETER_EXTRACTOR, NodeType.QUESTION_CLASSIFIER}:
raise ValueError(f"Node type {node_type} not supported")
node_cls = node_classes.get(node_type)
if not node_cls:
raise ValueError(f"Node class not found for node type {node_type}")
graph = Graph.init(graph_config=graph_dict)
# init variable pool
variable_pool = VariablePool(
system_variables={},
user_inputs={},
environment_variables=[],
)
node_cls = cast(type[BaseNode], node_cls)
# init workflow run state
node_instance: BaseNode = node_cls(
id=str(uuid.uuid4()),
config=node_config,
graph_init_params=GraphInitParams(
tenant_id=tenant_id,
app_id="",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="",
graph_config=graph_dict,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
)
try:
# variable selector to variable mapping
try:
variable_mapping = node_cls.extract_variable_selector_to_variable_mapping(
graph_config=graph_dict, config=node_config
)
except NotImplementedError:
variable_mapping = {}
cls.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping,
user_inputs=user_inputs,
variable_pool=variable_pool,
tenant_id=tenant_id,
node_type=node_type,
node_data=node_instance.node_data,
)
# run node
generator = node_instance.run()
return node_instance, generator
except Exception as e:
raise WorkflowNodeRunFailedError(node_instance=node_instance, error=str(e))
@classmethod @classmethod
def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]: def handle_special_values(cls, value: Optional[Mapping[str, Any]]) -> Optional[dict]:
""" """

@ -1,8 +1,8 @@
import json import json
import time import time
from collections.abc import Sequence from collections.abc import Callable, Generator, Sequence
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Optional from typing import Any, Optional
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
@ -10,7 +10,9 @@ from core.app.segments import Variable
from core.model_runtime.utils.encoders import jsonable_encoder from core.model_runtime.utils.encoders import jsonable_encoder
from core.workflow.entities.node_entities import NodeRunResult, NodeType from core.workflow.entities.node_entities import NodeRunResult, NodeType
from core.workflow.errors import WorkflowNodeRunFailedError from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.nodes.event import RunCompletedEvent from core.workflow.graph_engine.entities.event import InNodeEvent
from core.workflow.nodes.base_node import BaseNode
from core.workflow.nodes.event import RunCompletedEvent, RunEvent
from core.workflow.nodes.node_mapping import node_classes from core.workflow.nodes.node_mapping import node_classes
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
@ -216,13 +218,64 @@ class WorkflowService:
# run draft workflow node # run draft workflow node
start_at = time.perf_counter() start_at = time.perf_counter()
try: workflow_node_execution = self._handle_node_run_result(
node_instance, generator = WorkflowEntry.single_step_run( getter=lambda: WorkflowEntry.single_step_run(
workflow=draft_workflow, workflow=draft_workflow,
node_id=node_id, node_id=node_id,
user_inputs=user_inputs, user_inputs=user_inputs,
user_id=account.id, user_id=account.id,
) ),
start_at=start_at,
tenant_id=app_model.tenant_id,
node_id=node_id,
)
db.session.add(workflow_node_execution)
db.session.commit()
return workflow_node_execution
def run_free_workflow_node(
self, node_data: dict, tenant_id: str, user_id: str, node_id: str, user_inputs: dict[str, Any]
) -> WorkflowNodeExecution:
"""
Run draft workflow node
"""
# run draft workflow node
start_at = time.perf_counter()
workflow_node_execution = self._handle_node_run_result(
getter=lambda: WorkflowEntry.run_free_node(
node_id=node_id,
node_data=node_data,
tenant_id=tenant_id,
user_id=user_id,
user_inputs=user_inputs,
),
start_at=start_at,
tenant_id=tenant_id,
node_id=node_id
)
return workflow_node_execution
def _handle_node_run_result(
self,
getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]],
start_at: float,
tenant_id: str,
node_id: str,
):
"""
Handle node run result
:param getter: Callable[[], tuple[BaseNode, Generator[RunEvent | InNodeEvent, None, None]]]
:param start_at: float
:param tenant_id: str
:param node_id: str
"""
try:
node_instance, generator = getter()
node_run_result: NodeRunResult | None = None node_run_result: NodeRunResult | None = None
for event in generator: for event in generator:
@ -245,9 +298,7 @@ class WorkflowService:
error = e.error error = e.error
workflow_node_execution = WorkflowNodeExecution() workflow_node_execution = WorkflowNodeExecution()
workflow_node_execution.tenant_id = app_model.tenant_id workflow_node_execution.tenant_id = tenant_id
workflow_node_execution.app_id = app_model.id
workflow_node_execution.workflow_id = draft_workflow.id
workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value workflow_node_execution.triggered_from = WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP.value
workflow_node_execution.index = 1 workflow_node_execution.index = 1
workflow_node_execution.node_id = node_id workflow_node_execution.node_id = node_id
@ -255,7 +306,6 @@ class WorkflowService:
workflow_node_execution.title = node_instance.node_data.title workflow_node_execution.title = node_instance.node_data.title
workflow_node_execution.elapsed_time = time.perf_counter() - start_at workflow_node_execution.elapsed_time = time.perf_counter() - start_at
workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value workflow_node_execution.created_by_role = CreatedByRole.ACCOUNT.value
workflow_node_execution.created_by = account.id
workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.created_at = datetime.now(timezone.utc).replace(tzinfo=None)
workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None) workflow_node_execution.finished_at = datetime.now(timezone.utc).replace(tzinfo=None)
@ -277,9 +327,6 @@ class WorkflowService:
workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value workflow_node_execution.status = WorkflowNodeExecutionStatus.FAILED.value
workflow_node_execution.error = error workflow_node_execution.error = error
db.session.add(workflow_node_execution)
db.session.commit()
return workflow_node_execution return workflow_node_execution
def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App: def convert_to_workflow(self, app_model: App, account: Account, args: dict) -> App:
@ -302,10 +349,10 @@ class WorkflowService:
new_app = workflow_converter.convert_to_workflow( new_app = workflow_converter.convert_to_workflow(
app_model=app_model, app_model=app_model,
account=account, account=account,
name=args.get("name"), name=args.get("name", ""),
icon_type=args.get("icon_type"), icon_type=args.get("icon_type", ""),
icon=args.get("icon"), icon=args.get("icon", ""),
icon_background=args.get("icon_background"), icon_background=args.get("icon_background", ""),
) )
return new_app return new_app

Loading…
Cancel
Save