diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index fdd1a776f8..095b42e66a 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -29,12 +29,14 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from extensions.ext_database import db from factories import file_factory from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom from services.conversation_service import ConversationService from services.errors.message import MessageNotExistsError +from services.workflow_draft_variable_service import DraftVarLoader logger = logging.getLogger(__name__) @@ -260,6 +262,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=application_generate_entity.app_config.app_id, + ) return self._generate( workflow=workflow, @@ -270,6 +276,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, conversation=None, stream=streaming, + variable_loader=var_loader, ) def single_loop_generate( @@ -335,6 +342,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=application_generate_entity.app_config.app_id, + ) return self._generate( workflow=workflow, @@ -345,6 +356,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_node_execution_repository=workflow_node_execution_repository, conversation=None, stream=streaming, + variable_loader=var_loader, ) def _generate( @@ -358,6 +370,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_node_execution_repository: WorkflowNodeExecutionRepository, conversation: Optional[Conversation] = None, stream: bool = True, + variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER, ) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]: """ Generate App response. @@ -410,6 +423,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation_id=conversation.id, message_id=message.id, context=context, + variable_loader=variable_loader, ) worker_thread = threading.Thread(target=worker_with_context) @@ -439,6 +453,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation_id: str, message_id: str, context: contextvars.Context, + variable_loader: VariableLoader, ) -> None: """ Generate worker in a new thread. @@ -480,6 +495,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): conversation=conversation, message=message, dialogue_count=self._dialogue_count, + variable_loader=variable_loader, ) runner.run() diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 18e8310793..4476d2c736 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -19,6 +19,7 @@ from core.moderation.base import ModerationError from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey +from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.enums import UserFrom @@ -40,9 +41,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): conversation: Conversation, message: Message, dialogue_count: int, + variable_loader: VariableLoader, ) -> None: - super().__init__(queue_manager) - + super().__init__(queue_manager, variable_loader) self.application_generate_entity = application_generate_entity self.conversation = conversation self.message = message diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 6ea90e5a3d..fd66dc9fe5 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -27,10 +27,12 @@ from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.workflow.repository.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom from models.enums import WorkflowRunTriggeredFrom +from services.workflow_draft_variable_service import DraftVarLoader logger = logging.getLogger(__name__) @@ -185,6 +187,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_node_execution_repository: WorkflowNodeExecutionRepository, streaming: bool = True, workflow_thread_pool_id: Optional[str] = None, + variable_loader: DraftVarLoader = DUMMY_VARIABLE_LOADER, ) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]: """ Generate App response. @@ -219,6 +222,7 @@ class WorkflowAppGenerator(BaseAppGenerator): queue_manager=queue_manager, context=context, workflow_thread_pool_id=workflow_thread_pool_id, + variable_loader=variable_loader, ) worker_thread = threading.Thread(target=worker_with_context) @@ -304,6 +308,10 @@ class WorkflowAppGenerator(BaseAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) + var_loader = DraftVarLoader( + engine=db.engine, + app_id=application_generate_entity.app_config.app_id, + ) return self._generate( app_model=app_model, @@ -314,6 +322,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, + variable_loader=var_loader, ) def single_loop_generate( @@ -380,7 +389,10 @@ class WorkflowAppGenerator(BaseAppGenerator): app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.SINGLE_STEP, ) - + var_loader = DraftVarLoader( + engine=db.engine, + app_id=application_generate_entity.app_config.app_id, + ) return self._generate( app_model=app_model, workflow=workflow, @@ -390,6 +402,7 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_execution_repository=workflow_execution_repository, workflow_node_execution_repository=workflow_node_execution_repository, streaming=streaming, + variable_loader=var_loader, ) def _generate_worker( @@ -398,6 +411,7 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, context: contextvars.Context, + variable_loader: DraftVarLoader, workflow_thread_pool_id: Optional[str] = None, ) -> None: """ @@ -431,6 +445,7 @@ class WorkflowAppGenerator(BaseAppGenerator): application_generate_entity=application_generate_entity, queue_manager=queue_manager, workflow_thread_pool_id=workflow_thread_pool_id, + variable_loader=variable_loader, ) runner.run() diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index 13cf4581ce..b43daa8a15 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -12,6 +12,7 @@ from core.app.entities.app_invoke_entities import ( from core.workflow.callbacks import WorkflowCallback, WorkflowLoggingCallback from core.workflow.entities.variable_pool import VariablePool from core.workflow.enums import SystemVariableKey +from core.workflow.variable_loader import VariableLoader from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.enums import UserFrom @@ -30,6 +31,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): self, application_generate_entity: WorkflowAppGenerateEntity, queue_manager: AppQueueManager, + variable_loader: VariableLoader, workflow_thread_pool_id: Optional[str] = None, ) -> None: """ @@ -37,8 +39,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): :param queue_manager: application queue manager :param workflow_thread_pool_id: workflow thread pool id """ + super().__init__(queue_manager, variable_loader) self.application_generate_entity = application_generate_entity - self.queue_manager = queue_manager self.workflow_thread_pool_id = workflow_thread_pool_id def _get_app_id(self) -> str: @@ -80,6 +82,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow=workflow, node_id=self.application_generate_entity.single_iteration_run.node_id, user_inputs=self.application_generate_entity.single_iteration_run.inputs, + variable_loader=self._var_loader, ) elif self.application_generate_entity.single_loop_run: # if only single loop run is requested diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index 1e6c92d59b..5aaf0d66fa 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -64,19 +64,20 @@ from core.workflow.graph_engine.entities.event import ( from core.workflow.graph_engine.entities.graph import Graph from core.workflow.nodes import NodeType from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING +from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from core.workflow.workflow_entry import WorkflowEntry from extensions.ext_database import db from models.model import App from models.workflow import Workflow from services.workflow_draft_variable_service import ( - WorkflowDraftVariableService, - should_save_output_variables_for_draft, + DraftVariableSaver, ) class WorkflowBasedAppRunner(AppRunner): - def __init__(self, queue_manager: AppQueueManager): + def __init__(self, queue_manager: AppQueueManager, variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER) -> None: self.queue_manager = queue_manager + self._variable_loader = variable_loader def _get_app_id(self) -> str: raise NotImplementedError("not implemented") @@ -182,6 +183,13 @@ class WorkflowBasedAppRunner(AppRunner): except NotImplementedError: variable_mapping = {} + load_into_variable_pool( + variable_loader=self._variable_loader, + variable_pool=variable_pool, + variable_mapping=variable_mapping, + user_inputs=user_inputs, + ) + WorkflowEntry.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, user_inputs=user_inputs, @@ -271,6 +279,12 @@ class WorkflowBasedAppRunner(AppRunner): ) except NotImplementedError: variable_mapping = {} + load_into_variable_pool( + self._variable_loader, + variable_pool=variable_pool, + variable_mapping=variable_mapping, + user_inputs=user_inputs, + ) WorkflowEntry.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, @@ -385,23 +399,17 @@ class WorkflowBasedAppRunner(AppRunner): in_loop_id=event.in_loop_id, ) ) - - # FIXME(QuantumGhost): rely on private state of queue_manager is not ideal. - should_save = should_save_output_variables_for_draft( - self.queue_manager._invoke_from, - loop_id=event.in_loop_id, - iteration_id=event.in_iteration_id, - ) - if should_save and outputs is not None: - with Session(bind=db.engine) as session: - draft_var_srv = WorkflowDraftVariableService(session) - draft_var_srv.save_output_variables( - app_id=self._get_app_id(), - node_id=event.node_id, - node_type=event.node_type, - output=outputs, - ) - session.commit() + with Session(bind=db.engine) as session, session.begin(): + draft_var_saver = DraftVariableSaver( + session=session, + app_id=self._get_app_id(), + node_id=event.node_id, + node_type=event.node_type, + # FIXME(QuantumGhost): rely on private state of queue_manager is not ideal. + invoke_from=self.queue_manager._invoke_from, + enclosing_node_id=event.in_loop_id or event.in_iteration_id or None, + ) + draft_var_saver.save(outputs) elif isinstance(event, NodeRunFailedEvent): self._publish_event( @@ -717,3 +725,11 @@ class WorkflowBasedAppRunner(AppRunner): def _publish_event(self, event: AppQueueEvent) -> None: self.queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER) + + +def _remove_first_element_from_variable_string(key: str) -> str: + """ + Remove the first element from the prefix. + """ + prefix, remaining = key.split(".", maxsplit=1) + return remaining