strict_bytes

pull/20523/head
Bowen Liang 12 months ago
parent 275e86a26c
commit 6209108cf3

@ -602,6 +602,16 @@ class ToolConfig(BaseSettings):
default=3600,
)
TOOL_FILE_CHUNK_SIZE_LIMIT: PositiveInt = Field(
description="Maximum bytes for a single file chunk of tool generated files",
default=8 * 1024,
)
TOOL_FILE_SIZE_LIMIT: PositiveInt = Field(
description="Maximum bytes for a single file of tool generated files",
default=30 * 1024 * 1024,
)
class MailConfig(BaseSettings):
"""

@ -3,12 +3,44 @@ from typing import Any, Optional
from pydantic import BaseModel
from configs import dify_config
from core.plugin.entities.plugin import GenericProviderID, ToolProviderID
from core.plugin.entities.plugin_daemon import PluginBasicBooleanResponse, PluginToolProviderEntity
from core.plugin.impl.base import BasePluginClient
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter
class FileChunk:
"""
Only used for internal processing.
"""
__slots__ = ("bytes_written", "total_length", "data")
bytes_written: int
total_length: int
data: bytearray
def __init__(self, total_length: int):
self.bytes_written = 0
self.total_length = total_length
self.data = bytearray(total_length)
def write_blob(self, blob_data):
blob_data_length = len(blob_data)
if blob_data_length == 0:
return
# Validate write boundaries
expected_final_size = self.bytes_written + blob_data_length
if expected_final_size > self.total_length:
raise ValueError(f"Chunk would exceed file size ({expected_final_size} > {self.total_length})")
start_pos = self.bytes_written
self.data[start_pos : start_pos + blob_data_length] = blob_data
self.bytes_written += blob_data_length
class PluginToolManager(BasePluginClient):
def fetch_tool_providers(self, tenant_id: str) -> list[PluginToolProviderEntity]:
"""
@ -111,20 +143,6 @@ class PluginToolManager(BasePluginClient):
},
)
class FileChunk:
"""
Only used for internal processing.
"""
bytes_written: int
total_length: int
data: bytearray
def __init__(self, total_length: int):
self.bytes_written = 0
self.total_length = total_length
self.data = bytearray(total_length)
files: dict[str, FileChunk] = {}
for resp in response:
if resp.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
@ -134,36 +152,33 @@ class PluginToolManager(BasePluginClient):
total_length = resp.message.total_length
blob_data = resp.message.blob
is_end = resp.message.end
blob_data_length = len(blob_data)
# Pre-check conditions to avoid unnecessary processing
file_size_limit = dify_config.TOOL_FILE_SIZE_LIMIT
chunk_size_limit = dify_config.TOOL_FILE_CHUNK_SIZE_LIMIT
if total_length > file_size_limit:
raise ValueError(f"File size {total_length} exceeds limit of {file_size_limit} bytes")
if blob_data_length > chunk_size_limit:
raise ValueError(f"Chunk size {blob_data_length} exceeds limit of {chunk_size_limit} bytes")
# Initialize buffer for this file if it doesn't exist
if chunk_id not in files:
files[chunk_id] = FileChunk(total_length)
file_chunk = files[chunk_id]
# If this is the final chunk, yield a complete blob message
if is_end:
yield ToolInvokeMessage(
type=ToolInvokeMessage.MessageType.BLOB,
message=ToolInvokeMessage.BlobMessage(blob=files[chunk_id].data),
message=ToolInvokeMessage.BlobMessage(blob=bytes(file_chunk.data)),
meta=resp.meta,
)
del files[chunk_id]
else:
# Check if file is too large (30MB limit)
if files[chunk_id].bytes_written + len(blob_data) > 30 * 1024 * 1024:
# Delete the file if it's too large
del files[chunk_id]
# Skip yielding this message
raise ValueError("File is too large which reached the limit of 30MB")
# Check if single chunk is too large (8KB limit)
if len(blob_data) > 8192:
# Skip yielding this message
raise ValueError("File chunk is too large which reached the limit of 8KB")
# Append the blob data to the buffer
files[chunk_id].data[
files[chunk_id].bytes_written : files[chunk_id].bytes_written + len(blob_data)
] = blob_data
files[chunk_id].bytes_written += len(blob_data)
# Write the blob data to the file chunk
file_chunk.write_blob(blob_data)
else:
yield resp

@ -1,7 +1,7 @@
import base64
import enum
from collections.abc import Mapping
from enum import Enum
from enum import Enum, StrEnum
from typing import Any, Optional, Union
from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_serializer, field_validator, model_validator
@ -176,7 +176,7 @@ class ToolInvokeMessage(BaseModel):
data: Mapping[str, Any] = Field(..., description="Detailed log data")
metadata: Optional[Mapping[str, Any]] = Field(default=None, description="The metadata of the log")
class MessageType(Enum):
class MessageType(StrEnum):
TEXT = "text"
IMAGE = "image"
LINK = "link"

@ -2,6 +2,9 @@
warn_return_any = True
warn_unused_configs = True
check_untyped_defs = True
strict_bytes = True
sqlite_cache = True
cache_fine_grained = True
exclude = (?x)(
core/model_runtime/model_providers/
| tests/

@ -7,4 +7,4 @@ cd "$SCRIPT_DIR/.."
# run mypy checks
uv run --directory api --dev --with pip \
python -m mypy --install-types --non-interactive --cache-fine-grained --sqlite-cache .
python -m mypy --install-types --non-interactive .

Loading…
Cancel
Save