diff --git a/backend.py b/backend.py
index 486f474..a75e82c 100644
--- a/backend.py
+++ b/backend.py
@@ -6,13 +6,15 @@ import asyncio
from starlette.middleware.cors import CORSMiddleware
+from middlewares.handle import handle_middleware
from rag import rag_generate_rule
+from util.log_util import logger
+
@asynccontextmanager
async def lifespan(app : FastAPI):
# logger.info(f'{AppConfig.app_name}开始启动,当前运行环境{os.environ.get('APP_ENV', '未读取到环境变量APP_ENV')}')
print(f'开始启动')
- print(f'启动成功')
yield
print(f'关闭成功')
@@ -21,13 +23,23 @@ app = FastAPI(title="RAG SQL Generator API", description="调用RAG生成SQL语
rule_router = APIRouter(prefix="/rule")
# 添加CORS中间件 - 允许所有来源(开发环境)
-app.add_middleware(
- CORSMiddleware,
- allow_origins=["*"], # 在生产环境中应该指定具体的域名
- allow_credentials=True,
- allow_methods=["*"],
- allow_headers=["*"],
-)
+# app.add_middleware(
+# CORSMiddleware,
+# allow_origins=["*"], # 在生产环境中应该指定具体的域名
+# allow_credentials=True,
+# allow_methods=["*"],
+# allow_headers=["*"],
+# )
+
+# 加载中间件
+handle_middleware(app)
+
+
+
+
+
+
+
class QueryRequest(BaseModel):
query: str
@@ -48,6 +60,10 @@ async def generate_sql(request: QueryRequest):
result = await rag_generate_rule(request.query)
return QueryResponse(sql=result)
except Exception as e:
+ # print(e)
+ # print(e.__context__)
+ logger.info(e)
+ logger.info(e.__context__)
raise HTTPException(status_code=500, detail=f"生成SQL时出错: {str(e)}")
@rule_router.get("/health")
diff --git a/env.py b/env.py
index 92991fb..c6baec6 100644
--- a/env.py
+++ b/env.py
@@ -64,7 +64,6 @@ class GetConfig:
# 加载配置
load_dotenv(env_file)
print(f'当前运行环境为:{run_env}')
- print(os.getenv("DB_HOST"))
# 实例化获取配置类
get_config = GetConfig()
diff --git a/middlewares/cors_middleware.py b/middlewares/cors_middleware.py
new file mode 100644
index 0000000..55508e7
--- /dev/null
+++ b/middlewares/cors_middleware.py
@@ -0,0 +1,25 @@
+from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+
+
+def add_cors_middleware(app: FastAPI):
+ """
+ 添加跨域中间件
+
+ :param app: FastAPI对象
+ :return:
+ """
+ # 前端页面url
+ origins = [
+ 'http://localhost:80',
+ 'http://127.0.0.1:80',
+ ]
+
+ # 后台api允许跨域
+ app.add_middleware(
+ CORSMiddleware,
+ allow_origins=origins,
+ allow_credentials=True,
+ allow_methods=['*'],
+ allow_headers=['*'],
+ )
diff --git a/middlewares/gzip_middleware.py b/middlewares/gzip_middleware.py
new file mode 100644
index 0000000..eb371ce
--- /dev/null
+++ b/middlewares/gzip_middleware.py
@@ -0,0 +1,12 @@
+from fastapi import FastAPI
+from starlette.middleware.gzip import GZipMiddleware
+
+
+def add_gzip_middleware(app: FastAPI):
+ """
+ 添加gzip压缩中间件
+
+ :param app: FastAPI对象
+ :return:
+ """
+ app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=9)
diff --git a/middlewares/handle.py b/middlewares/handle.py
new file mode 100644
index 0000000..1c6943e
--- /dev/null
+++ b/middlewares/handle.py
@@ -0,0 +1,16 @@
+from fastapi import FastAPI
+from middlewares.cors_middleware import add_cors_middleware
+from middlewares.gzip_middleware import add_gzip_middleware
+from middlewares.trace_middleware import add_trace_middleware
+
+def handle_middleware(app: FastAPI):
+ """
+ 全局中间件处理
+ """
+ # 加载跨域中间件
+ add_cors_middleware(app)
+ # 加载gzip压缩中间件
+ add_gzip_middleware(app)
+ # 加载trace中间件
+ add_trace_middleware(app)
+
diff --git a/middlewares/trace_middleware/__init__.py b/middlewares/trace_middleware/__init__.py
new file mode 100644
index 0000000..76f8d85
--- /dev/null
+++ b/middlewares/trace_middleware/__init__.py
@@ -0,0 +1,17 @@
+from fastapi import FastAPI
+from .ctx import TraceCtx
+from .middle import TraceASGIMiddleware
+
+__all__ = ('TraceASGIMiddleware', 'TraceCtx')
+
+__version__ = '0.1.0'
+
+
+def add_trace_middleware(app: FastAPI):
+ """
+ 添加trace中间件
+
+ :param app: FastAPI对象
+ :return:
+ """
+ app.add_middleware(TraceASGIMiddleware)
diff --git a/middlewares/trace_middleware/ctx.py b/middlewares/trace_middleware/ctx.py
new file mode 100644
index 0000000..558a5c9
--- /dev/null
+++ b/middlewares/trace_middleware/ctx.py
@@ -0,0 +1,23 @@
+# -*- coding: utf-8 -*-
+"""
+@author: peng
+@file: ctx.py
+@time: 2025/1/17 16:57
+"""
+
+import contextvars
+from uuid import uuid4
+
+CTX_REQUEST_ID: contextvars.ContextVar[str] = contextvars.ContextVar('request-id', default='')
+
+
+class TraceCtx:
+ @staticmethod
+ def set_id():
+ _id = uuid4().hex
+ CTX_REQUEST_ID.set(_id)
+ return _id
+
+ @staticmethod
+ def get_id():
+ return CTX_REQUEST_ID.get()
diff --git a/middlewares/trace_middleware/middle.py b/middlewares/trace_middleware/middle.py
new file mode 100644
index 0000000..a071692
--- /dev/null
+++ b/middlewares/trace_middleware/middle.py
@@ -0,0 +1,47 @@
+# -*- coding: utf-8 -*-
+"""
+@author: peng
+@file: middle.py
+@time: 2025/1/17 16:57
+"""
+
+from functools import wraps
+from starlette.types import ASGIApp, Message, Receive, Scope, Send
+from .span import get_current_span, Span
+
+
+class TraceASGIMiddleware:
+ """
+ fastapi-example:
+ app = FastAPI()
+ app.add_middleware(TraceASGIMiddleware)
+ """
+
+ def __init__(self, app: ASGIApp) -> None:
+ self.app = app
+
+ @staticmethod
+ async def my_receive(receive: Receive, span: Span):
+ await span.request_before()
+
+ @wraps(receive)
+ async def my_receive():
+ message = await receive()
+ await span.request_after(message)
+ return message
+
+ return my_receive
+
+ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
+ if scope['type'] != 'http':
+ await self.app(scope, receive, send)
+ return
+
+ async with get_current_span(scope) as span:
+ handle_outgoing_receive = await self.my_receive(receive, span)
+
+ async def handle_outgoing_request(message: 'Message') -> None:
+ await span.response(message)
+ await send(message)
+
+ await self.app(scope, handle_outgoing_receive, handle_outgoing_request)
diff --git a/middlewares/trace_middleware/span.py b/middlewares/trace_middleware/span.py
new file mode 100644
index 0000000..1e38eab
--- /dev/null
+++ b/middlewares/trace_middleware/span.py
@@ -0,0 +1,52 @@
+# -*- coding: utf-8 -*-
+"""
+@author: peng
+@file: span.py
+@time: 2025/1/17 16:57
+"""
+
+from contextlib import asynccontextmanager
+from starlette.types import Scope, Message
+from .ctx import TraceCtx
+
+
+class Span:
+ """
+ 整个http生命周期:
+ request(before) --> request(after) --> response(before) --> response(after)
+ """
+
+ def __init__(self, scope: Scope):
+ self.scope = scope
+
+ async def request_before(self):
+ """
+ request_before: 处理header信息等, 如记录请求体信息
+ """
+ TraceCtx.set_id()
+
+ async def request_after(self, message: Message):
+ """
+ request_after: 处理请求bytes, 如记录请求参数
+
+ example:
+ message: {'type': 'http.request', 'body': b'{\r\n "name": "\xe8\x8b\x8f\xe8\x8b\x8f\xe8\x8b\x8f"\r\n}', 'more_body': False}
+ """
+ return message
+
+ async def response(self, message: Message):
+ """
+ if message['type'] == "http.response.start": -----> request-before
+ pass
+ if message['type'] == "http.response.body": -----> request-after
+ message.get('body', b'')
+ pass
+ """
+ if message['type'] == 'http.response.start':
+ message['headers'].append((b'request-id', TraceCtx.get_id().encode()))
+ return message
+
+
+@asynccontextmanager
+async def get_current_span(scope: Scope):
+ yield Span(scope)
diff --git a/rag.py b/rag.py
index b5b13ae..b78a313 100644
--- a/rag.py
+++ b/rag.py
@@ -1,31 +1,53 @@
from util import use_pgvector, use_opanai, use_mysql
+from util.log_util import logger
from util.use_mysql import search_desc_by_table_names
async def rag_generate_rule(query : str):
if not query:
return "请输入问题"
- # 连接数据库
- pgvector_conn = use_pgvector.connect_to_db()
+
+ # 询问大模型抽取用户问题中表名
+ prompt = [
+ {"role": "system", "content": f"""
+ 需要从需要从用户问题中抽取表名,请返回一个列表,格式为:[表名1, 表名2, ...]。
+
+ 例如:用户问题:"请查询表A和表B的交集",则返回:[A, B]。
+
+ 请根据用户问题,抽取表名,请勿返回其他内容。
+
+ """},
+ {"role": "user", "content": query}
+ ]
+ tables_in_query = await use_opanai.generation_rule(prompt)
+ tables_in_query = tables_in_query.strip('[]')
+ tables_in_query = [item.strip() for item in tables_in_query.split(',')]
# 将问题向量化
query_emb = await use_opanai.generation_vector("query")
+ # 连接pgvector数据库
+ pgvector_conn = use_pgvector.connect_to_db()
+
# 根据问题关联数据库表,得到table_name_list
similar_docs = use_pgvector.search_similar_table(pgvector_conn, query_emb, limit=4)
rerank_docs = await use_opanai.rerank_documents(query, similar_docs, top_n=2)
table_name_list = [similar_docs[int(index.strip())-1][0] for index in rerank_docs.strip('[]').split(',')]
- print(f"[table_name_list]: {table_name_list}")
+ table_name_list.extend(tables_in_query)
+ # print(f"【table_name_list】: {table_name_list}")
+ logger.info(f"【table_name_list】: {table_name_list}")
# 获得相关表的schema
db_client = await use_mysql.get_db()
schema = await search_desc_by_table_names(table_name_list, db_client)
- print(f"[schema]: {schema}")
+ # print(f"【schema】: {schema}")
+ logger.info(f"【schema】: {schema}")
# 根据问题搜索相关案例
similar_case = use_pgvector.search_similar_case(pgvector_conn, query_emb, limit=3)
- print(f"[similar_case]: {similar_case}")
+ # print(f"【similar_case】: {similar_case}")
+ logger.info(f"【similar_case】: {similar_case}")
# 询问大模型生成SQL
prompt = [
@@ -40,7 +62,7 @@ async def rag_generate_rule(query : str):
"""},
{"role": "user", "content": query}
]
- print(f"[prompt]: {prompt}")
+ # print(f"【prompt】: {prompt}")
ans = await use_opanai.generation_rule(prompt)
return ans
diff --git a/util/log_util.py b/util/log_util.py
new file mode 100644
index 0000000..4cd4222
--- /dev/null
+++ b/util/log_util.py
@@ -0,0 +1,63 @@
+import os
+import sys
+import time
+from loguru import logger as _logger
+from typing import Dict
+from middlewares.trace_middleware import TraceCtx
+
+
+class LoggerInitializer:
+ def __init__(self):
+ self.log_path = os.path.join(os.getcwd(), 'logs')
+ self.__ensure_log_directory_exists()
+ # self.log_path_error = os.path.join(self.log_path, f'{time.strftime("%Y-%m-%d")}_app.log')
+ self.log_path_error = os.path.join(self.log_path, f'app.log')
+
+ def __ensure_log_directory_exists(self):
+ """
+ 确保日志目录存在,如果不存在则创建
+ """
+ if not os.path.exists(self.log_path):
+ os.mkdir(self.log_path)
+
+ @staticmethod
+ def __filter(log: Dict):
+ """
+ 自定义日志过滤器,添加trace_id
+ """
+ log['trace_id'] = TraceCtx.get_id()
+ return log
+
+ def init_log(self):
+ """
+ 初始化日志配置
+ """
+ # 自定义日志格式
+ format_str = (
+ '{time:YYYY-MM-DD HH:mm:ss.SSS} | '
+ '{trace_id} | '
+ '{level: <8} | '
+ '{name}:{function}:{line} - '
+ '{message}'
+ )
+ _logger.remove()
+ # 移除后重新添加sys.stderr, 目的: 控制台输出与文件日志内容和结构一致
+ _logger.add(sys.stderr, filter=self.__filter, format=format_str, enqueue=True)
+ _logger.add(
+ self.log_path_error,
+ filter=self.__filter,
+ format=format_str,
+ # rotation='50MB',
+ rotation='daily', # 每天轮换一次
+ retention="15 days", # 保留最近15天的日志
+ encoding='utf-8',
+ enqueue=True,
+ compression=None,
+ )
+
+ return _logger
+
+
+# 初始化日志处理器
+log_initializer = LoggerInitializer()
+logger = log_initializer.init_log()