Co-authored-by: hashjang <hash@geek.com> Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>pull/19871/head
parent
e0e8cd6ca3
commit
6a74c97a0a
@ -0,0 +1,73 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import flask
|
||||
import werkzeug.http
|
||||
from flask import Flask
|
||||
from flask.signals import request_finished, request_started
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_content_type_json(content_type: str) -> bool:
|
||||
if not content_type:
|
||||
return False
|
||||
content_type_no_option, _ = werkzeug.http.parse_options_header(content_type)
|
||||
return content_type_no_option.lower() == "application/json"
|
||||
|
||||
|
||||
def _log_request_started(_sender, **_extra):
|
||||
"""Log the start of a request."""
|
||||
if not _logger.isEnabledFor(logging.DEBUG):
|
||||
return
|
||||
|
||||
request = flask.request
|
||||
if not (_is_content_type_json(request.content_type) and request.data):
|
||||
_logger.debug("Received Request %s -> %s", request.method, request.path)
|
||||
return
|
||||
try:
|
||||
json_data = json.loads(request.data)
|
||||
except (TypeError, ValueError):
|
||||
_logger.exception("Failed to parse JSON request")
|
||||
return
|
||||
formatted_json = json.dumps(json_data, ensure_ascii=False, indent=2)
|
||||
_logger.debug(
|
||||
"Received Request %s -> %s, Request Body:\n%s",
|
||||
request.method,
|
||||
request.path,
|
||||
formatted_json,
|
||||
)
|
||||
|
||||
|
||||
def _log_request_finished(_sender, response, **_extra):
|
||||
"""Log the end of a request."""
|
||||
if not _logger.isEnabledFor(logging.DEBUG) or response is None:
|
||||
return
|
||||
|
||||
if not _is_content_type_json(response.content_type):
|
||||
_logger.debug("Response %s %s", response.status, response.content_type)
|
||||
return
|
||||
|
||||
response_data = response.get_data(as_text=True)
|
||||
try:
|
||||
json_data = json.loads(response_data)
|
||||
except (TypeError, ValueError):
|
||||
_logger.exception("Failed to parse JSON response")
|
||||
return
|
||||
formatted_json = json.dumps(json_data, ensure_ascii=False, indent=2)
|
||||
_logger.debug(
|
||||
"Response %s %s, Response Body:\n%s",
|
||||
response.status,
|
||||
response.content_type,
|
||||
formatted_json,
|
||||
)
|
||||
|
||||
|
||||
def init_app(app: Flask):
|
||||
"""Initialize the request logging extension."""
|
||||
if not dify_config.ENABLE_REQUEST_LOGGING:
|
||||
return
|
||||
request_started.connect(_log_request_started, app)
|
||||
request_finished.connect(_log_request_finished, app)
|
||||
@ -0,0 +1,265 @@
|
||||
import json
|
||||
import logging
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from flask import Flask, Response
|
||||
|
||||
from configs import dify_config
|
||||
from extensions import ext_request_logging
|
||||
from extensions.ext_request_logging import _is_content_type_json, _log_request_finished, init_app
|
||||
|
||||
|
||||
def test_is_content_type_json():
|
||||
"""
|
||||
Test the _is_content_type_json function.
|
||||
"""
|
||||
|
||||
assert _is_content_type_json("application/json") is True
|
||||
# content type header with charset option.
|
||||
assert _is_content_type_json("application/json; charset=utf-8") is True
|
||||
# content type header with charset option, in uppercase.
|
||||
assert _is_content_type_json("APPLICATION/JSON; CHARSET=UTF-8") is True
|
||||
assert _is_content_type_json("text/html") is False
|
||||
assert _is_content_type_json("") is False
|
||||
|
||||
|
||||
_KEY_NEEDLE = "needle"
|
||||
_VALUE_NEEDLE = _KEY_NEEDLE[::-1]
|
||||
_RESPONSE_NEEDLE = "response"
|
||||
|
||||
|
||||
def _get_test_app():
|
||||
app = Flask(__name__)
|
||||
|
||||
@app.route("/", methods=["GET", "POST"])
|
||||
def handler():
|
||||
return _RESPONSE_NEEDLE
|
||||
|
||||
return app
|
||||
|
||||
|
||||
# NOTE(QuantumGhost): Due to the design of Flask, we need to use monkey patch to write tests.
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_request_receiver(monkeypatch) -> mock.Mock:
|
||||
mock_log_request_started = mock.Mock()
|
||||
monkeypatch.setattr(ext_request_logging, "_log_request_started", mock_log_request_started)
|
||||
return mock_log_request_started
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response_receiver(monkeypatch) -> mock.Mock:
|
||||
mock_log_request_finished = mock.Mock()
|
||||
monkeypatch.setattr(ext_request_logging, "_log_request_finished", mock_log_request_finished)
|
||||
return mock_log_request_finished
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_logger(monkeypatch) -> logging.Logger:
|
||||
_logger = mock.MagicMock(spec=logging.Logger)
|
||||
monkeypatch.setattr(ext_request_logging, "_logger", _logger)
|
||||
return _logger
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_request_logging(monkeypatch):
|
||||
monkeypatch.setattr(dify_config, "ENABLE_REQUEST_LOGGING", True)
|
||||
|
||||
|
||||
class TestRequestLoggingExtension:
|
||||
def test_receiver_should_not_be_invoked_if_configuration_is_disabled(
|
||||
self,
|
||||
monkeypatch,
|
||||
mock_request_receiver,
|
||||
mock_response_receiver,
|
||||
):
|
||||
monkeypatch.setattr(dify_config, "ENABLE_REQUEST_LOGGING", False)
|
||||
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
client.get("/")
|
||||
|
||||
mock_request_receiver.assert_not_called()
|
||||
mock_response_receiver.assert_not_called()
|
||||
|
||||
def test_receiver_should_be_called_if_enabled(
|
||||
self,
|
||||
enable_request_logging,
|
||||
mock_request_receiver,
|
||||
mock_response_receiver,
|
||||
):
|
||||
"""
|
||||
Test the request logging extension with JSON data.
|
||||
"""
|
||||
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
client.post("/", json={_KEY_NEEDLE: _VALUE_NEEDLE})
|
||||
|
||||
mock_request_receiver.assert_called_once()
|
||||
mock_response_receiver.assert_called_once()
|
||||
|
||||
|
||||
class TestLoggingLevel:
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_logging_should_be_skipped_if_level_is_above_debug(self, enable_request_logging, mock_logger):
|
||||
mock_logger.isEnabledFor.return_value = False
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
client.post("/", json={_KEY_NEEDLE: _VALUE_NEEDLE})
|
||||
mock_logger.debug.assert_not_called()
|
||||
|
||||
|
||||
class TestRequestReceiverLogging:
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_non_json_request(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
client.post("/", data="plain text")
|
||||
assert mock_logger.debug.call_count == 1
|
||||
call_args = mock_logger.debug.call_args[0]
|
||||
assert "Received Request" in call_args[0]
|
||||
assert call_args[1] == "POST"
|
||||
assert call_args[2] == "/"
|
||||
assert "Request Body" not in call_args[0]
|
||||
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_json_request(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
client.post("/", json={_KEY_NEEDLE: _VALUE_NEEDLE})
|
||||
assert mock_logger.debug.call_count == 1
|
||||
call_args = mock_logger.debug.call_args[0]
|
||||
assert "Received Request" in call_args[0]
|
||||
assert "Request Body" in call_args[0]
|
||||
assert call_args[1] == "POST"
|
||||
assert call_args[2] == "/"
|
||||
assert _KEY_NEEDLE in call_args[3]
|
||||
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_json_request_with_empty_body(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
client.post("/", headers={"Content-Type": "application/json"})
|
||||
|
||||
assert mock_logger.debug.call_count == 1
|
||||
call_args = mock_logger.debug.call_args[0]
|
||||
assert "Received Request" in call_args[0]
|
||||
assert "Request Body" not in call_args[0]
|
||||
assert call_args[1] == "POST"
|
||||
assert call_args[2] == "/"
|
||||
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_json_request_with_invalid_json_as_body(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
client.post(
|
||||
"/",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data="{",
|
||||
)
|
||||
assert mock_logger.debug.call_count == 0
|
||||
assert mock_logger.exception.call_count == 1
|
||||
|
||||
exception_call_args = mock_logger.exception.call_args[0]
|
||||
assert exception_call_args[0] == "Failed to parse JSON request"
|
||||
|
||||
|
||||
class TestResponseReceiverLogging:
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_non_json_response(self, enable_request_logging, mock_logger):
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
app = _get_test_app()
|
||||
response = Response(
|
||||
"OK",
|
||||
headers={"Content-Type": "text/plain"},
|
||||
)
|
||||
_log_request_finished(app, response)
|
||||
assert mock_logger.debug.call_count == 1
|
||||
call_args = mock_logger.debug.call_args[0]
|
||||
assert "Response" in call_args[0]
|
||||
assert "200" in call_args[1]
|
||||
assert call_args[2] == "text/plain"
|
||||
assert "Response Body" not in call_args[0]
|
||||
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_json_response(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
app = _get_test_app()
|
||||
response = Response(
|
||||
json.dumps({_KEY_NEEDLE: _VALUE_NEEDLE}),
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
_log_request_finished(app, response)
|
||||
assert mock_logger.debug.call_count == 1
|
||||
call_args = mock_logger.debug.call_args[0]
|
||||
assert "Response" in call_args[0]
|
||||
assert "Response Body" in call_args[0]
|
||||
assert "200" in call_args[1]
|
||||
assert call_args[2] == "application/json"
|
||||
assert _KEY_NEEDLE in call_args[3]
|
||||
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_json_request_with_invalid_json_as_body(self, enable_request_logging, mock_logger, mock_response_receiver):
|
||||
mock_logger.isEnabledFor.return_value = True
|
||||
app = _get_test_app()
|
||||
|
||||
response = Response(
|
||||
"{",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
_log_request_finished(app, response)
|
||||
assert mock_logger.debug.call_count == 0
|
||||
assert mock_logger.exception.call_count == 1
|
||||
|
||||
exception_call_args = mock_logger.exception.call_args[0]
|
||||
assert exception_call_args[0] == "Failed to parse JSON response"
|
||||
|
||||
|
||||
class TestResponseUnmodified:
|
||||
def test_when_request_logging_disabled(self):
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
response = client.post(
|
||||
"/",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data="{",
|
||||
)
|
||||
assert response.text == _RESPONSE_NEEDLE
|
||||
assert response.status_code == 200
|
||||
|
||||
@pytest.mark.usefixtures("enable_request_logging")
|
||||
def test_when_request_logging_enabled(self, enable_request_logging):
|
||||
app = _get_test_app()
|
||||
init_app(app)
|
||||
|
||||
with app.test_client() as client:
|
||||
response = client.post(
|
||||
"/",
|
||||
headers={"Content-Type": "application/json"},
|
||||
data="{",
|
||||
)
|
||||
assert response.text == _RESPONSE_NEEDLE
|
||||
assert response.status_code == 200
|
||||
Loading…
Reference in New Issue