feat: 增加smart_wrap装饰器,用于同步函数,使其能在异步代码中使用而不阻塞事件循环

main
chenzhirong 3 months ago
parent 436ae236e1
commit c5854db3ad

@ -0,0 +1,75 @@
import asyncio
import inspect
from functools import wraps
from typing import Any, Callable
"""
这个装饰器是用于同步代码使其支持异步调用避免阻塞事件事件循环
使用示例代码:
@smart_wrap
def task(seconds):
time.sleep(seconds)
return "任务完成"
async def task():
await task(100)
return "over"
使用装饰器smart_wrap后该函数可以在异步(协程)函数中执行且不阻塞
"""
class SmartWrapper:
"""
智能包装器
自动识别环境判断同步代码是否需要进行异步包装
"""
def __init__(self, func: Callable):
self.func = func
self._is_coroutine = asyncio.iscoroutinefunction(func)
def __call__(self, *args, **kwargs):
"""智能调用:自动检测并适配调用环境"""
# 检测调用环境
try:
# 获取当前运行的事件循环
loop = asyncio.get_running_loop()
in_async_context = True
except RuntimeError:
in_async_context = False
# 检测调用方式是否被await
caller_frame = inspect.currentframe().f_back
caller_code = caller_frame.f_code
is_awaited = 'await' in caller_code.co_names or 'async' in caller_code.co_names
if in_async_context and is_awaited:
# 异步环境 + await调用 → 异步执行
return self._async_call(*args, **kwargs)
elif not in_async_context:
# 同步环境 → 同步执行
return self.func(*args, **kwargs)
else:
# 异步环境但没有await → 返回协程
return self._async_call(*args, **kwargs)
async def _async_call(self, *args, **kwargs):
"""异步执行"""
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None, # 使用默认线程池
lambda: self.func(*args, **kwargs)
)
def sync(self, *args, **kwargs):
"""强制同步执行"""
return self.func(*args, **kwargs)
async def async_mode(self, *args, **kwargs):
"""强制异步执行"""
return await self._async_call(*args, **kwargs)
# 便捷装饰器
def smart_wrap(func: Callable):
return SmartWrapper(func)

@ -1,4 +1,6 @@
import inspect import inspect
from contextlib import asynccontextmanager
from functools import wraps
from typing import Any from typing import Any
from sqlalchemy import Executable, Result, Select, Delete, Update, column, and_ from sqlalchemy import Executable, Result, Select, Delete, Update, column, and_
@ -138,6 +140,25 @@ AsyncSessionLocal = async_sessionmaker(
) )
# 获取数据库session的独立方法
@asynccontextmanager
async def get_db_session():
"""获取数据库session的上下文管理器"""
async with AsyncSessionLocal() as session:
async with session.begin():
yield session
def with_db_session(func):
@wraps(func)
async def wrapper(*args, **kwargs):
async with get_db_session() as session:
result = await func(db_session=session, *args, **kwargs)
return result
return wrapper
# 关闭引擎 # 关闭引擎
async def close_engine(): async def close_engine():
await engine.dispose() await engine.dispose()

Loading…
Cancel
Save