diff --git a/api/configs/feature/__init__.py b/api/configs/feature/__init__.py index 963fcbedf9..cca9f252c9 100644 --- a/api/configs/feature/__init__.py +++ b/api/configs/feature/__init__.py @@ -537,6 +537,22 @@ class WorkflowNodeExecutionConfig(BaseSettings): ) +class RepositoryConfig(BaseSettings): + """ + Configuration for repository implementations + """ + + WORKFLOW_EXECUTION_REPOSITORY: str = Field( + description="Repository implementation for WorkflowExecution. Specify as a module path", + default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository", + ) + + WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( + description="Repository implementation for WorkflowNodeExecution. Specify as a module path", + default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository", + ) + + class AuthConfig(BaseSettings): """ Configuration for authentication and OAuth @@ -903,6 +919,7 @@ class FeatureConfig( MultiModalTransferConfig, PositionConfig, RagEtlConfig, + RepositoryConfig, SecurityConfig, ToolConfig, UpdateConfig, diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 7877408cef..649a7172d4 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -25,8 +25,7 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories import RepositoryFactory from core.workflow.repositories.draft_variable_repository import ( DraftVariableSaverFactory, ) @@ -183,14 +182,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING else: workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=workflow_triggered_from, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -260,14 +259,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -343,14 +342,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 40a1e272a7..f1203dfa4a 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -23,8 +23,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.ops.ops_trace_manager import TraceQueueManager -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository +from core.repositories import RepositoryFactory from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository @@ -156,14 +155,14 @@ class WorkflowAppGenerator(BaseAppGenerator): workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING else: workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=workflow_triggered_from, ) # Create workflow node execution repository - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -306,16 +305,14 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, @@ -390,16 +387,14 @@ class WorkflowAppGenerator(BaseAppGenerator): # Create session factory session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) # Create workflow execution(aka workflow run) repository - workflow_execution_repository = SQLAlchemyWorkflowExecutionRepository( + workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, ) # Create workflow node execution repository - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=user, app_id=application_generate_entity.app_config.app_id, diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index a3dbce0e59..d0b228f4ba 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.utils import filter_none_values -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import RepositoryFactory from core.workflow.nodes.enums import NodeType from extensions.ext_database import db from models import EndUser, WorkflowNodeExecutionTriggeredFrom @@ -123,10 +123,10 @@ class LangFuseDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index f94e5e49d7..f3f08d74b8 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -27,7 +27,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.utils import filter_none_values, generate_dotted_order -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import RepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -145,10 +145,10 @@ class LangSmithDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index 8bedea20fb..31ce6fe6c8 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import RepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -160,10 +160,10 @@ class OpikDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 3917348a91..95cb0dd621 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import RepositoryFactory from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.nodes.enums import NodeType from extensions.ext_database import db @@ -144,10 +144,10 @@ class WeaveDataTrace(BaseTraceInstance): service_account = self.get_service_account_with_tenant(app_id) - workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository( + workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository( session_factory=session_factory, user=service_account, - app_id=trace_info.metadata.get("app_id"), + app_id=app_id, triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index 6452317120..bb5b3224ff 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -5,8 +5,11 @@ This package contains concrete implementations of the repository interfaces defined in the core.workflow.repository package. """ +from core.repositories.factory import RepositoryFactory, RepositoryImportError from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository __all__ = [ + "RepositoryFactory", + "RepositoryImportError", "SQLAlchemyWorkflowNodeExecutionRepository", ] diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py new file mode 100644 index 0000000000..646b587244 --- /dev/null +++ b/api/core/repositories/factory.py @@ -0,0 +1,215 @@ +""" +Repository factory for dynamically creating repository instances based on configuration. + +This module provides a Django-like settings system for repository implementations, +allowing users to configure different repository backends through string paths. +""" + +import importlib +import logging +from typing import Any, Union + +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from configs import dify_config +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from models import Account, EndUser +from models.enums import WorkflowRunTriggeredFrom +from models.workflow import WorkflowNodeExecutionTriggeredFrom + +logger = logging.getLogger(__name__) + + +class RepositoryImportError(Exception): + """Raised when a repository implementation cannot be imported or instantiated.""" + + pass + + +class RepositoryFactory: + """ + Factory for creating repository instances based on configuration. + + This factory supports Django-like settings where repository implementations + are specified as module paths (e.g., 'module.submodule.ClassName'). + """ + + @staticmethod + def _import_class(class_path: str) -> Any: + """ + Import a class from a module path string. + + Args: + class_path: Full module path to the class (e.g., 'module.submodule.ClassName') + + Returns: + The imported class + + Raises: + RepositoryImportError: If the class cannot be imported + """ + try: + module_path, class_name = class_path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, class_name) + except (ValueError, ImportError, AttributeError) as e: + raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e + + @staticmethod + def _validate_repository_interface(repository_class: Any, expected_interface: Any) -> None: + """ + Validate that a class implements the expected repository interface. + + Args: + repository_class: The class to validate + expected_interface: The expected interface/protocol + + Raises: + RepositoryImportError: If the class doesn't implement the interface + """ + # Check if the class has all required methods from the protocol + required_methods = [ + method + for method in dir(expected_interface) + if not method.startswith("_") and callable(getattr(expected_interface, method, None)) + ] + + missing_methods = [] + for method_name in required_methods: + if not hasattr(repository_class, method_name): + missing_methods.append(method_name) + + if missing_methods: + raise RepositoryImportError( + f"Repository class '{repository_class.__name__}' does not implement required methods " + f"{missing_methods} from interface '{expected_interface.__name__}'" + ) + + @staticmethod + def _validate_constructor_signature(repository_class: Any, required_params: list[str]) -> None: + """ + Validate that a repository class constructor accepts required parameters. + + Args: + repository_class: The class to validate + required_params: List of required parameter names + + Raises: + RepositoryImportError: If the constructor doesn't accept required parameters + """ + import inspect + + try: + signature = inspect.signature(repository_class.__init__) + param_names = list(signature.parameters.keys()) + + # Remove 'self' parameter + if "self" in param_names: + param_names.remove("self") + + missing_params = [param for param in required_params if param not in param_names] + if missing_params: + raise RepositoryImportError( + f"Repository class '{repository_class.__name__}' constructor does not accept required parameters: " + f"{missing_params}. Expected parameters: {required_params}" + ) + except Exception as e: + raise RepositoryImportError( + f"Failed to validate constructor signature for '{repository_class.__name__}': {e}" + ) from e + + @classmethod + def create_workflow_execution_repository( + cls, + session_factory: Union[sessionmaker, Engine], + user: Union[Account, EndUser], + app_id: str, + triggered_from: WorkflowRunTriggeredFrom, + ) -> WorkflowExecutionRepository: + """ + Create a WorkflowExecutionRepository instance based on configuration. + + Args: + session_factory: SQLAlchemy sessionmaker or engine + user: Account or EndUser object + app_id: Application ID + triggered_from: Source of the execution trigger + + Returns: + Configured WorkflowExecutionRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be created + """ + class_path = dify_config.WORKFLOW_EXECUTION_REPOSITORY + logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, WorkflowExecutionRepository) + cls._validate_constructor_signature( + repository_class, ["session_factory", "user", "app_id", "triggered_from"] + ) + + return repository_class( # type: ignore[no-any-return] + session_factory=session_factory, + user=user, + app_id=app_id, + triggered_from=triggered_from, + ) + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create WorkflowExecutionRepository") + raise RepositoryImportError(f"Failed to create WorkflowExecutionRepository from '{class_path}': {e}") from e + + @classmethod + def create_workflow_node_execution_repository( + cls, + session_factory: Union[sessionmaker, Engine], + user: Union[Account, EndUser], + app_id: str, + triggered_from: WorkflowNodeExecutionTriggeredFrom, + ) -> WorkflowNodeExecutionRepository: + """ + Create a WorkflowNodeExecutionRepository instance based on configuration. + + Args: + session_factory: SQLAlchemy sessionmaker or engine + user: Account or EndUser object + app_id: Application ID + triggered_from: Source of the execution trigger + + Returns: + Configured WorkflowNodeExecutionRepository instance + + Raises: + RepositoryImportError: If the configured repository cannot be created + """ + class_path = dify_config.WORKFLOW_NODE_EXECUTION_REPOSITORY + logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}") + + try: + repository_class = cls._import_class(class_path) + cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository) + cls._validate_constructor_signature( + repository_class, ["session_factory", "user", "app_id", "triggered_from"] + ) + + return repository_class( # type: ignore[no-any-return] + session_factory=session_factory, + user=user, + app_id=app_id, + triggered_from=triggered_from, + ) + except RepositoryImportError: + # Re-raise our custom errors as-is + raise + except Exception as e: + logger.exception("Failed to create WorkflowNodeExecutionRepository") + raise RepositoryImportError( + f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}" + ) from e diff --git a/api/services/workflow_run_service.py b/api/services/workflow_run_service.py index 483c0d3086..4acf1206b1 100644 --- a/api/services/workflow_run_service.py +++ b/api/services/workflow_run_service.py @@ -2,9 +2,9 @@ import threading from collections.abc import Sequence from typing import Optional +from sqlalchemy import desc, select + import contexts -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from core.workflow.repositories.workflow_node_execution_repository import OrderConfig from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import ( @@ -15,7 +15,6 @@ from models import ( WorkflowRun, WorkflowRunTriggeredFrom, ) -from models.workflow import WorkflowNodeExecutionTriggeredFrom class WorkflowRunService: @@ -137,17 +136,19 @@ class WorkflowRunService: if not workflow_run: return [] - repository = SQLAlchemyWorkflowNodeExecutionRepository( - session_factory=db.engine, - user=user, - app_id=app_model.id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) + # Get tenant_id from user + tenant_id = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id - # Use the repository to get the database models directly - order_config = OrderConfig(order_by=["index"], order_direction="desc") - workflow_node_executions = repository.get_db_models_by_workflow_run( - workflow_run_id=run_id, order_config=order_config + # Use SQLAlchemy 2.0 style query directly + stmt = ( + select(WorkflowNodeExecutionModel) + .where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.app_id == app_model.id, + WorkflowNodeExecutionModel.workflow_run_id == run_id, + ) + .order_by(desc(WorkflowNodeExecutionModel.index)) ) + workflow_node_executions = db.session.execute(stmt).scalars().all() return workflow_node_executions diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 2be57fd51c..e38858f73e 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -13,7 +13,7 @@ from core.app.app_config.entities import VariableEntityType from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.file import File -from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository +from core.repositories import RepositoryFactory from core.variables import Variable from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.variable_pool import VariablePool @@ -396,7 +396,7 @@ class WorkflowService: node_execution.workflow_id = draft_workflow.id # Create repository and save the node execution - repository = SQLAlchemyWorkflowNodeExecutionRepository( + repository = RepositoryFactory.create_workflow_node_execution_repository( session_factory=db.engine, user=account, app_id=app_model.id, @@ -404,8 +404,8 @@ class WorkflowService: ) repository.save(node_execution) - # Convert node_execution to WorkflowNodeExecution after save - workflow_node_execution = repository.to_db_model(node_execution) + stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == node_execution.id) + workflow_node_execution = db.session.execute(stmt).scalar_one() with Session(bind=db.engine) as session, session.begin(): draft_var_saver = DraftVariableSaver( diff --git a/api/tests/unit_tests/core/repositories/__init__.py b/api/tests/unit_tests/core/repositories/__init__.py new file mode 100644 index 0000000000..c65d7da61d --- /dev/null +++ b/api/tests/unit_tests/core/repositories/__init__.py @@ -0,0 +1 @@ +# Unit tests for core repositories module diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py new file mode 100644 index 0000000000..1d52c5daf8 --- /dev/null +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -0,0 +1,455 @@ +""" +Unit tests for the RepositoryFactory. + +This module tests the factory pattern implementation for creating repository instances +based on configuration, including error handling and validation. +""" + +from unittest.mock import MagicMock, patch + +import pytest +from pytest_mock import MockerFixture +from sqlalchemy.engine import Engine +from sqlalchemy.orm import sessionmaker + +from core.repositories.factory import RepositoryFactory, RepositoryImportError +from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from models import Account, EndUser +from models.enums import WorkflowRunTriggeredFrom +from models.workflow import WorkflowNodeExecutionTriggeredFrom + + +class TestRepositoryFactory: + """Test cases for RepositoryFactory.""" + + def test_import_class_success(self): + """Test successful class import.""" + # Test importing a real class + class_path = "unittest.mock.MagicMock" + result = RepositoryFactory._import_class(class_path) + assert result is MagicMock + + def test_import_class_invalid_path(self): + """Test import with invalid module path.""" + with pytest.raises(RepositoryImportError) as exc_info: + RepositoryFactory._import_class("invalid.module.path") + assert "Cannot import repository class" in str(exc_info.value) + + def test_import_class_invalid_class_name(self): + """Test import with invalid class name.""" + with pytest.raises(RepositoryImportError) as exc_info: + RepositoryFactory._import_class("unittest.mock.NonExistentClass") + assert "Cannot import repository class" in str(exc_info.value) + + def test_import_class_malformed_path(self): + """Test import with malformed path (no dots).""" + with pytest.raises(RepositoryImportError) as exc_info: + RepositoryFactory._import_class("invalidpath") + assert "Cannot import repository class" in str(exc_info.value) + + def test_validate_repository_interface_success(self): + """Test successful interface validation.""" + + # Create a mock class that implements the required methods + class MockRepository: + def save(self): + pass + + def get_by_id(self): + pass + + # Create a mock interface with the same methods + class MockInterface: + def save(self): + pass + + def get_by_id(self): + pass + + # Should not raise an exception + RepositoryFactory._validate_repository_interface(MockRepository, MockInterface) + + def test_validate_repository_interface_missing_methods(self): + """Test interface validation with missing methods.""" + + # Create a mock class that doesn't implement all required methods + class IncompleteRepository: + def save(self): + pass + + # Missing get_by_id method + + # Create a mock interface with required methods + class MockInterface: + def save(self): + pass + + def get_by_id(self): + pass + + with pytest.raises(RepositoryImportError) as exc_info: + RepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface) + assert "does not implement required methods" in str(exc_info.value) + assert "get_by_id" in str(exc_info.value) + + def test_validate_constructor_signature_success(self): + """Test successful constructor signature validation.""" + + class MockRepository: + def __init__(self, session_factory, user, app_id, triggered_from): + pass + + # Should not raise an exception + RepositoryFactory._validate_constructor_signature( + MockRepository, ["session_factory", "user", "app_id", "triggered_from"] + ) + + def test_validate_constructor_signature_missing_params(self): + """Test constructor validation with missing parameters.""" + + class IncompleteRepository: + def __init__(self, session_factory, user): + # Missing app_id and triggered_from parameters + pass + + with pytest.raises(RepositoryImportError) as exc_info: + RepositoryFactory._validate_constructor_signature( + IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"] + ) + assert "does not accept required parameters" in str(exc_info.value) + assert "app_id" in str(exc_info.value) + assert "triggered_from" in str(exc_info.value) + + def test_validate_constructor_signature_inspection_error(self, mocker: MockerFixture): + """Test constructor validation when inspection fails.""" + # Mock inspect.signature to raise an exception + mocker.patch("inspect.signature", side_effect=Exception("Inspection failed")) + + class MockRepository: + def __init__(self, session_factory): + pass + + with pytest.raises(RepositoryImportError) as exc_info: + RepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"]) + assert "Failed to validate constructor signature" in str(exc_info.value) + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_execution_repository_success(self, mock_config, mocker: MockerFixture): + """Test successful creation of WorkflowExecutionRepository.""" + # Setup mock configuration + mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + # Create mock dependencies + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=Account) + app_id = "test-app-id" + triggered_from = WorkflowRunTriggeredFrom.APP_RUN + + # Mock the imported class to be a valid repository + mock_repository_class = MagicMock() + mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository) + mock_repository_class.return_value = mock_repository_instance + + # Mock the validation methods + with ( + patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(RepositoryFactory, "_validate_repository_interface"), + patch.object(RepositoryFactory, "_validate_constructor_signature"), + ): + result = RepositoryFactory.create_workflow_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id=app_id, + triggered_from=triggered_from, + ) + + # Verify the repository was created with correct parameters + mock_repository_class.assert_called_once_with( + session_factory=mock_session_factory, + user=mock_user, + app_id=app_id, + triggered_from=triggered_from, + ) + assert result is mock_repository_instance + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_execution_repository_import_error(self, mock_config): + """Test WorkflowExecutionRepository creation with import error.""" + # Setup mock configuration with invalid class path + mock_config.WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=Account) + + with pytest.raises(RepositoryImportError) as exc_info: + RepositoryFactory.create_workflow_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + assert "Cannot import repository class" in str(exc_info.value) + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture): + """Test WorkflowExecutionRepository creation with validation error.""" + # Setup mock configuration + mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=Account) + + # Mock import to succeed but validation to fail + mock_repository_class = MagicMock() + with ( + patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object( + RepositoryFactory, + "_validate_repository_interface", + side_effect=RepositoryImportError("Interface validation failed"), + ), + ): + with pytest.raises(RepositoryImportError) as exc_info: + RepositoryFactory.create_workflow_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + assert "Interface validation failed" in str(exc_info.value) + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_execution_repository_instantiation_error(self, mock_config, mocker: MockerFixture): + """Test WorkflowExecutionRepository creation with instantiation error.""" + # Setup mock configuration + mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=Account) + + # Mock import and validation to succeed but instantiation to fail + mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed")) + with ( + patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(RepositoryFactory, "_validate_repository_interface"), + patch.object(RepositoryFactory, "_validate_constructor_signature"), + ): + with pytest.raises(RepositoryImportError) as exc_info: + RepositoryFactory.create_workflow_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value) + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_node_execution_repository_success(self, mock_config, mocker: MockerFixture): + """Test successful creation of WorkflowNodeExecutionRepository.""" + # Setup mock configuration + mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + # Create mock dependencies + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=EndUser) + app_id = "test-app-id" + triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN + + # Mock the imported class to be a valid repository + mock_repository_class = MagicMock() + mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository) + mock_repository_class.return_value = mock_repository_instance + + # Mock the validation methods + with ( + patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(RepositoryFactory, "_validate_repository_interface"), + patch.object(RepositoryFactory, "_validate_constructor_signature"), + ): + result = RepositoryFactory.create_workflow_node_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id=app_id, + triggered_from=triggered_from, + ) + + # Verify the repository was created with correct parameters + mock_repository_class.assert_called_once_with( + session_factory=mock_session_factory, + user=mock_user, + app_id=app_id, + triggered_from=triggered_from, + ) + assert result is mock_repository_instance + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_node_execution_repository_import_error(self, mock_config): + """Test WorkflowNodeExecutionRepository creation with import error.""" + # Setup mock configuration with invalid class path + mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=EndUser) + + with pytest.raises(RepositoryImportError) as exc_info: + RepositoryFactory.create_workflow_node_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + assert "Cannot import repository class" in str(exc_info.value) + + def test_repository_import_error_exception(self): + """Test RepositoryImportError exception.""" + error_message = "Test error message" + exception = RepositoryImportError(error_message) + assert str(exception) == error_message + assert isinstance(exception, Exception) + + @patch("core.repositories.factory.dify_config") + def test_create_with_engine_instead_of_sessionmaker(self, mock_config, mocker: MockerFixture): + """Test repository creation with Engine instead of sessionmaker.""" + # Setup mock configuration + mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + # Create mock dependencies with Engine instead of sessionmaker + mock_engine = MagicMock(spec=Engine) + mock_user = MagicMock(spec=Account) + + # Mock the imported class to be a valid repository + mock_repository_class = MagicMock() + mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository) + mock_repository_class.return_value = mock_repository_instance + + # Mock the validation methods + with ( + patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(RepositoryFactory, "_validate_repository_interface"), + patch.object(RepositoryFactory, "_validate_constructor_signature"), + ): + result = RepositoryFactory.create_workflow_execution_repository( + session_factory=mock_engine, # Using Engine instead of sessionmaker + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + + # Verify the repository was created with the Engine + mock_repository_class.assert_called_once_with( + session_factory=mock_engine, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowRunTriggeredFrom.APP_RUN, + ) + assert result is mock_repository_instance + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_node_execution_repository_validation_error(self, mock_config): + """Test WorkflowNodeExecutionRepository creation with validation error.""" + # Setup mock configuration + mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=EndUser) + + # Mock import to succeed but validation to fail + mock_repository_class = MagicMock() + with ( + patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object( + RepositoryFactory, + "_validate_repository_interface", + side_effect=RepositoryImportError("Interface validation failed"), + ), + ): + with pytest.raises(RepositoryImportError) as exc_info: + RepositoryFactory.create_workflow_node_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + assert "Interface validation failed" in str(exc_info.value) + + @patch("core.repositories.factory.dify_config") + def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config): + """Test WorkflowNodeExecutionRepository creation with instantiation error.""" + # Setup mock configuration + mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock" + + mock_session_factory = MagicMock(spec=sessionmaker) + mock_user = MagicMock(spec=EndUser) + + # Mock import and validation to succeed but instantiation to fail + mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed")) + with ( + patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class), + patch.object(RepositoryFactory, "_validate_repository_interface"), + patch.object(RepositoryFactory, "_validate_constructor_signature"), + ): + with pytest.raises(RepositoryImportError) as exc_info: + RepositoryFactory.create_workflow_node_execution_repository( + session_factory=mock_session_factory, + user=mock_user, + app_id="test-app-id", + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value) + + def test_validate_repository_interface_with_private_methods(self): + """Test interface validation ignores private methods.""" + + # Create a mock class with private methods + class MockRepository: + def save(self): + pass + + def get_by_id(self): + pass + + def _private_method(self): + pass + + # Create a mock interface with private methods + class MockInterface: + def save(self): + pass + + def get_by_id(self): + pass + + def _private_method(self): + pass + + # Should not raise an exception (private methods are ignored) + RepositoryFactory._validate_repository_interface(MockRepository, MockInterface) + + def test_validate_constructor_signature_with_extra_params(self): + """Test constructor validation with extra parameters (should pass).""" + + class MockRepository: + def __init__(self, session_factory, user, app_id, triggered_from, extra_param=None): + pass + + # Should not raise an exception (extra parameters are allowed) + RepositoryFactory._validate_constructor_signature( + MockRepository, ["session_factory", "user", "app_id", "triggered_from"] + ) + + def test_validate_constructor_signature_with_kwargs(self): + """Test constructor validation with **kwargs (current implementation doesn't support this).""" + + class MockRepository: + def __init__(self, session_factory, user, **kwargs): + pass + + # Current implementation doesn't handle **kwargs, so this should raise an exception + with pytest.raises(RepositoryImportError) as exc_info: + RepositoryFactory._validate_constructor_signature( + MockRepository, ["session_factory", "user", "app_id", "triggered_from"] + ) + assert "does not accept required parameters" in str(exc_info.value) + assert "app_id" in str(exc_info.value) + assert "triggered_from" in str(exc_info.value)