feat: invoke node

pull/9184/head
Yeuoly 2 years ago
parent 68c10a1672
commit a91951b374
No known key found for this signature in database
GPG Key ID: A66E7E320FB19F61

@ -1,4 +1,5 @@
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from core.workflow.entities.node_entities import NodeType
from core.workflow.nodes.parameter_extractor.entities import ( from core.workflow.nodes.parameter_extractor.entities import (
ModelConfig as ParameterExtractorModelConfig, ModelConfig as ParameterExtractorModelConfig,
) )
@ -36,7 +37,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
:param model_config: ModelConfig :param model_config: ModelConfig
:param instruction: str :param instruction: str
:param query: str :param query: str
:return: dict with __reason, __is_success, and other parameters :return: dict
""" """
workflow_service = WorkflowService() workflow_service = WorkflowService()
node_id = "1919810" node_id = "1919810"
@ -50,6 +51,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
instruction=instruction, # instruct with variables are not supported instruction=instruction, # instruct with variables are not supported
) )
node_data_dict = node_data.model_dump() node_data_dict = node_data.model_dump()
node_data_dict["type"] = NodeType.PARAMETER_EXTRACTOR.value
execution = workflow_service.run_free_workflow_node( execution = workflow_service.run_free_workflow_node(
node_data_dict, node_data_dict,
tenant_id=tenant_id, tenant_id=tenant_id,
@ -60,10 +62,10 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
}, },
) )
output = execution.outputs_dict return {
return output or { "inputs": execution.inputs_dict,
"__reason": "No parameters extracted", "outputs": execution.outputs_dict,
"__is_success": False, "process_data": execution.process_data_dict,
} }
@classmethod @classmethod
@ -85,7 +87,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
:param classes: list[ClassConfig] :param classes: list[ClassConfig]
:param instruction: str :param instruction: str
:param query: str :param query: str
:return: dict with class_name :return: dict
""" """
workflow_service = WorkflowService() workflow_service = WorkflowService()
node_id = "1919810" node_id = "1919810"
@ -108,7 +110,8 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation):
}, },
) )
output = execution.outputs_dict return {
return output or { "inputs": execution.inputs_dict,
"class_name": classes[0].name, "outputs": execution.outputs_dict,
"process_data": execution.process_data_dict,
} }

@ -14,16 +14,18 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage, UserPromptMessage,
) )
from core.model_runtime.entities.model_entities import ModelType from core.model_runtime.entities.model_entities import ModelType
from core.workflow.nodes.question_classifier.entities import (
ClassConfig,
ModelConfig as QuestionClassifierModelConfig,
)
from core.workflow.nodes.parameter_extractor.entities import ( from core.workflow.nodes.parameter_extractor.entities import (
ModelConfig as ParameterExtractorModelConfig, ModelConfig as ParameterExtractorModelConfig,
) )
from core.workflow.nodes.parameter_extractor.entities import ( from core.workflow.nodes.parameter_extractor.entities import (
ParameterConfig, ParameterConfig,
) )
from core.workflow.nodes.question_classifier.entities import (
ClassConfig,
)
from core.workflow.nodes.question_classifier.entities import (
ModelConfig as QuestionClassifierModelConfig,
)
class RequestInvokeTool(BaseModel): class RequestInvokeTool(BaseModel):

@ -221,8 +221,27 @@ class WorkflowEntry:
""" """
# generate a fake graph # generate a fake graph
node_config = {"id": node_id, "width": 114, "height": 514, "type": "custom", "data": node_data} node_config = {"id": node_id, "width": 114, "height": 514, "type": "custom", "data": node_data}
start_node_config = {
"id": "start",
"width": 114,
"height": 514,
"type": "custom",
"data": {
"type": NodeType.START.value,
"title": "Start",
"desc": "Start",
},
}
graph_dict = { graph_dict = {
"nodes": [node_config], "nodes": [start_node_config, node_config],
"edges": [
{
"source": "start",
"target": node_id,
"sourceHandle": "source",
"targetHandle": "target",
}
],
} }
node_type = NodeType.value_of(node_data.get("type", "")) node_type = NodeType.value_of(node_data.get("type", ""))

@ -230,6 +230,10 @@ class WorkflowService:
node_id=node_id, node_id=node_id,
) )
workflow_node_execution.app_id = app_model.id
workflow_node_execution.created_by = account.id
workflow_node_execution.workflow_id = draft_workflow.id
db.session.add(workflow_node_execution) db.session.add(workflow_node_execution)
db.session.commit() db.session.commit()

Loading…
Cancel
Save