Merge branch 'main' into fix-agent-node-file-hanlding-one-turn-delay-bug
commit
16159a54ec
@ -0,0 +1,224 @@
|
||||
"""
|
||||
Repository factory for dynamically creating repository instances based on configuration.
|
||||
|
||||
This module provides a Django-like settings system for repository implementations,
|
||||
allowing users to configure different repository backends through string paths.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Protocol, Union
|
||||
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from models import Account, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RepositoryImportError(Exception):
|
||||
"""Raised when a repository implementation cannot be imported or instantiated."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class DifyCoreRepositoryFactory:
|
||||
"""
|
||||
Factory for creating repository instances based on configuration.
|
||||
|
||||
This factory supports Django-like settings where repository implementations
|
||||
are specified as module paths (e.g., 'module.submodule.ClassName').
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _import_class(class_path: str) -> type:
|
||||
"""
|
||||
Import a class from a module path string.
|
||||
|
||||
Args:
|
||||
class_path: Full module path to the class (e.g., 'module.submodule.ClassName')
|
||||
|
||||
Returns:
|
||||
The imported class
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the class cannot be imported
|
||||
"""
|
||||
try:
|
||||
module_path, class_name = class_path.rsplit(".", 1)
|
||||
module = importlib.import_module(module_path)
|
||||
repo_class = getattr(module, class_name)
|
||||
assert isinstance(repo_class, type)
|
||||
return repo_class
|
||||
except (ValueError, ImportError, AttributeError) as e:
|
||||
raise RepositoryImportError(f"Cannot import repository class '{class_path}': {e}") from e
|
||||
|
||||
@staticmethod
|
||||
def _validate_repository_interface(repository_class: type, expected_interface: type[Protocol]) -> None: # type: ignore
|
||||
"""
|
||||
Validate that a class implements the expected repository interface.
|
||||
|
||||
Args:
|
||||
repository_class: The class to validate
|
||||
expected_interface: The expected interface/protocol
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the class doesn't implement the interface
|
||||
"""
|
||||
# Check if the class has all required methods from the protocol
|
||||
required_methods = [
|
||||
method
|
||||
for method in dir(expected_interface)
|
||||
if not method.startswith("_") and callable(getattr(expected_interface, method, None))
|
||||
]
|
||||
|
||||
missing_methods = []
|
||||
for method_name in required_methods:
|
||||
if not hasattr(repository_class, method_name):
|
||||
missing_methods.append(method_name)
|
||||
|
||||
if missing_methods:
|
||||
raise RepositoryImportError(
|
||||
f"Repository class '{repository_class.__name__}' does not implement required methods "
|
||||
f"{missing_methods} from interface '{expected_interface.__name__}'"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_constructor_signature(repository_class: type, required_params: list[str]) -> None:
|
||||
"""
|
||||
Validate that a repository class constructor accepts required parameters.
|
||||
|
||||
Args:
|
||||
repository_class: The class to validate
|
||||
required_params: List of required parameter names
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the constructor doesn't accept required parameters
|
||||
"""
|
||||
|
||||
try:
|
||||
# MyPy may flag the line below with the following error:
|
||||
#
|
||||
# > Accessing "__init__" on an instance is unsound, since
|
||||
# > instance.__init__ could be from an incompatible subclass.
|
||||
#
|
||||
# Despite this, we need to ensure that the constructor of `repository_class`
|
||||
# has a compatible signature.
|
||||
signature = inspect.signature(repository_class.__init__) # type: ignore[misc]
|
||||
param_names = list(signature.parameters.keys())
|
||||
|
||||
# Remove 'self' parameter
|
||||
if "self" in param_names:
|
||||
param_names.remove("self")
|
||||
|
||||
missing_params = [param for param in required_params if param not in param_names]
|
||||
if missing_params:
|
||||
raise RepositoryImportError(
|
||||
f"Repository class '{repository_class.__name__}' constructor does not accept required parameters: "
|
||||
f"{missing_params}. Expected parameters: {required_params}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise RepositoryImportError(
|
||||
f"Failed to validate constructor signature for '{repository_class.__name__}': {e}"
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def create_workflow_execution_repository(
|
||||
cls,
|
||||
session_factory: Union[sessionmaker, Engine],
|
||||
user: Union[Account, EndUser],
|
||||
app_id: str,
|
||||
triggered_from: WorkflowRunTriggeredFrom,
|
||||
) -> WorkflowExecutionRepository:
|
||||
"""
|
||||
Create a WorkflowExecutionRepository instance based on configuration.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine
|
||||
user: Account or EndUser object
|
||||
app_id: Application ID
|
||||
triggered_from: Source of the execution trigger
|
||||
|
||||
Returns:
|
||||
Configured WorkflowExecutionRepository instance
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the configured repository cannot be created
|
||||
"""
|
||||
class_path = dify_config.CORE_WORKFLOW_EXECUTION_REPOSITORY
|
||||
logger.debug(f"Creating WorkflowExecutionRepository from: {class_path}")
|
||||
|
||||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
cls._validate_repository_interface(repository_class, WorkflowExecutionRepository)
|
||||
cls._validate_constructor_signature(
|
||||
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
except RepositoryImportError:
|
||||
# Re-raise our custom errors as-is
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create WorkflowExecutionRepository")
|
||||
raise RepositoryImportError(f"Failed to create WorkflowExecutionRepository from '{class_path}': {e}") from e
|
||||
|
||||
@classmethod
|
||||
def create_workflow_node_execution_repository(
|
||||
cls,
|
||||
session_factory: Union[sessionmaker, Engine],
|
||||
user: Union[Account, EndUser],
|
||||
app_id: str,
|
||||
triggered_from: WorkflowNodeExecutionTriggeredFrom,
|
||||
) -> WorkflowNodeExecutionRepository:
|
||||
"""
|
||||
Create a WorkflowNodeExecutionRepository instance based on configuration.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine
|
||||
user: Account or EndUser object
|
||||
app_id: Application ID
|
||||
triggered_from: Source of the execution trigger
|
||||
|
||||
Returns:
|
||||
Configured WorkflowNodeExecutionRepository instance
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the configured repository cannot be created
|
||||
"""
|
||||
class_path = dify_config.CORE_WORKFLOW_NODE_EXECUTION_REPOSITORY
|
||||
logger.debug(f"Creating WorkflowNodeExecutionRepository from: {class_path}")
|
||||
|
||||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
cls._validate_repository_interface(repository_class, WorkflowNodeExecutionRepository)
|
||||
cls._validate_constructor_signature(
|
||||
repository_class, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
return repository_class( # type: ignore[no-any-return]
|
||||
session_factory=session_factory,
|
||||
user=user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
except RepositoryImportError:
|
||||
# Re-raise our custom errors as-is
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create WorkflowNodeExecutionRepository")
|
||||
raise RepositoryImportError(
|
||||
f"Failed to create WorkflowNodeExecutionRepository from '{class_path}': {e}"
|
||||
) from e
|
||||
@ -0,0 +1,197 @@
|
||||
"""
|
||||
Service-layer repository protocol for WorkflowNodeExecutionModel operations.
|
||||
|
||||
This module provides a protocol interface for service-layer operations on WorkflowNodeExecutionModel
|
||||
that abstracts database queries currently done directly in service classes. This repository is
|
||||
specifically designed for service-layer needs and is separate from the core domain repository.
|
||||
|
||||
The service repository handles operations that require access to database-specific fields like
|
||||
tenant_id, app_id, triggered_from, etc., which are not part of the core domain model.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Optional, Protocol
|
||||
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
|
||||
class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol):
|
||||
"""
|
||||
Protocol for service-layer operations on WorkflowNodeExecutionModel.
|
||||
|
||||
This repository provides database access patterns specifically needed by service classes,
|
||||
handling queries that involve database-specific fields and multi-tenancy concerns.
|
||||
|
||||
Key responsibilities:
|
||||
- Manages database operations for workflow node executions
|
||||
- Handles multi-tenant data isolation
|
||||
- Provides batch processing capabilities
|
||||
- Supports execution lifecycle management
|
||||
|
||||
Implementation notes:
|
||||
- Returns database models directly (WorkflowNodeExecutionModel)
|
||||
- Handles tenant/app filtering automatically
|
||||
- Provides service-specific query patterns
|
||||
- Focuses on database operations without domain logic
|
||||
- Supports cleanup and maintenance operations
|
||||
"""
|
||||
|
||||
def get_node_last_execution(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
) -> Optional[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get the most recent execution for a specific node.
|
||||
|
||||
This method finds the latest execution of a specific node within a workflow,
|
||||
ordered by creation time. Used primarily for debugging and inspection purposes.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_id: The workflow identifier
|
||||
node_id: The node identifier
|
||||
|
||||
Returns:
|
||||
The most recent WorkflowNodeExecutionModel for the node, or None if not found
|
||||
"""
|
||||
...
|
||||
|
||||
def get_executions_by_workflow_run(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_run_id: str,
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get all node executions for a specific workflow run.
|
||||
|
||||
This method retrieves all node executions that belong to a specific workflow run,
|
||||
ordered by index in descending order for proper trace visualization.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_run_id: The workflow run identifier
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecutionModel instances ordered by index (desc)
|
||||
"""
|
||||
...
|
||||
|
||||
def get_execution_by_id(
|
||||
self,
|
||||
execution_id: str,
|
||||
tenant_id: Optional[str] = None,
|
||||
) -> Optional[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get a workflow node execution by its ID.
|
||||
|
||||
This method retrieves a specific execution by its unique identifier.
|
||||
Tenant filtering is optional for cases where the execution ID is globally unique.
|
||||
|
||||
When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants.
|
||||
If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should
|
||||
set `tenant_id` to prevent horizontal privilege escalation.
|
||||
|
||||
Args:
|
||||
execution_id: The execution identifier
|
||||
tenant_id: Optional tenant identifier for additional filtering
|
||||
|
||||
Returns:
|
||||
The WorkflowNodeExecutionModel if found, or None if not found
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_expired_executions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
before_date: datetime,
|
||||
batch_size: int = 1000,
|
||||
) -> int:
|
||||
"""
|
||||
Delete workflow node executions that are older than the specified date.
|
||||
|
||||
This method is used for cleanup operations to remove expired executions
|
||||
in batches to avoid overwhelming the database.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
before_date: Delete executions created before this date
|
||||
batch_size: Maximum number of executions to delete in one batch
|
||||
|
||||
Returns:
|
||||
The number of executions deleted
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_executions_by_app(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
batch_size: int = 1000,
|
||||
) -> int:
|
||||
"""
|
||||
Delete all workflow node executions for a specific app.
|
||||
|
||||
This method is used when removing an app and all its related data.
|
||||
Executions are deleted in batches to avoid overwhelming the database.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
batch_size: Maximum number of executions to delete in one batch
|
||||
|
||||
Returns:
|
||||
The total number of executions deleted
|
||||
"""
|
||||
...
|
||||
|
||||
def get_expired_executions_batch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
before_date: datetime,
|
||||
batch_size: int = 1000,
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get a batch of expired workflow node executions for backup purposes.
|
||||
|
||||
This method retrieves expired executions without deleting them,
|
||||
allowing the caller to backup the data before deletion.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
before_date: Get executions created before this date
|
||||
batch_size: Maximum number of executions to retrieve
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecutionModel instances
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_executions_by_ids(
|
||||
self,
|
||||
execution_ids: Sequence[str],
|
||||
) -> int:
|
||||
"""
|
||||
Delete workflow node executions by their IDs.
|
||||
|
||||
This method deletes specific executions by their IDs,
|
||||
typically used after backing up the data.
|
||||
|
||||
This method does not perform tenant isolation checks. The caller is responsible for ensuring proper
|
||||
data isolation between tenants. When execution IDs come from untrusted sources (e.g., API requests),
|
||||
additional tenant validation should be implemented to prevent unauthorized access.
|
||||
|
||||
Args:
|
||||
execution_ids: List of execution IDs to delete
|
||||
|
||||
Returns:
|
||||
The number of executions deleted
|
||||
"""
|
||||
...
|
||||
@ -0,0 +1,181 @@
|
||||
"""
|
||||
API WorkflowRun Repository Protocol
|
||||
|
||||
This module defines the protocol for service-layer WorkflowRun operations.
|
||||
The repository provides an abstraction layer for WorkflowRun database operations
|
||||
used by service classes, separating service-layer concerns from core domain logic.
|
||||
|
||||
Key Features:
|
||||
- Paginated workflow run queries with filtering
|
||||
- Bulk deletion operations with OSS backup support
|
||||
- Multi-tenant data isolation
|
||||
- Expired record cleanup with data retention
|
||||
- Service-layer specific query patterns
|
||||
|
||||
Usage:
|
||||
This protocol should be used by service classes that need to perform
|
||||
WorkflowRun database operations. It provides a clean interface that
|
||||
hides implementation details and supports dependency injection.
|
||||
|
||||
Example:
|
||||
```python
|
||||
from repositories.dify_api_repository_factory import DifyAPIRepositoryFactory
|
||||
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
# Get paginated workflow runs
|
||||
runs = repo.get_paginated_workflow_runs(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
triggered_from="debugging",
|
||||
limit=20
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Optional, Protocol
|
||||
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
|
||||
class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
||||
"""
|
||||
Protocol for service-layer WorkflowRun repository operations.
|
||||
|
||||
This protocol defines the interface for WorkflowRun database operations
|
||||
that are specific to service-layer needs, including pagination, filtering,
|
||||
and bulk operations with data backup support.
|
||||
"""
|
||||
|
||||
def get_paginated_workflow_runs(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
limit: int = 20,
|
||||
last_id: Optional[str] = None,
|
||||
) -> InfiniteScrollPagination:
|
||||
"""
|
||||
Get paginated workflow runs with filtering.
|
||||
|
||||
Retrieves workflow runs for a specific app and trigger source with
|
||||
cursor-based pagination support. Used primarily for debugging and
|
||||
workflow run listing in the UI.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier for multi-tenant isolation
|
||||
app_id: Application identifier
|
||||
triggered_from: Filter by trigger source (e.g., "debugging", "app-run")
|
||||
limit: Maximum number of records to return (default: 20)
|
||||
last_id: Cursor for pagination - ID of the last record from previous page
|
||||
|
||||
Returns:
|
||||
InfiniteScrollPagination object containing:
|
||||
- data: List of WorkflowRun objects
|
||||
- limit: Applied limit
|
||||
- has_more: Boolean indicating if more records exist
|
||||
|
||||
Raises:
|
||||
ValueError: If last_id is provided but the corresponding record doesn't exist
|
||||
"""
|
||||
...
|
||||
|
||||
def get_workflow_run_by_id(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
run_id: str,
|
||||
) -> Optional[WorkflowRun]:
|
||||
"""
|
||||
Get a specific workflow run by ID.
|
||||
|
||||
Retrieves a single workflow run with tenant and app isolation.
|
||||
Used for workflow run detail views and execution tracking.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier for multi-tenant isolation
|
||||
app_id: Application identifier
|
||||
run_id: Workflow run identifier
|
||||
|
||||
Returns:
|
||||
WorkflowRun object if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
def get_expired_runs_batch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
before_date: datetime,
|
||||
batch_size: int = 1000,
|
||||
) -> Sequence[WorkflowRun]:
|
||||
"""
|
||||
Get a batch of expired workflow runs for cleanup.
|
||||
|
||||
Retrieves workflow runs created before the specified date for
|
||||
cleanup operations. Used by scheduled tasks to remove old data
|
||||
while maintaining data retention policies.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier for multi-tenant isolation
|
||||
before_date: Only return runs created before this date
|
||||
batch_size: Maximum number of records to return
|
||||
|
||||
Returns:
|
||||
Sequence of WorkflowRun objects to be processed for cleanup
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_runs_by_ids(
|
||||
self,
|
||||
run_ids: Sequence[str],
|
||||
) -> int:
|
||||
"""
|
||||
Delete workflow runs by their IDs.
|
||||
|
||||
Performs bulk deletion of workflow runs by ID. This method should
|
||||
be used after backing up the data to OSS storage for retention.
|
||||
|
||||
Args:
|
||||
run_ids: Sequence of workflow run IDs to delete
|
||||
|
||||
Returns:
|
||||
Number of records actually deleted
|
||||
|
||||
Note:
|
||||
This method performs hard deletion. Ensure data is backed up
|
||||
to OSS storage before calling this method for compliance with
|
||||
data retention policies.
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_runs_by_app(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
batch_size: int = 1000,
|
||||
) -> int:
|
||||
"""
|
||||
Delete all workflow runs for a specific app.
|
||||
|
||||
Performs bulk deletion of all workflow runs associated with an app.
|
||||
Used during app cleanup operations. Processes records in batches
|
||||
to avoid memory issues and long-running transactions.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier for multi-tenant isolation
|
||||
app_id: Application identifier
|
||||
batch_size: Number of records to process in each batch
|
||||
|
||||
Returns:
|
||||
Total number of records deleted across all batches
|
||||
|
||||
Note:
|
||||
This method performs hard deletion without backup. Use with caution
|
||||
and ensure proper data retention policies are followed.
|
||||
"""
|
||||
...
|
||||
@ -0,0 +1,103 @@
|
||||
"""
|
||||
DifyAPI Repository Factory for creating repository instances.
|
||||
|
||||
This factory is specifically designed for DifyAPI repositories that handle
|
||||
service-layer operations with dependency injection patterns.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.repositories import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyAPIRepositoryFactory(DifyCoreRepositoryFactory):
|
||||
"""
|
||||
Factory for creating DifyAPI repository instances based on configuration.
|
||||
|
||||
This factory handles the creation of repositories that are specifically designed
|
||||
for service-layer operations and use dependency injection with sessionmaker
|
||||
for better testability and separation of concerns.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def create_api_workflow_node_execution_repository(
|
||||
cls, session_maker: sessionmaker
|
||||
) -> DifyAPIWorkflowNodeExecutionRepository:
|
||||
"""
|
||||
Create a DifyAPIWorkflowNodeExecutionRepository instance based on configuration.
|
||||
|
||||
This repository is designed for service-layer operations and uses dependency injection
|
||||
with a sessionmaker for better testability and separation of concerns. It provides
|
||||
database access patterns specifically needed by service classes, handling queries
|
||||
that involve database-specific fields and multi-tenancy concerns.
|
||||
|
||||
Args:
|
||||
session_maker: SQLAlchemy sessionmaker to inject for database session management.
|
||||
|
||||
Returns:
|
||||
Configured DifyAPIWorkflowNodeExecutionRepository instance
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the configured repository cannot be imported or instantiated
|
||||
"""
|
||||
class_path = dify_config.API_WORKFLOW_NODE_EXECUTION_REPOSITORY
|
||||
logger.debug(f"Creating DifyAPIWorkflowNodeExecutionRepository from: {class_path}")
|
||||
|
||||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
cls._validate_repository_interface(repository_class, DifyAPIWorkflowNodeExecutionRepository)
|
||||
# Service repository requires session_maker parameter
|
||||
cls._validate_constructor_signature(repository_class, ["session_maker"])
|
||||
|
||||
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
|
||||
except RepositoryImportError:
|
||||
# Re-raise our custom errors as-is
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create DifyAPIWorkflowNodeExecutionRepository")
|
||||
raise RepositoryImportError(
|
||||
f"Failed to create DifyAPIWorkflowNodeExecutionRepository from '{class_path}': {e}"
|
||||
) from e
|
||||
|
||||
@classmethod
|
||||
def create_api_workflow_run_repository(cls, session_maker: sessionmaker) -> APIWorkflowRunRepository:
|
||||
"""
|
||||
Create an APIWorkflowRunRepository instance based on configuration.
|
||||
|
||||
This repository is designed for service-layer WorkflowRun operations and uses dependency
|
||||
injection with a sessionmaker for better testability and separation of concerns. It provides
|
||||
database access patterns specifically needed by service classes for workflow run management,
|
||||
including pagination, filtering, and bulk operations.
|
||||
|
||||
Args:
|
||||
session_maker: SQLAlchemy sessionmaker to inject for database session management.
|
||||
|
||||
Returns:
|
||||
Configured APIWorkflowRunRepository instance
|
||||
|
||||
Raises:
|
||||
RepositoryImportError: If the configured repository cannot be imported or instantiated
|
||||
"""
|
||||
class_path = dify_config.API_WORKFLOW_RUN_REPOSITORY
|
||||
logger.debug(f"Creating APIWorkflowRunRepository from: {class_path}")
|
||||
|
||||
try:
|
||||
repository_class = cls._import_class(class_path)
|
||||
cls._validate_repository_interface(repository_class, APIWorkflowRunRepository)
|
||||
# Service repository requires session_maker parameter
|
||||
cls._validate_constructor_signature(repository_class, ["session_maker"])
|
||||
|
||||
return repository_class(session_maker=session_maker) # type: ignore[no-any-return]
|
||||
except RepositoryImportError:
|
||||
# Re-raise our custom errors as-is
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create APIWorkflowRunRepository")
|
||||
raise RepositoryImportError(f"Failed to create APIWorkflowRunRepository from '{class_path}': {e}") from e
|
||||
@ -0,0 +1,290 @@
|
||||
"""
|
||||
SQLAlchemy implementation of WorkflowNodeExecutionServiceRepository.
|
||||
|
||||
This module provides a concrete implementation of the service repository protocol
|
||||
using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import delete, desc, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
|
||||
"""
|
||||
SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
|
||||
|
||||
This repository provides service-layer database operations for WorkflowNodeExecutionModel
|
||||
using SQLAlchemy 2.0 style queries. It implements the DifyAPIWorkflowNodeExecutionRepository
|
||||
protocol with the following features:
|
||||
|
||||
- Multi-tenancy data isolation through tenant_id filtering
|
||||
- Direct database model operations without domain conversion
|
||||
- Batch processing for efficient large-scale operations
|
||||
- Optimized query patterns for common access patterns
|
||||
- Dependency injection for better testability and maintainability
|
||||
- Session management and transaction handling with proper cleanup
|
||||
- Maintenance operations for data lifecycle management
|
||||
- Thread-safe database operations using session-per-request pattern
|
||||
"""
|
||||
|
||||
def __init__(self, session_maker: sessionmaker[Session]):
|
||||
"""
|
||||
Initialize the repository with a sessionmaker.
|
||||
|
||||
Args:
|
||||
session_maker: SQLAlchemy sessionmaker for creating database sessions
|
||||
"""
|
||||
self._session_maker = session_maker
|
||||
|
||||
def get_node_last_execution(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
) -> Optional[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get the most recent execution for a specific node.
|
||||
|
||||
This method replicates the query pattern from WorkflowService.get_node_last_run()
|
||||
using SQLAlchemy 2.0 style syntax.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_id: The workflow identifier
|
||||
node_id: The node identifier
|
||||
|
||||
Returns:
|
||||
The most recent WorkflowNodeExecutionModel for the node, or None if not found
|
||||
"""
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
WorkflowNodeExecutionModel.workflow_id == workflow_id,
|
||||
WorkflowNodeExecutionModel.node_id == node_id,
|
||||
)
|
||||
.order_by(desc(WorkflowNodeExecutionModel.created_at))
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
return session.scalar(stmt)
|
||||
|
||||
def get_executions_by_workflow_run(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_run_id: str,
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get all node executions for a specific workflow run.
|
||||
|
||||
This method replicates the query pattern from WorkflowRunService.get_workflow_run_node_executions()
|
||||
using SQLAlchemy 2.0 style syntax.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_run_id: The workflow run identifier
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecutionModel instances ordered by index (desc)
|
||||
"""
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
)
|
||||
.order_by(desc(WorkflowNodeExecutionModel.index))
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
return session.execute(stmt).scalars().all()
|
||||
|
||||
def get_execution_by_id(
|
||||
self,
|
||||
execution_id: str,
|
||||
tenant_id: Optional[str] = None,
|
||||
) -> Optional[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get a workflow node execution by its ID.
|
||||
|
||||
This method replicates the query pattern from WorkflowDraftVariableService
|
||||
and WorkflowService.single_step_run_workflow_node() using SQLAlchemy 2.0 style syntax.
|
||||
|
||||
When `tenant_id` is None, it's the caller's responsibility to ensure proper data isolation between tenants.
|
||||
If the `execution_id` comes from untrusted sources (e.g., retrieved from an API request), the caller should
|
||||
set `tenant_id` to prevent horizontal privilege escalation.
|
||||
|
||||
Args:
|
||||
execution_id: The execution identifier
|
||||
tenant_id: Optional tenant identifier for additional filtering
|
||||
|
||||
Returns:
|
||||
The WorkflowNodeExecutionModel if found, or None if not found
|
||||
"""
|
||||
stmt = select(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id == execution_id)
|
||||
|
||||
# Add tenant filtering if provided
|
||||
if tenant_id is not None:
|
||||
stmt = stmt.where(WorkflowNodeExecutionModel.tenant_id == tenant_id)
|
||||
|
||||
with self._session_maker() as session:
|
||||
return session.scalar(stmt)
|
||||
|
||||
def delete_expired_executions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
before_date: datetime,
|
||||
batch_size: int = 1000,
|
||||
) -> int:
|
||||
"""
|
||||
Delete workflow node executions that are older than the specified date.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
before_date: Delete executions created before this date
|
||||
batch_size: Maximum number of executions to delete in one batch
|
||||
|
||||
Returns:
|
||||
The number of executions deleted
|
||||
"""
|
||||
total_deleted = 0
|
||||
|
||||
while True:
|
||||
with self._session_maker() as session:
|
||||
# Find executions to delete in batches
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel.id)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.created_at < before_date,
|
||||
)
|
||||
.limit(batch_size)
|
||||
)
|
||||
|
||||
execution_ids = session.execute(stmt).scalars().all()
|
||||
if not execution_ids:
|
||||
break
|
||||
|
||||
# Delete the batch
|
||||
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
||||
result = session.execute(delete_stmt)
|
||||
session.commit()
|
||||
total_deleted += result.rowcount
|
||||
|
||||
# If we deleted fewer than the batch size, we're done
|
||||
if len(execution_ids) < batch_size:
|
||||
break
|
||||
|
||||
return total_deleted
|
||||
|
||||
def delete_executions_by_app(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
batch_size: int = 1000,
|
||||
) -> int:
|
||||
"""
|
||||
Delete all workflow node executions for a specific app.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
batch_size: Maximum number of executions to delete in one batch
|
||||
|
||||
Returns:
|
||||
The total number of executions deleted
|
||||
"""
|
||||
total_deleted = 0
|
||||
|
||||
while True:
|
||||
with self._session_maker() as session:
|
||||
# Find executions to delete in batches
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel.id)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
)
|
||||
.limit(batch_size)
|
||||
)
|
||||
|
||||
execution_ids = session.execute(stmt).scalars().all()
|
||||
if not execution_ids:
|
||||
break
|
||||
|
||||
# Delete the batch
|
||||
delete_stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
||||
result = session.execute(delete_stmt)
|
||||
session.commit()
|
||||
total_deleted += result.rowcount
|
||||
|
||||
# If we deleted fewer than the batch size, we're done
|
||||
if len(execution_ids) < batch_size:
|
||||
break
|
||||
|
||||
return total_deleted
|
||||
|
||||
def get_expired_executions_batch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
before_date: datetime,
|
||||
batch_size: int = 1000,
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get a batch of expired workflow node executions for backup purposes.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
before_date: Get executions created before this date
|
||||
batch_size: Maximum number of executions to retrieve
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecutionModel instances
|
||||
"""
|
||||
stmt = (
|
||||
select(WorkflowNodeExecutionModel)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.created_at < before_date,
|
||||
)
|
||||
.limit(batch_size)
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
return session.execute(stmt).scalars().all()
|
||||
|
||||
def delete_executions_by_ids(
|
||||
self,
|
||||
execution_ids: Sequence[str],
|
||||
) -> int:
|
||||
"""
|
||||
Delete workflow node executions by their IDs.
|
||||
|
||||
Args:
|
||||
execution_ids: List of execution IDs to delete
|
||||
|
||||
Returns:
|
||||
The number of executions deleted
|
||||
"""
|
||||
if not execution_ids:
|
||||
return 0
|
||||
|
||||
with self._session_maker() as session:
|
||||
stmt = delete(WorkflowNodeExecutionModel).where(WorkflowNodeExecutionModel.id.in_(execution_ids))
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
return result.rowcount
|
||||
@ -0,0 +1,202 @@
|
||||
"""
|
||||
SQLAlchemy API WorkflowRun Repository Implementation
|
||||
|
||||
This module provides the SQLAlchemy-based implementation of the APIWorkflowRunRepository
|
||||
protocol. It handles service-layer WorkflowRun database operations using SQLAlchemy 2.0
|
||||
style queries with proper session management and multi-tenant data isolation.
|
||||
|
||||
Key Features:
|
||||
- SQLAlchemy 2.0 style queries for modern database operations
|
||||
- Cursor-based pagination for efficient large dataset handling
|
||||
- Bulk operations with batch processing for performance
|
||||
- Multi-tenant data isolation and security
|
||||
- Proper session management with dependency injection
|
||||
|
||||
Implementation Notes:
|
||||
- Uses sessionmaker for consistent session management
|
||||
- Implements cursor-based pagination using created_at timestamps
|
||||
- Provides efficient bulk deletion with batch processing
|
||||
- Maintains data consistency with proper transaction handling
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Optional, cast
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.workflow import WorkflowRun
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowRunRepository:
|
||||
"""
|
||||
SQLAlchemy implementation of APIWorkflowRunRepository.
|
||||
|
||||
Provides service-layer WorkflowRun database operations using SQLAlchemy 2.0
|
||||
style queries. Supports dependency injection through sessionmaker and
|
||||
maintains proper multi-tenant data isolation.
|
||||
|
||||
Args:
|
||||
session_maker: SQLAlchemy sessionmaker instance for database connections
|
||||
"""
|
||||
|
||||
def __init__(self, session_maker: sessionmaker[Session]) -> None:
|
||||
"""
|
||||
Initialize the repository with a sessionmaker.
|
||||
|
||||
Args:
|
||||
session_maker: SQLAlchemy sessionmaker for database connections
|
||||
"""
|
||||
self._session_maker = session_maker
|
||||
|
||||
def get_paginated_workflow_runs(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
triggered_from: str,
|
||||
limit: int = 20,
|
||||
last_id: Optional[str] = None,
|
||||
) -> InfiniteScrollPagination:
|
||||
"""
|
||||
Get paginated workflow runs with filtering.
|
||||
|
||||
Implements cursor-based pagination using created_at timestamps for
|
||||
efficient handling of large datasets. Filters by tenant, app, and
|
||||
trigger source for proper data isolation.
|
||||
"""
|
||||
with self._session_maker() as session:
|
||||
# Build base query with filters
|
||||
base_stmt = select(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.app_id == app_id,
|
||||
WorkflowRun.triggered_from == triggered_from,
|
||||
)
|
||||
|
||||
if last_id:
|
||||
# Get the last workflow run for cursor-based pagination
|
||||
last_run_stmt = base_stmt.where(WorkflowRun.id == last_id)
|
||||
last_workflow_run = session.scalar(last_run_stmt)
|
||||
|
||||
if not last_workflow_run:
|
||||
raise ValueError("Last workflow run not exists")
|
||||
|
||||
# Get records created before the last run's timestamp
|
||||
base_stmt = base_stmt.where(
|
||||
WorkflowRun.created_at < last_workflow_run.created_at,
|
||||
WorkflowRun.id != last_workflow_run.id,
|
||||
)
|
||||
|
||||
# First page - get most recent records
|
||||
workflow_runs = session.scalars(base_stmt.order_by(WorkflowRun.created_at.desc()).limit(limit + 1)).all()
|
||||
|
||||
# Check if there are more records for pagination
|
||||
has_more = len(workflow_runs) > limit
|
||||
if has_more:
|
||||
workflow_runs = workflow_runs[:-1]
|
||||
|
||||
return InfiniteScrollPagination(data=workflow_runs, limit=limit, has_more=has_more)
|
||||
|
||||
def get_workflow_run_by_id(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
run_id: str,
|
||||
) -> Optional[WorkflowRun]:
|
||||
"""
|
||||
Get a specific workflow run by ID with tenant and app isolation.
|
||||
"""
|
||||
with self._session_maker() as session:
|
||||
stmt = select(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.app_id == app_id,
|
||||
WorkflowRun.id == run_id,
|
||||
)
|
||||
return cast(Optional[WorkflowRun], session.scalar(stmt))
|
||||
|
||||
def get_expired_runs_batch(
|
||||
self,
|
||||
tenant_id: str,
|
||||
before_date: datetime,
|
||||
batch_size: int = 1000,
|
||||
) -> Sequence[WorkflowRun]:
|
||||
"""
|
||||
Get a batch of expired workflow runs for cleanup operations.
|
||||
"""
|
||||
with self._session_maker() as session:
|
||||
stmt = (
|
||||
select(WorkflowRun)
|
||||
.where(
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.created_at < before_date,
|
||||
)
|
||||
.limit(batch_size)
|
||||
)
|
||||
return cast(Sequence[WorkflowRun], session.scalars(stmt).all())
|
||||
|
||||
def delete_runs_by_ids(
|
||||
self,
|
||||
run_ids: Sequence[str],
|
||||
) -> int:
|
||||
"""
|
||||
Delete workflow runs by their IDs using bulk deletion.
|
||||
"""
|
||||
if not run_ids:
|
||||
return 0
|
||||
|
||||
with self._session_maker() as session:
|
||||
stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
deleted_count = cast(int, result.rowcount)
|
||||
logger.info(f"Deleted {deleted_count} workflow runs by IDs")
|
||||
return deleted_count
|
||||
|
||||
def delete_runs_by_app(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
batch_size: int = 1000,
|
||||
) -> int:
|
||||
"""
|
||||
Delete all workflow runs for a specific app in batches.
|
||||
"""
|
||||
total_deleted = 0
|
||||
|
||||
while True:
|
||||
with self._session_maker() as session:
|
||||
# Get a batch of run IDs to delete
|
||||
stmt = (
|
||||
select(WorkflowRun.id)
|
||||
.where(
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.app_id == app_id,
|
||||
)
|
||||
.limit(batch_size)
|
||||
)
|
||||
run_ids = session.scalars(stmt).all()
|
||||
|
||||
if not run_ids:
|
||||
break
|
||||
|
||||
# Delete the batch
|
||||
delete_stmt = delete(WorkflowRun).where(WorkflowRun.id.in_(run_ids))
|
||||
result = session.execute(delete_stmt)
|
||||
session.commit()
|
||||
|
||||
batch_deleted = result.rowcount
|
||||
total_deleted += batch_deleted
|
||||
|
||||
logger.info(f"Deleted batch of {batch_deleted} workflow runs for app {app_id}")
|
||||
|
||||
# If we deleted fewer records than the batch size, we're done
|
||||
if batch_deleted < batch_size:
|
||||
break
|
||||
|
||||
logger.info(f"Total deleted {total_deleted} workflow runs for app {app_id}")
|
||||
return total_deleted
|
||||
@ -0,0 +1 @@
|
||||
# Unit tests for core repositories module
|
||||
@ -0,0 +1,455 @@
|
||||
"""
|
||||
Unit tests for the RepositoryFactory.
|
||||
|
||||
This module tests the factory pattern implementation for creating repository instances
|
||||
based on configuration, including error handling and validation.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from models import Account, EndUser
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
|
||||
class TestRepositoryFactory:
|
||||
"""Test cases for RepositoryFactory."""
|
||||
|
||||
def test_import_class_success(self):
|
||||
"""Test successful class import."""
|
||||
# Test importing a real class
|
||||
class_path = "unittest.mock.MagicMock"
|
||||
result = DifyCoreRepositoryFactory._import_class(class_path)
|
||||
assert result is MagicMock
|
||||
|
||||
def test_import_class_invalid_path(self):
|
||||
"""Test import with invalid module path."""
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._import_class("invalid.module.path")
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
|
||||
def test_import_class_invalid_class_name(self):
|
||||
"""Test import with invalid class name."""
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._import_class("unittest.mock.NonExistentClass")
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
|
||||
def test_import_class_malformed_path(self):
|
||||
"""Test import with malformed path (no dots)."""
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._import_class("invalidpath")
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
|
||||
def test_validate_repository_interface_success(self):
|
||||
"""Test successful interface validation."""
|
||||
|
||||
# Create a mock class that implements the required methods
|
||||
class MockRepository:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
# Create a mock interface with the same methods
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
# Should not raise an exception
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
|
||||
|
||||
def test_validate_repository_interface_missing_methods(self):
|
||||
"""Test interface validation with missing methods."""
|
||||
|
||||
# Create a mock class that doesn't implement all required methods
|
||||
class IncompleteRepository:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
# Missing get_by_id method
|
||||
|
||||
# Create a mock interface with required methods
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(IncompleteRepository, MockInterface)
|
||||
assert "does not implement required methods" in str(exc_info.value)
|
||||
assert "get_by_id" in str(exc_info.value)
|
||||
|
||||
def test_validate_constructor_signature_success(self):
|
||||
"""Test successful constructor signature validation."""
|
||||
|
||||
class MockRepository:
|
||||
def __init__(self, session_factory, user, app_id, triggered_from):
|
||||
pass
|
||||
|
||||
# Should not raise an exception
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(
|
||||
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
def test_validate_constructor_signature_missing_params(self):
|
||||
"""Test constructor validation with missing parameters."""
|
||||
|
||||
class IncompleteRepository:
|
||||
def __init__(self, session_factory, user):
|
||||
# Missing app_id and triggered_from parameters
|
||||
pass
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(
|
||||
IncompleteRepository, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
assert "does not accept required parameters" in str(exc_info.value)
|
||||
assert "app_id" in str(exc_info.value)
|
||||
assert "triggered_from" in str(exc_info.value)
|
||||
|
||||
def test_validate_constructor_signature_inspection_error(self, mocker: MockerFixture):
|
||||
"""Test constructor validation when inspection fails."""
|
||||
# Mock inspect.signature to raise an exception
|
||||
mocker.patch("inspect.signature", side_effect=Exception("Inspection failed"))
|
||||
|
||||
class MockRepository:
|
||||
def __init__(self, session_factory):
|
||||
pass
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(MockRepository, ["session_factory"])
|
||||
assert "Failed to validate constructor signature" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_success(self, mock_config, mocker: MockerFixture):
|
||||
"""Test successful creation of WorkflowExecutionRepository."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
# Create mock dependencies
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowRunTriggeredFrom.APP_RUN
|
||||
|
||||
# Mock the imported class to be a valid repository
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
||||
# Mock the validation methods
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
# Verify the repository was created with correct parameters
|
||||
mock_repository_class.assert_called_once_with(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
assert result is mock_repository_instance
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_import_error(self, mock_config):
|
||||
"""Test WorkflowExecutionRepository creation with import error."""
|
||||
# Setup mock configuration with invalid class path
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_validation_error(self, mock_config, mocker: MockerFixture):
|
||||
"""Test WorkflowExecutionRepository creation with validation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
|
||||
# Mock import to succeed but validation to fail
|
||||
mock_repository_class = MagicMock()
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(
|
||||
DifyCoreRepositoryFactory,
|
||||
"_validate_repository_interface",
|
||||
side_effect=RepositoryImportError("Interface validation failed"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
assert "Interface validation failed" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_execution_repository_instantiation_error(self, mock_config, mocker: MockerFixture):
|
||||
"""Test WorkflowExecutionRepository creation with instantiation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
|
||||
# Mock import and validation to succeed but instantiation to fail
|
||||
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
assert "Failed to create WorkflowExecutionRepository" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_success(self, mock_config, mocker: MockerFixture):
|
||||
"""Test successful creation of WorkflowNodeExecutionRepository."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
# Create mock dependencies
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
app_id = "test-app-id"
|
||||
triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN
|
||||
|
||||
# Mock the imported class to be a valid repository
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_instance = MagicMock(spec=WorkflowNodeExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
||||
# Mock the validation methods
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
|
||||
# Verify the repository was created with correct parameters
|
||||
mock_repository_class.assert_called_once_with(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id=app_id,
|
||||
triggered_from=triggered_from,
|
||||
)
|
||||
assert result is mock_repository_instance
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_import_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with import error."""
|
||||
# Setup mock configuration with invalid class path
|
||||
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "invalid.module.InvalidClass"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
assert "Cannot import repository class" in str(exc_info.value)
|
||||
|
||||
def test_repository_import_error_exception(self):
|
||||
"""Test RepositoryImportError exception."""
|
||||
error_message = "Test error message"
|
||||
exception = RepositoryImportError(error_message)
|
||||
assert str(exception) == error_message
|
||||
assert isinstance(exception, Exception)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_with_engine_instead_of_sessionmaker(self, mock_config, mocker: MockerFixture):
|
||||
"""Test repository creation with Engine instead of sessionmaker."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
# Create mock dependencies with Engine instead of sessionmaker
|
||||
mock_engine = MagicMock(spec=Engine)
|
||||
mock_user = MagicMock(spec=Account)
|
||||
|
||||
# Mock the imported class to be a valid repository
|
||||
mock_repository_class = MagicMock()
|
||||
mock_repository_instance = MagicMock(spec=WorkflowExecutionRepository)
|
||||
mock_repository_class.return_value = mock_repository_instance
|
||||
|
||||
# Mock the validation methods
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
result = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=mock_engine, # Using Engine instead of sessionmaker
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
|
||||
# Verify the repository was created with the Engine
|
||||
mock_repository_class.assert_called_once_with(
|
||||
session_factory=mock_engine,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
)
|
||||
assert result is mock_repository_instance
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_validation_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with validation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
# Mock import to succeed but validation to fail
|
||||
mock_repository_class = MagicMock()
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(
|
||||
DifyCoreRepositoryFactory,
|
||||
"_validate_repository_interface",
|
||||
side_effect=RepositoryImportError("Interface validation failed"),
|
||||
),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
assert "Interface validation failed" in str(exc_info.value)
|
||||
|
||||
@patch("core.repositories.factory.dify_config")
|
||||
def test_create_workflow_node_execution_repository_instantiation_error(self, mock_config):
|
||||
"""Test WorkflowNodeExecutionRepository creation with instantiation error."""
|
||||
# Setup mock configuration
|
||||
mock_config.WORKFLOW_NODE_EXECUTION_REPOSITORY = "unittest.mock.MagicMock"
|
||||
|
||||
mock_session_factory = MagicMock(spec=sessionmaker)
|
||||
mock_user = MagicMock(spec=EndUser)
|
||||
|
||||
# Mock import and validation to succeed but instantiation to fail
|
||||
mock_repository_class = MagicMock(side_effect=Exception("Instantiation failed"))
|
||||
with (
|
||||
patch.object(DifyCoreRepositoryFactory, "_import_class", return_value=mock_repository_class),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_repository_interface"),
|
||||
patch.object(DifyCoreRepositoryFactory, "_validate_constructor_signature"),
|
||||
):
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=mock_session_factory,
|
||||
user=mock_user,
|
||||
app_id="test-app-id",
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
assert "Failed to create WorkflowNodeExecutionRepository" in str(exc_info.value)
|
||||
|
||||
def test_validate_repository_interface_with_private_methods(self):
|
||||
"""Test interface validation ignores private methods."""
|
||||
|
||||
# Create a mock class with private methods
|
||||
class MockRepository:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
def _private_method(self):
|
||||
pass
|
||||
|
||||
# Create a mock interface with private methods
|
||||
class MockInterface:
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def get_by_id(self):
|
||||
pass
|
||||
|
||||
def _private_method(self):
|
||||
pass
|
||||
|
||||
# Should not raise an exception (private methods are ignored)
|
||||
DifyCoreRepositoryFactory._validate_repository_interface(MockRepository, MockInterface)
|
||||
|
||||
def test_validate_constructor_signature_with_extra_params(self):
|
||||
"""Test constructor validation with extra parameters (should pass)."""
|
||||
|
||||
class MockRepository:
|
||||
def __init__(self, session_factory, user, app_id, triggered_from, extra_param=None):
|
||||
pass
|
||||
|
||||
# Should not raise an exception (extra parameters are allowed)
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(
|
||||
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
|
||||
def test_validate_constructor_signature_with_kwargs(self):
|
||||
"""Test constructor validation with **kwargs (current implementation doesn't support this)."""
|
||||
|
||||
class MockRepository:
|
||||
def __init__(self, session_factory, user, **kwargs):
|
||||
pass
|
||||
|
||||
# Current implementation doesn't handle **kwargs, so this should raise an exception
|
||||
with pytest.raises(RepositoryImportError) as exc_info:
|
||||
DifyCoreRepositoryFactory._validate_constructor_signature(
|
||||
MockRepository, ["session_factory", "user", "app_id", "triggered_from"]
|
||||
)
|
||||
assert "does not accept required parameters" in str(exc_info.value)
|
||||
assert "app_id" in str(exc_info.value)
|
||||
assert "triggered_from" in str(exc_info.value)
|
||||
@ -0,0 +1,232 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
from flask_login import LoginManager, UserMixin
|
||||
|
||||
from libs.login import _get_user, current_user, login_required
|
||||
|
||||
|
||||
class MockUser(UserMixin):
|
||||
"""Mock user class for testing."""
|
||||
|
||||
def __init__(self, id: str, is_authenticated: bool = True):
|
||||
self.id = id
|
||||
self._is_authenticated = is_authenticated
|
||||
|
||||
@property
|
||||
def is_authenticated(self):
|
||||
return self._is_authenticated
|
||||
|
||||
|
||||
class TestLoginRequired:
|
||||
"""Test cases for login_required decorator."""
|
||||
|
||||
@pytest.fixture
|
||||
def setup_app(self, app: Flask):
|
||||
"""Set up Flask app with login manager."""
|
||||
# Initialize login manager
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
# Mock unauthorized handler
|
||||
login_manager.unauthorized = MagicMock(return_value="Unauthorized")
|
||||
|
||||
# Add a dummy user loader to prevent exceptions
|
||||
@login_manager.user_loader
|
||||
def load_user(user_id):
|
||||
return None
|
||||
|
||||
return app
|
||||
|
||||
def test_authenticated_user_can_access_protected_view(self, setup_app: Flask):
|
||||
"""Test that authenticated users can access protected views."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context():
|
||||
# Mock authenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
|
||||
def test_unauthenticated_user_cannot_access_protected_view(self, setup_app: Flask):
|
||||
"""Test that unauthenticated users are redirected."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context():
|
||||
# Mock unauthenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Unauthorized"
|
||||
setup_app.login_manager.unauthorized.assert_called_once()
|
||||
|
||||
def test_login_disabled_allows_unauthenticated_access(self, setup_app: Flask):
|
||||
"""Test that LOGIN_DISABLED config bypasses authentication."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context():
|
||||
# Mock unauthenticated user and LOGIN_DISABLED
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
with patch("libs.login.dify_config") as mock_config:
|
||||
mock_config.LOGIN_DISABLED = True
|
||||
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
# Ensure unauthorized was not called
|
||||
setup_app.login_manager.unauthorized.assert_not_called()
|
||||
|
||||
def test_options_request_bypasses_authentication(self, setup_app: Flask):
|
||||
"""Test that OPTIONS requests are exempt from authentication."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
with setup_app.test_request_context(method="OPTIONS"):
|
||||
# Mock unauthenticated user
|
||||
mock_user = MockUser("test_user", is_authenticated=False)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
# Ensure unauthorized was not called
|
||||
setup_app.login_manager.unauthorized.assert_not_called()
|
||||
|
||||
def test_flask_2_compatibility(self, setup_app: Flask):
|
||||
"""Test Flask 2.x compatibility with ensure_sync."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
# Mock Flask 2.x ensure_sync
|
||||
setup_app.ensure_sync = MagicMock(return_value=lambda: "Synced content")
|
||||
|
||||
with setup_app.test_request_context():
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Synced content"
|
||||
setup_app.ensure_sync.assert_called_once()
|
||||
|
||||
def test_flask_1_compatibility(self, setup_app: Flask):
|
||||
"""Test Flask 1.x compatibility without ensure_sync."""
|
||||
|
||||
@login_required
|
||||
def protected_view():
|
||||
return "Protected content"
|
||||
|
||||
# Remove ensure_sync to simulate Flask 1.x
|
||||
if hasattr(setup_app, "ensure_sync"):
|
||||
delattr(setup_app, "ensure_sync")
|
||||
|
||||
with setup_app.test_request_context():
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
result = protected_view()
|
||||
assert result == "Protected content"
|
||||
|
||||
|
||||
class TestGetUser:
|
||||
"""Test cases for _get_user function."""
|
||||
|
||||
def test_get_user_returns_user_from_g(self, app: Flask):
|
||||
"""Test that _get_user returns user from g._login_user."""
|
||||
mock_user = MockUser("test_user")
|
||||
|
||||
with app.test_request_context():
|
||||
g._login_user = mock_user
|
||||
user = _get_user()
|
||||
assert user == mock_user
|
||||
assert user.id == "test_user"
|
||||
|
||||
def test_get_user_loads_user_if_not_in_g(self, app: Flask):
|
||||
"""Test that _get_user loads user if not already in g."""
|
||||
mock_user = MockUser("test_user")
|
||||
|
||||
# Mock login manager
|
||||
login_manager = MagicMock()
|
||||
login_manager._load_user = MagicMock()
|
||||
app.login_manager = login_manager
|
||||
|
||||
with app.test_request_context():
|
||||
# Simulate _load_user setting g._login_user
|
||||
def side_effect():
|
||||
g._login_user = mock_user
|
||||
|
||||
login_manager._load_user.side_effect = side_effect
|
||||
|
||||
user = _get_user()
|
||||
assert user == mock_user
|
||||
login_manager._load_user.assert_called_once()
|
||||
|
||||
def test_get_user_returns_none_without_request_context(self, app: Flask):
|
||||
"""Test that _get_user returns None outside request context."""
|
||||
# Outside of request context
|
||||
user = _get_user()
|
||||
assert user is None
|
||||
|
||||
|
||||
class TestCurrentUser:
|
||||
"""Test cases for current_user proxy."""
|
||||
|
||||
def test_current_user_proxy_returns_authenticated_user(self, app: Flask):
|
||||
"""Test that current_user proxy returns authenticated user."""
|
||||
mock_user = MockUser("test_user", is_authenticated=True)
|
||||
|
||||
with app.test_request_context():
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
assert current_user.id == "test_user"
|
||||
assert current_user.is_authenticated is True
|
||||
|
||||
def test_current_user_proxy_returns_none_when_no_user(self, app: Flask):
|
||||
"""Test that current_user proxy handles None user."""
|
||||
with app.test_request_context():
|
||||
with patch("libs.login._get_user", return_value=None):
|
||||
# When _get_user returns None, accessing attributes should fail
|
||||
# or current_user should evaluate to falsy
|
||||
try:
|
||||
# Try to access an attribute that would exist on a real user
|
||||
_ = current_user.id
|
||||
pytest.fail("Should have raised AttributeError")
|
||||
except AttributeError:
|
||||
# This is expected when current_user is None
|
||||
pass
|
||||
|
||||
def test_current_user_proxy_thread_safety(self, app: Flask):
|
||||
"""Test that current_user proxy is thread-safe."""
|
||||
import threading
|
||||
|
||||
results = {}
|
||||
|
||||
def check_user_in_thread(user_id: str, index: int):
|
||||
with app.test_request_context():
|
||||
mock_user = MockUser(user_id)
|
||||
with patch("libs.login._get_user", return_value=mock_user):
|
||||
results[index] = current_user.id
|
||||
|
||||
# Create multiple threads with different users
|
||||
threads = []
|
||||
for i in range(5):
|
||||
thread = threading.Thread(target=check_user_in_thread, args=(f"user_{i}", i))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
# Wait for all threads to complete
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Verify each thread got its own user
|
||||
for i in range(5):
|
||||
assert results[i] == f"user_{i}"
|
||||
@ -0,0 +1,205 @@
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
from libs.passport import PassportService
|
||||
|
||||
|
||||
class TestPassportService:
|
||||
"""Test PassportService JWT operations"""
|
||||
|
||||
@pytest.fixture
|
||||
def passport_service(self):
|
||||
"""Create PassportService instance with test secret key"""
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "test-secret-key-for-testing"
|
||||
return PassportService()
|
||||
|
||||
@pytest.fixture
|
||||
def another_passport_service(self):
|
||||
"""Create another PassportService instance with different secret key"""
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "another-secret-key-for-testing"
|
||||
return PassportService()
|
||||
|
||||
# Core functionality tests
|
||||
def test_should_issue_and_verify_token(self, passport_service):
|
||||
"""Test complete JWT lifecycle: issue and verify"""
|
||||
payload = {"user_id": "123", "app_code": "test-app"}
|
||||
token = passport_service.issue(payload)
|
||||
|
||||
# Verify token format
|
||||
assert isinstance(token, str)
|
||||
assert len(token.split(".")) == 3 # JWT format: header.payload.signature
|
||||
|
||||
# Verify token content
|
||||
decoded = passport_service.verify(token)
|
||||
assert decoded == payload
|
||||
|
||||
def test_should_handle_different_payload_types(self, passport_service):
|
||||
"""Test issuing and verifying tokens with different payload types"""
|
||||
test_cases = [
|
||||
{"string": "value"},
|
||||
{"number": 42},
|
||||
{"float": 3.14},
|
||||
{"boolean": True},
|
||||
{"null": None},
|
||||
{"array": [1, 2, 3]},
|
||||
{"nested": {"key": "value"}},
|
||||
{"unicode": "中文测试"},
|
||||
{"emoji": "🔐"},
|
||||
{}, # Empty payload
|
||||
]
|
||||
|
||||
for payload in test_cases:
|
||||
token = passport_service.issue(payload)
|
||||
decoded = passport_service.verify(token)
|
||||
assert decoded == payload
|
||||
|
||||
# Security tests
|
||||
def test_should_reject_modified_token(self, passport_service):
|
||||
"""Test that any modification to token invalidates it"""
|
||||
token = passport_service.issue({"user": "test"})
|
||||
|
||||
# Test multiple modification points
|
||||
test_positions = [0, len(token) // 3, len(token) // 2, len(token) - 1]
|
||||
|
||||
for pos in test_positions:
|
||||
if pos < len(token) and token[pos] != ".":
|
||||
# Change one character
|
||||
tampered = token[:pos] + ("X" if token[pos] != "X" else "Y") + token[pos + 1 :]
|
||||
with pytest.raises(Unauthorized):
|
||||
passport_service.verify(tampered)
|
||||
|
||||
def test_should_reject_token_with_different_secret_key(self, passport_service, another_passport_service):
|
||||
"""Test key isolation - token from one service should not work with another"""
|
||||
payload = {"user_id": "123", "app_code": "test-app"}
|
||||
token = passport_service.issue(payload)
|
||||
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
another_passport_service.verify(token)
|
||||
assert str(exc_info.value) == "401 Unauthorized: Invalid token signature."
|
||||
|
||||
def test_should_use_hs256_algorithm(self, passport_service):
|
||||
"""Test that HS256 algorithm is used for signing"""
|
||||
payload = {"test": "data"}
|
||||
token = passport_service.issue(payload)
|
||||
|
||||
# Decode header without relying on JWT internals
|
||||
# Use jwt.get_unverified_header which is a public API
|
||||
header = jwt.get_unverified_header(token)
|
||||
assert header["alg"] == "HS256"
|
||||
|
||||
def test_should_reject_token_with_wrong_algorithm(self, passport_service):
|
||||
"""Test rejection of token signed with different algorithm"""
|
||||
payload = {"user_id": "123"}
|
||||
|
||||
# Create token with different algorithm
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "test-secret-key-for-testing"
|
||||
# Create token with HS512 instead of HS256
|
||||
wrong_alg_token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS512")
|
||||
|
||||
# Should fail because service expects HS256
|
||||
# InvalidAlgorithmError is now caught by PyJWTError handler
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
passport_service.verify(wrong_alg_token)
|
||||
assert str(exc_info.value) == "401 Unauthorized: Invalid token."
|
||||
|
||||
# Exception handling tests
|
||||
def test_should_handle_invalid_tokens(self, passport_service):
|
||||
"""Test handling of various invalid token formats"""
|
||||
invalid_tokens = [
|
||||
("not.a.token", "Invalid token."),
|
||||
("invalid-jwt-format", "Invalid token."),
|
||||
("xxx.yyy.zzz", "Invalid token."),
|
||||
("a.b", "Invalid token."), # Missing signature
|
||||
("", "Invalid token."), # Empty string
|
||||
(" ", "Invalid token."), # Whitespace
|
||||
(None, "Invalid token."), # None value
|
||||
# Malformed base64
|
||||
("eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.INVALID_BASE64!@#$.signature", "Invalid token."),
|
||||
]
|
||||
|
||||
for invalid_token, expected_message in invalid_tokens:
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
passport_service.verify(invalid_token)
|
||||
assert expected_message in str(exc_info.value)
|
||||
|
||||
def test_should_reject_expired_token(self, passport_service):
|
||||
"""Test rejection of expired token"""
|
||||
past_time = datetime.now(UTC) - timedelta(hours=1)
|
||||
payload = {"user_id": "123", "exp": past_time.timestamp()}
|
||||
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = "test-secret-key-for-testing"
|
||||
token = jwt.encode(payload, mock_config.SECRET_KEY, algorithm="HS256")
|
||||
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
passport_service.verify(token)
|
||||
assert str(exc_info.value) == "401 Unauthorized: Token has expired."
|
||||
|
||||
# Configuration tests
|
||||
def test_should_handle_empty_secret_key(self):
|
||||
"""Test behavior when SECRET_KEY is empty"""
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = ""
|
||||
service = PassportService()
|
||||
|
||||
# Empty secret key should still work but is insecure
|
||||
payload = {"test": "data"}
|
||||
token = service.issue(payload)
|
||||
decoded = service.verify(token)
|
||||
assert decoded == payload
|
||||
|
||||
def test_should_handle_none_secret_key(self):
|
||||
"""Test behavior when SECRET_KEY is None"""
|
||||
with patch("libs.passport.dify_config") as mock_config:
|
||||
mock_config.SECRET_KEY = None
|
||||
service = PassportService()
|
||||
|
||||
payload = {"test": "data"}
|
||||
# JWT library will raise TypeError when secret is None
|
||||
with pytest.raises((TypeError, jwt.exceptions.InvalidKeyError)):
|
||||
service.issue(payload)
|
||||
|
||||
# Boundary condition tests
|
||||
def test_should_handle_large_payload(self, passport_service):
|
||||
"""Test handling of large payload"""
|
||||
# Test with 100KB instead of 1MB for faster tests
|
||||
large_data = "x" * (100 * 1024)
|
||||
payload = {"data": large_data}
|
||||
|
||||
token = passport_service.issue(payload)
|
||||
decoded = passport_service.verify(token)
|
||||
|
||||
assert decoded["data"] == large_data
|
||||
|
||||
def test_should_handle_special_characters_in_payload(self, passport_service):
|
||||
"""Test handling of special characters in payload"""
|
||||
special_payloads = [
|
||||
{"special": "!@#$%^&*()"},
|
||||
{"quotes": 'He said "Hello"'},
|
||||
{"backslash": "path\\to\\file"},
|
||||
{"newline": "line1\nline2"},
|
||||
{"unicode": "🔐🔑🛡️"},
|
||||
{"mixed": "Test123!@#中文🔐"},
|
||||
]
|
||||
|
||||
for payload in special_payloads:
|
||||
token = passport_service.issue(payload)
|
||||
decoded = passport_service.verify(token)
|
||||
assert decoded == payload
|
||||
|
||||
def test_should_catch_generic_pyjwt_errors(self, passport_service):
|
||||
"""Test that generic PyJWTError exceptions are caught and converted to Unauthorized"""
|
||||
# Mock jwt.decode to raise a generic PyJWTError
|
||||
with patch("libs.passport.jwt.decode") as mock_decode:
|
||||
mock_decode.side_effect = jwt.exceptions.PyJWTError("Generic JWT error")
|
||||
|
||||
with pytest.raises(Unauthorized) as exc_info:
|
||||
passport_service.verify("some-token")
|
||||
assert str(exc_info.value) == "401 Unauthorized: Invalid token."
|
||||
@ -0,0 +1,59 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
class ServiceDbTestHelper:
|
||||
"""
|
||||
Helper class for service database query tests.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def setup_db_query_filter_by_mock(mock_db, query_results):
|
||||
"""
|
||||
Smart database query mock that responds based on model type and query parameters.
|
||||
|
||||
Args:
|
||||
mock_db: Mock database session
|
||||
query_results: Dict mapping (model_name, filter_key, filter_value) to return value
|
||||
Example: {('Account', 'email', 'test@example.com'): mock_account}
|
||||
"""
|
||||
|
||||
def query_side_effect(model):
|
||||
mock_query = MagicMock()
|
||||
|
||||
def filter_by_side_effect(**kwargs):
|
||||
mock_filter_result = MagicMock()
|
||||
|
||||
def first_side_effect():
|
||||
# Find matching result based on model and filter parameters
|
||||
for (model_name, filter_key, filter_value), result in query_results.items():
|
||||
if model.__name__ == model_name and filter_key in kwargs and kwargs[filter_key] == filter_value:
|
||||
return result
|
||||
return None
|
||||
|
||||
mock_filter_result.first.side_effect = first_side_effect
|
||||
|
||||
# Handle order_by calls for complex queries
|
||||
def order_by_side_effect(*args, **kwargs):
|
||||
mock_order_result = MagicMock()
|
||||
|
||||
def order_first_side_effect():
|
||||
# Look for order_by results in the same query_results dict
|
||||
for (model_name, filter_key, filter_value), result in query_results.items():
|
||||
if (
|
||||
model.__name__ == model_name
|
||||
and filter_key == "order_by"
|
||||
and filter_value == "first_available"
|
||||
):
|
||||
return result
|
||||
return None
|
||||
|
||||
mock_order_result.first.side_effect = order_first_side_effect
|
||||
return mock_order_result
|
||||
|
||||
mock_filter_result.order_by.side_effect = order_by_side_effect
|
||||
return mock_filter_result
|
||||
|
||||
mock_query.filter_by.side_effect = filter_by_side_effect
|
||||
return mock_query
|
||||
|
||||
mock_db.session.query.side_effect = query_side_effect
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,288 @@
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
|
||||
)
|
||||
|
||||
|
||||
class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
|
||||
@pytest.fixture
|
||||
def repository(self):
|
||||
mock_session_maker = MagicMock()
|
||||
return DifyAPISQLAlchemyWorkflowNodeExecutionRepository(session_maker=mock_session_maker)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_execution(self):
|
||||
execution = MagicMock(spec=WorkflowNodeExecutionModel)
|
||||
execution.id = str(uuid4())
|
||||
execution.tenant_id = "tenant-123"
|
||||
execution.app_id = "app-456"
|
||||
execution.workflow_id = "workflow-789"
|
||||
execution.workflow_run_id = "run-101"
|
||||
execution.node_id = "node-202"
|
||||
execution.index = 1
|
||||
execution.created_at = "2023-01-01T00:00:00Z"
|
||||
return execution
|
||||
|
||||
def test_get_node_last_execution_found(self, repository, mock_execution):
|
||||
"""Test getting the last execution for a node when it exists."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = mock_execution
|
||||
|
||||
# Act
|
||||
result = repository.get_node_last_execution(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_id="workflow-789",
|
||||
node_id="node-202",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == mock_execution
|
||||
mock_session.scalar.assert_called_once()
|
||||
# Verify the query was constructed correctly
|
||||
call_args = mock_session.scalar.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
def test_get_node_last_execution_not_found(self, repository):
|
||||
"""Test getting the last execution for a node when it doesn't exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = repository.get_node_last_execution(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_id="workflow-789",
|
||||
node_id="node-202",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_executions_by_workflow_run(self, repository, mock_execution):
|
||||
"""Test getting all executions for a workflow run."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
executions = [mock_execution]
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = executions
|
||||
|
||||
# Act
|
||||
result = repository.get_executions_by_workflow_run(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_run_id="run-101",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == executions
|
||||
mock_session.execute.assert_called_once()
|
||||
# Verify the query was constructed correctly
|
||||
call_args = mock_session.execute.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
def test_get_executions_by_workflow_run_empty(self, repository):
|
||||
"""Test getting executions for a workflow run when none exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = repository.get_executions_by_workflow_run(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_run_id="run-101",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_get_execution_by_id_found(self, repository, mock_execution):
|
||||
"""Test getting execution by ID when it exists."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = mock_execution
|
||||
|
||||
# Act
|
||||
result = repository.get_execution_by_id(mock_execution.id)
|
||||
|
||||
# Assert
|
||||
assert result == mock_execution
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_execution_by_id_not_found(self, repository):
|
||||
"""Test getting execution by ID when it doesn't exist."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalar.return_value = None
|
||||
|
||||
# Act
|
||||
result = repository.get_execution_by_id("non-existent-id")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_repository_implements_protocol(self, repository):
|
||||
"""Test that the repository implements the required protocol methods."""
|
||||
# Verify all protocol methods are implemented
|
||||
assert hasattr(repository, "get_node_last_execution")
|
||||
assert hasattr(repository, "get_executions_by_workflow_run")
|
||||
assert hasattr(repository, "get_execution_by_id")
|
||||
|
||||
# Verify methods are callable
|
||||
assert callable(repository.get_node_last_execution)
|
||||
assert callable(repository.get_executions_by_workflow_run)
|
||||
assert callable(repository.get_execution_by_id)
|
||||
assert callable(repository.delete_expired_executions)
|
||||
assert callable(repository.delete_executions_by_app)
|
||||
assert callable(repository.get_expired_executions_batch)
|
||||
assert callable(repository.delete_executions_by_ids)
|
||||
|
||||
def test_delete_expired_executions(self, repository):
|
||||
"""Test deleting expired executions."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the select query to return some IDs first time, then empty to stop loop
|
||||
execution_ids = ["id1", "id2"] # Less than batch_size to trigger break
|
||||
|
||||
# Mock execute method to handle both select and delete statements
|
||||
def mock_execute(stmt):
|
||||
mock_result = MagicMock()
|
||||
# For select statements, return execution IDs
|
||||
if hasattr(stmt, "limit"): # This is our select statement
|
||||
mock_result.scalars.return_value.all.return_value = execution_ids
|
||||
else: # This is our delete statement
|
||||
mock_result.rowcount = 2
|
||||
return mock_result
|
||||
|
||||
mock_session.execute.side_effect = mock_execute
|
||||
|
||||
before_date = datetime(2023, 1, 1)
|
||||
|
||||
# Act
|
||||
result = repository.delete_expired_executions(
|
||||
tenant_id="tenant-123",
|
||||
before_date=before_date,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == 2
|
||||
assert mock_session.execute.call_count == 2 # One select call, one delete call
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_app(self, repository):
|
||||
"""Test deleting executions by app."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the select query to return some IDs first time, then empty to stop loop
|
||||
execution_ids = ["id1", "id2"]
|
||||
|
||||
# Mock execute method to handle both select and delete statements
|
||||
def mock_execute(stmt):
|
||||
mock_result = MagicMock()
|
||||
# For select statements, return execution IDs
|
||||
if hasattr(stmt, "limit"): # This is our select statement
|
||||
mock_result.scalars.return_value.all.return_value = execution_ids
|
||||
else: # This is our delete statement
|
||||
mock_result.rowcount = 2
|
||||
return mock_result
|
||||
|
||||
mock_session.execute.side_effect = mock_execute
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_app(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == 2
|
||||
assert mock_session.execute.call_count == 2 # One select call, one delete call
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_get_expired_executions_batch(self, repository):
|
||||
"""Test getting expired executions batch for backup."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Create mock execution objects
|
||||
mock_execution1 = MagicMock()
|
||||
mock_execution1.id = "exec-1"
|
||||
mock_execution2 = MagicMock()
|
||||
mock_execution2.id = "exec-2"
|
||||
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = [mock_execution1, mock_execution2]
|
||||
|
||||
before_date = datetime(2023, 1, 1)
|
||||
|
||||
# Act
|
||||
result = repository.get_expired_executions_batch(
|
||||
tenant_id="tenant-123",
|
||||
before_date=before_date,
|
||||
batch_size=1000,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
assert result[0].id == "exec-1"
|
||||
assert result[1].id == "exec-2"
|
||||
mock_session.execute.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_ids(self, repository):
|
||||
"""Test deleting executions by IDs."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock the delete query result
|
||||
mock_result = MagicMock()
|
||||
mock_result.rowcount = 3
|
||||
mock_session.execute.return_value = mock_result
|
||||
|
||||
execution_ids = ["id1", "id2", "id3"]
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_ids(execution_ids)
|
||||
|
||||
# Assert
|
||||
assert result == 3
|
||||
mock_session.execute.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_delete_executions_by_ids_empty_list(self, repository):
|
||||
"""Test deleting executions with empty ID list."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Act
|
||||
result = repository.delete_executions_by_ids([])
|
||||
|
||||
# Assert
|
||||
assert result == 0
|
||||
mock_session.query.assert_not_called()
|
||||
mock_session.commit.assert_not_called()
|
||||
@ -0,0 +1,78 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React, { useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import produce from 'immer'
|
||||
import { useContext } from 'use-context-selector'
|
||||
|
||||
import { Microphone01 } from '@/app/components/base/icons/src/vender/features'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import ConfigContext from '@/context/debug-configuration'
|
||||
import { SupportUploadFileTypes } from '@/app/components/workflow/types'
|
||||
import { useFeatures, useFeaturesStore } from '@/app/components/base/features/hooks'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
|
||||
const ConfigAudio: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const file = useFeatures(s => s.features.file)
|
||||
const featuresStore = useFeaturesStore()
|
||||
const { isShowAudioConfig } = useContext(ConfigContext)
|
||||
|
||||
const isAudioEnabled = file?.allowed_file_types?.includes(SupportUploadFileTypes.audio) ?? false
|
||||
|
||||
const handleChange = useCallback((value: boolean) => {
|
||||
const {
|
||||
features,
|
||||
setFeatures,
|
||||
} = featuresStore!.getState()
|
||||
|
||||
const newFeatures = produce(features, (draft) => {
|
||||
if (value) {
|
||||
draft.file!.allowed_file_types = Array.from(new Set([
|
||||
...(draft.file?.allowed_file_types || []),
|
||||
SupportUploadFileTypes.audio,
|
||||
]))
|
||||
}
|
||||
else {
|
||||
draft.file!.allowed_file_types = draft.file!.allowed_file_types?.filter(
|
||||
type => type !== SupportUploadFileTypes.audio,
|
||||
)
|
||||
}
|
||||
if (draft.file)
|
||||
draft.file.enabled = (draft.file.allowed_file_types?.length ?? 0) > 0
|
||||
})
|
||||
setFeatures(newFeatures)
|
||||
}, [featuresStore])
|
||||
|
||||
if (!isShowAudioConfig)
|
||||
return null
|
||||
|
||||
return (
|
||||
<div className='mt-2 flex items-center gap-2 rounded-xl border-l-[0.5px] border-t-[0.5px] bg-background-section-burn p-2'>
|
||||
<div className='shrink-0 p-1'>
|
||||
<div className='rounded-lg border-[0.5px] border-divider-subtle bg-util-colors-violet-violet-600 p-1 shadow-xs'>
|
||||
<Microphone01 className='h-4 w-4 text-text-primary-on-surface' />
|
||||
</div>
|
||||
</div>
|
||||
<div className='flex grow items-center'>
|
||||
<div className='system-sm-semibold mr-1 text-text-secondary'>{t('appDebug.feature.audioUpload.title')}</div>
|
||||
<Tooltip
|
||||
popupContent={
|
||||
<div className='w-[180px]' >
|
||||
{t('appDebug.feature.audioUpload.description')}
|
||||
</div>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<div className='flex shrink-0 items-center'>
|
||||
<div className='ml-1 mr-3 h-3.5 w-[1px] bg-divider-subtle'></div>
|
||||
<Switch
|
||||
defaultValue={isAudioEnabled}
|
||||
onChange={handleChange}
|
||||
size='md'
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(ConfigAudio)
|
||||
@ -1,234 +1,16 @@
|
||||
import { fetchNodeInspectVars } from '@/service/workflow'
|
||||
import { useStore, useWorkflowStore } from '@/app/components/workflow/store'
|
||||
import type { ValueSelector } from '@/app/components/workflow/types'
|
||||
import type { VarInInspect } from '@/types/workflow'
|
||||
import { VarInInspectType } from '@/types/workflow'
|
||||
import {
|
||||
useDeleteAllInspectorVars,
|
||||
useDeleteInspectVar,
|
||||
useDeleteNodeInspectorVars,
|
||||
useEditInspectorVar,
|
||||
useInvalidateConversationVarValues,
|
||||
useInvalidateSysVarValues,
|
||||
useResetConversationVar,
|
||||
useResetToLastRunValue,
|
||||
} from '@/service/use-workflow'
|
||||
import { useCallback } from 'react'
|
||||
import { isConversationVar, isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils'
|
||||
import produce from 'immer'
|
||||
import type { Node } from '@/app/components/workflow/types'
|
||||
import { useNodesInteractionsWithoutSync } from '@/app/components/workflow/hooks/use-nodes-interactions-without-sync'
|
||||
import { useEdgesInteractionsWithoutSync } from '@/app/components/workflow/hooks/use-edges-interactions-without-sync'
|
||||
import { useStore } from '@/app/components/workflow/store'
|
||||
import { useInspectVarsCrudCommon } from '../../workflow/hooks/use-inspect-vars-crud-common'
|
||||
import { useConfigsMap } from './use-configs-map'
|
||||
|
||||
export const useInspectVarsCrud = () => {
|
||||
const workflowStore = useWorkflowStore()
|
||||
const appId = useStore(s => s.appId)
|
||||
const { conversationVarsUrl, systemVarsUrl } = useConfigsMap()
|
||||
const invalidateConversationVarValues = useInvalidateConversationVarValues(conversationVarsUrl)
|
||||
const { mutateAsync: doResetConversationVar } = useResetConversationVar(appId)
|
||||
const { mutateAsync: doResetToLastRunValue } = useResetToLastRunValue(appId)
|
||||
const invalidateSysVarValues = useInvalidateSysVarValues(systemVarsUrl)
|
||||
|
||||
const { mutateAsync: doDeleteAllInspectorVars } = useDeleteAllInspectorVars(appId)
|
||||
const { mutate: doDeleteNodeInspectorVars } = useDeleteNodeInspectorVars(appId)
|
||||
const { mutate: doDeleteInspectVar } = useDeleteInspectVar(appId)
|
||||
|
||||
const { mutateAsync: doEditInspectorVar } = useEditInspectorVar(appId)
|
||||
const { handleCancelNodeSuccessStatus } = useNodesInteractionsWithoutSync()
|
||||
const { handleEdgeCancelRunningStatus } = useEdgesInteractionsWithoutSync()
|
||||
const getNodeInspectVars = useCallback((nodeId: string) => {
|
||||
const { nodesWithInspectVars } = workflowStore.getState()
|
||||
const node = nodesWithInspectVars.find(node => node.nodeId === nodeId)
|
||||
return node
|
||||
}, [workflowStore])
|
||||
|
||||
const getVarId = useCallback((nodeId: string, varName: string) => {
|
||||
const node = getNodeInspectVars(nodeId)
|
||||
if (!node)
|
||||
return undefined
|
||||
const varId = node.vars.find((varItem) => {
|
||||
return varItem.selector[1] === varName
|
||||
})?.id
|
||||
return varId
|
||||
}, [getNodeInspectVars])
|
||||
|
||||
const getInspectVar = useCallback((nodeId: string, name: string): VarInInspect | undefined => {
|
||||
const node = getNodeInspectVars(nodeId)
|
||||
if (!node)
|
||||
return undefined
|
||||
|
||||
const variable = node.vars.find((varItem) => {
|
||||
return varItem.name === name
|
||||
})
|
||||
return variable
|
||||
}, [getNodeInspectVars])
|
||||
|
||||
const hasSetInspectVar = useCallback((nodeId: string, name: string, sysVars: VarInInspect[], conversationVars: VarInInspect[]) => {
|
||||
const isEnv = isENV([nodeId])
|
||||
if (isEnv) // always have value
|
||||
return true
|
||||
const isSys = isSystemVar([nodeId])
|
||||
if (isSys)
|
||||
return sysVars.some(varItem => varItem.selector?.[1]?.replace('sys.', '') === name)
|
||||
const isChatVar = isConversationVar([nodeId])
|
||||
if (isChatVar)
|
||||
return conversationVars.some(varItem => varItem.selector?.[1] === name)
|
||||
return getInspectVar(nodeId, name) !== undefined
|
||||
}, [getInspectVar])
|
||||
|
||||
const hasNodeInspectVars = useCallback((nodeId: string) => {
|
||||
return !!getNodeInspectVars(nodeId)
|
||||
}, [getNodeInspectVars])
|
||||
|
||||
const fetchInspectVarValue = useCallback(async (selector: ValueSelector) => {
|
||||
const {
|
||||
appId,
|
||||
setNodeInspectVars,
|
||||
} = workflowStore.getState()
|
||||
const nodeId = selector[0]
|
||||
const isSystemVar = nodeId === 'sys'
|
||||
const isConversationVar = nodeId === 'conversation'
|
||||
if (isSystemVar) {
|
||||
invalidateSysVarValues()
|
||||
return
|
||||
}
|
||||
if (isConversationVar) {
|
||||
invalidateConversationVarValues()
|
||||
return
|
||||
}
|
||||
const vars = await fetchNodeInspectVars(appId, nodeId)
|
||||
setNodeInspectVars(nodeId, vars)
|
||||
}, [workflowStore, invalidateSysVarValues, invalidateConversationVarValues])
|
||||
|
||||
// after last run would call this
|
||||
const appendNodeInspectVars = useCallback((nodeId: string, payload: VarInInspect[], allNodes: Node[]) => {
|
||||
const {
|
||||
nodesWithInspectVars,
|
||||
setNodesWithInspectVars,
|
||||
} = workflowStore.getState()
|
||||
const nodes = produce(nodesWithInspectVars, (draft) => {
|
||||
const nodeInfo = allNodes.find(node => node.id === nodeId)
|
||||
if (nodeInfo) {
|
||||
const index = draft.findIndex(node => node.nodeId === nodeId)
|
||||
if (index === -1) {
|
||||
draft.unshift({
|
||||
nodeId,
|
||||
nodeType: nodeInfo.data.type,
|
||||
title: nodeInfo.data.title,
|
||||
vars: payload,
|
||||
nodePayload: nodeInfo.data,
|
||||
})
|
||||
}
|
||||
else {
|
||||
draft[index].vars = payload
|
||||
// put the node to the topAdd commentMore actions
|
||||
draft.unshift(draft.splice(index, 1)[0])
|
||||
}
|
||||
}
|
||||
const configsMap = useConfigsMap()
|
||||
const apis = useInspectVarsCrudCommon({
|
||||
flowId: appId,
|
||||
...configsMap,
|
||||
})
|
||||
setNodesWithInspectVars(nodes)
|
||||
handleCancelNodeSuccessStatus(nodeId)
|
||||
}, [workflowStore, handleCancelNodeSuccessStatus])
|
||||
|
||||
const hasNodeInspectVar = useCallback((nodeId: string, varId: string) => {
|
||||
const { nodesWithInspectVars } = workflowStore.getState()
|
||||
const targetNode = nodesWithInspectVars.find(item => item.nodeId === nodeId)
|
||||
if(!targetNode || !targetNode.vars)
|
||||
return false
|
||||
return targetNode.vars.some(item => item.id === varId)
|
||||
}, [workflowStore])
|
||||
|
||||
const deleteInspectVar = useCallback(async (nodeId: string, varId: string) => {
|
||||
const { deleteInspectVar } = workflowStore.getState()
|
||||
if(hasNodeInspectVar(nodeId, varId)) {
|
||||
await doDeleteInspectVar(varId)
|
||||
deleteInspectVar(nodeId, varId)
|
||||
}
|
||||
}, [doDeleteInspectVar, workflowStore, hasNodeInspectVar])
|
||||
|
||||
const resetConversationVar = useCallback(async (varId: string) => {
|
||||
await doResetConversationVar(varId)
|
||||
invalidateConversationVarValues()
|
||||
}, [doResetConversationVar, invalidateConversationVarValues])
|
||||
|
||||
const deleteNodeInspectorVars = useCallback(async (nodeId: string) => {
|
||||
const { deleteNodeInspectVars } = workflowStore.getState()
|
||||
if (hasNodeInspectVars(nodeId)) {
|
||||
await doDeleteNodeInspectorVars(nodeId)
|
||||
deleteNodeInspectVars(nodeId)
|
||||
}
|
||||
}, [doDeleteNodeInspectorVars, workflowStore, hasNodeInspectVars])
|
||||
|
||||
const deleteAllInspectorVars = useCallback(async () => {
|
||||
const { deleteAllInspectVars } = workflowStore.getState()
|
||||
await doDeleteAllInspectorVars()
|
||||
await invalidateConversationVarValues()
|
||||
await invalidateSysVarValues()
|
||||
deleteAllInspectVars()
|
||||
handleEdgeCancelRunningStatus()
|
||||
}, [doDeleteAllInspectorVars, invalidateConversationVarValues, invalidateSysVarValues, workflowStore, handleEdgeCancelRunningStatus])
|
||||
|
||||
const editInspectVarValue = useCallback(async (nodeId: string, varId: string, value: any) => {
|
||||
const { setInspectVarValue } = workflowStore.getState()
|
||||
await doEditInspectorVar({
|
||||
varId,
|
||||
value,
|
||||
})
|
||||
setInspectVarValue(nodeId, varId, value)
|
||||
if (nodeId === VarInInspectType.conversation)
|
||||
invalidateConversationVarValues()
|
||||
if (nodeId === VarInInspectType.system)
|
||||
invalidateSysVarValues()
|
||||
}, [doEditInspectorVar, invalidateConversationVarValues, invalidateSysVarValues, workflowStore])
|
||||
|
||||
const renameInspectVarName = useCallback(async (nodeId: string, oldName: string, newName: string) => {
|
||||
const { renameInspectVarName } = workflowStore.getState()
|
||||
const varId = getVarId(nodeId, oldName)
|
||||
if (!varId)
|
||||
return
|
||||
|
||||
const newSelector = [nodeId, newName]
|
||||
await doEditInspectorVar({
|
||||
varId,
|
||||
name: newName,
|
||||
})
|
||||
renameInspectVarName(nodeId, varId, newSelector)
|
||||
}, [doEditInspectorVar, getVarId, workflowStore])
|
||||
|
||||
const isInspectVarEdited = useCallback((nodeId: string, name: string) => {
|
||||
const inspectVar = getInspectVar(nodeId, name)
|
||||
if (!inspectVar)
|
||||
return false
|
||||
|
||||
return inspectVar.edited
|
||||
}, [getInspectVar])
|
||||
|
||||
const resetToLastRunVar = useCallback(async (nodeId: string, varId: string) => {
|
||||
const { resetToLastRunVar } = workflowStore.getState()
|
||||
const isSysVar = nodeId === 'sys'
|
||||
const data = await doResetToLastRunValue(varId)
|
||||
|
||||
if(isSysVar)
|
||||
invalidateSysVarValues()
|
||||
else
|
||||
resetToLastRunVar(nodeId, varId, data.value)
|
||||
}, [doResetToLastRunValue, invalidateSysVarValues, workflowStore])
|
||||
|
||||
return {
|
||||
hasNodeInspectVars,
|
||||
hasSetInspectVar,
|
||||
fetchInspectVarValue,
|
||||
editInspectVarValue,
|
||||
renameInspectVarName,
|
||||
appendNodeInspectVars,
|
||||
deleteInspectVar,
|
||||
deleteNodeInspectorVars,
|
||||
deleteAllInspectorVars,
|
||||
isInspectVarEdited,
|
||||
resetToLastRunVar,
|
||||
invalidateSysVarValues,
|
||||
resetConversationVar,
|
||||
invalidateConversationVarValues,
|
||||
...apis,
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,240 @@
|
||||
import { fetchNodeInspectVars } from '@/service/workflow'
|
||||
import { useWorkflowStore } from '@/app/components/workflow/store'
|
||||
import type { ValueSelector } from '@/app/components/workflow/types'
|
||||
import type { VarInInspect } from '@/types/workflow'
|
||||
import { VarInInspectType } from '@/types/workflow'
|
||||
import {
|
||||
useDeleteAllInspectorVars,
|
||||
useDeleteInspectVar,
|
||||
useDeleteNodeInspectorVars,
|
||||
useEditInspectorVar,
|
||||
useInvalidateConversationVarValues,
|
||||
useInvalidateSysVarValues,
|
||||
useResetConversationVar,
|
||||
useResetToLastRunValue,
|
||||
} from '@/service/use-workflow'
|
||||
import { useCallback } from 'react'
|
||||
import { isConversationVar, isENV, isSystemVar } from '@/app/components/workflow/nodes/_base/components/variable/utils'
|
||||
import produce from 'immer'
|
||||
import type { Node } from '@/app/components/workflow/types'
|
||||
import { useNodesInteractionsWithoutSync } from '@/app/components/workflow/hooks/use-nodes-interactions-without-sync'
|
||||
import { useEdgesInteractionsWithoutSync } from '@/app/components/workflow/hooks/use-edges-interactions-without-sync'
|
||||
|
||||
type Params = {
|
||||
flowId: string
|
||||
conversationVarsUrl: string
|
||||
systemVarsUrl: string
|
||||
}
|
||||
export const useInspectVarsCrudCommon = ({
|
||||
flowId,
|
||||
conversationVarsUrl,
|
||||
systemVarsUrl,
|
||||
}: Params) => {
|
||||
const workflowStore = useWorkflowStore()
|
||||
const invalidateConversationVarValues = useInvalidateConversationVarValues(conversationVarsUrl!)
|
||||
const { mutateAsync: doResetConversationVar } = useResetConversationVar(flowId)
|
||||
const { mutateAsync: doResetToLastRunValue } = useResetToLastRunValue(flowId)
|
||||
const invalidateSysVarValues = useInvalidateSysVarValues(systemVarsUrl!)
|
||||
|
||||
const { mutateAsync: doDeleteAllInspectorVars } = useDeleteAllInspectorVars(flowId)
|
||||
const { mutate: doDeleteNodeInspectorVars } = useDeleteNodeInspectorVars(flowId)
|
||||
const { mutate: doDeleteInspectVar } = useDeleteInspectVar(flowId)
|
||||
|
||||
const { mutateAsync: doEditInspectorVar } = useEditInspectorVar(flowId)
|
||||
const { handleCancelNodeSuccessStatus } = useNodesInteractionsWithoutSync()
|
||||
const { handleEdgeCancelRunningStatus } = useEdgesInteractionsWithoutSync()
|
||||
const getNodeInspectVars = useCallback((nodeId: string) => {
|
||||
const { nodesWithInspectVars } = workflowStore.getState()
|
||||
const node = nodesWithInspectVars.find(node => node.nodeId === nodeId)
|
||||
return node
|
||||
}, [workflowStore])
|
||||
|
||||
const getVarId = useCallback((nodeId: string, varName: string) => {
|
||||
const node = getNodeInspectVars(nodeId)
|
||||
if (!node)
|
||||
return undefined
|
||||
const varId = node.vars.find((varItem) => {
|
||||
return varItem.selector[1] === varName
|
||||
})?.id
|
||||
return varId
|
||||
}, [getNodeInspectVars])
|
||||
|
||||
const getInspectVar = useCallback((nodeId: string, name: string): VarInInspect | undefined => {
|
||||
const node = getNodeInspectVars(nodeId)
|
||||
if (!node)
|
||||
return undefined
|
||||
|
||||
const variable = node.vars.find((varItem) => {
|
||||
return varItem.name === name
|
||||
})
|
||||
return variable
|
||||
}, [getNodeInspectVars])
|
||||
|
||||
const hasSetInspectVar = useCallback((nodeId: string, name: string, sysVars: VarInInspect[], conversationVars: VarInInspect[]) => {
|
||||
const isEnv = isENV([nodeId])
|
||||
if (isEnv) // always have value
|
||||
return true
|
||||
const isSys = isSystemVar([nodeId])
|
||||
if (isSys)
|
||||
return sysVars.some(varItem => varItem.selector?.[1]?.replace('sys.', '') === name)
|
||||
const isChatVar = isConversationVar([nodeId])
|
||||
if (isChatVar)
|
||||
return conversationVars.some(varItem => varItem.selector?.[1] === name)
|
||||
return getInspectVar(nodeId, name) !== undefined
|
||||
}, [getInspectVar])
|
||||
|
||||
const hasNodeInspectVars = useCallback((nodeId: string) => {
|
||||
return !!getNodeInspectVars(nodeId)
|
||||
}, [getNodeInspectVars])
|
||||
|
||||
const fetchInspectVarValue = useCallback(async (selector: ValueSelector) => {
|
||||
const {
|
||||
appId,
|
||||
setNodeInspectVars,
|
||||
} = workflowStore.getState()
|
||||
const nodeId = selector[0]
|
||||
const isSystemVar = nodeId === 'sys'
|
||||
const isConversationVar = nodeId === 'conversation'
|
||||
if (isSystemVar) {
|
||||
invalidateSysVarValues()
|
||||
return
|
||||
}
|
||||
if (isConversationVar) {
|
||||
invalidateConversationVarValues()
|
||||
return
|
||||
}
|
||||
const vars = await fetchNodeInspectVars(appId, nodeId)
|
||||
setNodeInspectVars(nodeId, vars)
|
||||
}, [workflowStore, invalidateSysVarValues, invalidateConversationVarValues])
|
||||
|
||||
// after last run would call this
|
||||
const appendNodeInspectVars = useCallback((nodeId: string, payload: VarInInspect[], allNodes: Node[]) => {
|
||||
const {
|
||||
nodesWithInspectVars,
|
||||
setNodesWithInspectVars,
|
||||
} = workflowStore.getState()
|
||||
const nodes = produce(nodesWithInspectVars, (draft) => {
|
||||
const nodeInfo = allNodes.find(node => node.id === nodeId)
|
||||
if (nodeInfo) {
|
||||
const index = draft.findIndex(node => node.nodeId === nodeId)
|
||||
if (index === -1) {
|
||||
draft.unshift({
|
||||
nodeId,
|
||||
nodeType: nodeInfo.data.type,
|
||||
title: nodeInfo.data.title,
|
||||
vars: payload,
|
||||
nodePayload: nodeInfo.data,
|
||||
})
|
||||
}
|
||||
else {
|
||||
draft[index].vars = payload
|
||||
// put the node to the topAdd commentMore actions
|
||||
draft.unshift(draft.splice(index, 1)[0])
|
||||
}
|
||||
}
|
||||
})
|
||||
setNodesWithInspectVars(nodes)
|
||||
handleCancelNodeSuccessStatus(nodeId)
|
||||
}, [workflowStore, handleCancelNodeSuccessStatus])
|
||||
|
||||
const hasNodeInspectVar = useCallback((nodeId: string, varId: string) => {
|
||||
const { nodesWithInspectVars } = workflowStore.getState()
|
||||
const targetNode = nodesWithInspectVars.find(item => item.nodeId === nodeId)
|
||||
if(!targetNode || !targetNode.vars)
|
||||
return false
|
||||
return targetNode.vars.some(item => item.id === varId)
|
||||
}, [workflowStore])
|
||||
|
||||
const deleteInspectVar = useCallback(async (nodeId: string, varId: string) => {
|
||||
const { deleteInspectVar } = workflowStore.getState()
|
||||
if(hasNodeInspectVar(nodeId, varId)) {
|
||||
await doDeleteInspectVar(varId)
|
||||
deleteInspectVar(nodeId, varId)
|
||||
}
|
||||
}, [doDeleteInspectVar, workflowStore, hasNodeInspectVar])
|
||||
|
||||
const resetConversationVar = useCallback(async (varId: string) => {
|
||||
await doResetConversationVar(varId)
|
||||
invalidateConversationVarValues()
|
||||
}, [doResetConversationVar, invalidateConversationVarValues])
|
||||
|
||||
const deleteNodeInspectorVars = useCallback(async (nodeId: string) => {
|
||||
const { deleteNodeInspectVars } = workflowStore.getState()
|
||||
if (hasNodeInspectVars(nodeId)) {
|
||||
await doDeleteNodeInspectorVars(nodeId)
|
||||
deleteNodeInspectVars(nodeId)
|
||||
}
|
||||
}, [doDeleteNodeInspectorVars, workflowStore, hasNodeInspectVars])
|
||||
|
||||
const deleteAllInspectorVars = useCallback(async () => {
|
||||
const { deleteAllInspectVars } = workflowStore.getState()
|
||||
await doDeleteAllInspectorVars()
|
||||
await invalidateConversationVarValues()
|
||||
await invalidateSysVarValues()
|
||||
deleteAllInspectVars()
|
||||
handleEdgeCancelRunningStatus()
|
||||
}, [doDeleteAllInspectorVars, invalidateConversationVarValues, invalidateSysVarValues, workflowStore, handleEdgeCancelRunningStatus])
|
||||
|
||||
const editInspectVarValue = useCallback(async (nodeId: string, varId: string, value: any) => {
|
||||
const { setInspectVarValue } = workflowStore.getState()
|
||||
await doEditInspectorVar({
|
||||
varId,
|
||||
value,
|
||||
})
|
||||
setInspectVarValue(nodeId, varId, value)
|
||||
if (nodeId === VarInInspectType.conversation)
|
||||
invalidateConversationVarValues()
|
||||
if (nodeId === VarInInspectType.system)
|
||||
invalidateSysVarValues()
|
||||
}, [doEditInspectorVar, invalidateConversationVarValues, invalidateSysVarValues, workflowStore])
|
||||
|
||||
const renameInspectVarName = useCallback(async (nodeId: string, oldName: string, newName: string) => {
|
||||
const { renameInspectVarName } = workflowStore.getState()
|
||||
const varId = getVarId(nodeId, oldName)
|
||||
if (!varId)
|
||||
return
|
||||
|
||||
const newSelector = [nodeId, newName]
|
||||
await doEditInspectorVar({
|
||||
varId,
|
||||
name: newName,
|
||||
})
|
||||
renameInspectVarName(nodeId, varId, newSelector)
|
||||
}, [doEditInspectorVar, getVarId, workflowStore])
|
||||
|
||||
const isInspectVarEdited = useCallback((nodeId: string, name: string) => {
|
||||
const inspectVar = getInspectVar(nodeId, name)
|
||||
if (!inspectVar)
|
||||
return false
|
||||
|
||||
return inspectVar.edited
|
||||
}, [getInspectVar])
|
||||
|
||||
const resetToLastRunVar = useCallback(async (nodeId: string, varId: string) => {
|
||||
const { resetToLastRunVar } = workflowStore.getState()
|
||||
const isSysVar = nodeId === 'sys'
|
||||
const data = await doResetToLastRunValue(varId)
|
||||
|
||||
if(isSysVar)
|
||||
invalidateSysVarValues()
|
||||
else
|
||||
resetToLastRunVar(nodeId, varId, data.value)
|
||||
}, [doResetToLastRunValue, invalidateSysVarValues, workflowStore])
|
||||
|
||||
return {
|
||||
hasNodeInspectVars,
|
||||
hasSetInspectVar,
|
||||
fetchInspectVarValue,
|
||||
editInspectVarValue,
|
||||
renameInspectVarName,
|
||||
appendNodeInspectVars,
|
||||
deleteInspectVar,
|
||||
deleteNodeInspectorVars,
|
||||
deleteAllInspectorVars,
|
||||
isInspectVarEdited,
|
||||
resetToLastRunVar,
|
||||
invalidateSysVarValues,
|
||||
resetConversationVar,
|
||||
invalidateConversationVarValues,
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue