Merge main
commit
9c7bcd5abc
@ -0,0 +1,54 @@
|
||||
name: Check i18n Files and Create PR
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [closed]
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
check-and-update:
|
||||
if: github.event.pull_request.merged == true
|
||||
runs-on: ubuntu-latest
|
||||
defaults:
|
||||
run:
|
||||
working-directory: web
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 2 # last 2 commits
|
||||
|
||||
- name: Check for file changes in i18n/en-US
|
||||
id: check_files
|
||||
run: |
|
||||
recent_commit_sha=$(git rev-parse HEAD)
|
||||
second_recent_commit_sha=$(git rev-parse HEAD~1)
|
||||
changed_files=$(git diff --name-only $recent_commit_sha $second_recent_commit_sha -- 'i18n/en-US/*.ts')
|
||||
echo "Changed files: $changed_files"
|
||||
if [ -n "$changed_files" ]; then
|
||||
echo "FILES_CHANGED=true" >> $GITHUB_ENV
|
||||
else
|
||||
echo "FILES_CHANGED=false" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Set up Node.js
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: actions/setup-node@v2
|
||||
with:
|
||||
node-version: 'lts/*'
|
||||
|
||||
- name: Install dependencies
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: yarn install --frozen-lockfile
|
||||
|
||||
- name: Run npm script
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
run: npm run auto-gen-i18n
|
||||
|
||||
- name: Create Pull Request
|
||||
if: env.FILES_CHANGED == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
commit-message: Update i18n files based on en-US changes
|
||||
title: 'chore: translate i18n files'
|
||||
body: This PR was automatically created to update i18n files based on changes in en-US locale.
|
||||
branch: chore/automated-i18n-updates
|
||||
@ -0,0 +1,29 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class HuaweiCloudOBSStorageConfig(BaseModel):
|
||||
"""
|
||||
Huawei Cloud OBS storage configs
|
||||
"""
|
||||
|
||||
HUAWEI_OBS_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS bucket name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_ACCESS_KEY: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS Access key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_SECRET_KEY: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS Secret key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
HUAWEI_OBS_SERVER: Optional[str] = Field(
|
||||
description="Huawei Cloud OBS server URL",
|
||||
default=None,
|
||||
)
|
||||
@ -0,0 +1,34 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VolcengineTOSStorageConfig(BaseModel):
|
||||
"""
|
||||
Volcengine tos storage configs
|
||||
"""
|
||||
|
||||
VOLCENGINE_TOS_BUCKET_NAME: Optional[str] = Field(
|
||||
description="Volcengine TOS Bucket Name",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_ACCESS_KEY: Optional[str] = Field(
|
||||
description="Volcengine TOS Access Key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_SECRET_KEY: Optional[str] = Field(
|
||||
description="Volcengine TOS Secret Key",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_ENDPOINT: Optional[str] = Field(
|
||||
description="Volcengine TOS Endpoint URL",
|
||||
default=None,
|
||||
)
|
||||
|
||||
VOLCENGINE_TOS_REGION: Optional[str] = Field(
|
||||
description="Volcengine TOS Region",
|
||||
default=None,
|
||||
)
|
||||
File diff suppressed because one or more lines are too long
@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="61.1 180.15 377.8 139.718"><path d="M431.911 245.181c3.842 0 6.989 1.952 6.989 4.337v14.776c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-14.776c0-2.385 3.144-4.337 6.99-4.337ZM404.135 250.955c3.846 0 6.989 1.952 6.989 4.337v32.528c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-32.528c0-2.385 3.147-4.337 6.989-4.337ZM376.363 257.688c3.842 0 6.989 1.952 6.989 4.337v36.562c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-36.562c0-2.386 3.147-4.337 6.993-4.337ZM348.587 263.26c3.846 0 6.989 1.952 6.989 4.337v36.159c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-36.159c0-2.385 3.147-4.337 6.989-4.337ZM320.811 268.177c3.846 0 6.989 1.952 6.989 4.337v31.318c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-31.318c0-2.385 3.147-4.337 6.989-4.337ZM293.179 288.148c3.846 0 6.989 1.952 6.989 4.337v9.935c0 2.384-3.147 4.336-6.989 4.336s-6.99-1.951-6.99-4.336v-9.935c0-2.386 3.144-4.337 6.99-4.337Z" style="fill:#b1b3b4;fill-rule:evenodd"></path><path d="M431.911 205.441c3.842 0 6.989 1.952 6.989 4.337v24.459c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-24.459c0-2.385 3.144-4.337 6.99-4.337ZM404.135 189.026c3.846 0 6.989 1.952 6.989 4.337v43.622c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-43.622c0-2.385 3.147-4.337 6.989-4.337ZM376.363 182.848c3.842 0 6.989 1.953 6.989 4.337v56.937c0 2.384-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-56.937c0-2.385 3.147-4.337 6.993-4.337ZM348.587 180.15c3.846 0 6.989 1.952 6.989 4.337v66.619c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-66.619c0-2.385 3.147-4.337 6.989-4.337ZM320.811 181.84c3.846 0 6.989 1.952 6.989 4.337v67.627c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-67.627c0-2.386 3.147-4.337 6.989-4.337ZM293.179 186.076c3.846 0 6.989 1.952 6.989 4.337v84.37c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.951-6.99-4.337v-84.37c0-2.386 3.144-4.337 6.99-4.337ZM264.829 193.262c3.846 0 6.989 1.953 6.989 4.337v95.667c0 2.385-3.143 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-95.667c0-2.385 3.147-4.337 6.99-4.337ZM237.057 205.441c3.842 0 6.989 1.953 6.989 4.337v92.036c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.951-6.99-4.337v-92.036c0-2.385 3.144-4.337 6.99-4.337ZM209.281 221.302c3.846 0 6.989 1.952 6.989 4.337v80.134c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.952-6.99-4.337v-80.134c0-2.386 3.144-4.337 6.99-4.337ZM181.505 232.271c3.846 0 6.993 1.952 6.993 4.336v78.924c0 2.385-3.147 4.337-6.993 4.337-3.842 0-6.989-1.951-6.989-4.337v-78.924c0-2.385 3.147-4.336 6.989-4.336ZM153.873 241.348c3.846 0 6.989 1.953 6.989 4.337v42.009c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-42.009c0-2.385 3.147-4.337 6.99-4.337ZM125.266 200.398c3.842 0 6.989 1.953 6.989 4.337v58.55c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-58.55c0-2.385 3.144-4.337 6.99-4.337ZM96.7 204.231c3.842 0 6.989 1.953 6.989 4.337v18.004c0 2.384-3.147 4.337-6.989 4.337s-6.989-1.952-6.989-4.337v-18.004c0-2.385 3.143-4.337 6.989-4.337ZM68.089 201.81c3.846 0 6.99 1.953 6.99 4.337v8.12c0 2.384-3.147 4.336-6.99 4.336-3.842 0-6.989-1.951-6.989-4.336v-8.12c0-2.385 3.143-4.337 6.989-4.337ZM153.873 194.94c3.846 0 6.989 1.953 6.989 4.337v6.102c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-6.102c0-2.385 3.147-4.337 6.99-4.337Z" style="fill:#000;fill-rule:evenodd"></path></svg>
|
||||
|
After Width: | Height: | Size: 3.4 KiB |
@ -0,0 +1 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="61.1 180.15 377.8 139.718"><path d="M431.911 245.181c3.842 0 6.989 1.952 6.989 4.337v14.776c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-14.776c0-2.385 3.144-4.337 6.99-4.337ZM404.135 250.955c3.846 0 6.989 1.952 6.989 4.337v32.528c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-32.528c0-2.385 3.147-4.337 6.989-4.337ZM376.363 257.688c3.842 0 6.989 1.952 6.989 4.337v36.562c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-36.562c0-2.386 3.147-4.337 6.993-4.337ZM348.587 263.26c3.846 0 6.989 1.952 6.989 4.337v36.159c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-36.159c0-2.385 3.147-4.337 6.989-4.337ZM320.811 268.177c3.846 0 6.989 1.952 6.989 4.337v31.318c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-31.318c0-2.385 3.147-4.337 6.989-4.337ZM293.179 288.148c3.846 0 6.989 1.952 6.989 4.337v9.935c0 2.384-3.147 4.336-6.989 4.336s-6.99-1.951-6.99-4.336v-9.935c0-2.386 3.144-4.337 6.99-4.337Z" style="fill:#b1b3b4;fill-rule:evenodd"></path><path d="M431.911 205.441c3.842 0 6.989 1.952 6.989 4.337v24.459c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.952-6.99-4.337v-24.459c0-2.385 3.144-4.337 6.99-4.337ZM404.135 189.026c3.846 0 6.989 1.952 6.989 4.337v43.622c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-43.622c0-2.385 3.147-4.337 6.989-4.337ZM376.363 182.848c3.842 0 6.989 1.953 6.989 4.337v56.937c0 2.384-3.147 4.337-6.989 4.337-3.846 0-6.993-1.952-6.993-4.337v-56.937c0-2.385 3.147-4.337 6.993-4.337ZM348.587 180.15c3.846 0 6.989 1.952 6.989 4.337v66.619c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.952-6.989-4.337v-66.619c0-2.385 3.147-4.337 6.989-4.337ZM320.811 181.84c3.846 0 6.989 1.952 6.989 4.337v67.627c0 2.385-3.143 4.337-6.989 4.337-3.842 0-6.989-1.951-6.989-4.337v-67.627c0-2.386 3.147-4.337 6.989-4.337ZM293.179 186.076c3.846 0 6.989 1.952 6.989 4.337v84.37c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.951-6.99-4.337v-84.37c0-2.386 3.144-4.337 6.99-4.337ZM264.829 193.262c3.846 0 6.989 1.953 6.989 4.337v95.667c0 2.385-3.143 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-95.667c0-2.385 3.147-4.337 6.99-4.337ZM237.057 205.441c3.842 0 6.989 1.953 6.989 4.337v92.036c0 2.385-3.147 4.337-6.989 4.337-3.846 0-6.99-1.951-6.99-4.337v-92.036c0-2.385 3.144-4.337 6.99-4.337ZM209.281 221.302c3.846 0 6.989 1.952 6.989 4.337v80.134c0 2.385-3.147 4.337-6.989 4.337s-6.99-1.952-6.99-4.337v-80.134c0-2.386 3.144-4.337 6.99-4.337ZM181.505 232.271c3.846 0 6.993 1.952 6.993 4.336v78.924c0 2.385-3.147 4.337-6.993 4.337-3.842 0-6.989-1.951-6.989-4.337v-78.924c0-2.385 3.147-4.336 6.989-4.336ZM153.873 241.348c3.846 0 6.989 1.953 6.989 4.337v42.009c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-42.009c0-2.385 3.147-4.337 6.99-4.337ZM125.266 200.398c3.842 0 6.989 1.953 6.989 4.337v58.55c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.951-6.99-4.337v-58.55c0-2.385 3.144-4.337 6.99-4.337ZM96.7 204.231c3.842 0 6.989 1.953 6.989 4.337v18.004c0 2.384-3.147 4.337-6.989 4.337s-6.989-1.952-6.989-4.337v-18.004c0-2.385 3.143-4.337 6.989-4.337ZM68.089 201.81c3.846 0 6.99 1.953 6.99 4.337v8.12c0 2.384-3.147 4.336-6.99 4.336-3.842 0-6.989-1.951-6.989-4.336v-8.12c0-2.385 3.143-4.337 6.989-4.337ZM153.873 194.94c3.846 0 6.989 1.953 6.989 4.337v6.102c0 2.384-3.147 4.337-6.989 4.337-3.843 0-6.99-1.952-6.99-4.337v-6.102c0-2.385 3.147-4.337 6.99-4.337Z" style="fill:#000;fill-rule:evenodd"></path></svg>
|
||||
|
After Width: | Height: | Size: 3.4 KiB |
@ -0,0 +1,76 @@
|
||||
provider: fishaudio
|
||||
label:
|
||||
en_US: Fish Audio
|
||||
description:
|
||||
en_US: Models provided by Fish Audio, currently only support TTS.
|
||||
zh_Hans: Fish Audio 提供的模型,目前仅支持 TTS。
|
||||
icon_small:
|
||||
en_US: fishaudio_s_en.svg
|
||||
icon_large:
|
||||
en_US: fishaudio_l_en.svg
|
||||
background: "#E5E7EB"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API key from Fish Audio
|
||||
zh_Hans: 从 Fish Audio 获取你的 API Key
|
||||
url:
|
||||
en_US: https://fish.audio/go-api/
|
||||
supported_model_types:
|
||||
- tts
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: API Key
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API Key
|
||||
en_US: Enter your API Key
|
||||
- variable: api_base
|
||||
label:
|
||||
en_US: API URL
|
||||
type: text-input
|
||||
required: false
|
||||
default: https://api.fish.audio
|
||||
placeholder:
|
||||
en_US: Enter your API URL
|
||||
zh_Hans: 在此输入您的 API URL
|
||||
- variable: use_public_models
|
||||
label:
|
||||
en_US: Use Public Models
|
||||
type: select
|
||||
required: false
|
||||
default: "false"
|
||||
placeholder:
|
||||
en_US: Toggle to use public models
|
||||
zh_Hans: 切换以使用公共模型
|
||||
options:
|
||||
- value: "true"
|
||||
label:
|
||||
en_US: Allow Public Models
|
||||
zh_Hans: 使用公共模型
|
||||
- value: "false"
|
||||
label:
|
||||
en_US: Private Models Only
|
||||
zh_Hans: 仅使用私有模型
|
||||
- variable: latency
|
||||
label:
|
||||
en_US: Latency
|
||||
type: select
|
||||
required: false
|
||||
default: "normal"
|
||||
placeholder:
|
||||
en_US: Toggle to choice latency
|
||||
zh_Hans: 切换以调整延迟
|
||||
options:
|
||||
- value: "balanced"
|
||||
label:
|
||||
en_US: Low (may affect quality)
|
||||
zh_Hans: 低延迟 (可能降低质量)
|
||||
- value: "normal"
|
||||
label:
|
||||
en_US: Normal
|
||||
zh_Hans: 标准
|
||||
@ -0,0 +1,5 @@
|
||||
model: tts-default
|
||||
model_type: tts
|
||||
model_properties:
|
||||
word_limit: 1000
|
||||
audio_type: 'mp3'
|
||||
@ -0,0 +1 @@
|
||||
<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 231 30' preserveAspectRatio='xMinYMid'><path d='M99.61,19.52h15.24l-8.05-13L92,30H85.27l18-28.17a4.29,4.29,0,0,1,7-.05L128.32,30h-6.73l-3.17-5.25H103l-3.36-5.23m69.93,5.23V0.28h-5.72V27.16a2.76,2.76,0,0,0,.85,2,2.89,2.89,0,0,0,2.08.87h26l3.39-5.25H169.54M75,20.38A10,10,0,0,0,75,.28H50V30h5.71V5.54H74.65a4.81,4.81,0,0,1,0,9.62H58.54L75.6,30h8.29L72.43,20.38H75M14.88,30H32.15a14.86,14.86,0,0,0,0-29.71H14.88a14.86,14.86,0,1,0,0,29.71m16.88-5.23H15.26a9.62,9.62,0,0,1,0-19.23h16.5a9.62,9.62,0,1,1,0,19.23M140.25,30h17.63l3.34-5.23H140.64a9.62,9.62,0,1,1,0-19.23h16.75l3.38-5.25H140.25a14.86,14.86,0,1,0,0,29.71m69.87-5.23a9.62,9.62,0,0,1-9.26-7h24.42l3.36-5.24H200.86a9.61,9.61,0,0,1,9.26-7h16.76l3.35-5.25h-20.5a14.86,14.86,0,0,0,0,29.71h17.63l3.35-5.23h-20.6' transform='translate(-0.02 0)' style='fill:#C74634'/></svg>
|
||||
|
After Width: | Height: | Size: 874 B |
@ -0,0 +1 @@
|
||||
<svg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 231 30' preserveAspectRatio='xMinYMid'><path d='M99.61,19.52h15.24l-8.05-13L92,30H85.27l18-28.17a4.29,4.29,0,0,1,7-.05L128.32,30h-6.73l-3.17-5.25H103l-3.36-5.23m69.93,5.23V0.28h-5.72V27.16a2.76,2.76,0,0,0,.85,2,2.89,2.89,0,0,0,2.08.87h26l3.39-5.25H169.54M75,20.38A10,10,0,0,0,75,.28H50V30h5.71V5.54H74.65a4.81,4.81,0,0,1,0,9.62H58.54L75.6,30h8.29L72.43,20.38H75M14.88,30H32.15a14.86,14.86,0,0,0,0-29.71H14.88a14.86,14.86,0,1,0,0,29.71m16.88-5.23H15.26a9.62,9.62,0,0,1,0-19.23h16.5a9.62,9.62,0,1,1,0,19.23M140.25,30h17.63l3.34-5.23H140.64a9.62,9.62,0,1,1,0-19.23h16.75l3.38-5.25H140.25a14.86,14.86,0,1,0,0,29.71m69.87-5.23a9.62,9.62,0,0,1-9.26-7h24.42l3.36-5.24H200.86a9.61,9.61,0,0,1,9.26-7h16.76l3.35-5.25h-20.5a14.86,14.86,0,0,0,0,29.71h17.63l3.35-5.23h-20.6' transform='translate(-0.02 0)' style='fill:#C74634'/></svg>
|
||||
|
After Width: | Height: | Size: 874 B |
@ -0,0 +1,52 @@
|
||||
model: cohere.command-r-16k
|
||||
label:
|
||||
en_US: cohere.command-r-16k v1.2
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1
|
||||
max: 1.0
|
||||
- name: topP
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0
|
||||
max: 1
|
||||
- name: topK
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presencePenalty
|
||||
use_template: presence_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: frequencyPenalty
|
||||
use_template: frequency_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: maxTokens
|
||||
use_template: max_tokens
|
||||
default: 600
|
||||
max: 4000
|
||||
pricing:
|
||||
input: '0.004'
|
||||
output: '0.004'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@ -0,0 +1,52 @@
|
||||
model: cohere.command-r-plus
|
||||
label:
|
||||
en_US: cohere.command-r-plus v1.2
|
||||
model_type: llm
|
||||
features:
|
||||
- multi-tool-call
|
||||
- agent-thought
|
||||
- stream-tool-call
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 128000
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1
|
||||
max: 1.0
|
||||
- name: topP
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0
|
||||
max: 1
|
||||
- name: topK
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presencePenalty
|
||||
use_template: presence_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: frequencyPenalty
|
||||
use_template: frequency_penalty
|
||||
min: 0
|
||||
max: 1
|
||||
default: 0
|
||||
- name: maxTokens
|
||||
use_template: max_tokens
|
||||
default: 600
|
||||
max: 4000
|
||||
pricing:
|
||||
input: '0.0219'
|
||||
output: '0.0219'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@ -0,0 +1,461 @@
|
||||
import base64
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Optional, Union
|
||||
|
||||
import oci
|
||||
from oci.generative_ai_inference.models.base_chat_response import BaseChatResponse
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageTool,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
request_template = {
|
||||
"compartmentId": "",
|
||||
"servingMode": {
|
||||
"modelId": "cohere.command-r-plus",
|
||||
"servingType": "ON_DEMAND"
|
||||
},
|
||||
"chatRequest": {
|
||||
"apiFormat": "COHERE",
|
||||
#"preambleOverride": "You are a helpful assistant.",
|
||||
#"message": "Hello!",
|
||||
#"chatHistory": [],
|
||||
"maxTokens": 600,
|
||||
"isStream": False,
|
||||
"frequencyPenalty": 0,
|
||||
"presencePenalty": 0,
|
||||
"temperature": 1,
|
||||
"topP": 0.75
|
||||
}
|
||||
}
|
||||
oci_config_template = {
|
||||
"user": "",
|
||||
"fingerprint": "",
|
||||
"tenancy": "",
|
||||
"region": "",
|
||||
"compartment_id": "",
|
||||
"key_content": ""
|
||||
}
|
||||
|
||||
class OCILargeLanguageModel(LargeLanguageModel):
|
||||
# https://docs.oracle.com/en-us/iaas/Content/generative-ai/pretrained-models.htm
|
||||
_supported_models = {
|
||||
"meta.llama-3-70b-instruct": {
|
||||
"system": True,
|
||||
"multimodal": False,
|
||||
"tool_call": False,
|
||||
"stream_tool_call": False,
|
||||
},
|
||||
"cohere.command-r-16k": {
|
||||
"system": True,
|
||||
"multimodal": False,
|
||||
"tool_call": True,
|
||||
"stream_tool_call": False,
|
||||
},
|
||||
"cohere.command-r-plus": {
|
||||
"system": True,
|
||||
"multimodal": False,
|
||||
"tool_call": True,
|
||||
"stream_tool_call": False,
|
||||
},
|
||||
}
|
||||
|
||||
def _is_tool_call_supported(self, model_id: str, stream: bool = False) -> bool:
|
||||
feature = self._supported_models.get(model_id)
|
||||
if not feature:
|
||||
return False
|
||||
return feature["stream_tool_call"] if stream else feature["tool_call"]
|
||||
|
||||
def _is_multimodal_supported(self, model_id: str) -> bool:
|
||||
feature = self._supported_models.get(model_id)
|
||||
if not feature:
|
||||
return False
|
||||
return feature["multimodal"]
|
||||
|
||||
def _is_system_prompt_supported(self, model_id: str) -> bool:
|
||||
feature = self._supported_models.get(model_id)
|
||||
if not feature:
|
||||
return False
|
||||
return feature["system"]
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None) \
|
||||
-> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
#print("model"+"*"*20)
|
||||
#print(model)
|
||||
#print("credentials"+"*"*20)
|
||||
#print(credentials)
|
||||
#print("model_parameters"+"*"*20)
|
||||
#print(model_parameters)
|
||||
#print("prompt_messages"+"*"*200)
|
||||
#print(prompt_messages)
|
||||
#print("tools"+"*"*20)
|
||||
#print(tools)
|
||||
|
||||
# invoke model
|
||||
return self._generate(model, credentials, prompt_messages, model_parameters, tools, stop, stream, user)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:md = genai.GenerativeModel(model)
|
||||
"""
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
|
||||
return self._get_num_tokens_by_gpt2(prompt)
|
||||
|
||||
def get_num_characters(self, model: str, credentials: dict, prompt_messages: list[PromptMessage],
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param prompt_messages: prompt messages
|
||||
:param tools: tools for tool calling
|
||||
:return:md = genai.GenerativeModel(model)
|
||||
"""
|
||||
prompt = self._convert_messages_to_prompt(prompt_messages)
|
||||
|
||||
return len(prompt)
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: list[PromptMessage]) -> str:
|
||||
"""
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
|
||||
text = "".join(
|
||||
self._convert_one_message_to_text(message)
|
||||
for message in messages
|
||||
)
|
||||
|
||||
return text.rstrip()
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
# Setup basic variables
|
||||
# Auth Config
|
||||
try:
|
||||
ping_message = SystemPromptMessage(content="ping")
|
||||
self._generate(model, credentials, [ping_message], {"maxTokens": 5})
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _generate(self, model: str, credentials: dict,
|
||||
prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[list[str]] = None,
|
||||
stream: bool = True, user: Optional[str] = None
|
||||
) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke large language model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials kwargs
|
||||
:param prompt_messages: prompt messages
|
||||
:param model_parameters: model parameters
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
# config_kwargs = model_parameters.copy()
|
||||
# config_kwargs['max_output_tokens'] = config_kwargs.pop('max_tokens_to_sample', None)
|
||||
# if stop:
|
||||
# config_kwargs["stop_sequences"] = stop
|
||||
|
||||
# initialize client
|
||||
# ref: https://docs.oracle.com/en-us/iaas/api/#/en/generative-ai-inference/20231130/ChatResult/Chat
|
||||
oci_config = copy.deepcopy(oci_config_template)
|
||||
if "oci_config_content" in credentials:
|
||||
oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
|
||||
config_items = oci_config_content.split("/")
|
||||
if len(config_items) != 5:
|
||||
raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
|
||||
oci_config["user"] = config_items[0]
|
||||
oci_config["fingerprint"] = config_items[1]
|
||||
oci_config["tenancy"] = config_items[2]
|
||||
oci_config["region"] = config_items[3]
|
||||
oci_config["compartment_id"] = config_items[4]
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
if "oci_key_content" in credentials:
|
||||
oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
|
||||
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
|
||||
#oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
|
||||
compartment_id = oci_config["compartment_id"]
|
||||
client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
|
||||
# call embedding model
|
||||
request_args = copy.deepcopy(request_template)
|
||||
request_args["compartmentId"] = compartment_id
|
||||
request_args["servingMode"]["modelId"] = model
|
||||
|
||||
chat_history = []
|
||||
system_prompts = []
|
||||
#if "meta.llama" in model:
|
||||
# request_args["chatRequest"]["apiFormat"] = "GENERIC"
|
||||
request_args["chatRequest"]["maxTokens"] = model_parameters.pop('maxTokens', 600)
|
||||
request_args["chatRequest"].update(model_parameters)
|
||||
frequency_penalty = model_parameters.get("frequencyPenalty", 0)
|
||||
presence_penalty = model_parameters.get("presencePenalty", 0)
|
||||
if frequency_penalty > 0 and presence_penalty > 0:
|
||||
raise InvokeBadRequestError("Cannot set both frequency penalty and presence penalty")
|
||||
|
||||
# for msg in prompt_messages: # makes message roles strictly alternating
|
||||
# content = self._format_message_to_glm_content(msg)
|
||||
# if history and history[-1]["role"] == content["role"]:
|
||||
# history[-1]["parts"].extend(content["parts"])
|
||||
# else:
|
||||
# history.append(content)
|
||||
|
||||
# temporary not implement the tool call function
|
||||
valid_value = self._is_tool_call_supported(model, stream)
|
||||
if tools is not None and len(tools) > 0:
|
||||
if not valid_value:
|
||||
raise InvokeBadRequestError("Does not support function calling")
|
||||
if model.startswith("cohere"):
|
||||
#print("run cohere " * 10)
|
||||
for message in prompt_messages[:-1]:
|
||||
text = ""
|
||||
if isinstance(message.content, str):
|
||||
text = message.content
|
||||
if isinstance(message, UserPromptMessage):
|
||||
chat_history.append({"role": "USER", "message": text})
|
||||
else:
|
||||
chat_history.append({"role": "CHATBOT", "message": text})
|
||||
if isinstance(message, SystemPromptMessage):
|
||||
if isinstance(message.content, str):
|
||||
system_prompts.append(message.content)
|
||||
args = {"apiFormat": "COHERE",
|
||||
"preambleOverride": ' '.join(system_prompts),
|
||||
"message": prompt_messages[-1].content,
|
||||
"chatHistory": chat_history, }
|
||||
request_args["chatRequest"].update(args)
|
||||
elif model.startswith("meta"):
|
||||
#print("run meta " * 10)
|
||||
meta_messages = []
|
||||
for message in prompt_messages:
|
||||
text = message.content
|
||||
meta_messages.append({"role": message.role.name, "content": [{"type": "TEXT", "text": text}]})
|
||||
args = {"apiFormat": "GENERIC",
|
||||
"messages": meta_messages,
|
||||
"numGenerations": 1,
|
||||
"topK": -1}
|
||||
request_args["chatRequest"].update(args)
|
||||
|
||||
if stream:
|
||||
request_args["chatRequest"]["isStream"] = True
|
||||
#print("final request" + "|" * 20)
|
||||
#print(request_args)
|
||||
response = client.chat(request_args)
|
||||
#print(vars(response))
|
||||
|
||||
if stream:
|
||||
return self._handle_generate_stream_response(model, credentials, response, prompt_messages)
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: BaseChatResponse,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
"""
|
||||
Handle llm response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response
|
||||
"""
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=response.data.chat_response.text
|
||||
)
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_characters(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
# transform response
|
||||
result = LLMResult(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: BaseChatResponse,
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
:param model: model name
|
||||
:param credentials: credentials
|
||||
:param response: response
|
||||
:param prompt_messages: prompt messages
|
||||
:return: llm response chunk generator result
|
||||
"""
|
||||
index = -1
|
||||
events = response.data.events()
|
||||
for stream in events:
|
||||
chunk = json.loads(stream.data)
|
||||
#print(chunk)
|
||||
#chunk: {'apiFormat': 'COHERE', 'text': 'Hello'}
|
||||
|
||||
|
||||
|
||||
#for chunk in response:
|
||||
#for part in chunk.parts:
|
||||
#if part.function_call:
|
||||
# assistant_prompt_message.tool_calls = [
|
||||
# AssistantPromptMessage.ToolCall(
|
||||
# id=part.function_call.name,
|
||||
# type='function',
|
||||
# function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
# name=part.function_call.name,
|
||||
# arguments=json.dumps(dict(part.function_call.args.items()))
|
||||
# )
|
||||
# )
|
||||
# ]
|
||||
|
||||
if "finishReason" not in chunk:
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=''
|
||||
)
|
||||
if model.startswith("cohere"):
|
||||
if chunk["text"]:
|
||||
assistant_prompt_message.content += chunk["text"]
|
||||
elif model.startswith("meta"):
|
||||
assistant_prompt_message.content += chunk["message"]["content"][0]["text"]
|
||||
index += 1
|
||||
# transform assistant message to prompt message
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message
|
||||
)
|
||||
)
|
||||
else:
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_characters(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_characters(model, credentials, [assistant_prompt_message])
|
||||
|
||||
# transform usage
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=str(chunk["finishReason"]),
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
|
||||
def _convert_one_message_to_text(self, message: PromptMessage) -> str:
|
||||
"""
|
||||
Convert a single message to a string.
|
||||
|
||||
:param message: PromptMessage to convert.
|
||||
:return: String representation of the message.
|
||||
"""
|
||||
human_prompt = "\n\nuser:"
|
||||
ai_prompt = "\n\nmodel:"
|
||||
|
||||
content = message.content
|
||||
if isinstance(content, list):
|
||||
content = "".join(
|
||||
c.data for c in content if c.type != PromptMessageContentType.IMAGE
|
||||
)
|
||||
|
||||
if isinstance(message, UserPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, AssistantPromptMessage):
|
||||
message_text = f"{ai_prompt} {content}"
|
||||
elif isinstance(message, SystemPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
elif isinstance(message, ToolPromptMessage):
|
||||
message_text = f"{human_prompt} {content}"
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
|
||||
return message_text
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [],
|
||||
InvokeServerUnavailableError: [],
|
||||
InvokeRateLimitError: [],
|
||||
InvokeAuthorizationError: [],
|
||||
InvokeBadRequestError: []
|
||||
}
|
||||
@ -0,0 +1,51 @@
|
||||
model: meta.llama-3-70b-instruct
|
||||
label:
|
||||
zh_Hans: meta.llama-3-70b-instruct
|
||||
en_US: meta.llama-3-70b-instruct
|
||||
model_type: llm
|
||||
features:
|
||||
- agent-thought
|
||||
model_properties:
|
||||
mode: chat
|
||||
context_size: 131072
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 1
|
||||
max: 2.0
|
||||
- name: topP
|
||||
use_template: top_p
|
||||
default: 0.75
|
||||
min: 0
|
||||
max: 1
|
||||
- name: topK
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
help:
|
||||
zh_Hans: 仅从每个后续标记的前 K 个选项中采样。
|
||||
en_US: Only sample from the top K options for each subsequent token.
|
||||
required: false
|
||||
default: 0
|
||||
min: 0
|
||||
max: 500
|
||||
- name: presencePenalty
|
||||
use_template: presence_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: frequencyPenalty
|
||||
use_template: frequency_penalty
|
||||
min: -2
|
||||
max: 2
|
||||
default: 0
|
||||
- name: maxTokens
|
||||
use_template: max_tokens
|
||||
default: 600
|
||||
max: 8000
|
||||
pricing:
|
||||
input: '0.015'
|
||||
output: '0.015'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@ -0,0 +1,34 @@
|
||||
import logging
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OCIGENAIProvider(ModelProvider):
|
||||
|
||||
def validate_provider_credentials(self, credentials: dict) -> None:
|
||||
"""
|
||||
Validate provider credentials
|
||||
|
||||
if validate failed, raise exception
|
||||
|
||||
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
|
||||
"""
|
||||
try:
|
||||
model_instance = self.get_model_instance(ModelType.LLM)
|
||||
|
||||
# Use `cohere.command-r-plus` model for validate,
|
||||
model_instance.validate_credentials(
|
||||
model='cohere.command-r-plus',
|
||||
credentials=credentials
|
||||
)
|
||||
except CredentialsValidateFailedError as ex:
|
||||
raise ex
|
||||
except Exception as ex:
|
||||
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
|
||||
raise ex
|
||||
|
||||
|
||||
@ -0,0 +1,42 @@
|
||||
provider: oci
|
||||
label:
|
||||
en_US: OCIGenerativeAI
|
||||
description:
|
||||
en_US: Models provided by OCI, such as Cohere Command R and Cohere Command R+.
|
||||
zh_Hans: OCI 提供的模型,例如 Cohere Command R 和 Cohere Command R+。
|
||||
icon_small:
|
||||
en_US: icon_s_en.svg
|
||||
icon_large:
|
||||
en_US: icon_l_en.svg
|
||||
background: "#FFFFFF"
|
||||
help:
|
||||
title:
|
||||
en_US: Get your API Key from OCI
|
||||
zh_Hans: 从 OCI 获取 API Key
|
||||
url:
|
||||
en_US: https://docs.cloud.oracle.com/Content/API/Concepts/sdkconfig.htm
|
||||
supported_model_types:
|
||||
- llm
|
||||
- text-embedding
|
||||
#- rerank
|
||||
configurate_methods:
|
||||
- predefined-model
|
||||
#- customizable-model
|
||||
provider_credential_schema:
|
||||
credential_form_schemas:
|
||||
- variable: oci_config_content
|
||||
label:
|
||||
en_US: oci api key config file's content
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 oci api key config 文件的内容(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) )
|
||||
en_US: Enter your oci api key config file's content(base64.b64encode("user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid".encode('utf-8')) )
|
||||
- variable: oci_key_content
|
||||
label:
|
||||
en_US: oci api key file's content
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 oci api key 文件的内容(base64.b64encode("pem file content".encode('utf-8')))
|
||||
en_US: Enter your oci api key file's content(base64.b64encode("pem file content".encode('utf-8')))
|
||||
@ -0,0 +1,5 @@
|
||||
- cohere.embed-english-light-v2.0
|
||||
- cohere.embed-english-light-v3.0
|
||||
- cohere.embed-english-v3.0
|
||||
- cohere.embed-multilingual-light-v3.0
|
||||
- cohere.embed-multilingual-v3.0
|
||||
@ -0,0 +1,9 @@
|
||||
model: cohere.embed-english-light-v2.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 1024
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.001'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@ -0,0 +1,9 @@
|
||||
model: cohere.embed-english-light-v3.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 384
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.001'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@ -0,0 +1,9 @@
|
||||
model: cohere.embed-english-v3.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 1024
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.001'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@ -0,0 +1,9 @@
|
||||
model: cohere.embed-multilingual-light-v3.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 384
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.001'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@ -0,0 +1,9 @@
|
||||
model: cohere.embed-multilingual-v3.0
|
||||
model_type: text-embedding
|
||||
model_properties:
|
||||
context_size: 1024
|
||||
max_chunks: 48
|
||||
pricing:
|
||||
input: '0.001'
|
||||
unit: '0.0001'
|
||||
currency: USD
|
||||
@ -0,0 +1,242 @@
|
||||
import base64
|
||||
import copy
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import oci
|
||||
|
||||
from core.model_runtime.entities.model_entities import PriceType
|
||||
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.errors.validate import CredentialsValidateFailedError
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
|
||||
request_template = {
|
||||
"compartmentId": "",
|
||||
"servingMode": {
|
||||
"modelId": "cohere.embed-english-light-v3.0",
|
||||
"servingType": "ON_DEMAND"
|
||||
},
|
||||
"truncate": "NONE",
|
||||
"inputs": [""]
|
||||
}
|
||||
oci_config_template = {
|
||||
"user": "",
|
||||
"fingerprint": "",
|
||||
"tenancy": "",
|
||||
"region": "",
|
||||
"compartment_id": "",
|
||||
"key_content": ""
|
||||
}
|
||||
class OCITextEmbeddingModel(TextEmbeddingModel):
|
||||
"""
|
||||
Model class for Cohere text embedding model.
|
||||
"""
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
texts: list[str], user: Optional[str] = None) \
|
||||
-> TextEmbeddingResult:
|
||||
"""
|
||||
Invoke text embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:param user: unique user id
|
||||
:return: embeddings result
|
||||
"""
|
||||
# get model properties
|
||||
context_size = self._get_context_size(model, credentials)
|
||||
max_chunks = self._get_max_chunks(model, credentials)
|
||||
|
||||
inputs = []
|
||||
indices = []
|
||||
used_tokens = 0
|
||||
|
||||
for i, text in enumerate(texts):
|
||||
|
||||
# Here token count is only an approximation based on the GPT2 tokenizer
|
||||
num_tokens = self._get_num_tokens_by_gpt2(text)
|
||||
|
||||
if num_tokens >= context_size:
|
||||
cutoff = int(len(text) * (np.floor(context_size / num_tokens)))
|
||||
# if num tokens is larger than context length, only use the start
|
||||
inputs.append(text[0: cutoff])
|
||||
else:
|
||||
inputs.append(text)
|
||||
indices += [i]
|
||||
|
||||
batched_embeddings = []
|
||||
_iter = range(0, len(inputs), max_chunks)
|
||||
|
||||
for i in _iter:
|
||||
# call embedding model
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=inputs[i: i + max_chunks]
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
# calc usage
|
||||
usage = self._calc_response_usage(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
tokens=used_tokens
|
||||
)
|
||||
|
||||
return TextEmbeddingResult(
|
||||
embeddings=batched_embeddings,
|
||||
usage=usage,
|
||||
model=model
|
||||
)
|
||||
|
||||
def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
return sum(self._get_num_tokens_by_gpt2(text) for text in texts)
|
||||
|
||||
def get_num_characters(self, model: str, credentials: dict, texts: list[str]) -> int:
|
||||
"""
|
||||
Get number of tokens for given prompt messages
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return:
|
||||
"""
|
||||
characters = 0
|
||||
for text in texts:
|
||||
characters += len(text)
|
||||
return characters
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
try:
|
||||
# call embedding model
|
||||
self._embedding_invoke(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
texts=['ping']
|
||||
)
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
def _embedding_invoke(self, model: str, credentials: dict, texts: list[str]) -> tuple[list[list[float]], int]:
|
||||
"""
|
||||
Invoke embedding model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param texts: texts to embed
|
||||
:return: embeddings and used tokens
|
||||
"""
|
||||
|
||||
# oci
|
||||
# initialize client
|
||||
oci_config = copy.deepcopy(oci_config_template)
|
||||
if "oci_config_content" in credentials:
|
||||
oci_config_content = base64.b64decode(credentials.get('oci_config_content')).decode('utf-8')
|
||||
config_items = oci_config_content.split("/")
|
||||
if len(config_items) != 5:
|
||||
raise CredentialsValidateFailedError("oci_config_content should be base64.b64encode('user_ocid/fingerprint/tenancy_ocid/region/compartment_ocid'.encode('utf-8'))")
|
||||
oci_config["user"] = config_items[0]
|
||||
oci_config["fingerprint"] = config_items[1]
|
||||
oci_config["tenancy"] = config_items[2]
|
||||
oci_config["region"] = config_items[3]
|
||||
oci_config["compartment_id"] = config_items[4]
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
if "oci_key_content" in credentials:
|
||||
oci_key_content = base64.b64decode(credentials.get('oci_key_content')).decode('utf-8')
|
||||
oci_config["key_content"] = oci_key_content.encode(encoding="utf-8")
|
||||
else:
|
||||
raise CredentialsValidateFailedError("need to set oci_config_content in credentials ")
|
||||
# oci_config = oci.config.from_file('~/.oci/config', credentials.get('oci_api_profile'))
|
||||
compartment_id = oci_config["compartment_id"]
|
||||
client = oci.generative_ai_inference.GenerativeAiInferenceClient(config=oci_config)
|
||||
# call embedding model
|
||||
request_args = copy.deepcopy(request_template)
|
||||
request_args["compartmentId"] = compartment_id
|
||||
request_args["servingMode"]["modelId"] = model
|
||||
request_args["inputs"] = texts
|
||||
response = client.embed_text(request_args)
|
||||
return response.data.embeddings, self.get_num_characters(model=model, credentials=credentials, texts=texts)
|
||||
|
||||
def _calc_response_usage(self, model: str, credentials: dict, tokens: int) -> EmbeddingUsage:
|
||||
"""
|
||||
Calculate response usage
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param tokens: input tokens
|
||||
:return: usage
|
||||
"""
|
||||
# get input price info
|
||||
input_price_info = self.get_price(
|
||||
model=model,
|
||||
credentials=credentials,
|
||||
price_type=PriceType.INPUT,
|
||||
tokens=tokens
|
||||
)
|
||||
|
||||
# transform usage
|
||||
usage = EmbeddingUsage(
|
||||
tokens=tokens,
|
||||
total_tokens=tokens,
|
||||
unit_price=input_price_info.unit_price,
|
||||
price_unit=input_price_info.unit,
|
||||
total_price=input_price_info.total_amount,
|
||||
currency=input_price_info.currency,
|
||||
latency=time.perf_counter() - self.started_at
|
||||
)
|
||||
|
||||
return usage
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
KeyError
|
||||
]
|
||||
}
|
||||
@ -0,0 +1,142 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import IO, Any, Optional
|
||||
|
||||
import boto3
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel
|
||||
from core.model_runtime.model_providers.sagemaker.sagemaker import generate_presigned_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SageMakerSpeech2TextModel(Speech2TextModel):
|
||||
"""
|
||||
Model class for Xinference speech to text model.
|
||||
"""
|
||||
sagemaker_client: Any = None
|
||||
s3_client : Any = None
|
||||
|
||||
def _invoke(self, model: str, credentials: dict,
|
||||
file: IO[bytes], user: Optional[str] = None) \
|
||||
-> str:
|
||||
"""
|
||||
Invoke speech2text model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param file: audio file
|
||||
:param user: unique user id
|
||||
:return: text for given audio file
|
||||
"""
|
||||
asr_text = None
|
||||
|
||||
try:
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('aws_access_key_id')
|
||||
secret_key = credentials.get('aws_secret_access_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
self.s3_client = boto3.client("s3")
|
||||
|
||||
s3_prefix='dify/speech2text/'
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
bucket = credentials.get('audio_s3_cache_bucket')
|
||||
|
||||
s3_presign_url = generate_presigned_url(self.s3_client, file, bucket, s3_prefix)
|
||||
payload = {
|
||||
"audio_s3_presign_uri" : s3_presign_url
|
||||
}
|
||||
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=sagemaker_endpoint,
|
||||
Body=json.dumps(payload),
|
||||
ContentType="application/json"
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_obj = json.loads(json_str)
|
||||
asr_text = json_obj['text']
|
||||
except Exception as e:
|
||||
logger.exception(f'Exception {e}, line : {line}')
|
||||
|
||||
return asr_text
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError,
|
||||
KeyError,
|
||||
ValueError
|
||||
]
|
||||
}
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.SPEECH2TEXT,
|
||||
model_properties={ },
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
||||
@ -0,0 +1,287 @@
|
||||
import concurrent.futures
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
import boto3
|
||||
import requests
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType
|
||||
from core.model_runtime.errors.invoke import (
|
||||
InvokeAuthorizationError,
|
||||
InvokeBadRequestError,
|
||||
InvokeConnectionError,
|
||||
InvokeError,
|
||||
InvokeRateLimitError,
|
||||
InvokeServerUnavailableError,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.tts_model import TTSModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class TTSModelType(Enum):
|
||||
PresetVoice = "PresetVoice"
|
||||
CloneVoice = "CloneVoice"
|
||||
CloneVoice_CrossLingual = "CloneVoice_CrossLingual"
|
||||
InstructVoice = "InstructVoice"
|
||||
|
||||
class SageMakerText2SpeechModel(TTSModel):
|
||||
|
||||
sagemaker_client: Any = None
|
||||
s3_client : Any = None
|
||||
comprehend_client : Any = None
|
||||
|
||||
def __init__(self):
|
||||
# preset voices, need support custom voice
|
||||
self.model_voices = {
|
||||
'__default': {
|
||||
'all': [
|
||||
{'name': 'Default', 'value': 'default'},
|
||||
]
|
||||
},
|
||||
'CosyVoice': {
|
||||
'zh-Hans': [
|
||||
{'name': '中文男', 'value': '中文男'},
|
||||
{'name': '中文女', 'value': '中文女'},
|
||||
{'name': '粤语女', 'value': '粤语女'},
|
||||
],
|
||||
'zh-Hant': [
|
||||
{'name': '中文男', 'value': '中文男'},
|
||||
{'name': '中文女', 'value': '中文女'},
|
||||
{'name': '粤语女', 'value': '粤语女'},
|
||||
],
|
||||
'en-US': [
|
||||
{'name': '英文男', 'value': '英文男'},
|
||||
{'name': '英文女', 'value': '英文女'},
|
||||
],
|
||||
'ja-JP': [
|
||||
{'name': '日语男', 'value': '日语男'},
|
||||
],
|
||||
'ko-KR': [
|
||||
{'name': '韩语女', 'value': '韩语女'},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:return:
|
||||
"""
|
||||
pass
|
||||
|
||||
def _detect_lang_code(self, content:str, map_dict:dict=None):
|
||||
map_dict = {
|
||||
"zh" : "<|zh|>",
|
||||
"en" : "<|en|>",
|
||||
"ja" : "<|jp|>",
|
||||
"zh-TW" : "<|yue|>",
|
||||
"ko" : "<|ko|>"
|
||||
}
|
||||
|
||||
response = self.comprehend_client.detect_dominant_language(Text=content)
|
||||
language_code = response['Languages'][0]['LanguageCode']
|
||||
|
||||
return map_dict.get(language_code, '<|zh|>')
|
||||
|
||||
def _build_tts_payload(self, model_type:str, content_text:str, model_role:str, prompt_text:str, prompt_audio:str, instruct_text:str):
|
||||
if model_type == TTSModelType.PresetVoice.value and model_role:
|
||||
return { "tts_text" : content_text, "role" : model_role }
|
||||
if model_type == TTSModelType.CloneVoice.value and prompt_text and prompt_audio:
|
||||
return { "tts_text" : content_text, "prompt_text": prompt_text, "prompt_audio" : prompt_audio }
|
||||
if model_type == TTSModelType.CloneVoice_CrossLingual.value and prompt_audio:
|
||||
lang_tag = self._detect_lang_code(content_text)
|
||||
return { "tts_text" : f"{content_text}", "prompt_audio" : prompt_audio, "lang_tag" : lang_tag }
|
||||
if model_type == TTSModelType.InstructVoice.value and instruct_text and model_role:
|
||||
return { "tts_text" : content_text, "role" : model_role, "instruct_text" : instruct_text }
|
||||
|
||||
raise RuntimeError(f"Invalid params for {model_type}")
|
||||
|
||||
def _invoke(self, model: str, tenant_id: str, credentials: dict, content_text: str, voice: str,
|
||||
user: Optional[str] = None):
|
||||
"""
|
||||
_invoke text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param tenant_id: user tenant id
|
||||
:param credentials: model credentials
|
||||
:param voice: model timbre
|
||||
:param content_text: text content to be translated
|
||||
:param user: unique user id
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
if not self.sagemaker_client:
|
||||
access_key = credentials.get('aws_access_key_id')
|
||||
secret_key = credentials.get('aws_secret_access_key')
|
||||
aws_region = credentials.get('aws_region')
|
||||
if aws_region:
|
||||
if access_key and secret_key:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
self.comprehend_client = boto3.client('comprehend',
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region)
|
||||
self.s3_client = boto3.client("s3", region_name=aws_region)
|
||||
self.comprehend_client = boto3.client('comprehend', region_name=aws_region)
|
||||
else:
|
||||
self.sagemaker_client = boto3.client("sagemaker-runtime")
|
||||
self.s3_client = boto3.client("s3")
|
||||
self.comprehend_client = boto3.client('comprehend')
|
||||
|
||||
model_type = credentials.get('audio_model_type', 'PresetVoice')
|
||||
prompt_text = credentials.get('prompt_text')
|
||||
prompt_audio = credentials.get('prompt_audio')
|
||||
instruct_text = credentials.get('instruct_text')
|
||||
sagemaker_endpoint = credentials.get('sagemaker_endpoint')
|
||||
payload = self._build_tts_payload(
|
||||
model_type,
|
||||
content_text,
|
||||
voice,
|
||||
prompt_text,
|
||||
prompt_audio,
|
||||
instruct_text
|
||||
)
|
||||
|
||||
return self._tts_invoke_streaming(model_type, payload, sagemaker_endpoint)
|
||||
|
||||
def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
|
||||
"""
|
||||
used to define customizable model schema
|
||||
"""
|
||||
entity = AIModelEntity(
|
||||
model=model,
|
||||
label=I18nObject(
|
||||
en_US=model
|
||||
),
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_type=ModelType.TTS,
|
||||
model_properties={},
|
||||
parameter_rules=[]
|
||||
)
|
||||
|
||||
return entity
|
||||
|
||||
@property
|
||||
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
|
||||
"""
|
||||
Map model invoke error to unified error
|
||||
The key is the error type thrown to the caller
|
||||
The value is the error type thrown by the model,
|
||||
which needs to be converted into a unified error type for the caller.
|
||||
|
||||
:return: Invoke error mapping
|
||||
"""
|
||||
return {
|
||||
InvokeConnectionError: [
|
||||
InvokeConnectionError
|
||||
],
|
||||
InvokeServerUnavailableError: [
|
||||
InvokeServerUnavailableError
|
||||
],
|
||||
InvokeRateLimitError: [
|
||||
InvokeRateLimitError
|
||||
],
|
||||
InvokeAuthorizationError: [
|
||||
InvokeAuthorizationError
|
||||
],
|
||||
InvokeBadRequestError: [
|
||||
InvokeBadRequestError,
|
||||
KeyError,
|
||||
ValueError
|
||||
]
|
||||
}
|
||||
|
||||
def _get_model_default_voice(self, model: str, credentials: dict) -> any:
|
||||
return ""
|
||||
|
||||
def _get_model_word_limit(self, model: str, credentials: dict) -> int:
|
||||
return 15
|
||||
|
||||
def _get_model_audio_type(self, model: str, credentials: dict) -> str:
|
||||
return "mp3"
|
||||
|
||||
def _get_model_workers_limit(self, model: str, credentials: dict) -> int:
|
||||
return 5
|
||||
|
||||
def get_tts_model_voices(self, model: str, credentials: dict, language: Optional[str] = None) -> list:
|
||||
audio_model_name = 'CosyVoice'
|
||||
for key, voices in self.model_voices.items():
|
||||
if key in audio_model_name:
|
||||
if language and language in voices:
|
||||
return voices[language]
|
||||
elif 'all' in voices:
|
||||
return voices['all']
|
||||
|
||||
return self.model_voices['__default']['all']
|
||||
|
||||
def _invoke_sagemaker(self, payload:dict, endpoint:str):
|
||||
response_model = self.sagemaker_client.invoke_endpoint(
|
||||
EndpointName=endpoint,
|
||||
Body=json.dumps(payload),
|
||||
ContentType="application/json",
|
||||
)
|
||||
json_str = response_model['Body'].read().decode('utf8')
|
||||
json_obj = json.loads(json_str)
|
||||
return json_obj
|
||||
|
||||
def _tts_invoke_streaming(self, model_type:str, payload:dict, sagemaker_endpoint:str) -> any:
|
||||
"""
|
||||
_tts_invoke_streaming text2speech model
|
||||
|
||||
:param model: model name
|
||||
:param credentials: model credentials
|
||||
:param content_text: text content to be translated
|
||||
:param voice: model timbre
|
||||
:return: text translated to audio file
|
||||
"""
|
||||
try:
|
||||
lang_tag = ''
|
||||
if model_type == TTSModelType.CloneVoice_CrossLingual.value:
|
||||
lang_tag = payload.pop('lang_tag')
|
||||
|
||||
word_limit = self._get_model_word_limit(model='', credentials={})
|
||||
content_text = payload.get("tts_text")
|
||||
if len(content_text) > word_limit:
|
||||
split_sentences = self._split_text_into_sentences(content_text, max_length=word_limit)
|
||||
sentences = [ f"{lang_tag}{s}" for s in split_sentences if len(s) ]
|
||||
len_sent = len(sentences)
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=min(4, len_sent))
|
||||
payloads = [ copy.deepcopy(payload) for i in range(len_sent) ]
|
||||
for idx in range(len_sent):
|
||||
payloads[idx]["tts_text"] = sentences[idx]
|
||||
|
||||
futures = [ executor.submit(
|
||||
self._invoke_sagemaker,
|
||||
payload=payload,
|
||||
endpoint=sagemaker_endpoint,
|
||||
)
|
||||
for payload in payloads]
|
||||
|
||||
for index, future in enumerate(futures):
|
||||
resp = future.result()
|
||||
audio_bytes = requests.get(resp.get('s3_presign_url')).content
|
||||
for i in range(0, len(audio_bytes), 1024):
|
||||
yield audio_bytes[i:i + 1024]
|
||||
else:
|
||||
resp = self._invoke_sagemaker(payload, sagemaker_endpoint)
|
||||
audio_bytes = requests.get(resp.get('s3_presign_url')).content
|
||||
|
||||
for i in range(0, len(audio_bytes), 1024):
|
||||
yield audio_bytes[i:i + 1024]
|
||||
except Exception as ex:
|
||||
raise InvokeBadRequestError(str(ex))
|
||||
@ -0,0 +1,33 @@
|
||||
model: spark-lite
|
||||
label:
|
||||
en_US: Spark Lite
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.5
|
||||
help:
|
||||
zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。
|
||||
en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 模型回答的tokens的最大长度。
|
||||
en_US: Maximum length of tokens for the model response.
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
default: 4
|
||||
min: 1
|
||||
max: 6
|
||||
help:
|
||||
zh_Hans: 从 k 个候选中随机选择一个(非等概率)。
|
||||
en_US: Randomly select one from k candidates (non-equal probability).
|
||||
required: false
|
||||
@ -0,0 +1,33 @@
|
||||
model: spark-max
|
||||
label:
|
||||
en_US: Spark Max
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.5
|
||||
help:
|
||||
zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。
|
||||
en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 模型回答的tokens的最大长度。
|
||||
en_US: Maximum length of tokens for the model response.
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
default: 4
|
||||
min: 1
|
||||
max: 6
|
||||
help:
|
||||
zh_Hans: 从 k 个候选中随机选择一个(非等概率)。
|
||||
en_US: Randomly select one from k candidates (non-equal probability).
|
||||
required: false
|
||||
@ -0,0 +1,33 @@
|
||||
model: spark-pro-128k
|
||||
label:
|
||||
en_US: Spark Pro-128K
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.5
|
||||
help:
|
||||
zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。
|
||||
en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 4096
|
||||
help:
|
||||
zh_Hans: 模型回答的tokens的最大长度。
|
||||
en_US: Maximum length of tokens for the model response.
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
default: 4
|
||||
min: 1
|
||||
max: 6
|
||||
help:
|
||||
zh_Hans: 从 k 个候选中随机选择一个(非等概率)。
|
||||
en_US: Randomly select one from k candidates (non-equal probability).
|
||||
required: false
|
||||
@ -0,0 +1,33 @@
|
||||
model: spark-pro
|
||||
label:
|
||||
en_US: Spark Pro
|
||||
model_type: llm
|
||||
model_properties:
|
||||
mode: chat
|
||||
parameter_rules:
|
||||
- name: temperature
|
||||
use_template: temperature
|
||||
default: 0.5
|
||||
help:
|
||||
zh_Hans: 核采样阈值。用于决定结果随机性,取值越高随机性越强即相同的问题得到的不同答案的可能性越高。
|
||||
en_US: Kernel sampling threshold. Used to determine the randomness of the results. The higher the value, the stronger the randomness, that is, the higher the possibility of getting different answers to the same question.
|
||||
- name: max_tokens
|
||||
use_template: max_tokens
|
||||
default: 4096
|
||||
min: 1
|
||||
max: 8192
|
||||
help:
|
||||
zh_Hans: 模型回答的tokens的最大长度。
|
||||
en_US: Maximum length of tokens for the model response.
|
||||
- name: top_k
|
||||
label:
|
||||
zh_Hans: 取样数量
|
||||
en_US: Top k
|
||||
type: int
|
||||
default: 4
|
||||
min: 1
|
||||
max: 6
|
||||
help:
|
||||
zh_Hans: 从 k 个候选中随机选择一个(非等概率)。
|
||||
en_US: Randomly select one from k candidates (non-equal probability).
|
||||
required: false
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue