chore: refactor the http executor node (#5212)
parent
1e28a8c033
commit
f7900f298f
@ -1,65 +1,48 @@
|
||||
"""
|
||||
Proxy requests to avoid SSRF
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from httpx import get as _get
|
||||
from httpx import head as _head
|
||||
from httpx import options as _options
|
||||
from httpx import patch as _patch
|
||||
from httpx import post as _post
|
||||
from httpx import put as _put
|
||||
from requests import delete as _delete
|
||||
import httpx
|
||||
|
||||
SSRF_PROXY_ALL_URL = os.getenv('SSRF_PROXY_ALL_URL', '')
|
||||
SSRF_PROXY_HTTP_URL = os.getenv('SSRF_PROXY_HTTP_URL', '')
|
||||
SSRF_PROXY_HTTPS_URL = os.getenv('SSRF_PROXY_HTTPS_URL', '')
|
||||
|
||||
requests_proxies = {
|
||||
'http': SSRF_PROXY_HTTP_URL,
|
||||
'https': SSRF_PROXY_HTTPS_URL
|
||||
} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None
|
||||
|
||||
httpx_proxies = {
|
||||
proxies = {
|
||||
'http://': SSRF_PROXY_HTTP_URL,
|
||||
'https://': SSRF_PROXY_HTTPS_URL
|
||||
} if SSRF_PROXY_HTTP_URL and SSRF_PROXY_HTTPS_URL else None
|
||||
|
||||
def get(url, *args, **kwargs):
|
||||
return _get(url=url, *args, proxies=httpx_proxies, **kwargs)
|
||||
|
||||
def post(url, *args, **kwargs):
|
||||
return _post(url=url, *args, proxies=httpx_proxies, **kwargs)
|
||||
|
||||
def put(url, *args, **kwargs):
|
||||
return _put(url=url, *args, proxies=httpx_proxies, **kwargs)
|
||||
|
||||
def patch(url, *args, **kwargs):
|
||||
return _patch(url=url, *args, proxies=httpx_proxies, **kwargs)
|
||||
|
||||
def delete(url, *args, **kwargs):
|
||||
if 'follow_redirects' in kwargs:
|
||||
if kwargs['follow_redirects']:
|
||||
kwargs['allow_redirects'] = kwargs['follow_redirects']
|
||||
kwargs.pop('follow_redirects')
|
||||
if 'timeout' in kwargs:
|
||||
timeout = kwargs['timeout']
|
||||
if timeout is None:
|
||||
kwargs.pop('timeout')
|
||||
elif isinstance(timeout, tuple):
|
||||
# check length of tuple
|
||||
if len(timeout) == 2:
|
||||
kwargs['timeout'] = timeout
|
||||
elif len(timeout) == 1:
|
||||
kwargs['timeout'] = timeout[0]
|
||||
elif len(timeout) > 2:
|
||||
kwargs['timeout'] = (timeout[0], timeout[1])
|
||||
else:
|
||||
kwargs['timeout'] = (timeout, timeout)
|
||||
return _delete(url=url, *args, proxies=requests_proxies, **kwargs)
|
||||
|
||||
def head(url, *args, **kwargs):
|
||||
return _head(url=url, *args, proxies=httpx_proxies, **kwargs)
|
||||
|
||||
def options(url, *args, **kwargs):
|
||||
return _options(url=url, *args, proxies=httpx_proxies, **kwargs)
|
||||
|
||||
def make_request(method, url, **kwargs):
|
||||
if SSRF_PROXY_ALL_URL:
|
||||
return httpx.request(method=method, url=url, proxy=SSRF_PROXY_ALL_URL, **kwargs)
|
||||
elif proxies:
|
||||
return httpx.request(method=method, url=url, proxies=proxies, **kwargs)
|
||||
else:
|
||||
return httpx.request(method=method, url=url, **kwargs)
|
||||
|
||||
|
||||
def get(url, **kwargs):
|
||||
return make_request('GET', url, **kwargs)
|
||||
|
||||
|
||||
def post(url, **kwargs):
|
||||
return make_request('POST', url, **kwargs)
|
||||
|
||||
|
||||
def put(url, **kwargs):
|
||||
return make_request('PUT', url, **kwargs)
|
||||
|
||||
|
||||
def patch(url, **kwargs):
|
||||
return make_request('PATCH', url, **kwargs)
|
||||
|
||||
|
||||
def delete(url, **kwargs):
|
||||
return make_request('DELETE', url, **kwargs)
|
||||
|
||||
|
||||
def head(url, **kwargs):
|
||||
return make_request('HEAD', url, **kwargs)
|
||||
|
||||
@ -0,0 +1,36 @@
|
||||
import json
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
|
||||
class MockedHttp:
|
||||
def httpx_request(method: Literal['GET', 'POST', 'PUT', 'DELETE', 'PATCH', 'HEAD'],
|
||||
url: str, **kwargs) -> httpx.Response:
|
||||
"""
|
||||
Mocked httpx.request
|
||||
"""
|
||||
request = httpx.Request(
|
||||
method,
|
||||
url,
|
||||
params=kwargs.get('params'),
|
||||
headers=kwargs.get('headers'),
|
||||
cookies=kwargs.get('cookies')
|
||||
)
|
||||
data = kwargs.get('data', None)
|
||||
resp = json.dumps(data).encode('utf-8') if data else b'OK'
|
||||
response = httpx.Response(
|
||||
status_code=200,
|
||||
request=request,
|
||||
content=resp,
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def setup_http_mock(request, monkeypatch: MonkeyPatch):
|
||||
monkeypatch.setattr(httpx, "request", MockedHttp.httpx_request)
|
||||
yield
|
||||
monkeypatch.undo()
|
||||
@ -0,0 +1,39 @@
|
||||
from core.tools.tool.api_tool import ApiTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from tests.integration_tests.tools.__mock.http import setup_http_mock
|
||||
|
||||
tool_bundle = {
|
||||
'server_url': 'http://www.example.com/{path_param}',
|
||||
'method': 'post',
|
||||
'author': '',
|
||||
'openapi': {'parameters': [{'in': 'path', 'name': 'path_param'},
|
||||
{'in': 'query', 'name': 'query_param'},
|
||||
{'in': 'cookie', 'name': 'cookie_param'},
|
||||
{'in': 'header', 'name': 'header_param'},
|
||||
],
|
||||
'requestBody': {
|
||||
'content': {'application/json': {'schema': {'properties': {'body_param': {'type': 'string'}}}}}}
|
||||
},
|
||||
'parameters': []
|
||||
}
|
||||
parameters = {
|
||||
'path_param': 'p_param',
|
||||
'query_param': 'q_param',
|
||||
'cookie_param': 'c_param',
|
||||
'header_param': 'h_param',
|
||||
'body_param': 'b_param',
|
||||
}
|
||||
|
||||
|
||||
def test_api_tool(setup_http_mock):
|
||||
tool = ApiTool(api_bundle=tool_bundle, runtime=Tool.Runtime(credentials={'auth_type': 'none'}))
|
||||
headers = tool.assembling_request(parameters)
|
||||
response = tool.do_http_request(tool.api_bundle.server_url, tool.api_bundle.method, headers, parameters)
|
||||
|
||||
assert response.status_code == 200
|
||||
assert '/p_param' == response.request.url.path
|
||||
assert b'query_param=q_param' == response.request.url.query
|
||||
assert 'h_param' == response.request.headers.get('header_param')
|
||||
assert 'application/json' == response.request.headers.get('content-type')
|
||||
assert 'cookie_param=c_param' == response.request.headers.get('cookie')
|
||||
assert 'b_param' in response.content.decode()
|
||||
Loading…
Reference in New Issue