Merge branch 'main' into fix/note-node-zoom-issue

pull/7399/head
Yi 2 years ago
commit d688bebb1a

@ -85,7 +85,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
tools=tools, stop=stop, stream=stream, user=user, tools=tools, stop=stop, stream=stream, user=user,
extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter( extra_model_kwargs=XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'], server_url=credentials['server_url'],
model_uid=credentials['model_uid'] model_uid=credentials['model_uid'],
api_key=credentials.get('api_key'),
) )
) )
@ -106,7 +107,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
extra_param = XinferenceHelper.get_xinference_extra_parameter( extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'], server_url=credentials['server_url'],
model_uid=credentials['model_uid'] model_uid=credentials['model_uid'],
api_key=credentials.get('api_key')
) )
if 'completion_type' not in credentials: if 'completion_type' not in credentials:
if 'chat' in extra_param.model_ability: if 'chat' in extra_param.model_ability:
@ -396,7 +398,8 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
else: else:
extra_args = XinferenceHelper.get_xinference_extra_parameter( extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'], server_url=credentials['server_url'],
model_uid=credentials['model_uid'] model_uid=credentials['model_uid'],
api_key=credentials.get('api_key')
) )
if 'chat' in extra_args.model_ability: if 'chat' in extra_args.model_ability:
@ -464,6 +467,7 @@ class XinferenceAILargeLanguageModel(LargeLanguageModel):
xinference_client = Client( xinference_client = Client(
base_url=credentials['server_url'], base_url=credentials['server_url'],
api_key=credentials.get('api_key'),
) )
xinference_model = xinference_client.get_model(credentials['model_uid']) xinference_model = xinference_client.get_model(credentials['model_uid'])

@ -108,7 +108,8 @@ class XinferenceRerankModel(RerankModel):
# initialize client # initialize client
client = Client( client = Client(
base_url=credentials['server_url'] base_url=credentials['server_url'],
api_key=credentials.get('api_key'),
) )
xinference_client = client.get_model(model_uid=credentials['model_uid']) xinference_client = client.get_model(model_uid=credentials['model_uid'])

@ -52,7 +52,8 @@ class XinferenceSpeech2TextModel(Speech2TextModel):
# initialize client # initialize client
client = Client( client = Client(
base_url=credentials['server_url'] base_url=credentials['server_url'],
api_key=credentials.get('api_key'),
) )
xinference_client = client.get_model(model_uid=credentials['model_uid']) xinference_client = client.get_model(model_uid=credentials['model_uid'])

@ -110,14 +110,22 @@ class XinferenceTextEmbeddingModel(TextEmbeddingModel):
server_url = credentials['server_url'] server_url = credentials['server_url']
model_uid = credentials['model_uid'] model_uid = credentials['model_uid']
extra_args = XinferenceHelper.get_xinference_extra_parameter(server_url=server_url, model_uid=model_uid) api_key = credentials.get('api_key')
extra_args = XinferenceHelper.get_xinference_extra_parameter(
server_url=server_url,
model_uid=model_uid,
api_key=api_key,
)
if extra_args.max_tokens: if extra_args.max_tokens:
credentials['max_tokens'] = extra_args.max_tokens credentials['max_tokens'] = extra_args.max_tokens
if server_url.endswith('/'): if server_url.endswith('/'):
server_url = server_url[:-1] server_url = server_url[:-1]
client = Client(base_url=server_url) client = Client(
base_url=server_url,
api_key=api_key,
)
try: try:
handle = client.get_model(model_uid=model_uid) handle = client.get_model(model_uid=model_uid)

@ -81,7 +81,8 @@ class XinferenceText2SpeechModel(TTSModel):
extra_param = XinferenceHelper.get_xinference_extra_parameter( extra_param = XinferenceHelper.get_xinference_extra_parameter(
server_url=credentials['server_url'], server_url=credentials['server_url'],
model_uid=credentials['model_uid'] model_uid=credentials['model_uid'],
api_key=credentials.get('api_key'),
) )
if 'text-to-audio' not in extra_param.model_ability: if 'text-to-audio' not in extra_param.model_ability:
@ -203,7 +204,11 @@ class XinferenceText2SpeechModel(TTSModel):
credentials['server_url'] = credentials['server_url'][:-1] credentials['server_url'] = credentials['server_url'][:-1]
try: try:
handle = RESTfulAudioModelHandle(credentials['model_uid'], credentials['server_url'], auth_headers={}) api_key = credentials.get('api_key')
auth_headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
handle = RESTfulAudioModelHandle(
credentials['model_uid'], credentials['server_url'], auth_headers=auth_headers
)
model_support_voice = [x.get("value") for x in model_support_voice = [x.get("value") for x in
self.get_tts_model_voices(model=model, credentials=credentials)] self.get_tts_model_voices(model=model, credentials=credentials)]

@ -35,13 +35,13 @@ cache_lock = Lock()
class XinferenceHelper: class XinferenceHelper:
@staticmethod @staticmethod
def get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: def get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
XinferenceHelper._clean_cache() XinferenceHelper._clean_cache()
with cache_lock: with cache_lock:
if model_uid not in cache: if model_uid not in cache:
cache[model_uid] = { cache[model_uid] = {
'expires': time() + 300, 'expires': time() + 300,
'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid) 'value': XinferenceHelper._get_xinference_extra_parameter(server_url, model_uid, api_key)
} }
return cache[model_uid]['value'] return cache[model_uid]['value']
@ -56,7 +56,7 @@ class XinferenceHelper:
pass pass
@staticmethod @staticmethod
def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter: def _get_xinference_extra_parameter(server_url: str, model_uid: str, api_key: str) -> XinferenceModelExtraParameter:
""" """
get xinference model extra parameter like model_format and model_handle_type get xinference model extra parameter like model_format and model_handle_type
""" """
@ -70,9 +70,10 @@ class XinferenceHelper:
session = Session() session = Session()
session.mount('http://', HTTPAdapter(max_retries=3)) session.mount('http://', HTTPAdapter(max_retries=3))
session.mount('https://', HTTPAdapter(max_retries=3)) session.mount('https://', HTTPAdapter(max_retries=3))
headers = {'Authorization': f'Bearer {api_key}'} if api_key else {}
try: try:
response = session.get(url, timeout=10) response = session.get(url, headers=headers, timeout=10)
except (MissingSchema, ConnectionError, Timeout) as e: except (MissingSchema, ConnectionError, Timeout) as e:
raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}') raise RuntimeError(f'get xinference model extra parameter failed, url: {url}, error: {e}')
if response.status_code != 200: if response.status_code != 200:

2
api/poetry.lock generated

@ -9584,4 +9584,4 @@ cffi = ["cffi (>=1.11)"]
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = ">=3.10,<3.13" python-versions = ">=3.10,<3.13"
content-hash = "165e4af9cfbce83ee831dd0e82159446ef595d7a7850ee8644c8e2d24dd7040d" content-hash = "a74c7b6a72145d5074aa84581df6e543ea422810caf0ba1561cd2d35497243ca"

@ -156,6 +156,7 @@ markdown = "~3.5.1"
novita-client = "^0.5.6" novita-client = "^0.5.6"
numpy = "~1.26.4" numpy = "~1.26.4"
openai = "~1.29.0" openai = "~1.29.0"
openpyxl = "~3.1.5"
oss2 = "2.18.5" oss2 = "2.18.5"
pandas = { version = "~2.2.2", extras = ["performance", "excel"] } pandas = { version = "~2.2.2", extras = ["performance", "excel"] }
psycopg2-binary = "~2.9.6" psycopg2-binary = "~2.9.6"
@ -173,7 +174,6 @@ readabilipy = "0.2.0"
redis = { version = "~5.0.3", extras = ["hiredis"] } redis = { version = "~5.0.3", extras = ["hiredis"] }
replicate = "~0.22.0" replicate = "~0.22.0"
resend = "~0.7.0" resend = "~0.7.0"
safetensors = "~0.4.3"
scikit-learn = "^1.5.1" scikit-learn = "^1.5.1"
sentry-sdk = { version = "~1.44.1", extras = ["flask"] } sentry-sdk = { version = "~1.44.1", extras = ["flask"] }
sqlalchemy = "~2.0.29" sqlalchemy = "~2.0.29"
@ -187,10 +187,16 @@ werkzeug = "~3.0.1"
xinference-client = "0.13.3" xinference-client = "0.13.3"
yarl = "~1.9.4" yarl = "~1.9.4"
zhipuai = "1.0.7" zhipuai = "1.0.7"
rank-bm25 = "~0.2.2" # Before adding new dependency, consider place it in alphabet order (a-z) and suitable group.
openpyxl = "^3.1.5"
############################################################
# Related transparent dependencies with pinned verion
# required by main implementations
############################################################
[tool.poetry.group.indriect.dependencies]
kaleido = "0.2.1" kaleido = "0.2.1"
elasticsearch = "8.14.0" rank-bm25 = "~0.2.2"
safetensors = "~0.4.3"
############################################################ ############################################################
# Tool dependencies required by tool implementations # Tool dependencies required by tool implementations
@ -198,6 +204,7 @@ elasticsearch = "8.14.0"
[tool.poetry.group.tool.dependencies] [tool.poetry.group.tool.dependencies]
arxiv = "2.1.0" arxiv = "2.1.0"
cloudscraper = "1.2.71"
matplotlib = "~3.8.2" matplotlib = "~3.8.2"
newspaper3k = "0.2.8" newspaper3k = "0.2.8"
duckduckgo-search = "^6.2.6" duckduckgo-search = "^6.2.6"
@ -209,26 +216,25 @@ twilio = "~9.0.4"
vanna = { version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"] } vanna = { version = "0.5.5", extras = ["postgres", "mysql", "clickhouse", "duckdb"] }
wikipedia = "1.4.0" wikipedia = "1.4.0"
yfinance = "~0.2.40" yfinance = "~0.2.40"
cloudscraper = "1.2.71"
############################################################ ############################################################
# VDB dependencies required by vector store clients # VDB dependencies required by vector store clients
############################################################ ############################################################
[tool.poetry.group.vdb.dependencies] [tool.poetry.group.vdb.dependencies]
alibabacloud_gpdb20160503 = "~3.8.0"
alibabacloud_tea_openapi = "~0.3.9"
chromadb = "0.5.1" chromadb = "0.5.1"
clickhouse-connect = "~0.7.16"
elasticsearch = "8.14.0"
oracledb = "~2.2.1" oracledb = "~2.2.1"
pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] } pgvecto-rs = { version = "~0.2.1", extras = ['sqlalchemy'] }
pgvector = "0.2.5" pgvector = "0.2.5"
pymilvus = "~2.4.4" pymilvus = "~2.4.4"
pymysql = "1.1.1"
tcvectordb = "1.3.2" tcvectordb = "1.3.2"
tidb-vector = "0.0.9" tidb-vector = "0.0.9"
qdrant-client = "1.7.3" qdrant-client = "1.7.3"
weaviate-client = "~3.21.0" weaviate-client = "~3.21.0"
alibabacloud_gpdb20160503 = "~3.8.0"
alibabacloud_tea_openapi = "~0.3.9"
clickhouse-connect = "~0.7.16"
############################################################ ############################################################
# Dev dependencies for running tests # Dev dependencies for running tests
@ -252,5 +258,5 @@ pytest-mock = "~3.14.0"
optional = true optional = true
[tool.poetry.group.lint.dependencies] [tool.poetry.group.lint.dependencies]
ruff = "~0.6.1"
dotenv-linter = "~0.5.0" dotenv-linter = "~0.5.0"
ruff = "~0.6.1"

Loading…
Cancel
Save