feat(backend): 添加中间件和日志功能

- 新增跨域、gzip压缩和trace中间件
- 实现请求ID生成和日志记录功能
- 优化环境变量加载和数据库连接逻辑- 重构部分代码以提高可维护性
main
lijiazheng 6 months ago
parent f0f777dcc7
commit 560887d306

@ -6,13 +6,15 @@ import asyncio
from starlette.middleware.cors import CORSMiddleware from starlette.middleware.cors import CORSMiddleware
from middlewares.handle import handle_middleware
from rag import rag_generate_rule from rag import rag_generate_rule
from util.log_util import logger
@asynccontextmanager @asynccontextmanager
async def lifespan(app : FastAPI): async def lifespan(app : FastAPI):
# logger.info(f'{AppConfig.app_name}开始启动,当前运行环境{os.environ.get('APP_ENV', '未读取到环境变量APP_ENV')}') # logger.info(f'{AppConfig.app_name}开始启动,当前运行环境{os.environ.get('APP_ENV', '未读取到环境变量APP_ENV')}')
print(f'开始启动') print(f'开始启动')
print(f'启动成功')
yield yield
print(f'关闭成功') print(f'关闭成功')
@ -21,13 +23,23 @@ app = FastAPI(title="RAG SQL Generator API", description="调用RAG生成SQL语
rule_router = APIRouter(prefix="/rule") rule_router = APIRouter(prefix="/rule")
# 添加CORS中间件 - 允许所有来源(开发环境) # 添加CORS中间件 - 允许所有来源(开发环境)
app.add_middleware( # app.add_middleware(
CORSMiddleware, # CORSMiddleware,
allow_origins=["*"], # 在生产环境中应该指定具体的域名 # allow_origins=["*"], # 在生产环境中应该指定具体的域名
allow_credentials=True, # allow_credentials=True,
allow_methods=["*"], # allow_methods=["*"],
allow_headers=["*"], # allow_headers=["*"],
) # )
# 加载中间件
handle_middleware(app)
class QueryRequest(BaseModel): class QueryRequest(BaseModel):
query: str query: str
@ -48,6 +60,10 @@ async def generate_sql(request: QueryRequest):
result = await rag_generate_rule(request.query) result = await rag_generate_rule(request.query)
return QueryResponse(sql=result) return QueryResponse(sql=result)
except Exception as e: except Exception as e:
# print(e)
# print(e.__context__)
logger.info(e)
logger.info(e.__context__)
raise HTTPException(status_code=500, detail=f"生成SQL时出错: {str(e)}") raise HTTPException(status_code=500, detail=f"生成SQL时出错: {str(e)}")
@rule_router.get("/health") @rule_router.get("/health")

@ -64,7 +64,6 @@ class GetConfig:
# 加载配置 # 加载配置
load_dotenv(env_file) load_dotenv(env_file)
print(f'当前运行环境为:{run_env}') print(f'当前运行环境为:{run_env}')
print(os.getenv("DB_HOST"))
# 实例化获取配置类 # 实例化获取配置类
get_config = GetConfig() get_config = GetConfig()

@ -0,0 +1,25 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
def add_cors_middleware(app: FastAPI):
"""
添加跨域中间件
:param app: FastAPI对象
:return:
"""
# 前端页面url
origins = [
'http://localhost:80',
'http://127.0.0.1:80',
]
# 后台api允许跨域
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=['*'],
allow_headers=['*'],
)

@ -0,0 +1,12 @@
from fastapi import FastAPI
from starlette.middleware.gzip import GZipMiddleware
def add_gzip_middleware(app: FastAPI):
"""
添加gzip压缩中间件
:param app: FastAPI对象
:return:
"""
app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=9)

@ -0,0 +1,16 @@
from fastapi import FastAPI
from middlewares.cors_middleware import add_cors_middleware
from middlewares.gzip_middleware import add_gzip_middleware
from middlewares.trace_middleware import add_trace_middleware
def handle_middleware(app: FastAPI):
"""
全局中间件处理
"""
# 加载跨域中间件
add_cors_middleware(app)
# 加载gzip压缩中间件
add_gzip_middleware(app)
# 加载trace中间件
add_trace_middleware(app)

@ -0,0 +1,17 @@
from fastapi import FastAPI
from .ctx import TraceCtx
from .middle import TraceASGIMiddleware
__all__ = ('TraceASGIMiddleware', 'TraceCtx')
__version__ = '0.1.0'
def add_trace_middleware(app: FastAPI):
"""
添加trace中间件
:param app: FastAPI对象
:return:
"""
app.add_middleware(TraceASGIMiddleware)

@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
"""
@author: peng
@file: ctx.py
@time: 2025/1/17 16:57
"""
import contextvars
from uuid import uuid4
CTX_REQUEST_ID: contextvars.ContextVar[str] = contextvars.ContextVar('request-id', default='')
class TraceCtx:
@staticmethod
def set_id():
_id = uuid4().hex
CTX_REQUEST_ID.set(_id)
return _id
@staticmethod
def get_id():
return CTX_REQUEST_ID.get()

@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
"""
@author: peng
@file: middle.py
@time: 2025/1/17 16:57
"""
from functools import wraps
from starlette.types import ASGIApp, Message, Receive, Scope, Send
from .span import get_current_span, Span
class TraceASGIMiddleware:
"""
fastapi-example:
app = FastAPI()
app.add_middleware(TraceASGIMiddleware)
"""
def __init__(self, app: ASGIApp) -> None:
self.app = app
@staticmethod
async def my_receive(receive: Receive, span: Span):
await span.request_before()
@wraps(receive)
async def my_receive():
message = await receive()
await span.request_after(message)
return message
return my_receive
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope['type'] != 'http':
await self.app(scope, receive, send)
return
async with get_current_span(scope) as span:
handle_outgoing_receive = await self.my_receive(receive, span)
async def handle_outgoing_request(message: 'Message') -> None:
await span.response(message)
await send(message)
await self.app(scope, handle_outgoing_receive, handle_outgoing_request)

@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
"""
@author: peng
@file: span.py
@time: 2025/1/17 16:57
"""
from contextlib import asynccontextmanager
from starlette.types import Scope, Message
from .ctx import TraceCtx
class Span:
"""
整个http生命周期
request(before) --> request(after) --> response(before) --> response(after)
"""
def __init__(self, scope: Scope):
self.scope = scope
async def request_before(self):
"""
request_before: 处理header信息等, 如记录请求体信息
"""
TraceCtx.set_id()
async def request_after(self, message: Message):
"""
request_after: 处理请求bytes 如记录请求参数
example:
message: {'type': 'http.request', 'body': b'{\r\n "name": "\xe8\x8b\x8f\xe8\x8b\x8f\xe8\x8b\x8f"\r\n}', 'more_body': False}
"""
return message
async def response(self, message: Message):
"""
if message['type'] == "http.response.start": -----> request-before
pass
if message['type'] == "http.response.body": -----> request-after
message.get('body', b'')
pass
"""
if message['type'] == 'http.response.start':
message['headers'].append((b'request-id', TraceCtx.get_id().encode()))
return message
@asynccontextmanager
async def get_current_span(scope: Scope):
yield Span(scope)

@ -1,31 +1,53 @@
from util import use_pgvector, use_opanai, use_mysql from util import use_pgvector, use_opanai, use_mysql
from util.log_util import logger
from util.use_mysql import search_desc_by_table_names from util.use_mysql import search_desc_by_table_names
async def rag_generate_rule(query : str): async def rag_generate_rule(query : str):
if not query: if not query:
return "请输入问题" return "请输入问题"
# 连接数据库
pgvector_conn = use_pgvector.connect_to_db() # 询问大模型抽取用户问题中表名
prompt = [
{"role": "system", "content": f"""
需要从需要从用户问题中抽取表名请返回一个列表格式为[表名1, 表名2, ...]
例如用户问题"请查询表A和表B的交集"则返回[A, B]
请根据用户问题抽取表名请勿返回其他内容
"""},
{"role": "user", "content": query}
]
tables_in_query = await use_opanai.generation_rule(prompt)
tables_in_query = tables_in_query.strip('[]')
tables_in_query = [item.strip() for item in tables_in_query.split(',')]
# 将问题向量化 # 将问题向量化
query_emb = await use_opanai.generation_vector("query") query_emb = await use_opanai.generation_vector("query")
# 连接pgvector数据库
pgvector_conn = use_pgvector.connect_to_db()
# 根据问题关联数据库表得到table_name_list # 根据问题关联数据库表得到table_name_list
similar_docs = use_pgvector.search_similar_table(pgvector_conn, query_emb, limit=4) similar_docs = use_pgvector.search_similar_table(pgvector_conn, query_emb, limit=4)
rerank_docs = await use_opanai.rerank_documents(query, similar_docs, top_n=2) rerank_docs = await use_opanai.rerank_documents(query, similar_docs, top_n=2)
table_name_list = [similar_docs[int(index.strip())-1][0] for index in rerank_docs.strip('[]').split(',')] table_name_list = [similar_docs[int(index.strip())-1][0] for index in rerank_docs.strip('[]').split(',')]
print(f"[table_name_list]: {table_name_list}") table_name_list.extend(tables_in_query)
# print(f"【table_name_list】: {table_name_list}")
logger.info(f"【table_name_list】: {table_name_list}")
# 获得相关表的schema # 获得相关表的schema
db_client = await use_mysql.get_db() db_client = await use_mysql.get_db()
schema = await search_desc_by_table_names(table_name_list, db_client) schema = await search_desc_by_table_names(table_name_list, db_client)
print(f"[schema]: {schema}") # print(f"【schema】: {schema}")
logger.info(f"【schema】: {schema}")
# 根据问题搜索相关案例 # 根据问题搜索相关案例
similar_case = use_pgvector.search_similar_case(pgvector_conn, query_emb, limit=3) similar_case = use_pgvector.search_similar_case(pgvector_conn, query_emb, limit=3)
print(f"[similar_case]: {similar_case}") # print(f"【similar_case】: {similar_case}")
logger.info(f"【similar_case】: {similar_case}")
# 询问大模型生成SQL # 询问大模型生成SQL
prompt = [ prompt = [
@ -40,7 +62,7 @@ async def rag_generate_rule(query : str):
"""}, """},
{"role": "user", "content": query} {"role": "user", "content": query}
] ]
print(f"[prompt]: {prompt}") # print(f"【prompt】: {prompt}")
ans = await use_opanai.generation_rule(prompt) ans = await use_opanai.generation_rule(prompt)
return ans return ans

@ -0,0 +1,63 @@
import os
import sys
import time
from loguru import logger as _logger
from typing import Dict
from middlewares.trace_middleware import TraceCtx
class LoggerInitializer:
def __init__(self):
self.log_path = os.path.join(os.getcwd(), 'logs')
self.__ensure_log_directory_exists()
# self.log_path_error = os.path.join(self.log_path, f'{time.strftime("%Y-%m-%d")}_app.log')
self.log_path_error = os.path.join(self.log_path, f'app.log')
def __ensure_log_directory_exists(self):
"""
确保日志目录存在如果不存在则创建
"""
if not os.path.exists(self.log_path):
os.mkdir(self.log_path)
@staticmethod
def __filter(log: Dict):
"""
自定义日志过滤器添加trace_id
"""
log['trace_id'] = TraceCtx.get_id()
return log
def init_log(self):
"""
初始化日志配置
"""
# 自定义日志格式
format_str = (
'<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | '
'<cyan>{trace_id}</cyan> | '
'<level>{level: <8}</level> | '
'<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - '
'<level>{message}</level>'
)
_logger.remove()
# 移除后重新添加sys.stderr, 目的: 控制台输出与文件日志内容和结构一致
_logger.add(sys.stderr, filter=self.__filter, format=format_str, enqueue=True)
_logger.add(
self.log_path_error,
filter=self.__filter,
format=format_str,
# rotation='50MB',
rotation='daily', # 每天轮换一次
retention="15 days", # 保留最近15天的日志
encoding='utf-8',
enqueue=True,
compression=None,
)
return _logger
# 初始化日志处理器
log_initializer = LoggerInitializer()
logger = log_initializer.init_log()
Loading…
Cancel
Save