From 560887d30663663339518581d1b7d7d17b45eb56 Mon Sep 17 00:00:00 2001 From: lijiazheng Date: Tue, 12 Aug 2025 17:39:19 +0800 Subject: [PATCH] =?UTF-8?q?feat(backend):=20=E6=B7=BB=E5=8A=A0=E4=B8=AD?= =?UTF-8?q?=E9=97=B4=E4=BB=B6=E5=92=8C=E6=97=A5=E5=BF=97=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增跨域、gzip压缩和trace中间件 - 实现请求ID生成和日志记录功能 - 优化环境变量加载和数据库连接逻辑- 重构部分代码以提高可维护性 --- backend.py | 32 +++++++++--- env.py | 1 - middlewares/cors_middleware.py | 25 ++++++++++ middlewares/gzip_middleware.py | 12 +++++ middlewares/handle.py | 16 ++++++ middlewares/trace_middleware/__init__.py | 17 +++++++ middlewares/trace_middleware/ctx.py | 23 +++++++++ middlewares/trace_middleware/middle.py | 47 ++++++++++++++++++ middlewares/trace_middleware/span.py | 52 +++++++++++++++++++ rag.py | 34 ++++++++++--- util/log_util.py | 63 ++++++++++++++++++++++++ 11 files changed, 307 insertions(+), 15 deletions(-) create mode 100644 middlewares/cors_middleware.py create mode 100644 middlewares/gzip_middleware.py create mode 100644 middlewares/handle.py create mode 100644 middlewares/trace_middleware/__init__.py create mode 100644 middlewares/trace_middleware/ctx.py create mode 100644 middlewares/trace_middleware/middle.py create mode 100644 middlewares/trace_middleware/span.py create mode 100644 util/log_util.py diff --git a/backend.py b/backend.py index 486f474..a75e82c 100644 --- a/backend.py +++ b/backend.py @@ -6,13 +6,15 @@ import asyncio from starlette.middleware.cors import CORSMiddleware +from middlewares.handle import handle_middleware from rag import rag_generate_rule +from util.log_util import logger + @asynccontextmanager async def lifespan(app : FastAPI): # logger.info(f'{AppConfig.app_name}开始启动,当前运行环境{os.environ.get('APP_ENV', '未读取到环境变量APP_ENV')}') print(f'开始启动') - print(f'启动成功') yield print(f'关闭成功') @@ -21,13 +23,23 @@ app = FastAPI(title="RAG SQL Generator API", description="调用RAG生成SQL语 rule_router = APIRouter(prefix="/rule") # 添加CORS中间件 - 允许所有来源(开发环境) -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], # 在生产环境中应该指定具体的域名 - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +# app.add_middleware( +# CORSMiddleware, +# allow_origins=["*"], # 在生产环境中应该指定具体的域名 +# allow_credentials=True, +# allow_methods=["*"], +# allow_headers=["*"], +# ) + +# 加载中间件 +handle_middleware(app) + + + + + + + class QueryRequest(BaseModel): query: str @@ -48,6 +60,10 @@ async def generate_sql(request: QueryRequest): result = await rag_generate_rule(request.query) return QueryResponse(sql=result) except Exception as e: + # print(e) + # print(e.__context__) + logger.info(e) + logger.info(e.__context__) raise HTTPException(status_code=500, detail=f"生成SQL时出错: {str(e)}") @rule_router.get("/health") diff --git a/env.py b/env.py index 92991fb..c6baec6 100644 --- a/env.py +++ b/env.py @@ -64,7 +64,6 @@ class GetConfig: # 加载配置 load_dotenv(env_file) print(f'当前运行环境为:{run_env}') - print(os.getenv("DB_HOST")) # 实例化获取配置类 get_config = GetConfig() diff --git a/middlewares/cors_middleware.py b/middlewares/cors_middleware.py new file mode 100644 index 0000000..55508e7 --- /dev/null +++ b/middlewares/cors_middleware.py @@ -0,0 +1,25 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + + +def add_cors_middleware(app: FastAPI): + """ + 添加跨域中间件 + + :param app: FastAPI对象 + :return: + """ + # 前端页面url + origins = [ + 'http://localhost:80', + 'http://127.0.0.1:80', + ] + + # 后台api允许跨域 + app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=['*'], + allow_headers=['*'], + ) diff --git a/middlewares/gzip_middleware.py b/middlewares/gzip_middleware.py new file mode 100644 index 0000000..eb371ce --- /dev/null +++ b/middlewares/gzip_middleware.py @@ -0,0 +1,12 @@ +from fastapi import FastAPI +from starlette.middleware.gzip import GZipMiddleware + + +def add_gzip_middleware(app: FastAPI): + """ + 添加gzip压缩中间件 + + :param app: FastAPI对象 + :return: + """ + app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=9) diff --git a/middlewares/handle.py b/middlewares/handle.py new file mode 100644 index 0000000..1c6943e --- /dev/null +++ b/middlewares/handle.py @@ -0,0 +1,16 @@ +from fastapi import FastAPI +from middlewares.cors_middleware import add_cors_middleware +from middlewares.gzip_middleware import add_gzip_middleware +from middlewares.trace_middleware import add_trace_middleware + +def handle_middleware(app: FastAPI): + """ + 全局中间件处理 + """ + # 加载跨域中间件 + add_cors_middleware(app) + # 加载gzip压缩中间件 + add_gzip_middleware(app) + # 加载trace中间件 + add_trace_middleware(app) + diff --git a/middlewares/trace_middleware/__init__.py b/middlewares/trace_middleware/__init__.py new file mode 100644 index 0000000..76f8d85 --- /dev/null +++ b/middlewares/trace_middleware/__init__.py @@ -0,0 +1,17 @@ +from fastapi import FastAPI +from .ctx import TraceCtx +from .middle import TraceASGIMiddleware + +__all__ = ('TraceASGIMiddleware', 'TraceCtx') + +__version__ = '0.1.0' + + +def add_trace_middleware(app: FastAPI): + """ + 添加trace中间件 + + :param app: FastAPI对象 + :return: + """ + app.add_middleware(TraceASGIMiddleware) diff --git a/middlewares/trace_middleware/ctx.py b/middlewares/trace_middleware/ctx.py new file mode 100644 index 0000000..558a5c9 --- /dev/null +++ b/middlewares/trace_middleware/ctx.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- +""" +@author: peng +@file: ctx.py +@time: 2025/1/17 16:57 +""" + +import contextvars +from uuid import uuid4 + +CTX_REQUEST_ID: contextvars.ContextVar[str] = contextvars.ContextVar('request-id', default='') + + +class TraceCtx: + @staticmethod + def set_id(): + _id = uuid4().hex + CTX_REQUEST_ID.set(_id) + return _id + + @staticmethod + def get_id(): + return CTX_REQUEST_ID.get() diff --git a/middlewares/trace_middleware/middle.py b/middlewares/trace_middleware/middle.py new file mode 100644 index 0000000..a071692 --- /dev/null +++ b/middlewares/trace_middleware/middle.py @@ -0,0 +1,47 @@ +# -*- coding: utf-8 -*- +""" +@author: peng +@file: middle.py +@time: 2025/1/17 16:57 +""" + +from functools import wraps +from starlette.types import ASGIApp, Message, Receive, Scope, Send +from .span import get_current_span, Span + + +class TraceASGIMiddleware: + """ + fastapi-example: + app = FastAPI() + app.add_middleware(TraceASGIMiddleware) + """ + + def __init__(self, app: ASGIApp) -> None: + self.app = app + + @staticmethod + async def my_receive(receive: Receive, span: Span): + await span.request_before() + + @wraps(receive) + async def my_receive(): + message = await receive() + await span.request_after(message) + return message + + return my_receive + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope['type'] != 'http': + await self.app(scope, receive, send) + return + + async with get_current_span(scope) as span: + handle_outgoing_receive = await self.my_receive(receive, span) + + async def handle_outgoing_request(message: 'Message') -> None: + await span.response(message) + await send(message) + + await self.app(scope, handle_outgoing_receive, handle_outgoing_request) diff --git a/middlewares/trace_middleware/span.py b/middlewares/trace_middleware/span.py new file mode 100644 index 0000000..1e38eab --- /dev/null +++ b/middlewares/trace_middleware/span.py @@ -0,0 +1,52 @@ +# -*- coding: utf-8 -*- +""" +@author: peng +@file: span.py +@time: 2025/1/17 16:57 +""" + +from contextlib import asynccontextmanager +from starlette.types import Scope, Message +from .ctx import TraceCtx + + +class Span: + """ + 整个http生命周期: + request(before) --> request(after) --> response(before) --> response(after) + """ + + def __init__(self, scope: Scope): + self.scope = scope + + async def request_before(self): + """ + request_before: 处理header信息等, 如记录请求体信息 + """ + TraceCtx.set_id() + + async def request_after(self, message: Message): + """ + request_after: 处理请求bytes, 如记录请求参数 + + example: + message: {'type': 'http.request', 'body': b'{\r\n "name": "\xe8\x8b\x8f\xe8\x8b\x8f\xe8\x8b\x8f"\r\n}', 'more_body': False} + """ + return message + + async def response(self, message: Message): + """ + if message['type'] == "http.response.start": -----> request-before + pass + if message['type'] == "http.response.body": -----> request-after + message.get('body', b'') + pass + """ + if message['type'] == 'http.response.start': + message['headers'].append((b'request-id', TraceCtx.get_id().encode())) + return message + + +@asynccontextmanager +async def get_current_span(scope: Scope): + yield Span(scope) diff --git a/rag.py b/rag.py index b5b13ae..b78a313 100644 --- a/rag.py +++ b/rag.py @@ -1,31 +1,53 @@ from util import use_pgvector, use_opanai, use_mysql +from util.log_util import logger from util.use_mysql import search_desc_by_table_names async def rag_generate_rule(query : str): if not query: return "请输入问题" - # 连接数据库 - pgvector_conn = use_pgvector.connect_to_db() + + # 询问大模型抽取用户问题中表名 + prompt = [ + {"role": "system", "content": f""" + 需要从需要从用户问题中抽取表名,请返回一个列表,格式为:[表名1, 表名2, ...]。 + + 例如:用户问题:"请查询表A和表B的交集",则返回:[A, B]。 + + 请根据用户问题,抽取表名,请勿返回其他内容。 + + """}, + {"role": "user", "content": query} + ] + tables_in_query = await use_opanai.generation_rule(prompt) + tables_in_query = tables_in_query.strip('[]') + tables_in_query = [item.strip() for item in tables_in_query.split(',')] # 将问题向量化 query_emb = await use_opanai.generation_vector("query") + # 连接pgvector数据库 + pgvector_conn = use_pgvector.connect_to_db() + # 根据问题关联数据库表,得到table_name_list similar_docs = use_pgvector.search_similar_table(pgvector_conn, query_emb, limit=4) rerank_docs = await use_opanai.rerank_documents(query, similar_docs, top_n=2) table_name_list = [similar_docs[int(index.strip())-1][0] for index in rerank_docs.strip('[]').split(',')] - print(f"[table_name_list]: {table_name_list}") + table_name_list.extend(tables_in_query) + # print(f"【table_name_list】: {table_name_list}") + logger.info(f"【table_name_list】: {table_name_list}") # 获得相关表的schema db_client = await use_mysql.get_db() schema = await search_desc_by_table_names(table_name_list, db_client) - print(f"[schema]: {schema}") + # print(f"【schema】: {schema}") + logger.info(f"【schema】: {schema}") # 根据问题搜索相关案例 similar_case = use_pgvector.search_similar_case(pgvector_conn, query_emb, limit=3) - print(f"[similar_case]: {similar_case}") + # print(f"【similar_case】: {similar_case}") + logger.info(f"【similar_case】: {similar_case}") # 询问大模型生成SQL prompt = [ @@ -40,7 +62,7 @@ async def rag_generate_rule(query : str): """}, {"role": "user", "content": query} ] - print(f"[prompt]: {prompt}") + # print(f"【prompt】: {prompt}") ans = await use_opanai.generation_rule(prompt) return ans diff --git a/util/log_util.py b/util/log_util.py new file mode 100644 index 0000000..4cd4222 --- /dev/null +++ b/util/log_util.py @@ -0,0 +1,63 @@ +import os +import sys +import time +from loguru import logger as _logger +from typing import Dict +from middlewares.trace_middleware import TraceCtx + + +class LoggerInitializer: + def __init__(self): + self.log_path = os.path.join(os.getcwd(), 'logs') + self.__ensure_log_directory_exists() + # self.log_path_error = os.path.join(self.log_path, f'{time.strftime("%Y-%m-%d")}_app.log') + self.log_path_error = os.path.join(self.log_path, f'app.log') + + def __ensure_log_directory_exists(self): + """ + 确保日志目录存在,如果不存在则创建 + """ + if not os.path.exists(self.log_path): + os.mkdir(self.log_path) + + @staticmethod + def __filter(log: Dict): + """ + 自定义日志过滤器,添加trace_id + """ + log['trace_id'] = TraceCtx.get_id() + return log + + def init_log(self): + """ + 初始化日志配置 + """ + # 自定义日志格式 + format_str = ( + '{time:YYYY-MM-DD HH:mm:ss.SSS} | ' + '{trace_id} | ' + '{level: <8} | ' + '{name}:{function}:{line} - ' + '{message}' + ) + _logger.remove() + # 移除后重新添加sys.stderr, 目的: 控制台输出与文件日志内容和结构一致 + _logger.add(sys.stderr, filter=self.__filter, format=format_str, enqueue=True) + _logger.add( + self.log_path_error, + filter=self.__filter, + format=format_str, + # rotation='50MB', + rotation='daily', # 每天轮换一次 + retention="15 days", # 保留最近15天的日志 + encoding='utf-8', + enqueue=True, + compression=None, + ) + + return _logger + + +# 初始化日志处理器 +log_initializer = LoggerInitializer() +logger = log_initializer.init_log()