From 1555d637a8e80460d8106394235022ae2d3900b9 Mon Sep 17 00:00:00 2001 From: zhuyunxiang <493405455@qq.com> Date: Tue, 16 May 2023 18:52:54 +0800 Subject: [PATCH] fix openai base env bug --- api/core/embedding/openai_embedding.py | 4 ++-- api/core/llm/moderation.py | 5 +++++ api/core/llm/provider/openai_provider.py | 5 ++++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/api/core/embedding/openai_embedding.py b/api/core/embedding/openai_embedding.py index 134e8504d5..91c8aa8c5e 100644 --- a/api/core/embedding/openai_embedding.py +++ b/api/core/embedding/openai_embedding.py @@ -113,8 +113,8 @@ class OpenAIEmbedding(BaseEmbedding): self.deployment_name = deployment_name self.openai_api_key = openai_api_key # Use proxy openai base - if current_app.config['API_URL'] is not None: - openai.api_base = current_app.config['API_URL'] + if current_app.config['OPENAI_API_BASE'] is not None: + openai.api_base = current_app.config['OPENAI_API_BASE'] @handle_llm_exceptions def _get_query_embedding(self, query: str) -> List[float]: diff --git a/api/core/llm/moderation.py b/api/core/llm/moderation.py index d18d6fc5c2..be66513fa2 100644 --- a/api/core/llm/moderation.py +++ b/api/core/llm/moderation.py @@ -1,4 +1,5 @@ import openai +from flask import current_app from models.provider import ProviderName @@ -8,6 +9,10 @@ class Moderation: self.provider = provider self.api_key = api_key + # Use proxy openai base + if current_app.config['OPENAI_API_BASE'] is not None: + openai.api_base = current_app.config['OPENAI_API_BASE'] + if self.provider == ProviderName.OPENAI.value: self.client = openai.Moderation diff --git a/api/core/llm/provider/openai_provider.py b/api/core/llm/provider/openai_provider.py index 8257ad3aab..9b3f9ab876 100644 --- a/api/core/llm/provider/openai_provider.py +++ b/api/core/llm/provider/openai_provider.py @@ -3,7 +3,7 @@ from typing import Optional, Union import openai from openai.error import AuthenticationError, OpenAIError - +from flask import current_app from core.llm.moderation import Moderation from core.llm.provider.base import BaseProvider from core.llm.provider.errors import ValidateFailedError @@ -12,6 +12,9 @@ from models.provider import ProviderName class OpenAIProvider(BaseProvider): def get_models(self, model_id: Optional[str] = None) -> list[dict]: + # Use proxy openai base + if current_app.config['OPENAI_API_BASE'] is not None: + openai.api_base = current_app.config['OPENAI_API_BASE'] credentials = self.get_credentials(model_id) response = openai.Model.list(**credentials)