|
|
|
@ -1,6 +1,6 @@
|
|
|
|
from collections.abc import Sequence
|
|
|
|
from collections.abc import Sequence
|
|
|
|
from enum import Enum, StrEnum
|
|
|
|
from enum import Enum, StrEnum
|
|
|
|
from typing import Any, Optional, Union
|
|
|
|
from typing import Any, Annotated, Literal, Optional, Union, Union
|
|
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel, Field, field_serializer, field_validator
|
|
|
|
from pydantic import BaseModel, Field, field_serializer, field_validator
|
|
|
|
|
|
|
|
|
|
|
|
@ -61,11 +61,7 @@ class PromptMessageContentType(StrEnum):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PromptMessageContent(BaseModel):
|
|
|
|
class PromptMessageContent(BaseModel):
|
|
|
|
"""
|
|
|
|
pass
|
|
|
|
Model class for prompt message content.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
type: PromptMessageContentType
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextPromptMessageContent(PromptMessageContent):
|
|
|
|
class TextPromptMessageContent(PromptMessageContent):
|
|
|
|
@ -73,7 +69,7 @@ class TextPromptMessageContent(PromptMessageContent):
|
|
|
|
Model class for text prompt message content.
|
|
|
|
Model class for text prompt message content.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
type: PromptMessageContentType = PromptMessageContentType.TEXT
|
|
|
|
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT
|
|
|
|
data: str
|
|
|
|
data: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -82,7 +78,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
|
|
|
Model class for multi-modal prompt message content.
|
|
|
|
Model class for multi-modal prompt message content.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
type: PromptMessageContentType
|
|
|
|
|
|
|
|
format: str = Field(default=..., description="the format of multi-modal file")
|
|
|
|
format: str = Field(default=..., description="the format of multi-modal file")
|
|
|
|
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
|
|
|
|
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
|
|
|
|
url: str = Field(default="", description="the url of multi-modal file")
|
|
|
|
url: str = Field(default="", description="the url of multi-modal file")
|
|
|
|
@ -94,11 +89,11 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
|
|
|
class VideoPromptMessageContent(MultiModalPromptMessageContent):
|
|
|
|
type: PromptMessageContentType = PromptMessageContentType.VIDEO
|
|
|
|
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
|
|
|
class AudioPromptMessageContent(MultiModalPromptMessageContent):
|
|
|
|
type: PromptMessageContentType = PromptMessageContentType.AUDIO
|
|
|
|
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
|
|
|
class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
|
|
|
@ -110,13 +105,24 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
|
|
|
|
LOW = "low"
|
|
|
|
LOW = "low"
|
|
|
|
HIGH = "high"
|
|
|
|
HIGH = "high"
|
|
|
|
|
|
|
|
|
|
|
|
type: PromptMessageContentType = PromptMessageContentType.IMAGE
|
|
|
|
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
|
|
|
|
detail: DETAIL = DETAIL.LOW
|
|
|
|
detail: DETAIL = DETAIL.LOW
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
|
|
|
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
|
|
|
|
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
|
|
|
|
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PromptMessageContentUnionTypes = Annotated[
|
|
|
|
|
|
|
|
Union[
|
|
|
|
|
|
|
|
TextPromptMessageContent,
|
|
|
|
|
|
|
|
ImagePromptMessageContent,
|
|
|
|
|
|
|
|
DocumentPromptMessageContent,
|
|
|
|
|
|
|
|
AudioPromptMessageContent,
|
|
|
|
|
|
|
|
VideoPromptMessageContent,
|
|
|
|
|
|
|
|
],
|
|
|
|
|
|
|
|
Field(discriminator="type"),
|
|
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
class PromptMessage(BaseModel):
|
|
|
|
class PromptMessage(BaseModel):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
@ -124,7 +130,7 @@ class PromptMessage(BaseModel):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
role: PromptMessageRole
|
|
|
|
role: PromptMessageRole
|
|
|
|
content: Optional[str | Sequence[PromptMessageContent]] = None
|
|
|
|
content: Optional[str | list[PromptMessageContentUnionTypes]] = None
|
|
|
|
name: Optional[str] = None
|
|
|
|
name: Optional[str] = None
|
|
|
|
|
|
|
|
|
|
|
|
def is_empty(self) -> bool:
|
|
|
|
def is_empty(self) -> bool:
|
|
|
|
|