Merge remote-tracking branch 'company/feat/17150_2'
commit
452fa1b4c6
@ -0,0 +1,45 @@
|
|||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||||
|
|
||||||
|
|
||||||
|
def get_parameters_from_feature_dict(
|
||||||
|
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
|
||||||
|
) -> Mapping[str, Any]:
|
||||||
|
"""
|
||||||
|
Mapping from feature dict to webapp parameters
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"opening_statement": features_dict.get("opening_statement"),
|
||||||
|
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||||
|
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
|
||||||
|
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||||
|
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||||
|
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||||
|
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||||
|
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||||
|
"user_input_form": user_input_form,
|
||||||
|
"sensitive_word_avoidance": features_dict.get(
|
||||||
|
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||||
|
),
|
||||||
|
"file_upload": features_dict.get(
|
||||||
|
"file_upload",
|
||||||
|
{
|
||||||
|
"image": {
|
||||||
|
"enabled": False,
|
||||||
|
"number_limits": DEFAULT_FILE_NUMBER_LIMITS,
|
||||||
|
"detail": "high",
|
||||||
|
"transfer_methods": ["remote_url", "local_file"],
|
||||||
|
}
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"system_parameters": {
|
||||||
|
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||||
|
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||||
|
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||||
|
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||||
|
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||||
|
},
|
||||||
|
}
|
||||||
@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
Repository interfaces for data access.
|
||||||
|
|
||||||
|
This package contains repository interfaces that define the contract
|
||||||
|
for accessing and manipulating data, regardless of the underlying
|
||||||
|
storage mechanism.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from core.repository.repository_factory import RepositoryFactory
|
||||||
|
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"RepositoryFactory",
|
||||||
|
"WorkflowNodeExecutionRepository",
|
||||||
|
]
|
||||||
@ -0,0 +1,97 @@
|
|||||||
|
"""
|
||||||
|
Repository factory for creating repository instances.
|
||||||
|
|
||||||
|
This module provides a simple factory interface for creating repository instances.
|
||||||
|
It does not contain any implementation details or dependencies on specific repositories.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable, Mapping
|
||||||
|
from typing import Any, Literal, Optional, cast
|
||||||
|
|
||||||
|
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
# Type for factory functions - takes a dict of parameters and returns any repository type
|
||||||
|
RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any]
|
||||||
|
|
||||||
|
# Type for workflow node execution factory function
|
||||||
|
WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository]
|
||||||
|
|
||||||
|
# Repository type literals
|
||||||
|
_RepositoryType = Literal["workflow_node_execution"]
|
||||||
|
|
||||||
|
|
||||||
|
class RepositoryFactory:
|
||||||
|
"""
|
||||||
|
Factory class for creating repository instances.
|
||||||
|
|
||||||
|
This factory delegates the actual repository creation to implementation-specific
|
||||||
|
factory functions that are registered with the factory at runtime.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Dictionary to store factory functions
|
||||||
|
_factory_functions: dict[str, RepositoryFactoryFunc] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None:
|
||||||
|
"""
|
||||||
|
Register a factory function for a specific repository type.
|
||||||
|
This is a private method and should not be called directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repository_type: The type of repository (e.g., 'workflow_node_execution')
|
||||||
|
factory_func: A function that takes parameters and returns a repository instance
|
||||||
|
"""
|
||||||
|
cls._factory_functions[repository_type] = factory_func
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any:
|
||||||
|
"""
|
||||||
|
Create a new repository instance with the provided parameters.
|
||||||
|
This is a private method and should not be called directly.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
repository_type: The type of repository to create
|
||||||
|
params: A dictionary of parameters to pass to the factory function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new instance of the requested repository
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no factory function is registered for the repository type
|
||||||
|
"""
|
||||||
|
if repository_type not in cls._factory_functions:
|
||||||
|
raise ValueError(f"No factory function registered for repository type '{repository_type}'")
|
||||||
|
|
||||||
|
# Use empty dict if params is None
|
||||||
|
params = params or {}
|
||||||
|
|
||||||
|
return cls._factory_functions[repository_type](params)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None:
|
||||||
|
"""
|
||||||
|
Register a factory function for the workflow node execution repository.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance
|
||||||
|
"""
|
||||||
|
cls._register_factory("workflow_node_execution", factory_func)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create_workflow_node_execution_repository(
|
||||||
|
cls, params: Optional[Mapping[str, Any]] = None
|
||||||
|
) -> WorkflowNodeExecutionRepository:
|
||||||
|
"""
|
||||||
|
Create a new WorkflowNodeExecutionRepository instance with the provided parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: A dictionary of parameters to pass to the factory function
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A new instance of the WorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no factory function is registered for the workflow_node_execution repository type
|
||||||
|
"""
|
||||||
|
# We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc
|
||||||
|
return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params))
|
||||||
@ -0,0 +1,88 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Literal, Optional, Protocol
|
||||||
|
|
||||||
|
from models.workflow import WorkflowNodeExecution
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OrderConfig:
|
||||||
|
"""Configuration for ordering WorkflowNodeExecution instances."""
|
||||||
|
|
||||||
|
order_by: list[str]
|
||||||
|
order_direction: Optional[Literal["asc", "desc"]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class WorkflowNodeExecutionRepository(Protocol):
|
||||||
|
"""
|
||||||
|
Repository interface for WorkflowNodeExecution.
|
||||||
|
|
||||||
|
This interface defines the contract for accessing and manipulating
|
||||||
|
WorkflowNodeExecution data, regardless of the underlying storage mechanism.
|
||||||
|
|
||||||
|
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
|
||||||
|
and trigger sources (triggered_from) should be handled at the implementation level, not in
|
||||||
|
the core interface. This keeps the core domain model clean and independent of specific
|
||||||
|
application domains or deployment scenarios.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||||
|
"""
|
||||||
|
Save a WorkflowNodeExecution instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution: The WorkflowNodeExecution instance to save
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||||
|
"""
|
||||||
|
Retrieve a WorkflowNodeExecution by its node_execution_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_execution_id: The node execution ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The WorkflowNodeExecution instance if found, None otherwise
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_by_workflow_run(
|
||||||
|
self,
|
||||||
|
workflow_run_id: str,
|
||||||
|
order_config: Optional[OrderConfig] = None,
|
||||||
|
) -> Sequence[WorkflowNodeExecution]:
|
||||||
|
"""
|
||||||
|
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: The workflow run ID
|
||||||
|
order_config: Optional configuration for ordering results
|
||||||
|
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
|
||||||
|
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of WorkflowNodeExecution instances
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||||
|
"""
|
||||||
|
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: The workflow run ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of running WorkflowNodeExecution instances
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
def update(self, execution: WorkflowNodeExecution) -> None:
|
||||||
|
"""
|
||||||
|
Update an existing WorkflowNodeExecution instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution: The WorkflowNodeExecution instance to update
|
||||||
|
"""
|
||||||
|
...
|
||||||
@ -0,0 +1,18 @@
|
|||||||
|
"""
|
||||||
|
Extension for initializing repositories.
|
||||||
|
|
||||||
|
This extension registers repository implementations with the RepositoryFactory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dify_app import DifyApp
|
||||||
|
from repositories.repository_registry import register_repositories
|
||||||
|
|
||||||
|
|
||||||
|
def init_app(_app: DifyApp) -> None:
|
||||||
|
"""
|
||||||
|
Initialize repository implementations.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_app: The Flask application instance (unused)
|
||||||
|
"""
|
||||||
|
register_repositories()
|
||||||
@ -1,4 +0,0 @@
|
|||||||
[virtualenvs]
|
|
||||||
in-project = true
|
|
||||||
create = true
|
|
||||||
prefer-active-python = true
|
|
||||||
@ -0,0 +1,6 @@
|
|||||||
|
"""
|
||||||
|
Repository implementations for data access.
|
||||||
|
|
||||||
|
This package contains concrete implementations of the repository interfaces
|
||||||
|
defined in the core.repository package.
|
||||||
|
"""
|
||||||
@ -0,0 +1,87 @@
|
|||||||
|
"""
|
||||||
|
Registry for repository implementations.
|
||||||
|
|
||||||
|
This module is responsible for registering factory functions with the repository factory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Mapping
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from configs import dify_config
|
||||||
|
from core.repository.repository_factory import RepositoryFactory
|
||||||
|
from extensions.ext_database import db
|
||||||
|
from repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Storage type constants
|
||||||
|
STORAGE_TYPE_RDBMS = "rdbms"
|
||||||
|
STORAGE_TYPE_HYBRID = "hybrid"
|
||||||
|
|
||||||
|
|
||||||
|
def register_repositories() -> None:
|
||||||
|
"""
|
||||||
|
Register repository factory functions with the RepositoryFactory.
|
||||||
|
|
||||||
|
This function reads configuration settings to determine which repository
|
||||||
|
implementations to register.
|
||||||
|
"""
|
||||||
|
# Configure WorkflowNodeExecutionRepository factory based on configuration
|
||||||
|
workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE
|
||||||
|
|
||||||
|
# Check storage type and register appropriate implementation
|
||||||
|
if workflow_node_execution_storage == STORAGE_TYPE_RDBMS:
|
||||||
|
# Register SQLAlchemy implementation for RDBMS storage
|
||||||
|
logger.info("Registering WorkflowNodeExecution repository with RDBMS storage")
|
||||||
|
RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository)
|
||||||
|
elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID:
|
||||||
|
# Hybrid storage is not yet implemented
|
||||||
|
raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented")
|
||||||
|
else:
|
||||||
|
# Unknown storage type
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. "
|
||||||
|
f"Supported types: {STORAGE_TYPE_RDBMS}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository:
|
||||||
|
"""
|
||||||
|
Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation.
|
||||||
|
|
||||||
|
This factory function creates a repository for the RDBMS storage type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
params: Parameters for creating the repository, including:
|
||||||
|
- tenant_id: Required. The tenant ID for multi-tenancy.
|
||||||
|
- app_id: Optional. The application ID for filtering.
|
||||||
|
- session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided,
|
||||||
|
a new sessionmaker will be created using the global database engine.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A WorkflowNodeExecutionRepository instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If required parameters are missing
|
||||||
|
"""
|
||||||
|
# Extract required parameters
|
||||||
|
tenant_id = params.get("tenant_id")
|
||||||
|
if tenant_id is None:
|
||||||
|
raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage")
|
||||||
|
|
||||||
|
# Extract optional parameters
|
||||||
|
app_id = params.get("app_id")
|
||||||
|
|
||||||
|
# Use the session_factory from params if provided, otherwise create one using the global db engine
|
||||||
|
session_factory = params.get("session_factory")
|
||||||
|
if session_factory is None:
|
||||||
|
# Create a sessionmaker using the same engine as the global db session
|
||||||
|
session_factory = sessionmaker(bind=db.engine)
|
||||||
|
|
||||||
|
# Create and return the repository
|
||||||
|
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
|
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
|
||||||
|
)
|
||||||
@ -0,0 +1,9 @@
|
|||||||
|
"""
|
||||||
|
WorkflowNodeExecution repository implementations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||||
|
]
|
||||||
@ -0,0 +1,170 @@
|
|||||||
|
"""
|
||||||
|
SQLAlchemy implementation of the WorkflowNodeExecutionRepository.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import UnaryExpression, asc, desc, select
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from core.repository.workflow_node_execution_repository import OrderConfig
|
||||||
|
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SQLAlchemyWorkflowNodeExecutionRepository:
|
||||||
|
"""
|
||||||
|
SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface.
|
||||||
|
|
||||||
|
This implementation supports multi-tenancy by filtering operations based on tenant_id.
|
||||||
|
Each method creates its own session, handles the transaction, and commits changes
|
||||||
|
to the database. This prevents long-running connections in the workflow core.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
|
||||||
|
tenant_id: Tenant ID for multi-tenancy
|
||||||
|
app_id: Optional app ID for filtering by application
|
||||||
|
"""
|
||||||
|
# If an engine is provided, create a sessionmaker from it
|
||||||
|
if isinstance(session_factory, Engine):
|
||||||
|
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||||
|
else:
|
||||||
|
self._session_factory = session_factory
|
||||||
|
|
||||||
|
self._tenant_id = tenant_id
|
||||||
|
self._app_id = app_id
|
||||||
|
|
||||||
|
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||||
|
"""
|
||||||
|
Save a WorkflowNodeExecution instance and commit changes to the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution: The WorkflowNodeExecution instance to save
|
||||||
|
"""
|
||||||
|
with self._session_factory() as session:
|
||||||
|
# Ensure tenant_id is set
|
||||||
|
if not execution.tenant_id:
|
||||||
|
execution.tenant_id = self._tenant_id
|
||||||
|
|
||||||
|
# Set app_id if provided and not already set
|
||||||
|
if self._app_id and not execution.app_id:
|
||||||
|
execution.app_id = self._app_id
|
||||||
|
|
||||||
|
session.add(execution)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||||
|
"""
|
||||||
|
Retrieve a WorkflowNodeExecution by its node_execution_id.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_execution_id: The node execution ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The WorkflowNodeExecution instance if found, None otherwise
|
||||||
|
"""
|
||||||
|
with self._session_factory() as session:
|
||||||
|
stmt = select(WorkflowNodeExecution).where(
|
||||||
|
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||||
|
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._app_id:
|
||||||
|
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||||
|
|
||||||
|
return session.scalar(stmt)
|
||||||
|
|
||||||
|
def get_by_workflow_run(
|
||||||
|
self,
|
||||||
|
workflow_run_id: str,
|
||||||
|
order_config: Optional[OrderConfig] = None,
|
||||||
|
) -> Sequence[WorkflowNodeExecution]:
|
||||||
|
"""
|
||||||
|
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: The workflow run ID
|
||||||
|
order_config: Optional configuration for ordering results
|
||||||
|
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
|
||||||
|
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of WorkflowNodeExecution instances
|
||||||
|
"""
|
||||||
|
with self._session_factory() as session:
|
||||||
|
stmt = select(WorkflowNodeExecution).where(
|
||||||
|
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||||
|
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||||
|
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._app_id:
|
||||||
|
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||||
|
|
||||||
|
# Apply ordering if provided
|
||||||
|
if order_config and order_config.order_by:
|
||||||
|
order_columns: list[UnaryExpression] = []
|
||||||
|
for field in order_config.order_by:
|
||||||
|
column = getattr(WorkflowNodeExecution, field, None)
|
||||||
|
if not column:
|
||||||
|
continue
|
||||||
|
if order_config.order_direction == "desc":
|
||||||
|
order_columns.append(desc(column))
|
||||||
|
else:
|
||||||
|
order_columns.append(asc(column))
|
||||||
|
|
||||||
|
if order_columns:
|
||||||
|
stmt = stmt.order_by(*order_columns)
|
||||||
|
|
||||||
|
return session.scalars(stmt).all()
|
||||||
|
|
||||||
|
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||||
|
"""
|
||||||
|
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: The workflow run ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of running WorkflowNodeExecution instances
|
||||||
|
"""
|
||||||
|
with self._session_factory() as session:
|
||||||
|
stmt = select(WorkflowNodeExecution).where(
|
||||||
|
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||||
|
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||||
|
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING,
|
||||||
|
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._app_id:
|
||||||
|
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||||
|
|
||||||
|
return session.scalars(stmt).all()
|
||||||
|
|
||||||
|
def update(self, execution: WorkflowNodeExecution) -> None:
|
||||||
|
"""
|
||||||
|
Update an existing WorkflowNodeExecution instance and commit changes to the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
execution: The WorkflowNodeExecution instance to update
|
||||||
|
"""
|
||||||
|
with self._session_factory() as session:
|
||||||
|
# Ensure tenant_id is set
|
||||||
|
if not execution.tenant_id:
|
||||||
|
execution.tenant_id = self._tenant_id
|
||||||
|
|
||||||
|
# Set app_id if provided and not already set
|
||||||
|
if self._app_id and not execution.app_id:
|
||||||
|
execution.app_id = self._app_id
|
||||||
|
|
||||||
|
session.merge(execution)
|
||||||
|
session.commit()
|
||||||
@ -0,0 +1,99 @@
|
|||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||||
|
from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call
|
||||||
|
|
||||||
|
ToolCall = AssistantPromptMessage.ToolCall
|
||||||
|
|
||||||
|
# CASE 1: Single tool call
|
||||||
|
INPUTS_CASE_1 = [
|
||||||
|
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
]
|
||||||
|
EXPECTED_CASE_1 = [
|
||||||
|
ToolCall(
|
||||||
|
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...)
|
||||||
|
INPUTS_CASE_2 = [
|
||||||
|
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
]
|
||||||
|
EXPECTED_CASE_2 = [
|
||||||
|
ToolCall(
|
||||||
|
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...)
|
||||||
|
INPUTS_CASE_3 = [
|
||||||
|
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||||
|
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||||
|
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||||
|
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||||
|
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
]
|
||||||
|
EXPECTED_CASE_3 = [
|
||||||
|
ToolCall(
|
||||||
|
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# CASE 4: Tool call sequences with no IDs
|
||||||
|
INPUTS_CASE_4 = [
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||||
|
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||||
|
]
|
||||||
|
EXPECTED_CASE_4 = [
|
||||||
|
ToolCall(
|
||||||
|
id="RANDOM_ID_1",
|
||||||
|
type="function",
|
||||||
|
function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'),
|
||||||
|
),
|
||||||
|
ToolCall(
|
||||||
|
id="RANDOM_ID_2",
|
||||||
|
type="function",
|
||||||
|
function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _run_case(inputs: list[ToolCall], expected: list[ToolCall]):
|
||||||
|
actual = []
|
||||||
|
_increase_tool_call(inputs, actual)
|
||||||
|
assert actual == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test__increase_tool_call():
|
||||||
|
# case 1:
|
||||||
|
_run_case(INPUTS_CASE_1, EXPECTED_CASE_1)
|
||||||
|
|
||||||
|
# case 2:
|
||||||
|
_run_case(INPUTS_CASE_2, EXPECTED_CASE_2)
|
||||||
|
|
||||||
|
# case 3:
|
||||||
|
_run_case(INPUTS_CASE_3, EXPECTED_CASE_3)
|
||||||
|
|
||||||
|
# case 4:
|
||||||
|
mock_id_generator = MagicMock()
|
||||||
|
mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4]
|
||||||
|
with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator):
|
||||||
|
_run_case(INPUTS_CASE_4, EXPECTED_CASE_4)
|
||||||
@ -0,0 +1,198 @@
|
|||||||
|
import uuid
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import Response
|
||||||
|
|
||||||
|
from factories.file_factory import (
|
||||||
|
File,
|
||||||
|
FileTransferMethod,
|
||||||
|
FileType,
|
||||||
|
FileUploadConfig,
|
||||||
|
build_from_mapping,
|
||||||
|
)
|
||||||
|
from models import ToolFile, UploadFile
|
||||||
|
|
||||||
|
# Test Data
|
||||||
|
TEST_TENANT_ID = "test_tenant_id"
|
||||||
|
TEST_UPLOAD_FILE_ID = str(uuid.uuid4())
|
||||||
|
TEST_TOOL_FILE_ID = str(uuid.uuid4())
|
||||||
|
TEST_REMOTE_URL = "http://example.com/test.jpg"
|
||||||
|
|
||||||
|
# Test Config
|
||||||
|
TEST_CONFIG = FileUploadConfig(
|
||||||
|
allowed_file_types=["image", "document"],
|
||||||
|
allowed_file_extensions=[".jpg", ".pdf"],
|
||||||
|
allowed_file_upload_methods=[FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE],
|
||||||
|
number_limits=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Fixtures
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_upload_file():
|
||||||
|
mock = MagicMock(spec=UploadFile)
|
||||||
|
mock.id = TEST_UPLOAD_FILE_ID
|
||||||
|
mock.tenant_id = TEST_TENANT_ID
|
||||||
|
mock.name = "test.jpg"
|
||||||
|
mock.extension = "jpg"
|
||||||
|
mock.mime_type = "image/jpeg"
|
||||||
|
mock.source_url = TEST_REMOTE_URL
|
||||||
|
mock.size = 1024
|
||||||
|
mock.key = "test_key"
|
||||||
|
with patch("factories.file_factory.db.session.scalar", return_value=mock) as m:
|
||||||
|
yield m
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_tool_file():
|
||||||
|
mock = MagicMock(spec=ToolFile)
|
||||||
|
mock.id = TEST_TOOL_FILE_ID
|
||||||
|
mock.tenant_id = TEST_TENANT_ID
|
||||||
|
mock.name = "tool_file.pdf"
|
||||||
|
mock.file_key = "tool_file.pdf"
|
||||||
|
mock.mimetype = "application/pdf"
|
||||||
|
mock.original_url = "http://example.com/tool.pdf"
|
||||||
|
mock.size = 2048
|
||||||
|
with patch("factories.file_factory.db.session.query") as mock_query:
|
||||||
|
mock_query.return_value.filter.return_value.first.return_value = mock
|
||||||
|
yield mock
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_http_head():
|
||||||
|
def _mock_response(filename, size, content_type):
|
||||||
|
return Response(
|
||||||
|
status_code=200,
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||||
|
"Content-Length": str(size),
|
||||||
|
"Content-Type": content_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("factories.file_factory.ssrf_proxy.head") as mock_head:
|
||||||
|
mock_head.return_value = _mock_response("remote_test.jpg", 2048, "image/jpeg")
|
||||||
|
yield mock_head
|
||||||
|
|
||||||
|
|
||||||
|
# Helper functions
|
||||||
|
def local_file_mapping(file_type="image"):
|
||||||
|
return {
|
||||||
|
"transfer_method": "local_file",
|
||||||
|
"upload_file_id": TEST_UPLOAD_FILE_ID,
|
||||||
|
"type": file_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def tool_file_mapping(file_type="document"):
|
||||||
|
return {
|
||||||
|
"transfer_method": "tool_file",
|
||||||
|
"tool_file_id": TEST_TOOL_FILE_ID,
|
||||||
|
"type": file_type,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Tests
|
||||||
|
def test_build_from_mapping_backward_compatibility(mock_upload_file):
|
||||||
|
mapping = local_file_mapping(file_type="image")
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
|
assert isinstance(file, File)
|
||||||
|
assert file.transfer_method == FileTransferMethod.LOCAL_FILE
|
||||||
|
assert file.type == FileType.IMAGE
|
||||||
|
assert file.related_id == TEST_UPLOAD_FILE_ID
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("file_type", "should_pass", "expected_error"),
|
||||||
|
[
|
||||||
|
("image", True, None),
|
||||||
|
("document", False, "Detected file type does not match"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_build_from_local_file_strict_validation(mock_upload_file, file_type, should_pass, expected_error):
|
||||||
|
mapping = local_file_mapping(file_type=file_type)
|
||||||
|
if should_pass:
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||||
|
assert file.type == FileType(file_type)
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValueError, match=expected_error):
|
||||||
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("file_type", "should_pass", "expected_error"),
|
||||||
|
[
|
||||||
|
("document", True, None),
|
||||||
|
("image", False, "Detected file type does not match"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_build_from_tool_file_strict_validation(mock_tool_file, file_type, should_pass, expected_error):
|
||||||
|
"""Strict type validation for tool_file."""
|
||||||
|
mapping = tool_file_mapping(file_type=file_type)
|
||||||
|
if should_pass:
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||||
|
assert file.type == FileType(file_type)
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValueError, match=expected_error):
|
||||||
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_from_remote_url(mock_http_head):
|
||||||
|
mapping = {
|
||||||
|
"transfer_method": "remote_url",
|
||||||
|
"url": TEST_REMOTE_URL,
|
||||||
|
"type": "image",
|
||||||
|
}
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
|
assert file.transfer_method == FileTransferMethod.REMOTE_URL
|
||||||
|
assert file.type == FileType.IMAGE
|
||||||
|
assert file.filename == "remote_test.jpg"
|
||||||
|
assert file.size == 2048
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_file_not_found():
|
||||||
|
"""Test ToolFile not found in database."""
|
||||||
|
with patch("factories.file_factory.db.session.query") as mock_query:
|
||||||
|
mock_query.return_value.filter.return_value.first.return_value = None
|
||||||
|
mapping = tool_file_mapping()
|
||||||
|
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
|
||||||
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
|
|
||||||
|
|
||||||
|
def test_local_file_not_found():
|
||||||
|
"""Test UploadFile not found in database."""
|
||||||
|
with patch("factories.file_factory.db.session.scalar", return_value=None):
|
||||||
|
mapping = local_file_mapping()
|
||||||
|
with pytest.raises(ValueError, match="Invalid upload file"):
|
||||||
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_without_type_specification(mock_upload_file):
|
||||||
|
"""Test the situation where no file type is specified"""
|
||||||
|
mapping = {
|
||||||
|
"transfer_method": "local_file",
|
||||||
|
"upload_file_id": TEST_UPLOAD_FILE_ID,
|
||||||
|
# leave out the type
|
||||||
|
}
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||||
|
# It should automatically infer the type as "image" based on the file extension
|
||||||
|
assert file.type == FileType.IMAGE
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("file_type", "should_pass", "expected_error"),
|
||||||
|
[
|
||||||
|
("image", True, None),
|
||||||
|
("video", False, "File validation failed"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_file_validation_with_config(mock_upload_file, file_type, should_pass, expected_error):
|
||||||
|
"""Test the validation of files and configurations"""
|
||||||
|
mapping = local_file_mapping(file_type=file_type)
|
||||||
|
if should_pass:
|
||||||
|
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)
|
||||||
|
assert file is not None
|
||||||
|
else:
|
||||||
|
with pytest.raises(ValueError, match=expected_error):
|
||||||
|
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)
|
||||||
@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for repositories.
|
||||||
|
"""
|
||||||
@ -0,0 +1,3 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for workflow_node_execution repositories.
|
||||||
|
"""
|
||||||
@ -0,0 +1,154 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pytest_mock import MockerFixture
|
||||||
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
|
from core.repository.workflow_node_execution_repository import OrderConfig
|
||||||
|
from models.workflow import WorkflowNodeExecution
|
||||||
|
from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session():
|
||||||
|
"""Create a mock SQLAlchemy session."""
|
||||||
|
session = MagicMock(spec=Session)
|
||||||
|
# Configure the session to be used as a context manager
|
||||||
|
session.__enter__ = MagicMock(return_value=session)
|
||||||
|
session.__exit__ = MagicMock(return_value=None)
|
||||||
|
|
||||||
|
# Configure the session factory to return the session
|
||||||
|
session_factory = MagicMock(spec=sessionmaker)
|
||||||
|
session_factory.return_value = session
|
||||||
|
return session, session_factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def repository(session):
|
||||||
|
"""Create a repository instance with test data."""
|
||||||
|
_, session_factory = session
|
||||||
|
tenant_id = "test-tenant"
|
||||||
|
app_id = "test-app"
|
||||||
|
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||||
|
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save(repository, session):
|
||||||
|
"""Test save method."""
|
||||||
|
session_obj, _ = session
|
||||||
|
# Create a mock execution
|
||||||
|
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
|
execution.tenant_id = None
|
||||||
|
execution.app_id = None
|
||||||
|
|
||||||
|
# Call save method
|
||||||
|
repository.save(execution)
|
||||||
|
|
||||||
|
# Assert tenant_id and app_id are set
|
||||||
|
assert execution.tenant_id == repository._tenant_id
|
||||||
|
assert execution.app_id == repository._app_id
|
||||||
|
|
||||||
|
# Assert session.add was called
|
||||||
|
session_obj.add.assert_called_once_with(execution)
|
||||||
|
|
||||||
|
|
||||||
|
def test_save_with_existing_tenant_id(repository, session):
|
||||||
|
"""Test save method with existing tenant_id."""
|
||||||
|
session_obj, _ = session
|
||||||
|
# Create a mock execution with existing tenant_id
|
||||||
|
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
|
execution.tenant_id = "existing-tenant"
|
||||||
|
execution.app_id = None
|
||||||
|
|
||||||
|
# Call save method
|
||||||
|
repository.save(execution)
|
||||||
|
|
||||||
|
# Assert tenant_id is not changed and app_id is set
|
||||||
|
assert execution.tenant_id == "existing-tenant"
|
||||||
|
assert execution.app_id == repository._app_id
|
||||||
|
|
||||||
|
# Assert session.add was called
|
||||||
|
session_obj.add.assert_called_once_with(execution)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
|
||||||
|
"""Test get_by_node_execution_id method."""
|
||||||
|
session_obj, _ = session
|
||||||
|
# Set up mock
|
||||||
|
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
|
||||||
|
mock_stmt = mocker.MagicMock()
|
||||||
|
mock_select.return_value = mock_stmt
|
||||||
|
mock_stmt.where.return_value = mock_stmt
|
||||||
|
session_obj.scalar.return_value = mocker.MagicMock(spec=WorkflowNodeExecution)
|
||||||
|
|
||||||
|
# Call method
|
||||||
|
result = repository.get_by_node_execution_id("test-node-execution-id")
|
||||||
|
|
||||||
|
# Assert select was called with correct parameters
|
||||||
|
mock_select.assert_called_once()
|
||||||
|
session_obj.scalar.assert_called_once_with(mock_stmt)
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
||||||
|
"""Test get_by_workflow_run method."""
|
||||||
|
session_obj, _ = session
|
||||||
|
# Set up mock
|
||||||
|
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
|
||||||
|
mock_stmt = mocker.MagicMock()
|
||||||
|
mock_select.return_value = mock_stmt
|
||||||
|
mock_stmt.where.return_value = mock_stmt
|
||||||
|
mock_stmt.order_by.return_value = mock_stmt
|
||||||
|
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
|
||||||
|
|
||||||
|
# Call method
|
||||||
|
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||||
|
result = repository.get_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config)
|
||||||
|
|
||||||
|
# Assert select was called with correct parameters
|
||||||
|
mock_select.assert_called_once()
|
||||||
|
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_running_executions(repository, session, mocker: MockerFixture):
|
||||||
|
"""Test get_running_executions method."""
|
||||||
|
session_obj, _ = session
|
||||||
|
# Set up mock
|
||||||
|
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
|
||||||
|
mock_stmt = mocker.MagicMock()
|
||||||
|
mock_select.return_value = mock_stmt
|
||||||
|
mock_stmt.where.return_value = mock_stmt
|
||||||
|
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
|
||||||
|
|
||||||
|
# Call method
|
||||||
|
result = repository.get_running_executions("test-workflow-run-id")
|
||||||
|
|
||||||
|
# Assert select was called with correct parameters
|
||||||
|
mock_select.assert_called_once()
|
||||||
|
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||||
|
assert len(result) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_update(repository, session):
|
||||||
|
"""Test update method."""
|
||||||
|
session_obj, _ = session
|
||||||
|
# Create a mock execution
|
||||||
|
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||||
|
execution.tenant_id = None
|
||||||
|
execution.app_id = None
|
||||||
|
|
||||||
|
# Call update method
|
||||||
|
repository.update(execution)
|
||||||
|
|
||||||
|
# Assert tenant_id and app_id are set
|
||||||
|
assert execution.tenant_id == repository._tenant_id
|
||||||
|
assert execution.app_id == repository._app_id
|
||||||
|
|
||||||
|
# Assert session.merge was called
|
||||||
|
session_obj.merge.assert_called_once_with(execution)
|
||||||
@ -0,0 +1,69 @@
|
|||||||
|
import {
|
||||||
|
memo,
|
||||||
|
useState,
|
||||||
|
} from 'react'
|
||||||
|
import type { EnvironmentVariable } from '@/app/components/workflow/types'
|
||||||
|
import { DSL_EXPORT_CHECK } from '@/app/components/workflow/constants'
|
||||||
|
import { useStore } from '@/app/components/workflow/store'
|
||||||
|
import Features from '@/app/components/workflow/features'
|
||||||
|
import PluginDependency from '@/app/components/workflow/plugin-dependency'
|
||||||
|
import UpdateDSLModal from '@/app/components/workflow/update-dsl-modal'
|
||||||
|
import DSLExportConfirmModal from '@/app/components/workflow/dsl-export-confirm-modal'
|
||||||
|
import {
|
||||||
|
useDSL,
|
||||||
|
usePanelInteractions,
|
||||||
|
} from '@/app/components/workflow/hooks'
|
||||||
|
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||||
|
import WorkflowHeader from './workflow-header'
|
||||||
|
import WorkflowPanel from './workflow-panel'
|
||||||
|
|
||||||
|
const WorkflowChildren = () => {
|
||||||
|
const { eventEmitter } = useEventEmitterContextContext()
|
||||||
|
const [secretEnvList, setSecretEnvList] = useState<EnvironmentVariable[]>([])
|
||||||
|
const showFeaturesPanel = useStore(s => s.showFeaturesPanel)
|
||||||
|
const showImportDSLModal = useStore(s => s.showImportDSLModal)
|
||||||
|
const setShowImportDSLModal = useStore(s => s.setShowImportDSLModal)
|
||||||
|
const {
|
||||||
|
handlePaneContextmenuCancel,
|
||||||
|
} = usePanelInteractions()
|
||||||
|
const {
|
||||||
|
exportCheck,
|
||||||
|
handleExportDSL,
|
||||||
|
} = useDSL()
|
||||||
|
|
||||||
|
eventEmitter?.useSubscription((v: any) => {
|
||||||
|
if (v.type === DSL_EXPORT_CHECK)
|
||||||
|
setSecretEnvList(v.payload.data as EnvironmentVariable[])
|
||||||
|
})
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<PluginDependency />
|
||||||
|
{
|
||||||
|
showFeaturesPanel && <Features />
|
||||||
|
}
|
||||||
|
{
|
||||||
|
showImportDSLModal && (
|
||||||
|
<UpdateDSLModal
|
||||||
|
onCancel={() => setShowImportDSLModal(false)}
|
||||||
|
onBackup={exportCheck}
|
||||||
|
onImport={handlePaneContextmenuCancel}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
{
|
||||||
|
secretEnvList.length > 0 && (
|
||||||
|
<DSLExportConfirmModal
|
||||||
|
envList={secretEnvList}
|
||||||
|
onConfirm={handleExportDSL}
|
||||||
|
onClose={() => setSecretEnvList([])}
|
||||||
|
/>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
<WorkflowHeader />
|
||||||
|
<WorkflowPanel />
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(WorkflowChildren)
|
||||||
@ -0,0 +1,11 @@
|
|||||||
|
import { memo } from 'react'
|
||||||
|
import ChatVariableButton from '@/app/components/workflow/header/chat-variable-button'
|
||||||
|
import {
|
||||||
|
useNodesReadOnly,
|
||||||
|
} from '@/app/components/workflow/hooks'
|
||||||
|
|
||||||
|
const ChatVariableTrigger = () => {
|
||||||
|
const { nodesReadOnly } = useNodesReadOnly()
|
||||||
|
return <ChatVariableButton disabled={nodesReadOnly} />
|
||||||
|
}
|
||||||
|
export default memo(ChatVariableTrigger)
|
||||||
@ -0,0 +1,152 @@
|
|||||||
|
import {
|
||||||
|
memo,
|
||||||
|
useCallback,
|
||||||
|
useMemo,
|
||||||
|
} from 'react'
|
||||||
|
import { useNodes } from 'reactflow'
|
||||||
|
import { RiApps2AddLine } from '@remixicon/react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import {
|
||||||
|
useStore,
|
||||||
|
useWorkflowStore,
|
||||||
|
} from '@/app/components/workflow/store'
|
||||||
|
import {
|
||||||
|
useChecklistBeforePublish,
|
||||||
|
useNodesReadOnly,
|
||||||
|
useNodesSyncDraft,
|
||||||
|
} from '@/app/components/workflow/hooks'
|
||||||
|
import Button from '@/app/components/base/button'
|
||||||
|
import AppPublisher from '@/app/components/app/app-publisher'
|
||||||
|
import { useFeatures } from '@/app/components/base/features/hooks'
|
||||||
|
import {
|
||||||
|
BlockEnum,
|
||||||
|
InputVarType,
|
||||||
|
} from '@/app/components/workflow/types'
|
||||||
|
import type { StartNodeType } from '@/app/components/workflow/nodes/start/types'
|
||||||
|
import { useToastContext } from '@/app/components/base/toast'
|
||||||
|
import { usePublishWorkflow, useResetWorkflowVersionHistory } from '@/service/use-workflow'
|
||||||
|
import type { PublishWorkflowParams } from '@/types/workflow'
|
||||||
|
import { fetchAppDetail, fetchAppSSO } from '@/service/apps'
|
||||||
|
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||||
|
import { useSelector as useAppSelector } from '@/context/app-context'
|
||||||
|
|
||||||
|
const FeaturesTrigger = () => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
const workflowStore = useWorkflowStore()
|
||||||
|
const appDetail = useAppStore(s => s.appDetail)
|
||||||
|
const appID = appDetail?.id
|
||||||
|
const setAppDetail = useAppStore(s => s.setAppDetail)
|
||||||
|
const systemFeatures = useAppSelector(state => state.systemFeatures)
|
||||||
|
const {
|
||||||
|
nodesReadOnly,
|
||||||
|
getNodesReadOnly,
|
||||||
|
} = useNodesReadOnly()
|
||||||
|
const publishedAt = useStore(s => s.publishedAt)
|
||||||
|
const draftUpdatedAt = useStore(s => s.draftUpdatedAt)
|
||||||
|
const toolPublished = useStore(s => s.toolPublished)
|
||||||
|
const nodes = useNodes<StartNodeType>()
|
||||||
|
const startNode = nodes.find(node => node.data.type === BlockEnum.Start)
|
||||||
|
const startVariables = startNode?.data.variables
|
||||||
|
const fileSettings = useFeatures(s => s.features.file)
|
||||||
|
const variables = useMemo(() => {
|
||||||
|
const data = startVariables || []
|
||||||
|
if (fileSettings?.image?.enabled) {
|
||||||
|
return [
|
||||||
|
...data,
|
||||||
|
{
|
||||||
|
type: InputVarType.files,
|
||||||
|
variable: '__image',
|
||||||
|
required: false,
|
||||||
|
label: 'files',
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
return data
|
||||||
|
}, [fileSettings?.image?.enabled, startVariables])
|
||||||
|
|
||||||
|
const { handleCheckBeforePublish } = useChecklistBeforePublish()
|
||||||
|
const { handleSyncWorkflowDraft } = useNodesSyncDraft()
|
||||||
|
const { notify } = useToastContext()
|
||||||
|
|
||||||
|
const handleShowFeatures = useCallback(() => {
|
||||||
|
const {
|
||||||
|
showFeaturesPanel,
|
||||||
|
isRestoring,
|
||||||
|
setShowFeaturesPanel,
|
||||||
|
} = workflowStore.getState()
|
||||||
|
if (getNodesReadOnly() && !isRestoring)
|
||||||
|
return
|
||||||
|
setShowFeaturesPanel(!showFeaturesPanel)
|
||||||
|
}, [workflowStore, getNodesReadOnly])
|
||||||
|
|
||||||
|
const resetWorkflowVersionHistory = useResetWorkflowVersionHistory(appDetail!.id)
|
||||||
|
|
||||||
|
const updateAppDetail = useCallback(async () => {
|
||||||
|
try {
|
||||||
|
const res = await fetchAppDetail({ url: '/apps', id: appID! })
|
||||||
|
if (systemFeatures.enable_web_sso_switch_component) {
|
||||||
|
const ssoRes = await fetchAppSSO({ appId: appID! })
|
||||||
|
setAppDetail({ ...res, enable_sso: ssoRes.enabled })
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
setAppDetail({ ...res })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
catch (error) {
|
||||||
|
console.error(error)
|
||||||
|
}
|
||||||
|
}, [appID, setAppDetail, systemFeatures.enable_web_sso_switch_component])
|
||||||
|
const { mutateAsync: publishWorkflow } = usePublishWorkflow(appID!)
|
||||||
|
const onPublish = useCallback(async (params?: PublishWorkflowParams) => {
|
||||||
|
if (await handleCheckBeforePublish()) {
|
||||||
|
const res = await publishWorkflow({
|
||||||
|
title: params?.title || '',
|
||||||
|
releaseNotes: params?.releaseNotes || '',
|
||||||
|
})
|
||||||
|
|
||||||
|
if (res) {
|
||||||
|
notify({ type: 'success', message: t('common.api.actionSuccess') })
|
||||||
|
updateAppDetail()
|
||||||
|
workflowStore.getState().setPublishedAt(res.created_at)
|
||||||
|
resetWorkflowVersionHistory()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
throw new Error('Checklist failed')
|
||||||
|
}
|
||||||
|
}, [handleCheckBeforePublish, notify, t, workflowStore, publishWorkflow, resetWorkflowVersionHistory, updateAppDetail])
|
||||||
|
|
||||||
|
const onPublisherToggle = useCallback((state: boolean) => {
|
||||||
|
if (state)
|
||||||
|
handleSyncWorkflowDraft(true)
|
||||||
|
}, [handleSyncWorkflowDraft])
|
||||||
|
|
||||||
|
const handleToolConfigureUpdate = useCallback(() => {
|
||||||
|
workflowStore.setState({ toolPublished: true })
|
||||||
|
}, [workflowStore])
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Button className='text-components-button-secondary-text' onClick={handleShowFeatures}>
|
||||||
|
<RiApps2AddLine className='mr-1 h-4 w-4 text-components-button-secondary-text' />
|
||||||
|
{t('workflow.common.features')}
|
||||||
|
</Button>
|
||||||
|
<AppPublisher
|
||||||
|
{...{
|
||||||
|
publishedAt,
|
||||||
|
draftUpdatedAt,
|
||||||
|
disabled: nodesReadOnly,
|
||||||
|
toolPublished,
|
||||||
|
inputs: variables,
|
||||||
|
onRefreshData: handleToolConfigureUpdate,
|
||||||
|
onPublish,
|
||||||
|
onToggle: onPublisherToggle,
|
||||||
|
crossAxisOffset: 4,
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default memo(FeaturesTrigger)
|
||||||
@ -0,0 +1,31 @@
|
|||||||
|
import { useMemo } from 'react'
|
||||||
|
import type { HeaderProps } from '@/app/components/workflow/header'
|
||||||
|
import Header from '@/app/components/workflow/header'
|
||||||
|
import { useStore as useAppStore } from '@/app/components/app/store'
|
||||||
|
import ChatVariableTrigger from './chat-variable-trigger'
|
||||||
|
import FeaturesTrigger from './features-trigger'
|
||||||
|
import { useResetWorkflowVersionHistory } from '@/service/use-workflow'
|
||||||
|
|
||||||
|
const WorkflowHeader = () => {
|
||||||
|
const appDetail = useAppStore(s => s.appDetail)
|
||||||
|
const resetWorkflowVersionHistory = useResetWorkflowVersionHistory(appDetail!.id)
|
||||||
|
|
||||||
|
const headerProps: HeaderProps = useMemo(() => {
|
||||||
|
return {
|
||||||
|
normal: {
|
||||||
|
components: {
|
||||||
|
left: <ChatVariableTrigger />,
|
||||||
|
middle: <FeaturesTrigger />,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
restoring: {
|
||||||
|
onRestoreSettled: resetWorkflowVersionHistory,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}, [resetWorkflowVersionHistory])
|
||||||
|
return (
|
||||||
|
<Header {...headerProps} />
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default WorkflowHeader
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue