feat(api): inject the variable_loader all the way down for loop & iteration

# Conflicts:
#	api/core/app/apps/advanced_chat/app_generator.py
#	api/core/app/apps/workflow/app_generator.py
pull/20699/head
QuantumGhost 12 months ago
parent 31fa5baadf
commit 4c7873bdd9

@ -29,12 +29,14 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService from services.conversation_service import ConversationService
from services.errors.message import MessageNotExistsError from services.errors.message import MessageNotExistsError
from services.workflow_draft_variable_service import DraftVarLoader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -260,6 +262,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
) )
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
)
return self._generate( return self._generate(
workflow=workflow, workflow=workflow,
@ -270,6 +276,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None, conversation=None,
stream=streaming, stream=streaming,
variable_loader=var_loader,
) )
def single_loop_generate( def single_loop_generate(
@ -335,6 +342,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
) )
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
)
return self._generate( return self._generate(
workflow=workflow, workflow=workflow,
@ -345,6 +356,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
conversation=None, conversation=None,
stream=streaming, stream=streaming,
variable_loader=var_loader,
) )
def _generate( def _generate(
@ -358,6 +370,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Optional[Conversation] = None, conversation: Optional[Conversation] = None,
stream: bool = True, stream: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
""" """
Generate App response. Generate App response.
@ -410,6 +423,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation_id=conversation.id, conversation_id=conversation.id,
message_id=message.id, message_id=message.id,
context=context, context=context,
variable_loader=variable_loader,
) )
worker_thread = threading.Thread(target=worker_with_context) worker_thread = threading.Thread(target=worker_with_context)
@ -439,6 +453,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation_id: str, conversation_id: str,
message_id: str, message_id: str,
context: contextvars.Context, context: contextvars.Context,
variable_loader: VariableLoader,
) -> None: ) -> None:
""" """
Generate worker in a new thread. Generate worker in a new thread.
@ -480,6 +495,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
conversation=conversation, conversation=conversation,
message=message, message=message,
dialogue_count=self._dialogue_count, dialogue_count=self._dialogue_count,
variable_loader=variable_loader,
) )
runner.run() runner.run()

@ -19,6 +19,7 @@ from core.moderation.base import ModerationError
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import UserFrom from models.enums import UserFrom
@ -40,9 +41,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation: Conversation, conversation: Conversation,
message: Message, message: Message,
dialogue_count: int, dialogue_count: int,
variable_loader: VariableLoader,
) -> None: ) -> None:
super().__init__(queue_manager) super().__init__(queue_manager, variable_loader)
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.conversation = conversation self.conversation = conversation
self.message = message self.message = message

@ -27,10 +27,12 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER
from extensions.ext_database import db from extensions.ext_database import db
from factories import file_factory from factories import file_factory
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.enums import WorkflowRunTriggeredFrom from models.enums import WorkflowRunTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -185,6 +187,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository, workflow_node_execution_repository: WorkflowNodeExecutionRepository,
streaming: bool = True, streaming: bool = True,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
variable_loader: DraftVarLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
""" """
Generate App response. Generate App response.
@ -219,6 +222,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
queue_manager=queue_manager, queue_manager=queue_manager,
context=context, context=context,
workflow_thread_pool_id=workflow_thread_pool_id, workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
) )
worker_thread = threading.Thread(target=worker_with_context) worker_thread = threading.Thread(target=worker_with_context)
@ -304,6 +308,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
) )
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
)
return self._generate( return self._generate(
app_model=app_model, app_model=app_model,
@ -314,6 +322,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository, workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming, streaming=streaming,
variable_loader=var_loader,
) )
def single_loop_generate( def single_loop_generate(
@ -380,7 +389,10 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP,
) )
var_loader = DraftVarLoader(
engine=db.engine,
app_id=application_generate_entity.app_config.app_id,
)
return self._generate( return self._generate(
app_model=app_model, app_model=app_model,
workflow=workflow, workflow=workflow,
@ -390,6 +402,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_execution_repository=workflow_execution_repository, workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming, streaming=streaming,
variable_loader=var_loader,
) )
def _generate_worker( def _generate_worker(
@ -398,6 +411,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
context: contextvars.Context, context: contextvars.Context,
variable_loader: DraftVarLoader,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
) -> None: ) -> None:
""" """
@ -431,6 +445,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
application_generate_entity=application_generate_entity, application_generate_entity=application_generate_entity,
queue_manager=queue_manager, queue_manager=queue_manager,
workflow_thread_pool_id=workflow_thread_pool_id, workflow_thread_pool_id=workflow_thread_pool_id,
variable_loader=variable_loader,
) )
runner.run() runner.run()

@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import (
from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback
from core.workflow.entities.variable_pool import VariablePool from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey from core.workflow.enums import SystemVariableKey
from core.workflow.variable_loader import VariableLoader
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models.enums import UserFrom from models.enums import UserFrom
@ -30,6 +31,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self, self,
application_generate_entity: WorkflowAppGenerateEntity, application_generate_entity: WorkflowAppGenerateEntity,
queue_manager: AppQueueManager, queue_manager: AppQueueManager,
variable_loader: VariableLoader,
workflow_thread_pool_id: Optional[str] = None, workflow_thread_pool_id: Optional[str] = None,
) -> None: ) -> None:
""" """
@ -37,8 +39,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
:param queue_manager: application queue manager :param queue_manager: application queue manager
:param workflow_thread_pool_id: workflow thread pool id :param workflow_thread_pool_id: workflow thread pool id
""" """
super().__init__(queue_manager, variable_loader)
self.application_generate_entity = application_generate_entity self.application_generate_entity = application_generate_entity
self.queue_manager = queue_manager
self.workflow_thread_pool_id = workflow_thread_pool_id self.workflow_thread_pool_id = workflow_thread_pool_id
def _get_app_id(self) -> str: def _get_app_id(self) -> str:
@ -80,6 +82,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow=workflow, workflow=workflow,
node_id=self.application_generate_entity.single_iteration_run.node_id, node_id=self.application_generate_entity.single_iteration_run.node_id,
user_inputs=self.application_generate_entity.single_iteration_run.inputs, user_inputs=self.application_generate_entity.single_iteration_run.inputs,
variable_loader=self._var_loader,
) )
elif self.application_generate_entity.single_loop_run: elif self.application_generate_entity.single_loop_run:
# if only single loop run is requested # if only single loop run is requested

@ -64,19 +64,20 @@ from core.workflow.graph_engine.entities.event import (
from core.workflow.graph_engine.entities.graph import Graph from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.nodes import NodeType from core.workflow.nodes import NodeType
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool
from core.workflow.workflow_entry import WorkflowEntry from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_database import db from extensions.ext_database import db
from models.model import App from models.model import App
from models.workflow import Workflow from models.workflow import Workflow
from services.workflow_draft_variable_service import ( from services.workflow_draft_variable_service import (
WorkflowDraftVariableService, DraftVariableSaver,
should_save_output_variables_for_draft,
) )
class WorkflowBasedAppRunner(AppRunner): class WorkflowBasedAppRunner(AppRunner):
def __init__(self, queue_manager: AppQueueManager): def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None:
self.queue_manager = queue_manager self.queue_manager = queue_manager
self._variable_loader = variable_loader
def _get_app_id(self) -> str: def _get_app_id(self) -> str:
raise NotImplementedError("not implemented") raise NotImplementedError("not implemented")
@ -182,6 +183,13 @@ class WorkflowBasedAppRunner(AppRunner):
except NotImplementedError: except NotImplementedError:
variable_mapping = {} variable_mapping = {}
load_into_variable_pool(
variable_loader=self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
WorkflowEntry.mapping_user_inputs_to_variable_pool( WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping, variable_mapping=variable_mapping,
user_inputs=user_inputs, user_inputs=user_inputs,
@ -271,6 +279,12 @@ class WorkflowBasedAppRunner(AppRunner):
) )
except NotImplementedError: except NotImplementedError:
variable_mapping = {} variable_mapping = {}
load_into_variable_pool(
self._variable_loader,
variable_pool=variable_pool,
variable_mapping=variable_mapping,
user_inputs=user_inputs,
)
WorkflowEntry.mapping_user_inputs_to_variable_pool( WorkflowEntry.mapping_user_inputs_to_variable_pool(
variable_mapping=variable_mapping, variable_mapping=variable_mapping,
@ -385,23 +399,17 @@ class WorkflowBasedAppRunner(AppRunner):
in_loop_id=event.in_loop_id, in_loop_id=event.in_loop_id,
) )
) )
with Session(bind=db.engine) as session, session.begin():
# FIXME(QuantumGhost): rely on private state of queue_manager is not ideal. draft_var_saver = DraftVariableSaver(
should_save = should_save_output_variables_for_draft( session=session,
self.queue_manager._invoke_from, app_id=self._get_app_id(),
loop_id=event.in_loop_id, node_id=event.node_id,
iteration_id=event.in_iteration_id, node_type=event.node_type,
) # FIXME(QuantumGhost): rely on private state of queue_manager is not ideal.
if should_save and outputs is not None: invoke_from=self.queue_manager._invoke_from,
with Session(bind=db.engine) as session: enclosing_node_id=event.in_loop_id or event.in_iteration_id or None,
draft_var_srv = WorkflowDraftVariableService(session) )
draft_var_srv.save_output_variables( draft_var_saver.save(outputs)
app_id=self._get_app_id(),
node_id=event.node_id,
node_type=event.node_type,
output=outputs,
)
session.commit()
elif isinstance(event, NodeRunFailedEvent): elif isinstance(event, NodeRunFailedEvent):
self._publish_event( self._publish_event(
@ -717,3 +725,11 @@ class WorkflowBasedAppRunner(AppRunner):
def _publish_event(self, event: AppQueueEvent) -> None: def _publish_event(self, event: AppQueueEvent) -> None:
self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
def _remove_first_element_from_variable_string(key: str) -> str:
"""
Remove the first element from the prefix.
"""
prefix, remaining = key.split(".", maxsplit=1)
return remaining

Loading…
Cancel
Save