strict_bytes

pull/20523/head
Bowen Liang 12 months ago
parent 275e86a26c
commit 6209108cf3

@ -602,6 +602,16 @@ class ToolConfig(BaseSettings):
default=3600, default=3600,
) )
TOOL_FILE_CHUNK_SIZE_LIMIT: PositiveInt = Field(
description="Maximum bytes for a single file chunk of tool generated files",
default=8 * 1024,
)
TOOL_FILE_SIZE_LIMIT: PositiveInt = Field(
description="Maximum bytes for a single file of tool generated files",
default=30 * 1024 * 1024,
)
class MailConfig(BaseSettings): class MailConfig(BaseSettings):
""" """

@ -3,12 +3,44 @@ from typing import Any, Optional
from pydantic import BaseModel from pydantic import BaseModel
from configs import dify_config
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient from core.plugin.impl.base import BasePluginClient
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
class FileChunk:
"""
Only used for internal processing.
"""
__slots__ = ("bytes_written", "total_length", "data")
bytes_written: int
total_length: int
data: bytearray
def __init__(self, total_length: int):
self.bytes_written = 0
self.total_length = total_length
self.data = bytearray(total_length)
def write_blob(self, blob_data):
blob_data_length = len(blob_data)
if blob_data_length == 0:
return
# Validate write boundaries
expected_final_size = self.bytes_written + blob_data_length
if expected_final_size > self.total_length:
raise ValueError(f"Chunk would exceed file size ({expected_final_size} > {self.total_length})")
start_pos = self.bytes_written
self.data[start_pos : start_pos + blob_data_length] = blob_data
self.bytes_written += blob_data_length
class PluginToolManager(BasePluginClient): class PluginToolManager(BasePluginClient):
def fetch_tool_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]: def fetch_tool_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]:
""" """
@ -111,20 +143,6 @@ class PluginToolManager(BasePluginClient):
}, },
) )
class FileChunk:
"""
Only used for internal processing.
"""
bytes_written: int
total_length: int
data: bytearray
def __init__(self, total_length: int):
self.bytes_written = 0
self.total_length = total_length
self.data = bytearray(total_length)
files: dict[str, FileChunk] = {} files: dict[str, FileChunk] = {}
for resp in response: for resp in response:
if resp.type == ToolInvokeMessage.MessageType.BLOB_CHUNK: if resp.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
@ -134,36 +152,33 @@ class PluginToolManager(BasePluginClient):
total_length = resp.message.total_length total_length = resp.message.total_length
blob_data = resp.message.blob blob_data = resp.message.blob
is_end = resp.message.end is_end = resp.message.end
blob_data_length = len(blob_data)
# Pre-check conditions to avoid unnecessary processing
file_size_limit = dify_config.TOOL_FILE_SIZE_LIMIT
chunk_size_limit = dify_config.TOOL_FILE_CHUNK_SIZE_LIMIT
if total_length > file_size_limit:
raise ValueError(f"File size {total_length} exceeds limit of {file_size_limit} bytes")
if blob_data_length > chunk_size_limit:
raise ValueError(f"Chunk size {blob_data_length} exceeds limit of {chunk_size_limit} bytes")
# Initialize buffer for this file if it doesn't exist # Initialize buffer for this file if it doesn't exist
if chunk_id not in files: if chunk_id not in files:
files[chunk_id] = FileChunk(total_length) files[chunk_id] = FileChunk(total_length)
file_chunk = files[chunk_id]
# If this is the final chunk, yield a complete blob message # If this is the final chunk, yield a complete blob message
if is_end: if is_end:
yield ToolInvokeMessage( yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB, type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data), message=ToolInvokeMessage.BlobMessage(blob=bytes(file_chunk.data)),
meta=resp.meta, meta=resp.meta,
) )
del files[chunk_id]
else: else:
# Check if file is too large (30MB limit) # Write the blob data to the file chunk
if files[chunk_id].bytes_written + len(blob_data) > 30 * 1024 * 1024: file_chunk.write_blob(blob_data)
# Delete the file if it's too large
del files[chunk_id]
# Skip yielding this message
raise ValueError("File is too large which reached the limit of 30MB")
# Check if single chunk is too large (8KB limit)
if len(blob_data) > 8192:
# Skip yielding this message
raise ValueError("File chunk is too large which reached the limit of 8KB")
# Append the blob data to the buffer
files[chunk_id].data[
files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data)
] = blob_data
files[chunk_id].bytes_written += len(blob_data)
else: else:
yield resp yield resp

@ -1,7 +1,7 @@
import base64 import base64
import enum import enum
from collections.abc import Mapping from collections.abc import Mapping
from enum import Enum from enum import Enum, StrEnum
from typing import Any, Optional, Union from typing import Any, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
@ -176,7 +176,7 @@ class ToolInvokeMessage(BaseModel):
data: Mapping[str, Any] = Field(..., description="Detailed log data") data: Mapping[str, Any] = Field(..., description="Detailed log data")
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log") metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
class MessageType(Enum): class MessageType(StrEnum):
TEXT = "text" TEXT = "text"
IMAGE = "image" IMAGE = "image"
LINK = "link" LINK = "link"

@ -2,6 +2,9 @@
warn_return_any = True warn_return_any = True
warn_unused_configs = True warn_unused_configs = True
check_untyped_defs = True check_untyped_defs = True
strict_bytes = True
sqlite_cache = True
cache_fine_grained = True
exclude = (?x)( exclude = (?x)(
core/model_runtime/model_providers/ core/model_runtime/model_providers/
| tests/ | tests/

@ -7,4 +7,4 @@ cd "$SCRIPT_DIR/.."
# run mypy checks # run mypy checks
uv run --directory api --dev --with pip \ uv run --directory api --dev --with pip \
python -m mypy --install-types --non-interactive --cache-fine-grained --sqlite-cache . python -m mypy --install-types --non-interactive .

Loading…
Cancel
Save