fix: miss usage of os.path.join for URL assembly and add tests on yarl (#4224)

pull/4269/head
Bowen Liang 2 years ago committed by GitHub
parent 01555463d2
commit 228de1f12a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,6 +1,5 @@
import logging import logging
from collections.abc import Generator from collections.abc import Generator
from os.path import join
from typing import Optional, cast from typing import Optional, cast
from httpx import Timeout from httpx import Timeout
@ -19,6 +18,7 @@ from openai import (
) )
from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message import FunctionCall from openai.types.chat.chat_completion_message import FunctionCall
from yarl import URL
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import ( from core.model_runtime.entities.message_entities import (
@ -265,7 +265,7 @@ class ChatGLMLargeLanguageModel(LargeLanguageModel):
client_kwargs = { client_kwargs = {
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0), "timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"api_key": "1", "api_key": "1",
"base_url": join(credentials['api_base'], 'v1') "base_url": str(URL(credentials['api_base']) / 'v1')
} }
return client_kwargs return client_kwargs

@ -1,8 +1,8 @@
from base64 import b64decode from base64 import b64decode
from os.path import join
from typing import Any, Union from typing import Any, Union
from openai import OpenAI from openai import OpenAI
from yarl import URL
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
@ -23,7 +23,7 @@ class DallE2Tool(BuiltinTool):
if not openai_base_url: if not openai_base_url:
openai_base_url = None openai_base_url = None
else: else:
openai_base_url = join(openai_base_url, 'v1') openai_base_url = str(URL(openai_base_url) / 'v1')
client = OpenAI( client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'], api_key=self.runtime.credentials['openai_api_key'],

@ -1,8 +1,8 @@
from base64 import b64decode from base64 import b64decode
from os.path import join
from typing import Any, Union from typing import Any, Union
from openai import OpenAI from openai import OpenAI
from yarl import URL
from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.entities.tool_entities import ToolInvokeMessage
from core.tools.tool.builtin_tool import BuiltinTool from core.tools.tool.builtin_tool import BuiltinTool
@ -23,7 +23,7 @@ class DallE3Tool(BuiltinTool):
if not openai_base_url: if not openai_base_url:
openai_base_url = None openai_base_url = None
else: else:
openai_base_url = join(openai_base_url, 'v1') openai_base_url = str(URL(openai_base_url) / 'v1')
client = OpenAI( client = OpenAI(
api_key=self.runtime.credentials['openai_api_key'], api_key=self.runtime.credentials['openai_api_key'],

@ -0,0 +1,23 @@
import pytest
from yarl import URL
def test_yarl_urls():
expected_1 = 'https://dify.ai/api'
assert str(URL('https://dify.ai') / 'api') == expected_1
assert str(URL('https://dify.ai/') / 'api') == expected_1
expected_2 = 'http://dify.ai:12345/api'
assert str(URL('http://dify.ai:12345') / 'api') == expected_2
assert str(URL('http://dify.ai:12345/') / 'api') == expected_2
expected_3 = 'https://dify.ai/api/v1'
assert str(URL('https://dify.ai') / 'api' / 'v1') == expected_3
assert str(URL('https://dify.ai') / 'api/v1') == expected_3
assert str(URL('https://dify.ai/') / 'api/v1') == expected_3
assert str(URL('https://dify.ai/api') / 'v1') == expected_3
assert str(URL('https://dify.ai/api/') / 'v1') == expected_3
with pytest.raises(ValueError) as e1:
str(URL('https://dify.ai') / '/api')
assert str(e1.value) == "Appending path '/api' starting from slash is forbidden"
Loading…
Cancel
Save