feat(api/repo): Allow to config repository implementation

Signed-off-by: -LAN- <laipz8200@outlook.com>
pull/22581/head
-LAN- 11 months ago
parent 62586719b3
commit d2bb48d38f
No known key found for this signature in database
GPG Key ID: 6BA0D108DED011FF

@ -551,27 +551,16 @@ class RepositoryConfig(BaseSettings):
Configuration for repository implementations Configuration for repository implementations
""" """
CORE_WORKFLOW_EXECUTION_REPOSITORY: str = Field( WORKFLOW_EXECUTION_REPOSITORY: str = Field(
description="Repository implementation for WorkflowExecution. Specify as a module path", description="Repository implementation for WorkflowExecution. Specify as a module path",
default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository", default="core.repositories.sqlalchemy_workflow_execution_repository.SQLAlchemyWorkflowExecutionRepository",
) )
CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field( WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
description="Repository implementation for WorkflowNodeExecution. Specify as a module path", description="Repository implementation for WorkflowNodeExecution. Specify as a module path",
default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository", default="core.repositories.sqlalchemy_workflow_node_execution_repository.SQLAlchemyWorkflowNodeExecutionRepository",
) )
API_WORKFLOW_NODE_EXECUTION_REPOSITORY: str = Field(
description="Service-layer repository implementation for WorkflowNodeExecutionModel operations. "
"Specify as a module path",
default="repositories.sqlalchemy_api_workflow_node_execution_repository.DifyAPISQLAlchemyWorkflowNodeExecutionRepository",
)
API_WORKFLOW_RUN_REPOSITORY: str = Field(
description="Service-layer repository implementation for WorkflowRun operations. Specify as a module path",
default="repositories.sqlalchemy_api_workflow_run_repository.DifyAPISQLAlchemyWorkflowRunRepository",
)
class AuthConfig(BaseSettings): class AuthConfig(BaseSettings):
""" """

@ -25,7 +25,7 @@ from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotA
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import DifyCoreRepositoryFactory from core.repositories import RepositoryFactory
from core.workflow.repositories.draft_variable_repository import ( from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory, DraftVariableSaverFactory,
) )
@ -182,14 +182,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else: else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=user, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from, triggered_from=workflow_triggered_from,
) )
# Create workflow node execution repository # Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=user, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
@ -259,14 +259,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create session factory # Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository # Create workflow execution(aka workflow run) repository
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=user, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
) )
# Create workflow node execution repository # Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=user, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
@ -342,14 +342,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# Create session factory # Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository # Create workflow execution(aka workflow run) repository
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=user, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
) )
# Create workflow node execution repository # Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=user, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,

@ -23,7 +23,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerat
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.model_runtime.errors.invoke import InvokeAuthorizationError from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories import DifyCoreRepositoryFactory from core.repositories import RepositoryFactory
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
@ -155,14 +155,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING
else: else:
workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=user, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=workflow_triggered_from, triggered_from=workflow_triggered_from,
) )
# Create workflow node execution repository # Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=user, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
@ -305,14 +305,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create session factory # Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository # Create workflow execution(aka workflow run) repository
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=user, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
) )
# Create workflow node execution repository # Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=user, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
@ -387,14 +387,14 @@ class WorkflowAppGenerator(BaseAppGenerator):
# Create session factory # Create session factory
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
# Create workflow execution(aka workflow run) repository # Create workflow execution(aka workflow run) repository
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( workflow_execution_repository = RepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=user, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.DEBUGGING, triggered_from=WorkflowRunTriggeredFrom.DEBUGGING,
) )
# Create workflow node execution repository # Create workflow node execution repository
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=user, user=user,
app_id=application_generate_entity.app_config.app_id, app_id=application_generate_entity.app_config.app_id,

@ -28,7 +28,7 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import (
UnitEnum, UnitEnum,
) )
from core.ops.utils import filter_none_values from core.ops.utils import filter_none_values
from core.repositories import DifyCoreRepositoryFactory from core.repositories import RepositoryFactory
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
from models import EndUser, WorkflowNodeExecutionTriggeredFrom from models import EndUser, WorkflowNodeExecutionTriggeredFrom
@ -123,7 +123,7 @@ class LangFuseDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id) service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=service_account, user=service_account,
app_id=app_id, app_id=app_id,

@ -27,7 +27,7 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import (
LangSmithRunUpdateModel, LangSmithRunUpdateModel,
) )
from core.ops.utils import filter_none_values, generate_dotted_order from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import DifyCoreRepositoryFactory from core.repositories import RepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
@ -145,7 +145,7 @@ class LangSmithDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id) service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=service_account, user=service_account,
app_id=app_id, app_id=app_id,

@ -21,7 +21,7 @@ from core.ops.entities.trace_entity import (
TraceTaskName, TraceTaskName,
WorkflowTraceInfo, WorkflowTraceInfo,
) )
from core.repositories import DifyCoreRepositoryFactory from core.repositories import RepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
@ -160,7 +160,7 @@ class OpikDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id) service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=service_account, user=service_account,
app_id=app_id, app_id=app_id,

@ -22,7 +22,7 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo, WorkflowTraceInfo,
) )
from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel
from core.repositories import DifyCoreRepositoryFactory from core.repositories import RepositoryFactory
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes.enums import NodeType from core.workflow.nodes.enums import NodeType
from extensions.ext_database import db from extensions.ext_database import db
@ -144,7 +144,7 @@ class WeaveDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id) service_account = self.get_service_account_with_tenant(app_id)
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( workflow_node_execution_repository = RepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory, session_factory=session_factory,
user=service_account, user=service_account,
app_id=app_id, app_id=app_id,

@ -5,11 +5,11 @@ This package contains concrete implementations of the repository interfaces
defined in the core.workflow.repository package. defined in the core.workflow.repository package.
""" """
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError from core.repositories.factory import RepositoryFactory, RepositoryImportError
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [ __all__ = [
"DifyCoreRepositoryFactory", "RepositoryFactory",
"RepositoryImportError", "RepositoryImportError",
"SQLAlchemyWorkflowNodeExecutionRepository", "SQLAlchemyWorkflowNodeExecutionRepository",
] ]

@ -6,9 +6,8 @@ allowing users to configure different repository backends through string paths.
""" """
import importlib import importlib
import inspect
import logging import logging
from typing import Protocol, Union from typing import Any, Union
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
@ -29,7 +28,7 @@ class RepositoryImportError(Exception):
pass pass
class DifyCoreRepositoryFactory: class RepositoryFactory:
""" """
Factory for creating repository instances based on configuration. Factory for creating repository instances based on configuration.
@ -38,7 +37,7 @@ class DifyCoreRepositoryFactory:
""" """
@staticmethod @staticmethod
def _import_class(class_path: str) -> type: def _import_class(class_path: str) -> Any:
""" """
Import a class from a module path string. Import a class from a module path string.
@ -54,14 +53,12 @@ class DifyCoreRepositoryFactory:
try: try:
module_path, class_name = class_path.rsplit(".", 1) module_path, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_path) module = importlib.import_module(module_path)
repo_class = getattr(module, class_name) return getattr(module, class_name)
assert isinstance(repo_class, type)
return repo_class
except (ValueError, ImportError, AttributeError) as e: except (ValueError, ImportError, AttributeError) as e:
raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e
@staticmethod @staticmethod
def _validate_repository_interface(repository_class: type, expected_interface: type[Protocol]) -> None: # type: ignore def _validate_repository_interface(repository_class: Any, expected_interface: Any) -> None:
""" """
Validate that a class implements the expected repository interface. Validate that a class implements the expected repository interface.
@ -91,7 +88,7 @@ class DifyCoreRepositoryFactory:
) )
@staticmethod @staticmethod
def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None: def _validate_constructor_signature(repository_class: Any, required_params: list[str]) -> None:
""" """
Validate that a repository class constructor accepts required parameters. Validate that a repository class constructor accepts required parameters.
@ -102,16 +99,10 @@ class DifyCoreRepositoryFactory:
Raises: Raises:
RepositoryImportError: If the constructor doesn't accept required parameters RepositoryImportError: If the constructor doesn't accept required parameters
""" """
import inspect
try: try:
# MyPy may flag the line below with the following error: signature = inspect.signature(repository_class.__init__)
#
# > Accessing "__init__" on an instance is unsound, since
# > instance.__init__ could be from an incompatible subclass.
#
# Despite this, we need to ensure that the constructor of `repository_class`
# has a compatible signature.
signature = inspect.signature(repository_class.__init__) # type: ignore[misc]
param_names = list(signature.parameters.keys()) param_names = list(signature.parameters.keys())
# Remove 'self' parameter # Remove 'self' parameter
@ -152,7 +143,7 @@ class DifyCoreRepositoryFactory:
Raises: Raises:
RepositoryImportError: If the configured repository cannot be created RepositoryImportError: If the configured repository cannot be created
""" """
class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY class_path = dify_config.WORKFLOW_EXECUTION_REPOSITORY
logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}") logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}")
try: try:
@ -198,7 +189,7 @@ class DifyCoreRepositoryFactory:
Raises: Raises:
RepositoryImportError: If the configured repository cannot be created RepositoryImportError: If the configured repository cannot be created
""" """
class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY class_path = dify_config.WORKFLOW_NODE_EXECUTION_REPOSITORY
logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}") logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}")
try: try:

@ -2,7 +2,7 @@ import threading
from collections.abc import Sequence from collections.abc import Sequence
from typing import Optional from typing import Optional
from sqlalchemy.orm import sessionmaker from sqlalchemy import desc, select
import contexts import contexts
from extensions.ext_database import db from extensions.ext_database import db
@ -15,7 +15,6 @@ from models import (
WorkflowRun, WorkflowRun,
WorkflowRunTriggeredFrom, WorkflowRunTriggeredFrom,
) )
from repositories.factory import DifyAPIRepositoryFactory
class WorkflowRunService: class WorkflowRunService:
@ -112,11 +111,17 @@ class WorkflowRunService:
# Get tenant_id from user # Get tenant_id from user
tenant_id = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id tenant_id = user.tenant_id if isinstance(user, EndUser) else user.current_tenant_id
if tenant_id is None:
raise ValueError("User tenant_id cannot be None")
return self._node_execution_service_repo.get_executions_by_workflow_run( # Use SQLAlchemy 2.0 style query directly
tenant_id=tenant_id, stmt = (
app_id=app_model.id, select(WorkflowNodeExecutionModel)
workflow_run_id=run_id, .where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.app_id == app_model.id,
WorkflowNodeExecutionModel.workflow_run_id == run_id,
)
.order_by(desc(WorkflowNodeExecutionModel.index))
) )
workflow_node_executions = db.session.execute(stmt).scalars().all()
return workflow_node_executions

@ -13,7 +13,7 @@ from core.app.app_config.entities import VariableEntityType
from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File from core.file import File
from core.repositories import DifyCoreRepositoryFactory from core.repositories import RepositoryFactory
from core.variables import Variable from core.variables import Variable
from core.variables.variables import VariableUnion from core.variables.variables import VariableUnion
from core.workflow.entities.node_entities import NodeRunResult from core.workflow.entities.node_entities import NodeRunResult
@ -409,7 +409,7 @@ class WorkflowService:
node_execution.workflow_id = draft_workflow.id node_execution.workflow_id = draft_workflow.id
# Create repository and save the node execution # Create repository and save the node execution
repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( repository = RepositoryFactory.create_workflow_node_execution_repository(
session_factory=db.engine, session_factory=db.engine,
user=account, user=account,
app_id=app_model.id, app_id=app_model.id,
@ -417,9 +417,8 @@ class WorkflowService:
) )
repository.save(node_execution) repository.save(node_execution)
workflow_node_execution = self._node_execution_service_repo.get_execution_by_id(node_execution.id) stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == node_execution.id)
if workflow_node_execution is None: workflow_node_execution = db.session.execute(stmt).scalar_one()
raise ValueError(f"WorkflowNodeExecution with id {node_execution.id} not found after saving")
with Session(bind=db.engine) as session, session.begin(): with Session(bind=db.engine) as session, session.begin():
draft_var_saver = DraftVariableSaver( draft_var_saver = DraftVariableSaver(

@ -12,7 +12,7 @@ from pytest_mock import MockerFixture
from sqlalchemy.engine import Engine from sqlalchemy.engine import Engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError from core.repositories.factory import RepositoryFactory, RepositoryImportError
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from models import Account, EndUser from models import Account, EndUser
@ -27,25 +27,25 @@ class TestRepositoryFactory:
"""Test successful class import.""" """Test successful class import."""
# Test importing a real class # Test importing a real class
class_path = "unittest.mock.MagicMock" class_path = "unittest.mock.MagicMock"
result = DifyCoreRepositoryFactory._import_class(class_path) result = RepositoryFactory._import_class(class_path)
assert result is MagicMock assert result is MagicMock
def test_import_class_invalid_path(self): def test_import_class_invalid_path(self):
"""Test import with invalid module path.""" """Test import with invalid module path."""
with pytest.raises(RepositoryImportError) as exc_info: with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._import_class("invalid.module.path") RepositoryFactory._import_class("invalid.module.path")
assert "Cannot import repository class" in str(exc_info.value) assert "Cannot import repository class" in str(exc_info.value)
def test_import_class_invalid_class_name(self): def test_import_class_invalid_class_name(self):
"""Test import with invalid class name.""" """Test import with invalid class name."""
with pytest.raises(RepositoryImportError) as exc_info: with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass") RepositoryFactory._import_class("unittest.mock.NonExistentClass")
assert "Cannot import repository class" in str(exc_info.value) assert "Cannot import repository class" in str(exc_info.value)
def test_import_class_malformed_path(self): def test_import_class_malformed_path(self):
"""Test import with malformed path (no dots).""" """Test import with malformed path (no dots)."""
with pytest.raises(RepositoryImportError) as exc_info: with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._import_class("invalidpath") RepositoryFactory._import_class("invalidpath")
assert "Cannot import repository class" in str(exc_info.value) assert "Cannot import repository class" in str(exc_info.value)
def test_validate_repository_interface_success(self): def test_validate_repository_interface_success(self):
@ -68,7 +68,7 @@ class TestRepositoryFactory:
pass pass
# Should not raise an exception # Should not raise an exception
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) RepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
def test_validate_repository_interface_missing_methods(self): def test_validate_repository_interface_missing_methods(self):
"""Test interface validation with missing methods.""" """Test interface validation with missing methods."""
@ -89,7 +89,7 @@ class TestRepositoryFactory:
pass pass
with pytest.raises(RepositoryImportError) as exc_info: with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface) RepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
assert "does not implement required methods" in str(exc_info.value) assert "does not implement required methods" in str(exc_info.value)
assert "get_by_id" in str(exc_info.value) assert "get_by_id" in str(exc_info.value)
@ -101,7 +101,7 @@ class TestRepositoryFactory:
pass pass
# Should not raise an exception # Should not raise an exception
DifyCoreRepositoryFactory._validate_constructor_signature( RepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"] MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
) )
@ -114,7 +114,7 @@ class TestRepositoryFactory:
pass pass
with pytest.raises(RepositoryImportError) as exc_info: with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature( RepositoryFactory._validate_constructor_signature(
IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"] IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"]
) )
assert "does not accept required parameters" in str(exc_info.value) assert "does not accept required parameters" in str(exc_info.value)
@ -131,7 +131,7 @@ class TestRepositoryFactory:
pass pass
with pytest.raises(RepositoryImportError) as exc_info: with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"]) RepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"])
assert "Failed to validate constructor signature" in str(exc_info.value) assert "Failed to validate constructor signature" in str(exc_info.value)
@patch("core.repositories.factory.dify_config") @patch("core.repositories.factory.dify_config")
@ -153,11 +153,11 @@ class TestRepositoryFactory:
# Mock the validation methods # Mock the validation methods
with ( with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), patch.object(RepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), patch.object(RepositoryFactory, "_validate_constructor_signature"),
): ):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository( result = RepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory, session_factory=mock_session_factory,
user=mock_user, user=mock_user,
app_id=app_id, app_id=app_id,
@ -183,7 +183,7 @@ class TestRepositoryFactory:
mock_user = MagicMock(spec=Account) mock_user = MagicMock(spec=Account)
with pytest.raises(RepositoryImportError) as exc_info: with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository( RepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory, session_factory=mock_session_factory,
user=mock_user, user=mock_user,
app_id="test-app-id", app_id="test-app-id",
@ -203,15 +203,15 @@ class TestRepositoryFactory:
# Mock import to succeed but validation to fail # Mock import to succeed but validation to fail
mock_repository_class = MagicMock() mock_repository_class = MagicMock()
with ( with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object( patch.object(
DifyCoreRepositoryFactory, RepositoryFactory,
"_validate_repository_interface", "_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"), side_effect=RepositoryImportError("Interface validation failed"),
), ),
): ):
with pytest.raises(RepositoryImportError) as exc_info: with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository( RepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory, session_factory=mock_session_factory,
user=mock_user, user=mock_user,
app_id="test-app-id", app_id="test-app-id",
@ -231,12 +231,12 @@ class TestRepositoryFactory:
# Mock import and validation to succeed but instantiation to fail # Mock import and validation to succeed but instantiation to fail
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed")) mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
with ( with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), patch.object(RepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), patch.object(RepositoryFactory, "_validate_constructor_signature"),
): ):
with pytest.raises(RepositoryImportError) as exc_info: with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_execution_repository( RepositoryFactory.create_workflow_execution_repository(
session_factory=mock_session_factory, session_factory=mock_session_factory,
user=mock_user, user=mock_user,
app_id="test-app-id", app_id="test-app-id",
@ -263,11 +263,11 @@ class TestRepositoryFactory:
# Mock the validation methods # Mock the validation methods
with ( with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), patch.object(RepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), patch.object(RepositoryFactory, "_validate_constructor_signature"),
): ):
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( result = RepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory, session_factory=mock_session_factory,
user=mock_user, user=mock_user,
app_id=app_id, app_id=app_id,
@ -293,7 +293,7 @@ class TestRepositoryFactory:
mock_user = MagicMock(spec=EndUser) mock_user = MagicMock(spec=EndUser)
with pytest.raises(RepositoryImportError) as exc_info: with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository( RepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory, session_factory=mock_session_factory,
user=mock_user, user=mock_user,
app_id="test-app-id", app_id="test-app-id",
@ -325,11 +325,11 @@ class TestRepositoryFactory:
# Mock the validation methods # Mock the validation methods
with ( with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), patch.object(RepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), patch.object(RepositoryFactory, "_validate_constructor_signature"),
): ):
result = DifyCoreRepositoryFactory.create_workflow_execution_repository( result = RepositoryFactory.create_workflow_execution_repository(
session_factory=mock_engine, # Using Engine instead of sessionmaker session_factory=mock_engine, # Using Engine instead of sessionmaker
user=mock_user, user=mock_user,
app_id="test-app-id", app_id="test-app-id",
@ -357,15 +357,15 @@ class TestRepositoryFactory:
# Mock import to succeed but validation to fail # Mock import to succeed but validation to fail
mock_repository_class = MagicMock() mock_repository_class = MagicMock()
with ( with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object( patch.object(
DifyCoreRepositoryFactory, RepositoryFactory,
"_validate_repository_interface", "_validate_repository_interface",
side_effect=RepositoryImportError("Interface validation failed"), side_effect=RepositoryImportError("Interface validation failed"),
), ),
): ):
with pytest.raises(RepositoryImportError) as exc_info: with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository( RepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory, session_factory=mock_session_factory,
user=mock_user, user=mock_user,
app_id="test-app-id", app_id="test-app-id",
@ -385,12 +385,12 @@ class TestRepositoryFactory:
# Mock import and validation to succeed but instantiation to fail # Mock import and validation to succeed but instantiation to fail
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed")) mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
with ( with (
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class), patch.object(RepositoryFactory, "_import_class", return_value=mock_repository_class),
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"), patch.object(RepositoryFactory, "_validate_repository_interface"),
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"), patch.object(RepositoryFactory, "_validate_constructor_signature"),
): ):
with pytest.raises(RepositoryImportError) as exc_info: with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory.create_workflow_node_execution_repository( RepositoryFactory.create_workflow_node_execution_repository(
session_factory=mock_session_factory, session_factory=mock_session_factory,
user=mock_user, user=mock_user,
app_id="test-app-id", app_id="test-app-id",
@ -424,7 +424,7 @@ class TestRepositoryFactory:
pass pass
# Should not raise an exception (private methods are ignored) # Should not raise an exception (private methods are ignored)
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface) RepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
def test_validate_constructor_signature_with_extra_params(self): def test_validate_constructor_signature_with_extra_params(self):
"""Test constructor validation with extra parameters (should pass).""" """Test constructor validation with extra parameters (should pass)."""
@ -434,7 +434,7 @@ class TestRepositoryFactory:
pass pass
# Should not raise an exception (extra parameters are allowed) # Should not raise an exception (extra parameters are allowed)
DifyCoreRepositoryFactory._validate_constructor_signature( RepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"] MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
) )
@ -447,7 +447,7 @@ class TestRepositoryFactory:
# Current implementation doesn't handle **kwargs, so this should raise an exception # Current implementation doesn't handle **kwargs, so this should raise an exception
with pytest.raises(RepositoryImportError) as exc_info: with pytest.raises(RepositoryImportError) as exc_info:
DifyCoreRepositoryFactory._validate_constructor_signature( RepositoryFactory._validate_constructor_signature(
MockRepository, ["session_factory", "user", "app_id", "triggered_from"] MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
) )
assert "does not accept required parameters" in str(exc_info.value) assert "does not accept required parameters" in str(exc_info.value)

Loading…
Cancel
Save