ix: Update prompt message content types to use Literal and add union type for content

pull/17136/head
朱庆超 1 year ago
parent be964c78ec
commit 9a0e2e8f42

@ -1,6 +1,6 @@
from collections.abc import Sequence from collections.abc import Sequence
from enum import Enum, StrEnum from enum import Enum, StrEnum
from typing import Any, Optional, Union from typing import Any, Annotated, Literal, Optional, Union, Union
from pydantic import BaseModel, Field, field_serializer, field_validator from pydantic import BaseModel, Field, field_serializer, field_validator
@ -61,11 +61,7 @@ class PromptMessageContentType(StrEnum):
class PromptMessageContent(BaseModel): class PromptMessageContent(BaseModel):
""" pass
Model class for prompt message content.
"""
type: PromptMessageContentType
class TextPromptMessageContent(PromptMessageContent): class TextPromptMessageContent(PromptMessageContent):
@ -73,7 +69,7 @@ class TextPromptMessageContent(PromptMessageContent):
Model class for text prompt message content. Model class for text prompt message content.
""" """
type: PromptMessageContentType = PromptMessageContentType.TEXT type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT
data: str data: str
@ -82,7 +78,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
Model class for multi-modal prompt message content. Model class for multi-modal prompt message content.
""" """
type: PromptMessageContentType
format: str = Field(default=..., description="the format of multi-modal file") format: str = Field(default=..., description="the format of multi-modal file")
base64_data: str = Field(default="", description="the base64 data of multi-modal file") base64_data: str = Field(default="", description="the base64 data of multi-modal file")
url: str = Field(default="", description="the url of multi-modal file") url: str = Field(default="", description="the url of multi-modal file")
@ -94,11 +89,11 @@ class MultiModalPromptMessageContent(PromptMessageContent):
class VideoPromptMessageContent(MultiModalPromptMessageContent): class VideoPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.VIDEO type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO
class AudioPromptMessageContent(MultiModalPromptMessageContent): class AudioPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.AUDIO type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO
class ImagePromptMessageContent(MultiModalPromptMessageContent): class ImagePromptMessageContent(MultiModalPromptMessageContent):
@ -110,13 +105,24 @@ class ImagePromptMessageContent(MultiModalPromptMessageContent):
LOW = "low" LOW = "low"
HIGH = "high" HIGH = "high"
type: PromptMessageContentType = PromptMessageContentType.IMAGE type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW detail: DETAIL = DETAIL.LOW
class DocumentPromptMessageContent(MultiModalPromptMessageContent): class DocumentPromptMessageContent(MultiModalPromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT
PromptMessageContentUnionTypes = Annotated[
Union[
TextPromptMessageContent,
ImagePromptMessageContent,
DocumentPromptMessageContent,
AudioPromptMessageContent,
VideoPromptMessageContent,
],
Field(discriminator="type"),
]
class PromptMessage(BaseModel): class PromptMessage(BaseModel):
""" """
@ -124,7 +130,7 @@ class PromptMessage(BaseModel):
""" """
role: PromptMessageRole role: PromptMessageRole
content: Optional[str | Sequence[PromptMessageContent]] = None content: Optional[str | list[PromptMessageContentUnionTypes]] = None
name: Optional[str] = None name: Optional[str] = None
def is_empty(self) -> bool: def is_empty(self) -> bool:

@ -0,0 +1,23 @@
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, TextPromptMessageContent, UserPromptMessage
def test_build_prompt_message_with_prompt_message_contents():
prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")])
assert isinstance(prompt.content, list)
assert isinstance(prompt.content[0], TextPromptMessageContent)
assert prompt.content[0].data == "Hello, World!"
def test_dump_prompt_message():
example_url = "https://example.com/image.jpg"
prompt = UserPromptMessage(
content=[
ImagePromptMessageContent(
url=example_url,
format="jpeg",
mime_type="image/jpeg",
)
]
)
data = prompt.model_dump()
assert data["content"][0].get("url") == example_url
Loading…
Cancel
Save