Merge branch 'main' into feat/rag-pipeline
commit
d238da9826
@ -1,3 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd api && poetry install
|
||||
cd api && uv sync
|
||||
|
||||
@ -1,36 +0,0 @@
|
||||
name: Setup Poetry and Python
|
||||
|
||||
inputs:
|
||||
python-version:
|
||||
description: Python version to use and the Poetry installed with
|
||||
required: true
|
||||
default: '3.11'
|
||||
poetry-version:
|
||||
description: Poetry version to set up
|
||||
required: true
|
||||
default: '2.0.1'
|
||||
poetry-lockfile:
|
||||
description: Path to the Poetry lockfile to restore cache from
|
||||
required: true
|
||||
default: ''
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Set up Python ${{ inputs.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
cache: pip
|
||||
|
||||
- name: Install Poetry
|
||||
shell: bash
|
||||
run: pip install poetry==${{ inputs.poetry-version }}
|
||||
|
||||
- name: Restore Poetry cache
|
||||
if: ${{ inputs.poetry-lockfile != '' }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
cache: poetry
|
||||
cache-dependency-path: ${{ inputs.poetry-lockfile }}
|
||||
@ -0,0 +1,34 @@
|
||||
name: Setup UV and Python
|
||||
|
||||
inputs:
|
||||
python-version:
|
||||
description: Python version to use and the UV installed with
|
||||
required: true
|
||||
default: '3.12'
|
||||
uv-version:
|
||||
description: UV version to set up
|
||||
required: true
|
||||
default: '0.6.14'
|
||||
uv-lockfile:
|
||||
description: Path to the UV lockfile to restore cache from
|
||||
required: true
|
||||
default: ''
|
||||
enable-cache:
|
||||
required: true
|
||||
default: true
|
||||
|
||||
runs:
|
||||
using: composite
|
||||
steps:
|
||||
- name: Set up Python ${{ inputs.python-version }}
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ inputs.python-version }}
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
version: ${{ inputs.uv-version }}
|
||||
python-version: ${{ inputs.python-version }}
|
||||
enable-cache: ${{ inputs.enable-cache }}
|
||||
cache-dependency-glob: ${{ inputs.uv-lockfile }}
|
||||
@ -0,0 +1,45 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from configs import dify_config
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS
|
||||
|
||||
|
||||
def get_parameters_from_feature_dict(
|
||||
*, features_dict: Mapping[str, Any], user_input_form: list[dict[str, Any]]
|
||||
) -> Mapping[str, Any]:
|
||||
"""
|
||||
Mapping from feature dict to webapp parameters
|
||||
"""
|
||||
return {
|
||||
"opening_statement": features_dict.get("opening_statement"),
|
||||
"suggested_questions": features_dict.get("suggested_questions", []),
|
||||
"suggested_questions_after_answer": features_dict.get("suggested_questions_after_answer", {"enabled": False}),
|
||||
"speech_to_text": features_dict.get("speech_to_text", {"enabled": False}),
|
||||
"text_to_speech": features_dict.get("text_to_speech", {"enabled": False}),
|
||||
"retriever_resource": features_dict.get("retriever_resource", {"enabled": False}),
|
||||
"annotation_reply": features_dict.get("annotation_reply", {"enabled": False}),
|
||||
"more_like_this": features_dict.get("more_like_this", {"enabled": False}),
|
||||
"user_input_form": user_input_form,
|
||||
"sensitive_word_avoidance": features_dict.get(
|
||||
"sensitive_word_avoidance", {"enabled": False, "type": "", "configs": []}
|
||||
),
|
||||
"file_upload": features_dict.get(
|
||||
"file_upload",
|
||||
{
|
||||
"image": {
|
||||
"enabled": False,
|
||||
"number_limits": DEFAULT_FILE_NUMBER_LIMITS,
|
||||
"detail": "high",
|
||||
"transfer_methods": ["remote_url", "local_file"],
|
||||
}
|
||||
},
|
||||
),
|
||||
"system_parameters": {
|
||||
"image_file_size_limit": dify_config.UPLOAD_IMAGE_FILE_SIZE_LIMIT,
|
||||
"video_file_size_limit": dify_config.UPLOAD_VIDEO_FILE_SIZE_LIMIT,
|
||||
"audio_file_size_limit": dify_config.UPLOAD_AUDIO_FILE_SIZE_LIMIT,
|
||||
"file_size_limit": dify_config.UPLOAD_FILE_SIZE_LIMIT,
|
||||
"workflow_file_upload_limit": dify_config.WORKFLOW_FILE_UPLOAD_LIMIT,
|
||||
},
|
||||
}
|
||||
@ -0,0 +1,15 @@
|
||||
"""
|
||||
Repository interfaces for data access.
|
||||
|
||||
This package contains repository interfaces that define the contract
|
||||
for accessing and manipulating data, regardless of the underlying
|
||||
storage mechanism.
|
||||
"""
|
||||
|
||||
from core.repository.repository_factory import RepositoryFactory
|
||||
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"RepositoryFactory",
|
||||
"WorkflowNodeExecutionRepository",
|
||||
]
|
||||
@ -0,0 +1,97 @@
|
||||
"""
|
||||
Repository factory for creating repository instances.
|
||||
|
||||
This module provides a simple factory interface for creating repository instances.
|
||||
It does not contain any implementation details or dependencies on specific repositories.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any, Literal, Optional, cast
|
||||
|
||||
from core.repository.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
|
||||
# Type for factory functions - takes a dict of parameters and returns any repository type
|
||||
RepositoryFactoryFunc = Callable[[Mapping[str, Any]], Any]
|
||||
|
||||
# Type for workflow node execution factory function
|
||||
WorkflowNodeExecutionFactoryFunc = Callable[[Mapping[str, Any]], WorkflowNodeExecutionRepository]
|
||||
|
||||
# Repository type literals
|
||||
_RepositoryType = Literal["workflow_node_execution"]
|
||||
|
||||
|
||||
class RepositoryFactory:
|
||||
"""
|
||||
Factory class for creating repository instances.
|
||||
|
||||
This factory delegates the actual repository creation to implementation-specific
|
||||
factory functions that are registered with the factory at runtime.
|
||||
"""
|
||||
|
||||
# Dictionary to store factory functions
|
||||
_factory_functions: dict[str, RepositoryFactoryFunc] = {}
|
||||
|
||||
@classmethod
|
||||
def _register_factory(cls, repository_type: _RepositoryType, factory_func: RepositoryFactoryFunc) -> None:
|
||||
"""
|
||||
Register a factory function for a specific repository type.
|
||||
This is a private method and should not be called directly.
|
||||
|
||||
Args:
|
||||
repository_type: The type of repository (e.g., 'workflow_node_execution')
|
||||
factory_func: A function that takes parameters and returns a repository instance
|
||||
"""
|
||||
cls._factory_functions[repository_type] = factory_func
|
||||
|
||||
@classmethod
|
||||
def _create_repository(cls, repository_type: _RepositoryType, params: Optional[Mapping[str, Any]] = None) -> Any:
|
||||
"""
|
||||
Create a new repository instance with the provided parameters.
|
||||
This is a private method and should not be called directly.
|
||||
|
||||
Args:
|
||||
repository_type: The type of repository to create
|
||||
params: A dictionary of parameters to pass to the factory function
|
||||
|
||||
Returns:
|
||||
A new instance of the requested repository
|
||||
|
||||
Raises:
|
||||
ValueError: If no factory function is registered for the repository type
|
||||
"""
|
||||
if repository_type not in cls._factory_functions:
|
||||
raise ValueError(f"No factory function registered for repository type '{repository_type}'")
|
||||
|
||||
# Use empty dict if params is None
|
||||
params = params or {}
|
||||
|
||||
return cls._factory_functions[repository_type](params)
|
||||
|
||||
@classmethod
|
||||
def register_workflow_node_execution_factory(cls, factory_func: WorkflowNodeExecutionFactoryFunc) -> None:
|
||||
"""
|
||||
Register a factory function for the workflow node execution repository.
|
||||
|
||||
Args:
|
||||
factory_func: A function that takes parameters and returns a WorkflowNodeExecutionRepository instance
|
||||
"""
|
||||
cls._register_factory("workflow_node_execution", factory_func)
|
||||
|
||||
@classmethod
|
||||
def create_workflow_node_execution_repository(
|
||||
cls, params: Optional[Mapping[str, Any]] = None
|
||||
) -> WorkflowNodeExecutionRepository:
|
||||
"""
|
||||
Create a new WorkflowNodeExecutionRepository instance with the provided parameters.
|
||||
|
||||
Args:
|
||||
params: A dictionary of parameters to pass to the factory function
|
||||
|
||||
Returns:
|
||||
A new instance of the WorkflowNodeExecutionRepository
|
||||
|
||||
Raises:
|
||||
ValueError: If no factory function is registered for the workflow_node_execution repository type
|
||||
"""
|
||||
# We can safely cast here because we've registered a WorkflowNodeExecutionFactoryFunc
|
||||
return cast(WorkflowNodeExecutionRepository, cls._create_repository("workflow_node_execution", params))
|
||||
@ -0,0 +1,88 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, Optional, Protocol
|
||||
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
|
||||
|
||||
@dataclass
|
||||
class OrderConfig:
|
||||
"""Configuration for ordering WorkflowNodeExecution instances."""
|
||||
|
||||
order_by: list[str]
|
||||
order_direction: Optional[Literal["asc", "desc"]] = None
|
||||
|
||||
|
||||
class WorkflowNodeExecutionRepository(Protocol):
|
||||
"""
|
||||
Repository interface for WorkflowNodeExecution.
|
||||
|
||||
This interface defines the contract for accessing and manipulating
|
||||
WorkflowNodeExecution data, regardless of the underlying storage mechanism.
|
||||
|
||||
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
|
||||
and trigger sources (triggered_from) should be handled at the implementation level, not in
|
||||
the core interface. This keeps the core domain model clean and independent of specific
|
||||
application domains or deployment scenarios.
|
||||
"""
|
||||
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save a WorkflowNodeExecution instance.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to save
|
||||
"""
|
||||
...
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve a WorkflowNodeExecution by its node_execution_id.
|
||||
|
||||
Args:
|
||||
node_execution_id: The node execution ID
|
||||
|
||||
Returns:
|
||||
The WorkflowNodeExecution instance if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
def get_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
order_config: Optional configuration for ordering results
|
||||
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
|
||||
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||
|
||||
Returns:
|
||||
A list of WorkflowNodeExecution instances
|
||||
"""
|
||||
...
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
|
||||
Returns:
|
||||
A list of running WorkflowNodeExecution instances
|
||||
"""
|
||||
...
|
||||
|
||||
def update(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Update an existing WorkflowNodeExecution instance.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to update
|
||||
"""
|
||||
...
|
||||
@ -0,0 +1,18 @@
|
||||
"""
|
||||
Extension for initializing repositories.
|
||||
|
||||
This extension registers repository implementations with the RepositoryFactory.
|
||||
"""
|
||||
|
||||
from dify_app import DifyApp
|
||||
from repositories.repository_registry import register_repositories
|
||||
|
||||
|
||||
def init_app(_app: DifyApp) -> None:
|
||||
"""
|
||||
Initialize repository implementations.
|
||||
|
||||
Args:
|
||||
_app: The Flask application instance (unused)
|
||||
"""
|
||||
register_repositories()
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,4 +0,0 @@
|
||||
[virtualenvs]
|
||||
in-project = true
|
||||
create = true
|
||||
prefer-active-python = true
|
||||
@ -1,215 +1,197 @@
|
||||
[project]
|
||||
name = "dify-api"
|
||||
version = "1.2.0"
|
||||
requires-python = ">=3.11,<3.13"
|
||||
dynamic = ["dependencies"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=2.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
dependencies = [
|
||||
"authlib==1.3.1",
|
||||
"azure-identity==1.16.1",
|
||||
"beautifulsoup4==4.12.2",
|
||||
"boto3==1.35.99",
|
||||
"bs4~=0.0.1",
|
||||
"cachetools~=5.3.0",
|
||||
"celery~=5.4.0",
|
||||
"chardet~=5.1.0",
|
||||
"flask~=3.1.0",
|
||||
"flask-compress~=1.17",
|
||||
"flask-cors~=4.0.0",
|
||||
"flask-login~=0.6.3",
|
||||
"flask-migrate~=4.0.7",
|
||||
"flask-restful~=0.3.10",
|
||||
"flask-sqlalchemy~=3.1.1",
|
||||
"gevent~=24.11.1",
|
||||
"gmpy2~=2.2.1",
|
||||
"google-api-core==2.18.0",
|
||||
"google-api-python-client==2.90.0",
|
||||
"google-auth==2.29.0",
|
||||
"google-auth-httplib2==0.2.0",
|
||||
"google-cloud-aiplatform==1.49.0",
|
||||
"googleapis-common-protos==1.63.0",
|
||||
"gunicorn~=23.0.0",
|
||||
"httpx[socks]~=0.27.0",
|
||||
"jieba==0.42.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.1.77",
|
||||
"mailchimp-transactional~=1.0.50",
|
||||
"markdown~=3.5.1",
|
||||
"numpy~=1.26.4",
|
||||
"oci~=2.135.1",
|
||||
"openai~=1.61.0",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.3.4",
|
||||
"opentelemetry-api==1.27.0",
|
||||
"opentelemetry-distro==0.48b0",
|
||||
"opentelemetry-exporter-otlp==1.27.0",
|
||||
"opentelemetry-exporter-otlp-proto-common==1.27.0",
|
||||
"opentelemetry-exporter-otlp-proto-grpc==1.27.0",
|
||||
"opentelemetry-exporter-otlp-proto-http==1.27.0",
|
||||
"opentelemetry-instrumentation==0.48b0",
|
||||
"opentelemetry-instrumentation-celery==0.48b0",
|
||||
"opentelemetry-instrumentation-flask==0.48b0",
|
||||
"opentelemetry-instrumentation-sqlalchemy==0.48b0",
|
||||
"opentelemetry-propagator-b3==1.27.0",
|
||||
# opentelemetry-proto1.28.0 depends on protobuf (>=5.0,<6.0),
|
||||
# which is conflict with googleapis-common-protos (1.63.0)
|
||||
"opentelemetry-proto==1.27.0",
|
||||
"opentelemetry-sdk==1.27.0",
|
||||
"opentelemetry-semantic-conventions==0.48b0",
|
||||
"opentelemetry-util-http==0.48b0",
|
||||
"pandas-stubs~=2.2.3.241009",
|
||||
"pandas[excel,output-formatting,performance]~=2.2.2",
|
||||
"pandoc~=2.4",
|
||||
"psycogreen~=1.0.2",
|
||||
"psycopg2-binary~=2.9.6",
|
||||
"pycryptodome==3.19.1",
|
||||
"pydantic~=2.9.2",
|
||||
"pydantic-extra-types~=2.9.0",
|
||||
"pydantic-settings~=2.6.0",
|
||||
"pyjwt~=2.8.0",
|
||||
"pypdfium2~=4.30.0",
|
||||
"python-docx~=1.1.0",
|
||||
"python-dotenv==1.0.1",
|
||||
"pyyaml~=6.0.1",
|
||||
"readabilipy==0.2.0",
|
||||
"redis[hiredis]~=5.0.3",
|
||||
"resend~=0.7.0",
|
||||
"sentry-sdk[flask]~=1.44.1",
|
||||
"sqlalchemy~=2.0.29",
|
||||
"starlette==0.41.0",
|
||||
"tiktoken~=0.8.0",
|
||||
"tokenizers~=0.15.0",
|
||||
"transformers~=4.35.0",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
|
||||
"validators==0.21.0",
|
||||
"yarl~=1.18.3",
|
||||
]
|
||||
# Before adding new dependency, consider place it in
|
||||
# alphabet order (a-z) and suitable group.
|
||||
|
||||
[tool.poetry]
|
||||
package-mode = false
|
||||
[tool.uv]
|
||||
default-groups = ["storage", "tools", "vdb"]
|
||||
|
||||
############################################################
|
||||
# [ Main ] Dependency group
|
||||
############################################################
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
authlib = "1.3.1"
|
||||
azure-identity = "1.16.1"
|
||||
beautifulsoup4 = "4.12.2"
|
||||
boto3 = "1.35.99"
|
||||
bs4 = "~0.0.1"
|
||||
cachetools = "~5.3.0"
|
||||
celery = "~5.4.0"
|
||||
chardet = "~5.1.0"
|
||||
flask = "~3.1.0"
|
||||
flask-compress = "~1.17"
|
||||
flask-cors = "~4.0.0"
|
||||
flask-login = "~0.6.3"
|
||||
flask-migrate = "~4.0.7"
|
||||
flask-restful = "~0.3.10"
|
||||
flask-sqlalchemy = "~3.1.1"
|
||||
gevent = "~24.11.1"
|
||||
gmpy2 = "~2.2.1"
|
||||
google-api-core = "2.18.0"
|
||||
google-api-python-client = "2.90.0"
|
||||
google-auth = "2.29.0"
|
||||
google-auth-httplib2 = "0.2.0"
|
||||
google-cloud-aiplatform = "1.49.0"
|
||||
googleapis-common-protos = "1.63.0"
|
||||
gunicorn = "~23.0.0"
|
||||
httpx = { version = "~0.27.0", extras = ["socks"] }
|
||||
jieba = "0.42.1"
|
||||
langfuse = "~2.51.3"
|
||||
langsmith = "~0.1.77"
|
||||
mailchimp-transactional = "~1.0.50"
|
||||
markdown = "~3.5.1"
|
||||
numpy = "~1.26.4"
|
||||
oci = "~2.135.1"
|
||||
openai = "~1.61.0"
|
||||
openpyxl = "~3.1.5"
|
||||
opentelemetry-api = "1.27.0"
|
||||
opentelemetry-distro = "0.48b0"
|
||||
opentelemetry-exporter-otlp = "1.27.0"
|
||||
opentelemetry-exporter-otlp-proto-common = "1.27.0"
|
||||
opentelemetry-exporter-otlp-proto-grpc = "1.27.0"
|
||||
opentelemetry-exporter-otlp-proto-http = "1.27.0"
|
||||
opentelemetry-instrumentation = "0.48b0"
|
||||
opentelemetry-instrumentation-celery = "0.48b0"
|
||||
opentelemetry-instrumentation-flask = "0.48b0"
|
||||
opentelemetry-instrumentation-sqlalchemy = "0.48b0"
|
||||
opentelemetry-propagator-b3 = "1.27.0"
|
||||
opentelemetry-proto = "1.27.0" # 1.28.0 depends on protobuf (>=5.0,<6.0), conflict with googleapis-common-protos (1.63.0)
|
||||
opentelemetry-sdk = "1.27.0"
|
||||
opentelemetry-semantic-conventions = "0.48b0"
|
||||
opentelemetry-util-http = "0.48b0"
|
||||
opik = "~1.3.4"
|
||||
pandas = { version = "~2.2.2", extras = ["performance", "excel", "output-formatting"] }
|
||||
pandas-stubs = "~2.2.3.241009"
|
||||
pandoc = "~2.4"
|
||||
psycogreen = "~1.0.2"
|
||||
psycopg2-binary = "~2.9.6"
|
||||
pycryptodome = "3.19.1"
|
||||
pydantic = "~2.9.2"
|
||||
pydantic-settings = "~2.6.0"
|
||||
pydantic_extra_types = "~2.9.0"
|
||||
pyjwt = "~2.8.0"
|
||||
pypdfium2 = "~4.30.0"
|
||||
python = ">=3.11,<3.13"
|
||||
python-docx = "~1.1.0"
|
||||
python-dotenv = "1.0.1"
|
||||
pyyaml = "~6.0.1"
|
||||
readabilipy = "0.2.0"
|
||||
redis = { version = "~5.0.3", extras = ["hiredis"] }
|
||||
resend = "~0.7.0"
|
||||
sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
|
||||
sqlalchemy = "~2.0.29"
|
||||
starlette = "0.41.0"
|
||||
tiktoken = "~0.8.0"
|
||||
tokenizers = "~0.15.0"
|
||||
transformers = "~4.35.0"
|
||||
unstructured = { version = "~0.16.1", extras = ["docx", "epub", "md", "ppt", "pptx"] }
|
||||
validators = "0.21.0"
|
||||
yarl = "~1.18.3"
|
||||
# Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
|
||||
|
||||
############################################################
|
||||
# [ Indirect ] dependency group
|
||||
# Related transparent dependencies with pinned version
|
||||
# required by main implementations
|
||||
############################################################
|
||||
[tool.poetry.group.indirect.dependencies]
|
||||
kaleido = "0.2.1"
|
||||
rank-bm25 = "~0.2.2"
|
||||
safetensors = "~0.4.3"
|
||||
[dependency-groups]
|
||||
|
||||
############################################################
|
||||
# [ Tools ] dependency group
|
||||
# [ Dev ] dependency group
|
||||
# Required for development and running tests
|
||||
############################################################
|
||||
[tool.poetry.group.tools.dependencies]
|
||||
cloudscraper = "1.2.71"
|
||||
nltk = "3.9.1"
|
||||
dev = [
|
||||
"coverage~=7.2.4",
|
||||
"dotenv-linter~=0.5.0",
|
||||
"faker~=32.1.0",
|
||||
"lxml-stubs~=0.5.1",
|
||||
"mypy~=1.15.0",
|
||||
"ruff~=0.11.5",
|
||||
"pytest~=8.3.2",
|
||||
"pytest-benchmark~=4.0.0",
|
||||
"pytest-cov~=4.1.0",
|
||||
"pytest-env~=1.1.3",
|
||||
"pytest-mock~=3.14.0",
|
||||
"types-aiofiles~=24.1.0",
|
||||
"types-beautifulsoup4~=4.12.0",
|
||||
"types-cachetools~=5.5.0",
|
||||
"types-colorama~=0.4.15",
|
||||
"types-defusedxml~=0.7.0",
|
||||
"types-deprecated~=1.2.15",
|
||||
"types-docutils~=0.21.0",
|
||||
"types-flask-cors~=5.0.0",
|
||||
"types-flask-migrate~=4.1.0",
|
||||
"types-gevent~=24.11.0",
|
||||
"types-greenlet~=3.1.0",
|
||||
"types-html5lib~=1.1.11",
|
||||
"types-markdown~=3.7.0",
|
||||
"types-oauthlib~=3.2.0",
|
||||
"types-objgraph~=3.6.0",
|
||||
"types-olefile~=0.47.0",
|
||||
"types-openpyxl~=3.1.5",
|
||||
"types-pexpect~=4.9.0",
|
||||
"types-protobuf~=5.29.1",
|
||||
"types-psutil~=7.0.0",
|
||||
"types-psycopg2~=2.9.21",
|
||||
"types-pygments~=2.19.0",
|
||||
"types-pymysql~=1.1.0",
|
||||
"types-python-dateutil~=2.9.0",
|
||||
"types-pywin32~=310.0.0",
|
||||
"types-pyyaml~=6.0.12",
|
||||
"types-regex~=2024.11.6",
|
||||
"types-requests~=2.32.0",
|
||||
"types-requests-oauthlib~=2.0.0",
|
||||
"types-shapely~=2.0.0",
|
||||
"types-simplejson~=3.20.0",
|
||||
"types-six~=1.17.0",
|
||||
"types-tensorflow~=2.18.0",
|
||||
"types-tqdm~=4.67.0",
|
||||
"types-ujson~=5.10.0",
|
||||
]
|
||||
|
||||
############################################################
|
||||
# [ Storage ] dependency group
|
||||
# Required for storage clients
|
||||
############################################################
|
||||
[tool.poetry.group.storage.dependencies]
|
||||
azure-storage-blob = "12.13.0"
|
||||
bce-python-sdk = "~0.9.23"
|
||||
cos-python-sdk-v5 = "1.9.30"
|
||||
esdk-obs-python = "3.24.6.1"
|
||||
google-cloud-storage = "2.16.0"
|
||||
opendal = "~0.45.16"
|
||||
oss2 = "2.18.5"
|
||||
supabase = "~2.8.1"
|
||||
tos = "~2.7.1"
|
||||
storage = [
|
||||
"azure-storage-blob==12.13.0",
|
||||
"bce-python-sdk~=0.9.23",
|
||||
"cos-python-sdk-v5==1.9.30",
|
||||
"esdk-obs-python==3.24.6.1",
|
||||
"google-cloud-storage==2.16.0",
|
||||
"opendal~=0.45.16",
|
||||
"oss2==2.18.5",
|
||||
"supabase~=2.8.1",
|
||||
"tos~=2.7.1",
|
||||
]
|
||||
|
||||
############################################################
|
||||
# [ VDB ] dependency group
|
||||
# Required by vector store clients
|
||||
############################################################
|
||||
[tool.poetry.group.vdb.dependencies]
|
||||
alibabacloud_gpdb20160503 = "~3.8.0"
|
||||
alibabacloud_tea_openapi = "~0.3.9"
|
||||
chromadb = "0.5.20"
|
||||
clickhouse-connect = "~0.7.16"
|
||||
couchbase = "~4.3.0"
|
||||
elasticsearch = "8.14.0"
|
||||
opensearch-py = "2.4.0"
|
||||
oracledb = "~2.2.1"
|
||||
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
|
||||
pgvector = "0.2.5"
|
||||
pymilvus = "~2.5.0"
|
||||
pymochow = "1.3.1"
|
||||
pyobvector = "~0.1.6"
|
||||
qdrant-client = "1.7.3"
|
||||
tablestore = "6.1.0"
|
||||
tcvectordb = "~1.6.4"
|
||||
tidb-vector = "0.0.9"
|
||||
upstash-vector = "0.6.0"
|
||||
volcengine-compat = "~1.0.156"
|
||||
weaviate-client = "~3.21.0"
|
||||
xinference-client = "~1.2.2"
|
||||
|
||||
############################################################
|
||||
# [ Dev ] dependency group
|
||||
# Required for development and running tests
|
||||
# [ Tools ] dependency group
|
||||
############################################################
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
coverage = "~7.2.4"
|
||||
faker = "~32.1.0"
|
||||
lxml-stubs = "~0.5.1"
|
||||
mypy = "~1.15.0"
|
||||
pytest = "~8.3.2"
|
||||
pytest-benchmark = "~4.0.0"
|
||||
pytest-env = "~1.1.3"
|
||||
pytest-mock = "~3.14.0"
|
||||
types-aiofiles = "~24.1.0"
|
||||
types-beautifulsoup4 = "~4.12.0"
|
||||
types-cachetools = "~5.5.0"
|
||||
types-colorama = "~0.4.15"
|
||||
types-defusedxml = "~0.7.0"
|
||||
types-deprecated = "~1.2.15"
|
||||
types-docutils = "~0.21.0"
|
||||
types-flask-cors = "~5.0.0"
|
||||
types-flask-migrate = "~4.1.0"
|
||||
types-gevent = "~24.11.0"
|
||||
types-greenlet = "~3.1.0"
|
||||
types-html5lib = "~1.1.11"
|
||||
types-markdown = "~3.7.0"
|
||||
types-oauthlib = "~3.2.0"
|
||||
types-objgraph = "~3.6.0"
|
||||
types-olefile = "~0.47.0"
|
||||
types-openpyxl = "~3.1.5"
|
||||
types-pexpect = "~4.9.0"
|
||||
types-protobuf = "~5.29.1"
|
||||
types-psutil = "~7.0.0"
|
||||
types-psycopg2 = "~2.9.21"
|
||||
types-pygments = "~2.19.0"
|
||||
types-pymysql = "~1.1.0"
|
||||
types-python-dateutil = "~2.9.0"
|
||||
types-pywin32 = "~310.0.0"
|
||||
types-pyyaml = "~6.0.12"
|
||||
types-regex = "~2024.11.6"
|
||||
types-requests = "~2.32.0"
|
||||
types-requests-oauthlib = "~2.0.0"
|
||||
types-shapely = "~2.0.0"
|
||||
types-simplejson = "~3.20.0"
|
||||
types-six = "~1.17.0"
|
||||
types-tensorflow = "~2.18.0"
|
||||
types-tqdm = "~4.67.0"
|
||||
types-ujson = "~5.10.0"
|
||||
tools = [
|
||||
"cloudscraper~=1.2.71",
|
||||
"nltk~=3.9.1",
|
||||
]
|
||||
|
||||
############################################################
|
||||
# [ Lint ] dependency group
|
||||
# Required for code style linting
|
||||
# [ VDB ] dependency group
|
||||
# Required by vector store clients
|
||||
############################################################
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
dotenv-linter = "~0.5.0"
|
||||
ruff = "~0.11.0"
|
||||
vdb = [
|
||||
"alibabacloud_gpdb20160503~=3.8.0",
|
||||
"alibabacloud_tea_openapi~=0.3.9",
|
||||
"chromadb==0.5.20",
|
||||
"clickhouse-connect~=0.7.16",
|
||||
"couchbase~=4.3.0",
|
||||
"elasticsearch==8.14.0",
|
||||
"opensearch-py==2.4.0",
|
||||
"oracledb~=2.2.1",
|
||||
"pgvecto-rs[sqlalchemy]~=0.2.1",
|
||||
"pgvector==0.2.5",
|
||||
"pymilvus~=2.5.0",
|
||||
"pymochow==1.3.1",
|
||||
"pyobvector~=0.1.6",
|
||||
"qdrant-client==1.7.3",
|
||||
"tablestore==6.1.0",
|
||||
"tcvectordb~=1.6.4",
|
||||
"tidb-vector==0.0.9",
|
||||
"upstash-vector==0.6.0",
|
||||
"volcengine-compat~=1.0.156",
|
||||
"weaviate-client~=3.21.0",
|
||||
"xinference-client~=1.2.2",
|
||||
]
|
||||
|
||||
@ -0,0 +1,6 @@
|
||||
"""
|
||||
Repository implementations for data access.
|
||||
|
||||
This package contains concrete implementations of the repository interfaces
|
||||
defined in the core.repository package.
|
||||
"""
|
||||
@ -0,0 +1,87 @@
|
||||
"""
|
||||
Registry for repository implementations.
|
||||
|
||||
This module is responsible for registering factory functions with the repository factory.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from core.repository.repository_factory import RepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from repositories.workflow_node_execution import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Storage type constants
|
||||
STORAGE_TYPE_RDBMS = "rdbms"
|
||||
STORAGE_TYPE_HYBRID = "hybrid"
|
||||
|
||||
|
||||
def register_repositories() -> None:
|
||||
"""
|
||||
Register repository factory functions with the RepositoryFactory.
|
||||
|
||||
This function reads configuration settings to determine which repository
|
||||
implementations to register.
|
||||
"""
|
||||
# Configure WorkflowNodeExecutionRepository factory based on configuration
|
||||
workflow_node_execution_storage = dify_config.WORKFLOW_NODE_EXECUTION_STORAGE
|
||||
|
||||
# Check storage type and register appropriate implementation
|
||||
if workflow_node_execution_storage == STORAGE_TYPE_RDBMS:
|
||||
# Register SQLAlchemy implementation for RDBMS storage
|
||||
logger.info("Registering WorkflowNodeExecution repository with RDBMS storage")
|
||||
RepositoryFactory.register_workflow_node_execution_factory(create_workflow_node_execution_repository)
|
||||
elif workflow_node_execution_storage == STORAGE_TYPE_HYBRID:
|
||||
# Hybrid storage is not yet implemented
|
||||
raise NotImplementedError("Hybrid storage for WorkflowNodeExecution repository is not yet implemented")
|
||||
else:
|
||||
# Unknown storage type
|
||||
raise ValueError(
|
||||
f"Unknown storage type '{workflow_node_execution_storage}' for WorkflowNodeExecution repository. "
|
||||
f"Supported types: {STORAGE_TYPE_RDBMS}"
|
||||
)
|
||||
|
||||
|
||||
def create_workflow_node_execution_repository(params: Mapping[str, Any]) -> SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
"""
|
||||
Create a WorkflowNodeExecutionRepository instance using SQLAlchemy implementation.
|
||||
|
||||
This factory function creates a repository for the RDBMS storage type.
|
||||
|
||||
Args:
|
||||
params: Parameters for creating the repository, including:
|
||||
- tenant_id: Required. The tenant ID for multi-tenancy.
|
||||
- app_id: Optional. The application ID for filtering.
|
||||
- session_factory: Optional. A SQLAlchemy sessionmaker instance. If not provided,
|
||||
a new sessionmaker will be created using the global database engine.
|
||||
|
||||
Returns:
|
||||
A WorkflowNodeExecutionRepository instance
|
||||
|
||||
Raises:
|
||||
ValueError: If required parameters are missing
|
||||
"""
|
||||
# Extract required parameters
|
||||
tenant_id = params.get("tenant_id")
|
||||
if tenant_id is None:
|
||||
raise ValueError("tenant_id is required for WorkflowNodeExecution repository with RDBMS storage")
|
||||
|
||||
# Extract optional parameters
|
||||
app_id = params.get("app_id")
|
||||
|
||||
# Use the session_factory from params if provided, otherwise create one using the global db engine
|
||||
session_factory = params.get("session_factory")
|
||||
if session_factory is None:
|
||||
# Create a sessionmaker using the same engine as the global db session
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
|
||||
# Create and return the repository
|
||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
|
||||
)
|
||||
@ -0,0 +1,9 @@
|
||||
"""
|
||||
WorkflowNodeExecution repository implementations.
|
||||
"""
|
||||
|
||||
from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
]
|
||||
@ -0,0 +1,170 @@
|
||||
"""
|
||||
SQLAlchemy implementation of the WorkflowNodeExecutionRepository.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import UnaryExpression, asc, desc, select
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.repository.workflow_node_execution_repository import OrderConfig
|
||||
from models.workflow import WorkflowNodeExecution, WorkflowNodeExecutionStatus, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SQLAlchemyWorkflowNodeExecutionRepository:
|
||||
"""
|
||||
SQLAlchemy implementation of the WorkflowNodeExecutionRepository interface.
|
||||
|
||||
This implementation supports multi-tenancy by filtering operations based on tenant_id.
|
||||
Each method creates its own session, handles the transaction, and commits changes
|
||||
to the database. This prevents long-running connections in the workflow core.
|
||||
"""
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine, tenant_id: str, app_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize the repository with a SQLAlchemy sessionmaker or engine and tenant context.
|
||||
|
||||
Args:
|
||||
session_factory: SQLAlchemy sessionmaker or engine for creating sessions
|
||||
tenant_id: Tenant ID for multi-tenancy
|
||||
app_id: Optional app ID for filtering by application
|
||||
"""
|
||||
# If an engine is provided, create a sessionmaker from it
|
||||
if isinstance(session_factory, Engine):
|
||||
self._session_factory = sessionmaker(bind=session_factory, expire_on_commit=False)
|
||||
else:
|
||||
self._session_factory = session_factory
|
||||
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
|
||||
def save(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Save a WorkflowNodeExecution instance and commit changes to the database.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to save
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
# Ensure tenant_id is set
|
||||
if not execution.tenant_id:
|
||||
execution.tenant_id = self._tenant_id
|
||||
|
||||
# Set app_id if provided and not already set
|
||||
if self._app_id and not execution.app_id:
|
||||
execution.app_id = self._app_id
|
||||
|
||||
session.add(execution)
|
||||
session.commit()
|
||||
|
||||
def get_by_node_execution_id(self, node_execution_id: str) -> Optional[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve a WorkflowNodeExecution by its node_execution_id.
|
||||
|
||||
Args:
|
||||
node_execution_id: The node execution ID
|
||||
|
||||
Returns:
|
||||
The WorkflowNodeExecution instance if found, None otherwise
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.node_execution_id == node_execution_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
|
||||
return session.scalar(stmt)
|
||||
|
||||
def get_by_workflow_run(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
order_config: Optional[OrderConfig] = None,
|
||||
) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all WorkflowNodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
order_config: Optional configuration for ordering results
|
||||
order_config.order_by: List of fields to order by (e.g., ["index", "created_at"])
|
||||
order_config.order_direction: Direction to order ("asc" or "desc")
|
||||
|
||||
Returns:
|
||||
A list of WorkflowNodeExecution instances
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
|
||||
# Apply ordering if provided
|
||||
if order_config and order_config.order_by:
|
||||
order_columns: list[UnaryExpression] = []
|
||||
for field in order_config.order_by:
|
||||
column = getattr(WorkflowNodeExecution, field, None)
|
||||
if not column:
|
||||
continue
|
||||
if order_config.order_direction == "desc":
|
||||
order_columns.append(desc(column))
|
||||
else:
|
||||
order_columns.append(asc(column))
|
||||
|
||||
if order_columns:
|
||||
stmt = stmt.order_by(*order_columns)
|
||||
|
||||
return session.scalars(stmt).all()
|
||||
|
||||
def get_running_executions(self, workflow_run_id: str) -> Sequence[WorkflowNodeExecution]:
|
||||
"""
|
||||
Retrieve all running WorkflowNodeExecution instances for a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: The workflow run ID
|
||||
|
||||
Returns:
|
||||
A list of running WorkflowNodeExecution instances
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
stmt = select(WorkflowNodeExecution).where(
|
||||
WorkflowNodeExecution.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecution.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecution.status == WorkflowNodeExecutionStatus.RUNNING,
|
||||
WorkflowNodeExecution.triggered_from == WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
stmt = stmt.where(WorkflowNodeExecution.app_id == self._app_id)
|
||||
|
||||
return session.scalars(stmt).all()
|
||||
|
||||
def update(self, execution: WorkflowNodeExecution) -> None:
|
||||
"""
|
||||
Update an existing WorkflowNodeExecution instance and commit changes to the database.
|
||||
|
||||
Args:
|
||||
execution: The WorkflowNodeExecution instance to update
|
||||
"""
|
||||
with self._session_factory() as session:
|
||||
# Ensure tenant_id is set
|
||||
if not execution.tenant_id:
|
||||
execution.tenant_id = self._tenant_id
|
||||
|
||||
# Set app_id if provided and not already set
|
||||
if self._app_id and not execution.app_id:
|
||||
execution.app_id = self._app_id
|
||||
|
||||
session.merge(execution)
|
||||
session.commit()
|
||||
@ -1,49 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import toml # type: ignore
|
||||
|
||||
|
||||
def load_api_poetry_configs() -> dict[str, Any]:
|
||||
pyproject_toml = toml.load("api/pyproject.toml")
|
||||
return pyproject_toml["tool"]["poetry"]
|
||||
|
||||
|
||||
def load_all_dependency_groups() -> dict[str, dict[str, dict[str, Any]]]:
|
||||
configs = load_api_poetry_configs()
|
||||
configs_by_group = {"main": configs}
|
||||
for group_name in configs["group"]:
|
||||
configs_by_group[group_name] = configs["group"][group_name]
|
||||
dependencies_by_group = {group_name: base["dependencies"] for group_name, base in configs_by_group.items()}
|
||||
return dependencies_by_group
|
||||
|
||||
|
||||
def test_group_dependencies_sorted():
|
||||
for group_name, dependencies in load_all_dependency_groups().items():
|
||||
dependency_names = list(dependencies.keys())
|
||||
expected_dependency_names = sorted(set(dependency_names))
|
||||
section = f"tool.poetry.group.{group_name}.dependencies" if group_name else "tool.poetry.dependencies"
|
||||
assert expected_dependency_names == dependency_names, (
|
||||
f"Dependencies in group {group_name} are not sorted. "
|
||||
f"Check and fix [{section}] section in pyproject.toml file"
|
||||
)
|
||||
|
||||
|
||||
def test_group_dependencies_version_operator():
|
||||
for group_name, dependencies in load_all_dependency_groups().items():
|
||||
for dependency_name, specification in dependencies.items():
|
||||
version_spec = specification if isinstance(specification, str) else specification["version"]
|
||||
assert not version_spec.startswith("^"), (
|
||||
f"Please replace '{dependency_name} = {version_spec}' with '{dependency_name} = ~{version_spec[1:]}' "
|
||||
f"'^' operator is too wide and not allowed in the version specification."
|
||||
)
|
||||
|
||||
|
||||
def test_duplicated_dependency_crossing_groups() -> None:
|
||||
all_dependency_names: list[str] = []
|
||||
for dependencies in load_all_dependency_groups().values():
|
||||
dependency_names = list(dependencies.keys())
|
||||
all_dependency_names.extend(dependency_names)
|
||||
expected_all_dependency_names = set(all_dependency_names)
|
||||
assert sorted(expected_all_dependency_names) == sorted(all_dependency_names), (
|
||||
"Duplicated dependencies crossing groups are found"
|
||||
)
|
||||
@ -0,0 +1,99 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.model_runtime.entities.message_entities import AssistantPromptMessage
|
||||
from core.model_runtime.model_providers.__base.large_language_model import _increase_tool_call
|
||||
|
||||
ToolCall = AssistantPromptMessage.ToolCall
|
||||
|
||||
# CASE 1: Single tool call
|
||||
INPUTS_CASE_1 = [
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_1 = [
|
||||
ToolCall(
|
||||
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||
),
|
||||
]
|
||||
|
||||
# CASE 2: Tool call sequences where IDs are anchored to the first chunk (vLLM/SiliconFlow ...)
|
||||
INPUTS_CASE_2 = [
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_2 = [
|
||||
ToolCall(
|
||||
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||
),
|
||||
ToolCall(
|
||||
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
|
||||
),
|
||||
]
|
||||
|
||||
# CASE 3: Tool call sequences where IDs are anchored to every chunk (SGLang ...)
|
||||
INPUTS_CASE_3 = [
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="1", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||
ToolCall(id="2", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_3 = [
|
||||
ToolCall(
|
||||
id="1", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}')
|
||||
),
|
||||
ToolCall(
|
||||
id="2", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}')
|
||||
),
|
||||
]
|
||||
|
||||
# CASE 4: Tool call sequences with no IDs
|
||||
INPUTS_CASE_4 = [
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_foo", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg1": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="func_bar", arguments="")),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='{"arg2": ')),
|
||||
ToolCall(id="", type="function", function=ToolCall.ToolCallFunction(name="", arguments='"value"}')),
|
||||
]
|
||||
EXPECTED_CASE_4 = [
|
||||
ToolCall(
|
||||
id="RANDOM_ID_1",
|
||||
type="function",
|
||||
function=ToolCall.ToolCallFunction(name="func_foo", arguments='{"arg1": "value"}'),
|
||||
),
|
||||
ToolCall(
|
||||
id="RANDOM_ID_2",
|
||||
type="function",
|
||||
function=ToolCall.ToolCallFunction(name="func_bar", arguments='{"arg2": "value"}'),
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _run_case(inputs: list[ToolCall], expected: list[ToolCall]):
|
||||
actual = []
|
||||
_increase_tool_call(inputs, actual)
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test__increase_tool_call():
|
||||
# case 1:
|
||||
_run_case(INPUTS_CASE_1, EXPECTED_CASE_1)
|
||||
|
||||
# case 2:
|
||||
_run_case(INPUTS_CASE_2, EXPECTED_CASE_2)
|
||||
|
||||
# case 3:
|
||||
_run_case(INPUTS_CASE_3, EXPECTED_CASE_3)
|
||||
|
||||
# case 4:
|
||||
mock_id_generator = MagicMock()
|
||||
mock_id_generator.side_effect = [_exp_case.id for _exp_case in EXPECTED_CASE_4]
|
||||
with patch("core.model_runtime.model_providers.__base.large_language_model._gen_tool_call_id", mock_id_generator):
|
||||
_run_case(INPUTS_CASE_4, EXPECTED_CASE_4)
|
||||
@ -0,0 +1,198 @@
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import Response
|
||||
|
||||
from factories.file_factory import (
|
||||
File,
|
||||
FileTransferMethod,
|
||||
FileType,
|
||||
FileUploadConfig,
|
||||
build_from_mapping,
|
||||
)
|
||||
from models import ToolFile, UploadFile
|
||||
|
||||
# Test Data
|
||||
TEST_TENANT_ID = "test_tenant_id"
|
||||
TEST_UPLOAD_FILE_ID = str(uuid.uuid4())
|
||||
TEST_TOOL_FILE_ID = str(uuid.uuid4())
|
||||
TEST_REMOTE_URL = "http://example.com/test.jpg"
|
||||
|
||||
# Test Config
|
||||
TEST_CONFIG = FileUploadConfig(
|
||||
allowed_file_types=["image", "document"],
|
||||
allowed_file_extensions=[".jpg", ".pdf"],
|
||||
allowed_file_upload_methods=[FileTransferMethod.LOCAL_FILE, FileTransferMethod.TOOL_FILE],
|
||||
number_limits=10,
|
||||
)
|
||||
|
||||
|
||||
# Fixtures
|
||||
@pytest.fixture
|
||||
def mock_upload_file():
|
||||
mock = MagicMock(spec=UploadFile)
|
||||
mock.id = TEST_UPLOAD_FILE_ID
|
||||
mock.tenant_id = TEST_TENANT_ID
|
||||
mock.name = "test.jpg"
|
||||
mock.extension = "jpg"
|
||||
mock.mime_type = "image/jpeg"
|
||||
mock.source_url = TEST_REMOTE_URL
|
||||
mock.size = 1024
|
||||
mock.key = "test_key"
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=mock) as m:
|
||||
yield m
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_tool_file():
|
||||
mock = MagicMock(spec=ToolFile)
|
||||
mock.id = TEST_TOOL_FILE_ID
|
||||
mock.tenant_id = TEST_TENANT_ID
|
||||
mock.name = "tool_file.pdf"
|
||||
mock.file_key = "tool_file.pdf"
|
||||
mock.mimetype = "application/pdf"
|
||||
mock.original_url = "http://example.com/tool.pdf"
|
||||
mock.size = 2048
|
||||
with patch("factories.file_factory.db.session.query") as mock_query:
|
||||
mock_query.return_value.filter.return_value.first.return_value = mock
|
||||
yield mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_http_head():
|
||||
def _mock_response(filename, size, content_type):
|
||||
return Response(
|
||||
status_code=200,
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="{filename}"',
|
||||
"Content-Length": str(size),
|
||||
"Content-Type": content_type,
|
||||
},
|
||||
)
|
||||
|
||||
with patch("factories.file_factory.ssrf_proxy.head") as mock_head:
|
||||
mock_head.return_value = _mock_response("remote_test.jpg", 2048, "image/jpeg")
|
||||
yield mock_head
|
||||
|
||||
|
||||
# Helper functions
|
||||
def local_file_mapping(file_type="image"):
|
||||
return {
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": TEST_UPLOAD_FILE_ID,
|
||||
"type": file_type,
|
||||
}
|
||||
|
||||
|
||||
def tool_file_mapping(file_type="document"):
|
||||
return {
|
||||
"transfer_method": "tool_file",
|
||||
"tool_file_id": TEST_TOOL_FILE_ID,
|
||||
"type": file_type,
|
||||
}
|
||||
|
||||
|
||||
# Tests
|
||||
def test_build_from_mapping_backward_compatibility(mock_upload_file):
|
||||
mapping = local_file_mapping(file_type="image")
|
||||
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||
assert isinstance(file, File)
|
||||
assert file.transfer_method == FileTransferMethod.LOCAL_FILE
|
||||
assert file.type == FileType.IMAGE
|
||||
assert file.related_id == TEST_UPLOAD_FILE_ID
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("file_type", "should_pass", "expected_error"),
|
||||
[
|
||||
("image", True, None),
|
||||
("document", False, "Detected file type does not match"),
|
||||
],
|
||||
)
|
||||
def test_build_from_local_file_strict_validation(mock_upload_file, file_type, should_pass, expected_error):
|
||||
mapping = local_file_mapping(file_type=file_type)
|
||||
if should_pass:
|
||||
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||
assert file.type == FileType(file_type)
|
||||
else:
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("file_type", "should_pass", "expected_error"),
|
||||
[
|
||||
("document", True, None),
|
||||
("image", False, "Detected file type does not match"),
|
||||
],
|
||||
)
|
||||
def test_build_from_tool_file_strict_validation(mock_tool_file, file_type, should_pass, expected_error):
|
||||
"""Strict type validation for tool_file."""
|
||||
mapping = tool_file_mapping(file_type=file_type)
|
||||
if should_pass:
|
||||
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||
assert file.type == FileType(file_type)
|
||||
else:
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, strict_type_validation=True)
|
||||
|
||||
|
||||
def test_build_from_remote_url(mock_http_head):
|
||||
mapping = {
|
||||
"transfer_method": "remote_url",
|
||||
"url": TEST_REMOTE_URL,
|
||||
"type": "image",
|
||||
}
|
||||
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||
assert file.transfer_method == FileTransferMethod.REMOTE_URL
|
||||
assert file.type == FileType.IMAGE
|
||||
assert file.filename == "remote_test.jpg"
|
||||
assert file.size == 2048
|
||||
|
||||
|
||||
def test_tool_file_not_found():
|
||||
"""Test ToolFile not found in database."""
|
||||
with patch("factories.file_factory.db.session.query") as mock_query:
|
||||
mock_query.return_value.filter.return_value.first.return_value = None
|
||||
mapping = tool_file_mapping()
|
||||
with pytest.raises(ValueError, match=f"ToolFile {TEST_TOOL_FILE_ID} not found"):
|
||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||
|
||||
|
||||
def test_local_file_not_found():
|
||||
"""Test UploadFile not found in database."""
|
||||
with patch("factories.file_factory.db.session.scalar", return_value=None):
|
||||
mapping = local_file_mapping()
|
||||
with pytest.raises(ValueError, match="Invalid upload file"):
|
||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||
|
||||
|
||||
def test_build_without_type_specification(mock_upload_file):
|
||||
"""Test the situation where no file type is specified"""
|
||||
mapping = {
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": TEST_UPLOAD_FILE_ID,
|
||||
# leave out the type
|
||||
}
|
||||
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID)
|
||||
# It should automatically infer the type as "image" based on the file extension
|
||||
assert file.type == FileType.IMAGE
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("file_type", "should_pass", "expected_error"),
|
||||
[
|
||||
("image", True, None),
|
||||
("video", False, "File validation failed"),
|
||||
],
|
||||
)
|
||||
def test_file_validation_with_config(mock_upload_file, file_type, should_pass, expected_error):
|
||||
"""Test the validation of files and configurations"""
|
||||
mapping = local_file_mapping(file_type=file_type)
|
||||
if should_pass:
|
||||
file = build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)
|
||||
assert file is not None
|
||||
else:
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
build_from_mapping(mapping=mapping, tenant_id=TEST_TENANT_ID, config=TEST_CONFIG)
|
||||
@ -0,0 +1,3 @@
|
||||
"""
|
||||
Unit tests for repositories.
|
||||
"""
|
||||
@ -0,0 +1,3 @@
|
||||
"""
|
||||
Unit tests for workflow_node_execution repositories.
|
||||
"""
|
||||
@ -0,0 +1,154 @@
|
||||
"""
|
||||
Unit tests for the SQLAlchemy implementation of WorkflowNodeExecutionRepository.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.repository.workflow_node_execution_repository import OrderConfig
|
||||
from models.workflow import WorkflowNodeExecution
|
||||
from repositories.workflow_node_execution.sqlalchemy_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
"""Create a mock SQLAlchemy session."""
|
||||
session = MagicMock(spec=Session)
|
||||
# Configure the session to be used as a context manager
|
||||
session.__enter__ = MagicMock(return_value=session)
|
||||
session.__exit__ = MagicMock(return_value=None)
|
||||
|
||||
# Configure the session factory to return the session
|
||||
session_factory = MagicMock(spec=sessionmaker)
|
||||
session_factory.return_value = session
|
||||
return session, session_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def repository(session):
|
||||
"""Create a repository instance with test data."""
|
||||
_, session_factory = session
|
||||
tenant_id = "test-tenant"
|
||||
app_id = "test-app"
|
||||
return SQLAlchemyWorkflowNodeExecutionRepository(
|
||||
session_factory=session_factory, tenant_id=tenant_id, app_id=app_id
|
||||
)
|
||||
|
||||
|
||||
def test_save(repository, session):
|
||||
"""Test save method."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution
|
||||
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
execution.tenant_id = None
|
||||
execution.app_id = None
|
||||
|
||||
# Call save method
|
||||
repository.save(execution)
|
||||
|
||||
# Assert tenant_id and app_id are set
|
||||
assert execution.tenant_id == repository._tenant_id
|
||||
assert execution.app_id == repository._app_id
|
||||
|
||||
# Assert session.add was called
|
||||
session_obj.add.assert_called_once_with(execution)
|
||||
|
||||
|
||||
def test_save_with_existing_tenant_id(repository, session):
|
||||
"""Test save method with existing tenant_id."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution with existing tenant_id
|
||||
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
execution.tenant_id = "existing-tenant"
|
||||
execution.app_id = None
|
||||
|
||||
# Call save method
|
||||
repository.save(execution)
|
||||
|
||||
# Assert tenant_id is not changed and app_id is set
|
||||
assert execution.tenant_id == "existing-tenant"
|
||||
assert execution.app_id == repository._app_id
|
||||
|
||||
# Assert session.add was called
|
||||
session_obj.add.assert_called_once_with(execution)
|
||||
|
||||
|
||||
def test_get_by_node_execution_id(repository, session, mocker: MockerFixture):
|
||||
"""Test get_by_node_execution_id method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
session_obj.scalar.return_value = mocker.MagicMock(spec=WorkflowNodeExecution)
|
||||
|
||||
# Call method
|
||||
result = repository.get_by_node_execution_id("test-node-execution-id")
|
||||
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalar.assert_called_once_with(mock_stmt)
|
||||
assert result is not None
|
||||
|
||||
|
||||
def test_get_by_workflow_run(repository, session, mocker: MockerFixture):
|
||||
"""Test get_by_workflow_run method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
mock_stmt.order_by.return_value = mock_stmt
|
||||
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
|
||||
|
||||
# Call method
|
||||
order_config = OrderConfig(order_by=["index"], order_direction="desc")
|
||||
result = repository.get_by_workflow_run(workflow_run_id="test-workflow-run-id", order_config=order_config)
|
||||
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
def test_get_running_executions(repository, session, mocker: MockerFixture):
|
||||
"""Test get_running_executions method."""
|
||||
session_obj, _ = session
|
||||
# Set up mock
|
||||
mock_select = mocker.patch("repositories.workflow_node_execution.sqlalchemy_repository.select")
|
||||
mock_stmt = mocker.MagicMock()
|
||||
mock_select.return_value = mock_stmt
|
||||
mock_stmt.where.return_value = mock_stmt
|
||||
session_obj.scalars.return_value.all.return_value = [mocker.MagicMock(spec=WorkflowNodeExecution)]
|
||||
|
||||
# Call method
|
||||
result = repository.get_running_executions("test-workflow-run-id")
|
||||
|
||||
# Assert select was called with correct parameters
|
||||
mock_select.assert_called_once()
|
||||
session_obj.scalars.assert_called_once_with(mock_stmt)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
def test_update(repository, session):
|
||||
"""Test update method."""
|
||||
session_obj, _ = session
|
||||
# Create a mock execution
|
||||
execution = MagicMock(spec=WorkflowNodeExecution)
|
||||
execution.tenant_id = None
|
||||
execution.app_id = None
|
||||
|
||||
# Call update method
|
||||
repository.update(execution)
|
||||
|
||||
# Assert tenant_id and app_id are set
|
||||
assert execution.tenant_id == repository._tenant_id
|
||||
assert execution.app_id == repository._app_id
|
||||
|
||||
# Assert session.merge was called
|
||||
session_obj.merge.assert_called_once_with(execution)
|
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -x
|
||||
|
||||
# run mypy checks
|
||||
uv run --directory api --dev \
|
||||
python -m mypy --install-types --non-interactive .
|
||||
@ -1,11 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -x
|
||||
|
||||
if ! command -v mypy &> /dev/null; then
|
||||
poetry install -C api --with dev
|
||||
fi
|
||||
|
||||
# run mypy checks
|
||||
poetry run -C api \
|
||||
python -m mypy --install-types --non-interactive .
|
||||
@ -1,18 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# rely on `poetry` in path
|
||||
if ! command -v poetry &> /dev/null; then
|
||||
echo "Installing Poetry ..."
|
||||
pip install poetry
|
||||
fi
|
||||
|
||||
# check poetry.lock in sync with pyproject.toml
|
||||
poetry check -C api --lock
|
||||
if [ $? -ne 0 ]; then
|
||||
# update poetry.lock
|
||||
# refreshing lockfile only without updating locked versions
|
||||
echo "poetry.lock is outdated, refreshing without updating locked versions ..."
|
||||
poetry lock -C api
|
||||
else
|
||||
echo "poetry.lock is ready."
|
||||
fi
|
||||
@ -0,0 +1,10 @@
|
||||
#!/bin/bash
|
||||
|
||||
# rely on `uv` in path
|
||||
if ! command -v uv &> /dev/null; then
|
||||
echo "Installing uv ..."
|
||||
pip install uv
|
||||
fi
|
||||
|
||||
# check uv.lock in sync with pyproject.toml
|
||||
uv lock --project api
|
||||
@ -1,13 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# rely on `poetry` in path
|
||||
if ! command -v poetry &> /dev/null; then
|
||||
echo "Installing Poetry ..."
|
||||
pip install poetry
|
||||
fi
|
||||
|
||||
# refreshing lockfile, updating locked versions
|
||||
poetry update -C api
|
||||
|
||||
# check poetry.lock in sync with pyproject.toml
|
||||
poetry check -C api --lock
|
||||
@ -0,0 +1,22 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Update dependencies in dify/api project using uv
|
||||
set -e
|
||||
set -o pipefail
|
||||
|
||||
SCRIPT_DIR="$(dirname "$0")"
|
||||
REPO_ROOT="$(dirname "${SCRIPT_DIR}")"
|
||||
|
||||
# rely on `poetry` in path
|
||||
if ! command -v uv &> /dev/null; then
|
||||
echo "Installing uv ..."
|
||||
pip install uv
|
||||
fi
|
||||
|
||||
cd "${REPO_ROOT}"
|
||||
|
||||
# refreshing lockfile, updating locked versions
|
||||
uv lock --project api --upgrade
|
||||
|
||||
# check uv.lock in sync with pyproject.toml
|
||||
uv lock --project api --check
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue