feat(tools/cogview): Updated cogview tool to support cogview-3 and the latest cogview-3-plus (#8382)
parent
0665268578
commit
740fad06c1
@ -1 +1 @@
|
||||
__version__ = "v2.0.1"
|
||||
__version__ = "v2.1.0"
|
||||
|
||||
@ -1,5 +1,34 @@
|
||||
from .chat import chat
|
||||
from .assistant import (
|
||||
Assistant,
|
||||
)
|
||||
from .batches import Batches
|
||||
from .chat import (
|
||||
AsyncCompletions,
|
||||
Chat,
|
||||
Completions,
|
||||
)
|
||||
from .embeddings import Embeddings
|
||||
from .files import Files
|
||||
from .fine_tuning import fine_tuning
|
||||
from .files import Files, FilesWithRawResponse
|
||||
from .fine_tuning import FineTuning
|
||||
from .images import Images
|
||||
from .knowledge import Knowledge
|
||||
from .tools import Tools
|
||||
from .videos import (
|
||||
Videos,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Videos",
|
||||
"AsyncCompletions",
|
||||
"Chat",
|
||||
"Completions",
|
||||
"Images",
|
||||
"Embeddings",
|
||||
"Files",
|
||||
"FilesWithRawResponse",
|
||||
"FineTuning",
|
||||
"Batches",
|
||||
"Knowledge",
|
||||
"Tools",
|
||||
"Assistant",
|
||||
]
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
from .assistant import Assistant
|
||||
|
||||
__all__ = ["Assistant"]
|
||||
@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
StreamResponse,
|
||||
deepcopy_minimal,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.assistant import AssistantCompletion
|
||||
from ...types.assistant.assistant_conversation_resp import ConversationUsageListResp
|
||||
from ...types.assistant.assistant_support_resp import AssistantSupportResp
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
from ...types.assistant import assistant_conversation_params, assistant_create_params
|
||||
|
||||
__all__ = ["Assistant"]
|
||||
|
||||
|
||||
class Assistant(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def conversation(
|
||||
self,
|
||||
assistant_id: str,
|
||||
model: str,
|
||||
messages: list[assistant_create_params.ConversationMessage],
|
||||
*,
|
||||
stream: bool = True,
|
||||
conversation_id: Optional[str] = None,
|
||||
attachments: Optional[list[assistant_create_params.AssistantAttachments]] = None,
|
||||
metadata: dict | None = None,
|
||||
request_id: str = None,
|
||||
user_id: str = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> StreamResponse[AssistantCompletion]:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"assistant_id": assistant_id,
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"conversation_id": conversation_id,
|
||||
"attachments": attachments,
|
||||
"metadata": metadata,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/assistant",
|
||||
body=maybe_transform(body, assistant_create_params.AssistantParameters),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=AssistantCompletion,
|
||||
stream=stream or True,
|
||||
stream_cls=StreamResponse[AssistantCompletion],
|
||||
)
|
||||
|
||||
def query_support(
|
||||
self,
|
||||
*,
|
||||
assistant_id_list: list[str] = None,
|
||||
request_id: str = None,
|
||||
user_id: str = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> AssistantSupportResp:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"assistant_id_list": assistant_id_list,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/assistant/list",
|
||||
body=body,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=AssistantSupportResp,
|
||||
)
|
||||
|
||||
def query_conversation_usage(
|
||||
self,
|
||||
assistant_id: str,
|
||||
page: int = 1,
|
||||
page_size: int = 10,
|
||||
*,
|
||||
request_id: str = None,
|
||||
user_id: str = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> ConversationUsageListResp:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"assistant_id": assistant_id,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/assistant/conversation/list",
|
||||
body=maybe_transform(body, assistant_conversation_params.ConversationParameters),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=ConversationUsageListResp,
|
||||
)
|
||||
@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ..core import NOT_GIVEN, BaseAPI, Body, Headers, NotGiven, make_request_options, maybe_transform
|
||||
from ..core.pagination import SyncCursorPage
|
||||
from ..types import batch_create_params, batch_list_params
|
||||
from ..types.batch import Batch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._client import ZhipuAI
|
||||
|
||||
|
||||
class Batches(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
completion_window: str | None = None,
|
||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings"],
|
||||
input_file_id: str,
|
||||
metadata: Optional[dict[str, str]] | NotGiven = NOT_GIVEN,
|
||||
auto_delete_input_file: bool = True,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Batch:
|
||||
return self._post(
|
||||
"/batches",
|
||||
body=maybe_transform(
|
||||
{
|
||||
"completion_window": completion_window,
|
||||
"endpoint": endpoint,
|
||||
"input_file_id": input_file_id,
|
||||
"metadata": metadata,
|
||||
"auto_delete_input_file": auto_delete_input_file,
|
||||
},
|
||||
batch_create_params.BatchCreateParams,
|
||||
),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=Batch,
|
||||
)
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
batch_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Batch:
|
||||
"""
|
||||
Retrieves a batch.
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not batch_id:
|
||||
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
|
||||
return self._get(
|
||||
f"/batches/{batch_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=Batch,
|
||||
)
|
||||
|
||||
def list(
|
||||
self,
|
||||
*,
|
||||
after: str | NotGiven = NOT_GIVEN,
|
||||
limit: int | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> SyncCursorPage[Batch]:
|
||||
"""List your organization's batches.
|
||||
|
||||
Args:
|
||||
after: A cursor for use in pagination.
|
||||
|
||||
`after` is an object ID that defines your place
|
||||
in the list. For instance, if you make a list request and receive 100 objects,
|
||||
ending with obj_foo, your subsequent call can include after=obj_foo in order to
|
||||
fetch the next page of the list.
|
||||
|
||||
limit: A limit on the number of objects to be returned. Limit can range between 1 and
|
||||
100, and the default is 20.
|
||||
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
return self._get_api_list(
|
||||
"/batches",
|
||||
page=SyncCursorPage[Batch],
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query=maybe_transform(
|
||||
{
|
||||
"after": after,
|
||||
"limit": limit,
|
||||
},
|
||||
batch_list_params.BatchListParams,
|
||||
),
|
||||
),
|
||||
model=Batch,
|
||||
)
|
||||
|
||||
def cancel(
|
||||
self,
|
||||
batch_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> Batch:
|
||||
"""
|
||||
Cancels an in-progress batch.
|
||||
|
||||
Args:
|
||||
batch_id: The ID of the batch to cancel.
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
|
||||
"""
|
||||
if not batch_id:
|
||||
raise ValueError(f"Expected a non-empty value for `batch_id` but received {batch_id!r}")
|
||||
return self._post(
|
||||
f"/batches/{batch_id}/cancel",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=Batch,
|
||||
)
|
||||
@ -0,0 +1,5 @@
|
||||
from .async_completions import AsyncCompletions
|
||||
from .chat import Chat
|
||||
from .completions import Completions
|
||||
|
||||
__all__ = ["AsyncCompletions", "Chat", "Completions"]
|
||||
@ -1,17 +1,18 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...core._base_api import BaseAPI
|
||||
from ...core import BaseAPI, cached_property
|
||||
from .async_completions import AsyncCompletions
|
||||
from .completions import Completions
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
pass
|
||||
|
||||
|
||||
class Chat(BaseAPI):
|
||||
completions: Completions
|
||||
@cached_property
|
||||
def completions(self) -> Completions:
|
||||
return Completions(self._client)
|
||||
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
self.completions = Completions(client)
|
||||
self.asyncCompletions = AsyncCompletions(client)
|
||||
@cached_property
|
||||
def asyncCompletions(self) -> AsyncCompletions: # noqa: N802
|
||||
return AsyncCompletions(self._client)
|
||||
|
||||
@ -0,0 +1,5 @@
|
||||
from .fine_tuning import FineTuning
|
||||
from .jobs import Jobs
|
||||
from .models import FineTunedModels
|
||||
|
||||
__all__ = ["Jobs", "FineTunedModels", "FineTuning"]
|
||||
@ -1,15 +1,18 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...core._base_api import BaseAPI
|
||||
from ...core import BaseAPI, cached_property
|
||||
from .jobs import Jobs
|
||||
from .models import FineTunedModels
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
pass
|
||||
|
||||
|
||||
class FineTuning(BaseAPI):
|
||||
jobs: Jobs
|
||||
@cached_property
|
||||
def jobs(self) -> Jobs:
|
||||
return Jobs(self._client)
|
||||
|
||||
def __init__(self, client: "ZhipuAI") -> None:
|
||||
super().__init__(client)
|
||||
self.jobs = Jobs(client)
|
||||
@cached_property
|
||||
def models(self) -> FineTunedModels:
|
||||
return FineTunedModels(self._client)
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
from .jobs import Jobs
|
||||
|
||||
__all__ = ["Jobs"]
|
||||
@ -0,0 +1,3 @@
|
||||
from .fine_tuned_models import FineTunedModels
|
||||
|
||||
__all__ = ["FineTunedModels"]
|
||||
@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
|
||||
from ....core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
make_request_options,
|
||||
)
|
||||
from ....types.fine_tuning.models import FineTunedModelsStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...._client import ZhipuAI
|
||||
|
||||
__all__ = ["FineTunedModels"]
|
||||
|
||||
|
||||
class FineTunedModels(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
fine_tuned_model: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> FineTunedModelsStatus:
|
||||
if not fine_tuned_model:
|
||||
raise ValueError(f"Expected a non-empty value for `fine_tuned_model` but received {fine_tuned_model!r}")
|
||||
return self._delete(
|
||||
f"fine_tuning/fine_tuned_models/{fine_tuned_model}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=FineTunedModelsStatus,
|
||||
)
|
||||
@ -0,0 +1,3 @@
|
||||
from .knowledge import Knowledge
|
||||
|
||||
__all__ = ["Knowledge"]
|
||||
@ -0,0 +1,3 @@
|
||||
from .document import Document
|
||||
|
||||
__all__ = ["Document"]
|
||||
@ -0,0 +1,217 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Literal, Optional, cast
|
||||
|
||||
import httpx
|
||||
|
||||
from ....core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
FileTypes,
|
||||
Headers,
|
||||
NotGiven,
|
||||
deepcopy_minimal,
|
||||
extract_files,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ....types.files import UploadDetail, file_create_params
|
||||
from ....types.knowledge.document import DocumentData, DocumentObject, document_edit_params, document_list_params
|
||||
from ....types.knowledge.document.document_list_resp import DocumentPage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...._client import ZhipuAI
|
||||
|
||||
__all__ = ["Document"]
|
||||
|
||||
|
||||
class Document(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
*,
|
||||
file: FileTypes = None,
|
||||
custom_separator: Optional[list[str]] = None,
|
||||
upload_detail: list[UploadDetail] = None,
|
||||
purpose: Literal["retrieval"],
|
||||
knowledge_id: str = None,
|
||||
sentence_size: int = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> DocumentObject:
|
||||
if not file and not upload_detail:
|
||||
raise ValueError("At least one of `file` and `upload_detail` must be provided.")
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"file": file,
|
||||
"upload_detail": upload_detail,
|
||||
"purpose": purpose,
|
||||
"custom_separator": custom_separator,
|
||||
"knowledge_id": knowledge_id,
|
||||
"sentence_size": sentence_size,
|
||||
}
|
||||
)
|
||||
files = extract_files(cast(Mapping[str, object], body), paths=[["file"]])
|
||||
if files:
|
||||
# It should be noted that the actual Content-Type header that will be
|
||||
# sent to the server will contain a `boundary` parameter, e.g.
|
||||
# multipart/form-data; boundary=---abc--
|
||||
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
|
||||
return self._post(
|
||||
"/files",
|
||||
body=maybe_transform(body, file_create_params.FileCreateParams),
|
||||
files=files,
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=DocumentObject,
|
||||
)
|
||||
|
||||
def edit(
|
||||
self,
|
||||
document_id: str,
|
||||
knowledge_type: str,
|
||||
*,
|
||||
custom_separator: Optional[list[str]] = None,
|
||||
sentence_size: Optional[int] = None,
|
||||
callback_url: Optional[str] = None,
|
||||
callback_header: Optional[dict[str, str]] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
|
||||
Args:
|
||||
document_id: 知识id
|
||||
knowledge_type: 知识类型:
|
||||
1:文章知识: 支持pdf,url,docx
|
||||
2.问答知识-文档: 支持pdf,url,docx
|
||||
3.问答知识-表格: 支持xlsx
|
||||
4.商品库-表格: 支持xlsx
|
||||
5.自定义: 支持pdf,url,docx
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
:param knowledge_type:
|
||||
:param document_id:
|
||||
:param timeout:
|
||||
:param extra_body:
|
||||
:param callback_header:
|
||||
:param sentence_size:
|
||||
:param extra_headers:
|
||||
:param callback_url:
|
||||
:param custom_separator:
|
||||
"""
|
||||
if not document_id:
|
||||
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
|
||||
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"id": document_id,
|
||||
"knowledge_type": knowledge_type,
|
||||
"custom_separator": custom_separator,
|
||||
"sentence_size": sentence_size,
|
||||
"callback_url": callback_url,
|
||||
"callback_header": callback_header,
|
||||
}
|
||||
)
|
||||
|
||||
return self._put(
|
||||
f"/document/{document_id}",
|
||||
body=maybe_transform(body, document_edit_params.DocumentEditParams),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=httpx.Response,
|
||||
)
|
||||
|
||||
def list(
|
||||
self,
|
||||
knowledge_id: str,
|
||||
*,
|
||||
purpose: str | NotGiven = NOT_GIVEN,
|
||||
page: str | NotGiven = NOT_GIVEN,
|
||||
limit: str | NotGiven = NOT_GIVEN,
|
||||
order: Literal["desc", "asc"] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> DocumentPage:
|
||||
return self._get(
|
||||
"/files",
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query=maybe_transform(
|
||||
{
|
||||
"knowledge_id": knowledge_id,
|
||||
"purpose": purpose,
|
||||
"page": page,
|
||||
"limit": limit,
|
||||
"order": order,
|
||||
},
|
||||
document_list_params.DocumentListParams,
|
||||
),
|
||||
),
|
||||
cast_type=DocumentPage,
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
document_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Delete a file.
|
||||
|
||||
Args:
|
||||
|
||||
document_id: 知识id
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not document_id:
|
||||
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
|
||||
|
||||
return self._delete(
|
||||
f"/document/{document_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=httpx.Response,
|
||||
)
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
document_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> DocumentData:
|
||||
"""
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not document_id:
|
||||
raise ValueError(f"Expected a non-empty value for `document_id` but received {document_id!r}")
|
||||
|
||||
return self._get(
|
||||
f"/document/{document_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=DocumentData,
|
||||
)
|
||||
@ -0,0 +1,173 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Literal, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
cached_property,
|
||||
deepcopy_minimal,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.knowledge import KnowledgeInfo, KnowledgeUsed, knowledge_create_params, knowledge_list_params
|
||||
from ...types.knowledge.knowledge_list_resp import KnowledgePage
|
||||
from .document import Document
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
__all__ = ["Knowledge"]
|
||||
|
||||
|
||||
class Knowledge(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
@cached_property
|
||||
def document(self) -> Document:
|
||||
return Document(self._client)
|
||||
|
||||
def create(
|
||||
self,
|
||||
embedding_id: int,
|
||||
name: str,
|
||||
*,
|
||||
customer_identifier: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None,
|
||||
icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None,
|
||||
bucket_id: Optional[str] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> KnowledgeInfo:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"embedding_id": embedding_id,
|
||||
"name": name,
|
||||
"customer_identifier": customer_identifier,
|
||||
"description": description,
|
||||
"background": background,
|
||||
"icon": icon,
|
||||
"bucket_id": bucket_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/knowledge",
|
||||
body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=KnowledgeInfo,
|
||||
)
|
||||
|
||||
def modify(
|
||||
self,
|
||||
knowledge_id: str,
|
||||
embedding_id: int,
|
||||
*,
|
||||
name: str,
|
||||
description: Optional[str] = None,
|
||||
background: Optional[Literal["blue", "red", "orange", "purple", "sky"]] = None,
|
||||
icon: Optional[Literal["question", "book", "seal", "wrench", "tag", "horn", "house"]] = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> httpx.Response:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"id": knowledge_id,
|
||||
"embedding_id": embedding_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"background": background,
|
||||
"icon": icon,
|
||||
}
|
||||
)
|
||||
return self._put(
|
||||
f"/knowledge/{knowledge_id}",
|
||||
body=maybe_transform(body, knowledge_create_params.KnowledgeBaseParams),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=httpx.Response,
|
||||
)
|
||||
|
||||
def query(
|
||||
self,
|
||||
*,
|
||||
page: int | NotGiven = 1,
|
||||
size: int | NotGiven = 10,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> KnowledgePage:
|
||||
return self._get(
|
||||
"/knowledge",
|
||||
options=make_request_options(
|
||||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
timeout=timeout,
|
||||
query=maybe_transform(
|
||||
{
|
||||
"page": page,
|
||||
"size": size,
|
||||
},
|
||||
knowledge_list_params.KnowledgeListParams,
|
||||
),
|
||||
),
|
||||
cast_type=KnowledgePage,
|
||||
)
|
||||
|
||||
def delete(
|
||||
self,
|
||||
knowledge_id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> httpx.Response:
|
||||
"""
|
||||
Delete a file.
|
||||
|
||||
Args:
|
||||
knowledge_id: 知识库ID
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
if not knowledge_id:
|
||||
raise ValueError("Expected a non-empty value for `knowledge_id`")
|
||||
|
||||
return self._delete(
|
||||
f"/knowledge/{knowledge_id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=httpx.Response,
|
||||
)
|
||||
|
||||
def used(
|
||||
self,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> KnowledgeUsed:
|
||||
"""
|
||||
Returns the contents of the specified file.
|
||||
|
||||
Args:
|
||||
extra_headers: Send extra headers
|
||||
|
||||
extra_body: Add additional JSON properties to the request
|
||||
|
||||
timeout: Override the client-level default timeout for this request, in seconds
|
||||
"""
|
||||
return self._get(
|
||||
"/knowledge/capacity",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=KnowledgeUsed,
|
||||
)
|
||||
@ -0,0 +1,3 @@
|
||||
from .tools import Tools
|
||||
|
||||
__all__ = ["Tools"]
|
||||
@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
StreamResponse,
|
||||
deepcopy_minimal,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.tools import WebSearch, WebSearchChunk, tools_web_search_params
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
__all__ = ["Tools"]
|
||||
|
||||
|
||||
class Tools(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def web_search(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
request_id: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN,
|
||||
messages: Union[str, list[str], list[int], object, None],
|
||||
scope: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
location: Optional[str] | NotGiven = NOT_GIVEN,
|
||||
recent_days: Optional[int] | NotGiven = NOT_GIVEN,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> WebSearch | StreamResponse[WebSearchChunk]:
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"model": model,
|
||||
"request_id": request_id,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"scope": scope,
|
||||
"location": location,
|
||||
"recent_days": recent_days,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/tools",
|
||||
body=maybe_transform(body, tools_web_search_params.WebSearchParams),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=WebSearch,
|
||||
stream=stream or False,
|
||||
stream_cls=StreamResponse[WebSearchChunk],
|
||||
)
|
||||
@ -0,0 +1,7 @@
|
||||
from .videos import (
|
||||
Videos,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Videos",
|
||||
]
|
||||
@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from ...core import (
|
||||
NOT_GIVEN,
|
||||
BaseAPI,
|
||||
Body,
|
||||
Headers,
|
||||
NotGiven,
|
||||
deepcopy_minimal,
|
||||
make_request_options,
|
||||
maybe_transform,
|
||||
)
|
||||
from ...types.sensitive_word_check import SensitiveWordCheckRequest
|
||||
from ...types.video import VideoObject, video_create_params
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..._client import ZhipuAI
|
||||
|
||||
__all__ = ["Videos"]
|
||||
|
||||
|
||||
class Videos(BaseAPI):
|
||||
def __init__(self, client: ZhipuAI) -> None:
|
||||
super().__init__(client)
|
||||
|
||||
def generations(
|
||||
self,
|
||||
model: str,
|
||||
*,
|
||||
prompt: str = None,
|
||||
image_url: str = None,
|
||||
sensitive_word_check: Optional[SensitiveWordCheckRequest] | NotGiven = NOT_GIVEN,
|
||||
request_id: str = None,
|
||||
user_id: str = None,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> VideoObject:
|
||||
if not model and not model:
|
||||
raise ValueError("At least one of `model` and `prompt` must be provided.")
|
||||
body = deepcopy_minimal(
|
||||
{
|
||||
"model": model,
|
||||
"prompt": prompt,
|
||||
"image_url": image_url,
|
||||
"sensitive_word_check": sensitive_word_check,
|
||||
"request_id": request_id,
|
||||
"user_id": user_id,
|
||||
}
|
||||
)
|
||||
return self._post(
|
||||
"/videos/generations",
|
||||
body=maybe_transform(body, video_create_params.VideoCreateParams),
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=VideoObject,
|
||||
)
|
||||
|
||||
def retrieve_videos_result(
|
||||
self,
|
||||
id: str,
|
||||
*,
|
||||
extra_headers: Headers | None = None,
|
||||
extra_body: Body | None = None,
|
||||
timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN,
|
||||
) -> VideoObject:
|
||||
if not id:
|
||||
raise ValueError("At least one of `id` must be provided.")
|
||||
|
||||
return self._get(
|
||||
f"/async-result/{id}",
|
||||
options=make_request_options(extra_headers=extra_headers, extra_body=extra_body, timeout=timeout),
|
||||
cast_type=VideoObject,
|
||||
)
|
||||
@ -0,0 +1,108 @@
|
||||
from ._base_api import BaseAPI
|
||||
from ._base_compat import (
|
||||
PYDANTIC_V2,
|
||||
ConfigDict,
|
||||
GenericModel,
|
||||
cached_property,
|
||||
field_get_default,
|
||||
get_args,
|
||||
get_model_config,
|
||||
get_model_fields,
|
||||
get_origin,
|
||||
is_literal_type,
|
||||
is_union,
|
||||
parse_obj,
|
||||
)
|
||||
from ._base_models import BaseModel, construct_type
|
||||
from ._base_type import (
|
||||
NOT_GIVEN,
|
||||
Body,
|
||||
FileTypes,
|
||||
Headers,
|
||||
IncEx,
|
||||
ModelT,
|
||||
NotGiven,
|
||||
Query,
|
||||
)
|
||||
from ._constants import (
|
||||
ZHIPUAI_DEFAULT_LIMITS,
|
||||
ZHIPUAI_DEFAULT_MAX_RETRIES,
|
||||
ZHIPUAI_DEFAULT_TIMEOUT,
|
||||
)
|
||||
from ._errors import (
|
||||
APIAuthenticationError,
|
||||
APIConnectionError,
|
||||
APIInternalError,
|
||||
APIReachLimitError,
|
||||
APIRequestFailedError,
|
||||
APIResponseError,
|
||||
APIResponseValidationError,
|
||||
APIServerFlowExceedError,
|
||||
APIStatusError,
|
||||
APITimeoutError,
|
||||
ZhipuAIError,
|
||||
)
|
||||
from ._files import is_file_content
|
||||
from ._http_client import HttpClient, make_request_options
|
||||
from ._sse_client import StreamResponse
|
||||
from ._utils import (
|
||||
deepcopy_minimal,
|
||||
drop_prefix_image_data,
|
||||
extract_files,
|
||||
is_given,
|
||||
is_list,
|
||||
is_mapping,
|
||||
maybe_transform,
|
||||
parse_date,
|
||||
parse_datetime,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BaseModel",
|
||||
"construct_type",
|
||||
"BaseAPI",
|
||||
"NOT_GIVEN",
|
||||
"Headers",
|
||||
"NotGiven",
|
||||
"Body",
|
||||
"IncEx",
|
||||
"ModelT",
|
||||
"Query",
|
||||
"FileTypes",
|
||||
"PYDANTIC_V2",
|
||||
"ConfigDict",
|
||||
"GenericModel",
|
||||
"get_args",
|
||||
"is_union",
|
||||
"parse_obj",
|
||||
"get_origin",
|
||||
"is_literal_type",
|
||||
"get_model_config",
|
||||
"get_model_fields",
|
||||
"field_get_default",
|
||||
"is_file_content",
|
||||
"ZhipuAIError",
|
||||
"APIStatusError",
|
||||
"APIRequestFailedError",
|
||||
"APIAuthenticationError",
|
||||
"APIReachLimitError",
|
||||
"APIInternalError",
|
||||
"APIServerFlowExceedError",
|
||||
"APIResponseError",
|
||||
"APIResponseValidationError",
|
||||
"APITimeoutError",
|
||||
"make_request_options",
|
||||
"HttpClient",
|
||||
"ZHIPUAI_DEFAULT_TIMEOUT",
|
||||
"ZHIPUAI_DEFAULT_MAX_RETRIES",
|
||||
"ZHIPUAI_DEFAULT_LIMITS",
|
||||
"is_list",
|
||||
"is_mapping",
|
||||
"parse_date",
|
||||
"parse_datetime",
|
||||
"is_given",
|
||||
"maybe_transform",
|
||||
"deepcopy_minimal",
|
||||
"extract_files",
|
||||
"StreamResponse",
|
||||
]
|
||||
@ -0,0 +1,209 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from datetime import date, datetime
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, overload
|
||||
|
||||
import pydantic
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import Self
|
||||
|
||||
from ._base_type import StrBytesIntFloat
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_ModelT = TypeVar("_ModelT", bound=pydantic.BaseModel)
|
||||
|
||||
# --------------- Pydantic v2 compatibility ---------------
|
||||
|
||||
# Pyright incorrectly reports some of our functions as overriding a method when they don't
|
||||
# pyright: reportIncompatibleMethodOverride=false
|
||||
|
||||
PYDANTIC_V2 = pydantic.VERSION.startswith("2.")
|
||||
|
||||
# v1 re-exports
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def parse_date(value: date | StrBytesIntFloat) -> date: ...
|
||||
|
||||
def parse_datetime(value: Union[datetime, StrBytesIntFloat]) -> datetime: ...
|
||||
|
||||
def get_args(t: type[Any]) -> tuple[Any, ...]: ...
|
||||
|
||||
def is_union(tp: type[Any] | None) -> bool: ...
|
||||
|
||||
def get_origin(t: type[Any]) -> type[Any] | None: ...
|
||||
|
||||
def is_literal_type(type_: type[Any]) -> bool: ...
|
||||
|
||||
def is_typeddict(type_: type[Any]) -> bool: ...
|
||||
|
||||
else:
|
||||
if PYDANTIC_V2:
|
||||
from pydantic.v1.typing import ( # noqa: I001
|
||||
get_args as get_args, # noqa: PLC0414
|
||||
is_union as is_union, # noqa: PLC0414
|
||||
get_origin as get_origin, # noqa: PLC0414
|
||||
is_typeddict as is_typeddict, # noqa: PLC0414
|
||||
is_literal_type as is_literal_type, # noqa: PLC0414
|
||||
)
|
||||
from pydantic.v1.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414
|
||||
else:
|
||||
from pydantic.typing import ( # noqa: I001
|
||||
get_args as get_args, # noqa: PLC0414
|
||||
is_union as is_union, # noqa: PLC0414
|
||||
get_origin as get_origin, # noqa: PLC0414
|
||||
is_typeddict as is_typeddict, # noqa: PLC0414
|
||||
is_literal_type as is_literal_type, # noqa: PLC0414
|
||||
)
|
||||
from pydantic.datetime_parse import parse_date as parse_date, parse_datetime as parse_datetime # noqa: PLC0414
|
||||
|
||||
|
||||
# refactored config
|
||||
if TYPE_CHECKING:
|
||||
from pydantic import ConfigDict
|
||||
else:
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import ConfigDict
|
||||
else:
|
||||
# TODO: provide an error message here?
|
||||
ConfigDict = None
|
||||
|
||||
|
||||
# renamed methods / properties
|
||||
def parse_obj(model: type[_ModelT], value: object) -> _ModelT:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_validate(value)
|
||||
else:
|
||||
# pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
||||
return cast(_ModelT, model.parse_obj(value))
|
||||
|
||||
|
||||
def field_is_required(field: FieldInfo) -> bool:
|
||||
if PYDANTIC_V2:
|
||||
return field.is_required()
|
||||
return field.required # type: ignore
|
||||
|
||||
|
||||
def field_get_default(field: FieldInfo) -> Any:
|
||||
value = field.get_default()
|
||||
if PYDANTIC_V2:
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
if value == PydanticUndefined:
|
||||
return None
|
||||
return value
|
||||
return value
|
||||
|
||||
|
||||
def field_outer_type(field: FieldInfo) -> Any:
|
||||
if PYDANTIC_V2:
|
||||
return field.annotation
|
||||
return field.outer_type_ # type: ignore
|
||||
|
||||
|
||||
def get_model_config(model: type[pydantic.BaseModel]) -> Any:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_config
|
||||
return model.__config__ # type: ignore
|
||||
|
||||
|
||||
def get_model_fields(model: type[pydantic.BaseModel]) -> dict[str, FieldInfo]:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_fields
|
||||
return model.__fields__ # type: ignore
|
||||
|
||||
|
||||
def model_copy(model: _ModelT) -> _ModelT:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_copy()
|
||||
return model.copy() # type: ignore
|
||||
|
||||
|
||||
def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump_json(indent=indent)
|
||||
return model.json(indent=indent) # type: ignore
|
||||
|
||||
|
||||
def model_dump(
|
||||
model: pydantic.BaseModel,
|
||||
*,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_dump(
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
)
|
||||
return cast(
|
||||
"dict[str, Any]",
|
||||
model.dict( # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def model_parse(model: type[_ModelT], data: Any) -> _ModelT:
|
||||
if PYDANTIC_V2:
|
||||
return model.model_validate(data)
|
||||
return model.parse_obj(data) # pyright: ignore[reportDeprecated]
|
||||
|
||||
|
||||
# generic models
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class GenericModel(pydantic.BaseModel): ...
|
||||
|
||||
else:
|
||||
if PYDANTIC_V2:
|
||||
# there no longer needs to be a distinction in v2 but
|
||||
# we still have to create our own subclass to avoid
|
||||
# inconsistent MRO ordering errors
|
||||
class GenericModel(pydantic.BaseModel): ...
|
||||
|
||||
else:
|
||||
import pydantic.generics
|
||||
|
||||
class GenericModel(pydantic.generics.GenericModel, pydantic.BaseModel): ...
|
||||
|
||||
|
||||
# cached properties
|
||||
if TYPE_CHECKING:
|
||||
cached_property = property
|
||||
|
||||
# we define a separate type (copied from typeshed)
|
||||
# that represents that `cached_property` is `set`able
|
||||
# at runtime, which differs from `@property`.
|
||||
#
|
||||
# this is a separate type as editors likely special case
|
||||
# `@property` and we don't want to cause issues just to have
|
||||
# more helpful internal types.
|
||||
|
||||
class typed_cached_property(Generic[_T]): # noqa: N801
|
||||
func: Callable[[Any], _T]
|
||||
attrname: str | None
|
||||
|
||||
def __init__(self, func: Callable[[Any], _T]) -> None: ...
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: None, owner: type[Any] | None = None) -> Self: ...
|
||||
|
||||
@overload
|
||||
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T: ...
|
||||
|
||||
def __get__(self, instance: object, owner: type[Any] | None = None) -> _T | Self:
|
||||
raise NotImplementedError()
|
||||
|
||||
def __set_name__(self, owner: type[Any], name: str) -> None: ...
|
||||
|
||||
# __set__ is not defined at runtime, but @cached_property is designed to be settable
|
||||
def __set__(self, instance: object, value: _T) -> None: ...
|
||||
else:
|
||||
try:
|
||||
from functools import cached_property
|
||||
except ImportError:
|
||||
from cached_property import cached_property
|
||||
|
||||
typed_cached_property = cached_property
|
||||
@ -0,0 +1,671 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from datetime import date, datetime
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, TypeGuard, TypeVar, cast
|
||||
|
||||
import pydantic
|
||||
import pydantic.generics
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import (
|
||||
ParamSpec,
|
||||
Protocol,
|
||||
override,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from ._base_compat import (
|
||||
PYDANTIC_V2,
|
||||
ConfigDict,
|
||||
field_get_default,
|
||||
get_args,
|
||||
get_model_config,
|
||||
get_model_fields,
|
||||
get_origin,
|
||||
is_literal_type,
|
||||
is_union,
|
||||
parse_obj,
|
||||
)
|
||||
from ._base_compat import (
|
||||
GenericModel as BaseGenericModel,
|
||||
)
|
||||
from ._base_type import (
|
||||
IncEx,
|
||||
ModelT,
|
||||
)
|
||||
from ._utils import (
|
||||
PropertyInfo,
|
||||
coerce_boolean,
|
||||
extract_type_arg,
|
||||
is_annotated_type,
|
||||
is_list,
|
||||
is_mapping,
|
||||
parse_date,
|
||||
parse_datetime,
|
||||
strip_annotated_type,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pydantic_core.core_schema import LiteralSchema, ModelField, ModelFieldsSchema
|
||||
|
||||
__all__ = ["BaseModel", "GenericModel"]
|
||||
_BaseModelT = TypeVar("_BaseModelT", bound="BaseModel")
|
||||
|
||||
_T = TypeVar("_T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class _ConfigProtocol(Protocol):
|
||||
allow_population_by_field_name: bool
|
||||
|
||||
|
||||
class BaseModel(pydantic.BaseModel):
|
||||
if PYDANTIC_V2:
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(
|
||||
extra="allow", defer_build=coerce_boolean(os.environ.get("DEFER_PYDANTIC_BUILD", "true"))
|
||||
)
|
||||
else:
|
||||
|
||||
@property
|
||||
@override
|
||||
def model_fields_set(self) -> set[str]:
|
||||
# a forwards-compat shim for pydantic v2
|
||||
return self.__fields_set__ # type: ignore
|
||||
|
||||
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
|
||||
extra: Any = pydantic.Extra.allow # type: ignore
|
||||
|
||||
def to_dict(
|
||||
self,
|
||||
*,
|
||||
mode: Literal["json", "python"] = "python",
|
||||
use_api_names: bool = True,
|
||||
exclude_unset: bool = True,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
warnings: bool = True,
|
||||
) -> dict[str, object]:
|
||||
"""Recursively generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
|
||||
|
||||
By default, fields that were not set by the API will not be included,
|
||||
and keys will match the API response, *not* the property names from the model.
|
||||
|
||||
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
|
||||
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
|
||||
|
||||
Args:
|
||||
mode:
|
||||
If mode is 'json', the dictionary will only contain JSON serializable types. e.g. `datetime` will be turned into a string, `"2024-3-22T18:11:19.117000Z"`.
|
||||
If mode is 'python', the dictionary may contain any Python objects. e.g. `datetime(2024, 3, 22)`
|
||||
|
||||
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
|
||||
exclude_unset: Whether to exclude fields that have not been explicitly set.
|
||||
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
|
||||
exclude_none: Whether to exclude fields that have a value of `None` from the output.
|
||||
warnings: Whether to log warnings when invalid fields are encountered. This is only supported in Pydantic v2.
|
||||
""" # noqa: E501
|
||||
return self.model_dump(
|
||||
mode=mode,
|
||||
by_alias=use_api_names,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
def to_json(
|
||||
self,
|
||||
*,
|
||||
indent: int | None = 2,
|
||||
use_api_names: bool = True,
|
||||
exclude_unset: bool = True,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
warnings: bool = True,
|
||||
) -> str:
|
||||
"""Generates a JSON string representing this model as it would be received from or sent to the API (but with indentation).
|
||||
|
||||
By default, fields that were not set by the API will not be included,
|
||||
and keys will match the API response, *not* the property names from the model.
|
||||
|
||||
For example, if the API responds with `"fooBar": true` but we've defined a `foo_bar: bool` property,
|
||||
the output will use the `"fooBar"` key (unless `use_api_names=False` is passed).
|
||||
|
||||
Args:
|
||||
indent: Indentation to use in the JSON output. If `None` is passed, the output will be compact. Defaults to `2`
|
||||
use_api_names: Whether to use the key that the API responded with or the property name. Defaults to `True`.
|
||||
exclude_unset: Whether to exclude fields that have not been explicitly set.
|
||||
exclude_defaults: Whether to exclude fields that have the default value.
|
||||
exclude_none: Whether to exclude fields that have a value of `None`.
|
||||
warnings: Whether to show any warnings that occurred during serialization. This is only supported in Pydantic v2.
|
||||
""" # noqa: E501
|
||||
return self.model_dump_json(
|
||||
indent=indent,
|
||||
by_alias=use_api_names,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
@override
|
||||
def __str__(self) -> str:
|
||||
# mypy complains about an invalid self arg
|
||||
return f'{self.__repr_name__()}({self.__repr_str__(", ")})' # type: ignore[misc]
|
||||
|
||||
# Override the 'construct' method in a way that supports recursive parsing without validation.
|
||||
# Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836.
|
||||
@classmethod
|
||||
@override
|
||||
def construct(
|
||||
cls: type[ModelT],
|
||||
_fields_set: set[str] | None = None,
|
||||
**values: object,
|
||||
) -> ModelT:
|
||||
m = cls.__new__(cls)
|
||||
fields_values: dict[str, object] = {}
|
||||
|
||||
config = get_model_config(cls)
|
||||
populate_by_name = (
|
||||
config.allow_population_by_field_name
|
||||
if isinstance(config, _ConfigProtocol)
|
||||
else config.get("populate_by_name")
|
||||
)
|
||||
|
||||
if _fields_set is None:
|
||||
_fields_set = set()
|
||||
|
||||
model_fields = get_model_fields(cls)
|
||||
for name, field in model_fields.items():
|
||||
key = field.alias
|
||||
if key is None or (key not in values and populate_by_name):
|
||||
key = name
|
||||
|
||||
if key in values:
|
||||
fields_values[name] = _construct_field(value=values[key], field=field, key=key)
|
||||
_fields_set.add(name)
|
||||
else:
|
||||
fields_values[name] = field_get_default(field)
|
||||
|
||||
_extra = {}
|
||||
for key, value in values.items():
|
||||
if key not in model_fields:
|
||||
if PYDANTIC_V2:
|
||||
_extra[key] = value
|
||||
else:
|
||||
_fields_set.add(key)
|
||||
fields_values[key] = value
|
||||
|
||||
object.__setattr__(m, "__dict__", fields_values) # noqa: PLC2801
|
||||
|
||||
if PYDANTIC_V2:
|
||||
# these properties are copied from Pydantic's `model_construct()` method
|
||||
object.__setattr__(m, "__pydantic_private__", None) # noqa: PLC2801
|
||||
object.__setattr__(m, "__pydantic_extra__", _extra) # noqa: PLC2801
|
||||
object.__setattr__(m, "__pydantic_fields_set__", _fields_set) # noqa: PLC2801
|
||||
else:
|
||||
# init_private_attributes() does not exist in v2
|
||||
m._init_private_attributes() # type: ignore
|
||||
|
||||
# copied from Pydantic v1's `construct()` method
|
||||
object.__setattr__(m, "__fields_set__", _fields_set) # noqa: PLC2801
|
||||
|
||||
return m
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
# type checkers incorrectly complain about this assignment
|
||||
# because the type signatures are technically different
|
||||
# although not in practice
|
||||
model_construct = construct
|
||||
|
||||
if not PYDANTIC_V2:
|
||||
# we define aliases for some of the new pydantic v2 methods so
|
||||
# that we can just document these methods without having to specify
|
||||
# a specific pydantic version as some users may not know which
|
||||
# pydantic version they are currently using
|
||||
|
||||
@override
|
||||
def model_dump(
|
||||
self,
|
||||
*,
|
||||
mode: Literal["json", "python"] | str = "python",
|
||||
include: IncEx = None,
|
||||
exclude: IncEx = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal["none", "warn", "error"] = True,
|
||||
context: dict[str, Any] | None = None,
|
||||
serialize_as_any: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump
|
||||
|
||||
Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
|
||||
|
||||
Args:
|
||||
mode: The mode in which `to_python` should run.
|
||||
If mode is 'json', the dictionary will only contain JSON serializable types.
|
||||
If mode is 'python', the dictionary may contain any Python objects.
|
||||
include: A list of fields to include in the output.
|
||||
exclude: A list of fields to exclude from the output.
|
||||
by_alias: Whether to use the field's alias in the dictionary key if defined.
|
||||
exclude_unset: Whether to exclude fields that are unset or None from the output.
|
||||
exclude_defaults: Whether to exclude fields that are set to their default value from the output.
|
||||
exclude_none: Whether to exclude fields that have a value of `None` from the output.
|
||||
round_trip: Whether to enable serialization and deserialization round-trip support.
|
||||
warnings: Whether to log warnings when invalid fields are encountered.
|
||||
|
||||
Returns:
|
||||
A dictionary representation of the model.
|
||||
"""
|
||||
if mode != "python":
|
||||
raise ValueError("mode is only supported in Pydantic v2")
|
||||
if round_trip != False:
|
||||
raise ValueError("round_trip is only supported in Pydantic v2")
|
||||
if warnings != True:
|
||||
raise ValueError("warnings is only supported in Pydantic v2")
|
||||
if context is not None:
|
||||
raise ValueError("context is only supported in Pydantic v2")
|
||||
if serialize_as_any != False:
|
||||
raise ValueError("serialize_as_any is only supported in Pydantic v2")
|
||||
return super().dict( # pyright: ignore[reportDeprecated]
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
@override
|
||||
def model_dump_json(
|
||||
self,
|
||||
*,
|
||||
indent: int | None = None,
|
||||
include: IncEx = None,
|
||||
exclude: IncEx = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
round_trip: bool = False,
|
||||
warnings: bool | Literal["none", "warn", "error"] = True,
|
||||
context: dict[str, Any] | None = None,
|
||||
serialize_as_any: bool = False,
|
||||
) -> str:
|
||||
"""Usage docs: https://docs.pydantic.dev/2.4/concepts/serialization/#modelmodel_dump_json
|
||||
|
||||
Generates a JSON representation of the model using Pydantic's `to_json` method.
|
||||
|
||||
Args:
|
||||
indent: Indentation to use in the JSON output. If None is passed, the output will be compact.
|
||||
include: Field(s) to include in the JSON output. Can take either a string or set of strings.
|
||||
exclude: Field(s) to exclude from the JSON output. Can take either a string or set of strings.
|
||||
by_alias: Whether to serialize using field aliases.
|
||||
exclude_unset: Whether to exclude fields that have not been explicitly set.
|
||||
exclude_defaults: Whether to exclude fields that have the default value.
|
||||
exclude_none: Whether to exclude fields that have a value of `None`.
|
||||
round_trip: Whether to use serialization/deserialization between JSON and class instance.
|
||||
warnings: Whether to show any warnings that occurred during serialization.
|
||||
|
||||
Returns:
|
||||
A JSON string representation of the model.
|
||||
"""
|
||||
if round_trip != False:
|
||||
raise ValueError("round_trip is only supported in Pydantic v2")
|
||||
if warnings != True:
|
||||
raise ValueError("warnings is only supported in Pydantic v2")
|
||||
if context is not None:
|
||||
raise ValueError("context is only supported in Pydantic v2")
|
||||
if serialize_as_any != False:
|
||||
raise ValueError("serialize_as_any is only supported in Pydantic v2")
|
||||
return super().json( # type: ignore[reportDeprecated]
|
||||
indent=indent,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
by_alias=by_alias,
|
||||
exclude_unset=exclude_unset,
|
||||
exclude_defaults=exclude_defaults,
|
||||
exclude_none=exclude_none,
|
||||
)
|
||||
|
||||
|
||||
def _construct_field(value: object, field: FieldInfo, key: str) -> object:
|
||||
if value is None:
|
||||
return field_get_default(field)
|
||||
|
||||
if PYDANTIC_V2:
|
||||
type_ = field.annotation
|
||||
else:
|
||||
type_ = cast(type, field.outer_type_) # type: ignore
|
||||
|
||||
if type_ is None:
|
||||
raise RuntimeError(f"Unexpected field type is None for {key}")
|
||||
|
||||
return construct_type(value=value, type_=type_)
|
||||
|
||||
|
||||
def is_basemodel(type_: type) -> bool:
|
||||
"""Returns whether or not the given type is either a `BaseModel` or a union of `BaseModel`"""
|
||||
if is_union(type_):
|
||||
return any(is_basemodel(variant) for variant in get_args(type_))
|
||||
|
||||
return is_basemodel_type(type_)
|
||||
|
||||
|
||||
def is_basemodel_type(type_: type) -> TypeGuard[type[BaseModel] | type[GenericModel]]:
|
||||
origin = get_origin(type_) or type_
|
||||
return issubclass(origin, BaseModel) or issubclass(origin, GenericModel)
|
||||
|
||||
|
||||
def build(
|
||||
base_model_cls: Callable[P, _BaseModelT],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> _BaseModelT:
|
||||
"""Construct a BaseModel class without validation.
|
||||
|
||||
This is useful for cases where you need to instantiate a `BaseModel`
|
||||
from an API response as this provides type-safe params which isn't supported
|
||||
by helpers like `construct_type()`.
|
||||
|
||||
```py
|
||||
build(MyModel, my_field_a="foo", my_field_b=123)
|
||||
```
|
||||
"""
|
||||
if args:
|
||||
raise TypeError(
|
||||
"Received positional arguments which are not supported; Keyword arguments must be used instead",
|
||||
)
|
||||
|
||||
return cast(_BaseModelT, construct_type(type_=base_model_cls, value=kwargs))
|
||||
|
||||
|
||||
def construct_type_unchecked(*, value: object, type_: type[_T]) -> _T:
|
||||
"""Loose coercion to the expected type with construction of nested values.
|
||||
|
||||
Note: the returned value from this function is not guaranteed to match the
|
||||
given type.
|
||||
"""
|
||||
return cast(_T, construct_type(value=value, type_=type_))
|
||||
|
||||
|
||||
def construct_type(*, value: object, type_: type) -> object:
|
||||
"""Loose coercion to the expected type with construction of nested values.
|
||||
|
||||
If the given value does not match the expected type then it is returned as-is.
|
||||
"""
|
||||
# we allow `object` as the input type because otherwise, passing things like
|
||||
# `Literal['value']` will be reported as a type error by type checkers
|
||||
type_ = cast("type[object]", type_)
|
||||
|
||||
# unwrap `Annotated[T, ...]` -> `T`
|
||||
if is_annotated_type(type_):
|
||||
meta: tuple[Any, ...] = get_args(type_)[1:]
|
||||
type_ = extract_type_arg(type_, 0)
|
||||
else:
|
||||
meta = ()
|
||||
# we need to use the origin class for any types that are subscripted generics
|
||||
# e.g. Dict[str, object]
|
||||
origin = get_origin(type_) or type_
|
||||
args = get_args(type_)
|
||||
|
||||
if is_union(origin):
|
||||
try:
|
||||
return validate_type(type_=cast("type[object]", type_), value=value)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# if the type is a discriminated union then we want to construct the right variant
|
||||
# in the union, even if the data doesn't match exactly, otherwise we'd break code
|
||||
# that relies on the constructed class types, e.g.
|
||||
#
|
||||
# class FooType:
|
||||
# kind: Literal['foo']
|
||||
# value: str
|
||||
#
|
||||
# class BarType:
|
||||
# kind: Literal['bar']
|
||||
# value: int
|
||||
#
|
||||
# without this block, if the data we get is something like `{'kind': 'bar', 'value': 'foo'}` then
|
||||
# we'd end up constructing `FooType` when it should be `BarType`.
|
||||
discriminator = _build_discriminated_union_meta(union=type_, meta_annotations=meta)
|
||||
if discriminator and is_mapping(value):
|
||||
variant_value = value.get(discriminator.field_alias_from or discriminator.field_name)
|
||||
if variant_value and isinstance(variant_value, str):
|
||||
variant_type = discriminator.mapping.get(variant_value)
|
||||
if variant_type:
|
||||
return construct_type(type_=variant_type, value=value)
|
||||
|
||||
# if the data is not valid, use the first variant that doesn't fail while deserializing
|
||||
for variant in args:
|
||||
try:
|
||||
return construct_type(value=value, type_=variant)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
raise RuntimeError(f"Could not convert data into a valid instance of {type_}")
|
||||
if origin == dict:
|
||||
if not is_mapping(value):
|
||||
return value
|
||||
|
||||
_, items_type = get_args(type_) # Dict[_, items_type]
|
||||
return {key: construct_type(value=item, type_=items_type) for key, item in value.items()}
|
||||
|
||||
if not is_literal_type(type_) and (issubclass(origin, BaseModel) or issubclass(origin, GenericModel)):
|
||||
if is_list(value):
|
||||
return [cast(Any, type_).construct(**entry) if is_mapping(entry) else entry for entry in value]
|
||||
|
||||
if is_mapping(value):
|
||||
if issubclass(type_, BaseModel):
|
||||
return type_.construct(**value) # type: ignore[arg-type]
|
||||
|
||||
return cast(Any, type_).construct(**value)
|
||||
|
||||
if origin == list:
|
||||
if not is_list(value):
|
||||
return value
|
||||
|
||||
inner_type = args[0] # List[inner_type]
|
||||
return [construct_type(value=entry, type_=inner_type) for entry in value]
|
||||
|
||||
if origin == float:
|
||||
if isinstance(value, int):
|
||||
coerced = float(value)
|
||||
if coerced != value:
|
||||
return value
|
||||
return coerced
|
||||
|
||||
return value
|
||||
|
||||
if type_ == datetime:
|
||||
try:
|
||||
return parse_datetime(value) # type: ignore
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
if type_ == date:
|
||||
try:
|
||||
return parse_date(value) # type: ignore
|
||||
except Exception:
|
||||
return value
|
||||
|
||||
return value
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CachedDiscriminatorType(Protocol):
|
||||
__discriminator__: DiscriminatorDetails
|
||||
|
||||
|
||||
class DiscriminatorDetails:
|
||||
field_name: str
|
||||
"""The name of the discriminator field in the variant class, e.g.
|
||||
|
||||
```py
|
||||
class Foo(BaseModel):
|
||||
type: Literal['foo']
|
||||
```
|
||||
|
||||
Will result in field_name='type'
|
||||
"""
|
||||
|
||||
field_alias_from: str | None
|
||||
"""The name of the discriminator field in the API response, e.g.
|
||||
|
||||
```py
|
||||
class Foo(BaseModel):
|
||||
type: Literal['foo'] = Field(alias='type_from_api')
|
||||
```
|
||||
|
||||
Will result in field_alias_from='type_from_api'
|
||||
"""
|
||||
|
||||
mapping: dict[str, type]
|
||||
"""Mapping of discriminator value to variant type, e.g.
|
||||
|
||||
{'foo': FooVariant, 'bar': BarVariant}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
mapping: dict[str, type],
|
||||
discriminator_field: str,
|
||||
discriminator_alias: str | None,
|
||||
) -> None:
|
||||
self.mapping = mapping
|
||||
self.field_name = discriminator_field
|
||||
self.field_alias_from = discriminator_alias
|
||||
|
||||
|
||||
def _build_discriminated_union_meta(*, union: type, meta_annotations: tuple[Any, ...]) -> DiscriminatorDetails | None:
|
||||
if isinstance(union, CachedDiscriminatorType):
|
||||
return union.__discriminator__
|
||||
|
||||
discriminator_field_name: str | None = None
|
||||
|
||||
for annotation in meta_annotations:
|
||||
if isinstance(annotation, PropertyInfo) and annotation.discriminator is not None:
|
||||
discriminator_field_name = annotation.discriminator
|
||||
break
|
||||
|
||||
if not discriminator_field_name:
|
||||
return None
|
||||
|
||||
mapping: dict[str, type] = {}
|
||||
discriminator_alias: str | None = None
|
||||
|
||||
for variant in get_args(union):
|
||||
variant = strip_annotated_type(variant)
|
||||
if is_basemodel_type(variant):
|
||||
if PYDANTIC_V2:
|
||||
field = _extract_field_schema_pv2(variant, discriminator_field_name)
|
||||
if not field:
|
||||
continue
|
||||
|
||||
# Note: if one variant defines an alias then they all should
|
||||
discriminator_alias = field.get("serialization_alias")
|
||||
|
||||
field_schema = field["schema"]
|
||||
|
||||
if field_schema["type"] == "literal":
|
||||
for entry in cast("LiteralSchema", field_schema)["expected"]:
|
||||
if isinstance(entry, str):
|
||||
mapping[entry] = variant
|
||||
else:
|
||||
field_info = cast("dict[str, FieldInfo]", variant.__fields__).get(discriminator_field_name) # pyright: ignore[reportDeprecated, reportUnnecessaryCast]
|
||||
if not field_info:
|
||||
continue
|
||||
|
||||
# Note: if one variant defines an alias then they all should
|
||||
discriminator_alias = field_info.alias
|
||||
|
||||
if field_info.annotation and is_literal_type(field_info.annotation):
|
||||
for entry in get_args(field_info.annotation):
|
||||
if isinstance(entry, str):
|
||||
mapping[entry] = variant
|
||||
|
||||
if not mapping:
|
||||
return None
|
||||
|
||||
details = DiscriminatorDetails(
|
||||
mapping=mapping,
|
||||
discriminator_field=discriminator_field_name,
|
||||
discriminator_alias=discriminator_alias,
|
||||
)
|
||||
cast(CachedDiscriminatorType, union).__discriminator__ = details
|
||||
return details
|
||||
|
||||
|
||||
def _extract_field_schema_pv2(model: type[BaseModel], field_name: str) -> ModelField | None:
|
||||
schema = model.__pydantic_core_schema__
|
||||
if schema["type"] != "model":
|
||||
return None
|
||||
|
||||
fields_schema = schema["schema"]
|
||||
if fields_schema["type"] != "model-fields":
|
||||
return None
|
||||
|
||||
fields_schema = cast("ModelFieldsSchema", fields_schema)
|
||||
|
||||
field = fields_schema["fields"].get(field_name)
|
||||
if not field:
|
||||
return None
|
||||
|
||||
return cast("ModelField", field) # pyright: ignore[reportUnnecessaryCast]
|
||||
|
||||
|
||||
def validate_type(*, type_: type[_T], value: object) -> _T:
|
||||
"""Strict validation that the given value matches the expected type"""
|
||||
if inspect.isclass(type_) and issubclass(type_, pydantic.BaseModel):
|
||||
return cast(_T, parse_obj(type_, value))
|
||||
|
||||
return cast(_T, _validate_non_model_type(type_=type_, value=value))
|
||||
|
||||
|
||||
# our use of subclasssing here causes weirdness for type checkers,
|
||||
# so we just pretend that we don't subclass
|
||||
if TYPE_CHECKING:
|
||||
GenericModel = BaseModel
|
||||
else:
|
||||
|
||||
class GenericModel(BaseGenericModel, BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
if PYDANTIC_V2:
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
|
||||
return TypeAdapter(type_).validate_python(value)
|
||||
|
||||
elif not TYPE_CHECKING:
|
||||
|
||||
class TypeAdapter(Generic[_T]):
|
||||
"""Used as a placeholder to easily convert runtime types to a Pydantic format
|
||||
to provide validation.
|
||||
|
||||
For example:
|
||||
```py
|
||||
validated = RootModel[int](__root__="5").__root__
|
||||
# validated: 5
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, type_: type[_T]):
|
||||
self.type_ = type_
|
||||
|
||||
def validate_python(self, value: Any) -> _T:
|
||||
if not isinstance(value, self.type_):
|
||||
raise ValueError(f"Invalid type: {value} is not of type {self.type_}")
|
||||
return value
|
||||
|
||||
def _validate_non_model_type(*, type_: type[_T], value: object) -> _T:
|
||||
return TypeAdapter(type_).validate_python(value)
|
||||
@ -0,0 +1,207 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class HttpxResponseContent:
|
||||
@property
|
||||
def content(self) -> bytes:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
@property
|
||||
def encoding(self) -> str | None:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
@property
|
||||
def charset_encoding(self) -> str | None:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
def json(self, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
def read(self) -> bytes:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
def iter_lines(self) -> Iterator[str]:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
def write_to_file(
|
||||
self,
|
||||
file: str | os.PathLike[str],
|
||||
) -> None:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
def stream_to_file(
|
||||
self,
|
||||
file: str | os.PathLike[str],
|
||||
*,
|
||||
chunk_size: int | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
def close(self) -> None:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
async def aread(self) -> bytes:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
async def aiter_lines(self) -> AsyncIterator[str]:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
async def astream_to_file(
|
||||
self,
|
||||
file: str | os.PathLike[str],
|
||||
*,
|
||||
chunk_size: int | None = None,
|
||||
) -> None:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
raise NotImplementedError("This method is not implemented for this class.")
|
||||
|
||||
|
||||
class HttpxBinaryResponseContent(HttpxResponseContent):
|
||||
response: httpx.Response
|
||||
|
||||
def __init__(self, response: httpx.Response) -> None:
|
||||
self.response = response
|
||||
|
||||
@property
|
||||
def content(self) -> bytes:
|
||||
return self.response.content
|
||||
|
||||
@property
|
||||
def encoding(self) -> str | None:
|
||||
return self.response.encoding
|
||||
|
||||
@property
|
||||
def charset_encoding(self) -> str | None:
|
||||
return self.response.charset_encoding
|
||||
|
||||
def read(self) -> bytes:
|
||||
return self.response.read()
|
||||
|
||||
def text(self) -> str:
|
||||
raise NotImplementedError("Not implemented for binary response content")
|
||||
|
||||
def json(self, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError("Not implemented for binary response content")
|
||||
|
||||
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
|
||||
raise NotImplementedError("Not implemented for binary response content")
|
||||
|
||||
def iter_lines(self) -> Iterator[str]:
|
||||
raise NotImplementedError("Not implemented for binary response content")
|
||||
|
||||
async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
|
||||
raise NotImplementedError("Not implemented for binary response content")
|
||||
|
||||
async def aiter_lines(self) -> AsyncIterator[str]:
|
||||
raise NotImplementedError("Not implemented for binary response content")
|
||||
|
||||
def iter_bytes(self, chunk_size: int | None = None) -> Iterator[bytes]:
|
||||
return self.response.iter_bytes(chunk_size)
|
||||
|
||||
def iter_raw(self, chunk_size: int | None = None) -> Iterator[bytes]:
|
||||
return self.response.iter_raw(chunk_size)
|
||||
|
||||
def write_to_file(
|
||||
self,
|
||||
file: str | os.PathLike[str],
|
||||
) -> None:
|
||||
"""Write the output to the given file.
|
||||
|
||||
Accepts a filename or any path-like object, e.g. pathlib.Path
|
||||
|
||||
Note: if you want to stream the data to the file instead of writing
|
||||
all at once then you should use `.with_streaming_response` when making
|
||||
the API request, e.g. `client.with_streaming_response.foo().stream_to_file('my_filename.txt')`
|
||||
"""
|
||||
with open(file, mode="wb") as f:
|
||||
for data in self.response.iter_bytes():
|
||||
f.write(data)
|
||||
|
||||
def stream_to_file(
|
||||
self,
|
||||
file: str | os.PathLike[str],
|
||||
*,
|
||||
chunk_size: int | None = None,
|
||||
) -> None:
|
||||
with open(file, mode="wb") as f:
|
||||
for data in self.response.iter_bytes(chunk_size):
|
||||
f.write(data)
|
||||
|
||||
def close(self) -> None:
|
||||
return self.response.close()
|
||||
|
||||
async def aread(self) -> bytes:
|
||||
return await self.response.aread()
|
||||
|
||||
async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
|
||||
return self.response.aiter_bytes(chunk_size)
|
||||
|
||||
async def aiter_raw(self, chunk_size: int | None = None) -> AsyncIterator[bytes]:
|
||||
return self.response.aiter_raw(chunk_size)
|
||||
|
||||
async def astream_to_file(
|
||||
self,
|
||||
file: str | os.PathLike[str],
|
||||
*,
|
||||
chunk_size: int | None = None,
|
||||
) -> None:
|
||||
path = anyio.Path(file)
|
||||
async with await path.open(mode="wb") as f:
|
||||
async for data in self.response.aiter_bytes(chunk_size):
|
||||
await f.write(data)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
return await self.response.aclose()
|
||||
|
||||
|
||||
class HttpxTextBinaryResponseContent(HttpxBinaryResponseContent):
|
||||
response: httpx.Response
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
return self.response.text
|
||||
|
||||
def json(self, **kwargs: Any) -> Any:
|
||||
return self.response.json(**kwargs)
|
||||
|
||||
def iter_text(self, chunk_size: int | None = None) -> Iterator[str]:
|
||||
return self.response.iter_text(chunk_size)
|
||||
|
||||
def iter_lines(self) -> Iterator[str]:
|
||||
return self.response.iter_lines()
|
||||
|
||||
async def aiter_text(self, chunk_size: int | None = None) -> AsyncIterator[str]:
|
||||
return self.response.aiter_text(chunk_size)
|
||||
|
||||
async def aiter_lines(self) -> AsyncIterator[str]:
|
||||
return self.response.aiter_lines()
|
||||
@ -0,0 +1,341 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, cast, get_origin, overload
|
||||
|
||||
import httpx
|
||||
import pydantic
|
||||
from typing_extensions import ParamSpec, override
|
||||
|
||||
from ._base_models import BaseModel, is_basemodel
|
||||
from ._base_type import NoneType
|
||||
from ._constants import RAW_RESPONSE_HEADER
|
||||
from ._errors import APIResponseValidationError
|
||||
from ._legacy_binary_response import HttpxResponseContent, HttpxTextBinaryResponseContent
|
||||
from ._sse_client import StreamResponse, extract_stream_chunk_type, is_stream_class_type
|
||||
from ._utils import extract_type_arg, is_annotated_type, is_given
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._http_client import HttpClient
|
||||
from ._request_opt import FinalRequestOptions
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
_T = TypeVar("_T")
|
||||
|
||||
log: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LegacyAPIResponse(Generic[R]):
|
||||
"""This is a legacy class as it will be replaced by `APIResponse`
|
||||
and `AsyncAPIResponse` in the `_response.py` file in the next major
|
||||
release.
|
||||
|
||||
For the sync client this will mostly be the same with the exception
|
||||
of `content` & `text` will be methods instead of properties. In the
|
||||
async client, all methods will be async.
|
||||
|
||||
A migration script will be provided & the migration in general should
|
||||
be smooth.
|
||||
"""
|
||||
|
||||
_cast_type: type[R]
|
||||
_client: HttpClient
|
||||
_parsed_by_type: dict[type[Any], Any]
|
||||
_stream: bool
|
||||
_stream_cls: type[StreamResponse[Any]] | None
|
||||
_options: FinalRequestOptions
|
||||
|
||||
http_response: httpx.Response
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
raw: httpx.Response,
|
||||
cast_type: type[R],
|
||||
client: HttpClient,
|
||||
stream: bool,
|
||||
stream_cls: type[StreamResponse[Any]] | None,
|
||||
options: FinalRequestOptions,
|
||||
) -> None:
|
||||
self._cast_type = cast_type
|
||||
self._client = client
|
||||
self._parsed_by_type = {}
|
||||
self._stream = stream
|
||||
self._stream_cls = stream_cls
|
||||
self._options = options
|
||||
self.http_response = raw
|
||||
|
||||
@property
|
||||
def request_id(self) -> str | None:
|
||||
return self.http_response.headers.get("x-request-id") # type: ignore[no-any-return]
|
||||
|
||||
@overload
|
||||
def parse(self, *, to: type[_T]) -> _T: ...
|
||||
|
||||
@overload
|
||||
def parse(self) -> R: ...
|
||||
|
||||
def parse(self, *, to: type[_T] | None = None) -> R | _T:
|
||||
"""Returns the rich python representation of this response's data.
|
||||
|
||||
NOTE: For the async client: this will become a coroutine in the next major version.
|
||||
|
||||
For lower-level control, see `.read()`, `.json()`, `.iter_bytes()`.
|
||||
|
||||
You can customise the type that the response is parsed into through
|
||||
the `to` argument, e.g.
|
||||
|
||||
```py
|
||||
from zhipuai import BaseModel
|
||||
|
||||
|
||||
class MyModel(BaseModel):
|
||||
foo: str
|
||||
|
||||
|
||||
obj = response.parse(to=MyModel)
|
||||
print(obj.foo)
|
||||
```
|
||||
|
||||
We support parsing:
|
||||
- `BaseModel`
|
||||
- `dict`
|
||||
- `list`
|
||||
- `Union`
|
||||
- `str`
|
||||
- `int`
|
||||
- `float`
|
||||
- `httpx.Response`
|
||||
"""
|
||||
cache_key = to if to is not None else self._cast_type
|
||||
cached = self._parsed_by_type.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached # type: ignore[no-any-return]
|
||||
|
||||
parsed = self._parse(to=to)
|
||||
if is_given(self._options.post_parser):
|
||||
parsed = self._options.post_parser(parsed)
|
||||
|
||||
self._parsed_by_type[cache_key] = parsed
|
||||
return parsed
|
||||
|
||||
@property
|
||||
def headers(self) -> httpx.Headers:
|
||||
return self.http_response.headers
|
||||
|
||||
@property
|
||||
def http_request(self) -> httpx.Request:
|
||||
return self.http_response.request
|
||||
|
||||
@property
|
||||
def status_code(self) -> int:
|
||||
return self.http_response.status_code
|
||||
|
||||
@property
|
||||
def url(self) -> httpx.URL:
|
||||
return self.http_response.url
|
||||
|
||||
@property
|
||||
def method(self) -> str:
|
||||
return self.http_request.method
|
||||
|
||||
@property
|
||||
def content(self) -> bytes:
|
||||
"""Return the binary response content.
|
||||
|
||||
NOTE: this will be removed in favour of `.read()` in the
|
||||
next major version.
|
||||
"""
|
||||
return self.http_response.content
|
||||
|
||||
@property
|
||||
def text(self) -> str:
|
||||
"""Return the decoded response content.
|
||||
|
||||
NOTE: this will be turned into a method in the next major version.
|
||||
"""
|
||||
return self.http_response.text
|
||||
|
||||
@property
|
||||
def http_version(self) -> str:
|
||||
return self.http_response.http_version
|
||||
|
||||
@property
|
||||
def is_closed(self) -> bool:
|
||||
return self.http_response.is_closed
|
||||
|
||||
@property
|
||||
def elapsed(self) -> datetime.timedelta:
|
||||
"""The time taken for the complete request/response cycle to complete."""
|
||||
return self.http_response.elapsed
|
||||
|
||||
def _parse(self, *, to: type[_T] | None = None) -> R | _T:
|
||||
# unwrap `Annotated[T, ...]` -> `T`
|
||||
if to and is_annotated_type(to):
|
||||
to = extract_type_arg(to, 0)
|
||||
|
||||
if self._stream:
|
||||
if to:
|
||||
if not is_stream_class_type(to):
|
||||
raise TypeError(f"Expected custom parse type to be a subclass of {StreamResponse}")
|
||||
|
||||
return cast(
|
||||
_T,
|
||||
to(
|
||||
cast_type=extract_stream_chunk_type(
|
||||
to,
|
||||
failure_message="Expected custom stream type to be passed with a type argument, e.g. StreamResponse[ChunkType]", # noqa: E501
|
||||
),
|
||||
response=self.http_response,
|
||||
client=cast(Any, self._client),
|
||||
),
|
||||
)
|
||||
|
||||
if self._stream_cls:
|
||||
return cast(
|
||||
R,
|
||||
self._stream_cls(
|
||||
cast_type=extract_stream_chunk_type(self._stream_cls),
|
||||
response=self.http_response,
|
||||
client=cast(Any, self._client),
|
||||
),
|
||||
)
|
||||
|
||||
stream_cls = cast("type[StreamResponse[Any]] | None", self._client._default_stream_cls)
|
||||
if stream_cls is None:
|
||||
raise MissingStreamClassError()
|
||||
|
||||
return cast(
|
||||
R,
|
||||
stream_cls(
|
||||
cast_type=self._cast_type,
|
||||
response=self.http_response,
|
||||
client=cast(Any, self._client),
|
||||
),
|
||||
)
|
||||
|
||||
cast_type = to if to is not None else self._cast_type
|
||||
|
||||
# unwrap `Annotated[T, ...]` -> `T`
|
||||
if is_annotated_type(cast_type):
|
||||
cast_type = extract_type_arg(cast_type, 0)
|
||||
|
||||
if cast_type is NoneType:
|
||||
return cast(R, None)
|
||||
|
||||
response = self.http_response
|
||||
if cast_type == str:
|
||||
return cast(R, response.text)
|
||||
|
||||
if cast_type == int:
|
||||
return cast(R, int(response.text))
|
||||
|
||||
if cast_type == float:
|
||||
return cast(R, float(response.text))
|
||||
|
||||
origin = get_origin(cast_type) or cast_type
|
||||
|
||||
if inspect.isclass(origin) and issubclass(origin, HttpxResponseContent):
|
||||
# in the response, e.g. mime file
|
||||
*_, filename = response.headers.get("content-disposition", "").split("filename=")
|
||||
# 判断文件类型是jsonl类型的使用HttpxTextBinaryResponseContent
|
||||
if filename and filename.endswith(".jsonl") or filename and filename.endswith(".xlsx"):
|
||||
return cast(R, HttpxTextBinaryResponseContent(response))
|
||||
else:
|
||||
return cast(R, cast_type(response)) # type: ignore
|
||||
|
||||
if origin == LegacyAPIResponse:
|
||||
raise RuntimeError("Unexpected state - cast_type is `APIResponse`")
|
||||
|
||||
if inspect.isclass(origin) and issubclass(origin, httpx.Response):
|
||||
# Because of the invariance of our ResponseT TypeVar, users can subclass httpx.Response
|
||||
# and pass that class to our request functions. We cannot change the variance to be either
|
||||
# covariant or contravariant as that makes our usage of ResponseT illegal. We could construct
|
||||
# the response class ourselves but that is something that should be supported directly in httpx
|
||||
# as it would be easy to incorrectly construct the Response object due to the multitude of arguments.
|
||||
if cast_type != httpx.Response:
|
||||
raise ValueError("Subclasses of httpx.Response cannot be passed to `cast_type`")
|
||||
return cast(R, response)
|
||||
|
||||
if inspect.isclass(origin) and not issubclass(origin, BaseModel) and issubclass(origin, pydantic.BaseModel):
|
||||
raise TypeError("Pydantic models must subclass our base model type, e.g. `from openai import BaseModel`")
|
||||
|
||||
if (
|
||||
cast_type is not object
|
||||
and origin is not list
|
||||
and origin is not dict
|
||||
and origin is not Union
|
||||
and not issubclass(origin, BaseModel)
|
||||
):
|
||||
raise RuntimeError(
|
||||
f"Unsupported type, expected {cast_type} to be a subclass of {BaseModel}, {dict}, {list}, {Union}, {NoneType}, {str} or {httpx.Response}." # noqa: E501
|
||||
)
|
||||
|
||||
# split is required to handle cases where additional information is included
|
||||
# in the response, e.g. application/json; charset=utf-8
|
||||
content_type, *_ = response.headers.get("content-type", "*").split(";")
|
||||
if content_type != "application/json":
|
||||
if is_basemodel(cast_type):
|
||||
try:
|
||||
data = response.json()
|
||||
except Exception as exc:
|
||||
log.debug("Could not read JSON from response data due to %s - %s", type(exc), exc)
|
||||
else:
|
||||
return self._client._process_response_data(
|
||||
data=data,
|
||||
cast_type=cast_type, # type: ignore
|
||||
response=response,
|
||||
)
|
||||
|
||||
if self._client._strict_response_validation:
|
||||
raise APIResponseValidationError(
|
||||
response=response,
|
||||
message=f"Expected Content-Type response header to be `application/json` but received `{content_type}` instead.", # noqa: E501
|
||||
json_data=response.text,
|
||||
)
|
||||
|
||||
# If the API responds with content that isn't JSON then we just return
|
||||
# the (decoded) text without performing any parsing so that you can still
|
||||
# handle the response however you need to.
|
||||
return response.text # type: ignore
|
||||
|
||||
data = response.json()
|
||||
|
||||
return self._client._process_response_data(
|
||||
data=data,
|
||||
cast_type=cast_type, # type: ignore
|
||||
response=response,
|
||||
)
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f"<APIResponse [{self.status_code} {self.http_response.reason_phrase}] type={self._cast_type}>"
|
||||
|
||||
|
||||
class MissingStreamClassError(TypeError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__(
|
||||
"The `stream` argument was set to `True` but the `stream_cls` argument was not given. See `openai._streaming` for reference", # noqa: E501
|
||||
)
|
||||
|
||||
|
||||
def to_raw_response_wrapper(func: Callable[P, R]) -> Callable[P, LegacyAPIResponse[R]]:
|
||||
"""Higher order function that takes one of our bound API methods and wraps it
|
||||
to support returning the raw `APIResponse` object directly.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args: P.args, **kwargs: P.kwargs) -> LegacyAPIResponse[R]:
|
||||
extra_headers: dict[str, str] = {**(cast(Any, kwargs.get("extra_headers")) or {})}
|
||||
extra_headers[RAW_RESPONSE_HEADER] = "true"
|
||||
|
||||
kwargs["extra_headers"] = extra_headers
|
||||
|
||||
return cast(LegacyAPIResponse[R], func(*args, **kwargs))
|
||||
|
||||
return wrapped
|
||||
@ -1,48 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, ClassVar, Union
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Union, cast
|
||||
|
||||
import pydantic.generics
|
||||
from httpx import Timeout
|
||||
from pydantic import ConfigDict
|
||||
from typing_extensions import TypedDict, Unpack
|
||||
from typing_extensions import Required, TypedDict, Unpack, final
|
||||
|
||||
from ._base_type import Body, Headers, HttpxRequestFiles, NotGiven, Query
|
||||
from ._utils import remove_notgiven_indict
|
||||
from ._base_compat import PYDANTIC_V2, ConfigDict
|
||||
from ._base_type import AnyMapping, Body, Headers, HttpxRequestFiles, NotGiven, Query
|
||||
from ._constants import RAW_RESPONSE_HEADER
|
||||
from ._utils import is_given, strip_not_given
|
||||
|
||||
|
||||
class UserRequestInput(TypedDict, total=False):
|
||||
headers: Headers
|
||||
max_retries: int
|
||||
timeout: float | Timeout | None
|
||||
params: Query
|
||||
extra_json: AnyMapping
|
||||
|
||||
|
||||
class FinalRequestOptionsInput(TypedDict, total=False):
|
||||
method: Required[str]
|
||||
url: Required[str]
|
||||
params: Query
|
||||
headers: Headers
|
||||
params: Query | None
|
||||
max_retries: int
|
||||
timeout: float | Timeout | None
|
||||
files: HttpxRequestFiles | None
|
||||
json_data: Body
|
||||
extra_json: AnyMapping
|
||||
|
||||
|
||||
class ClientRequestParam:
|
||||
@final
|
||||
class FinalRequestOptions(pydantic.BaseModel):
|
||||
method: str
|
||||
url: str
|
||||
max_retries: Union[int, NotGiven] = NotGiven()
|
||||
timeout: Union[float, NotGiven] = NotGiven()
|
||||
params: Query = {}
|
||||
headers: Union[Headers, NotGiven] = NotGiven()
|
||||
json_data: Union[Body, None] = None
|
||||
max_retries: Union[int, NotGiven] = NotGiven()
|
||||
timeout: Union[float, Timeout, None, NotGiven] = NotGiven()
|
||||
files: Union[HttpxRequestFiles, None] = None
|
||||
params: Query = {}
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
|
||||
idempotency_key: Union[str, None] = None
|
||||
post_parser: Union[Callable[[Any], Any], NotGiven] = NotGiven()
|
||||
|
||||
# It should be noted that we cannot use `json` here as that would override
|
||||
# a BaseModel method in an incompatible fashion.
|
||||
json_data: Union[Body, None] = None
|
||||
extra_json: Union[AnyMapping, None] = None
|
||||
|
||||
if PYDANTIC_V2:
|
||||
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
|
||||
else:
|
||||
|
||||
def get_max_retries(self, max_retries) -> int:
|
||||
class Config(pydantic.BaseConfig): # pyright: ignore[reportDeprecated]
|
||||
arbitrary_types_allowed: bool = True
|
||||
|
||||
def get_max_retries(self, max_retries: int) -> int:
|
||||
if isinstance(self.max_retries, NotGiven):
|
||||
return max_retries
|
||||
return self.max_retries
|
||||
|
||||
def _strip_raw_response_header(self) -> None:
|
||||
if not is_given(self.headers):
|
||||
return
|
||||
|
||||
if self.headers.get(RAW_RESPONSE_HEADER):
|
||||
self.headers = {**self.headers}
|
||||
self.headers.pop(RAW_RESPONSE_HEADER)
|
||||
|
||||
# override the `construct` method so that we can run custom transformations.
|
||||
# this is necessary as we don't want to do any actual runtime type checking
|
||||
# (which means we can't use validators) but we do want to ensure that `NotGiven`
|
||||
# values are not present
|
||||
#
|
||||
# type ignore required because we're adding explicit types to `**values`
|
||||
@classmethod
|
||||
def construct( # type: ignore
|
||||
cls,
|
||||
_fields_set: set[str] | None = None,
|
||||
**values: Unpack[UserRequestInput],
|
||||
) -> ClientRequestParam:
|
||||
kwargs: dict[str, Any] = {key: remove_notgiven_indict(value) for key, value in values.items()}
|
||||
client = cls()
|
||||
client.__dict__.update(kwargs)
|
||||
|
||||
return client
|
||||
) -> FinalRequestOptions:
|
||||
kwargs: dict[str, Any] = {
|
||||
# we unconditionally call `strip_not_given` on any value
|
||||
# as it will just ignore any non-mapping types
|
||||
key: strip_not_given(value)
|
||||
for key, value in values.items()
|
||||
}
|
||||
if PYDANTIC_V2:
|
||||
return super().model_construct(_fields_set, **kwargs)
|
||||
return cast(FinalRequestOptions, super().construct(_fields_set, **kwargs)) # pyright: ignore[reportDeprecated]
|
||||
|
||||
model_construct = construct
|
||||
if not TYPE_CHECKING:
|
||||
# type checkers incorrectly complain about this assignment
|
||||
model_construct = construct
|
||||
|
||||
@ -1,19 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable, Mapping
|
||||
from typing import TypeVar
|
||||
|
||||
from ._base_type import NotGiven
|
||||
|
||||
|
||||
def remove_notgiven_indict(obj):
|
||||
if obj is None or (not isinstance(obj, Mapping)):
|
||||
return obj
|
||||
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
|
||||
return [item for sublist in t for item in sublist]
|
||||
@ -0,0 +1,52 @@
|
||||
from ._utils import ( # noqa: I001
|
||||
remove_notgiven_indict as remove_notgiven_indict, # noqa: PLC0414
|
||||
flatten as flatten, # noqa: PLC0414
|
||||
is_dict as is_dict, # noqa: PLC0414
|
||||
is_list as is_list, # noqa: PLC0414
|
||||
is_given as is_given, # noqa: PLC0414
|
||||
is_tuple as is_tuple, # noqa: PLC0414
|
||||
is_mapping as is_mapping, # noqa: PLC0414
|
||||
is_tuple_t as is_tuple_t, # noqa: PLC0414
|
||||
parse_date as parse_date, # noqa: PLC0414
|
||||
is_iterable as is_iterable, # noqa: PLC0414
|
||||
is_sequence as is_sequence, # noqa: PLC0414
|
||||
coerce_float as coerce_float, # noqa: PLC0414
|
||||
is_mapping_t as is_mapping_t, # noqa: PLC0414
|
||||
removeprefix as removeprefix, # noqa: PLC0414
|
||||
removesuffix as removesuffix, # noqa: PLC0414
|
||||
extract_files as extract_files, # noqa: PLC0414
|
||||
is_sequence_t as is_sequence_t, # noqa: PLC0414
|
||||
required_args as required_args, # noqa: PLC0414
|
||||
coerce_boolean as coerce_boolean, # noqa: PLC0414
|
||||
coerce_integer as coerce_integer, # noqa: PLC0414
|
||||
file_from_path as file_from_path, # noqa: PLC0414
|
||||
parse_datetime as parse_datetime, # noqa: PLC0414
|
||||
strip_not_given as strip_not_given, # noqa: PLC0414
|
||||
deepcopy_minimal as deepcopy_minimal, # noqa: PLC0414
|
||||
get_async_library as get_async_library, # noqa: PLC0414
|
||||
maybe_coerce_float as maybe_coerce_float, # noqa: PLC0414
|
||||
get_required_header as get_required_header, # noqa: PLC0414
|
||||
maybe_coerce_boolean as maybe_coerce_boolean, # noqa: PLC0414
|
||||
maybe_coerce_integer as maybe_coerce_integer, # noqa: PLC0414
|
||||
drop_prefix_image_data as drop_prefix_image_data, # noqa: PLC0414
|
||||
)
|
||||
|
||||
|
||||
from ._typing import (
|
||||
is_list_type as is_list_type, # noqa: PLC0414
|
||||
is_union_type as is_union_type, # noqa: PLC0414
|
||||
extract_type_arg as extract_type_arg, # noqa: PLC0414
|
||||
is_iterable_type as is_iterable_type, # noqa: PLC0414
|
||||
is_required_type as is_required_type, # noqa: PLC0414
|
||||
is_annotated_type as is_annotated_type, # noqa: PLC0414
|
||||
strip_annotated_type as strip_annotated_type, # noqa: PLC0414
|
||||
extract_type_var_from_base as extract_type_var_from_base, # noqa: PLC0414
|
||||
)
|
||||
|
||||
from ._transform import (
|
||||
PropertyInfo as PropertyInfo, # noqa: PLC0414
|
||||
transform as transform, # noqa: PLC0414
|
||||
async_transform as async_transform, # noqa: PLC0414
|
||||
maybe_transform as maybe_transform, # noqa: PLC0414
|
||||
async_maybe_transform as async_maybe_transform, # noqa: PLC0414
|
||||
)
|
||||
@ -0,0 +1,383 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
import pathlib
|
||||
from collections.abc import Mapping
|
||||
from datetime import date, datetime
|
||||
from typing import Any, Literal, TypeVar, cast, get_args, get_type_hints
|
||||
|
||||
import anyio
|
||||
import pydantic
|
||||
from typing_extensions import override
|
||||
|
||||
from .._base_compat import is_typeddict, model_dump
|
||||
from .._files import is_base64_file_input
|
||||
from ._typing import (
|
||||
extract_type_arg,
|
||||
is_annotated_type,
|
||||
is_iterable_type,
|
||||
is_list_type,
|
||||
is_required_type,
|
||||
is_union_type,
|
||||
strip_annotated_type,
|
||||
)
|
||||
from ._utils import (
|
||||
is_iterable,
|
||||
is_list,
|
||||
is_mapping,
|
||||
)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
# TODO: support for drilling globals() and locals()
|
||||
# TODO: ensure works correctly with forward references in all cases
|
||||
|
||||
|
||||
PropertyFormat = Literal["iso8601", "base64", "custom"]
|
||||
|
||||
|
||||
class PropertyInfo:
|
||||
"""Metadata class to be used in Annotated types to provide information about a given type.
|
||||
|
||||
For example:
|
||||
|
||||
class MyParams(TypedDict):
|
||||
account_holder_name: Annotated[str, PropertyInfo(alias='accountHolderName')]
|
||||
|
||||
This means that {'account_holder_name': 'Robert'} will be transformed to {'accountHolderName': 'Robert'} before being sent to the API.
|
||||
""" # noqa: E501
|
||||
|
||||
alias: str | None
|
||||
format: PropertyFormat | None
|
||||
format_template: str | None
|
||||
discriminator: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
alias: str | None = None,
|
||||
format: PropertyFormat | None = None,
|
||||
format_template: str | None = None,
|
||||
discriminator: str | None = None,
|
||||
) -> None:
|
||||
self.alias = alias
|
||||
self.format = format
|
||||
self.format_template = format_template
|
||||
self.discriminator = discriminator
|
||||
|
||||
@override
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(alias='{self.alias}', format={self.format}, format_template='{self.format_template}', discriminator='{self.discriminator}')" # noqa: E501
|
||||
|
||||
|
||||
def maybe_transform(
|
||||
data: object,
|
||||
expected_type: object,
|
||||
) -> Any | None:
|
||||
"""Wrapper over `transform()` that allows `None` to be passed.
|
||||
|
||||
See `transform()` for more details.
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
return transform(data, expected_type)
|
||||
|
||||
|
||||
# Wrapper over _transform_recursive providing fake types
|
||||
def transform(
|
||||
data: _T,
|
||||
expected_type: object,
|
||||
) -> _T:
|
||||
"""Transform dictionaries based off of type information from the given type, for example:
|
||||
|
||||
```py
|
||||
class Params(TypedDict, total=False):
|
||||
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
|
||||
|
||||
|
||||
transformed = transform({"card_id": "<my card ID>"}, Params)
|
||||
# {'cardID': '<my card ID>'}
|
||||
```
|
||||
|
||||
Any keys / data that does not have type information given will be included as is.
|
||||
|
||||
It should be noted that the transformations that this function does are not represented in the type system.
|
||||
"""
|
||||
transformed = _transform_recursive(data, annotation=cast(type, expected_type))
|
||||
return cast(_T, transformed)
|
||||
|
||||
|
||||
def _get_annotated_type(type_: type) -> type | None:
|
||||
"""If the given type is an `Annotated` type then it is returned, if not `None` is returned.
|
||||
|
||||
This also unwraps the type when applicable, e.g. `Required[Annotated[T, ...]]`
|
||||
"""
|
||||
if is_required_type(type_):
|
||||
# Unwrap `Required[Annotated[T, ...]]` to `Annotated[T, ...]`
|
||||
type_ = get_args(type_)[0]
|
||||
|
||||
if is_annotated_type(type_):
|
||||
return type_
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _maybe_transform_key(key: str, type_: type) -> str:
|
||||
"""Transform the given `data` based on the annotations provided in `type_`.
|
||||
|
||||
Note: this function only looks at `Annotated` types that contain `PropertInfo` metadata.
|
||||
"""
|
||||
annotated_type = _get_annotated_type(type_)
|
||||
if annotated_type is None:
|
||||
# no `Annotated` definition for this type, no transformation needed
|
||||
return key
|
||||
|
||||
# ignore the first argument as it is the actual type
|
||||
annotations = get_args(annotated_type)[1:]
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, PropertyInfo) and annotation.alias is not None:
|
||||
return annotation.alias
|
||||
|
||||
return key
|
||||
|
||||
|
||||
def _transform_recursive(
|
||||
data: object,
|
||||
*,
|
||||
annotation: type,
|
||||
inner_type: type | None = None,
|
||||
) -> object:
|
||||
"""Transform the given data against the expected type.
|
||||
|
||||
Args:
|
||||
annotation: The direct type annotation given to the particular piece of data.
|
||||
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
|
||||
|
||||
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
|
||||
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
|
||||
the list can be transformed using the metadata from the container type.
|
||||
|
||||
Defaults to the same value as the `annotation` argument.
|
||||
"""
|
||||
if inner_type is None:
|
||||
inner_type = annotation
|
||||
|
||||
stripped_type = strip_annotated_type(inner_type)
|
||||
if is_typeddict(stripped_type) and is_mapping(data):
|
||||
return _transform_typeddict(data, stripped_type)
|
||||
|
||||
if (
|
||||
# List[T]
|
||||
(is_list_type(stripped_type) and is_list(data))
|
||||
# Iterable[T]
|
||||
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
|
||||
):
|
||||
inner_type = extract_type_arg(stripped_type, 0)
|
||||
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
|
||||
|
||||
if is_union_type(stripped_type):
|
||||
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
|
||||
#
|
||||
# TODO: there may be edge cases where the same normalized field name will transform to two different names
|
||||
# in different subtypes.
|
||||
for subtype in get_args(stripped_type):
|
||||
data = _transform_recursive(data, annotation=annotation, inner_type=subtype)
|
||||
return data
|
||||
|
||||
if isinstance(data, pydantic.BaseModel):
|
||||
return model_dump(data, exclude_unset=True)
|
||||
|
||||
annotated_type = _get_annotated_type(annotation)
|
||||
if annotated_type is None:
|
||||
return data
|
||||
|
||||
# ignore the first argument as it is the actual type
|
||||
annotations = get_args(annotated_type)[1:]
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
|
||||
return _format_data(data, annotation.format, annotation.format_template)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
|
||||
if isinstance(data, date | datetime):
|
||||
if format_ == "iso8601":
|
||||
return data.isoformat()
|
||||
|
||||
if format_ == "custom" and format_template is not None:
|
||||
return data.strftime(format_template)
|
||||
|
||||
if format_ == "base64" and is_base64_file_input(data):
|
||||
binary: str | bytes | None = None
|
||||
|
||||
if isinstance(data, pathlib.Path):
|
||||
binary = data.read_bytes()
|
||||
elif isinstance(data, io.IOBase):
|
||||
binary = data.read()
|
||||
|
||||
if isinstance(binary, str): # type: ignore[unreachable]
|
||||
binary = binary.encode()
|
||||
|
||||
if not isinstance(binary, bytes):
|
||||
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
|
||||
|
||||
return base64.b64encode(binary).decode("ascii")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _transform_typeddict(
|
||||
data: Mapping[str, object],
|
||||
expected_type: type,
|
||||
) -> Mapping[str, object]:
|
||||
result: dict[str, object] = {}
|
||||
annotations = get_type_hints(expected_type, include_extras=True)
|
||||
for key, value in data.items():
|
||||
type_ = annotations.get(key)
|
||||
if type_ is None:
|
||||
# we do not have a type annotation for this field, leave it as is
|
||||
result[key] = value
|
||||
else:
|
||||
result[_maybe_transform_key(key, type_)] = _transform_recursive(value, annotation=type_)
|
||||
return result
|
||||
|
||||
|
||||
async def async_maybe_transform(
|
||||
data: object,
|
||||
expected_type: object,
|
||||
) -> Any | None:
|
||||
"""Wrapper over `async_transform()` that allows `None` to be passed.
|
||||
|
||||
See `async_transform()` for more details.
|
||||
"""
|
||||
if data is None:
|
||||
return None
|
||||
return await async_transform(data, expected_type)
|
||||
|
||||
|
||||
async def async_transform(
|
||||
data: _T,
|
||||
expected_type: object,
|
||||
) -> _T:
|
||||
"""Transform dictionaries based off of type information from the given type, for example:
|
||||
|
||||
```py
|
||||
class Params(TypedDict, total=False):
|
||||
card_id: Required[Annotated[str, PropertyInfo(alias="cardID")]]
|
||||
|
||||
|
||||
transformed = transform({"card_id": "<my card ID>"}, Params)
|
||||
# {'cardID': '<my card ID>'}
|
||||
```
|
||||
|
||||
Any keys / data that does not have type information given will be included as is.
|
||||
|
||||
It should be noted that the transformations that this function does are not represented in the type system.
|
||||
"""
|
||||
transformed = await _async_transform_recursive(data, annotation=cast(type, expected_type))
|
||||
return cast(_T, transformed)
|
||||
|
||||
|
||||
async def _async_transform_recursive(
|
||||
data: object,
|
||||
*,
|
||||
annotation: type,
|
||||
inner_type: type | None = None,
|
||||
) -> object:
|
||||
"""Transform the given data against the expected type.
|
||||
|
||||
Args:
|
||||
annotation: The direct type annotation given to the particular piece of data.
|
||||
This may or may not be wrapped in metadata types, e.g. `Required[T]`, `Annotated[T, ...]` etc
|
||||
|
||||
inner_type: If applicable, this is the "inside" type. This is useful in certain cases where the outside type
|
||||
is a container type such as `List[T]`. In that case `inner_type` should be set to `T` so that each entry in
|
||||
the list can be transformed using the metadata from the container type.
|
||||
|
||||
Defaults to the same value as the `annotation` argument.
|
||||
"""
|
||||
if inner_type is None:
|
||||
inner_type = annotation
|
||||
|
||||
stripped_type = strip_annotated_type(inner_type)
|
||||
if is_typeddict(stripped_type) and is_mapping(data):
|
||||
return await _async_transform_typeddict(data, stripped_type)
|
||||
|
||||
if (
|
||||
# List[T]
|
||||
(is_list_type(stripped_type) and is_list(data))
|
||||
# Iterable[T]
|
||||
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
|
||||
):
|
||||
inner_type = extract_type_arg(stripped_type, 0)
|
||||
return [await _async_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
|
||||
|
||||
if is_union_type(stripped_type):
|
||||
# For union types we run the transformation against all subtypes to ensure that everything is transformed.
|
||||
#
|
||||
# TODO: there may be edge cases where the same normalized field name will transform to two different names
|
||||
# in different subtypes.
|
||||
for subtype in get_args(stripped_type):
|
||||
data = await _async_transform_recursive(data, annotation=annotation, inner_type=subtype)
|
||||
return data
|
||||
|
||||
if isinstance(data, pydantic.BaseModel):
|
||||
return model_dump(data, exclude_unset=True)
|
||||
|
||||
annotated_type = _get_annotated_type(annotation)
|
||||
if annotated_type is None:
|
||||
return data
|
||||
|
||||
# ignore the first argument as it is the actual type
|
||||
annotations = get_args(annotated_type)[1:]
|
||||
for annotation in annotations:
|
||||
if isinstance(annotation, PropertyInfo) and annotation.format is not None:
|
||||
return await _async_format_data(data, annotation.format, annotation.format_template)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def _async_format_data(data: object, format_: PropertyFormat, format_template: str | None) -> object:
|
||||
if isinstance(data, date | datetime):
|
||||
if format_ == "iso8601":
|
||||
return data.isoformat()
|
||||
|
||||
if format_ == "custom" and format_template is not None:
|
||||
return data.strftime(format_template)
|
||||
|
||||
if format_ == "base64" and is_base64_file_input(data):
|
||||
binary: str | bytes | None = None
|
||||
|
||||
if isinstance(data, pathlib.Path):
|
||||
binary = await anyio.Path(data).read_bytes()
|
||||
elif isinstance(data, io.IOBase):
|
||||
binary = data.read()
|
||||
|
||||
if isinstance(binary, str): # type: ignore[unreachable]
|
||||
binary = binary.encode()
|
||||
|
||||
if not isinstance(binary, bytes):
|
||||
raise RuntimeError(f"Could not read bytes from {data}; Received {type(binary)}")
|
||||
|
||||
return base64.b64encode(binary).decode("ascii")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
async def _async_transform_typeddict(
|
||||
data: Mapping[str, object],
|
||||
expected_type: type,
|
||||
) -> Mapping[str, object]:
|
||||
result: dict[str, object] = {}
|
||||
annotations = get_type_hints(expected_type, include_extras=True)
|
||||
for key, value in data.items():
|
||||
type_ = annotations.get(key)
|
||||
if type_ is None:
|
||||
# we do not have a type annotation for this field, leave it as is
|
||||
result[key] = value
|
||||
else:
|
||||
result[_maybe_transform_key(key, type_)] = await _async_transform_recursive(value, annotation=type_)
|
||||
return result
|
||||
@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import abc as _c_abc
|
||||
from collections.abc import Iterable
|
||||
from typing import Annotated, Any, TypeVar, cast, get_args, get_origin
|
||||
|
||||
from typing_extensions import Required
|
||||
|
||||
from .._base_compat import is_union as _is_union
|
||||
from .._base_type import InheritsGeneric
|
||||
|
||||
|
||||
def is_annotated_type(typ: type) -> bool:
|
||||
return get_origin(typ) == Annotated
|
||||
|
||||
|
||||
def is_list_type(typ: type) -> bool:
|
||||
return (get_origin(typ) or typ) == list
|
||||
|
||||
|
||||
def is_iterable_type(typ: type) -> bool:
|
||||
"""If the given type is `typing.Iterable[T]`"""
|
||||
origin = get_origin(typ) or typ
|
||||
return origin in {Iterable, _c_abc.Iterable}
|
||||
|
||||
|
||||
def is_union_type(typ: type) -> bool:
|
||||
return _is_union(get_origin(typ))
|
||||
|
||||
|
||||
def is_required_type(typ: type) -> bool:
|
||||
return get_origin(typ) == Required
|
||||
|
||||
|
||||
def is_typevar(typ: type) -> bool:
|
||||
# type ignore is required because type checkers
|
||||
# think this expression will always return False
|
||||
return type(typ) == TypeVar # type: ignore
|
||||
|
||||
|
||||
# Extracts T from Annotated[T, ...] or from Required[Annotated[T, ...]]
|
||||
def strip_annotated_type(typ: type) -> type:
|
||||
if is_required_type(typ) or is_annotated_type(typ):
|
||||
return strip_annotated_type(cast(type, get_args(typ)[0]))
|
||||
|
||||
return typ
|
||||
|
||||
|
||||
def extract_type_arg(typ: type, index: int) -> type:
|
||||
args = get_args(typ)
|
||||
try:
|
||||
return cast(type, args[index])
|
||||
except IndexError as err:
|
||||
raise RuntimeError(f"Expected type {typ} to have a type argument at index {index} but it did not") from err
|
||||
|
||||
|
||||
def extract_type_var_from_base(
|
||||
typ: type,
|
||||
*,
|
||||
generic_bases: tuple[type, ...],
|
||||
index: int,
|
||||
failure_message: str | None = None,
|
||||
) -> type:
|
||||
"""Given a type like `Foo[T]`, returns the generic type variable `T`.
|
||||
|
||||
This also handles the case where a concrete subclass is given, e.g.
|
||||
```py
|
||||
class MyResponse(Foo[bytes]):
|
||||
...
|
||||
|
||||
extract_type_var(MyResponse, bases=(Foo,), index=0) -> bytes
|
||||
```
|
||||
|
||||
And where a generic subclass is given:
|
||||
```py
|
||||
_T = TypeVar('_T')
|
||||
class MyResponse(Foo[_T]):
|
||||
...
|
||||
|
||||
extract_type_var(MyResponse[bytes], bases=(Foo,), index=0) -> bytes
|
||||
```
|
||||
"""
|
||||
cls = cast(object, get_origin(typ) or typ)
|
||||
if cls in generic_bases:
|
||||
# we're given the class directly
|
||||
return extract_type_arg(typ, index)
|
||||
|
||||
# if a subclass is given
|
||||
# ---
|
||||
# this is needed as __orig_bases__ is not present in the typeshed stubs
|
||||
# because it is intended to be for internal use only, however there does
|
||||
# not seem to be a way to resolve generic TypeVars for inherited subclasses
|
||||
# without using it.
|
||||
if isinstance(cls, InheritsGeneric):
|
||||
target_base_class: Any | None = None
|
||||
for base in cls.__orig_bases__:
|
||||
if base.__origin__ in generic_bases:
|
||||
target_base_class = base
|
||||
break
|
||||
|
||||
if target_base_class is None:
|
||||
raise RuntimeError(
|
||||
"Could not find the generic base class;\n"
|
||||
"This should never happen;\n"
|
||||
f"Does {cls} inherit from one of {generic_bases} ?"
|
||||
)
|
||||
|
||||
extracted = extract_type_arg(target_base_class, index)
|
||||
if is_typevar(extracted):
|
||||
# If the extracted type argument is itself a type variable
|
||||
# then that means the subclass itself is generic, so we have
|
||||
# to resolve the type argument from the class itself, not
|
||||
# the base class.
|
||||
#
|
||||
# Note: if there is more than 1 type argument, the subclass could
|
||||
# change the ordering of the type arguments, this is not currently
|
||||
# supported.
|
||||
return extract_type_arg(typ, index)
|
||||
|
||||
return extracted
|
||||
|
||||
raise RuntimeError(failure_message or f"Could not resolve inner type variable at index {index} for {typ}")
|
||||
@ -0,0 +1,409 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import re
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
TypeGuard,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
import sniffio
|
||||
|
||||
from .._base_compat import parse_date as parse_date # noqa: PLC0414
|
||||
from .._base_compat import parse_datetime as parse_datetime # noqa: PLC0414
|
||||
from .._base_type import FileTypes, Headers, HeadersLike, NotGiven, NotGivenOr
|
||||
|
||||
|
||||
def remove_notgiven_indict(obj):
|
||||
if obj is None or (not isinstance(obj, Mapping)):
|
||||
return obj
|
||||
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_TupleT = TypeVar("_TupleT", bound=tuple[object, ...])
|
||||
_MappingT = TypeVar("_MappingT", bound=Mapping[str, object])
|
||||
_SequenceT = TypeVar("_SequenceT", bound=Sequence[object])
|
||||
CallableT = TypeVar("CallableT", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def flatten(t: Iterable[Iterable[_T]]) -> list[_T]:
|
||||
return [item for sublist in t for item in sublist]
|
||||
|
||||
|
||||
def extract_files(
|
||||
# TODO: this needs to take Dict but variance issues.....
|
||||
# create protocol type ?
|
||||
query: Mapping[str, object],
|
||||
*,
|
||||
paths: Sequence[Sequence[str]],
|
||||
) -> list[tuple[str, FileTypes]]:
|
||||
"""Recursively extract files from the given dictionary based on specified paths.
|
||||
|
||||
A path may look like this ['foo', 'files', '<array>', 'data'].
|
||||
|
||||
Note: this mutates the given dictionary.
|
||||
"""
|
||||
files: list[tuple[str, FileTypes]] = []
|
||||
for path in paths:
|
||||
files.extend(_extract_items(query, path, index=0, flattened_key=None))
|
||||
return files
|
||||
|
||||
|
||||
def _extract_items(
|
||||
obj: object,
|
||||
path: Sequence[str],
|
||||
*,
|
||||
index: int,
|
||||
flattened_key: str | None,
|
||||
) -> list[tuple[str, FileTypes]]:
|
||||
try:
|
||||
key = path[index]
|
||||
except IndexError:
|
||||
if isinstance(obj, NotGiven):
|
||||
# no value was provided - we can safely ignore
|
||||
return []
|
||||
|
||||
# cyclical import
|
||||
from .._files import assert_is_file_content
|
||||
|
||||
# We have exhausted the path, return the entry we found.
|
||||
assert_is_file_content(obj, key=flattened_key)
|
||||
assert flattened_key is not None
|
||||
return [(flattened_key, cast(FileTypes, obj))]
|
||||
|
||||
index += 1
|
||||
if is_dict(obj):
|
||||
try:
|
||||
# We are at the last entry in the path so we must remove the field
|
||||
if (len(path)) == index:
|
||||
item = obj.pop(key)
|
||||
else:
|
||||
item = obj[key]
|
||||
except KeyError:
|
||||
# Key was not present in the dictionary, this is not indicative of an error
|
||||
# as the given path may not point to a required field. We also do not want
|
||||
# to enforce required fields as the API may differ from the spec in some cases.
|
||||
return []
|
||||
if flattened_key is None:
|
||||
flattened_key = key
|
||||
else:
|
||||
flattened_key += f"[{key}]"
|
||||
return _extract_items(
|
||||
item,
|
||||
path,
|
||||
index=index,
|
||||
flattened_key=flattened_key,
|
||||
)
|
||||
elif is_list(obj):
|
||||
if key != "<array>":
|
||||
return []
|
||||
|
||||
return flatten(
|
||||
[
|
||||
_extract_items(
|
||||
item,
|
||||
path,
|
||||
index=index,
|
||||
flattened_key=flattened_key + "[]" if flattened_key is not None else "[]",
|
||||
)
|
||||
for item in obj
|
||||
]
|
||||
)
|
||||
|
||||
# Something unexpected was passed, just ignore it.
|
||||
return []
|
||||
|
||||
|
||||
def is_given(obj: NotGivenOr[_T]) -> TypeGuard[_T]:
|
||||
return not isinstance(obj, NotGiven)
|
||||
|
||||
|
||||
# Type safe methods for narrowing types with TypeVars.
|
||||
# The default narrowing for isinstance(obj, dict) is dict[unknown, unknown],
|
||||
# however this cause Pyright to rightfully report errors. As we know we don't
|
||||
# care about the contained types we can safely use `object` in it's place.
|
||||
#
|
||||
# There are two separate functions defined, `is_*` and `is_*_t` for different use cases.
|
||||
# `is_*` is for when you're dealing with an unknown input
|
||||
# `is_*_t` is for when you're narrowing a known union type to a specific subset
|
||||
|
||||
|
||||
def is_tuple(obj: object) -> TypeGuard[tuple[object, ...]]:
|
||||
return isinstance(obj, tuple)
|
||||
|
||||
|
||||
def is_tuple_t(obj: _TupleT | object) -> TypeGuard[_TupleT]:
|
||||
return isinstance(obj, tuple)
|
||||
|
||||
|
||||
def is_sequence(obj: object) -> TypeGuard[Sequence[object]]:
|
||||
return isinstance(obj, Sequence)
|
||||
|
||||
|
||||
def is_sequence_t(obj: _SequenceT | object) -> TypeGuard[_SequenceT]:
|
||||
return isinstance(obj, Sequence)
|
||||
|
||||
|
||||
def is_mapping(obj: object) -> TypeGuard[Mapping[str, object]]:
|
||||
return isinstance(obj, Mapping)
|
||||
|
||||
|
||||
def is_mapping_t(obj: _MappingT | object) -> TypeGuard[_MappingT]:
|
||||
return isinstance(obj, Mapping)
|
||||
|
||||
|
||||
def is_dict(obj: object) -> TypeGuard[dict[object, object]]:
|
||||
return isinstance(obj, dict)
|
||||
|
||||
|
||||
def is_list(obj: object) -> TypeGuard[list[object]]:
|
||||
return isinstance(obj, list)
|
||||
|
||||
|
||||
def is_iterable(obj: object) -> TypeGuard[Iterable[object]]:
|
||||
return isinstance(obj, Iterable)
|
||||
|
||||
|
||||
def deepcopy_minimal(item: _T) -> _T:
|
||||
"""Minimal reimplementation of copy.deepcopy() that will only copy certain object types:
|
||||
|
||||
- mappings, e.g. `dict`
|
||||
- list
|
||||
|
||||
This is done for performance reasons.
|
||||
"""
|
||||
if is_mapping(item):
|
||||
return cast(_T, {k: deepcopy_minimal(v) for k, v in item.items()})
|
||||
if is_list(item):
|
||||
return cast(_T, [deepcopy_minimal(entry) for entry in item])
|
||||
return item
|
||||
|
||||
|
||||
# copied from https://github.com/Rapptz/RoboDanny
|
||||
def human_join(seq: Sequence[str], *, delim: str = ", ", final: str = "or") -> str:
|
||||
size = len(seq)
|
||||
if size == 0:
|
||||
return ""
|
||||
|
||||
if size == 1:
|
||||
return seq[0]
|
||||
|
||||
if size == 2:
|
||||
return f"{seq[0]} {final} {seq[1]}"
|
||||
|
||||
return delim.join(seq[:-1]) + f" {final} {seq[-1]}"
|
||||
|
||||
|
||||
def quote(string: str) -> str:
|
||||
"""Add single quotation marks around the given string. Does *not* do any escaping."""
|
||||
return f"'{string}'"
|
||||
|
||||
|
||||
def required_args(*variants: Sequence[str]) -> Callable[[CallableT], CallableT]:
|
||||
"""Decorator to enforce a given set of arguments or variants of arguments are passed to the decorated function.
|
||||
|
||||
Useful for enforcing runtime validation of overloaded functions.
|
||||
|
||||
Example usage:
|
||||
```py
|
||||
@overload
|
||||
def foo(*, a: str) -> str:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def foo(*, b: bool) -> str:
|
||||
...
|
||||
|
||||
|
||||
# This enforces the same constraints that a static type checker would
|
||||
# i.e. that either a or b must be passed to the function
|
||||
@required_args(["a"], ["b"])
|
||||
def foo(*, a: str | None = None, b: bool | None = None) -> str:
|
||||
...
|
||||
```
|
||||
"""
|
||||
|
||||
def inner(func: CallableT) -> CallableT:
|
||||
params = inspect.signature(func).parameters
|
||||
positional = [
|
||||
name
|
||||
for name, param in params.items()
|
||||
if param.kind
|
||||
in {
|
||||
param.POSITIONAL_ONLY,
|
||||
param.POSITIONAL_OR_KEYWORD,
|
||||
}
|
||||
]
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args: object, **kwargs: object) -> object:
|
||||
given_params: set[str] = set()
|
||||
for i, _ in enumerate(args):
|
||||
try:
|
||||
given_params.add(positional[i])
|
||||
except IndexError:
|
||||
raise TypeError(
|
||||
f"{func.__name__}() takes {len(positional)} argument(s) but {len(args)} were given"
|
||||
) from None
|
||||
|
||||
given_params.update(kwargs.keys())
|
||||
|
||||
for variant in variants:
|
||||
matches = all(param in given_params for param in variant)
|
||||
if matches:
|
||||
break
|
||||
else: # no break
|
||||
if len(variants) > 1:
|
||||
variations = human_join(
|
||||
["(" + human_join([quote(arg) for arg in variant], final="and") + ")" for variant in variants]
|
||||
)
|
||||
msg = f"Missing required arguments; Expected either {variations} arguments to be given"
|
||||
else:
|
||||
# TODO: this error message is not deterministic
|
||||
missing = list(set(variants[0]) - given_params)
|
||||
if len(missing) > 1:
|
||||
msg = f"Missing required arguments: {human_join([quote(arg) for arg in missing])}"
|
||||
else:
|
||||
msg = f"Missing required argument: {quote(missing[0])}"
|
||||
raise TypeError(msg)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
_K = TypeVar("_K")
|
||||
_V = TypeVar("_V")
|
||||
|
||||
|
||||
@overload
|
||||
def strip_not_given(obj: None) -> None: ...
|
||||
|
||||
|
||||
@overload
|
||||
def strip_not_given(obj: Mapping[_K, _V | NotGiven]) -> dict[_K, _V]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def strip_not_given(obj: object) -> object: ...
|
||||
|
||||
|
||||
def strip_not_given(obj: object | None) -> object:
|
||||
"""Remove all top-level keys where their values are instances of `NotGiven`"""
|
||||
if obj is None:
|
||||
return None
|
||||
|
||||
if not is_mapping(obj):
|
||||
return obj
|
||||
|
||||
return {key: value for key, value in obj.items() if not isinstance(value, NotGiven)}
|
||||
|
||||
|
||||
def coerce_integer(val: str) -> int:
|
||||
return int(val, base=10)
|
||||
|
||||
|
||||
def coerce_float(val: str) -> float:
|
||||
return float(val)
|
||||
|
||||
|
||||
def coerce_boolean(val: str) -> bool:
|
||||
return val in {"true", "1", "on"}
|
||||
|
||||
|
||||
def maybe_coerce_integer(val: str | None) -> int | None:
|
||||
if val is None:
|
||||
return None
|
||||
return coerce_integer(val)
|
||||
|
||||
|
||||
def maybe_coerce_float(val: str | None) -> float | None:
|
||||
if val is None:
|
||||
return None
|
||||
return coerce_float(val)
|
||||
|
||||
|
||||
def maybe_coerce_boolean(val: str | None) -> bool | None:
|
||||
if val is None:
|
||||
return None
|
||||
return coerce_boolean(val)
|
||||
|
||||
|
||||
def removeprefix(string: str, prefix: str) -> str:
|
||||
"""Remove a prefix from a string.
|
||||
|
||||
Backport of `str.removeprefix` for Python < 3.9
|
||||
"""
|
||||
if string.startswith(prefix):
|
||||
return string[len(prefix) :]
|
||||
return string
|
||||
|
||||
|
||||
def removesuffix(string: str, suffix: str) -> str:
|
||||
"""Remove a suffix from a string.
|
||||
|
||||
Backport of `str.removesuffix` for Python < 3.9
|
||||
"""
|
||||
if string.endswith(suffix):
|
||||
return string[: -len(suffix)]
|
||||
return string
|
||||
|
||||
|
||||
def file_from_path(path: str) -> FileTypes:
|
||||
contents = Path(path).read_bytes()
|
||||
file_name = os.path.basename(path)
|
||||
return (file_name, contents)
|
||||
|
||||
|
||||
def get_required_header(headers: HeadersLike, header: str) -> str:
|
||||
lower_header = header.lower()
|
||||
if isinstance(headers, Mapping):
|
||||
headers = cast(Headers, headers)
|
||||
for k, v in headers.items():
|
||||
if k.lower() == lower_header and isinstance(v, str):
|
||||
return v
|
||||
|
||||
""" to deal with the case where the header looks like Stainless-Event-Id """
|
||||
intercaps_header = re.sub(r"([^\w])(\w)", lambda pat: pat.group(1) + pat.group(2).upper(), header.capitalize())
|
||||
|
||||
for normalized_header in [header, lower_header, header.upper(), intercaps_header]:
|
||||
value = headers.get(normalized_header)
|
||||
if value:
|
||||
return value
|
||||
|
||||
raise ValueError(f"Could not find {header} header")
|
||||
|
||||
|
||||
def get_async_library() -> str:
|
||||
try:
|
||||
return sniffio.current_async_library()
|
||||
except Exception:
|
||||
return "false"
|
||||
|
||||
|
||||
def drop_prefix_image_data(content: Union[str, list[dict]]) -> Union[str, list[dict]]:
|
||||
"""
|
||||
删除 ;base64, 前缀
|
||||
:param image_data:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(content, list):
|
||||
for data in content:
|
||||
if data.get("type") == "image_url":
|
||||
image_data = data.get("image_url").get("url")
|
||||
if image_data.startswith("data:image/"):
|
||||
image_data = image_data.split("base64,")[-1]
|
||||
data["image_url"]["url"] = image_data
|
||||
|
||||
return content
|
||||
@ -0,0 +1,78 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoggerNameFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
# return record.name.startswith("loom_core") or record.name in "ERROR" or (
|
||||
# record.name.startswith("uvicorn.error")
|
||||
# and record.getMessage().startswith("Uvicorn running on")
|
||||
# )
|
||||
return True
|
||||
|
||||
|
||||
def get_log_file(log_path: str, sub_dir: str):
|
||||
"""
|
||||
sub_dir should contain a timestamp.
|
||||
"""
|
||||
log_dir = os.path.join(log_path, sub_dir)
|
||||
# Here should be creating a new directory each time, so `exist_ok=False`
|
||||
os.makedirs(log_dir, exist_ok=False)
|
||||
return os.path.join(log_dir, "zhipuai.log")
|
||||
|
||||
|
||||
def get_config_dict(log_level: str, log_file_path: str, log_backup_count: int, log_max_bytes: int) -> dict:
|
||||
# for windows, the path should be a raw string.
|
||||
log_file_path = log_file_path.encode("unicode-escape").decode() if os.name == "nt" else log_file_path
|
||||
log_level = log_level.upper()
|
||||
config_dict = {
|
||||
"version": 1,
|
||||
"disable_existing_loggers": False,
|
||||
"formatters": {
|
||||
"formatter": {"format": ("%(asctime)s %(name)-12s %(process)d %(levelname)-8s %(message)s")},
|
||||
},
|
||||
"filters": {
|
||||
"logger_name_filter": {
|
||||
"()": __name__ + ".LoggerNameFilter",
|
||||
},
|
||||
},
|
||||
"handlers": {
|
||||
"stream_handler": {
|
||||
"class": "logging.StreamHandler",
|
||||
"formatter": "formatter",
|
||||
"level": log_level,
|
||||
# "stream": "ext://sys.stdout",
|
||||
# "filters": ["logger_name_filter"],
|
||||
},
|
||||
"file_handler": {
|
||||
"class": "logging.handlers.RotatingFileHandler",
|
||||
"formatter": "formatter",
|
||||
"level": log_level,
|
||||
"filename": log_file_path,
|
||||
"mode": "a",
|
||||
"maxBytes": log_max_bytes,
|
||||
"backupCount": log_backup_count,
|
||||
"encoding": "utf8",
|
||||
},
|
||||
},
|
||||
"loggers": {
|
||||
"loom_core": {
|
||||
"handlers": ["stream_handler", "file_handler"],
|
||||
"level": log_level,
|
||||
"propagate": False,
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"level": log_level,
|
||||
"handlers": ["stream_handler", "file_handler"],
|
||||
},
|
||||
}
|
||||
return config_dict
|
||||
|
||||
|
||||
def get_timestamp_ms():
|
||||
t = time.time()
|
||||
return int(round(t * 1000))
|
||||
@ -0,0 +1,62 @@
|
||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing import Any, Generic, Optional, TypeVar, cast
|
||||
|
||||
from typing_extensions import Protocol, override, runtime_checkable
|
||||
|
||||
from ._http_client import BasePage, BaseSyncPage, PageInfo
|
||||
|
||||
__all__ = ["SyncPage", "SyncCursorPage"]
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class CursorPageItem(Protocol):
|
||||
id: Optional[str]
|
||||
|
||||
|
||||
class SyncPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
|
||||
"""Note: no pagination actually occurs yet, this is for forwards-compatibility."""
|
||||
|
||||
data: list[_T]
|
||||
object: str
|
||||
|
||||
@override
|
||||
def _get_page_items(self) -> list[_T]:
|
||||
data = self.data
|
||||
if not data:
|
||||
return []
|
||||
return data
|
||||
|
||||
@override
|
||||
def next_page_info(self) -> None:
|
||||
"""
|
||||
This page represents a response that isn't actually paginated at the API level
|
||||
so there will never be a next page.
|
||||
"""
|
||||
return None
|
||||
|
||||
|
||||
class SyncCursorPage(BaseSyncPage[_T], BasePage[_T], Generic[_T]):
|
||||
data: list[_T]
|
||||
|
||||
@override
|
||||
def _get_page_items(self) -> list[_T]:
|
||||
data = self.data
|
||||
if not data:
|
||||
return []
|
||||
return data
|
||||
|
||||
@override
|
||||
def next_page_info(self) -> Optional[PageInfo]:
|
||||
data = self.data
|
||||
if not data:
|
||||
return None
|
||||
|
||||
item = cast(Any, data[-1])
|
||||
if not isinstance(item, CursorPageItem) or item.id is None:
|
||||
# TODO emit warning log
|
||||
return None
|
||||
|
||||
return PageInfo(params={"after": item.id})
|
||||
@ -0,0 +1,5 @@
|
||||
from .assistant_completion import AssistantCompletion
|
||||
|
||||
__all__ = [
|
||||
"AssistantCompletion",
|
||||
]
|
||||
@ -0,0 +1,7 @@
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
class ConversationParameters(TypedDict, total=False):
|
||||
assistant_id: str # 智能体 ID
|
||||
page: int # 当前分页
|
||||
page_size: int # 分页数量
|
||||
@ -0,0 +1,29 @@
|
||||
from ...core import BaseModel
|
||||
|
||||
__all__ = ["ConversationUsageListResp"]
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
prompt_tokens: int # 用户输入的 tokens 数量
|
||||
completion_tokens: int # 模型输入的 tokens 数量
|
||||
total_tokens: int # 总 tokens 数量
|
||||
|
||||
|
||||
class ConversationUsage(BaseModel):
|
||||
id: str # 会话 id
|
||||
assistant_id: str # 智能体Assistant id
|
||||
create_time: int # 创建时间
|
||||
update_time: int # 更新时间
|
||||
usage: Usage # 会话中 tokens 数量统计
|
||||
|
||||
|
||||
class ConversationUsageList(BaseModel):
|
||||
assistant_id: str # 智能体id
|
||||
has_more: bool # 是否还有更多页
|
||||
conversation_list: list[ConversationUsage] # 返回的
|
||||
|
||||
|
||||
class ConversationUsageListResp(BaseModel):
|
||||
code: int
|
||||
msg: str
|
||||
data: ConversationUsageList
|
||||
@ -0,0 +1,3 @@
|
||||
from .message_content import MessageContent
|
||||
|
||||
__all__ = ["MessageContent"]
|
||||
@ -0,0 +1,13 @@
|
||||
from typing import Annotated, TypeAlias, Union
|
||||
|
||||
from ....core._utils import PropertyInfo
|
||||
from .text_content_block import TextContentBlock
|
||||
from .tools_delta_block import ToolsDeltaBlock
|
||||
|
||||
__all__ = ["MessageContent"]
|
||||
|
||||
|
||||
MessageContent: TypeAlias = Annotated[
|
||||
Union[ToolsDeltaBlock, TextContentBlock],
|
||||
PropertyInfo(discriminator="type"),
|
||||
]
|
||||
@ -0,0 +1,14 @@
|
||||
from typing import Literal
|
||||
|
||||
from ....core import BaseModel
|
||||
|
||||
__all__ = ["TextContentBlock"]
|
||||
|
||||
|
||||
class TextContentBlock(BaseModel):
|
||||
content: str
|
||||
|
||||
role: str = "assistant"
|
||||
|
||||
type: Literal["content"] = "content"
|
||||
"""Always `content`."""
|
||||
@ -0,0 +1,27 @@
|
||||
from typing import Literal
|
||||
|
||||
__all__ = ["CodeInterpreterToolBlock"]
|
||||
|
||||
from .....core import BaseModel
|
||||
|
||||
|
||||
class CodeInterpreterToolOutput(BaseModel):
|
||||
"""代码工具输出结果"""
|
||||
|
||||
type: str # 代码执行日志,目前只有 logs
|
||||
logs: str # 代码执行的日志结果
|
||||
error_msg: str # 错误信息
|
||||
|
||||
|
||||
class CodeInterpreter(BaseModel):
|
||||
"""代码解释器"""
|
||||
|
||||
input: str # 生成的代码片段,输入给代码沙盒
|
||||
outputs: list[CodeInterpreterToolOutput] # 代码执行后的输出结果
|
||||
|
||||
|
||||
class CodeInterpreterToolBlock(BaseModel):
|
||||
"""代码工具块"""
|
||||
|
||||
code_interpreter: CodeInterpreter # 代码解释器对象
|
||||
type: Literal["code_interpreter"] # 调用工具的类型,始终为 `code_interpreter`
|
||||
@ -0,0 +1,21 @@
|
||||
from typing import Literal
|
||||
|
||||
from .....core import BaseModel
|
||||
|
||||
__all__ = ["DrawingToolBlock"]
|
||||
|
||||
|
||||
class DrawingToolOutput(BaseModel):
|
||||
image: str
|
||||
|
||||
|
||||
class DrawingTool(BaseModel):
|
||||
input: str
|
||||
outputs: list[DrawingToolOutput]
|
||||
|
||||
|
||||
class DrawingToolBlock(BaseModel):
|
||||
drawing_tool: DrawingTool
|
||||
|
||||
type: Literal["drawing_tool"]
|
||||
"""Always `drawing_tool`."""
|
||||
@ -0,0 +1,22 @@
|
||||
from typing import Literal, Union
|
||||
|
||||
__all__ = ["FunctionToolBlock"]
|
||||
|
||||
from .....core import BaseModel
|
||||
|
||||
|
||||
class FunctionToolOutput(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class FunctionTool(BaseModel):
|
||||
name: str
|
||||
arguments: Union[str, dict]
|
||||
outputs: list[FunctionToolOutput]
|
||||
|
||||
|
||||
class FunctionToolBlock(BaseModel):
|
||||
function: FunctionTool
|
||||
|
||||
type: Literal["function"]
|
||||
"""Always `drawing_tool`."""
|
||||
@ -0,0 +1,41 @@
|
||||
from typing import Literal
|
||||
|
||||
from .....core import BaseModel
|
||||
|
||||
|
||||
class RetrievalToolOutput(BaseModel):
|
||||
"""
|
||||
This class represents the output of a retrieval tool.
|
||||
|
||||
Attributes:
|
||||
- text (str): The text snippet retrieved from the knowledge base.
|
||||
- document (str): The name of the document from which the text snippet was retrieved, returned only in intelligent configuration.
|
||||
""" # noqa: E501
|
||||
|
||||
text: str
|
||||
document: str
|
||||
|
||||
|
||||
class RetrievalTool(BaseModel):
|
||||
"""
|
||||
This class represents the outputs of a retrieval tool.
|
||||
|
||||
Attributes:
|
||||
- outputs (List[RetrievalToolOutput]): A list of text snippets and their respective document names retrieved from the knowledge base.
|
||||
""" # noqa: E501
|
||||
|
||||
outputs: list[RetrievalToolOutput]
|
||||
|
||||
|
||||
class RetrievalToolBlock(BaseModel):
|
||||
"""
|
||||
This class represents a block for invoking the retrieval tool.
|
||||
|
||||
Attributes:
|
||||
- retrieval (RetrievalTool): An instance of the RetrievalTool class containing the retrieval outputs.
|
||||
- type (Literal["retrieval"]): The type of tool being used, always set to "retrieval".
|
||||
"""
|
||||
|
||||
retrieval: RetrievalTool
|
||||
type: Literal["retrieval"]
|
||||
"""Always `retrieval`."""
|
||||
@ -0,0 +1,16 @@
|
||||
from typing import Annotated, TypeAlias, Union
|
||||
|
||||
from .....core._utils import PropertyInfo
|
||||
from .code_interpreter_delta_block import CodeInterpreterToolBlock
|
||||
from .drawing_tool_delta_block import DrawingToolBlock
|
||||
from .function_delta_block import FunctionToolBlock
|
||||
from .retrieval_delta_black import RetrievalToolBlock
|
||||
from .web_browser_delta_block import WebBrowserToolBlock
|
||||
|
||||
__all__ = ["ToolsType"]
|
||||
|
||||
|
||||
ToolsType: TypeAlias = Annotated[
|
||||
Union[DrawingToolBlock, CodeInterpreterToolBlock, WebBrowserToolBlock, RetrievalToolBlock, FunctionToolBlock],
|
||||
PropertyInfo(discriminator="type"),
|
||||
]
|
||||
@ -0,0 +1,48 @@
|
||||
from typing import Literal
|
||||
|
||||
from .....core import BaseModel
|
||||
|
||||
__all__ = ["WebBrowserToolBlock"]
|
||||
|
||||
|
||||
class WebBrowserOutput(BaseModel):
|
||||
"""
|
||||
This class represents the output of a web browser search result.
|
||||
|
||||
Attributes:
|
||||
- title (str): The title of the search result.
|
||||
- link (str): The URL link to the search result's webpage.
|
||||
- content (str): The textual content extracted from the search result.
|
||||
- error_msg (str): Any error message encountered during the search or retrieval process.
|
||||
"""
|
||||
|
||||
title: str
|
||||
link: str
|
||||
content: str
|
||||
error_msg: str
|
||||
|
||||
|
||||
class WebBrowser(BaseModel):
|
||||
"""
|
||||
This class represents the input and outputs of a web browser search.
|
||||
|
||||
Attributes:
|
||||
- input (str): The input query for the web browser search.
|
||||
- outputs (List[WebBrowserOutput]): A list of search results returned by the web browser.
|
||||
"""
|
||||
|
||||
input: str
|
||||
outputs: list[WebBrowserOutput]
|
||||
|
||||
|
||||
class WebBrowserToolBlock(BaseModel):
|
||||
"""
|
||||
This class represents a block for invoking the web browser tool.
|
||||
|
||||
Attributes:
|
||||
- web_browser (WebBrowser): An instance of the WebBrowser class containing the search input and outputs.
|
||||
- type (Literal["web_browser"]): The type of tool being used, always set to "web_browser".
|
||||
"""
|
||||
|
||||
web_browser: WebBrowser
|
||||
type: Literal["web_browser"]
|
||||
@ -0,0 +1,16 @@
|
||||
from typing import Literal
|
||||
|
||||
from ....core import BaseModel
|
||||
from .tools.tools_type import ToolsType
|
||||
|
||||
__all__ = ["ToolsDeltaBlock"]
|
||||
|
||||
|
||||
class ToolsDeltaBlock(BaseModel):
|
||||
tool_calls: list[ToolsType]
|
||||
"""The index of the content part in the message."""
|
||||
|
||||
role: str = "tool"
|
||||
|
||||
type: Literal["tool_calls"] = "tool_calls"
|
||||
"""Always `tool_calls`."""
|
||||
@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
__all__ = ["BatchCreateParams"]
|
||||
|
||||
|
||||
class BatchCreateParams(TypedDict, total=False):
|
||||
completion_window: Required[str]
|
||||
"""The time frame within which the batch should be processed.
|
||||
|
||||
Currently only `24h` is supported.
|
||||
"""
|
||||
|
||||
endpoint: Required[Literal["/v1/chat/completions", "/v1/embeddings"]]
|
||||
"""The endpoint to be used for all requests in the batch.
|
||||
|
||||
Currently `/v1/chat/completions` and `/v1/embeddings` are supported.
|
||||
"""
|
||||
|
||||
input_file_id: Required[str]
|
||||
"""The ID of an uploaded file that contains requests for the new batch.
|
||||
|
||||
See [upload file](https://platform.openai.com/docs/api-reference/files/create)
|
||||
for how to upload a file.
|
||||
|
||||
Your input file must be formatted as a
|
||||
[JSONL file](https://platform.openai.com/docs/api-reference/batch/requestInput),
|
||||
and must be uploaded with the purpose `batch`.
|
||||
"""
|
||||
|
||||
metadata: Optional[dict[str, str]]
|
||||
"""Optional custom metadata for the batch."""
|
||||
|
||||
auto_delete_input_file: Optional[bool]
|
||||
@ -0,0 +1,21 @@
|
||||
# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ..core import BaseModel
|
||||
|
||||
__all__ = ["BatchError"]
|
||||
|
||||
|
||||
class BatchError(BaseModel):
|
||||
code: Optional[str] = None
|
||||
"""定义的业务错误码"""
|
||||
|
||||
line: Optional[int] = None
|
||||
"""文件中的行号"""
|
||||
|
||||
message: Optional[str] = None
|
||||
"""关于对话文件中的错误的描述"""
|
||||
|
||||
param: Optional[str] = None
|
||||
"""参数名称,如果有的话"""
|
||||
@ -0,0 +1,14 @@
|
||||
from ..core import BaseModel
|
||||
|
||||
__all__ = ["BatchRequestCounts"]
|
||||
|
||||
|
||||
class BatchRequestCounts(BaseModel):
|
||||
completed: int
|
||||
"""这个数字表示已经完成的请求。"""
|
||||
|
||||
failed: int
|
||||
"""这个数字表示失败的请求。"""
|
||||
|
||||
total: int
|
||||
"""这个数字表示总的请求。"""
|
||||
@ -0,0 +1,5 @@
|
||||
from .file_deleted import FileDeleted
|
||||
from .file_object import FileObject, ListOfFileObject
|
||||
from .upload_detail import UploadDetail
|
||||
|
||||
__all__ = ["FileObject", "ListOfFileObject", "UploadDetail", "FileDeleted"]
|
||||
@ -0,0 +1,13 @@
|
||||
from typing import Literal
|
||||
|
||||
from ...core import BaseModel
|
||||
|
||||
__all__ = ["FileDeleted"]
|
||||
|
||||
|
||||
class FileDeleted(BaseModel):
|
||||
id: str
|
||||
|
||||
deleted: bool
|
||||
|
||||
object: Literal["file"]
|
||||
@ -1,8 +1,8 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from ...core import BaseModel
|
||||
|
||||
__all__ = ["FileObject"]
|
||||
__all__ = ["FileObject", "ListOfFileObject"]
|
||||
|
||||
|
||||
class FileObject(BaseModel):
|
||||
@ -0,0 +1,13 @@
|
||||
from typing import Optional
|
||||
|
||||
from ...core import BaseModel
|
||||
|
||||
|
||||
class UploadDetail(BaseModel):
|
||||
url: str
|
||||
knowledge_type: int
|
||||
file_name: Optional[str] = None
|
||||
sentence_size: Optional[int] = None
|
||||
custom_separator: Optional[list[str]] = None
|
||||
callback_url: Optional[str] = None
|
||||
callback_header: Optional[dict[str, str]] = None
|
||||
@ -0,0 +1 @@
|
||||
from .fine_tuned_models import FineTunedModelsStatus
|
||||
@ -0,0 +1,8 @@
|
||||
from .knowledge import KnowledgeInfo
|
||||
from .knowledge_used import KnowledgeStatistics, KnowledgeUsed
|
||||
|
||||
__all__ = [
|
||||
"KnowledgeInfo",
|
||||
"KnowledgeStatistics",
|
||||
"KnowledgeUsed",
|
||||
]
|
||||
@ -0,0 +1,8 @@
|
||||
from .document import DocumentData, DocumentFailedInfo, DocumentObject, DocumentSuccessinfo
|
||||
|
||||
__all__ = [
|
||||
"DocumentData",
|
||||
"DocumentObject",
|
||||
"DocumentSuccessinfo",
|
||||
"DocumentFailedInfo",
|
||||
]
|
||||
@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ....core import BaseModel
|
||||
from . import DocumentData
|
||||
|
||||
__all__ = ["DocumentPage"]
|
||||
|
||||
|
||||
class DocumentPage(BaseModel):
|
||||
list: list[DocumentData]
|
||||
object: str
|
||||
@ -0,0 +1,21 @@
|
||||
from typing import Optional
|
||||
|
||||
from ...core import BaseModel
|
||||
|
||||
__all__ = ["KnowledgeInfo"]
|
||||
|
||||
|
||||
class KnowledgeInfo(BaseModel):
|
||||
id: Optional[str] = None
|
||||
"""知识库唯一 id"""
|
||||
embedding_id: Optional[str] = (
|
||||
None # 知识库绑定的向量化模型 见模型列表 [内部服务开放接口文档](https://lslfd0slxc.feishu.cn/docx/YauWdbBiMopV0FxB7KncPWCEn8f#H15NduiQZo3ugmxnWQFcfAHpnQ4)
|
||||
)
|
||||
name: Optional[str] = None # 知识库名称 100字限制
|
||||
customer_identifier: Optional[str] = None # 用户标识 长度32位以内
|
||||
description: Optional[str] = None # 知识库描述 500字限制
|
||||
background: Optional[str] = None # 背景颜色(给枚举)'blue', 'red', 'orange', 'purple', 'sky'
|
||||
icon: Optional[str] = (
|
||||
None # 知识库图标(给枚举) question: 问号、book: 书籍、seal: 印章、wrench: 扳手、tag: 标签、horn: 喇叭、house: 房子 # noqa: E501
|
||||
)
|
||||
bucket_id: Optional[str] = None # 桶id 限制32位
|
||||
@ -0,0 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ...core import BaseModel
|
||||
from . import KnowledgeInfo
|
||||
|
||||
__all__ = ["KnowledgePage"]
|
||||
|
||||
|
||||
class KnowledgePage(BaseModel):
|
||||
list: list[KnowledgeInfo]
|
||||
object: str
|
||||
@ -0,0 +1,21 @@
|
||||
from typing import Optional
|
||||
|
||||
from ...core import BaseModel
|
||||
|
||||
__all__ = ["KnowledgeStatistics", "KnowledgeUsed"]
|
||||
|
||||
|
||||
class KnowledgeStatistics(BaseModel):
|
||||
"""
|
||||
使用量统计
|
||||
"""
|
||||
|
||||
word_num: Optional[int] = None
|
||||
length: Optional[int] = None
|
||||
|
||||
|
||||
class KnowledgeUsed(BaseModel):
|
||||
used: Optional[KnowledgeStatistics] = None
|
||||
"""已使用量"""
|
||||
total: Optional[KnowledgeStatistics] = None
|
||||
"""知识库总量"""
|
||||
@ -0,0 +1,3 @@
|
||||
from .sensitive_word_check import SensitiveWordCheckRequest
|
||||
|
||||
__all__ = ["SensitiveWordCheckRequest"]
|
||||
@ -0,0 +1,9 @@
|
||||
from .web_search import (
|
||||
SearchIntent,
|
||||
SearchRecommend,
|
||||
SearchResult,
|
||||
WebSearch,
|
||||
)
|
||||
from .web_search_chunk import WebSearchChunk
|
||||
|
||||
__all__ = ["WebSearch", "SearchIntent", "SearchResult", "SearchRecommend", "WebSearchChunk"]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue