Merge branch 'main' into feat/r2
commit
f7a4e5d1a6
@ -0,0 +1,65 @@
|
||||
import contextvars
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import TypeVar
|
||||
|
||||
from flask import Flask, g, has_request_context
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def preserve_flask_contexts(
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> Iterator[None]:
|
||||
"""
|
||||
A context manager that handles:
|
||||
1. flask-login's UserProxy copy
|
||||
2. ContextVars copy
|
||||
3. flask_app.app_context()
|
||||
|
||||
This context manager ensures that the Flask application context is properly set up,
|
||||
the current user is preserved across context boundaries, and any provided context variables
|
||||
are set within the new context.
|
||||
|
||||
Note:
|
||||
This manager aims to allow use current_user cross thread and app context,
|
||||
but it's not the recommend use, it's better to pass user directly in parameters.
|
||||
|
||||
Args:
|
||||
flask_app: The Flask application instance
|
||||
context_vars: contextvars.Context object containing context variables to be set in the new context
|
||||
|
||||
Yields:
|
||||
None
|
||||
|
||||
Example:
|
||||
```python
|
||||
with preserve_flask_contexts(flask_app, context_vars=context_vars):
|
||||
# Code that needs Flask app context and context variables
|
||||
# Current user will be preserved if available
|
||||
```
|
||||
"""
|
||||
# Set context variables if provided
|
||||
if context_vars:
|
||||
for var, val in context_vars.items():
|
||||
var.set(val)
|
||||
|
||||
# Save current user before entering new app context
|
||||
saved_user = None
|
||||
if has_request_context() and hasattr(g, "_login_user"):
|
||||
saved_user = g._login_user
|
||||
|
||||
# Enter Flask app context
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
# Restore user in new app context if it was saved
|
||||
if saved_user is not None:
|
||||
g._login_user = saved_user
|
||||
|
||||
# Yield control back to the caller
|
||||
yield
|
||||
finally:
|
||||
# Any cleanup can be added here if needed
|
||||
pass
|
||||
@ -0,0 +1,124 @@
|
||||
import contextvars
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from flask import Flask
|
||||
from flask_login import LoginManager, UserMixin, current_user, login_user
|
||||
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
|
||||
|
||||
class User(UserMixin):
|
||||
"""Simple User class for testing."""
|
||||
|
||||
def __init__(self, id: str):
|
||||
self.id = id
|
||||
|
||||
def get_id(self) -> str:
|
||||
return self.id
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def login_app(app: Flask) -> Flask:
|
||||
"""Set up a Flask app with flask-login."""
|
||||
# Set a secret key for the app
|
||||
app.config["SECRET_KEY"] = "test-secret-key"
|
||||
|
||||
login_manager = LoginManager()
|
||||
login_manager.init_app(app)
|
||||
|
||||
@login_manager.user_loader
|
||||
def load_user(user_id: str) -> Optional[User]:
|
||||
if user_id == "test_user":
|
||||
return User("test_user")
|
||||
return None
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user() -> User:
|
||||
"""Create a test user."""
|
||||
return User("test_user")
|
||||
|
||||
|
||||
def test_current_user_not_accessible_across_threads(login_app: Flask, test_user: User):
|
||||
"""
|
||||
Test that current_user is not accessible in a different thread without preserve_flask_contexts.
|
||||
|
||||
This test demonstrates that without the preserve_flask_contexts, we cannot access
|
||||
current_user in a different thread, even with app_context.
|
||||
"""
|
||||
# Log in the user in the main thread
|
||||
with login_app.test_request_context():
|
||||
login_user(test_user)
|
||||
assert current_user.is_authenticated
|
||||
assert current_user.id == "test_user"
|
||||
|
||||
# Store the result of the thread execution
|
||||
result = {"user_accessible": True, "error": None}
|
||||
|
||||
# Define a function to run in a separate thread
|
||||
def check_user_in_thread():
|
||||
try:
|
||||
# Try to access current_user in a different thread with app_context
|
||||
with login_app.app_context():
|
||||
# This should fail because current_user is not accessible across threads
|
||||
# without preserve_flask_contexts
|
||||
result["user_accessible"] = current_user.is_authenticated
|
||||
except Exception as e:
|
||||
result["error"] = str(e) # type: ignore
|
||||
|
||||
# Run the function in a separate thread
|
||||
thread = threading.Thread(target=check_user_in_thread)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
# Verify that we got an error or current_user is not authenticated
|
||||
assert result["error"] is not None or (result["user_accessible"] is not None and not result["user_accessible"])
|
||||
|
||||
|
||||
def test_current_user_accessible_with_preserve_flask_contexts(login_app: Flask, test_user: User):
|
||||
"""
|
||||
Test that current_user is accessible in a different thread with preserve_flask_contexts.
|
||||
|
||||
This test demonstrates that with the preserve_flask_contexts, we can access
|
||||
current_user in a different thread.
|
||||
"""
|
||||
# Log in the user in the main thread
|
||||
with login_app.test_request_context():
|
||||
login_user(test_user)
|
||||
assert current_user.is_authenticated
|
||||
assert current_user.id == "test_user"
|
||||
|
||||
# Save the context variables
|
||||
context_vars = contextvars.copy_context()
|
||||
|
||||
# Store the result of the thread execution
|
||||
result = {"user_accessible": False, "user_id": None, "error": None}
|
||||
|
||||
# Define a function to run in a separate thread
|
||||
def check_user_in_thread_with_manager():
|
||||
try:
|
||||
# Use preserve_flask_contexts to access current_user in a different thread
|
||||
with preserve_flask_contexts(login_app, context_vars):
|
||||
from flask_login import current_user
|
||||
|
||||
if current_user:
|
||||
result["user_accessible"] = True
|
||||
result["user_id"] = current_user.id
|
||||
else:
|
||||
result["user_accessible"] = False
|
||||
except Exception as e:
|
||||
result["error"] = str(e) # type: ignore
|
||||
|
||||
# Run the function in a separate thread
|
||||
thread = threading.Thread(target=check_user_in_thread_with_manager)
|
||||
thread.start()
|
||||
thread.join()
|
||||
|
||||
# Verify that current_user is accessible and has the correct ID
|
||||
assert result["error"] is None
|
||||
assert result["user_accessible"] is True
|
||||
assert result["user_id"] == "test_user"
|
||||
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue