Merge branch 'langgenius:main' into deploy
commit
029e72395b
@ -0,0 +1,15 @@
|
||||
"""
|
||||
Repository interfaces for data access.
|
||||
|
||||
This package contains repository interfaces that define the contract
|
||||
for accessing and manipulating data, regardless of the underlying
|
||||
storage mechanism.
|
||||
"""
|
||||
|
||||
from core.repository.repository_factory import RepositoryFactory
|
||||
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"RepositoryFactory",
|
||||
"WorkflowNodeExecutionRepository",
|
||||
]
|
||||
@ -0,0 +1,97 @@
|
||||
"""
|
||||
Repository factory for creating repository instances.
|
||||
|
||||
This module provides a simple factory interface for creating repository instances.
|
||||
It does not contain any implementation details or dependencies on specific repositories.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
|
||||
# Type for factory functions - takes a dict of parameters and returns any repository type
|
||||
RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any]
|
||||
|
||||
# Type for workflow node execution factory function
|
||||
WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository]
|
||||
|
||||
# Repository type literals
|
||||
_RepositoryType = Literal["workflow_node_execution"]
|
||||
|
||||
|
||||
class RepositoryFactory:
|
||||
"""
|
||||
Factory class for creating repository instances.
|
||||
|
||||
This factory delegates the actual repository creation to implementation-specific
|
||||
factory functions that are registered with the factory at runtime.
|
||||
"""
|
||||
|
||||
# Dictionary to store factory functions
|
||||
_factory_functions: dict[str, RepositoryFactoryFunc] = {}
|
||||
|
||||
@classmethod
|
||||
def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None:
|
||||
"""
|
||||
Register a factory function for a specific repository type.
|
||||
This is a private method and should not be called directly.
|
||||
|
||||
Args:
|
||||
repository_type: The type of repository (e.g., 'workflow_node_execution')
|
||||
factory_func: A function that takes parameters and returns a repository instance
|
||||
"""
|
||||
cls._factory_functions[repository_type] = factory_func
|
||||
|
||||
@classmethod
|
||||
def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any:
|
||||
"""
|
||||
Create a new repository instance with the provided parameters.
|
||||
This is a private method and should not be called directly.
|
||||
|
||||
Args:
|
||||
repository_type: The type of repository to create
|
||||
params: A dictionary of parameters to pass to the factory function
|
||||
|
||||
Returns:
|
||||
A new instance of the requested repository
|
||||
|
||||
Raises:
|
||||
ValueError: If no factory function is registered for the repository type
|
||||
"""
|
||||
if repository_type not in cls._factory_functions:
|
||||
raise ValueError(f"No factory function registered for repository type '{repository_type}'")
|
||||
|
||||
# Use empty dict if params is None
|
||||
params = params or {}
|
||||
|
||||
return cls._factory_functions[repository_type](params)
|
||||
|
||||
@classmethod
|
||||
def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None:
|
||||
"""
|
||||
Register a factory function for the workflow node execution repository.
|
||||
|
||||
Args:
|
||||
factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance
|
||||
"""
|
||||
cls._register_factory("workflow_node_execution", factory_func)
|
||||
|
||||
@classmethod
|
||||
def create_workflow_node_execution_repository(
|
||||
cls, params: Optional[Mapping[str, Any]] = None
|
||||
) -> WorkflowNodeExecutionRepository:
|
||||
"""
|
||||
Create a new WorkflowNodeExecutionRepository instance with the provided parameters.
|
||||
|
||||
Args:
|
||||
params: A dictionary of parameters to pass to the factory function
|
||||
|
||||
Returns:
|
||||
A new instance of the WorkflowNodeExecutionRepository
|
||||
|
||||
Raises:
|
||||
ValueError: If no factory function is registered for the workflow_node_execution repository type
|
||||
"""
|
||||
# We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc
|
||||
return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params))
|
||||
@ -0,0 +1,97 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional, Protocol
|
||||
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrderConfig:
|
||||
"""Configuration for ordering WorkflowNodeExecution instances."""
|
||||
|
||||
order_by: list[str]
|
||||
order_direction: Optional[Literal["asc", "desc"]] = None
|
||||
|
||||
|
||||
class WorkflowNodeExecutionRepository(Protocol):
|
||||
"""
|
||||
Repository interface for WorkflowNodeExecution.
|
||||
|
||||
This interface defines the contract for accessing and manipulating
|
||||
WorkflowNodeExecution data, regardless of the underlying storage mechanism.
|
||||
|
||||
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
|
||||
and trigger sources (triggered_from) should be handled at the implementation level, not in
|
||||
the core interface. This keeps the core domain model clean and independent of specific
|
||||
application domains or deployment scenarios.
|
||||
"""
|
||||
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save a WorkflowNodeExecution instance.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to save
|
||||
"""
|
||||
...
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve a WorkflowNodeExecution by its node_execution_id.
|
||||
|
||||
Args:
|
||||
node_execution_id: The node execution ID
|
||||
|
||||
Returns:
|
||||
The WorkflowNodeExecution instance if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
def get_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
order_config: Optional configuration for ordering results
|
||||
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
|
||||
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||
|
||||
Returns:
|
||||
A list of WorkflowNodeExecution instances
|
||||
"""
|
||||
...
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
|
||||
Returns:
|
||||
A list of running WorkflowNodeExecution instances
|
||||
"""
|
||||
...
|
||||
|
||||
def update(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Update an existing WorkflowNodeExecution instance.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to update
|
||||
"""
|
||||
...
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all WorkflowNodeExecution records based on implementation-specific criteria.
|
||||
|
||||
This method is intended to be used for bulk deletion operations, such as removing
|
||||
all records associated with a specific app_id and tenant_id in multi-tenant implementations.
|
||||
"""
|
||||
...
|
||||
@ -0,0 +1,24 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class ResponseFormat(StrEnum):
|
||||
"""Constants for model response formats"""
|
||||
|
||||
JSON_SCHEMA = "json_schema" # model's structured output mode. some model like gemini, gpt-4o, support this mode.
|
||||
JSON = "JSON" # model's json mode. some model like claude support this mode.
|
||||
JSON_OBJECT = "json_object" # json mode's another alias. some model like deepseek-chat, qwen use this alias.
|
||||
|
||||
|
||||
class SpecialModelType(StrEnum):
|
||||
"""Constants for identifying model types"""
|
||||
|
||||
GEMINI = "gemini"
|
||||
OLLAMA = "ollama"
|
||||
|
||||
|
||||
class SupportStructuredOutputStatus(StrEnum):
|
||||
"""Constants for structured output support status"""
|
||||
|
||||
SUPPORTED = "supported"
|
||||
UNSUPPORTED = "unsupported"
|
||||
DISABLED = "disabled"
|
||||
@ -0,0 +1,18 @@
|
||||
"""
|
||||
Extension for initializing repositories.
|
||||
|
||||
This extension registers repository implementations with the RepositoryFactory.
|
||||
"""
|
||||
|
||||
from dify_app import DifyApp
|
||||
from repositories.repository_registry import register_repositories
|
||||
|
||||
|
||||
def init_app(_app: DifyApp) -> None:
|
||||
"""
|
||||
Initialize repository implementations.
|
||||
|
||||
Args:
|
||||
_app: The Flask application instance (unused)
|
||||
"""
|
||||
register_repositories()
|
||||
@ -0,0 +1,6 @@
|
||||
"""
|
||||
Repository implementations for data access.
|
||||
|
||||
This package contains concrete implementations of the repository interfaces
|
||||
defined in the core.repository package.
|
||||
"""
|
||||
@ -0,0 +1,87 @@
|
||||
"""
|
||||
Registry for repository implementations.
|
||||
|
||||
This module is responsible for registering factory functions with the repository factory.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.repository.repository_factory import RepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Storage type constants
|
||||
STORAGE_TYPE_RDBMS = "rdbms"
|
||||
STORAGE_TYPE_HYBRID = "hybrid"
|
||||
|
||||
|
||||
def register_repositories() -> None:
|
||||
"""
|
||||
Register repository factory functions with the RepositoryFactory.
|
||||
|
||||
This function reads configuration settings to determine which repository
|
||||
implementations to register.
|
||||
"""
|
||||
# Configure WorkflowNodeExecutionRepository factory based on configuration
|
||||
workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE
|
||||
|
||||
# Check storage type and register appropriate implementation
|
||||
if workflow_node_execution_storage == STORAGE_TYPE_RDBMS:
|
||||
# Register SQLAlchemy implementation for RDBMS storage
|
||||
logger.info("Registering WorkflowNodeExecution repository with RDBMS storage")
|
||||
RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository)
|
||||
elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID:
|
||||
# Hybrid storage is not yet implemented
|
||||
raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented")
|
||||
else:
|
||||
# Unknown storage type
|
||||
raise ValueError(
|
||||
f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. "
|
||||
f"Supported types: {STORAGE_TYPE_RDBMS}"
|
||||
)
|
||||
|
||||
|
||||
def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
"""
|
||||
Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation.
|
||||
|
||||
This factory function creates a repository for the RDBMS storage type.
|
||||
|
||||
Args:
|
||||
params: Parameters for creating the repository, including:
|
||||
- tenant_id: Required. The tenant ID for multi-tenancy.
|
||||
- app_id: Optional. The application ID for filtering.
|
||||
- session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided,
|
||||
a new sessionmaker will be created using the global database engine.
|
||||
|
||||
Returns:
|
||||
A WorkflowNodeExecutionRepository instance
|
||||
|
||||
Raises:
|
||||
ValueError: If required parameters are missing
|
||||
"""
|
||||
# Extract required parameters
|
||||
tenant_id = params.get("tenant_id")
|
||||
if tenant_id is None:
|
||||
raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage")
|
||||
|
||||
# Extract optional parameters
|
||||
app_id = params.get("app_id")
|
||||
|
||||
# Use the session_factory from params if provided, otherwise create one using the global db engine
|
||||
session_factory = params.get("session_factory")
|
||||
if session_factory is None:
|
||||
# Create a sessionmaker using the same engine as the global db session
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
|
||||
# Create and return the repository
|
||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
|
||||
)
|
||||
@ -0,0 +1,9 @@
|
||||
"""
|
||||
WorkflowNodeExecution repository implementations.
|
||||
"""
|
||||
|
||||
from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
]
|
||||
@ -0,0 +1,192 @@
|
||||
"""
|
||||
SQLAlchemy implementation of the WorkflowNodeExecutionRepository.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import UnaryExpression, asc, delete, desc, select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repository.workflow_node_execution_repository import OrderConfig
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
"""
|
||||
SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface.
|
||||
|
||||
This implementation supports multi-tenancy by filtering operations based on tenant_id.
|
||||
Each method creates its own session, handles the transaction, and commits changes
|
||||
to the database. This prevents long-running connections in the workflow core.
|
||||
"""
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
|
||||
tenant_id: Tenant ID for multi-tenancy
|
||||
app_id: Optional app ID for filtering by application
|
||||
"""
|
||||
# If an engine is provided, create a sessionmaker from it
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
else:
|
||||
self._session_factory = session_factory
|
||||
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save a WorkflowNodeExecution instance and commit changes to the database.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to save
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
# Ensure tenant_id is set
|
||||
if not execution.tenant_id:
|
||||
execution.tenant_id = self._tenant_id
|
||||
|
||||
# Set app_id if provided and not already set
|
||||
if self._app_id and not execution.app_id:
|
||||
execution.app_id = self._app_id
|
||||
|
||||
session.add(execution)
|
||||
session.commit()
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve a WorkflowNodeExecution by its node_execution_id.
|
||||
|
||||
Args:
|
||||
node_execution_id: The node execution ID
|
||||
|
||||
Returns:
|
||||
The WorkflowNodeExecution instance if found, None otherwise
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
|
||||
return session.scalar(stmt)
|
||||
|
||||
def get_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
order_config: Optional configuration for ordering results
|
||||
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
|
||||
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||
|
||||
Returns:
|
||||
A list of WorkflowNodeExecution instances
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
|
||||
# Apply ordering if provided
|
||||
if order_config and order_config.order_by:
|
||||
order_columns: list[UnaryExpression] = []
|
||||
for field in order_config.order_by:
|
||||
column = getattr(WorkflowNodeExecution, field, None)
|
||||
if not column:
|
||||
continue
|
||||
if order_config.order_direction == "desc":
|
||||
order_columns.append(desc(column))
|
||||
else:
|
||||
order_columns.append(asc(column))
|
||||
|
||||
if order_columns:
|
||||
stmt = stmt.order_by(*order_columns)
|
||||
|
||||
return session.scalars(stmt).all()
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
|
||||
Returns:
|
||||
A list of running WorkflowNodeExecution instances
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
|
||||
return session.scalars(stmt).all()
|
||||
|
||||
def update(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Update an existing WorkflowNodeExecution instance and commit changes to the database.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to update
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
# Ensure tenant_id is set
|
||||
if not execution.tenant_id:
|
||||
execution.tenant_id = self._tenant_id
|
||||
|
||||
# Set app_id if provided and not already set
|
||||
if self._app_id and not execution.app_id:
|
||||
execution.app_id = self._app_id
|
||||
|
||||
session.merge(execution)
|
||||
session.commit()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all WorkflowNodeExecution records for the current tenant_id and app_id.
|
||||
|
||||
This method deletes all WorkflowNodeExecution records that match the tenant_id
|
||||
and app_id (if provided) associated with this repository instance.
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = delete(WorkflowNodeExecution).where(WorkflowNodeExecution.tenant_id == self._tenant_id)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
|
||||
result = session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
deleted_count = result.rowcount
|
||||
logger.info(
|
||||
f"Cleared {deleted_count} workflow node execution records for tenant {self._tenant_id}"
|
||||
+ (f" and app {self._app_id}" if self._app_id else "")
|
||||
)
|
||||
@ -0,0 +1,99 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call
|
||||
|
||||
ToolCall = AssistantPromptMessage.ToolCall
|
||||
|
||||
# CASE 1: Single tool call
|
||||
INPUTS_CASE_1 = [
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_1 = [
|
||||
ToolCall(
|
||||
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||
),
|
||||
]
|
||||
|
||||
# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...)
|
||||
INPUTS_CASE_2 = [
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_2 = [
|
||||
ToolCall(
|
||||
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||
),
|
||||
ToolCall(
|
||||
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
|
||||
),
|
||||
]
|
||||
|
||||
# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...)
|
||||
INPUTS_CASE_3 = [
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_3 = [
|
||||
ToolCall(
|
||||
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||
),
|
||||
ToolCall(
|
||||
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
|
||||
),
|
||||
]
|
||||
|
||||
# CASE 4: Tool call sequences with no IDs
|
||||
INPUTS_CASE_4 = [
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_4 = [
|
||||
ToolCall(
|
||||
id="RANDOM_ID_1",
|
||||
type="function",
|
||||
function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'),
|
||||
),
|
||||
ToolCall(
|
||||
id="RANDOM_ID_2",
|
||||
type="function",
|
||||
function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _run_case(inputs: list[ToolCall], expected: list[ToolCall]):
|
||||
actual = []
|
||||
_increase_tool_call(inputs, actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test__increase_tool_call():
|
||||
# case 1:
|
||||
_run_case(INPUTS_CASE_1, EXPECTED_CASE_1)
|
||||
|
||||
# case 2:
|
||||
_run_case(INPUTS_CASE_2, EXPECTED_CASE_2)
|
||||
|
||||
# case 3:
|
||||
_run_case(INPUTS_CASE_3, EXPECTED_CASE_3)
|
||||
|
||||
# case 4:
|
||||
mock_id_generator = MagicMock()
|
||||
mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4]
|
||||
with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator):
|
||||
_run_case(INPUTS_CASE_4, EXPECTED_CASE_4)
|
||||
@ -0,0 +1,3 @@
|
||||
"""
|
||||
Unit tests for repositories.
|
||||
"""
|
||||
@ -0,0 +1,3 @@
|
||||
"""
|
||||
Unit tests for workflow_node_execution repositories.
|
||||
"""
|
||||
@ -0,0 +1,178 @@
|
||||
"""
|
||||
Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.repository.workflow_node_execution_repository import OrderConfig
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
"""Create a mock SQLAlchemy session."""
|
||||
session = MagicMock(spec=Session)
|
||||
# Configure the session to be used as a context manager
|
||||
session.__enter__ = MagicMock(return_value=session)
|
||||
session.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Configure the session factory to return the session
|
||||
session_factory = MagicMock(spec=sessionmaker)
|
||||
session_factory.return_value = session
|
||||
return session, session_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository(session):
|
||||
"""Create a repository instance with test data."""
|
||||
_, session_factory = session
|
||||
tenant_id = "test-tenant"
|
||||
app_id = "test-app"
|
||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
|
||||
)
|
||||
|
||||
|
||||
def test_save(repository, session):
|
||||
"""Test save method."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution
|
||||
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
execution.tenant_id = None
|
||||
execution.app_id = None
|
||||
|
||||
# Call save method
|
||||
repository.save(execution)
|
||||
|
||||
# Assert tenant_id and app_id are set
|
||||
assert execution.tenant_id == repository._tenant_id
|
||||
assert execution.app_id == repository._app_id
|
||||
|
||||
# Assert session.add was called
|
||||
session_obj.add.assert_called_once_with(execution)
|
||||
|
||||
|
||||
def test_save_with_existing_tenant_id(repository, session):
|
||||
"""Test save method with existing tenant_id."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution with existing tenant_id
|
||||
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
execution.tenant_id = "existing-tenant"
|
||||
execution.app_id = None
|
||||
|
||||
# Call save method
|
||||
repository.save(execution)
|
||||
|
||||
# Assert tenant_id is not changed and app_id is set
|
||||
assert execution.tenant_id == "existing-tenant"
|
||||
assert execution.app_id == repository._app_id
|
||||
|
||||
# Assert session.add was called
|
||||
session_obj.add.assert_called_once_with(execution)
|
||||
|
||||
|
||||
def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
|
||||
"""Test get_by_node_execution_id method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
session_obj.scalar.return_value = mocker.MagicMock(spec=WorkflowNodeExecution)
|
||||
|
||||
# Call method
|
||||
result = repository.get_by_node_execution_id("test-node-execution-id")
|
||||
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalar.assert_called_once_with(mock_stmt)
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
||||
"""Test get_by_workflow_run method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
|
||||
|
||||
# Call method
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
result = repository.get_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config)
|
||||
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
def test_get_running_executions(repository, session, mocker: MockerFixture):
|
||||
"""Test get_running_executions method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
|
||||
|
||||
# Call method
|
||||
result = repository.get_running_executions("test-workflow-run-id")
|
||||
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
def test_update(repository, session):
|
||||
"""Test update method."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution
|
||||
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
execution.tenant_id = None
|
||||
execution.app_id = None
|
||||
|
||||
# Call update method
|
||||
repository.update(execution)
|
||||
|
||||
# Assert tenant_id and app_id are set
|
||||
assert execution.tenant_id == repository._tenant_id
|
||||
assert execution.app_id == repository._app_id
|
||||
|
||||
# Assert session.merge was called
|
||||
session_obj.merge.assert_called_once_with(execution)
|
||||
|
||||
|
||||
def test_clear(repository, session, mocker: MockerFixture):
|
||||
"""Test clear method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_delete = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.delete")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_delete.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
|
||||
# Mock the execute result with rowcount
|
||||
mock_result = mocker.MagicMock()
|
||||
mock_result.rowcount = 5 # Simulate 5 records deleted
|
||||
session_obj.execute.return_value = mock_result
|
||||
|
||||
# Call method
|
||||
repository.clear()
|
||||
|
||||
# Assert delete was called with correct parameters
|
||||
mock_delete.assert_called_once_with(WorkflowNodeExecution)
|
||||
mock_stmt.where.assert_called()
|
||||
session_obj.execute.assert_called_once_with(mock_stmt)
|
||||
session_obj.commit.assert_called_once()
|
||||
@ -0,0 +1,82 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import ConfigSelect from './index'
|
||||
|
||||
jest.mock('react-sortablejs', () => ({
|
||||
ReactSortable: ({ children }: { children: React.ReactNode }) => <div>{children}</div>,
|
||||
}))
|
||||
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('ConfigSelect Component', () => {
|
||||
const defaultProps = {
|
||||
options: ['Option 1', 'Option 2'],
|
||||
onChange: jest.fn(),
|
||||
}
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
it('renders all options', () => {
|
||||
render(<ConfigSelect {...defaultProps} />)
|
||||
|
||||
defaultProps.options.forEach((option) => {
|
||||
expect(screen.getByDisplayValue(option)).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('renders add button', () => {
|
||||
render(<ConfigSelect {...defaultProps} />)
|
||||
|
||||
expect(screen.getByText('appDebug.variableConfig.addOption')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('handles option deletion', () => {
|
||||
render(<ConfigSelect {...defaultProps} />)
|
||||
const optionContainer = screen.getByDisplayValue('Option 1').closest('div')
|
||||
const deleteButton = optionContainer?.querySelector('div[role="button"]')
|
||||
|
||||
if (!deleteButton) return
|
||||
fireEvent.click(deleteButton)
|
||||
expect(defaultProps.onChange).toHaveBeenCalledWith(['Option 2'])
|
||||
})
|
||||
|
||||
it('handles adding new option', () => {
|
||||
render(<ConfigSelect {...defaultProps} />)
|
||||
const addButton = screen.getByText('appDebug.variableConfig.addOption')
|
||||
|
||||
fireEvent.click(addButton)
|
||||
|
||||
expect(defaultProps.onChange).toHaveBeenCalledWith([...defaultProps.options, ''])
|
||||
})
|
||||
|
||||
it('applies focus styles on input focus', () => {
|
||||
render(<ConfigSelect {...defaultProps} />)
|
||||
const firstInput = screen.getByDisplayValue('Option 1')
|
||||
|
||||
fireEvent.focus(firstInput)
|
||||
|
||||
expect(firstInput.closest('div')).toHaveClass('border-components-input-border-active')
|
||||
})
|
||||
|
||||
it('applies delete hover styles', () => {
|
||||
render(<ConfigSelect {...defaultProps} />)
|
||||
const optionContainer = screen.getByDisplayValue('Option 1').closest('div')
|
||||
const deleteButton = optionContainer?.querySelector('div[role="button"]')
|
||||
|
||||
if (!deleteButton) return
|
||||
fireEvent.mouseEnter(deleteButton)
|
||||
expect(optionContainer).toHaveClass('border-components-input-border-destructive')
|
||||
})
|
||||
|
||||
it('renders empty state correctly', () => {
|
||||
render(<ConfigSelect options={[]} onChange={defaultProps.onChange} />)
|
||||
|
||||
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('appDebug.variableConfig.addOption')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,11 @@
|
||||
const IndeterminateIcon = () => {
|
||||
return (
|
||||
<div data-testid='indeterminate-icon'>
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="12" height="12" viewBox="0 0 12 12" fill="none">
|
||||
<path d="M2.5 6H9.5" stroke="currentColor" strokeWidth="1.5" strokeLinecap="round"/>
|
||||
</svg>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default IndeterminateIcon
|
||||
@ -1,5 +0,0 @@
|
||||
<svg width="12" height="12" viewBox="0 0 12 12" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g id="check">
|
||||
<path id="Vector 1" d="M2.5 6H9.5" stroke="white" stroke-width="1.5" stroke-linecap="round"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 217 B |
@ -1,10 +0,0 @@
|
||||
.mixed {
|
||||
background: var(--color-components-checkbox-bg) url(./assets/mixed.svg) center center no-repeat;
|
||||
background-size: 12px 12px;
|
||||
border: none;
|
||||
}
|
||||
|
||||
.checked.disabled {
|
||||
background-color: #d0d5dd;
|
||||
border-color: #d0d5dd;
|
||||
}
|
||||
@ -0,0 +1,67 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import Checkbox from './index'
|
||||
|
||||
describe('Checkbox Component', () => {
|
||||
const mockProps = {
|
||||
id: 'test',
|
||||
}
|
||||
|
||||
it('renders unchecked checkbox by default', () => {
|
||||
render(<Checkbox {...mockProps} />)
|
||||
const checkbox = screen.getByTestId('checkbox-test')
|
||||
expect(checkbox).toBeInTheDocument()
|
||||
expect(checkbox).not.toHaveClass('bg-components-checkbox-bg')
|
||||
})
|
||||
|
||||
it('renders checked checkbox when checked prop is true', () => {
|
||||
render(<Checkbox {...mockProps} checked />)
|
||||
const checkbox = screen.getByTestId('checkbox-test')
|
||||
expect(checkbox).toHaveClass('bg-components-checkbox-bg')
|
||||
expect(screen.getByTestId('check-icon-test')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders indeterminate state correctly', () => {
|
||||
render(<Checkbox {...mockProps} indeterminate />)
|
||||
expect(screen.getByTestId('indeterminate-icon')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('handles click events when not disabled', () => {
|
||||
const onCheck = jest.fn()
|
||||
render(<Checkbox {...mockProps} onCheck={onCheck} />)
|
||||
const checkbox = screen.getByTestId('checkbox-test')
|
||||
|
||||
fireEvent.click(checkbox)
|
||||
expect(onCheck).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
it('does not handle click events when disabled', () => {
|
||||
const onCheck = jest.fn()
|
||||
render(<Checkbox {...mockProps} disabled onCheck={onCheck} />)
|
||||
const checkbox = screen.getByTestId('checkbox-test')
|
||||
|
||||
fireEvent.click(checkbox)
|
||||
expect(onCheck).not.toHaveBeenCalled()
|
||||
expect(checkbox).toHaveClass('cursor-not-allowed')
|
||||
})
|
||||
|
||||
it('applies custom className when provided', () => {
|
||||
const customClass = 'custom-class'
|
||||
render(<Checkbox {...mockProps} className={customClass} />)
|
||||
const checkbox = screen.getByTestId('checkbox-test')
|
||||
expect(checkbox).toHaveClass(customClass)
|
||||
})
|
||||
|
||||
it('applies correct styles for disabled checked state', () => {
|
||||
render(<Checkbox {...mockProps} checked disabled />)
|
||||
const checkbox = screen.getByTestId('checkbox-test')
|
||||
expect(checkbox).toHaveClass('bg-components-checkbox-bg-disabled-checked')
|
||||
expect(checkbox).toHaveClass('cursor-not-allowed')
|
||||
})
|
||||
|
||||
it('applies correct styles for disabled unchecked state', () => {
|
||||
render(<Checkbox {...mockProps} disabled />)
|
||||
const checkbox = screen.getByTestId('checkbox-test')
|
||||
expect(checkbox).toHaveClass('bg-components-checkbox-bg-disabled')
|
||||
expect(checkbox).toHaveClass('cursor-not-allowed')
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,43 @@
|
||||
import cn from '@/utils/classnames'
|
||||
import { useFieldContext } from '../..'
|
||||
import Checkbox from '../../../checkbox'
|
||||
|
||||
type CheckboxFieldProps = {
|
||||
label: string;
|
||||
labelClassName?: string;
|
||||
}
|
||||
|
||||
const CheckboxField = ({
|
||||
label,
|
||||
labelClassName,
|
||||
}: CheckboxFieldProps) => {
|
||||
const field = useFieldContext<boolean>()
|
||||
|
||||
return (
|
||||
<div className='flex gap-2'>
|
||||
<div className='flex h-6 shrink-0 items-center'>
|
||||
<Checkbox
|
||||
id={field.name}
|
||||
checked={field.state.value}
|
||||
onCheck={() => {
|
||||
field.handleChange(!field.state.value)
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<label
|
||||
htmlFor={field.name}
|
||||
className={cn(
|
||||
'system-sm-medium grow cursor-pointer pt-1 text-text-secondary',
|
||||
labelClassName,
|
||||
)}
|
||||
onClick={() => {
|
||||
field.handleChange(!field.state.value)
|
||||
}}
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default CheckboxField
|
||||
@ -0,0 +1,49 @@
|
||||
import React from 'react'
|
||||
import { useFieldContext } from '../..'
|
||||
import Label from '../label'
|
||||
import cn from '@/utils/classnames'
|
||||
import type { InputNumberProps } from '../../../input-number'
|
||||
import { InputNumber } from '../../../input-number'
|
||||
|
||||
type TextFieldProps = {
|
||||
label: string
|
||||
isRequired?: boolean
|
||||
showOptional?: boolean
|
||||
tooltip?: string
|
||||
className?: string
|
||||
labelClassName?: string
|
||||
} & Omit<InputNumberProps, 'id' | 'value' | 'onChange' | 'onBlur'>
|
||||
|
||||
const NumberInputField = ({
|
||||
label,
|
||||
isRequired,
|
||||
showOptional,
|
||||
tooltip,
|
||||
className,
|
||||
labelClassName,
|
||||
...inputProps
|
||||
}: TextFieldProps) => {
|
||||
const field = useFieldContext<number | undefined>()
|
||||
|
||||
return (
|
||||
<div className={cn('flex flex-col gap-y-0.5', className)}>
|
||||
<Label
|
||||
htmlFor={field.name}
|
||||
label={label}
|
||||
isRequired={isRequired}
|
||||
showOptional={showOptional}
|
||||
tooltip={tooltip}
|
||||
className={labelClassName}
|
||||
/>
|
||||
<InputNumber
|
||||
id={field.name}
|
||||
value={field.state.value}
|
||||
onChange={value => field.handleChange(value)}
|
||||
onBlur={field.handleBlur}
|
||||
{...inputProps}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default NumberInputField
|
||||
@ -0,0 +1,34 @@
|
||||
import cn from '@/utils/classnames'
|
||||
import { useFieldContext } from '../..'
|
||||
import Label from '../label'
|
||||
import ConfigSelect from '@/app/components/app/configuration/config-var/config-select'
|
||||
|
||||
type OptionsFieldProps = {
|
||||
label: string;
|
||||
className?: string;
|
||||
labelClassName?: string;
|
||||
}
|
||||
|
||||
const OptionsField = ({
|
||||
label,
|
||||
className,
|
||||
labelClassName,
|
||||
}: OptionsFieldProps) => {
|
||||
const field = useFieldContext<string[]>()
|
||||
|
||||
return (
|
||||
<div className={cn('flex flex-col gap-y-0.5', className)}>
|
||||
<Label
|
||||
htmlFor={field.name}
|
||||
label={label}
|
||||
className={labelClassName}
|
||||
/>
|
||||
<ConfigSelect
|
||||
options={field.state.value}
|
||||
onChange={value => field.handleChange(value)}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default OptionsField
|
||||
@ -0,0 +1,51 @@
|
||||
import cn from '@/utils/classnames'
|
||||
import { useFieldContext } from '../..'
|
||||
import PureSelect from '../../../select/pure'
|
||||
import Label from '../label'
|
||||
|
||||
type SelectOption = {
|
||||
value: string
|
||||
label: string
|
||||
}
|
||||
|
||||
type SelectFieldProps = {
|
||||
label: string
|
||||
options: SelectOption[]
|
||||
isRequired?: boolean
|
||||
showOptional?: boolean
|
||||
tooltip?: string
|
||||
className?: string
|
||||
labelClassName?: string
|
||||
}
|
||||
|
||||
const SelectField = ({
|
||||
label,
|
||||
options,
|
||||
isRequired,
|
||||
showOptional,
|
||||
tooltip,
|
||||
className,
|
||||
labelClassName,
|
||||
}: SelectFieldProps) => {
|
||||
const field = useFieldContext<string>()
|
||||
|
||||
return (
|
||||
<div className={cn('flex flex-col gap-y-0.5', className)}>
|
||||
<Label
|
||||
htmlFor={field.name}
|
||||
label={label}
|
||||
isRequired={isRequired}
|
||||
showOptional={showOptional}
|
||||
tooltip={tooltip}
|
||||
className={labelClassName}
|
||||
/>
|
||||
<PureSelect
|
||||
value={field.state.value}
|
||||
options={options}
|
||||
onChange={value => field.handleChange(value)}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default SelectField
|
||||
@ -0,0 +1,48 @@
|
||||
import React from 'react'
|
||||
import { useFieldContext } from '../..'
|
||||
import Input, { type InputProps } from '../../../input'
|
||||
import Label from '../label'
|
||||
import cn from '@/utils/classnames'
|
||||
|
||||
type TextFieldProps = {
|
||||
label: string
|
||||
isRequired?: boolean
|
||||
showOptional?: boolean
|
||||
tooltip?: string
|
||||
className?: string
|
||||
labelClassName?: string
|
||||
} & Omit<InputProps, 'className' | 'onChange' | 'onBlur' | 'value' | 'id'>
|
||||
|
||||
const TextField = ({
|
||||
label,
|
||||
isRequired,
|
||||
showOptional,
|
||||
tooltip,
|
||||
className,
|
||||
labelClassName,
|
||||
...inputProps
|
||||
}: TextFieldProps) => {
|
||||
const field = useFieldContext<string>()
|
||||
|
||||
return (
|
||||
<div className={cn('flex flex-col gap-y-0.5', className)}>
|
||||
<Label
|
||||
htmlFor={field.name}
|
||||
label={label}
|
||||
isRequired={isRequired}
|
||||
showOptional={showOptional}
|
||||
tooltip={tooltip}
|
||||
className={labelClassName}
|
||||
/>
|
||||
<Input
|
||||
id={field.name}
|
||||
value={field.state.value}
|
||||
onChange={e => field.handleChange(e.target.value)}
|
||||
onBlur={field.handleBlur}
|
||||
{...inputProps}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default TextField
|
||||
@ -0,0 +1,25 @@
|
||||
import { useStore } from '@tanstack/react-form'
|
||||
import { useFormContext } from '../..'
|
||||
import Button, { type ButtonProps } from '../../../button'
|
||||
|
||||
type SubmitButtonProps = Omit<ButtonProps, 'disabled' | 'loading' | 'onClick'>
|
||||
|
||||
const SubmitButton = ({ ...buttonProps }: SubmitButtonProps) => {
|
||||
const form = useFormContext()
|
||||
|
||||
const [isSubmitting, canSubmit] = useStore(form.store, state => [
|
||||
state.isSubmitting,
|
||||
state.canSubmit,
|
||||
])
|
||||
|
||||
return (
|
||||
<Button
|
||||
disabled={isSubmitting || !canSubmit}
|
||||
loading={isSubmitting}
|
||||
onClick={() => form.handleSubmit()}
|
||||
{...buttonProps}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default SubmitButton
|
||||
@ -0,0 +1,53 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import Label from './label'
|
||||
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('Label Component', () => {
|
||||
const defaultProps = {
|
||||
htmlFor: 'test-input',
|
||||
label: 'Test Label',
|
||||
}
|
||||
|
||||
it('renders basic label correctly', () => {
|
||||
render(<Label {...defaultProps} />)
|
||||
const label = screen.getByTestId('label')
|
||||
expect(label).toBeInTheDocument()
|
||||
expect(label).toHaveAttribute('for', 'test-input')
|
||||
})
|
||||
|
||||
it('shows optional text when showOptional is true', () => {
|
||||
render(<Label {...defaultProps} showOptional />)
|
||||
expect(screen.getByText('common.label.optional')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('shows required asterisk when isRequired is true', () => {
|
||||
render(<Label {...defaultProps} isRequired />)
|
||||
expect(screen.getByText('*')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('renders tooltip when tooltip prop is provided', () => {
|
||||
const tooltipText = 'Test Tooltip'
|
||||
render(<Label {...defaultProps} tooltip={tooltipText} />)
|
||||
const trigger = screen.getByTestId('test-input-tooltip')
|
||||
fireEvent.mouseEnter(trigger)
|
||||
expect(screen.getByText(tooltipText)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('applies custom className when provided', () => {
|
||||
const customClass = 'custom-label'
|
||||
render(<Label {...defaultProps} className={customClass} />)
|
||||
const label = screen.getByTestId('label')
|
||||
expect(label).toHaveClass(customClass)
|
||||
})
|
||||
|
||||
it('does not show optional text and required asterisk simultaneously', () => {
|
||||
render(<Label {...defaultProps} isRequired showOptional />)
|
||||
expect(screen.queryByText('common.label.optional')).not.toBeInTheDocument()
|
||||
expect(screen.getByText('*')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,48 @@
|
||||
import cn from '@/utils/classnames'
|
||||
import Tooltip from '../../tooltip'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
export type LabelProps = {
|
||||
htmlFor: string
|
||||
label: string
|
||||
isRequired?: boolean
|
||||
showOptional?: boolean
|
||||
tooltip?: string
|
||||
className?: string
|
||||
}
|
||||
|
||||
const Label = ({
|
||||
htmlFor,
|
||||
label,
|
||||
isRequired,
|
||||
showOptional,
|
||||
tooltip,
|
||||
className,
|
||||
}: LabelProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className='flex h-6 items-center'>
|
||||
<label
|
||||
data-testid='label'
|
||||
htmlFor={htmlFor}
|
||||
className={cn('system-sm-medium text-text-secondary', className)}
|
||||
>
|
||||
{label}
|
||||
</label>
|
||||
{!isRequired && showOptional && <div className='system-xs-regular ml-1 text-text-tertiary'>{t('common.label.optional')}</div>}
|
||||
{isRequired && <div className='system-xs-regular ml-1 text-text-destructive-secondary'>*</div>}
|
||||
{tooltip && (
|
||||
<Tooltip
|
||||
popupContent={
|
||||
<div className='w-[200px]'>{tooltip}</div>
|
||||
}
|
||||
triggerClassName='ml-0.5 w-4 h-4'
|
||||
triggerTestId={`${htmlFor}-tooltip`}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default Label
|
||||
@ -0,0 +1,35 @@
|
||||
import { withForm } from '../..'
|
||||
import { demoFormOpts } from './shared-options'
|
||||
import { ContactMethods } from './types'
|
||||
|
||||
const ContactFields = withForm({
|
||||
...demoFormOpts,
|
||||
render: ({ form }) => {
|
||||
return (
|
||||
<div className='my-2'>
|
||||
<h3 className='title-lg-bold text-text-primary'>Contacts</h3>
|
||||
<div className='flex flex-col gap-4'>
|
||||
<form.AppField
|
||||
name='contact.email'
|
||||
children={field => <field.TextField label='Email' />}
|
||||
/>
|
||||
<form.AppField
|
||||
name='contact.phone'
|
||||
children={field => <field.TextField label='Phone' />}
|
||||
/>
|
||||
<form.AppField
|
||||
name='contact.preferredContactMethod'
|
||||
children={field => (
|
||||
<field.SelectField
|
||||
label='Preferred Contact Method'
|
||||
options={ContactMethods}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
},
|
||||
})
|
||||
|
||||
export default ContactFields
|
||||
@ -0,0 +1,68 @@
|
||||
import { useStore } from '@tanstack/react-form'
|
||||
import { useAppForm } from '../..'
|
||||
import ContactFields from './contact-fields'
|
||||
import { demoFormOpts } from './shared-options'
|
||||
import { UserSchema } from './types'
|
||||
|
||||
const DemoForm = () => {
|
||||
const form = useAppForm({
|
||||
...demoFormOpts,
|
||||
validators: {
|
||||
onSubmit: ({ value }) => {
|
||||
// Validate the entire form
|
||||
const result = UserSchema.safeParse(value)
|
||||
if (!result.success) {
|
||||
const issues = result.error.issues
|
||||
console.log('Validation errors:', issues)
|
||||
return issues[0].message
|
||||
}
|
||||
return undefined
|
||||
},
|
||||
},
|
||||
onSubmit: ({ value }) => {
|
||||
console.log('Form submitted:', value)
|
||||
},
|
||||
})
|
||||
|
||||
const name = useStore(form.store, state => state.values.name)
|
||||
|
||||
return (
|
||||
<form
|
||||
className='flex w-[400px] flex-col gap-4'
|
||||
onSubmit={(e) => {
|
||||
e.preventDefault()
|
||||
e.stopPropagation()
|
||||
form.handleSubmit()
|
||||
}}
|
||||
>
|
||||
<form.AppField
|
||||
name='name'
|
||||
children={field => (
|
||||
<field.TextField label='Name' />
|
||||
)}
|
||||
/>
|
||||
<form.AppField
|
||||
name='surname'
|
||||
children={field => (
|
||||
<field.TextField label='Surname' />
|
||||
)}
|
||||
/>
|
||||
<form.AppField
|
||||
name='isAcceptingTerms'
|
||||
children={field => (
|
||||
<field.CheckboxField label='I accept the terms and conditions.' />
|
||||
)}
|
||||
/>
|
||||
{
|
||||
!!name && (
|
||||
<ContactFields form={form} />
|
||||
)
|
||||
}
|
||||
<form.AppForm>
|
||||
<form.SubmitButton>Submit</form.SubmitButton>
|
||||
</form.AppForm>
|
||||
</form>
|
||||
)
|
||||
}
|
||||
|
||||
export default DemoForm
|
||||
@ -0,0 +1,14 @@
|
||||
import { formOptions } from '@tanstack/react-form'
|
||||
|
||||
export const demoFormOpts = formOptions({
|
||||
defaultValues: {
|
||||
name: '',
|
||||
surname: '',
|
||||
isAcceptingTerms: false,
|
||||
contact: {
|
||||
email: '',
|
||||
phone: '',
|
||||
preferredContactMethod: 'email',
|
||||
},
|
||||
},
|
||||
})
|
||||
@ -0,0 +1,34 @@
|
||||
import { z } from 'zod'
|
||||
|
||||
const ContactMethod = z.union([
|
||||
z.literal('email'),
|
||||
z.literal('phone'),
|
||||
z.literal('whatsapp'),
|
||||
z.literal('sms'),
|
||||
])
|
||||
|
||||
export const ContactMethods = ContactMethod.options.map(({ value }) => ({
|
||||
value,
|
||||
label: value.charAt(0).toUpperCase() + value.slice(1),
|
||||
}))
|
||||
|
||||
export const UserSchema = z.object({
|
||||
name: z
|
||||
.string()
|
||||
.regex(/^[A-Z]/, 'Name must start with a capital letter')
|
||||
.min(3, 'Name must be at least 3 characters long'),
|
||||
surname: z
|
||||
.string()
|
||||
.min(3, 'Surname must be at least 3 characters long')
|
||||
.regex(/^[A-Z]/, 'Surname must start with a capital letter'),
|
||||
isAcceptingTerms: z.boolean().refine(val => val, {
|
||||
message: 'You must accept the terms and conditions',
|
||||
}),
|
||||
contact: z.object({
|
||||
email: z.string().email('Invalid email address'),
|
||||
phone: z.string().optional(),
|
||||
preferredContactMethod: ContactMethod,
|
||||
}),
|
||||
})
|
||||
|
||||
export type User = z.infer<typeof UserSchema>
|
||||
@ -0,0 +1,25 @@
|
||||
import { createFormHook, createFormHookContexts } from '@tanstack/react-form'
|
||||
import TextField from './components/field/text'
|
||||
import NumberInputField from './components/field/number-input'
|
||||
import CheckboxField from './components/field/checkbox'
|
||||
import SelectField from './components/field/select'
|
||||
import OptionsField from './components/field/options'
|
||||
import SubmitButton from './components/form/submit-button'
|
||||
|
||||
export const { fieldContext, useFieldContext, formContext, useFormContext }
|
||||
= createFormHookContexts()
|
||||
|
||||
export const { useAppForm, withForm } = createFormHook({
|
||||
fieldComponents: {
|
||||
TextField,
|
||||
NumberInputField,
|
||||
CheckboxField,
|
||||
SelectField,
|
||||
OptionsField,
|
||||
},
|
||||
formComponents: {
|
||||
SubmitButton,
|
||||
},
|
||||
fieldContext,
|
||||
formContext,
|
||||
})
|
||||
@ -0,0 +1,5 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g id="arrow-down-round-fill">
|
||||
<path id="Vector" d="M6.02913 6.23572C5.08582 6.23572 4.56482 7.33027 5.15967 8.06239L7.13093 10.4885C7.57922 11.0403 8.42149 11.0403 8.86986 10.4885L10.8411 8.06239C11.4359 7.33027 10.9149 6.23572 9.97158 6.23572H6.02913Z" fill="#101828"/>
|
||||
</g>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 380 B |
@ -0,0 +1,36 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "16",
|
||||
"height": "16",
|
||||
"viewBox": "0 0 16 16",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "g",
|
||||
"attributes": {
|
||||
"id": "arrow-down-round-fill"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"id": "Vector",
|
||||
"d": "M6.02913 6.23572C5.08582 6.23572 4.56482 7.33027 5.15967 8.06239L7.13093 10.4885C7.57922 11.0403 8.42149 11.0403 8.86986 10.4885L10.8411 8.06239C11.4359 7.33027 10.9149 6.23572 9.97158 6.23572H6.02913Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "ArrowDownRoundFill"
|
||||
}
|
||||
@ -0,0 +1,20 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import * as React from 'react'
|
||||
import data from './ArrowDownRoundFill.json'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.MutableRefObject<HTMLOrSVGElement>>;
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'ArrowDownRoundFill'
|
||||
|
||||
export default Icon
|
||||
@ -0,0 +1,97 @@
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { InputNumber } from './index'
|
||||
|
||||
jest.mock('react-i18next', () => ({
|
||||
useTranslation: () => ({
|
||||
t: (key: string) => key,
|
||||
}),
|
||||
}))
|
||||
|
||||
describe('InputNumber Component', () => {
|
||||
const defaultProps = {
|
||||
onChange: jest.fn(),
|
||||
}
|
||||
|
||||
afterEach(() => {
|
||||
jest.clearAllMocks()
|
||||
})
|
||||
|
||||
it('renders input with default values', () => {
|
||||
render(<InputNumber {...defaultProps} />)
|
||||
const input = screen.getByRole('textbox')
|
||||
expect(input).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('handles increment button click', () => {
|
||||
render(<InputNumber {...defaultProps} value={5} />)
|
||||
const incrementBtn = screen.getByRole('button', { name: /increment/i })
|
||||
|
||||
fireEvent.click(incrementBtn)
|
||||
expect(defaultProps.onChange).toHaveBeenCalledWith(6)
|
||||
})
|
||||
|
||||
it('handles decrement button click', () => {
|
||||
render(<InputNumber {...defaultProps} value={5} />)
|
||||
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
|
||||
|
||||
fireEvent.click(decrementBtn)
|
||||
expect(defaultProps.onChange).toHaveBeenCalledWith(4)
|
||||
})
|
||||
|
||||
it('respects max value constraint', () => {
|
||||
render(<InputNumber {...defaultProps} value={10} max={10} />)
|
||||
const incrementBtn = screen.getByRole('button', { name: /increment/i })
|
||||
|
||||
fireEvent.click(incrementBtn)
|
||||
expect(defaultProps.onChange).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('respects min value constraint', () => {
|
||||
render(<InputNumber {...defaultProps} value={0} min={0} />)
|
||||
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
|
||||
|
||||
fireEvent.click(decrementBtn)
|
||||
expect(defaultProps.onChange).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('handles direct input changes', () => {
|
||||
render(<InputNumber {...defaultProps} />)
|
||||
const input = screen.getByRole('textbox')
|
||||
|
||||
fireEvent.change(input, { target: { value: '42' } })
|
||||
expect(defaultProps.onChange).toHaveBeenCalledWith(42)
|
||||
})
|
||||
|
||||
it('handles empty input', () => {
|
||||
render(<InputNumber {...defaultProps} value={0} />)
|
||||
const input = screen.getByRole('textbox')
|
||||
|
||||
fireEvent.change(input, { target: { value: '' } })
|
||||
expect(defaultProps.onChange).toHaveBeenCalledWith(undefined)
|
||||
})
|
||||
|
||||
it('handles invalid input', () => {
|
||||
render(<InputNumber {...defaultProps} />)
|
||||
const input = screen.getByRole('textbox')
|
||||
|
||||
fireEvent.change(input, { target: { value: 'abc' } })
|
||||
expect(defaultProps.onChange).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('displays unit when provided', () => {
|
||||
const unit = 'px'
|
||||
render(<InputNumber {...defaultProps} unit={unit} />)
|
||||
expect(screen.getByText(unit)).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('disables controls when disabled prop is true', () => {
|
||||
render(<InputNumber {...defaultProps} disabled />)
|
||||
const input = screen.getByRole('textbox')
|
||||
const incrementBtn = screen.getByRole('button', { name: /increment/i })
|
||||
const decrementBtn = screen.getByRole('button', { name: /decrement/i })
|
||||
|
||||
expect(input).toBeDisabled()
|
||||
expect(incrementBtn).toBeDisabled()
|
||||
expect(decrementBtn).toBeDisabled()
|
||||
})
|
||||
})
|
||||
@ -0,0 +1,37 @@
|
||||
import abcjs from 'abcjs'
|
||||
import { useEffect, useRef } from 'react'
|
||||
import 'abcjs/abcjs-audio.css'
|
||||
|
||||
const MarkdownMusic = ({ children }: { children: React.ReactNode }) => {
|
||||
const containerRef = useRef<HTMLDivElement>(null)
|
||||
const controlsRef = useRef<HTMLDivElement>(null)
|
||||
|
||||
useEffect(() => {
|
||||
if (containerRef.current && controlsRef.current) {
|
||||
if (typeof children === 'string') {
|
||||
const visualObjs = abcjs.renderAbc(containerRef.current, children, {
|
||||
add_classes: true, // Add classes to SVG elements for cursor tracking
|
||||
responsive: 'resize', // Make notation responsive
|
||||
})
|
||||
const synthControl = new abcjs.synth.SynthController()
|
||||
synthControl.load(controlsRef.current, {}, { displayPlay: true })
|
||||
const synth = new abcjs.synth.CreateSynth()
|
||||
const visualObj = visualObjs[0]
|
||||
synth.init({ visualObj }).then(() => {
|
||||
synthControl.setTune(visualObj, false)
|
||||
})
|
||||
containerRef.current.style.overflow = 'auto'
|
||||
}
|
||||
}
|
||||
}, [children])
|
||||
|
||||
return (
|
||||
<div style={{ minWidth: '100%', overflow: 'auto' }}>
|
||||
<div ref={containerRef} />
|
||||
<div ref={controlsRef} />
|
||||
</div>
|
||||
)
|
||||
}
|
||||
MarkdownMusic.displayName = 'MarkdownMusic'
|
||||
|
||||
export default MarkdownMusic
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue