Merge branch 'main' into e-300

pull/19898/head
NFish 1 year ago
commit 0301bd3ac1

@ -5,18 +5,35 @@ root = true
# Unix-style newlines with a newline ending every file
[*]
charset = utf-8
end_of_line = lf
insert_final_newline = true
trim_trailing_whitespace = true
[*.py]
indent_size = 4
indent_style = space
[*.{yml,yaml}]
indent_style = space
indent_size = 2
[*.toml]
indent_size = 4
indent_style = space
# Markdown and MDX are whitespace sensitive languages.
# Do not remove trailing spaces.
[*.{md,mdx}]
trim_trailing_whitespace = false
# Matches multiple files with brace expansion notation
# Set default charset
[*.{js,tsx}]
charset = utf-8
indent_style = space
indent_size = 2
# Matches the exact files either package.json or .travis.yml
[{package.json,.travis.yml}]
# Matches the exact files package.json
[package.json]
indent_style = space
indent_size = 2

@ -0,0 +1,22 @@
{
"Verbose": false,
"Debug": false,
"IgnoreDefaults": false,
"SpacesAfterTabs": false,
"NoColor": false,
"Exclude": [
"^web/public/vs/",
"^web/public/pdf.worker.min.mjs$",
"web/app/components/base/icons/src/vender/"
],
"AllowedContentTypes": [],
"PassedFiles": [],
"Disable": {
"EndOfLine": false,
"Indentation": false,
"IndentSize": true,
"InsertFinalNewline": false,
"TrimTrailingWhitespace": false,
"MaxLineLength": false
}
}

@ -88,3 +88,6 @@ jobs:
- name: Run Workflow
run: uv run --project api bash dev/pytest/pytest_workflow.sh
- name: Run Tool
run: uv run --project api bash dev/pytest/pytest_tools.sh

@ -9,6 +9,12 @@ concurrency:
group: style-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
permissions:
checks: write
statuses: write
contents: read
jobs:
python-style:
name: Python Style
@ -43,8 +49,8 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
run: |
uv run --directory api ruff --version
uv run --directory api ruff check ./
uv run --directory api ruff format --check ./
uv run --directory api ruff check --diff ./
uv run --directory api ruff format --check --diff ./
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'
@ -163,3 +169,14 @@ jobs:
VALIDATE_DOCKERFILE_HADOLINT: true
VALIDATE_XML: true
VALIDATE_YAML: true
- name: EditorConfig checks
uses: super-linter/super-linter/slim@v7
env:
DEFAULT_BRANCH: main
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
IGNORE_GENERATED_FILES: true
IGNORE_GITIGNORED_FILES: true
# EditorConfig validation
VALIDATE_EDITORCONFIG: true
EDITORCONFIG_FILE_NAME: editorconfig-checker.json

@ -90,3 +90,4 @@
```bash
uv run -P api bash dev/pytest/pytest_all_tests.sh
```

@ -818,8 +818,9 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green"))
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
@click.command("clear-orphaned-file-records", help="Clear orphaned file records.")
def clear_orphaned_file_records():
def clear_orphaned_file_records(force: bool):
"""
Clear orphaned file records in the database.
"""
@ -845,7 +846,15 @@ def clear_orphaned_file_records():
# notify user and ask for confirmation
click.echo(
click.style("This command will find and delete orphaned file records in the following tables:", fg="yellow")
click.style(
"This command will first find and delete orphaned file records from the message_files table,", fg="yellow"
)
)
click.echo(
click.style(
"and then it will find and delete orphaned file records in the following tables:",
fg="yellow",
)
)
for files_table in files_tables:
click.echo(click.style(f"- {files_table['table']}", fg="yellow"))
@ -878,11 +887,55 @@ def clear_orphaned_file_records():
fg="yellow",
)
)
if not force:
click.confirm("Do you want to proceed?", abort=True)
# start the cleanup process
click.echo(click.style("Starting orphaned file records cleanup.", fg="white"))
# clean up the orphaned records in the message_files table where message_id doesn't exist in messages table
try:
click.echo(
click.style("- Listing message_files records where message_id doesn't exist in messages table", fg="white")
)
query = (
"SELECT mf.id, mf.message_id "
"FROM message_files mf LEFT JOIN messages m ON mf.message_id = m.id "
"WHERE m.id IS NULL"
)
orphaned_message_files = []
with db.engine.begin() as conn:
rs = conn.execute(db.text(query))
for i in rs:
orphaned_message_files.append({"id": str(i[0]), "message_id": str(i[1])})
if orphaned_message_files:
click.echo(click.style(f"Found {len(orphaned_message_files)} orphaned message_files records:", fg="white"))
for record in orphaned_message_files:
click.echo(click.style(f" - id: {record['id']}, message_id: {record['message_id']}", fg="black"))
if not force:
click.confirm(
(
f"Do you want to proceed "
f"to delete all {len(orphaned_message_files)} orphaned message_files records?"
),
abort=True,
)
click.echo(click.style("- Deleting orphaned message_files records", fg="white"))
query = "DELETE FROM message_files WHERE id IN :ids"
with db.engine.begin() as conn:
conn.execute(db.text(query), {"ids": tuple([record["id"] for record in orphaned_message_files])})
click.echo(
click.style(f"Removed {len(orphaned_message_files)} orphaned message_files records.", fg="green")
)
else:
click.echo(click.style("No orphaned message_files records found. There is nothing to delete.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error deleting orphaned message_files records: {str(e)}", fg="red"))
# clean up the orphaned records in the rest of the *_files tables
try:
# fetch file id and keys from each table
all_files_in_tables = []
@ -964,6 +1017,7 @@ def clear_orphaned_file_records():
click.echo(click.style(f"Found {len(orphaned_files)} orphaned file records.", fg="white"))
for file in orphaned_files:
click.echo(click.style(f"- orphaned file id: {file}", fg="black"))
if not force:
click.confirm(f"Do you want to proceed to delete all {len(orphaned_files)} orphaned file records?", abort=True)
# delete orphaned records for each file
@ -979,8 +1033,9 @@ def clear_orphaned_file_records():
click.echo(click.style(f"Removed {len(orphaned_files)} orphaned file records.", fg="green"))
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
@click.command("remove-orphaned-files-on-storage", help="Remove orphaned files on the storage.")
def remove_orphaned_files_on_storage():
def remove_orphaned_files_on_storage(force: bool):
"""
Remove orphaned files on the storage.
"""
@ -1028,6 +1083,7 @@ def remove_orphaned_files_on_storage():
fg="yellow",
)
)
if not force:
click.confirm("Do you want to proceed?", abort=True)
# start the cleanup process
@ -1069,6 +1125,7 @@ def remove_orphaned_files_on_storage():
click.echo(click.style(f"Found {len(orphaned_files)} orphaned files.", fg="white"))
for file in orphaned_files:
click.echo(click.style(f"- orphaned file: {file}", fg="black"))
if not force:
click.confirm(f"Do you want to proceed to remove all {len(orphaned_files)} orphaned files?", abort=True)
# delete orphaned files

@ -398,6 +398,11 @@ class InnerAPIConfig(BaseSettings):
default=False,
)
INNER_API_KEY: Optional[str] = Field(
description="API key for accessing the internal API",
default=None,
)
class LoggingConfig(BaseSettings):
"""

@ -1,4 +1,5 @@
from typing import Optional
import enum
from typing import Literal, Optional
from pydantic import Field, PositiveInt
from pydantic_settings import BaseSettings
@ -9,6 +10,14 @@ class OpenSearchConfig(BaseSettings):
Configuration settings for OpenSearch
"""
class AuthMethod(enum.StrEnum):
"""
Authentication method for OpenSearch
"""
BASIC = "basic"
AWS_MANAGED_IAM = "aws_managed_iam"
OPENSEARCH_HOST: Optional[str] = Field(
description="Hostname or IP address of the OpenSearch server (e.g., 'localhost' or 'opensearch.example.com')",
default=None,
@ -19,6 +28,16 @@ class OpenSearchConfig(BaseSettings):
default=9200,
)
OPENSEARCH_SECURE: bool = Field(
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
default=False,
)
OPENSEARCH_AUTH_METHOD: AuthMethod = Field(
description="Authentication method for OpenSearch connection (default is 'basic')",
default=AuthMethod.BASIC,
)
OPENSEARCH_USER: Optional[str] = Field(
description="Username for authenticating with OpenSearch",
default=None,
@ -29,7 +48,11 @@ class OpenSearchConfig(BaseSettings):
default=None,
)
OPENSEARCH_SECURE: bool = Field(
description="Whether to use SSL/TLS encrypted connection for OpenSearch (True for HTTPS, False for HTTP)",
default=False,
OPENSEARCH_AWS_REGION: Optional[str] = Field(
description="AWS region for OpenSearch (e.g. 'us-west-2')",
default=None,
)
OPENSEARCH_AWS_SERVICE: Optional[Literal["es", "aoss"]] = Field(
description="AWS service for OpenSearch (e.g. 'aoss' for OpenSearch Serverless)", default=None
)

@ -0,0 +1,7 @@
# The two constants below should keep in sync.
# Default content type for files which have no explicit content type.
DEFAULT_MIME_TYPE = "application/octet-stream"
# Default file extension for files which have no explicit content type, should
# correspond to the `DEFAULT_MIME_TYPE` above.
DEFAULT_EXTENSION = ".bin"

@ -2,22 +2,22 @@ import uuid
from typing import cast
from flask_login import current_user # type: ignore
from flask_restful import (Resource, inputs, marshal, # type: ignore
marshal_with, reqparse)
from flask_restful import Resource, inputs, marshal, marshal_with, reqparse # type: ignore
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden, abort
from controllers.console import api
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import (account_initialization_required,
from controllers.console.wraps import (
account_initialization_required,
cloud_edition_billing_resource_check,
enterprise_license_required,
setup_required)
setup_required,
)
from core.ops.ops_trace_manager import OpsTraceManager
from extensions.ext_database import db
from fields.app_fields import (app_detail_fields, app_detail_fields_with_site,
app_pagination_fields)
from fields.app_fields import app_detail_fields, app_detail_fields_with_site, app_pagination_fields
from libs.login import login_required
from models import Account, App
from services.app_dsl_service import AppDslService, ImportMode

@ -24,7 +24,7 @@ from libs.password import hash_password, valid_password
from models.account import Account
from services.account_service import AccountService, TenantService
from services.errors.account import AccountRegisterError
from services.errors.workspace import WorkSpaceNotAllowedCreateError, WorkspacesLimitExceededError
from services.errors.workspace import WorkSpaceNotAllowedCreateError
from services.feature_service import FeatureService

@ -4,15 +4,13 @@ from typing import Any
from flask import request
from flask_login import current_user # type: ignore
from flask_restful import (Resource, inputs, marshal_with, # type: ignore
reqparse)
from flask_restful import Resource, inputs, marshal_with, reqparse # type: ignore
from sqlalchemy import and_
from werkzeug.exceptions import BadRequest, Forbidden, NotFound
from controllers.console import api
from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import (account_initialization_required,
cloud_edition_billing_resource_check)
from controllers.console.wraps import account_initialization_required, cloud_edition_billing_resource_check
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from libs.login import login_required

@ -70,6 +70,20 @@ class FilePreviewApi(Resource):
direct_passthrough=True,
headers={},
)
# add Accept-Ranges header for audio/video files
if upload_file.mime_type in [
"audio/mpeg",
"audio/wav",
"audio/mp4",
"audio/ogg",
"audio/flac",
"audio/aac",
"video/mp4",
"video/webm",
"video/quicktime",
"audio/x-m4a",
]:
response.headers["Accept-Ranges"] = "bytes"
if upload_file.size > 0:
response.headers["Content-Length"] = str(upload_file.size)
if args["as_attachment"]:

@ -1,10 +1,14 @@
from urllib.parse import quote
from flask import Response
from flask_restful import Resource, reqparse # type: ignore
from werkzeug.exceptions import Forbidden, NotFound
from controllers.files import api
from controllers.files.error import UnsupportedFileTypeError
from core.tools.signature import verify_tool_file_signature
from core.tools.tool_file_manager import ToolFileManager
from models import db as global_db
class ToolFilePreviewApi(Resource):
@ -19,17 +23,14 @@ class ToolFilePreviewApi(Resource):
parser.add_argument("as_attachment", type=bool, required=False, default=False, location="args")
args = parser.parse_args()
if not ToolFileManager.verify_file(
file_id=file_id,
timestamp=args["timestamp"],
nonce=args["nonce"],
sign=args["sign"],
if not verify_tool_file_signature(
file_id=file_id, timestamp=args["timestamp"], nonce=args["nonce"], sign=args["sign"]
):
raise Forbidden("Invalid request.")
try:
stream, tool_file = ToolFileManager.get_file_generator_by_tool_file_id(
tool_file_manager = ToolFileManager(engine=global_db.engine)
stream, tool_file = tool_file_manager.get_file_generator_by_tool_file_id(
file_id,
)
@ -47,7 +48,8 @@ class ToolFilePreviewApi(Resource):
if tool_file.size > 0:
response.headers["Content-Length"] = str(tool_file.size)
if args["as_attachment"]:
response.headers["Content-Disposition"] = f"attachment; filename={tool_file.name}"
encoded_filename = quote(tool_file.name)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
return response

@ -53,7 +53,7 @@ class PluginUploadFileApi(Resource):
raise Forbidden("Invalid request.")
try:
tool_file = ToolFileManager.create_file_by_raw(
tool_file = ToolFileManager().create_file_by_raw(
user_id=user.id,
tenant_id=tenant_id,
file_binary=file.read(),

@ -5,6 +5,6 @@ from libs.external_api import ExternalApi
bp = Blueprint("inner_api", __name__, url_prefix="/inner/api")
api = ExternalApi(bp)
from .plugin import plugin
from . import mail
from .plugin import plugin
from .workspace import workspace

@ -18,7 +18,7 @@ def enterprise_inner_api_only(view):
# get header 'X-Inner-Api-Key'
inner_api_key = request.headers.get("X-Inner-Api-Key")
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY_FOR_PLUGIN:
if not inner_api_key or inner_api_key != dify_config.INNER_API_KEY:
abort(401)
return view(*args, **kwargs)

@ -69,6 +69,13 @@ class CotAgentRunner(BaseAgentRunner, ABC):
tool_instances, prompt_messages_tools = self._init_prompt_tools()
self._prompt_messages_tools = prompt_messages_tools
# fix metadata filter not work
if app_config.dataset is not None:
metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions
for key, dataset_retriever_tool in tool_instances.items():
if hasattr(dataset_retriever_tool, "retrieval_tool"):
dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions
function_call_state = True
llm_usage: dict[str, Optional[LLMUsage]] = {"usage": None}
final_answer = ""

@ -45,6 +45,13 @@ class FunctionCallAgentRunner(BaseAgentRunner):
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
# fix metadata filter not work
if app_config.dataset is not None:
metadata_filtering_conditions = app_config.dataset.retrieve_config.metadata_filtering_conditions
for key, dataset_retriever_tool in tool_instances.items():
if hasattr(dataset_retriever_tool, "retrieval_tool"):
dataset_retriever_tool.retrieval_tool.metadata_filtering_conditions = metadata_filtering_conditions
assert app_config.agent
iteration_step = 1

@ -21,7 +21,6 @@ from core.model_runtime.entities.message_entities import (
ImagePromptMessageContent,
PromptMessage,
)
from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.errors.invoke import InvokeBadRequestError
from core.moderation.input_moderation import InputModeration
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform

@ -24,7 +24,7 @@ from core.app.entities.task_entities import (
WorkflowTaskState,
)
from core.llm_generator.llm_generator import LLMGenerator
from core.tools.tool_file_manager import ToolFileManager
from core.tools.signature import sign_tool_file
from extensions.ext_database import db
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
from services.annotation_service import AppAnnotationService
@ -154,7 +154,7 @@ class MessageCycleManage:
if message_file.url.startswith("http"):
url = message_file.url
else:
url = ToolFileManager.sign_file(tool_file_id=tool_file_id, extension=extension)
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
return MessageFileStreamResponse(
task_id=self._application_generate_entity.task_id,

@ -381,6 +381,8 @@ class WorkflowCycleManage:
workflow_node_execution.elapsed_time = elapsed_time
workflow_node_execution.execution_metadata = execution_metadata
self._workflow_node_execution_repository.update(workflow_node_execution)
return workflow_node_execution
def _handle_workflow_node_execution_retried(

@ -10,12 +10,12 @@ from core.model_runtime.entities import (
VideoPromptMessageContent,
)
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from core.tools.signature import sign_tool_file
from extensions.ext_storage import storage
from . import helpers
from .enums import FileAttribute
from .models import File, FileTransferMethod, FileType
from .tool_file_parser import ToolFileParser
def get_attr(*, file: File, attr: FileAttribute):
@ -130,6 +130,6 @@ def _to_url(f: File, /):
# add sign url
if f.related_id is None or f.extension is None:
raise ValueError("Missing file related_id or extension")
return ToolFileParser.get_tool_file_manager().sign_file(tool_file_id=f.related_id, extension=f.extension)
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
else:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")

@ -4,11 +4,11 @@ from typing import Any, Optional
from pydantic import BaseModel, Field, model_validator
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
from core.tools.signature import sign_tool_file
from . import helpers
from .constants import FILE_MODEL_IDENTITY
from .enums import FileTransferMethod, FileType
from .tool_file_parser import ToolFileParser
class ImageConfig(BaseModel):
@ -34,13 +34,21 @@ class FileUploadConfig(BaseModel):
class File(BaseModel):
# NOTE: dify_model_identity is a special identifier used to distinguish between
# new and old data formats during serialization and deserialization.
dify_model_identity: str = FILE_MODEL_IDENTITY
id: Optional[str] = None # message file id
tenant_id: str
type: FileType
transfer_method: FileTransferMethod
# If `transfer_method` is `FileTransferMethod.remote_url`, the
# `remote_url` attribute must not be `None`.
remote_url: Optional[str] = None # remote url
# If `transfer_method` is `FileTransferMethod.local_file` or
# `FileTransferMethod.tool_file`, the `related_id` attribute must not be `None`.
#
# It should be set to `ToolFile.id` when `transfer_method` is `tool_file`.
related_id: Optional[str] = None
filename: Optional[str] = None
extension: Optional[str] = Field(default=None, description="File extension, should contains dot")
@ -110,9 +118,7 @@ class File(BaseModel):
elif self.transfer_method == FileTransferMethod.TOOL_FILE:
assert self.related_id is not None
assert self.extension is not None
return ToolFileParser.get_tool_file_manager().sign_file(
tool_file_id=self.related_id, extension=self.extension
)
return sign_tool_file(tool_file_id=self.related_id, extension=self.extension)
def to_plugin_parameter(self) -> dict[str, Any]:
return {

@ -1,12 +1,19 @@
from typing import TYPE_CHECKING, Any, cast
from collections.abc import Callable
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from core.tools.tool_file_manager import ToolFileManager
tool_file_manager: dict[str, Any] = {"manager": None}
_tool_file_manager_factory: Callable[[], "ToolFileManager"] | None = None
class ToolFileParser:
@staticmethod
def get_tool_file_manager() -> "ToolFileManager":
return cast("ToolFileManager", tool_file_manager["manager"])
assert _tool_file_manager_factory is not None
return _tool_file_manager_factory()
def set_tool_file_manager_factory(factory: Callable[[], "ToolFileManager"]) -> None:
global _tool_file_manager_factory
_tool_file_manager_factory = factory

@ -3,6 +3,8 @@ import logging
import re
from typing import Optional, cast
import json_repair
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import (
@ -366,7 +368,20 @@ class LLMGenerator:
),
)
generated_json_schema = cast(str, response.message.content)
raw_content = response.message.content
if not isinstance(raw_content, str):
raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}")
try:
parsed_content = json.loads(raw_content)
except json.JSONDecodeError:
parsed_content = json_repair.loads(raw_content)
if not isinstance(parsed_content, dict | list):
raise ValueError(f"Failed to parse structured output from llm: {raw_content}")
generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False)
return {"output": generated_json_schema, "error": ""}
except InvokeError as e:

@ -101,7 +101,7 @@ class ModelInstance:
@overload
def invoke_llm(
self,
prompt_messages: list[PromptMessage],
prompt_messages: Sequence[PromptMessage],
model_parameters: Optional[dict] = None,
tools: Sequence[PromptMessageTool] | None = None,
stop: Optional[list[str]] = None,

@ -1,4 +1,5 @@
from collections.abc import Sequence
from abc import ABC
from collections.abc import Mapping, Sequence
from enum import Enum, StrEnum
from typing import Annotated, Any, Literal, Optional, Union
@ -60,8 +61,12 @@ class PromptMessageContentType(StrEnum):
DOCUMENT = "document"
class PromptMessageContent(BaseModel):
pass
class PromptMessageContent(ABC, BaseModel):
"""
Model class for prompt message content.
"""
type: PromptMessageContentType
class TextPromptMessageContent(PromptMessageContent):
@ -125,7 +130,16 @@ PromptMessageContentUnionTypes = Annotated[
]
class PromptMessage(BaseModel):
CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = {
PromptMessageContentType.TEXT: TextPromptMessageContent,
PromptMessageContentType.IMAGE: ImagePromptMessageContent,
PromptMessageContentType.AUDIO: AudioPromptMessageContent,
PromptMessageContentType.VIDEO: VideoPromptMessageContent,
PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent,
}
class PromptMessage(ABC, BaseModel):
"""
Model class for prompt message.
"""
@ -142,6 +156,23 @@ class PromptMessage(BaseModel):
"""
return not self.content
@field_validator("content", mode="before")
@classmethod
def validate_content(cls, v):
if isinstance(v, list):
prompts = []
for prompt in v:
if isinstance(prompt, PromptMessageContent):
if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent):
prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump())
elif isinstance(prompt, dict):
prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt)
else:
raise ValueError(f"invalid prompt message {prompt}")
prompts.append(prompt)
return prompts
return v
@field_serializer("content")
def serialize_content(
self, content: Optional[Union[str, Sequence[PromptMessageContent]]]

@ -24,7 +24,6 @@ from core.model_runtime.errors.invoke import (
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.plugin.entities.plugin_daemon import PluginDaemonInnerError, PluginModelProviderEntity
from core.plugin.impl.model import PluginModelClient
@ -253,15 +252,3 @@ class AIModel(BaseModel):
raise Exception(f"Invalid model parameter rule name {name}")
return default_parameter_rule
def _get_num_tokens_by_gpt2(self, text: str) -> int:
"""
Get number of tokens for given prompt messages by gpt2
Some provider models do not provide an interface for obtaining the number of tokens.
Here, the gpt2 tokenizer is used to calculate the number of tokens.
This method can be executed offline, and the gpt2 tokenizer has been cached in the project.
:param text: plain text of prompt. You need to convert the original message to plain text
:return: number of tokens
"""
return GPT2Tokenizer.get_num_tokens(text)

@ -2,7 +2,7 @@ import logging
import time
import uuid
from collections.abc import Generator, Sequence
from typing import Optional, Union, cast
from typing import Optional, Union
from pydantic import ConfigDict
@ -13,14 +13,15 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk,
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageContentUnionTypes,
PromptMessageTool,
TextPromptMessageContent,
)
from core.model_runtime.entities.model_entities import (
ModelType,
PriceType,
)
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str
from core.plugin.impl.model import PluginModelClient
logger = logging.getLogger(__name__)
@ -238,7 +239,7 @@ class LargeLanguageModel(AIModel):
def _invoke_result_generator(
self,
model: str,
result: Generator,
result: Generator[LLMResultChunk, None, None],
credentials: dict,
prompt_messages: Sequence[PromptMessage],
model_parameters: dict,
@ -255,11 +256,21 @@ class LargeLanguageModel(AIModel):
:return: result generator
"""
callbacks = callbacks or []
assistant_message = AssistantPromptMessage(content="")
message_content: list[PromptMessageContentUnionTypes] = []
usage = None
system_fingerprint = None
real_model = model
def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None):
if not content:
return
if isinstance(content, list):
message_content.extend(content)
return
if isinstance(content, str):
message_content.append(TextPromptMessageContent(data=content))
return
try:
for chunk in result:
# Following https://github.com/langgenius/dify/issues/17799,
@ -281,9 +292,8 @@ class LargeLanguageModel(AIModel):
callbacks=callbacks,
)
text = convert_llm_result_chunk_to_str(chunk.delta.message.content)
current_content = cast(str, assistant_message.content)
assistant_message.content = current_content + text
_update_message_content(chunk.delta.message.content)
real_model = chunk.model
if chunk.delta.usage:
usage = chunk.delta.usage
@ -293,6 +303,7 @@ class LargeLanguageModel(AIModel):
except Exception as e:
raise self._transform_invoke_error(e)
assistant_message = AssistantPromptMessage(content=message_content)
self._trigger_after_invoke_callbacks(
model=model,
result=LLMResult(

@ -30,6 +30,8 @@ class GPT2Tokenizer:
@staticmethod
def get_encoder() -> Any:
global _tokenizer, _lock
if _tokenizer is not None:
return _tokenizer
with _lock:
if _tokenizer is None:
# Try to use tiktoken to get the tokenizer because it is faster

@ -1,8 +1,6 @@
import pydantic
from pydantic import BaseModel
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
def dump_model(model: BaseModel) -> dict:
if hasattr(pydantic, "model_dump"):
@ -10,18 +8,3 @@ def dump_model(model: BaseModel) -> dict:
return pydantic.model_dump(model) # type: ignore
else:
return model.model_dump()
def convert_llm_result_chunk_to_str(content: None | str | list[PromptMessageContentUnionTypes]) -> str:
if content is None:
message_text = ""
elif isinstance(content, str):
message_text = content
elif isinstance(content, list):
# Assuming the list contains PromptMessageContent objects with a "data" attribute
message_text = "".join(
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
)
else:
message_text = str(content)
return message_text

@ -27,8 +27,8 @@ class MilvusConfig(BaseModel):
uri: str # Milvus server URI
token: Optional[str] = None # Optional token for authentication
user: str # Username for authentication
password: str # Password for authentication
user: Optional[str] = None # Username for authentication
password: Optional[str] = None # Password for authentication
batch_size: int = 100 # Batch size for operations
database: str = "default" # Database name
enable_hybrid_search: bool = False # Flag to enable hybrid search
@ -43,6 +43,7 @@ class MilvusConfig(BaseModel):
"""
if not values.get("uri"):
raise ValueError("config MILVUS_URI is required")
if not values.get("token"):
if not values.get("user"):
raise ValueError("config MILVUS_USER is required")
if not values.get("password"):
@ -356,10 +357,13 @@ class MilvusVector(BaseVector):
)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
def _init_client(self, config) -> MilvusClient:
def _init_client(self, config: MilvusConfig) -> MilvusClient:
"""
Initialize and return a Milvus client.
"""
if config.token:
client = MilvusClient(uri=config.uri, token=config.token, db_name=config.database)
else:
client = MilvusClient(uri=config.uri, user=config.user, password=config.password, db_name=config.database)
return client

@ -1,10 +1,9 @@
import json
import logging
import ssl
from typing import Any, Optional
from typing import Any, Literal, Optional
from uuid import uuid4
from opensearchpy import OpenSearch, helpers
from opensearchpy import OpenSearch, Urllib3AWSV4SignerAuth, Urllib3HttpConnection, helpers
from opensearchpy.helpers import BulkIndexError
from pydantic import BaseModel, model_validator
@ -24,9 +23,12 @@ logger = logging.getLogger(__name__)
class OpenSearchConfig(BaseModel):
host: str
port: int
secure: bool = False
auth_method: Literal["basic", "aws_managed_iam"] = "basic"
user: Optional[str] = None
password: Optional[str] = None
secure: bool = False
aws_region: Optional[str] = None
aws_service: Optional[str] = None
@model_validator(mode="before")
@classmethod
@ -35,24 +37,40 @@ class OpenSearchConfig(BaseModel):
raise ValueError("config OPENSEARCH_HOST is required")
if not values.get("port"):
raise ValueError("config OPENSEARCH_PORT is required")
if values.get("auth_method") == "aws_managed_iam":
if not values.get("aws_region"):
raise ValueError("config OPENSEARCH_AWS_REGION is required for AWS_MANAGED_IAM auth method")
if not values.get("aws_service"):
raise ValueError("config OPENSEARCH_AWS_SERVICE is required for AWS_MANAGED_IAM auth method")
return values
def create_ssl_context(self) -> ssl.SSLContext:
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE # Disable Certificate Validation
return ssl_context
def create_aws_managed_iam_auth(self) -> Urllib3AWSV4SignerAuth:
import boto3 # type: ignore
return Urllib3AWSV4SignerAuth(
credentials=boto3.Session().get_credentials(),
region=self.aws_region,
service=self.aws_service, # type: ignore[arg-type]
)
def to_opensearch_params(self) -> dict[str, Any]:
params = {
"hosts": [{"host": self.host, "port": self.port}],
"use_ssl": self.secure,
"verify_certs": self.secure,
"connection_class": Urllib3HttpConnection,
"pool_maxsize": 20,
}
if self.user and self.password:
if self.auth_method == "basic":
logger.info("Using basic authentication for OpenSearch Vector DB")
params["http_auth"] = (self.user, self.password)
if self.secure:
params["ssl_context"] = self.create_ssl_context()
elif self.auth_method == "aws_managed_iam":
logger.info("Using AWS managed IAM role for OpenSearch Vector DB")
params["http_auth"] = self.create_aws_managed_iam_auth()
return params
@ -76,16 +94,23 @@ class OpenSearchVector(BaseVector):
action = {
"_op_type": "index",
"_index": self._collection_name.lower(),
"_id": uuid4().hex,
"_source": {
Field.CONTENT_KEY.value: documents[i].page_content,
Field.VECTOR.value: embeddings[i], # Make sure you pass an array here
Field.METADATA_KEY.value: documents[i].metadata,
},
}
# See https://github.com/langchain-ai/langchainjs/issues/4346#issuecomment-1935123377
if self._client_config.aws_service not in ["aoss"]:
action["_id"] = uuid4().hex
actions.append(action)
helpers.bulk(self._client, actions)
helpers.bulk(
client=self._client,
actions=actions,
timeout=30,
max_retries=3,
)
def get_ids_by_metadata_field(self, key: str, value: str):
query = {"query": {"term": {f"{Field.METADATA_KEY.value}.{key}": value}}}
@ -234,6 +259,7 @@ class OpenSearchVector(BaseVector):
},
}
logger.info(f"Creating OpenSearch index {self._collection_name.lower()}")
self._client.indices.create(index=self._collection_name.lower(), body=index_body)
redis_client.set(collection_exist_cache_key, 1, ex=3600)
@ -252,9 +278,12 @@ class OpenSearchVectorFactory(AbstractVectorFactory):
open_search_config = OpenSearchConfig(
host=dify_config.OPENSEARCH_HOST or "localhost",
port=dify_config.OPENSEARCH_PORT,
secure=dify_config.OPENSEARCH_SECURE,
auth_method=dify_config.OPENSEARCH_AUTH_METHOD.value,
user=dify_config.OPENSEARCH_USER,
password=dify_config.OPENSEARCH_PASSWORD,
secure=dify_config.OPENSEARCH_SECURE,
aws_region=dify_config.OPENSEARCH_AWS_REGION,
aws_service=dify_config.OPENSEARCH_AWS_SERVICE,
)
return OpenSearchVector(collection_name=collection_name, config=open_search_config)

@ -52,6 +52,7 @@ class RerankModelRunner(BaseRerankRunner):
rerank_documents = []
for result in rerank_result.docs:
if score_threshold is None or result.score >= score_threshold:
# format document
rerank_document = Document(
page_content=result.text,
@ -62,4 +63,5 @@ class RerankModelRunner(BaseRerankRunner):
rerank_document.metadata["score"] = result.score
rerank_documents.append(rerank_document)
return rerank_documents
rerank_documents.sort(key=lambda x: x.metadata.get("score", 0.0), reverse=True)
return rerank_documents[:top_n] if top_n else rerank_documents

@ -7,7 +7,7 @@ from collections.abc import Generator, Mapping
from typing import Any, Optional, Union, cast
from flask import Flask, current_app
from sqlalchemy import Integer, and_, or_, text
from sqlalchemy import Float, and_, or_, text
from sqlalchemy import cast as sqlalchemy_cast
from core.app.app_config.entities import (
@ -1005,28 +1005,24 @@ class DatasetRetrieval:
if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] == f'"{value}"')
else:
filters.append(
sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) == value
)
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) == value)
case "is not" | "":
if isinstance(value, str):
filters.append(DatasetDocument.doc_metadata[metadata_name] != f'"{value}"')
else:
filters.append(
sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) != value
)
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) != value)
case "empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(DatasetDocument.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) < value)
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) < value)
case "after" | ">":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) > value)
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) > value)
case "" | "<=":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) <= value)
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) <= value)
case "" | ">=":
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Integer) >= value)
filters.append(sqlalchemy_cast(DatasetDocument.doc_metadata[metadata_name].astext, Float) >= value)
case _:
pass
return filters

@ -2,7 +2,7 @@ METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "=", "", ">", "<", "", "", "before", "after"] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint

@ -159,50 +159,6 @@ class TextSplitter(BaseDocumentTransformer, ABC):
)
return cls(length_function=lambda x: [_huggingface_tokenizer_length(text) for text in x], **kwargs)
@classmethod
def from_tiktoken_encoder(
cls: type[TS],
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], Set[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
) -> TS:
"""Text splitter that uses tiktoken encoder to count length."""
try:
import tiktoken
except ImportError:
raise ImportError(
"Could not import tiktoken python package. "
"This is needed in order to calculate max_tokens_for_prompt. "
"Please install it with `pip install tiktoken`."
)
if model_name is not None:
enc = tiktoken.encoding_for_model(model_name)
else:
enc = tiktoken.get_encoding(encoding_name)
def _tiktoken_encoder(text: str) -> int:
return len(
enc.encode(
text,
allowed_special=allowed_special,
disallowed_special=disallowed_special,
)
)
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"encoding_name": encoding_name,
"model_name": model_name,
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=lambda x: [_tiktoken_encoder(text) for text in x], **kwargs)
def transform_documents(self, documents: Sequence[Document], **kwargs: Any) -> Sequence[Document]:
"""Transform sequence of documents by splitting them."""
return self.split_documents(list(documents))

@ -0,0 +1,41 @@
import base64
import hashlib
import hmac
import os
import time
from configs import dify_config
def sign_tool_file(tool_file_id: str, extension: str) -> str:
"""
sign file to get a temporary url
"""
base_url = dify_config.FILES_URL
file_preview_url = f"{base_url}/files/tools/{tool_file_id}{extension}"
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"file-preview|{tool_file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{file_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
"""
data_to_sign = f"file-preview|{file_id}|{timestamp}|{nonce}"
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT

@ -4,23 +4,34 @@ import hmac
import logging
import os
import time
from collections.abc import Generator
from mimetypes import guess_extension, guess_type
from typing import Optional, Union
from uuid import uuid4
import httpx
from sqlalchemy.orm import Session
from configs import dify_config
from core.helper import ssrf_proxy
from extensions.ext_database import db
from extensions.ext_database import db as global_db
from extensions.ext_storage import storage
from models.model import MessageFile
from models.tools import ToolFile
logger = logging.getLogger(__name__)
from sqlalchemy.engine import Engine
class ToolFileManager:
_engine: Engine
def __init__(self, engine: Engine | None = None):
if engine is None:
engine = global_db.engine
self._engine = engine
@staticmethod
def sign_file(tool_file_id: str, extension: str) -> str:
"""
@ -55,8 +66,8 @@ class ToolFileManager:
current_time = int(time.time())
return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT
@staticmethod
def create_file_by_raw(
self,
*,
user_id: str,
tenant_id: str,
@ -77,6 +88,7 @@ class ToolFileManager:
filepath = f"tools/{tenant_id}/{unique_filename}"
storage.save(filepath, file_binary)
with Session(self._engine, expire_on_commit=False) as session:
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
@ -87,14 +99,14 @@ class ToolFileManager:
size=len(file_binary),
)
db.session.add(tool_file)
db.session.commit()
db.session.refresh(tool_file)
session.add(tool_file)
session.commit()
session.refresh(tool_file)
return tool_file
@staticmethod
def create_file_by_url(
self,
user_id: str,
tenant_id: str,
file_url: str,
@ -119,6 +131,7 @@ class ToolFileManager:
filepath = f"tools/{tenant_id}/{filename}"
storage.save(filepath, blob)
with Session(self._engine, expire_on_commit=False) as session:
tool_file = ToolFile(
user_id=user_id,
tenant_id=tenant_id,
@ -130,13 +143,12 @@ class ToolFileManager:
size=len(blob),
)
db.session.add(tool_file)
db.session.commit()
session.add(tool_file)
session.commit()
return tool_file
@staticmethod
def get_file_binary(id: str) -> Union[tuple[bytes, str], None]:
def get_file_binary(self, id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
@ -144,8 +156,9 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with Session(self._engine, expire_on_commit=False) as session:
tool_file: ToolFile | None = (
db.session.query(ToolFile)
session.query(ToolFile)
.filter(
ToolFile.id == id,
)
@ -159,8 +172,7 @@ class ToolFileManager:
return blob, tool_file.mimetype
@staticmethod
def get_file_binary_by_message_file_id(id: str) -> Union[tuple[bytes, str], None]:
def get_file_binary_by_message_file_id(self, id: str) -> Union[tuple[bytes, str], None]:
"""
get file binary
@ -168,8 +180,9 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with Session(self._engine, expire_on_commit=False) as session:
message_file: MessageFile | None = (
db.session.query(MessageFile)
session.query(MessageFile)
.filter(
MessageFile.id == id,
)
@ -189,7 +202,7 @@ class ToolFileManager:
tool_file_id = None
tool_file: ToolFile | None = (
db.session.query(ToolFile)
session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
)
@ -203,8 +216,7 @@ class ToolFileManager:
return blob, tool_file.mimetype
@staticmethod
def get_file_generator_by_tool_file_id(tool_file_id: str):
def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Optional[Generator], Optional[ToolFile]]:
"""
get file binary
@ -212,8 +224,9 @@ class ToolFileManager:
:return: the binary of the file, mime type
"""
with Session(self._engine, expire_on_commit=False) as session:
tool_file: ToolFile | None = (
db.session.query(ToolFile)
session.query(ToolFile)
.filter(
ToolFile.id == tool_file_id,
)
@ -229,6 +242,11 @@ class ToolFileManager:
# init tool_file_parser
from core.file.tool_file_parser import tool_file_manager
from core.file.tool_file_parser import set_tool_file_manager_factory
def _factory() -> ToolFileManager:
return ToolFileManager()
tool_file_manager["manager"] = ToolFileManager
set_tool_file_manager_factory(_factory)

@ -4,6 +4,7 @@ from pydantic import BaseModel, Field
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.context_entities import DocumentContext
from core.rag.entities.metadata_entities import MetadataCondition
from core.rag.models.document import Document as RetrievalDocument
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
@ -33,6 +34,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
args_schema: type[BaseModel] = DatasetRetrieverToolInput
description: str = "use this to retrieve a dataset. "
dataset_id: str
metadata_filtering_conditions: MetadataCondition
@classmethod
def from_dataset(cls, dataset: Dataset, **kwargs):
@ -46,6 +48,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
description=description,
metadata_filtering_conditions=MetadataCondition(),
**kwargs,
)
@ -65,6 +68,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
dataset_id=dataset.id,
query=query,
external_retrieval_parameters=dataset.retrieval_model,
metadata_condition=self.metadata_filtering_conditions,
)
for external_document in external_documents:
document = RetrievalDocument(

@ -31,8 +31,8 @@ class ToolFileMessageTransformer:
# try to download image
try:
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
file = ToolFileManager.create_file_by_url(
tool_file_manager = ToolFileManager()
file = tool_file_manager.create_file_by_url(
user_id=user_id,
tenant_id=tenant_id,
file_url=message.message.text,
@ -60,7 +60,7 @@ class ToolFileMessageTransformer:
mimetype = meta.get("mime_type", "application/octet-stream")
# get filename from meta
filename = meta.get("file_name", None)
filename = meta.get("filename", None)
# if message is str, encode it to bytes
if not isinstance(message.message, ToolInvokeMessage.BlobMessage):
@ -68,7 +68,8 @@ class ToolFileMessageTransformer:
# FIXME: should do a type check here.
assert isinstance(message.message.blob, bytes)
file = ToolFileManager.create_file_by_raw(
tool_file_manager = ToolFileManager()
file = tool_file_manager.create_file_by_raw(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=conversation_id,

@ -223,8 +223,8 @@ def _extract_text_from_doc(file_content: bytes) -> str:
"""
from unstructured.partition.api import partition_via_api
if not (dify_config.UNSTRUCTURED_API_URL and dify_config.UNSTRUCTURED_API_KEY):
raise TextExtractionError("UNSTRUCTURED_API_URL and UNSTRUCTURED_API_KEY must be set")
if not dify_config.UNSTRUCTURED_API_URL:
raise TextExtractionError("UNSTRUCTURED_API_URL must be set")
try:
with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file:
@ -235,7 +235,7 @@ def _extract_text_from_doc(file_content: bytes) -> str:
file=file,
metadata_filename=temp_file.name,
api_url=dify_config.UNSTRUCTURED_API_URL,
api_key=dify_config.UNSTRUCTURED_API_KEY,
api_key=dify_config.UNSTRUCTURED_API_KEY, # type: ignore
)
os.unlink(temp_file.name)
return "\n".join([getattr(element, "text", "") for element in elements])

@ -262,7 +262,10 @@ class Executor:
headers[authorization.config.header] = f"Bearer {authorization.config.api_key}"
elif self.auth.config.type == "basic":
credentials = authorization.config.api_key
if ":" in credentials:
encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8")
else:
encoded_credentials = credentials
headers[authorization.config.header] = f"Basic {encoded_credentials}"
elif self.auth.config.type == "custom":
headers[authorization.config.header] = authorization.config.api_key or ""

@ -191,8 +191,9 @@ class HttpRequestNode(BaseNode[HttpRequestNodeData]):
mime_type = (
content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream"
)
tool_file_manager = ToolFileManager()
tool_file = ToolFileManager.create_file_by_raw(
tool_file = tool_file_manager.create_file_by_raw(
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=None,

@ -6,7 +6,7 @@ from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Any, Optional, cast
from sqlalchemy import Integer, and_, func, or_, text
from sqlalchemy import Float, and_, func, or_, text
from sqlalchemy import cast as sqlalchemy_cast
from core.app.app_config.entities import DatasetRetrieveConfigEntity
@ -32,11 +32,11 @@ from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_COMPLETION_PROMPT,
METADATA_FILTER_SYSTEM_PROMPT,
METADATA_FILTER_USER_PROMPT_1,
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.question_classifier.template_prompts import QUESTION_CLASSIFIER_USER_PROMPT_2
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown
@ -493,24 +493,24 @@ class KnowledgeRetrievalNode(LLMNode):
if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] == f'"{value}"')
else:
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) == value)
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) == value)
case "is not" | "":
if isinstance(value, str):
filters.append(Document.doc_metadata[metadata_name] != f'"{value}"')
else:
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) != value)
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) != value)
case "empty":
filters.append(Document.doc_metadata[metadata_name].is_(None))
case "not empty":
filters.append(Document.doc_metadata[metadata_name].isnot(None))
case "before" | "<":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) < value)
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) < value)
case "after" | ">":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) > value)
case "" | ">=":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) <= value)
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) > value)
case "" | "<=":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) <= value)
case "" | ">=":
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Integer) >= value)
filters.append(sqlalchemy_cast(Document.doc_metadata[metadata_name].astext, Float) >= value)
case _:
pass
return filters
@ -618,7 +618,7 @@ class KnowledgeRetrievalNode(LLMNode):
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = LLMNodeChatModelMessage(

@ -2,7 +2,7 @@ METADATA_FILTER_SYSTEM_PROMPT = """
### Job Description',
You are a text metadata extract engine that extract text's metadata based on user input and set the metadata value
### Task
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["=", "!=", ">", "<", ">=", "<="] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
Your task is to ONLY extract the metadatas that exist in the input text from the provided metadata list and Use the following operators ["contains", "not contains", "start with", "end with", "is", "is not", "empty", "not empty", "=", "", ">", "<", "", "", "before", "after"] to express logical relationships, then return result in JSON format with the key "metadata_fields" and value "metadata_field_value" and comparison operator "comparison_operator".
### Format
The input text is in the variable input_text. Metadata are specified as a list in the variable metadata_fields.
### Constraint

@ -38,3 +38,8 @@ class MemoryRolePrefixRequiredError(LLMNodeError):
class FileTypeNotSupportError(LLMNodeError):
def __init__(self, *, type_name: str):
super().__init__(f"{type_name} type is not supported by this model")
class UnsupportedPromptContentTypeError(LLMNodeError):
def __init__(self, *, type_name: str) -> None:
super().__init__(f"Prompt content type {type_name} is not supported.")

@ -0,0 +1,160 @@
import mimetypes
import typing as tp
from sqlalchemy import Engine
from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE
from core.file import File, FileTransferMethod, FileType
from core.helper import ssrf_proxy
from core.tools.signature import sign_tool_file
from core.tools.tool_file_manager import ToolFileManager
from models import db as global_db
class LLMFileSaver(tp.Protocol):
"""LLMFileSaver is responsible for save multimodal output returned by
LLM.
"""
def save_binary_string(
self,
data: bytes,
mime_type: str,
file_type: FileType,
extension_override: str | None = None,
) -> File:
"""save_binary_string saves the inline file data returned by LLM.
Currently (2025-04-30), only some of Google Gemini models will return
multimodal output as inline data.
:param data: the contents of the file
:param mime_type: the media type of the file, specified by rfc6838
(https://datatracker.ietf.org/doc/html/rfc6838)
:param file_type: The file type of the inline file.
:param extension_override: Override the auto-detected file extension while saving this file.
The default value is `None`, which means do not override the file extension and guessing it
from the `mime_type` attribute while saving the file.
Setting it to values other than `None` means override the file's extension, and
will bypass the extension guessing saving the file.
Specially, setting it to empty string (`""`) will leave the file extension empty.
When it is not `None` or empty string (`""`), it should be a string beginning with a
dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py`
and `tar.gz` are not.
"""
pass
def save_remote_url(self, url: str, file_type: FileType) -> File:
"""save_remote_url saves the file from a remote url returned by LLM.
Currently (2025-04-30), no model returns multimodel output as a url.
:param url: the url of the file.
:param file_type: the file type of the file, check `FileType` enum for reference.
"""
pass
EngineFactory: tp.TypeAlias = tp.Callable[[], Engine]
class FileSaverImpl(LLMFileSaver):
_engine_factory: EngineFactory
_tenant_id: str
_user_id: str
def __init__(self, user_id: str, tenant_id: str, engine_factory: EngineFactory | None = None):
if engine_factory is None:
def _factory():
return global_db.engine
engine_factory = _factory
self._engine_factory = engine_factory
self._user_id = user_id
self._tenant_id = tenant_id
def _get_tool_file_manager(self):
return ToolFileManager(engine=self._engine_factory())
def save_remote_url(self, url: str, file_type: FileType) -> File:
http_response = ssrf_proxy.get(url)
http_response.raise_for_status()
data = http_response.content
mime_type_from_header = http_response.headers.get("Content-Type")
mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header)
return self.save_binary_string(data, mime_type, file_type, extension_override=extension)
def save_binary_string(
self,
data: bytes,
mime_type: str,
file_type: FileType,
extension_override: str | None = None,
) -> File:
tool_file_manager = self._get_tool_file_manager()
tool_file = tool_file_manager.create_file_by_raw(
user_id=self._user_id,
tenant_id=self._tenant_id,
# TODO(QuantumGhost): what is conversation id?
conversation_id=None,
file_binary=data,
mimetype=mime_type,
)
extension_override = _validate_extension_override(extension_override)
extension = _get_extension(mime_type, extension_override)
url = sign_tool_file(tool_file.id, extension)
return File(
tenant_id=self._tenant_id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
filename=tool_file.name,
extension=extension,
mime_type=mime_type,
size=len(data),
related_id=tool_file.id,
url=url,
# TODO(QuantumGhost): how should I set the following key?
# What's the difference between `remote_url` and `url`?
# What's the purpose of `storage_key` and `dify_model_identity`?
storage_key=tool_file.file_key,
)
def _get_extension(mime_type: str, extension_override: str | None = None) -> str:
"""get_extension return the extension of file.
If the `extension_override` parameter is set, this function should honor it and
return its value.
"""
if extension_override is not None:
return extension_override
return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION
def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]:
"""_extract_content_type_and_extension tries to
guess content type of file from url and `Content-Type` header in response.
"""
if content_type_header:
extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION
return content_type_header, extension
content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE
extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION
return content_type, extension
def _validate_extension_override(extension_override: str | None) -> str | None:
# `extension_override` is allow to be `None or `""`.
if extension_override is None:
return None
if extension_override == "":
return ""
if not extension_override.startswith("."):
raise ValueError("extension_override should start with '.' if not None or empty.", extension_override)
return extension_override

@ -1,3 +1,5 @@
import base64
import io
import json
import logging
from collections.abc import Generator, Mapping, Sequence
@ -21,7 +23,7 @@ from core.model_runtime.entities import (
PromptMessageContentType,
TextPromptMessageContent,
)
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContentUnionTypes,
@ -38,7 +40,6 @@ from core.model_runtime.entities.model_entities import (
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str
from core.plugin.entities.plugin import ModelProviderID
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@ -95,9 +96,13 @@ from .exc import (
TemplateTypeNotSupportError,
VariableNotFoundError,
)
from .file_saver import FileSaverImpl, LLMFileSaver
if TYPE_CHECKING:
from core.file.models import File
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
logger = logging.getLogger(__name__)
@ -106,6 +111,43 @@ class LLMNode(BaseNode[LLMNodeData]):
_node_data_cls = LLMNodeData
_node_type = NodeType.LLM
# Instance attributes specific to LLMNode.
# Output variable for file
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph: "Graph",
graph_runtime_state: "GraphRuntimeState",
previous_node_id: Optional[str] = None,
thread_pool_id: Optional[str] = None,
*,
llm_file_saver: LLMFileSaver | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
previous_node_id=previous_node_id,
thread_pool_id=thread_pool_id,
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs: list[File] = []
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
tenant_id=graph_init_params.tenant_id,
)
self._llm_file_saver = llm_file_saver
def _run(self) -> Generator[NodeEvent | InNodeEvent, None, None]:
def process_structured_output(text: str) -> Optional[dict[str, Any] | list[Any]]:
"""Process structured output if enabled"""
@ -215,6 +257,9 @@ class LLMNode(BaseNode[LLMNodeData]):
structured_output = process_structured_output(result_text)
if structured_output:
outputs["structured_output"] = structured_output
if self._file_outputs is not None:
outputs["files"] = self._file_outputs
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@ -240,6 +285,7 @@ class LLMNode(BaseNode[LLMNodeData]):
)
)
except Exception as e:
logger.exception("error while executing llm node")
yield RunCompletedEvent(
run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
@ -268,44 +314,45 @@ class LLMNode(BaseNode[LLMNodeData]):
return self._handle_invoke_result(invoke_result=invoke_result)
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
def _handle_invoke_result(
self, invoke_result: LLMResult | Generator[LLMResultChunk, None, None]
) -> Generator[NodeEvent, None, None]:
# For blocking mode
if isinstance(invoke_result, LLMResult):
message_text = convert_llm_result_chunk_to_str(invoke_result.message.content)
yield ModelInvokeCompletedEvent(
text=message_text,
usage=invoke_result.usage,
finish_reason=None,
)
event = self._handle_blocking_result(invoke_result=invoke_result)
yield event
return
model = None
# For streaming mode
model = ""
prompt_messages: list[PromptMessage] = []
full_text = ""
usage = None
usage = LLMUsage.empty_usage()
finish_reason = None
full_text_buffer = io.StringIO()
for result in invoke_result:
text = convert_llm_result_chunk_to_str(result.delta.message.content)
full_text += text
contents = result.delta.message.content
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(contents):
full_text_buffer.write(text_part)
yield RunStreamChunkEvent(chunk_content=text_part, from_variable_selector=[self.node_id, "text"])
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])
if not model:
# Update the whole metadata
if not model and result.model:
model = result.model
if not prompt_messages:
prompt_messages = result.prompt_messages
if not usage and result.delta.usage:
if len(prompt_messages) == 0:
# TODO(QuantumGhost): it seems that this update has no visable effect.
# What's the purpose of the line below?
prompt_messages = list(result.prompt_messages)
if usage.prompt_tokens == 0 and result.delta.usage:
usage = result.delta.usage
if not finish_reason and result.delta.finish_reason:
if finish_reason is None and result.delta.finish_reason:
finish_reason = result.delta.finish_reason
if not usage:
usage = LLMUsage.empty_usage()
yield ModelInvokeCompletedEvent(text=full_text_buffer.getvalue(), usage=usage, finish_reason=finish_reason)
yield ModelInvokeCompletedEvent(text=full_text, usage=usage, finish_reason=finish_reason)
def _image_file_to_markdown(self, file: "File", /):
text_chunk = f"![]({file.generate_url()})"
return text_chunk
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
@ -963,6 +1010,42 @@ class LLMNode(BaseNode[LLMNodeData]):
return prompt_messages
def _handle_blocking_result(self, *, invoke_result: LLMResult) -> ModelInvokeCompletedEvent:
buffer = io.StringIO()
for text_part in self._save_multimodal_output_and_convert_result_to_markdown(invoke_result.message.content):
buffer.write(text_part)
return ModelInvokeCompletedEvent(
text=buffer.getvalue(),
usage=invoke_result.usage,
finish_reason=None,
)
def _save_multimodal_image_output(self, content: ImagePromptMessageContent) -> "File":
"""_save_multimodal_output saves multi-modal contents generated by LLM plugins.
There are two kinds of multimodal outputs:
- Inlined data encoded in base64, which would be saved to storage directly.
- Remote files referenced by an url, which would be downloaded and then saved to storage.
Currently, only image files are supported.
"""
# Inject the saver somehow...
_saver = self._llm_file_saver
# If this
if content.url != "":
saved_file = _saver.save_remote_url(content.url, FileType.IMAGE)
else:
saved_file = _saver.save_binary_string(
data=base64.b64decode(content.base64_data),
mime_type=content.mime_type,
file_type=FileType.IMAGE,
)
self._file_outputs.append(saved_file)
return saved_file
def _handle_native_json_schema(self, model_parameters: dict, rules: list[ParameterRule]) -> dict:
"""
Handle structured output for models with native JSON schema support.
@ -1123,6 +1206,41 @@ class LLMNode(BaseNode[LLMNodeData]):
else SupportStructuredOutputStatus.UNSUPPORTED
)
def _save_multimodal_output_and_convert_result_to_markdown(
self,
contents: str | list[PromptMessageContentUnionTypes] | None,
) -> Generator[str, None, None]:
"""Convert intermediate prompt messages into strings and yield them to the caller.
If the messages contain non-textual content (e.g., multimedia like images or videos),
it will be saved separately, and the corresponding Markdown representation will
be yielded to the caller.
"""
# NOTE(QuantumGhost): This function should yield results to the caller immediately
# whenever new content or partial content is available. Avoid any intermediate buffering
# of results. Additionally, do not yield empty strings; instead, yield from an empty list
# if necessary.
if contents is None:
yield from []
return
if isinstance(contents, str):
yield contents
elif isinstance(contents, list):
for item in contents:
if isinstance(item, TextPromptMessageContent):
yield item.data
elif isinstance(item, ImagePromptMessageContent):
file = self._save_multimodal_image_output(item)
self._file_outputs.append(file)
yield self._image_file_to_markdown(file)
else:
logger.warning("unknown item type encountered, type=%s", type(item))
yield str(item)
else:
logger.warning("unknown contents type encountered, type=%s", type(contents))
yield str(contents)
def _combine_message_content_with_role(
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole

@ -11,6 +11,8 @@ class Operation(StrEnum):
SUBTRACT = "-="
MULTIPLY = "*="
DIVIDE = "/="
REMOVE_FIRST = "remove-first"
REMOVE_LAST = "remove-last"
class InputType(StrEnum):

@ -23,6 +23,15 @@ def is_operation_supported(*, variable_type: SegmentType, operation: Operation):
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
case Operation.REMOVE_FIRST | Operation.REMOVE_LAST:
# Only array variable can have elements removed
return variable_type in {
SegmentType.ARRAY_ANY,
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
SegmentType.ARRAY_NUMBER,
SegmentType.ARRAY_FILE,
}
case _:
return False
@ -51,7 +60,7 @@ def is_constant_input_supported(*, variable_type: SegmentType, operation: Operat
def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any):
if operation == Operation.CLEAR:
if operation in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST}:
return True
match variable_type:
case SegmentType.STRING:

@ -64,7 +64,7 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
# Get value from variable pool
if (
item.input_type == InputType.VARIABLE
and item.operation != Operation.CLEAR
and item.operation not in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST}
and item.value is not None
):
value = self.graph_runtime_state.variable_pool.get(item.value)
@ -165,5 +165,15 @@ class VariableAssignerNode(BaseNode[VariableAssignerNodeData]):
return variable.value * value
case Operation.DIVIDE:
return variable.value / value
case Operation.REMOVE_FIRST:
# If array is empty, do nothing
if not variable.value:
return variable.value
return variable.value[1:]
case Operation.REMOVE_LAST:
# If array is empty, do nothing
if not variable.value:
return variable.value
return variable.value[:-1]
case _:
raise OperationNotSupportedError(operation=operation, variable_type=variable.value_type)

@ -9,6 +9,7 @@ from core.app.apps.base_app_queue_manager import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.workflow.callbacks import WorkflowCallback
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.graph_engine.entities.event import GraphEngineEvent, GraphRunFailedEvent, InNodeEvent
@ -364,4 +365,5 @@ class WorkflowEntry:
input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id)
# append variable and value to variable pool
if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID:
variable_pool.add([variable_node_id] + variable_key_list, input_value)

@ -10,4 +10,16 @@ POSTGRES_INDEXES_NAMING_CONVENTION = {
}
metadata = MetaData(naming_convention=POSTGRES_INDEXES_NAMING_CONVENTION)
# ****** IMPORTANT NOTICE ******
#
# NOTE(QuantumGhost): Avoid directly importing and using `db` in modules outside of the
# `controllers` package.
#
# Instead, import `db` within the `controllers` package and pass it as an argument to
# functions or class constructors.
#
# Directly importing `db` in other modules can make the code more difficult to read, test, and maintain.
#
# Whenever possible, avoid this pattern in new code.
db = SQLAlchemy(metadata=metadata)

@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, cast
from core.plugin.entities.plugin import GenericProviderID
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.signature import sign_tool_file
from services.plugin.plugin_service import PluginService
if TYPE_CHECKING:
@ -23,7 +24,6 @@ from configs import dify_config
from constants import DEFAULT_FILE_NUMBER_LIMITS
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType
from core.file import helpers as file_helpers
from core.file.tool_file_parser import ToolFileParser
from libs.helper import generate_string
from models.base import Base
from models.enums import CreatedByRole
@ -986,9 +986,7 @@ class Message(db.Model): # type: ignore[name-defined]
if not tool_file_id:
continue
sign_url = ToolFileParser.get_tool_file_manager().sign_file(
tool_file_id=tool_file_id, extension=extension
)
sign_url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
elif "file-preview" in url:
# get upload file id
upload_file_id_pattern = r"\/files\/([\w-]+)\/file-preview?\?timestamp="

@ -263,8 +263,8 @@ class ToolConversationVariables(Base):
class ToolFile(Base):
"""
store the file created by agent
"""This table stores file metadata generated in workflows,
not only files created by agent.
"""
__tablename__ = "tool_files"

@ -10,7 +10,7 @@ dependencies = [
"boto3==1.35.99",
"bs4~=0.0.1",
"cachetools~=5.3.0",
"celery~=5.4.0",
"celery~=5.5.2",
"chardet~=5.1.0",
"flask~=3.1.0",
"flask-compress~=1.17",
@ -77,11 +77,10 @@ dependencies = [
"sentry-sdk[flask]~=1.44.1",
"sqlalchemy~=2.0.29",
"starlette==0.41.0",
"tiktoken~=0.8.0",
"tiktoken~=0.9.0",
"tokenizers~=0.15.0",
"transformers~=4.35.0",
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
"validators==0.21.0",
"weave~=0.51.34",
"yarl~=1.18.3",
"webvtt-py~=0.5.1",
@ -121,6 +120,7 @@ dev = [
"types-defusedxml~=0.7.0",
"types-deprecated~=1.2.15",
"types-docutils~=0.21.0",
"types-jsonschema~=4.23.0",
"types-flask-cors~=5.0.0",
"types-flask-migrate~=4.1.0",
"types-gevent~=24.11.0",
@ -196,6 +196,6 @@ vdb = [
"tidb-vector==0.0.9",
"upstash-vector==0.6.0",
"volcengine-compat~=1.0.156",
"weaviate-client~=3.21.0",
"weaviate-client~=3.24.0",
"xinference-client~=1.2.2",
]

@ -2,9 +2,9 @@ import json
from copy import deepcopy
from datetime import UTC, datetime
from typing import Any, Optional, Union, cast
from urllib.parse import urlparse
import httpx
import validators
from constants import HIDDEN_VALUE
from core.helper import ssrf_proxy
@ -72,7 +72,9 @@ class ExternalDatasetService:
endpoint = f"{settings['endpoint']}/retrieval"
api_key = settings["api_key"]
if not validators.url(endpoint, simple_host=True):
parsed_url = urlparse(endpoint)
if not all([parsed_url.scheme, parsed_url.netloc]):
if not endpoint.startswith("http://") and not endpoint.startswith("https://"):
raise ValueError(f"invalid endpoint: {endpoint} must start with http:// or https://")
else:

@ -5,6 +5,8 @@ import httpx
import pytest
from _pytest.monkeypatch import MonkeyPatch
from core.helper import ssrf_proxy
class MockedHttp:
@staticmethod
@ -29,6 +31,6 @@ class MockedHttp:
@pytest.fixture
def setup_http_mock(request, monkeypatch: MonkeyPatch):
monkeypatch.setattr(httpx, "request", MockedHttp.httpx_request)
monkeypatch.setattr(ssrf_proxy, "make_request", MockedHttp.httpx_request)
yield
monkeypatch.undo()

@ -34,10 +34,11 @@ parameters = {
def test_api_tool(setup_http_mock):
tool = ApiTool(
entity=ToolEntity(
identity=ToolIdentity(provider="", author="", name="", label=I18nObject()),
identity=ToolIdentity(provider="", author="", name="", label=I18nObject(en_US="test tool")),
),
api_bundle=ApiToolBundle(**tool_bundle),
runtime=ToolRuntime(tenant_id="", credentials={"auth_type": "none"}),
provider_id="test_tool",
)
headers = tool.assembling_request(parameters)
response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters)

@ -23,13 +23,70 @@ def setup_mock_redis():
ext_redis.redis_client.lock = MagicMock(return_value=mock_redis_lock)
class TestOpenSearchConfig:
def test_to_opensearch_params(self):
config = OpenSearchConfig(
host="localhost",
port=9200,
secure=True,
user="admin",
password="password",
)
params = config.to_opensearch_params()
assert params["hosts"] == [{"host": "localhost", "port": 9200}]
assert params["use_ssl"] is True
assert params["verify_certs"] is True
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
assert params["http_auth"] == ("admin", "password")
@patch("boto3.Session")
@patch("core.rag.datasource.vdb.opensearch.opensearch_vector.Urllib3AWSV4SignerAuth")
def test_to_opensearch_params_with_aws_managed_iam(
self, mock_aws_signer_auth: MagicMock, mock_boto_session: MagicMock
):
mock_credentials = MagicMock()
mock_boto_session.return_value.get_credentials.return_value = mock_credentials
mock_auth_instance = MagicMock()
mock_aws_signer_auth.return_value = mock_auth_instance
aws_region = "ap-southeast-2"
aws_service = "aoss"
host = f"aoss-endpoint.{aws_region}.aoss.amazonaws.com"
port = 9201
config = OpenSearchConfig(
host=host,
port=port,
secure=True,
auth_method="aws_managed_iam",
aws_region=aws_region,
aws_service=aws_service,
)
params = config.to_opensearch_params()
assert params["hosts"] == [{"host": host, "port": port}]
assert params["use_ssl"] is True
assert params["verify_certs"] is True
assert params["connection_class"].__name__ == "Urllib3HttpConnection"
assert params["http_auth"] is mock_auth_instance
mock_aws_signer_auth.assert_called_once_with(
credentials=mock_credentials, region=aws_region, service=aws_service
)
assert mock_boto_session.return_value.get_credentials.called
class TestOpenSearchVector:
def setup_method(self):
self.collection_name = "test_collection"
self.example_doc_id = "example_doc_id"
self.vector = OpenSearchVector(
collection_name=self.collection_name,
config=OpenSearchConfig(host="localhost", port=9200, user="admin", password="password", secure=False),
config=OpenSearchConfig(host="localhost", port=9200, secure=False, user="admin", password="password"),
)
self.vector._client = MagicMock()

@ -100,3 +100,9 @@ def test_flask_configs(example_env_file):
assert str(config["CODE_EXECUTION_ENDPOINT"]) == "http://sandbox:8194/"
assert str(URL(str(config["CODE_EXECUTION_ENDPOINT"])) / "v1") == "http://sandbox:8194/v1"
def test_inner_api_config_exist():
config = DifyConfig()
assert config.INNER_API is False
assert config.INNER_API_KEY is None

@ -864,10 +864,11 @@ def test_condition_parallel_correct_output(mock_close, mock_remove, app):
with patch.object(CodeNode, "_run", new=code_generator):
generator = graph_engine.run()
stream_content = ""
res_content = "VAT:\ndify 123"
wrong_content = ["Stamp Duty", "other"]
for item in generator:
if isinstance(item, NodeRunStreamChunkEvent):
stream_content += f"{item.chunk_content}\n"
if isinstance(item, GraphRunSucceededEvent):
assert item.outputs == {"answer": res_content}
assert stream_content == res_content + "\n"
assert item.outputs is not None
answer = item.outputs["answer"]
assert all(rc not in answer for rc in wrong_content)

@ -0,0 +1,192 @@
import uuid
from typing import NamedTuple
from unittest import mock
import httpx
import pytest
from sqlalchemy import Engine
from core.file import FileTransferMethod, FileType, models
from core.helper import ssrf_proxy
from core.tools import signature
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.nodes.llm.file_saver import (
FileSaverImpl,
_extract_content_type_and_extension,
_get_extension,
_validate_extension_override,
)
from models import ToolFile
_PNG_DATA = b"\x89PNG\r\n\x1a\n"
def _gen_id():
return str(uuid.uuid4())
class TestFileSaverImpl:
def test_save_binary_string(self, monkeypatch):
user_id = _gen_id()
tenant_id = _gen_id()
file_type = FileType.IMAGE
mime_type = "image/png"
mock_signed_url = "https://example.com/image.png"
mock_tool_file = ToolFile(
id=_gen_id(),
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_key="test-file-key",
mimetype=mime_type,
original_url=None,
name=f"{_gen_id()}.png",
size=len(_PNG_DATA),
)
mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager)
mocked_engine = mock.MagicMock(spec=Engine)
mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file
monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager)
# Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here.
mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file)
# Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here.
monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file)
mocked_sign_file.return_value = mock_signed_url
storage_file_manager = FileSaverImpl(
user_id=user_id,
tenant_id=tenant_id,
engine_factory=mocked_engine,
)
file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type)
assert file.tenant_id == tenant_id
assert file.type == file_type
assert file.transfer_method == FileTransferMethod.TOOL_FILE
assert file.extension == ".png"
assert file.mime_type == mime_type
assert file.size == len(_PNG_DATA)
assert file.related_id == mock_tool_file.id
assert file.generate_url() == mock_signed_url
mocked_tool_file_manager.create_file_by_raw.assert_called_once_with(
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_binary=_PNG_DATA,
mimetype=mime_type,
)
mocked_sign_file.assert_called_once_with(mock_tool_file.id, ".png")
def test_save_remote_url_request_failed(self, monkeypatch):
_TEST_URL = "https://example.com/image.png"
mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response(
status_code=401,
request=mock_request,
)
file_saver = FileSaverImpl(
user_id=_gen_id(),
tenant_id=_gen_id(),
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
with pytest.raises(httpx.HTTPStatusError) as exc:
file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
mock_get.assert_called_once_with(_TEST_URL)
assert exc.value.response.status_code == 401
def test_save_remote_url_success(self, monkeypatch):
_TEST_URL = "https://example.com/image.png"
mime_type = "image/png"
user_id = _gen_id()
tenant_id = _gen_id()
mock_request = httpx.Request("GET", _TEST_URL)
mock_response = httpx.Response(
status_code=200,
content=b"test-data",
headers={"Content-Type": mime_type},
request=mock_request,
)
file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id)
mock_tool_file = ToolFile(
id=_gen_id(),
user_id=user_id,
tenant_id=tenant_id,
conversation_id=None,
file_key="test-file-key",
mimetype=mime_type,
original_url=None,
name=f"{_gen_id()}.png",
size=len(_PNG_DATA),
)
mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response)
monkeypatch.setattr(ssrf_proxy, "get", mock_get)
mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file)
monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string)
file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE)
mock_save_binary_string.assert_called_once_with(
mock_response.content,
mime_type,
FileType.IMAGE,
extension_override=".png",
)
assert file == mock_tool_file
def test_validate_extension_override():
class TestCase(NamedTuple):
extension_override: str | None
expected: str | None
cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"]
for valid_ext_override in [None, "", ".png", ".tar.gz"]:
assert valid_ext_override == _validate_extension_override(valid_ext_override)
for invalid_ext_override in ["png", "tar.gz"]:
with pytest.raises(ValueError) as exc:
_validate_extension_override(invalid_ext_override)
class TestExtractContentTypeAndExtension:
def test_with_both_content_type_and_extension(self):
content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png")
assert content_type == "image/png"
assert extension == ".png"
def test_url_with_file_extension(self):
for content_type in [None, ""]:
content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type)
assert content_type == "image/png"
assert extension == ".png"
def test_response_with_content_type(self):
content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png")
assert content_type == "image/png"
assert extension == ".png"
def test_no_content_type_and_no_extension(self):
for content_type in [None, ""]:
content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type)
assert content_type == "application/octet-stream"
assert extension == ".bin"
class TestGetExtension:
def test_with_extension_override(self):
mime_type = "image/png"
for override in [".jpg", ""]:
extension = _get_extension(mime_type, override)
assert extension == override
def test_without_extension_override(self):
mime_type = "image/png"
extension = _get_extension(mime_type)
assert extension == ".png"

@ -1,5 +1,8 @@
import base64
import uuid
from collections.abc import Sequence
from typing import Optional
from unittest import mock
import pytest
@ -30,6 +33,7 @@ from core.workflow.nodes.llm.entities import (
VisionConfig,
VisionConfigOptions,
)
from core.workflow.nodes.llm.file_saver import LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
from models.enums import UserFrom
from models.provider import ProviderType
@ -49,8 +53,8 @@ class MockTokenBufferMemory:
@pytest.fixture
def llm_node():
data = LLMNodeData(
def llm_node_data() -> LLMNodeData:
return LLMNodeData(
title="Test LLM",
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[],
@ -64,17 +68,11 @@ def llm_node():
),
),
)
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
node = LLMNode(
id="1",
config={
"id": "1",
"data": data.model_dump(),
},
graph_init_params=GraphInitParams(
@pytest.fixture
def graph_init_params() -> GraphInitParams:
return GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
@ -84,8 +82,12 @@ def llm_node():
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
),
graph=Graph(
)
@pytest.fixture
def graph() -> Graph:
return Graph(
root_node_id="1",
answer_stream_generate_routes=AnswerStreamGenerateRoute(
answer_dependencies={},
@ -95,11 +97,36 @@ def llm_node():
end_dependencies={},
end_stream_variable_selector_mapping={},
),
),
graph_runtime_state=GraphRuntimeState(
)
@pytest.fixture
def graph_runtime_state() -> GraphRuntimeState:
variable_pool = VariablePool(
system_variables={},
user_inputs={},
)
return GraphRuntimeState(
variable_pool=variable_pool,
start_at=0,
),
)
@pytest.fixture
def llm_node(
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph: Graph, graph_runtime_state: GraphRuntimeState
) -> LLMNode:
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
node = LLMNode(
id="1",
config={
"id": "1",
"data": llm_node_data.model_dump(),
},
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
return node
@ -465,3 +492,167 @@ def test_handle_list_messages_basic(llm_node):
assert len(result) == 1
assert isinstance(result[0], UserPromptMessage)
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
@pytest.fixture
def llm_node_for_multimodal(
llm_node_data, graph_init_params, graph, graph_runtime_state
) -> tuple[LLMNode, LLMFileSaver]:
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
node = LLMNode(
id="1",
config={
"id": "1",
"data": llm_node_data.model_dump(),
},
graph_init_params=graph_init_params,
graph=graph,
graph_runtime_state=graph_runtime_state,
llm_file_saver=mock_file_saver,
)
return node, mock_file_saver
class TestLLMNodeSaveMultiModalImageOutput:
def test_llm_node_save_inline_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]):
llm_node, mock_file_saver = llm_node_for_multimodal
content = ImagePromptMessageContent(
format="png",
base64_data=base64.b64encode(b"test-data").decode(),
mime_type="image/png",
)
mock_file = File(
id=str(uuid.uuid4()),
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=str(uuid.uuid4()),
filename="test-file.png",
extension=".png",
mime_type="image/png",
size=9,
)
mock_file_saver.save_binary_string.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
mock_file_saver.save_binary_string.assert_called_once_with(
data=b"test-data", mime_type="image/png", file_type=FileType.IMAGE
)
def test_llm_node_save_url_output(self, llm_node_for_multimodal: tuple[LLMNode, LLMFileSaver]):
llm_node, mock_file_saver = llm_node_for_multimodal
content = ImagePromptMessageContent(
format="png",
url="https://example.com/image.png",
mime_type="image/jpg",
)
mock_file = File(
id=str(uuid.uuid4()),
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
related_id=str(uuid.uuid4()),
filename="test-file.png",
extension=".png",
mime_type="image/png",
size=9,
)
mock_file_saver.save_remote_url.return_value = mock_file
file = llm_node._save_multimodal_image_output(content=content)
assert llm_node._file_outputs == [mock_file]
assert file == mock_file
mock_file_saver.save_remote_url.assert_called_once_with(content.url, FileType.IMAGE)
def test_llm_node_image_file_to_markdown(llm_node: LLMNode):
mock_file = mock.MagicMock(spec=File)
mock_file.generate_url.return_value = "https://example.com/image.png"
markdown = llm_node._image_file_to_markdown(mock_file)
assert markdown == "![](https://example.com/image.png)"
class TestSaveMultimodalOutputAndConvertResultToMarkdown:
def test_str_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown("hello world")
assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_text_prompt_message_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[TextPromptMessageContent(data="hello world")]
)
assert list(gen) == ["hello world"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_image_content_with_inline_data(self, llm_node_for_multimodal, monkeypatch):
llm_node, mock_file_saver = llm_node_for_multimodal
image_raw_data = b"PNG_DATA"
image_b64_data = base64.b64encode(image_raw_data).decode()
mock_saved_file = File(
id=str(uuid.uuid4()),
tenant_id="1",
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
filename="test.png",
extension=".png",
size=len(image_raw_data),
related_id=str(uuid.uuid4()),
url="https://example.com/test.png",
storage_key="test_storage_key",
)
mock_file_saver.save_binary_string.return_value = mock_saved_file
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(
[
ImagePromptMessageContent(
format="png",
base64_data=image_b64_data,
mime_type="image/png",
)
]
)
yielded_strs = list(gen)
assert len(yielded_strs) == 1
# This assertion requires careful handling.
# `FILES_URL` settings can vary across environments, which might lead to fragile tests.
#
# Rather than asserting the complete URL returned by _save_multimodal_output_and_convert_result_to_markdown,
# we verify that the result includes the markdown image syntax and the expected file URL path.
expected_file_url_path = f"/files/tools/{mock_saved_file.related_id}.png"
assert yielded_strs[0].startswith("![](")
assert expected_file_url_path in yielded_strs[0]
assert yielded_strs[0].endswith(")")
mock_file_saver.save_binary_string.assert_called_once_with(
data=image_raw_data,
mime_type="image/png",
file_type=FileType.IMAGE,
)
assert mock_saved_file in llm_node._file_outputs
def test_unknown_content_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(frozenset(["hello world"]))
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_unknown_item_type(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown([frozenset(["hello world"])])
assert list(gen) == ["frozenset({'hello world'})"]
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()
def test_none_content(self, llm_node_for_multimodal):
llm_node, mock_file_saver = llm_node_for_multimodal
gen = llm_node._save_multimodal_output_and_convert_result_to_markdown(None)
assert list(gen) == []
mock_file_saver.save_binary_string.assert_not_called()
mock_file_saver.save_remote_url.assert_not_called()

@ -0,0 +1,390 @@
import time
import uuid
from uuid import uuid4
from core.app.entities.app_invoke_entities import InvokeFrom
from core.variables import ArrayStringVariable
from core.workflow.entities.variable_pool import VariablePool
from core.workflow.enums import SystemVariableKey
from core.workflow.graph_engine.entities.graph import Graph
from core.workflow.graph_engine.entities.graph_init_params import GraphInitParams
from core.workflow.graph_engine.entities.graph_runtime_state import GraphRuntimeState
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode
from core.workflow.nodes.variable_assigner.v2.enums import InputType, Operation
from models.enums import UserFrom
from models.workflow import WorkflowType
DEFAULT_NODE_ID = "node_id"
def test_handle_item_directly():
"""Test the _handle_item method directly for remove operations."""
# Create variables
variable1 = ArrayStringVariable(
id=str(uuid4()),
name="test_variable1",
value=["first", "second", "third"],
)
variable2 = ArrayStringVariable(
id=str(uuid4()),
name="test_variable2",
value=["first", "second", "third"],
)
# Create a mock class with just the _handle_item method
class MockNode:
def _handle_item(self, *, variable, operation, value):
match operation:
case Operation.REMOVE_FIRST:
if not variable.value:
return variable.value
return variable.value[1:]
case Operation.REMOVE_LAST:
if not variable.value:
return variable.value
return variable.value[:-1]
node = MockNode()
# Test remove-first
result1 = node._handle_item(
variable=variable1,
operation=Operation.REMOVE_FIRST,
value=None,
)
# Test remove-last
result2 = node._handle_item(
variable=variable2,
operation=Operation.REMOVE_LAST,
value=None,
)
# Check the results
assert result1 == ["second", "third"]
assert result2 == ["first", "second"]
def test_remove_first_from_array():
"""Test removing the first element from an array."""
graph_config = {
"edges": [
{
"id": "start-source-assigner-target",
"source": "start",
"target": "assigner",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "assigner",
},
"id": "assigner",
},
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
conversation_variable = ArrayStringVariable(
id=str(uuid4()),
name="test_conversation_variable",
value=["first", "second", "third"],
selector=["conversation", "test_conversation_variable"],
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
},
)
# Skip the mock assertion since we're in a test environment
# Print the variable before running
print(f"Before: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
# Run the node
result = list(node.run())
# Print the variable after running and the result
print(f"After: {variable_pool.get(['conversation', conversation_variable.name]).to_object()}")
print(f"Result: {result}")
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
assert got.to_object() == ["second", "third"]
def test_remove_last_from_array():
"""Test removing the last element from an array."""
graph_config = {
"edges": [
{
"id": "start-source-assigner-target",
"source": "start",
"target": "assigner",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "assigner",
},
"id": "assigner",
},
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
conversation_variable = ArrayStringVariable(
id=str(uuid4()),
name="test_conversation_variable",
value=["first", "second", "third"],
selector=["conversation", "test_conversation_variable"],
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
},
)
# Skip the mock assertion since we're in a test environment
list(node.run())
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
assert got.to_object() == ["first", "second"]
def test_remove_first_from_empty_array():
"""Test removing the first element from an empty array (should do nothing)."""
graph_config = {
"edges": [
{
"id": "start-source-assigner-target",
"source": "start",
"target": "assigner",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "assigner",
},
"id": "assigner",
},
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
conversation_variable = ArrayStringVariable(
id=str(uuid4()),
name="test_conversation_variable",
value=[],
selector=["conversation", "test_conversation_variable"],
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_FIRST,
"value": None,
}
],
},
},
)
# Skip the mock assertion since we're in a test environment
list(node.run())
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
assert got.to_object() == []
def test_remove_last_from_empty_array():
"""Test removing the last element from an empty array (should do nothing)."""
graph_config = {
"edges": [
{
"id": "start-source-assigner-target",
"source": "start",
"target": "assigner",
},
],
"nodes": [
{"data": {"type": "start"}, "id": "start"},
{
"data": {
"type": "assigner",
},
"id": "assigner",
},
],
}
graph = Graph.init(graph_config=graph_config)
init_params = GraphInitParams(
tenant_id="1",
app_id="1",
workflow_type=WorkflowType.WORKFLOW,
workflow_id="1",
graph_config=graph_config,
user_id="1",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
conversation_variable = ArrayStringVariable(
id=str(uuid4()),
name="test_conversation_variable",
value=[],
selector=["conversation", "test_conversation_variable"],
)
variable_pool = VariablePool(
system_variables={SystemVariableKey.CONVERSATION_ID: "conversation_id"},
user_inputs={},
environment_variables=[],
conversation_variables=[conversation_variable],
)
node = VariableAssignerNode(
id=str(uuid.uuid4()),
graph_init_params=init_params,
graph=graph,
graph_runtime_state=GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()),
config={
"id": "node_id",
"data": {
"title": "test",
"version": "2",
"items": [
{
"variable_selector": ["conversation", conversation_variable.name],
"input_type": InputType.VARIABLE,
"operation": Operation.REMOVE_LAST,
"value": None,
}
],
},
},
)
# Skip the mock assertion since we're in a test environment
list(node.run())
got = variable_pool.get(["conversation", conversation_variable.name])
assert got is not None
assert got.to_object() == []

File diff suppressed because it is too large Load Diff

@ -3,5 +3,5 @@
set -x
# run mypy checks
uv run --directory api --dev \
python -m mypy --install-types --non-interactive .
uv run --directory api --dev --with pip \
python -m mypy --install-types --non-interactive --cache-fine-grained --sqlite-cache .

@ -1,4 +1,4 @@
#!/bin/bash
set -x
pytest api/tests/integration_tests/tools/test_all_provider.py
pytest api/tests/integration_tests/tools

@ -39,6 +39,12 @@ APP_WEB_URL=
# File preview or download Url prefix.
# used to display File preview or download Url to the front-end or as Multi-model inputs;
# Url is signed and has expiration time.
# Setting FILES_URL is required for file processing plugins.
# - For https://example.com, use FILES_URL=https://example.com
# - For http://example.com, use FILES_URL=http://example.com
# Recommendation: use a dedicated domain (e.g., https://upload.example.com).
# Alternatively, use http://<your-ip>:5001 or http://api:5001,
# ensuring port 5001 is externally accessible (see docker-compose.yaml).
FILES_URL=
# ------------------------------
@ -520,9 +526,13 @@ RELYT_DATABASE=postgres
# open search configuration, only available when VECTOR_STORE is `opensearch`
OPENSEARCH_HOST=opensearch
OPENSEARCH_PORT=9200
OPENSEARCH_SECURE=true
OPENSEARCH_AUTH_METHOD=basic
OPENSEARCH_USER=admin
OPENSEARCH_PASSWORD=admin
OPENSEARCH_SECURE=true
# If using AWS managed IAM, e.g. Managed Cluster or OpenSearch Serverless
OPENSEARCH_AWS_REGION=ap-southeast-1
OPENSEARCH_AWS_SERVICE=aoss
# tencent vector configurations, only available when VECTOR_STORE is `tencent`
TENCENT_VECTOR_DB_URL=http://127.0.0.1

@ -14,7 +14,6 @@ Welcome to the new `docker` directory for deploying Dify using Docker Compose. T
- **Unified Vector Database Services**: All vector database services are now managed from a single Docker Compose file `docker-compose.yaml`. You can switch between different vector databases by setting the `VECTOR_STORE` environment variable in your `.env` file.
- **Mandatory .env File**: A `.env` file is now required to run `docker compose up`. This file is crucial for configuring your deployment and for any custom settings to persist through upgrades.
- **Legacy Support**: Previous deployment files are now located in the `docker-legacy` directory and will no longer be maintained.
### How to Deploy Dify with `docker-compose.yaml`

Binary file not shown.

Before

Width:  |  Height:  |  Size: 62 KiB

After

Width:  |  Height:  |  Size: 170 KiB

@ -225,9 +225,12 @@ x-shared-env: &shared-api-worker-env
RELYT_DATABASE: ${RELYT_DATABASE:-postgres}
OPENSEARCH_HOST: ${OPENSEARCH_HOST:-opensearch}
OPENSEARCH_PORT: ${OPENSEARCH_PORT:-9200}
OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true}
OPENSEARCH_AUTH_METHOD: ${OPENSEARCH_AUTH_METHOD:-basic}
OPENSEARCH_USER: ${OPENSEARCH_USER:-admin}
OPENSEARCH_PASSWORD: ${OPENSEARCH_PASSWORD:-admin}
OPENSEARCH_SECURE: ${OPENSEARCH_SECURE:-true}
OPENSEARCH_AWS_REGION: ${OPENSEARCH_AWS_REGION:-ap-southeast-1}
OPENSEARCH_AWS_SERVICE: ${OPENSEARCH_AWS_SERVICE:-aoss}
TENCENT_VECTOR_DB_URL: ${TENCENT_VECTOR_DB_URL:-http://127.0.0.1}
TENCENT_VECTOR_DB_API_KEY: ${TENCENT_VECTOR_DB_API_KEY:-dify}
TENCENT_VECTOR_DB_TIMEOUT: ${TENCENT_VECTOR_DB_TIMEOUT:-30}

@ -87,7 +87,7 @@ const AppPublisher = ({
const setAppDetail = useAppStore(s => s.setAppDetail)
const { app_base_url: appBaseURL = '', access_token: accessToken = '' } = appDetail?.site ?? {}
const appMode = (appDetail?.mode !== 'completion' && appDetail?.mode !== 'workflow') ? 'chat' : appDetail.mode
const appURL = `${appBaseURL}/${basePath}/${appMode}/${accessToken}`
const appURL = `${appBaseURL}${basePath}/${appMode}/${accessToken}`
const isChatApp = ['chat', 'agent-chat', 'completion'].includes(appDetail?.mode || '')
const { data: useCanAccessApp, isLoading: isGettingUserCanAccessApp, refetch } = useGetUserCanAccessApp({ appId: appDetail?.id, enabled: false })
const { data: appAccessSubjects, isLoading: isGettingAppWhiteListSubjects } = useAppWhiteListSubjects(appDetail?.id, open && appDetail?.access_mode === AccessMode.SPECIFIC_GROUPS_MEMBERS)

@ -429,6 +429,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
text_to_speech: {
enabled: true,
},
questionEditEnable: false,
supportAnnotation: true,
annotation_reply: {
enabled: true,
@ -484,6 +485,7 @@ function DetailPanel({ detail, onFeedback }: IDetailPanel) {
text_to_speech: {
enabled: true,
},
questionEditEnable: false,
supportAnnotation: true,
annotation_reply: {
enabled: true,

@ -265,6 +265,7 @@ const Chat: FC<ChatProps> = ({
item={item}
questionIcon={questionIcon}
theme={themeBuilder?.theme}
enableEdit={config?.questionEditEnable}
switchSibling={switchSibling}
/>
)

@ -28,6 +28,7 @@ type QuestionProps = {
item: ChatItem
questionIcon?: ReactNode
theme: Theme | null | undefined
enableEdit?: boolean
switchSibling?: (siblingMessageId: string) => void
}
@ -35,6 +36,7 @@ const Question: FC<QuestionProps> = ({
item,
questionIcon,
theme,
enableEdit = true,
switchSibling,
}) => {
const { t } = useTranslation()
@ -87,9 +89,9 @@ const Question: FC<QuestionProps> = ({
}}>
<RiClipboardLine className='h-4 w-4' />
</ActionButton>
<ActionButton onClick={handleEdit}>
{enableEdit && <ActionButton onClick={handleEdit}>
<RiEditLine className='h-4 w-4' />
</ActionButton>
</ActionButton>}
</div>
</div>
<div

@ -46,6 +46,7 @@ export type EnableType = {
export type ChatConfig = Omit<ModelConfig, 'model'> & {
supportAnnotation?: boolean
appId?: string
questionEditEnable?: boolean
supportFeedback?: boolean
supportCitationHitInfo?: boolean
}

@ -32,6 +32,7 @@ const FileImageItem = ({
}: FileImageItemProps) => {
const { id, progress, base64Url, url, name } = file
const [imagePreviewUrl, setImagePreviewUrl] = useState('')
const download_url = url ? `${url}&as_attachment=true` : base64Url
return (
<>
@ -84,7 +85,7 @@ const FileImageItem = ({
className='absolute bottom-0.5 right-0.5 flex h-6 w-6 items-center justify-center rounded-lg bg-components-actionbar-bg shadow-md'
onClick={(e) => {
e.stopPropagation()
downloadFile(url || base64Url || '', name)
downloadFile(download_url || '', name)
}}
>
<RiDownloadLine className='h-4 w-4 text-text-tertiary' />

@ -45,6 +45,7 @@ const FileItem = ({
let tmp_preview_url = url || base64Url
if (!tmp_preview_url && file?.originalFile)
tmp_preview_url = URL.createObjectURL(file.originalFile.slice()).toString()
const download_url = url ? `${url}&as_attachment=true` : base64Url
return (
<>
@ -93,13 +94,13 @@ const FileItem = ({
}
</div>
{
showDownloadAction && tmp_preview_url && (
showDownloadAction && download_url && (
<ActionButton
size='m'
className='absolute -right-1 -top-1 hidden group-hover/file-item:flex'
onClick={(e) => {
e.stopPropagation()
downloadFile(tmp_preview_url || '', name)
downloadFile(download_url || '', name)
}}
>
<RiDownloadLine className='h-3.5 w-3.5 text-text-tertiary' />

@ -121,7 +121,7 @@ export function PreCode(props: { children: any }) {
// visit https://reactjs.org/docs/error-decoder.html?invariant=185 for the full message
// or use the non-minified dev environment for full errors and additional helpful warnings.
const CodeBlock: any = memo(({ inline, className, children, ...props }: any) => {
const CodeBlock: any = memo(({ inline, className, children = '', ...props }: any) => {
const { theme } = useTheme()
const [isSVG, setIsSVG] = useState(true)
const match = /language-(\w+)/.exec(className || '')
@ -258,7 +258,7 @@ const Link = ({ node, children, ...props }: any) => {
const { onSend } = useChatContext()
const hidden_text = decodeURIComponent(node.properties.href.toString().split('abbr:')[1])
return <abbr className="cursor-pointer underline !decoration-primary-700 decoration-dashed" onClick={() => onSend?.(hidden_text)} title={node.children[0]?.value}>{node.children[0]?.value}</abbr>
return <abbr className="cursor-pointer underline !decoration-primary-700 decoration-dashed" onClick={() => onSend?.(hidden_text)} title={node.children[0]?.value || ''}>{node.children[0]?.value || ''}</abbr>
}
else {
return <a {...props} target="_blank" className="cursor-pointer underline !decoration-primary-700 decoration-dashed">{children || 'Download'}</a>

@ -22,10 +22,6 @@ export function preprocessMermaidCode(code: string): string {
.replace(/section\s+([^:]+):/g, (match, sectionName) => `section ${sectionName}`)
// Fix common syntax issues
.replace(/fifopacket/g, 'rect')
// Ensure graph has direction
.replace(/^graph\s+((?:TB|BT|RL|LR)*)/, (match, direction) => {
return direction ? match : 'graph TD'
})
// Clean up empty lines and extra spaces
.trim()
}

@ -13,7 +13,7 @@ import { CodeNode } from '@lexical/code'
import { LexicalComposer } from '@lexical/react/LexicalComposer'
import { RichTextPlugin } from '@lexical/react/LexicalRichTextPlugin'
import { ContentEditable } from '@lexical/react/LexicalContentEditable'
import LexicalErrorBoundary from '@lexical/react/LexicalErrorBoundary'
import { LexicalErrorBoundary } from '@lexical/react/LexicalErrorBoundary'
import { OnChangePlugin } from '@lexical/react/LexicalOnChangePlugin'
import { HistoryPlugin } from '@lexical/react/LexicalHistoryPlugin'
// import TreeView from './plugins/tree-view'

@ -120,10 +120,8 @@ const ComponentPicker = ({
}, [editor, checkForTriggerMatch, triggerString])
const handleClose = useCallback(() => {
ReactDOM.flushSync(() => {
const escapeEvent = new KeyboardEvent('keydown', { key: 'Escape' })
editor.dispatchCommand(KEY_ESCAPE_COMMAND, escapeEvent)
})
}, [editor])
const renderMenu = useCallback<MenuRenderFn<PickerBlockMenuOption>>((
@ -132,7 +130,11 @@ const ComponentPicker = ({
) => {
if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show)))
return null
setTimeout(() => {
if (anchorElementRef.current)
refs.setReference(anchorElementRef.current)
}, 0)
return (
<>
@ -149,7 +151,6 @@ const ComponentPicker = ({
visibility: isPositioned ? 'visible' : 'hidden',
}}
ref={refs.setFloating}
data-testid="component-picker-container"
>
{
workflowVariableBlock?.show && (
@ -173,7 +174,7 @@ const ComponentPicker = ({
<div className='my-1 h-px w-full -translate-x-1 bg-divider-subtle'></div>
)
}
<div data-testid="options-list">
<div>
{
options.map((option, index) => (
<Fragment key={option.key}>

@ -15,6 +15,7 @@ import { useProviderContext } from '@/context/provider-context'
import GridMask from '@/app/components/base/grid-mask'
import { useAppContext } from '@/context/app-context'
import classNames from '@/utils/classnames'
import { useGetPricingPageLanguage } from '@/context/i18n'
type Props = {
onCancel: () => void
@ -33,6 +34,11 @@ const Pricing: FC<Props> = ({
useKeyPress(['esc'], onCancel)
const pricingPageLanguage = useGetPricingPageLanguage()
const pricingPageURL = pricingPageLanguage
? `https://dify.ai/${pricingPageLanguage}/pricing#plans-and-features`
: 'https://dify.ai/pricing#plans-and-features'
return createPortal(
<div
className='fixed inset-0 bottom-0 left-0 right-0 top-0 z-[1000] bg-background-overlay-backdrop p-4 backdrop-blur-[6px]'
@ -127,7 +133,7 @@ const Pricing: FC<Props> = ({
</div>
<div className='flex items-center justify-center py-4'>
<div className='flex items-center justify-center gap-x-0.5 rounded-lg px-3 py-2 text-components-button-secondary-accent-text hover:cursor-pointer hover:bg-state-accent-hover'>
<Link href='https://dify.ai/pricing#plans-and-features' className='system-sm-medium'>{t('billing.plansCommon.comparePlanAndFeatures')}</Link>
<Link href={pricingPageURL} className='system-sm-medium'>{t('billing.plansCommon.comparePlanAndFeatures')}</Link>
<RiArrowRightUpLine className='size-4' />
</div>
</div>

@ -37,4 +37,3 @@
white-space: pre-line;
word-break: break-all;
}

@ -17,6 +17,7 @@ type Props = {
const NoData: FC<Props> = ({
onConfig,
provider,
}) => {
const { t } = useTranslation()
@ -38,7 +39,7 @@ const NoData: FC<Props> = ({
} : null,
}
const currentProvider = Object.values(providerConfig).find(provider => provider !== null) || providerConfig[DataSourceProvider.jinaReader]
const currentProvider = providerConfig[provider] || providerConfig[DataSourceProvider.jinaReader]
if (!currentProvider) return null

@ -2,7 +2,6 @@
import { useTranslation } from 'react-i18next'
import { Fragment, useState } from 'react'
import { useRouter } from 'next/navigation'
import { useContext } from 'use-context-selector'
import {
RiAccountCircleLine,
RiArrowRightUpLine,
@ -23,13 +22,12 @@ import GithubStar from '../github-star'
import Support from './support'
import Compliance from './compliance'
import PremiumBadge from '@/app/components/base/premium-badge'
import I18n from '@/context/i18n'
import { useGetDocLanguage } from '@/context/i18n'
import Avatar from '@/app/components/base/avatar'
import { logout } from '@/service/common'
import { useAppContext } from '@/context/app-context'
import { useProviderContext } from '@/context/provider-context'
import { useModalContext } from '@/context/modal-context'
import { LanguagesSupported } from '@/i18n/language'
import { IS_CLOUD_EDITION } from '@/config'
import cn from '@/utils/classnames'
import { useGlobalPublicStore } from '@/context/global-public-context'
@ -43,11 +41,11 @@ export default function AppSelector() {
const [aboutVisible, setAboutVisible] = useState(false)
const { systemFeatures } = useGlobalPublicStore()
const { locale } = useContext(I18n)
const { t } = useTranslation()
const { userProfile, langeniusVersionInfo, isCurrentWorkspaceOwner } = useAppContext()
const { isEducationAccount } = useProviderContext()
const { setShowAccountSettingModal } = useModalContext()
const docLanguage = useGetDocLanguage()
const handleLogout = async () => {
await logout({
@ -133,9 +131,7 @@ export default function AppSelector() {
className={cn(itemClassName, 'group justify-between',
'data-[active]:bg-state-base-hover',
)}
href={
locale !== LanguagesSupported[1] ? 'https://docs.dify.ai/' : `https://docs.dify.ai/v/${locale.toLowerCase()}/`
}
href={`https://docs.dify.ai/${docLanguage}/introduction`}
target='_blank' rel='noopener noreferrer'>
<RiBookOpenLine className='size-4 shrink-0 text-text-tertiary' />
<div className='system-md-regular grow px-1 text-text-secondary'>{t('common.userProfile.helpCenter')}</div>

@ -6,7 +6,7 @@ import {
RiArrowDownSLine,
RiArrowRightSLine,
} from '@remixicon/react'
import { Menu, MenuButton, MenuItems, Transition } from '@headlessui/react'
import { Menu, MenuButton, MenuItem, MenuItems, Transition } from '@headlessui/react'
import { useRouter } from 'next/navigation'
import { debounce } from 'lodash-es'
import cn from '@/utils/classnames'
@ -77,7 +77,7 @@ const NavSelector = ({ curNav, navs, createText, isApp, onCreate, onLoadmore }:
<div className="overflow-auto px-1 py-1" style={{ maxHeight: '50vh' }} onScroll={handleScroll}>
{
navs.map(nav => (
<MenuItems key={nav.id}>
<MenuItem key={nav.id}>
<div className='flex w-full cursor-pointer items-center truncate rounded-lg px-3 py-[6px] text-[14px] font-normal text-gray-700 hover:bg-gray-100' onClick={() => {
if (curNav?.id === nav.id)
return
@ -112,12 +112,12 @@ const NavSelector = ({ curNav, navs, createText, isApp, onCreate, onLoadmore }:
{nav.name}
</div>
</div>
</MenuItems>
</MenuItem>
))
}
</div>
{!isApp && isCurrentWorkspaceEditor && (
<MenuButton className='w-full p-1'>
<MenuItem as="div" className='w-full p-1'>
<div onClick={() => onCreate('')} className={cn(
'flex cursor-pointer items-center gap-2 rounded-lg px-3 py-[6px] hover:bg-gray-100',
)}>
@ -126,7 +126,7 @@ const NavSelector = ({ curNav, navs, createText, isApp, onCreate, onLoadmore }:
</div>
<div className='grow text-left text-[14px] font-normal text-gray-700'>{createText}</div>
</div>
</MenuButton>
</MenuItem>
)}
{isApp && isCurrentWorkspaceEditor && (
<Menu as="div" className="relative h-full w-full">

@ -1,6 +1,8 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import { useContext } from 'use-context-selector'
import I18n from '@/context/i18n'
import {
RiArrowRightUpLine,
RiBugLine,
@ -9,12 +11,14 @@ import { useTranslation } from 'react-i18next'
import KeyValueItem from '../base/key-value-item'
import Tooltip from '@/app/components/base/tooltip'
import Button from '@/app/components/base/button'
import { getDocsUrl } from '@/app/components/plugins/utils'
import { useDebugKey } from '@/service/use-plugins'
const i18nPrefix = 'plugin.debugInfo'
const DebugInfo: FC = () => {
const { t } = useTranslation()
const { locale } = useContext(I18n)
const { data: info, isLoading } = useDebugKey()
// info.key likes 4580bdb7-b878-471c-a8a4-bfd760263a53 mask the middle part using *.
@ -30,7 +34,7 @@ const DebugInfo: FC = () => {
<>
<div className='flex items-center gap-1 self-stretch'>
<span className='system-sm-semibold flex shrink-0 grow basis-0 flex-col items-start justify-center text-text-secondary'>{t(`${i18nPrefix}.title`)}</span>
<a href='https://docs.dify.ai/plugins/quick-start/develop-plugins/debug-plugin' target='_blank' className='flex cursor-pointer items-center gap-0.5 text-text-accent-light-mode-only'>
<a href={getDocsUrl(locale, '/plugins/quick-start/debug-plugin')} target='_blank' className='flex cursor-pointer items-center gap-0.5 text-text-accent-light-mode-only'>
<span className='system-xs-medium'>{t(`${i18nPrefix}.viewDocs`)}</span>
<RiArrowRightUpLine className='h-3 w-3' />
</a>

@ -33,10 +33,10 @@ import {
import type { Dependency } from '../types'
import type { PluginDeclaration, PluginManifestInMarket } from '../types'
import { sleep } from '@/utils'
import { getDocsUrl } from '@/app/components/plugins/utils'
import { fetchBundleInfoFromMarketPlace, fetchManifestFromMarketPlace } from '@/service/plugins'
import { marketplaceApiPrefix } from '@/config'
import { SUPPORT_INSTALL_LOCAL_FILE_EXTENSIONS } from '@/config'
import { LanguagesSupported } from '@/i18n/language'
import I18n from '@/context/i18n'
import { noop } from 'lodash-es'
import { PLUGIN_TYPE_SEARCH_MAP } from '../marketplace/plugin-type-switch'
@ -187,7 +187,7 @@ const PluginPage = ({
isExploringMarketplace && (
<>
<Link
href={`https://docs.dify.ai/${locale === LanguagesSupported[1] ? 'v/zh-hans/' : ''}plugins/publish-plugins/publish-to-dify-marketplace`}
href={getDocsUrl(locale, '/plugins/publish-plugins/publish-to-dify-marketplace/README')}
target='_blank'
>
<Button

@ -1,3 +1,5 @@
import { LanguagesSupported } from '@/i18n/language'
import {
categoryKeys,
tagKeys,
@ -10,3 +12,15 @@ export const getValidTagKeys = (tags: string[]) => {
export const getValidCategoryKeys = (category?: string) => {
return categoryKeys.find(key => key === category)
}
export const getDocsUrl = (locale: string, path: string) => {
let localePath = 'en'
if (locale === LanguagesSupported[1])
localePath = 'zh-hans'
else if (locale === LanguagesSupported[7])
localePath = 'ja-jp'
return `https://docs.dify.ai/${localePath}${path}`
}

@ -316,7 +316,7 @@ export const Workflow: FC<WorkflowProps> = memo(({
nodesConnectable={!nodesReadOnly}
nodesFocusable={!nodesReadOnly}
edgesFocusable={!nodesReadOnly}
panOnScroll
panOnScroll={false}
panOnDrag={controlMode === ControlMode.Hand && !workflowReadOnly}
zoomOnPinch={!workflowReadOnly}
zoomOnScroll={!workflowReadOnly}

@ -152,6 +152,7 @@ const VarList: FC<Props> = ({
/>
</div>
{item.operation !== WriteMode.clear && item.operation !== WriteMode.set
&& item.operation !== WriteMode.removeFirst && item.operation !== WriteMode.removeLast
&& !writeModeTypesNum?.includes(item.operation)
&& (
<VarReferencePicker

@ -29,7 +29,7 @@ const nodeDefault: NodeDefault<AssignerNodeType> = {
if (!errorMessages && !value.variable_selector?.length)
errorMessages = t(`${i18nPrefix}.fieldRequired`, { field: t('workflow.nodes.assigner.assignedVariable') })
if (!errorMessages && value.operation !== WriteMode.clear) {
if (!errorMessages && value.operation !== WriteMode.clear && value.operation !== WriteMode.removeFirst && value.operation !== WriteMode.removeLast) {
if (value.operation === WriteMode.set || value.operation === WriteMode.increment
|| value.operation === WriteMode.decrement || value.operation === WriteMode.multiply
|| value.operation === WriteMode.divide) {

@ -10,6 +10,8 @@ export enum WriteMode {
decrement = '-=',
multiply = '*=',
divide = '/=',
removeFirst = 'remove-first',
removeLast = 'remove-last',
}
export enum AssignerNodeInputType {

@ -69,7 +69,7 @@ const useConfig = (id: string, rawPayload: AssignerNodeType) => {
newSetInputs(newInputs)
}, [inputs, newSetInputs])
const writeModeTypesArr = [WriteMode.overwrite, WriteMode.clear, WriteMode.append, WriteMode.extend]
const writeModeTypesArr = [WriteMode.overwrite, WriteMode.clear, WriteMode.append, WriteMode.extend, WriteMode.removeFirst, WriteMode.removeLast]
const writeModeTypes = [WriteMode.overwrite, WriteMode.clear, WriteMode.set]
const writeModeTypesNum = [WriteMode.increment, WriteMode.decrement, WriteMode.multiply, WriteMode.divide]

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save